Added type hints, moved imports and classes around
This commit is contained in:
		| @@ -2,13 +2,8 @@ | |||||||
| Module that provides different readers for trajectory files. | Module that provides different readers for trajectory files. | ||||||
|  |  | ||||||
| It also provides a common interface layer between the file IO packages, | It also provides a common interface layer between the file IO packages, | ||||||
| namely pygmx and mdanalysis, and mdevaluate. | namely mdanalysis, and mdevaluate. | ||||||
| """ | """ | ||||||
| from .checksum import checksum |  | ||||||
| from .logging import logger |  | ||||||
| from . import atoms |  | ||||||
|  |  | ||||||
| from functools import lru_cache |  | ||||||
| from collections import namedtuple | from collections import namedtuple | ||||||
| import os | import os | ||||||
| from os import path | from os import path | ||||||
| @@ -19,9 +14,19 @@ import re | |||||||
| import itertools | import itertools | ||||||
|  |  | ||||||
| import numpy as np | import numpy as np | ||||||
| import MDAnalysis as mdanalysis | import numpy.typing as npt | ||||||
|  | import MDAnalysis | ||||||
| from scipy import sparse | from scipy import sparse | ||||||
|  |  | ||||||
|  | from .checksum import checksum | ||||||
|  | from .logging import logger | ||||||
|  | from . import atoms | ||||||
|  | from .coordinates import Coordinates | ||||||
|  |  | ||||||
|  | CSR_ATTRS = ("data", "indices", "indptr") | ||||||
|  | NOJUMP_MAGIC = 2016 | ||||||
|  | Group_RE = re.compile("\[ ([-+\w]+) \]") | ||||||
|  |  | ||||||
|  |  | ||||||
| class NojumpError(Exception): | class NojumpError(Exception): | ||||||
|     pass |     pass | ||||||
| @@ -31,11 +36,49 @@ class WrongTopologyError(Exception): | |||||||
|     pass |     pass | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class BaseReader: | ||||||
|  |     """Base class for trajectory readers.""" | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def filename(self): | ||||||
|  |         return self.rd.filename | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def nojump_matrices(self): | ||||||
|  |         if self._nojump_matrices is None: | ||||||
|  |             raise NojumpError("Nojump Data not available: {}".format(self.filename)) | ||||||
|  |         return self._nojump_matrices | ||||||
|  |  | ||||||
|  |     @nojump_matrices.setter | ||||||
|  |     def nojump_matrices(self, mats): | ||||||
|  |         self._nojump_matrices = mats | ||||||
|  |  | ||||||
|  |     def __init__(self, rd): | ||||||
|  |         self.rd = rd | ||||||
|  |         self._nojump_matrices = None | ||||||
|  |         if path.exists(nojump_load_filename(self)): | ||||||
|  |             load_nojump_matrices(self) | ||||||
|  |  | ||||||
|  |     def __getitem__(self, item): | ||||||
|  |         return self.rd[item] | ||||||
|  |  | ||||||
|  |     def __len__(self): | ||||||
|  |         return len(self.rd) | ||||||
|  |  | ||||||
|  |     def __checksum__(self): | ||||||
|  |         cache = array("L", self.rd._xdr.offsets.tobytes()) | ||||||
|  |         return checksum(self.filename, str(cache)) | ||||||
|  |  | ||||||
|  |  | ||||||
| def open_with_mdanalysis( | def open_with_mdanalysis( | ||||||
|     topology, trajectory, index_file=None, charges=None, masses=None |     topology: str, | ||||||
| ): |     trajectory: str, | ||||||
|  |     index_file: str = None, | ||||||
|  |     charges: npt.ArrayLike = None, | ||||||
|  |     masses: npt.ArrayLike = None, | ||||||
|  | ) -> (atoms.Atoms, BaseReader): | ||||||
|     """Open the topology and trajectory with mdanalysis.""" |     """Open the topology and trajectory with mdanalysis.""" | ||||||
|     uni = mdanalysis.Universe(topology, trajectory, convert_units=False) |     uni = MDAnalysis.Universe(topology, trajectory, convert_units=False) | ||||||
|     reader = BaseReader(uni.trajectory) |     reader = BaseReader(uni.trajectory) | ||||||
|     reader.universe = uni |     reader.universe = uni | ||||||
|     if topology.endswith(".tpr"): |     if topology.endswith(".tpr"): | ||||||
| @@ -60,15 +103,12 @@ def open_with_mdanalysis( | |||||||
|     return atms, reader |     return atms, reader | ||||||
|  |  | ||||||
|  |  | ||||||
| group_re = re.compile("\[ ([-+\w]+) \]") | def load_indices(index_file: str): | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_indices(index_file): |  | ||||||
|     indices = {} |     indices = {} | ||||||
|     index_array = None |     index_array = None | ||||||
|     with open(index_file) as idx_file: |     with open(index_file) as idx_file: | ||||||
|         for line in idx_file: |         for line in idx_file: | ||||||
|             m = group_re.search(line) |             m = Group_RE.search(line) | ||||||
|             if m is not None: |             if m is not None: | ||||||
|                 group_name = m.group(1) |                 group_name = m.group(1) | ||||||
|                 index_array = indices.get(group_name, []) |                 index_array = indices.get(group_name, []) | ||||||
| @@ -82,7 +122,7 @@ def load_indices(index_file): | |||||||
|     return indices |     return indices | ||||||
|  |  | ||||||
|  |  | ||||||
| def is_writeable(fname): | def is_writeable(fname: str): | ||||||
|     """Test if a directory is actually writeable, by writing a temporary file.""" |     """Test if a directory is actually writeable, by writing a temporary file.""" | ||||||
|     fdir = os.path.dirname(fname) |     fdir = os.path.dirname(fname) | ||||||
|     ftmp = os.path.join(fdir, str(np.random.randint(999999999))) |     ftmp = os.path.join(fdir, str(np.random.randint(999999999))) | ||||||
| @@ -100,7 +140,7 @@ def is_writeable(fname): | |||||||
|     return False |     return False | ||||||
|  |  | ||||||
|  |  | ||||||
| def nojump_load_filename(reader): | def nojump_load_filename(reader: BaseReader): | ||||||
|     directory, fname = path.split(reader.filename) |     directory, fname = path.split(reader.filename) | ||||||
|     full_path = path.join(directory, ".{}.nojump.npz".format(fname)) |     full_path = path.join(directory, ".{}.nojump.npz".format(fname)) | ||||||
|     if not is_writeable(directory): |     if not is_writeable(directory): | ||||||
| @@ -116,7 +156,7 @@ def nojump_load_filename(reader): | |||||||
|         return full_path |         return full_path | ||||||
|     else: |     else: | ||||||
|         user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1]) |         user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1]) | ||||||
|         full_path_fallback = os.path.join( |         full_path = os.path.join( | ||||||
|             os.path.join(user_data_dir, ".mdevaluate/nojump"), |             os.path.join(user_data_dir, ".mdevaluate/nojump"), | ||||||
|             directory.lstrip("/"), |             directory.lstrip("/"), | ||||||
|             ".{}.nojump.npz".format(fname), |             ".{}.nojump.npz".format(fname), | ||||||
| @@ -124,7 +164,7 @@ def nojump_load_filename(reader): | |||||||
|         return full_path |         return full_path | ||||||
|  |  | ||||||
|  |  | ||||||
| def nojump_save_filename(reader): | def nojump_save_filename(reader: BaseReader): | ||||||
|     directory, fname = path.split(reader.filename) |     directory, fname = path.split(reader.filename) | ||||||
|     full_path = path.join(directory, ".{}.nojump.npz".format(fname)) |     full_path = path.join(directory, ".{}.nojump.npz".format(fname)) | ||||||
|     if is_writeable(directory): |     if is_writeable(directory): | ||||||
| @@ -145,11 +185,7 @@ def nojump_save_filename(reader): | |||||||
|         return full_path_fallback |         return full_path_fallback | ||||||
|  |  | ||||||
|  |  | ||||||
| CSR_ATTRS = ("data", "indices", "indptr") | def parse_jumps(trajectory: Coordinates): | ||||||
| NOJUMP_MAGIC = 2016 |  | ||||||
|  |  | ||||||
|  |  | ||||||
| def parse_jumps(trajectory): |  | ||||||
|     prev = trajectory[0].whole |     prev = trajectory[0].whole | ||||||
|     box = prev.box.diagonal() |     box = prev.box.diagonal() | ||||||
|     SparseData = namedtuple("SparseData", ["data", "row", "col"]) |     SparseData = namedtuple("SparseData", ["data", "row", "col"]) | ||||||
| @@ -173,28 +209,28 @@ def parse_jumps(trajectory): | |||||||
|     return jump_data |     return jump_data | ||||||
|  |  | ||||||
|  |  | ||||||
| def generate_nojump_matrixes(trajectory): | def generate_nojump_matrices(trajectory: Coordinates): | ||||||
|     """ |     """ | ||||||
|     Create the matrixes with pbc jumps for a trajectory. |     Create the matrices with pbc jumps for a trajectory. | ||||||
|     """ |     """ | ||||||
|     logger.info("generate Nojump Matrixes for: {}".format(trajectory)) |     logger.info("generate Nojump matrices for: {}".format(trajectory)) | ||||||
|  |  | ||||||
|     jump_data = parse_jumps(trajectory) |     jump_data = parse_jumps(trajectory) | ||||||
|     N = len(trajectory) |     N = len(trajectory) | ||||||
|     M = len(trajectory[0]) |     M = len(trajectory[0]) | ||||||
|  |  | ||||||
|     trajectory.frames.nojump_matrixes = tuple( |     trajectory.frames.nojump_matrices = tuple( | ||||||
|         sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M)) |         sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M)) | ||||||
|         for m in jump_data |         for m in jump_data | ||||||
|     ) |     ) | ||||||
|     save_nojump_matrixes(trajectory.frames) |     save_nojump_matrices(trajectory.frames) | ||||||
|  |  | ||||||
|  |  | ||||||
| def save_nojump_matrixes(reader, matrixes=None): | def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None): | ||||||
|     if matrixes is None: |     if matrices is None: | ||||||
|         matrixes = reader.nojump_matrixes |         matrices = reader.nojump_matrices | ||||||
|     data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))} |     data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))} | ||||||
|     for d, mat in enumerate(matrixes): |     for d, mat in enumerate(matrices): | ||||||
|         data["shape"] = mat.shape |         data["shape"] = mat.shape | ||||||
|         for attr in CSR_ATTRS: |         for attr in CSR_ATTRS: | ||||||
|             data["{}_{}".format(attr, d)] = getattr(mat, attr) |             data["{}_{}".format(attr, d)] = getattr(mat, attr) | ||||||
| @@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None): | |||||||
|     np.savez(nojump_save_filename(reader), **data) |     np.savez(nojump_save_filename(reader), **data) | ||||||
|  |  | ||||||
|  |  | ||||||
| def load_nojump_matrixes(reader): | def load_nojump_matrices(reader: BaseReader): | ||||||
|     zipname = nojump_load_filename(reader) |     zipname = nojump_load_filename(reader) | ||||||
|     try: |     try: | ||||||
|         data = np.load(zipname, allow_pickle=True) |         data = np.load(zipname, allow_pickle=True) | ||||||
|     except (AttributeError, BadZipFile, OSError): |     except (AttributeError, BadZipFile, OSError): | ||||||
|         # npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed? |         # npz-files can be corrupted, probably a bug for big arrays saved with | ||||||
|  |         # savez_compressed? | ||||||
|         logger.info("Removing zip-File: %s", zipname) |         logger.info("Removing zip-File: %s", zipname) | ||||||
|         os.remove(nojump_load_filename(reader)) |         os.remove(nojump_load_filename(reader)) | ||||||
|         return |         return | ||||||
|     try: |     try: | ||||||
|         if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)): |         if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)): | ||||||
|             reader.nojump_matrixes = tuple( |             reader.nojump_matrices = tuple( | ||||||
|                 sparse.csr_matrix( |                 sparse.csr_matrix( | ||||||
|                     tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS), |                     tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS), | ||||||
|                     shape=data["shape"], |                     shape=data["shape"], | ||||||
| @@ -221,7 +258,7 @@ def load_nojump_matrixes(reader): | |||||||
|                 for d in range(3) |                 for d in range(3) | ||||||
|             ) |             ) | ||||||
|             logger.info( |             logger.info( | ||||||
|                 "Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader)) |                 "Loaded Nojump matrices: {}".format(nojump_load_filename(reader)) | ||||||
|             ) |             ) | ||||||
|         else: |         else: | ||||||
|             logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader))) |             logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader))) | ||||||
| @@ -231,53 +268,19 @@ def load_nojump_matrixes(reader): | |||||||
|         return |         return | ||||||
|  |  | ||||||
|  |  | ||||||
| def correct_nojump_matrixes_for_whole(trajectory): | def correct_nojump_matrices_for_whole(trajectory: Coordinates): | ||||||
|     reader = trajectory.frames |     reader = trajectory.frames | ||||||
|     frame = trajectory[0] |     frame = trajectory[0] | ||||||
|     box = frame.box.diagonal() |     box = frame.box.diagonal() | ||||||
|     cor = ((frame - frame.whole) / box).round().astype(np.int8) |     cor = ((frame - frame.whole) / box).round().astype(np.int8) | ||||||
|     for d in range(3): |     for d in range(3): | ||||||
|         reader.nojump_matrixes[d][0] = cor[:, d] |         reader.nojump_matrices[d][0] = cor[:, d] | ||||||
|     save_nojump_matrixes(reader) |     save_nojump_matrices(reader) | ||||||
|  |  | ||||||
|  |  | ||||||
| def energy_reader(file): | def energy_reader(file: str): | ||||||
|     """Reads a gromacs energy file with mdanalysis and returns an auxiliary file. |     """Reads a gromacs energy file with mdanalysis and returns an auxiliary file. | ||||||
|     Args: |     Args: | ||||||
|         file: Filename of the energy file |         file: Filename of the energy file | ||||||
|     """ |     """ | ||||||
|     return mdanalysis.auxiliary.EDR.EDRReader(file) |     return MDAnalysis.auxiliary.EDR.EDRReader(file) | ||||||
|  |  | ||||||
|  |  | ||||||
| class BaseReader: |  | ||||||
|     """Base class for trajectory readers.""" |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def filename(self): |  | ||||||
|         return self.rd.filename |  | ||||||
|  |  | ||||||
|     @property |  | ||||||
|     def nojump_matrixes(self): |  | ||||||
|         if self._nojump_matrixes is None: |  | ||||||
|             raise NojumpError("Nojump Data not available: {}".format(self.filename)) |  | ||||||
|         return self._nojump_matrixes |  | ||||||
|  |  | ||||||
|     @nojump_matrixes.setter |  | ||||||
|     def nojump_matrixes(self, mats): |  | ||||||
|         self._nojump_matrixes = mats |  | ||||||
|  |  | ||||||
|     def __init__(self, rd): |  | ||||||
|         self.rd = rd |  | ||||||
|         self._nojump_matrixes = None |  | ||||||
|         if path.exists(nojump_load_filename(self)): |  | ||||||
|             load_nojump_matrixes(self) |  | ||||||
|  |  | ||||||
|     def __getitem__(self, item): |  | ||||||
|         return self.rd[item] |  | ||||||
|  |  | ||||||
|     def __len__(self): |  | ||||||
|         return len(self.rd) |  | ||||||
|  |  | ||||||
|     def __checksum__(self): |  | ||||||
|         cache = array("L", self.rd._xdr.offsets.tobytes()) |  | ||||||
|         return checksum(self.filename, str(cache)) |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user