Added type hints, moved imports and classes around
This commit is contained in:
parent
8bad2861fa
commit
c00fc78f23
@ -2,13 +2,8 @@
|
||||
Module that provides different readers for trajectory files.
|
||||
|
||||
It also provides a common interface layer between the file IO packages,
|
||||
namely pygmx and mdanalysis, and mdevaluate.
|
||||
namely mdanalysis, and mdevaluate.
|
||||
"""
|
||||
from .checksum import checksum
|
||||
from .logging import logger
|
||||
from . import atoms
|
||||
|
||||
from functools import lru_cache
|
||||
from collections import namedtuple
|
||||
import os
|
||||
from os import path
|
||||
@ -19,9 +14,19 @@ import re
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import MDAnalysis as mdanalysis
|
||||
import numpy.typing as npt
|
||||
import MDAnalysis
|
||||
from scipy import sparse
|
||||
|
||||
from .checksum import checksum
|
||||
from .logging import logger
|
||||
from . import atoms
|
||||
from .coordinates import Coordinates
|
||||
|
||||
CSR_ATTRS = ("data", "indices", "indptr")
|
||||
NOJUMP_MAGIC = 2016
|
||||
Group_RE = re.compile("\[ ([-+\w]+) \]")
|
||||
|
||||
|
||||
class NojumpError(Exception):
|
||||
pass
|
||||
@ -31,11 +36,49 @@ class WrongTopologyError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class BaseReader:
|
||||
"""Base class for trajectory readers."""
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self.rd.filename
|
||||
|
||||
@property
|
||||
def nojump_matrices(self):
|
||||
if self._nojump_matrices is None:
|
||||
raise NojumpError("Nojump Data not available: {}".format(self.filename))
|
||||
return self._nojump_matrices
|
||||
|
||||
@nojump_matrices.setter
|
||||
def nojump_matrices(self, mats):
|
||||
self._nojump_matrices = mats
|
||||
|
||||
def __init__(self, rd):
|
||||
self.rd = rd
|
||||
self._nojump_matrices = None
|
||||
if path.exists(nojump_load_filename(self)):
|
||||
load_nojump_matrices(self)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.rd[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.rd)
|
||||
|
||||
def __checksum__(self):
|
||||
cache = array("L", self.rd._xdr.offsets.tobytes())
|
||||
return checksum(self.filename, str(cache))
|
||||
|
||||
|
||||
def open_with_mdanalysis(
|
||||
topology, trajectory, index_file=None, charges=None, masses=None
|
||||
):
|
||||
topology: str,
|
||||
trajectory: str,
|
||||
index_file: str = None,
|
||||
charges: npt.ArrayLike = None,
|
||||
masses: npt.ArrayLike = None,
|
||||
) -> (atoms.Atoms, BaseReader):
|
||||
"""Open the topology and trajectory with mdanalysis."""
|
||||
uni = mdanalysis.Universe(topology, trajectory, convert_units=False)
|
||||
uni = MDAnalysis.Universe(topology, trajectory, convert_units=False)
|
||||
reader = BaseReader(uni.trajectory)
|
||||
reader.universe = uni
|
||||
if topology.endswith(".tpr"):
|
||||
@ -60,15 +103,12 @@ def open_with_mdanalysis(
|
||||
return atms, reader
|
||||
|
||||
|
||||
group_re = re.compile("\[ ([-+\w]+) \]")
|
||||
|
||||
|
||||
def load_indices(index_file):
|
||||
def load_indices(index_file: str):
|
||||
indices = {}
|
||||
index_array = None
|
||||
with open(index_file) as idx_file:
|
||||
for line in idx_file:
|
||||
m = group_re.search(line)
|
||||
m = Group_RE.search(line)
|
||||
if m is not None:
|
||||
group_name = m.group(1)
|
||||
index_array = indices.get(group_name, [])
|
||||
@ -82,7 +122,7 @@ def load_indices(index_file):
|
||||
return indices
|
||||
|
||||
|
||||
def is_writeable(fname):
|
||||
def is_writeable(fname: str):
|
||||
"""Test if a directory is actually writeable, by writing a temporary file."""
|
||||
fdir = os.path.dirname(fname)
|
||||
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
|
||||
@ -100,7 +140,7 @@ def is_writeable(fname):
|
||||
return False
|
||||
|
||||
|
||||
def nojump_load_filename(reader):
|
||||
def nojump_load_filename(reader: BaseReader):
|
||||
directory, fname = path.split(reader.filename)
|
||||
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
||||
if not is_writeable(directory):
|
||||
@ -116,7 +156,7 @@ def nojump_load_filename(reader):
|
||||
return full_path
|
||||
else:
|
||||
user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1])
|
||||
full_path_fallback = os.path.join(
|
||||
full_path = os.path.join(
|
||||
os.path.join(user_data_dir, ".mdevaluate/nojump"),
|
||||
directory.lstrip("/"),
|
||||
".{}.nojump.npz".format(fname),
|
||||
@ -124,7 +164,7 @@ def nojump_load_filename(reader):
|
||||
return full_path
|
||||
|
||||
|
||||
def nojump_save_filename(reader):
|
||||
def nojump_save_filename(reader: BaseReader):
|
||||
directory, fname = path.split(reader.filename)
|
||||
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
||||
if is_writeable(directory):
|
||||
@ -145,11 +185,7 @@ def nojump_save_filename(reader):
|
||||
return full_path_fallback
|
||||
|
||||
|
||||
CSR_ATTRS = ("data", "indices", "indptr")
|
||||
NOJUMP_MAGIC = 2016
|
||||
|
||||
|
||||
def parse_jumps(trajectory):
|
||||
def parse_jumps(trajectory: Coordinates):
|
||||
prev = trajectory[0].whole
|
||||
box = prev.box.diagonal()
|
||||
SparseData = namedtuple("SparseData", ["data", "row", "col"])
|
||||
@ -173,28 +209,28 @@ def parse_jumps(trajectory):
|
||||
return jump_data
|
||||
|
||||
|
||||
def generate_nojump_matrixes(trajectory):
|
||||
def generate_nojump_matrices(trajectory: Coordinates):
|
||||
"""
|
||||
Create the matrixes with pbc jumps for a trajectory.
|
||||
Create the matrices with pbc jumps for a trajectory.
|
||||
"""
|
||||
logger.info("generate Nojump Matrixes for: {}".format(trajectory))
|
||||
logger.info("generate Nojump matrices for: {}".format(trajectory))
|
||||
|
||||
jump_data = parse_jumps(trajectory)
|
||||
N = len(trajectory)
|
||||
M = len(trajectory[0])
|
||||
|
||||
trajectory.frames.nojump_matrixes = tuple(
|
||||
trajectory.frames.nojump_matrices = tuple(
|
||||
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M))
|
||||
for m in jump_data
|
||||
)
|
||||
save_nojump_matrixes(trajectory.frames)
|
||||
save_nojump_matrices(trajectory.frames)
|
||||
|
||||
|
||||
def save_nojump_matrixes(reader, matrixes=None):
|
||||
if matrixes is None:
|
||||
matrixes = reader.nojump_matrixes
|
||||
def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None):
|
||||
if matrices is None:
|
||||
matrices = reader.nojump_matrices
|
||||
data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))}
|
||||
for d, mat in enumerate(matrixes):
|
||||
for d, mat in enumerate(matrices):
|
||||
data["shape"] = mat.shape
|
||||
for attr in CSR_ATTRS:
|
||||
data["{}_{}".format(attr, d)] = getattr(mat, attr)
|
||||
@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None):
|
||||
np.savez(nojump_save_filename(reader), **data)
|
||||
|
||||
|
||||
def load_nojump_matrixes(reader):
|
||||
def load_nojump_matrices(reader: BaseReader):
|
||||
zipname = nojump_load_filename(reader)
|
||||
try:
|
||||
data = np.load(zipname, allow_pickle=True)
|
||||
except (AttributeError, BadZipFile, OSError):
|
||||
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed?
|
||||
# npz-files can be corrupted, probably a bug for big arrays saved with
|
||||
# savez_compressed?
|
||||
logger.info("Removing zip-File: %s", zipname)
|
||||
os.remove(nojump_load_filename(reader))
|
||||
return
|
||||
try:
|
||||
if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)):
|
||||
reader.nojump_matrixes = tuple(
|
||||
reader.nojump_matrices = tuple(
|
||||
sparse.csr_matrix(
|
||||
tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS),
|
||||
shape=data["shape"],
|
||||
@ -221,7 +258,7 @@ def load_nojump_matrixes(reader):
|
||||
for d in range(3)
|
||||
)
|
||||
logger.info(
|
||||
"Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader))
|
||||
"Loaded Nojump matrices: {}".format(nojump_load_filename(reader))
|
||||
)
|
||||
else:
|
||||
logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader)))
|
||||
@ -231,53 +268,19 @@ def load_nojump_matrixes(reader):
|
||||
return
|
||||
|
||||
|
||||
def correct_nojump_matrixes_for_whole(trajectory):
|
||||
def correct_nojump_matrices_for_whole(trajectory: Coordinates):
|
||||
reader = trajectory.frames
|
||||
frame = trajectory[0]
|
||||
box = frame.box.diagonal()
|
||||
cor = ((frame - frame.whole) / box).round().astype(np.int8)
|
||||
for d in range(3):
|
||||
reader.nojump_matrixes[d][0] = cor[:, d]
|
||||
save_nojump_matrixes(reader)
|
||||
reader.nojump_matrices[d][0] = cor[:, d]
|
||||
save_nojump_matrices(reader)
|
||||
|
||||
|
||||
def energy_reader(file):
|
||||
def energy_reader(file: str):
|
||||
"""Reads a gromacs energy file with mdanalysis and returns an auxiliary file.
|
||||
Args:
|
||||
file: Filename of the energy file
|
||||
"""
|
||||
return mdanalysis.auxiliary.EDR.EDRReader(file)
|
||||
|
||||
|
||||
class BaseReader:
|
||||
"""Base class for trajectory readers."""
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self.rd.filename
|
||||
|
||||
@property
|
||||
def nojump_matrixes(self):
|
||||
if self._nojump_matrixes is None:
|
||||
raise NojumpError("Nojump Data not available: {}".format(self.filename))
|
||||
return self._nojump_matrixes
|
||||
|
||||
@nojump_matrixes.setter
|
||||
def nojump_matrixes(self, mats):
|
||||
self._nojump_matrixes = mats
|
||||
|
||||
def __init__(self, rd):
|
||||
self.rd = rd
|
||||
self._nojump_matrixes = None
|
||||
if path.exists(nojump_load_filename(self)):
|
||||
load_nojump_matrixes(self)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.rd[item]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.rd)
|
||||
|
||||
def __checksum__(self):
|
||||
cache = array("L", self.rd._xdr.offsets.tobytes())
|
||||
return checksum(self.filename, str(cache))
|
||||
return MDAnalysis.auxiliary.EDR.EDRReader(file)
|
||||
|
Loading…
Reference in New Issue
Block a user