Files
mdevaluate/mdevaluate/reader.py

380 lines
12 KiB
Python
Executable File

"""
Module that provides different readers for trajectory files.
It also provides a common interface layer between the file IO packages,
namely pygmx and mdanalysis, and mdevaluate.
"""
from .checksum import checksum
from .logging import logger
from . import atoms
from functools import lru_cache
from collections import namedtuple
import os
from os import path
from array import array
from zipfile import BadZipFile
import builtins
import warnings
import subprocess
import re
import itertools
import numpy as np
import MDAnalysis as mdanalysis
from scipy import sparse
from dask import delayed, __version__ as DASK_VERSION
import pandas as pd
import re
class NojumpError(Exception):
pass
class WrongTopologyError(Exception):
pass
def open_with_mdanalysis(topology, trajectory, cached=False, index_file=None,
charges=None, masses=None):
"""Open a the topology and trajectory with mdanalysis."""
uni = mdanalysis.Universe(topology, trajectory, convert_units=False)
if cached is not False:
if cached is True:
maxsize = 128
else:
maxsize = cached
reader = CachedReader(uni.trajectory, maxsize)
else:
reader = BaseReader(uni.trajectory)
reader.universe = uni
if topology.endswith('.tpr'):
charges = uni.atoms.charges
masses = uni.atoms.masses
elif topology.endswith('.gro'):
charges = charges
masses = masses
else:
raise WrongTopologyError('Topology file should end with ".tpr" or ".gro"')
indices = None
if index_file:
indices = load_indices(index_file)
atms = atoms.Atoms(
np.stack((uni.atoms.resids, uni.atoms.resnames, uni.atoms.names), axis=1),
charges=charges, masses=masses, indices=indices
).subset()
return atms, reader
group_re = re.compile('\[ ([-+\w]+) \]')
def load_indices(index_file):
indices = {}
index_array = None
with open(index_file) as idx_file:
for line in idx_file:
m = group_re.search(line)
if m is not None:
group_name = m.group(1)
index_array = indices.get(group_name, [])
indices[group_name] = index_array
else:
elements = line.strip().split('\t')
elements = [x.split(' ') for x in elements]
elements = itertools.chain(*elements) # make a flat iterator
elements = [x for x in elements if x != '']
index_array += [int(x) - 1 for x in elements]
return indices
def is_writeable(fname):
"""Test if a directory is actually writeable, by writing a temporary file."""
fdir = os.path.dirname(fname)
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
while os.path.exists(ftmp):
ftmp = os.path.join(fdir, str(np.random.randint(999999999)))
if os.access(fdir, os.W_OK):
try:
with builtins.open(ftmp, 'w'):
pass
os.remove(ftmp)
return True
except PermissionError:
pass
return False
def nojump_load_filename(reader):
directory, fname = path.split(reader.filename)
full_path = path.join(directory, '.{}.nojump.npz'.format(fname))
if not is_writeable(directory):
user_data_dir = os.path.join("/data/",
os.environ['HOME'].split("/")[-1])
full_path_fallback = os.path.join(
os.path.join(user_data_dir, '.mdevaluate/nojump'),
directory.lstrip('/'),
'.{}.nojump.npz'.format(fname)
)
if os.path.exists(full_path_fallback):
return full_path_fallback
if os.path.exists(fname) or is_writeable(directory):
return full_path
else:
user_data_dir = os.path.join("/data/",
os.environ['HOME'].split("/")[-1])
full_path_fallback = os.path.join(
os.path.join(user_data_dir, '.mdevaluate/nojump'),
directory.lstrip('/'),
'.{}.nojump.npz'.format(fname)
)
return full_path
def nojump_save_filename(reader):
directory, fname = path.split(reader.filename)
full_path = path.join(directory, '.{}.nojump.npz'.format(fname))
if is_writeable(directory):
return full_path
else:
user_data_dir = os.path.join("/data/",
os.environ['HOME'].split("/")[-1])
full_path_fallback = os.path.join(
os.path.join(user_data_dir, '.mdevaluate/nojump'),
directory.lstrip('/'),
'.{}.nojump.npz'.format(fname)
)
logger.info('Saving nojump to {}, since original location is not writeable.'.format(full_path_fallback))
os.makedirs(os.path.dirname(full_path_fallback), exist_ok=True)
return full_path_fallback
CSR_ATTRS = ('data', 'indices', 'indptr')
NOJUMP_MAGIC = 2016
def parse_jumps(trajectory):
prev = trajectory[0].whole
box = prev.box.diagonal()
SparseData = namedtuple('SparseData', ['data', 'row', 'col'])
jump_data = (
SparseData(data=array('b'), row=array('l'), col=array('l')),
SparseData(data=array('b'), row=array('l'), col=array('l')),
SparseData(data=array('b'), row=array('l'), col=array('l'))
)
for i, curr in enumerate(trajectory):
if i % 500 == 0:
logger.debug('Parse jumps Step: %d', i)
delta = ((curr - prev) / box).round().astype(np.int8)
prev = curr
for d in range(3):
col, = np.where(delta[:, d] != 0)
jump_data[d].col.extend(col)
jump_data[d].row.extend([i] * len(col))
jump_data[d].data.extend(delta[col, d])
return jump_data
def generate_nojump_matrixes(trajectory):
"""
Create the matrixes with pbc jumps for a trajectory.
"""
logger.info('generate Nojump Matrixes for: {}'.format(trajectory))
jump_data = parse_jumps(trajectory)
N = len(trajectory)
M = len(trajectory[0])
trajectory.frames.nojump_matrixes = tuple(
sparse.csr_matrix((np.array(m.data), (m.row, m.col)), shape=(N, M)) for m in jump_data
)
save_nojump_matrixes(trajectory.frames)
def save_nojump_matrixes(reader, matrixes=None):
if matrixes is None:
matrixes = reader.nojump_matrixes
data = {'checksum': checksum(NOJUMP_MAGIC, checksum(reader))}
for d, mat in enumerate(matrixes):
data['shape'] = mat.shape
for attr in CSR_ATTRS:
data['{}_{}'.format(attr, d)] = getattr(mat, attr)
np.savez(nojump_save_filename(reader), **data)
def load_nojump_matrixes(reader):
zipname = nojump_load_filename(reader)
try:
data = np.load(zipname, allow_pickle=True)
except (AttributeError, BadZipFile, OSError):
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed?
logger.info('Removing zip-File: %s', zipname)
os.remove(nojump_load_filename(reader))
return
try:
if data['checksum'] == checksum(NOJUMP_MAGIC, checksum(reader)):
reader.nojump_matrixes = tuple(
sparse.csr_matrix(
tuple(data['{}_{}'.format(attr, d)] for attr in CSR_ATTRS),
shape=data['shape']
)
for d in range(3)
)
logger.info('Loaded Nojump Matrixes: {}'.format(nojump_load_filename(reader)))
else:
logger.info('Invlaid Nojump Data: {}'.format(nojump_load_filename(reader)))
except KeyError:
logger.info('Removing zip-File: %s', zipname)
os.remove(nojump_load_filename(reader))
return
def correct_nojump_matrixes_for_whole(trajectory):
reader = trajectory.frames
frame = trajectory[0]
box = frame.box.diagonal()
cor = ((frame - frame.whole) / box).round().astype(np.int8)
for d in range(3):
reader.nojump_matrixes[d][0] = cor[:, d]
save_nojump_matrixes(reader)
def energy_reader(file, energies=None):
"""Reads an gromacs energy file and output the data in a pandas DataFrame.
Args:
file: Filename of the energy file
energies (opt.): Specify energies to extract from the energy file
"""
if energies is None:
energies = np.arange(1, 100).astype('str')
directory = file.rsplit("/", 1)[0]
ps = subprocess.Popen(("echo", *energies), stdout=subprocess.PIPE)
try:
subprocess.run(("gmx", "energy", "-f", file, "-o",
f"{directory}/tmp.xvg", "-quiet", "-nobackup"),
stdin=ps.stdout)
except FileNotFoundError:
print("No GROMACS found!")
ps.wait()
labels = []
is_legend = False
with open(f"{directory}/tmp.xvg") as f:
for i, line in enumerate(f):
if line.split(" ")[0] == "@":
if re.search("s\d+", line.split()[1]):
is_legend = True
labels.append(line.split('"')[1])
elif is_legend:
header = i
break
data = np.loadtxt(f"{directory}/tmp.xvg", skiprows=header)
df = pd.DataFrame({"Time":data[:,0]})
for i, label in enumerate(labels):
tmp_df = pd.DataFrame({label:data[:,i+1]})
df = pd.concat([df, tmp_df], axis=1)
subprocess.run(("rm", f"{directory}/tmp.xvg"))
return df
class BaseReader:
"""Base class for trajectory readers."""
@property
def filename(self):
return self.rd.filename
@property
def nojump_matrixes(self):
if self._nojump_matrixes is None:
raise NojumpError('Nojump Data not available: {}'.format(self.filename))
return self._nojump_matrixes
@nojump_matrixes.setter
def nojump_matrixes(self, mats):
self._nojump_matrixes = mats
def __init__(self, rd):
"""
Args:
filename: Trajectory file to open.
reindex (bool, opt.): If True, regenerate the index file if necessary.
"""
self.rd = rd
self._nojump_matrixes = None
if path.exists(nojump_load_filename(self)):
load_nojump_matrixes(self)
def __getitem__(self, item):
return self.rd[item]
def __len__(self):
return len(self.rd)
def __checksum__(self):
if hasattr(self.rd, 'cache'):
# Has an pygmx reader
return checksum(self.filename, str(self.rd.cache))
elif hasattr(self.rd, '_xdr'):
# Has an mdanalysis reader
cache = array('L', self.rd._xdr.offsets.tobytes())
return checksum(self.filename, str(cache))
class CachedReader(BaseReader):
"""A reader that has a least-recently-used cache for frames."""
@property
def cache_info(self):
"""Get Information about the lru cache."""
return self._get_item.cache_info()
def clear_cache(self):
"""Clear the cache of the frames."""
self._get_item.cache_clear()
def __init__(self, rd, maxsize):
"""
Args:
filename (str): Trajectory file that will be opened.
maxsize: Maximum size of the lru_cache or None for infinite cache.
"""
super().__init__(rd)
self._get_item = lru_cache(maxsize=maxsize)(self._get_item)
def _get_item(self, item):
"""Buffer function for lru_cache, since __getitem__ can not be cached."""
return super().__getitem__(item)
def __getitem__(self, item):
return self._get_item(item)
class DelayedReader(BaseReader):
@property
def filename(self):
if self.rd is not None:
return self.rd.filename
else:
return self._filename
def __init__(self, filename, reindex=False, ignore_index_timestamps=False):
super().__init__(filename, reindex=False, ignore_index_timestamps=False)
self.natoms = len(self.rd[0].coordinates)
self.cache = self.rd.cache
self._filename = self.rd.filename
self.rd = None
def __len__(self):
return len(self.cache)
def _get_item(self, frame):
return read_xtcframe_delayed(self.filename, self.cache[frame], self.natoms)
def __getitem__(self, frame):
return self._get_item(frame)