Added type hints, moved imports and classes around

This commit is contained in:
Sebastian Kloth 2023-12-26 11:50:16 +01:00
parent 8bad2861fa
commit c00fc78f23

View File

@ -2,13 +2,8 @@
Module that provides different readers for trajectory files.
It also provides a common interface layer between the file IO packages,
namely pygmx and mdanalysis, and mdevaluate.
namely mdanalysis, and mdevaluate.
"""
from .checksum import checksum
from .logging import logger
from . import atoms
from functools import lru_cache
from collections import namedtuple
import os
from os import path
@ -19,9 +14,19 @@ import re
import itertools
import numpy as np
import MDAnalysis as mdanalysis
import numpy.typing as npt
import MDAnalysis
from scipy import sparse
from .checksum import checksum
from .logging import logger
from . import atoms
from .coordinates import Coordinates
CSR_ATTRS = ("data", "indices", "indptr")
NOJUMP_MAGIC = 2016
Group_RE = re.compile("\[ ([-+\w]+) \]")
class NojumpError(Exception):
pass
@ -31,11 +36,49 @@ class WrongTopologyError(Exception):
pass
class BaseReader:
"""Base class for trajectory readers."""
@property
def filename(self):
return self.rd.filename
@property
def nojump_matrices(self):
if self._nojump_matrices is None:
raise NojumpError("Nojump Data not available: {}".format(self.filename))
return self._nojump_matrices
@nojump_matrices.setter
def nojump_matrices(self, mats):
self._nojump_matrices = mats
def __init__(self, rd):
self.rd = rd
self._nojump_matrices = None
if path.exists(nojump_load_filename(self)):
load_nojump_matrices(self)
def __getitem__(self, item):
return self.rd[item]
def __len__(self):
return len(self.rd)
def __checksum__(self):
cache = array("L", self.rd._xdr.offsets.tobytes())
return checksum(self.filename, str(cache))
def open_with_mdanalysis(
topology, trajectory, index_file=None, charges=None, masses=None
):
topology: str,
trajectory: str,
index_file: str = None,
charges: npt.ArrayLike = None,
masses: npt.ArrayLike = None,
) -> (atoms.Atoms, BaseReader):
"""Open the topology and trajectory with mdanalysis."""
uni = mdanalysis.Universe(topology, trajectory, convert_units=False)
uni = MDAnalysis.Universe(topology, trajectory, convert_units=False)
reader = BaseReader(uni.trajectory)
reader.universe = uni
if topology.endswith(".tpr"):
@ -60,15 +103,12 @@ def open_with_mdanalysis(
return atms, reader
group_re = re.compile("\[ ([-+\w]+) \]")
def load_indices(index_file):
def load_indices(index_file: str):
indices = {}
index_array = None
with open(index_file) as idx_file:
for line in idx_file:
m = group_re.search(line)
m = Group_RE.search(line)
if m is not None:
group_name = m.group(1)
index_array = indices.get(group_name, [])
@ -82,7 +122,7 @@ def load_indices(index_file):
return indices
def is_writeable(fname):
def is_writeable(fname: str):
"""Test if a directory is actually writeable, by writing a temporary file."""
fdir = os.path.dirname(fname)
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
@ -100,7 +140,7 @@ def is_writeable(fname):
return False
def nojump_load_filename(reader):
def nojump_load_filename(reader: BaseReader):
directory, fname = path.split(reader.filename)
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
if not is_writeable(directory):
@ -116,7 +156,7 @@ def nojump_load_filename(reader):
return full_path
else:
user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1])
full_path_fallback = os.path.join(
full_path = os.path.join(
os.path.join(user_data_dir, ".mdevaluate/nojump"),
directory.lstrip("/"),
".{}.nojump.npz".format(fname),
@ -124,7 +164,7 @@ def nojump_load_filename(reader):
return full_path
def nojump_save_filename(reader):
def nojump_save_filename(reader: BaseReader):
directory, fname = path.split(reader.filename)
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
if is_writeable(directory):
@ -145,11 +185,7 @@ def nojump_save_filename(reader):
return full_path_fallback
CSR_ATTRS = ("data", "indices", "indptr")
NOJUMP_MAGIC = 2016
def parse_jumps(trajectory):
def parse_jumps(trajectory: Coordinates):
prev = trajectory[0].whole
box = prev.box.diagonal()
SparseData = namedtuple("SparseData", ["data", "row", "col"])
@ -173,28 +209,28 @@ def parse_jumps(trajectory):
return jump_data
def generate_nojump_matrixes(trajectory):
def generate_nojump_matrices(trajectory: Coordinates):
"""
Create the matrixes with pbc jumps for a trajectory.
Create the matrices with pbc jumps for a trajectory.
"""
logger.info("generate Nojump Matrixes for: {}".format(trajectory))
logger.info("generate Nojump matrices for: {}".format(trajectory))
jump_data = parse_jumps(trajectory)
N = len(trajectory)
M = len(trajectory[0])
trajectory.frames.nojump_matrixes = tuple(
trajectory.frames.nojump_matrices = tuple(
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M))
for m in jump_data
)
save_nojump_matrixes(trajectory.frames)
save_nojump_matrices(trajectory.frames)
def save_nojump_matrixes(reader, matrixes=None):
if matrixes is None:
matrixes = reader.nojump_matrixes
def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None):
if matrices is None:
matrices = reader.nojump_matrices
data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))}
for d, mat in enumerate(matrixes):
for d, mat in enumerate(matrices):
data["shape"] = mat.shape
for attr in CSR_ATTRS:
data["{}_{}".format(attr, d)] = getattr(mat, attr)
@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None):
np.savez(nojump_save_filename(reader), **data)
def load_nojump_matrixes(reader):
def load_nojump_matrices(reader: BaseReader):
zipname = nojump_load_filename(reader)
try:
data = np.load(zipname, allow_pickle=True)
except (AttributeError, BadZipFile, OSError):
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed?
# npz-files can be corrupted, probably a bug for big arrays saved with
# savez_compressed?
logger.info("Removing zip-File: %s", zipname)
os.remove(nojump_load_filename(reader))
return
try:
if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)):
reader.nojump_matrixes = tuple(
reader.nojump_matrices = tuple(
sparse.csr_matrix(
tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS),
shape=data["shape"],
@ -221,7 +258,7 @@ def load_nojump_matrixes(reader):
for d in range(3)
)
logger.info(
"Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader))
"Loaded Nojump matrices: {}".format(nojump_load_filename(reader))
)
else:
logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader)))
@ -231,53 +268,19 @@ def load_nojump_matrixes(reader):
return
def correct_nojump_matrixes_for_whole(trajectory):
def correct_nojump_matrices_for_whole(trajectory: Coordinates):
reader = trajectory.frames
frame = trajectory[0]
box = frame.box.diagonal()
cor = ((frame - frame.whole) / box).round().astype(np.int8)
for d in range(3):
reader.nojump_matrixes[d][0] = cor[:, d]
save_nojump_matrixes(reader)
reader.nojump_matrices[d][0] = cor[:, d]
save_nojump_matrices(reader)
def energy_reader(file):
def energy_reader(file: str):
"""Reads a gromacs energy file with mdanalysis and returns an auxiliary file.
Args:
file: Filename of the energy file
"""
return mdanalysis.auxiliary.EDR.EDRReader(file)
class BaseReader:
"""Base class for trajectory readers."""
@property
def filename(self):
return self.rd.filename
@property
def nojump_matrixes(self):
if self._nojump_matrixes is None:
raise NojumpError("Nojump Data not available: {}".format(self.filename))
return self._nojump_matrixes
@nojump_matrixes.setter
def nojump_matrixes(self, mats):
self._nojump_matrixes = mats
def __init__(self, rd):
self.rd = rd
self._nojump_matrixes = None
if path.exists(nojump_load_filename(self)):
load_nojump_matrixes(self)
def __getitem__(self, item):
return self.rd[item]
def __len__(self):
return len(self.rd)
def __checksum__(self):
cache = array("L", self.rd._xdr.offsets.tobytes())
return checksum(self.filename, str(cache))
return MDAnalysis.auxiliary.EDR.EDRReader(file)