Added type hints, moved imports and classes around

This commit is contained in:
Sebastian Kloth 2023-12-26 11:50:16 +01:00
parent 8bad2861fa
commit c00fc78f23

View File

@ -2,13 +2,8 @@
Module that provides different readers for trajectory files. Module that provides different readers for trajectory files.
It also provides a common interface layer between the file IO packages, It also provides a common interface layer between the file IO packages,
namely pygmx and mdanalysis, and mdevaluate. namely mdanalysis, and mdevaluate.
""" """
from .checksum import checksum
from .logging import logger
from . import atoms
from functools import lru_cache
from collections import namedtuple from collections import namedtuple
import os import os
from os import path from os import path
@ -19,9 +14,19 @@ import re
import itertools import itertools
import numpy as np import numpy as np
import MDAnalysis as mdanalysis import numpy.typing as npt
import MDAnalysis
from scipy import sparse from scipy import sparse
from .checksum import checksum
from .logging import logger
from . import atoms
from .coordinates import Coordinates
CSR_ATTRS = ("data", "indices", "indptr")
NOJUMP_MAGIC = 2016
Group_RE = re.compile("\[ ([-+\w]+) \]")
class NojumpError(Exception): class NojumpError(Exception):
pass pass
@ -31,11 +36,49 @@ class WrongTopologyError(Exception):
pass pass
class BaseReader:
"""Base class for trajectory readers."""
@property
def filename(self):
return self.rd.filename
@property
def nojump_matrices(self):
if self._nojump_matrices is None:
raise NojumpError("Nojump Data not available: {}".format(self.filename))
return self._nojump_matrices
@nojump_matrices.setter
def nojump_matrices(self, mats):
self._nojump_matrices = mats
def __init__(self, rd):
self.rd = rd
self._nojump_matrices = None
if path.exists(nojump_load_filename(self)):
load_nojump_matrices(self)
def __getitem__(self, item):
return self.rd[item]
def __len__(self):
return len(self.rd)
def __checksum__(self):
cache = array("L", self.rd._xdr.offsets.tobytes())
return checksum(self.filename, str(cache))
def open_with_mdanalysis( def open_with_mdanalysis(
topology, trajectory, index_file=None, charges=None, masses=None topology: str,
): trajectory: str,
index_file: str = None,
charges: npt.ArrayLike = None,
masses: npt.ArrayLike = None,
) -> (atoms.Atoms, BaseReader):
"""Open the topology and trajectory with mdanalysis.""" """Open the topology and trajectory with mdanalysis."""
uni = mdanalysis.Universe(topology, trajectory, convert_units=False) uni = MDAnalysis.Universe(topology, trajectory, convert_units=False)
reader = BaseReader(uni.trajectory) reader = BaseReader(uni.trajectory)
reader.universe = uni reader.universe = uni
if topology.endswith(".tpr"): if topology.endswith(".tpr"):
@ -60,15 +103,12 @@ def open_with_mdanalysis(
return atms, reader return atms, reader
group_re = re.compile("\[ ([-+\w]+) \]") def load_indices(index_file: str):
def load_indices(index_file):
indices = {} indices = {}
index_array = None index_array = None
with open(index_file) as idx_file: with open(index_file) as idx_file:
for line in idx_file: for line in idx_file:
m = group_re.search(line) m = Group_RE.search(line)
if m is not None: if m is not None:
group_name = m.group(1) group_name = m.group(1)
index_array = indices.get(group_name, []) index_array = indices.get(group_name, [])
@ -82,7 +122,7 @@ def load_indices(index_file):
return indices return indices
def is_writeable(fname): def is_writeable(fname: str):
"""Test if a directory is actually writeable, by writing a temporary file.""" """Test if a directory is actually writeable, by writing a temporary file."""
fdir = os.path.dirname(fname) fdir = os.path.dirname(fname)
ftmp = os.path.join(fdir, str(np.random.randint(999999999))) ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
@ -100,7 +140,7 @@ def is_writeable(fname):
return False return False
def nojump_load_filename(reader): def nojump_load_filename(reader: BaseReader):
directory, fname = path.split(reader.filename) directory, fname = path.split(reader.filename)
full_path = path.join(directory, ".{}.nojump.npz".format(fname)) full_path = path.join(directory, ".{}.nojump.npz".format(fname))
if not is_writeable(directory): if not is_writeable(directory):
@ -116,7 +156,7 @@ def nojump_load_filename(reader):
return full_path return full_path
else: else:
user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1]) user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1])
full_path_fallback = os.path.join( full_path = os.path.join(
os.path.join(user_data_dir, ".mdevaluate/nojump"), os.path.join(user_data_dir, ".mdevaluate/nojump"),
directory.lstrip("/"), directory.lstrip("/"),
".{}.nojump.npz".format(fname), ".{}.nojump.npz".format(fname),
@ -124,7 +164,7 @@ def nojump_load_filename(reader):
return full_path return full_path
def nojump_save_filename(reader): def nojump_save_filename(reader: BaseReader):
directory, fname = path.split(reader.filename) directory, fname = path.split(reader.filename)
full_path = path.join(directory, ".{}.nojump.npz".format(fname)) full_path = path.join(directory, ".{}.nojump.npz".format(fname))
if is_writeable(directory): if is_writeable(directory):
@ -145,11 +185,7 @@ def nojump_save_filename(reader):
return full_path_fallback return full_path_fallback
CSR_ATTRS = ("data", "indices", "indptr") def parse_jumps(trajectory: Coordinates):
NOJUMP_MAGIC = 2016
def parse_jumps(trajectory):
prev = trajectory[0].whole prev = trajectory[0].whole
box = prev.box.diagonal() box = prev.box.diagonal()
SparseData = namedtuple("SparseData", ["data", "row", "col"]) SparseData = namedtuple("SparseData", ["data", "row", "col"])
@ -173,28 +209,28 @@ def parse_jumps(trajectory):
return jump_data return jump_data
def generate_nojump_matrixes(trajectory): def generate_nojump_matrices(trajectory: Coordinates):
""" """
Create the matrixes with pbc jumps for a trajectory. Create the matrices with pbc jumps for a trajectory.
""" """
logger.info("generate Nojump Matrixes for: {}".format(trajectory)) logger.info("generate Nojump matrices for: {}".format(trajectory))
jump_data = parse_jumps(trajectory) jump_data = parse_jumps(trajectory)
N = len(trajectory) N = len(trajectory)
M = len(trajectory[0]) M = len(trajectory[0])
trajectory.frames.nojump_matrixes = tuple( trajectory.frames.nojump_matrices = tuple(
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M)) sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M))
for m in jump_data for m in jump_data
) )
save_nojump_matrixes(trajectory.frames) save_nojump_matrices(trajectory.frames)
def save_nojump_matrixes(reader, matrixes=None): def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None):
if matrixes is None: if matrices is None:
matrixes = reader.nojump_matrixes matrices = reader.nojump_matrices
data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))} data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))}
for d, mat in enumerate(matrixes): for d, mat in enumerate(matrices):
data["shape"] = mat.shape data["shape"] = mat.shape
for attr in CSR_ATTRS: for attr in CSR_ATTRS:
data["{}_{}".format(attr, d)] = getattr(mat, attr) data["{}_{}".format(attr, d)] = getattr(mat, attr)
@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None):
np.savez(nojump_save_filename(reader), **data) np.savez(nojump_save_filename(reader), **data)
def load_nojump_matrixes(reader): def load_nojump_matrices(reader: BaseReader):
zipname = nojump_load_filename(reader) zipname = nojump_load_filename(reader)
try: try:
data = np.load(zipname, allow_pickle=True) data = np.load(zipname, allow_pickle=True)
except (AttributeError, BadZipFile, OSError): except (AttributeError, BadZipFile, OSError):
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed? # npz-files can be corrupted, probably a bug for big arrays saved with
# savez_compressed?
logger.info("Removing zip-File: %s", zipname) logger.info("Removing zip-File: %s", zipname)
os.remove(nojump_load_filename(reader)) os.remove(nojump_load_filename(reader))
return return
try: try:
if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)): if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)):
reader.nojump_matrixes = tuple( reader.nojump_matrices = tuple(
sparse.csr_matrix( sparse.csr_matrix(
tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS), tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS),
shape=data["shape"], shape=data["shape"],
@ -221,7 +258,7 @@ def load_nojump_matrixes(reader):
for d in range(3) for d in range(3)
) )
logger.info( logger.info(
"Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader)) "Loaded Nojump matrices: {}".format(nojump_load_filename(reader))
) )
else: else:
logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader))) logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader)))
@ -231,53 +268,19 @@ def load_nojump_matrixes(reader):
return return
def correct_nojump_matrixes_for_whole(trajectory): def correct_nojump_matrices_for_whole(trajectory: Coordinates):
reader = trajectory.frames reader = trajectory.frames
frame = trajectory[0] frame = trajectory[0]
box = frame.box.diagonal() box = frame.box.diagonal()
cor = ((frame - frame.whole) / box).round().astype(np.int8) cor = ((frame - frame.whole) / box).round().astype(np.int8)
for d in range(3): for d in range(3):
reader.nojump_matrixes[d][0] = cor[:, d] reader.nojump_matrices[d][0] = cor[:, d]
save_nojump_matrixes(reader) save_nojump_matrices(reader)
def energy_reader(file): def energy_reader(file: str):
"""Reads a gromacs energy file with mdanalysis and returns an auxiliary file. """Reads a gromacs energy file with mdanalysis and returns an auxiliary file.
Args: Args:
file: Filename of the energy file file: Filename of the energy file
""" """
return mdanalysis.auxiliary.EDR.EDRReader(file) return MDAnalysis.auxiliary.EDR.EDRReader(file)
class BaseReader:
"""Base class for trajectory readers."""
@property
def filename(self):
return self.rd.filename
@property
def nojump_matrixes(self):
if self._nojump_matrixes is None:
raise NojumpError("Nojump Data not available: {}".format(self.filename))
return self._nojump_matrixes
@nojump_matrixes.setter
def nojump_matrixes(self, mats):
self._nojump_matrixes = mats
def __init__(self, rd):
self.rd = rd
self._nojump_matrixes = None
if path.exists(nojump_load_filename(self)):
load_nojump_matrixes(self)
def __getitem__(self, item):
return self.rd[item]
def __len__(self):
return len(self.rd)
def __checksum__(self):
cache = array("L", self.rd._xdr.offsets.tobytes())
return checksum(self.filename, str(cache))