Added type hints, moved imports and classes around
This commit is contained in:
parent
8bad2861fa
commit
c00fc78f23
@ -2,13 +2,8 @@
|
|||||||
Module that provides different readers for trajectory files.
|
Module that provides different readers for trajectory files.
|
||||||
|
|
||||||
It also provides a common interface layer between the file IO packages,
|
It also provides a common interface layer between the file IO packages,
|
||||||
namely pygmx and mdanalysis, and mdevaluate.
|
namely mdanalysis, and mdevaluate.
|
||||||
"""
|
"""
|
||||||
from .checksum import checksum
|
|
||||||
from .logging import logger
|
|
||||||
from . import atoms
|
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
import os
|
import os
|
||||||
from os import path
|
from os import path
|
||||||
@ -19,9 +14,19 @@ import re
|
|||||||
import itertools
|
import itertools
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import MDAnalysis as mdanalysis
|
import numpy.typing as npt
|
||||||
|
import MDAnalysis
|
||||||
from scipy import sparse
|
from scipy import sparse
|
||||||
|
|
||||||
|
from .checksum import checksum
|
||||||
|
from .logging import logger
|
||||||
|
from . import atoms
|
||||||
|
from .coordinates import Coordinates
|
||||||
|
|
||||||
|
CSR_ATTRS = ("data", "indices", "indptr")
|
||||||
|
NOJUMP_MAGIC = 2016
|
||||||
|
Group_RE = re.compile("\[ ([-+\w]+) \]")
|
||||||
|
|
||||||
|
|
||||||
class NojumpError(Exception):
|
class NojumpError(Exception):
|
||||||
pass
|
pass
|
||||||
@ -31,11 +36,49 @@ class WrongTopologyError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BaseReader:
|
||||||
|
"""Base class for trajectory readers."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self):
|
||||||
|
return self.rd.filename
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nojump_matrices(self):
|
||||||
|
if self._nojump_matrices is None:
|
||||||
|
raise NojumpError("Nojump Data not available: {}".format(self.filename))
|
||||||
|
return self._nojump_matrices
|
||||||
|
|
||||||
|
@nojump_matrices.setter
|
||||||
|
def nojump_matrices(self, mats):
|
||||||
|
self._nojump_matrices = mats
|
||||||
|
|
||||||
|
def __init__(self, rd):
|
||||||
|
self.rd = rd
|
||||||
|
self._nojump_matrices = None
|
||||||
|
if path.exists(nojump_load_filename(self)):
|
||||||
|
load_nojump_matrices(self)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.rd[item]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.rd)
|
||||||
|
|
||||||
|
def __checksum__(self):
|
||||||
|
cache = array("L", self.rd._xdr.offsets.tobytes())
|
||||||
|
return checksum(self.filename, str(cache))
|
||||||
|
|
||||||
|
|
||||||
def open_with_mdanalysis(
|
def open_with_mdanalysis(
|
||||||
topology, trajectory, index_file=None, charges=None, masses=None
|
topology: str,
|
||||||
):
|
trajectory: str,
|
||||||
|
index_file: str = None,
|
||||||
|
charges: npt.ArrayLike = None,
|
||||||
|
masses: npt.ArrayLike = None,
|
||||||
|
) -> (atoms.Atoms, BaseReader):
|
||||||
"""Open the topology and trajectory with mdanalysis."""
|
"""Open the topology and trajectory with mdanalysis."""
|
||||||
uni = mdanalysis.Universe(topology, trajectory, convert_units=False)
|
uni = MDAnalysis.Universe(topology, trajectory, convert_units=False)
|
||||||
reader = BaseReader(uni.trajectory)
|
reader = BaseReader(uni.trajectory)
|
||||||
reader.universe = uni
|
reader.universe = uni
|
||||||
if topology.endswith(".tpr"):
|
if topology.endswith(".tpr"):
|
||||||
@ -60,15 +103,12 @@ def open_with_mdanalysis(
|
|||||||
return atms, reader
|
return atms, reader
|
||||||
|
|
||||||
|
|
||||||
group_re = re.compile("\[ ([-+\w]+) \]")
|
def load_indices(index_file: str):
|
||||||
|
|
||||||
|
|
||||||
def load_indices(index_file):
|
|
||||||
indices = {}
|
indices = {}
|
||||||
index_array = None
|
index_array = None
|
||||||
with open(index_file) as idx_file:
|
with open(index_file) as idx_file:
|
||||||
for line in idx_file:
|
for line in idx_file:
|
||||||
m = group_re.search(line)
|
m = Group_RE.search(line)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
group_name = m.group(1)
|
group_name = m.group(1)
|
||||||
index_array = indices.get(group_name, [])
|
index_array = indices.get(group_name, [])
|
||||||
@ -82,7 +122,7 @@ def load_indices(index_file):
|
|||||||
return indices
|
return indices
|
||||||
|
|
||||||
|
|
||||||
def is_writeable(fname):
|
def is_writeable(fname: str):
|
||||||
"""Test if a directory is actually writeable, by writing a temporary file."""
|
"""Test if a directory is actually writeable, by writing a temporary file."""
|
||||||
fdir = os.path.dirname(fname)
|
fdir = os.path.dirname(fname)
|
||||||
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
|
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
|
||||||
@ -100,7 +140,7 @@ def is_writeable(fname):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def nojump_load_filename(reader):
|
def nojump_load_filename(reader: BaseReader):
|
||||||
directory, fname = path.split(reader.filename)
|
directory, fname = path.split(reader.filename)
|
||||||
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
||||||
if not is_writeable(directory):
|
if not is_writeable(directory):
|
||||||
@ -116,7 +156,7 @@ def nojump_load_filename(reader):
|
|||||||
return full_path
|
return full_path
|
||||||
else:
|
else:
|
||||||
user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1])
|
user_data_dir = os.path.join("/data/", os.environ["HOME"].split("/")[-1])
|
||||||
full_path_fallback = os.path.join(
|
full_path = os.path.join(
|
||||||
os.path.join(user_data_dir, ".mdevaluate/nojump"),
|
os.path.join(user_data_dir, ".mdevaluate/nojump"),
|
||||||
directory.lstrip("/"),
|
directory.lstrip("/"),
|
||||||
".{}.nojump.npz".format(fname),
|
".{}.nojump.npz".format(fname),
|
||||||
@ -124,7 +164,7 @@ def nojump_load_filename(reader):
|
|||||||
return full_path
|
return full_path
|
||||||
|
|
||||||
|
|
||||||
def nojump_save_filename(reader):
|
def nojump_save_filename(reader: BaseReader):
|
||||||
directory, fname = path.split(reader.filename)
|
directory, fname = path.split(reader.filename)
|
||||||
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
full_path = path.join(directory, ".{}.nojump.npz".format(fname))
|
||||||
if is_writeable(directory):
|
if is_writeable(directory):
|
||||||
@ -145,11 +185,7 @@ def nojump_save_filename(reader):
|
|||||||
return full_path_fallback
|
return full_path_fallback
|
||||||
|
|
||||||
|
|
||||||
CSR_ATTRS = ("data", "indices", "indptr")
|
def parse_jumps(trajectory: Coordinates):
|
||||||
NOJUMP_MAGIC = 2016
|
|
||||||
|
|
||||||
|
|
||||||
def parse_jumps(trajectory):
|
|
||||||
prev = trajectory[0].whole
|
prev = trajectory[0].whole
|
||||||
box = prev.box.diagonal()
|
box = prev.box.diagonal()
|
||||||
SparseData = namedtuple("SparseData", ["data", "row", "col"])
|
SparseData = namedtuple("SparseData", ["data", "row", "col"])
|
||||||
@ -173,28 +209,28 @@ def parse_jumps(trajectory):
|
|||||||
return jump_data
|
return jump_data
|
||||||
|
|
||||||
|
|
||||||
def generate_nojump_matrixes(trajectory):
|
def generate_nojump_matrices(trajectory: Coordinates):
|
||||||
"""
|
"""
|
||||||
Create the matrixes with pbc jumps for a trajectory.
|
Create the matrices with pbc jumps for a trajectory.
|
||||||
"""
|
"""
|
||||||
logger.info("generate Nojump Matrixes for: {}".format(trajectory))
|
logger.info("generate Nojump matrices for: {}".format(trajectory))
|
||||||
|
|
||||||
jump_data = parse_jumps(trajectory)
|
jump_data = parse_jumps(trajectory)
|
||||||
N = len(trajectory)
|
N = len(trajectory)
|
||||||
M = len(trajectory[0])
|
M = len(trajectory[0])
|
||||||
|
|
||||||
trajectory.frames.nojump_matrixes = tuple(
|
trajectory.frames.nojump_matrices = tuple(
|
||||||
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M))
|
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M))
|
||||||
for m in jump_data
|
for m in jump_data
|
||||||
)
|
)
|
||||||
save_nojump_matrixes(trajectory.frames)
|
save_nojump_matrices(trajectory.frames)
|
||||||
|
|
||||||
|
|
||||||
def save_nojump_matrixes(reader, matrixes=None):
|
def save_nojump_matrices(reader: BaseReader, matrices: npt.ArrayLike = None):
|
||||||
if matrixes is None:
|
if matrices is None:
|
||||||
matrixes = reader.nojump_matrixes
|
matrices = reader.nojump_matrices
|
||||||
data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))}
|
data = {"checksum": checksum(NOJUMP_MAGIC, checksum(reader))}
|
||||||
for d, mat in enumerate(matrixes):
|
for d, mat in enumerate(matrices):
|
||||||
data["shape"] = mat.shape
|
data["shape"] = mat.shape
|
||||||
for attr in CSR_ATTRS:
|
for attr in CSR_ATTRS:
|
||||||
data["{}_{}".format(attr, d)] = getattr(mat, attr)
|
data["{}_{}".format(attr, d)] = getattr(mat, attr)
|
||||||
@ -202,18 +238,19 @@ def save_nojump_matrixes(reader, matrixes=None):
|
|||||||
np.savez(nojump_save_filename(reader), **data)
|
np.savez(nojump_save_filename(reader), **data)
|
||||||
|
|
||||||
|
|
||||||
def load_nojump_matrixes(reader):
|
def load_nojump_matrices(reader: BaseReader):
|
||||||
zipname = nojump_load_filename(reader)
|
zipname = nojump_load_filename(reader)
|
||||||
try:
|
try:
|
||||||
data = np.load(zipname, allow_pickle=True)
|
data = np.load(zipname, allow_pickle=True)
|
||||||
except (AttributeError, BadZipFile, OSError):
|
except (AttributeError, BadZipFile, OSError):
|
||||||
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed?
|
# npz-files can be corrupted, probably a bug for big arrays saved with
|
||||||
|
# savez_compressed?
|
||||||
logger.info("Removing zip-File: %s", zipname)
|
logger.info("Removing zip-File: %s", zipname)
|
||||||
os.remove(nojump_load_filename(reader))
|
os.remove(nojump_load_filename(reader))
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)):
|
if data["checksum"] == checksum(NOJUMP_MAGIC, checksum(reader)):
|
||||||
reader.nojump_matrixes = tuple(
|
reader.nojump_matrices = tuple(
|
||||||
sparse.csr_matrix(
|
sparse.csr_matrix(
|
||||||
tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS),
|
tuple(data["{}_{}".format(attr, d)] for attr in CSR_ATTRS),
|
||||||
shape=data["shape"],
|
shape=data["shape"],
|
||||||
@ -221,7 +258,7 @@ def load_nojump_matrixes(reader):
|
|||||||
for d in range(3)
|
for d in range(3)
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Loaded Nojump Matrixes: {}".format(nojump_load_filename(reader))
|
"Loaded Nojump matrices: {}".format(nojump_load_filename(reader))
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader)))
|
logger.info("Invlaid Nojump Data: {}".format(nojump_load_filename(reader)))
|
||||||
@ -231,53 +268,19 @@ def load_nojump_matrixes(reader):
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def correct_nojump_matrixes_for_whole(trajectory):
|
def correct_nojump_matrices_for_whole(trajectory: Coordinates):
|
||||||
reader = trajectory.frames
|
reader = trajectory.frames
|
||||||
frame = trajectory[0]
|
frame = trajectory[0]
|
||||||
box = frame.box.diagonal()
|
box = frame.box.diagonal()
|
||||||
cor = ((frame - frame.whole) / box).round().astype(np.int8)
|
cor = ((frame - frame.whole) / box).round().astype(np.int8)
|
||||||
for d in range(3):
|
for d in range(3):
|
||||||
reader.nojump_matrixes[d][0] = cor[:, d]
|
reader.nojump_matrices[d][0] = cor[:, d]
|
||||||
save_nojump_matrixes(reader)
|
save_nojump_matrices(reader)
|
||||||
|
|
||||||
|
|
||||||
def energy_reader(file):
|
def energy_reader(file: str):
|
||||||
"""Reads a gromacs energy file with mdanalysis and returns an auxiliary file.
|
"""Reads a gromacs energy file with mdanalysis and returns an auxiliary file.
|
||||||
Args:
|
Args:
|
||||||
file: Filename of the energy file
|
file: Filename of the energy file
|
||||||
"""
|
"""
|
||||||
return mdanalysis.auxiliary.EDR.EDRReader(file)
|
return MDAnalysis.auxiliary.EDR.EDRReader(file)
|
||||||
|
|
||||||
|
|
||||||
class BaseReader:
|
|
||||||
"""Base class for trajectory readers."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def filename(self):
|
|
||||||
return self.rd.filename
|
|
||||||
|
|
||||||
@property
|
|
||||||
def nojump_matrixes(self):
|
|
||||||
if self._nojump_matrixes is None:
|
|
||||||
raise NojumpError("Nojump Data not available: {}".format(self.filename))
|
|
||||||
return self._nojump_matrixes
|
|
||||||
|
|
||||||
@nojump_matrixes.setter
|
|
||||||
def nojump_matrixes(self, mats):
|
|
||||||
self._nojump_matrixes = mats
|
|
||||||
|
|
||||||
def __init__(self, rd):
|
|
||||||
self.rd = rd
|
|
||||||
self._nojump_matrixes = None
|
|
||||||
if path.exists(nojump_load_filename(self)):
|
|
||||||
load_nojump_matrixes(self)
|
|
||||||
|
|
||||||
def __getitem__(self, item):
|
|
||||||
return self.rd[item]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
return len(self.rd)
|
|
||||||
|
|
||||||
def __checksum__(self):
|
|
||||||
cache = array("L", self.rd._xdr.offsets.tobytes())
|
|
||||||
return checksum(self.filename, str(cache))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user