From c00fc78f2360ca53d8ba602c0233a89539e0de85 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Tue, 26 Dec 2023 11:50:16 +0100 Subject: [PATCH] Added type hints, moved imports and classes around --- src/mdevaluate/reader.py | 155 ++++++++++++++++++++------------------- 1 file changed, 79 insertions(+), 76 deletions(-) diff --git a/src/mdevaluate/reader.py b/src/mdevaluate/reader.py index 7152fc1..d4c4ae5 100755 --- a/src/mdevaluate/reader.py +++ b/src/mdevaluate/reader.py @@ -2,13 +2,8 @@ Module that provides different readers for trajectory files. It also provides a common interface layer between the file IO packages, -namely pygmx and mdanalysis, and mdevaluate. +namely mdanalysis, and mdevaluate. """ -from .checksum import checksum -from .logging import logger -from . import atoms - -from functools import lru_cache from collections import namedtuple import os from os import path @@ -19,9 +14,19 @@ import re import itertools import numpy as np -import MDAnalysis as mdanalysis +import numpy.typing as npt +import MDAnalysis from scipy import sparse +from .checksum import checksum +from .logging import logger +from . import atoms +from .coordinates import Coordinates + +CSR_ATTRS = ("data", "indices", "indptr") +NOJUMP_MAGIC = 2016 +Group_RE = re.compile("\[ ([-+\w]+) \]") + class NojumpError(Exception): pass @@ -31,11 +36,49 @@ class WrongTopologyError(Exception): pass +class BaseReader: + """Base class for trajectory readers.""" + + @property + def filename(self): + return self.rd.filename + + @property + def nojump_matrices(self): + if self._nojump_matrices is None: + raise NojumpError("Nojump Data not available: {}".format(self.filename)) + return self._nojump_matrices + + @nojump_matrices.setter + def nojump_matrices(self, mats): + self._nojump_matrices = mats + + def __init__(self, rd): + self.rd = rd + self._nojump_matrices = None + if path.exists(nojump_load_filename(self)): + load_nojump_matrices(self) + + def __getitem__(self, item): + return self.rd[item] + + def __len__(self): + return len(self.rd) + + def __checksum__(self): + cache = array("L", self.rd._xdr.offsets.tobytes()) + return checksum(self.filename, str(cache)) + + def open_with_mdanalysis( - topology, trajectory, index_file=None, charges=None, masses=None -): + topology: str, + trajectory: str, + index_file: str = None, + charges: npt.ArrayLike = None, + masses: npt.ArrayLike = None, +) -> (atoms.Atoms, BaseReader): """Open the topology and trajectory with mdanalysis.""" - uni = mdanalysis.Universe(topology, trajectory, convert_units=False) + uni = MDAnalysis.Universe(topology, trajectory, convert_units=False) reader = BaseReader(uni.trajectory) reader.universe = uni if topology.endswith(".tpr"): @@ -60,15 +103,12 @@ def open_with_mdanalysis( return atms, reader -group_re = re.compile("\[ ([-+\w]+) \]") - - -def load_indices(index_file): +def load_indices(index_file: str): indices = {} index_array = None with open(index_file) as idx_file: for line in idx_file: - m = group_re.search(line) + m = Group_RE.search(line) if m is not None: group_name = m.group(1) index_array = indices.get(group_name, []) @@ -82,7 +122,7 @@ def load_indices(index_file): return indices -def is_writeable(fname): +def is_writeable(fname: str): """Test if a directory is actually writeable, by writing a temporary file.""" fdir = os.path.dirname(fname) ftmp = os.path.join(fdir, str(np.random.randint(999999999))) @@ -100,7 +140,7 @@ def is_writeable(fname): return False -def nojump_load_filename(reader): +def nojump_load_filename(reader: BaseReader): directory, fname = path.split(reader.filename) full_path = path.join(directory, ".{}.nojump.npz".format(fname)) if not is_writeable(directory): @@ -116,7 +156,7 @@ def nojump_load_filename(reader): return full_path else: user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1]) - full_path_fallback = os.path.join( + full_path = os.path.join( os.path.join(user_data_dir, ".mdevaluate/nojump"), directory.lstrip("/"), ".{}.nojump.npz".format(fname), @@ -124,7 +164,7 @@ def nojump_load_filename(reader): return full_path -def nojump_save_filename(reader): +def nojump_save_filename(reader: BaseReader): directory, fname = path.split(reader.filename) full_path = path.join(directory, ".{}.nojump.npz".format(fname)) if is_writeable(directory): @@ -145,11 +185,7 @@ def nojump_save_filename(reader): return full_path_fallback -CSR_ATTRS = ("data", "indices", "indptr") -NOJUMP_MAGIC = 2016 - - -def parse_jumps(trajectory): +def parse_jumps(trajectory: Coordinates): prev = trajectory[0].whole box = prev.box.diagonal() SparseData = namedtuple("SparseData", ["data", "row", "col"]) @@ -173,28 +209,28 @@ def parse_jumps(trajectory): return jump_data -def generate_nojump_matrixes(trajectory): +def generate_nojump_matrices(trajectory: Coordinates): """ - Create the matrixes with pbc jumps for a trajectory. + Create the matrices with pbc jumps for a trajectory. """ - logger.info("generate Nojump Matrixes for: {}".format(trajectory)) + logger.info("generate Nojump matrices for: {}".format(trajectory)) jump_data = parse_jumps(trajectory) N = len(trajectory) M = len(trajectory[0]) - trajectory.frames.nojump_matrixes = tuple( + trajectory.frames.nojump_matrices = tuple( sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M)) for m in jump_data ) - save_nojump_matrixes(trajectory.frames) + save_nojump_matrices(trajectory.frames) -def save_nojump_matrixes(reader, matrixes=None): - if matrixes is None: - matrixes = reader.nojump_matrixes +def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None): + if matrices is None: + matrices = reader.nojump_matrices data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))} - for d, mat in enumerate(matrixes): + for d, mat in enumerate(matrices): data["shape"] = mat.shape for attr in CSR_ATTRS: data["{}_{}".format(attr, d)] = getattr(mat, attr) @@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None): np.savez(nojump_save_filename(reader), **data) -def load_nojump_matrixes(reader): +def load_nojump_matrices(reader: BaseReader): zipname = nojump_load_filename(reader) try: data = np.load(zipname, allow_pickle=True) except (AttributeError, BadZipFile, OSError): - # npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed? + # npz-files can be corrupted, probably a bug for big arrays saved with + # savez_compressed? logger.info("Removing zip-File: %s", zipname) os.remove(nojump_load_filename(reader)) return try: if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)): - reader.nojump_matrixes = tuple( + reader.nojump_matrices = tuple( sparse.csr_matrix( tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS), shape=data["shape"], @@ -221,7 +258,7 @@ def load_nojump_matrixes(reader): for d in range(3) ) logger.info( - "Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader)) + "Loaded Nojump matrices: {}".format(nojump_load_filename(reader)) ) else: logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader))) @@ -231,53 +268,19 @@ def load_nojump_matrixes(reader): return -def correct_nojump_matrixes_for_whole(trajectory): +def correct_nojump_matrices_for_whole(trajectory: Coordinates): reader = trajectory.frames frame = trajectory[0] box = frame.box.diagonal() cor = ((frame - frame.whole) / box).round().astype(np.int8) for d in range(3): - reader.nojump_matrixes[d][0] = cor[:, d] - save_nojump_matrixes(reader) + reader.nojump_matrices[d][0] = cor[:, d] + save_nojump_matrices(reader) -def energy_reader(file): +def energy_reader(file: str): """Reads a gromacs energy file with mdanalysis and returns an auxiliary file. Args: file: Filename of the energy file """ - return mdanalysis.auxiliary.EDR.EDRReader(file) - - -class BaseReader: - """Base class for trajectory readers.""" - - @property - def filename(self): - return self.rd.filename - - @property - def nojump_matrixes(self): - if self._nojump_matrixes is None: - raise NojumpError("Nojump Data not available: {}".format(self.filename)) - return self._nojump_matrixes - - @nojump_matrixes.setter - def nojump_matrixes(self, mats): - self._nojump_matrixes = mats - - def __init__(self, rd): - self.rd = rd - self._nojump_matrixes = None - if path.exists(nojump_load_filename(self)): - load_nojump_matrixes(self) - - def __getitem__(self, item): - return self.rd[item] - - def __len__(self): - return len(self.rd) - - def __checksum__(self): - cache = array("L", self.rd._xdr.offsets.tobytes()) - return checksum(self.filename, str(cache)) + return MDAnalysis.auxiliary.EDR.EDRReader(file)