First draft of a python wrapper for gromacs library
This commit is contained in:
495
pygmx/reader.py
Normal file
495
pygmx/reader.py
Normal file
@ -0,0 +1,495 @@
|
||||
import xdrlib
|
||||
import io
|
||||
import os
|
||||
# from .coordinates import decompress
|
||||
from ..utils import hash_anything, merge_hashes
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
|
||||
TRR_MAGIC = 1993
|
||||
XTC_MAGIC = 1995
|
||||
INDEX_MAGIC = 2015
|
||||
|
||||
def index_filename_for_xtc(xtcfile):
|
||||
"""
|
||||
Get the filename for the index file of a xtc file.
|
||||
|
||||
In general, the index file is located in the same directory as the xtc file.
|
||||
If the xtc file is located in a read-only directory (for the current user)
|
||||
and does not exist, the index file will be instead located in a subdirectory
|
||||
of ~/.xtcindex, of the current user.
|
||||
|
||||
The directory for user index files can be changed by setting the environment
|
||||
variable 'XTCINDEXDIR'.
|
||||
|
||||
Example:
|
||||
xtcfile = '/data/niels/test.xtc'
|
||||
|
||||
# case 1: index file exists, or current user is niels
|
||||
index_filename_for_xtc(xtcfile) == '/data/niels/.test.xtc.xtcindex'
|
||||
|
||||
# case 2: index file doesn't exist, and nor write permission in /data/niels
|
||||
index_filename_for_xtc(xtcfile) == '~/.xtcindex/data/niels/.test.xtc.xtcindex'
|
||||
|
||||
Warning:
|
||||
At this point, the index file is not checked for validity. If an invalid
|
||||
index file is present in the xtc files directory, this will be used and an
|
||||
error will be risen by XTCReader.
|
||||
|
||||
Warning:
|
||||
On most systems, the home directory is on the local drive so that the indexing
|
||||
has probably to do be done on every system if it can not be saved in the directory
|
||||
of the xtc file.
|
||||
"""
|
||||
base = os.path.basename(xtcfile)
|
||||
filename = "." + base + ".xtcindex"
|
||||
|
||||
dir = os.path.abspath(os.path.dirname(xtcfile))
|
||||
|
||||
if not os.path.exists(os.path.join(dir, filename)) and not os.access(dir, os.W_OK):
|
||||
if 'XTCINDEXDIR' in os.environ:
|
||||
index_dir = os.environ['XTCINDEXDIR']
|
||||
else:
|
||||
index_dir = os.path.join(os.environ['HOME'], '.xtcindex')
|
||||
dir = os.path.join(index_dir, dir.lstrip('/'))
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
|
||||
return os.path.join(dir, filename)
|
||||
|
||||
|
||||
class NumpyUnpacker(xdrlib.Unpacker):
|
||||
|
||||
def unpack_float_array(self, n_items):
|
||||
i = self.get_position()
|
||||
j = i + 4 * n_items
|
||||
self.set_position(j)
|
||||
data = self.get_buffer()[i:j]
|
||||
|
||||
if len(data) < 4:
|
||||
raise EOFError
|
||||
|
||||
ret = np.frombuffer(data, '>f')
|
||||
if len(ret) != n_items:
|
||||
raise EOFError
|
||||
return ret
|
||||
|
||||
def unpack_double_array(self, n_items):
|
||||
i = self.get_position()
|
||||
j = i + 8 * n_items
|
||||
self.set_position(j)
|
||||
data = self.get_buffer()[i:j]
|
||||
|
||||
if len(data) < 8:
|
||||
raise EOFError
|
||||
|
||||
ret = np.frombuffer(data, '>d')
|
||||
if len(ret) != n_items:
|
||||
raise EOFError
|
||||
return ret
|
||||
|
||||
|
||||
class InvalidMagicException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidIndexException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class UnknownLenError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class SubscriptableReader:
|
||||
|
||||
def __init__(self, fd):
|
||||
self.fd = fd
|
||||
self.len = os.stat(self.fd.fileno()).st_size
|
||||
|
||||
def __getitem__(self, r):
|
||||
if isinstance(r, slice):
|
||||
if r.start is not None:
|
||||
if r.start < 0:
|
||||
self.fd.seek(r.start, io.SEEK_END)
|
||||
else:
|
||||
self.fd.seek(r.start, io.SEEK_SET)
|
||||
|
||||
if r.step is not None:
|
||||
raise NotImplementedError
|
||||
|
||||
return self.fd.read(r.stop - r.start)
|
||||
else:
|
||||
self.fd.seek(r, io.SEEK_SET)
|
||||
return self.fd.read(1)
|
||||
|
||||
def __len__(self):
|
||||
return self.len
|
||||
|
||||
|
||||
class XTCFrame:
|
||||
__slots__ = ['_coordinates', 'index', 'time', 'box', '_compressed_coordinates']
|
||||
|
||||
def __init__(self):
|
||||
self._coordinates = None
|
||||
|
||||
@property
|
||||
def coordinates(self):
|
||||
if self._coordinates is None:
|
||||
self._coordinates = self._compressed_coordinates()
|
||||
return self._coordinates
|
||||
|
||||
|
||||
class BaseReader:
|
||||
|
||||
def __init__(self, file):
|
||||
self.fd = open(file, 'rb')
|
||||
self.filename = file
|
||||
|
||||
self.reader = SubscriptableReader(self.fd)
|
||||
self.xdr_reader = NumpyUnpacker(self.reader)
|
||||
|
||||
def __exit__(self, *_):
|
||||
self.fd.close()
|
||||
|
||||
def __enter__(self, *_):
|
||||
return self
|
||||
|
||||
def get_position(self):
|
||||
return self.fd.tell()
|
||||
|
||||
def set_position(self, pos):
|
||||
return self.fd.seek(pos)
|
||||
|
||||
def skip_frames(self, num):
|
||||
for i in range(num):
|
||||
self.skip_frame()
|
||||
|
||||
def skip_frame(self):
|
||||
self._read_header()
|
||||
self._read_frame()
|
||||
|
||||
def dump_frame(self):
|
||||
raise NotImplemented
|
||||
|
||||
def _read_header(self):
|
||||
raise NotImplemented
|
||||
|
||||
def _read_frame(self):
|
||||
raise NotImplemented
|
||||
|
||||
XTC_HEADER_SIZE = 4 * 4
|
||||
XTC_FRAME_HEADER_SIZE = 9 * 4 + 10 * 4
|
||||
|
||||
|
||||
class XTCReader(BaseReader):
|
||||
|
||||
len_available = False
|
||||
|
||||
def __init__(self, xtcfile, load_cache_file=True):
|
||||
super().__init__(xtcfile)
|
||||
self._cache = [0]
|
||||
self._times = None
|
||||
self.current_frame = 0
|
||||
|
||||
index_file_name = index_filename_for_xtc(xtcfile)
|
||||
if load_cache_file:
|
||||
self.load_index(index_file_name)
|
||||
|
||||
def load_index(self, filename):
|
||||
xtc_stat = os.stat(self.filename)
|
||||
c_time = int(xtc_stat.st_ctime)
|
||||
m_time = int(xtc_stat.st_mtime)
|
||||
size = xtc_stat.st_size
|
||||
|
||||
with open(filename, 'rb') as index_file_fd:
|
||||
# TODO: Is NumpyUnpacker even necessary at this point?
|
||||
# Seems like xdrlib.Unpacker would be sufficient here ...
|
||||
reader = NumpyUnpacker(SubscriptableReader(index_file_fd))
|
||||
|
||||
if reader.unpack_hyper() != INDEX_MAGIC:
|
||||
raise InvalidMagicException
|
||||
if reader.unpack_hyper() != c_time:
|
||||
raise InvalidIndexException
|
||||
if reader.unpack_hyper() != m_time:
|
||||
raise InvalidIndexException
|
||||
if reader.unpack_hyper() != size:
|
||||
raise InvalidIndexException
|
||||
|
||||
self._cache = []
|
||||
self._times = []
|
||||
try:
|
||||
while True:
|
||||
self._cache.append(reader.unpack_hyper())
|
||||
self._times.append(reader.unpack_float())
|
||||
except EOFError:
|
||||
pass
|
||||
self.len_available = True
|
||||
|
||||
def _raw_header(self):
|
||||
if len(self._cache) == self.current_frame:
|
||||
self._cache.append(self.get_position())
|
||||
return self.fd.read(XTC_HEADER_SIZE)
|
||||
|
||||
def _raw_frame(self):
|
||||
frame_header = self.fd.read(XTC_FRAME_HEADER_SIZE)
|
||||
blob_size = xdrlib.Unpacker(frame_header[-4:]).unpack_int()
|
||||
blob_size = (blob_size + 3) // 4 * 4 # Padding to 4 bytes
|
||||
frame_blob = self.fd.read(blob_size)
|
||||
|
||||
self.current_frame += 1
|
||||
return frame_header + frame_blob
|
||||
|
||||
def _unpack_header(self, raw_header):
|
||||
unpacker = xdrlib.Unpacker(raw_header)
|
||||
|
||||
magic = unpacker.unpack_int()
|
||||
|
||||
if magic != XTC_MAGIC:
|
||||
raise InvalidMagicException
|
||||
|
||||
n_atoms = unpacker.unpack_int()
|
||||
step = unpacker.unpack_int()
|
||||
time = unpacker.unpack_float()
|
||||
|
||||
return n_atoms, step, time
|
||||
|
||||
def _read_header(self):
|
||||
raw_header = self._raw_header()
|
||||
return self._unpack_header(raw_header)
|
||||
|
||||
def _unpack_frame(self, raw_frame):
|
||||
unpacker = xdrlib.Unpacker(raw_frame)
|
||||
|
||||
raw_box = unpacker.unpack_farray(9, unpacker.unpack_float)
|
||||
box = np.array(raw_box).reshape(3, 3)
|
||||
num_coords = unpacker.unpack_int()
|
||||
precision = unpacker.unpack_float()
|
||||
maxint = unpacker.unpack_farray(3, unpacker.unpack_int)
|
||||
minint = unpacker.unpack_farray(3, unpacker.unpack_int)
|
||||
smallindex = unpacker.unpack_int()
|
||||
|
||||
blob_len = unpacker.unpack_int()
|
||||
blob = unpacker.unpack_fopaque(blob_len)
|
||||
|
||||
return box, precision, num_coords, minint, maxint, smallindex, blob
|
||||
|
||||
def _read_frame(self):
|
||||
raw_frame = self._raw_frame()
|
||||
return self._unpack_frame(raw_frame)
|
||||
|
||||
def decode_frame(self, header=None, body=None):
|
||||
n_atoms, step, time = header or self._read_header()
|
||||
box, precision, num_coords, minint, maxint, smallindex, blob = body or self._read_frame()
|
||||
coordinates = partial(decompress, num_coords, precision, minint, maxint, smallindex, blob)
|
||||
|
||||
frame = XTCFrame()
|
||||
frame._compressed_coordinates = coordinates
|
||||
frame.index = step
|
||||
frame.time = time
|
||||
frame.box = box
|
||||
return frame
|
||||
|
||||
def dump_frame(self):
|
||||
"""
|
||||
:return: Tuple: The binary data of the frame, the frame itself
|
||||
"""
|
||||
|
||||
raw_header = self._raw_header()
|
||||
header = self._unpack_header(raw_header)
|
||||
|
||||
raw_body = self._raw_frame()
|
||||
body = self._unpack_frame(raw_body)
|
||||
|
||||
return (raw_header + raw_body, self.decode_frame(header, body))
|
||||
|
||||
def cached_position(self, item):
|
||||
try:
|
||||
return self._cache[item]
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
def __getitem__(self, item):
|
||||
position = self.cached_position(item)
|
||||
|
||||
if position is not None:
|
||||
self.set_position(position)
|
||||
self.current_frame = item
|
||||
return self.decode_frame()
|
||||
# TODO: Use elif and one single return at the end
|
||||
# Also: skip_frames is in fact not implemented at all!
|
||||
# TODO: Catch EOFError and raise IndexError
|
||||
if self.current_frame <= item:
|
||||
self.skip_frames(item - self.current_frame)
|
||||
return self.decode_frame()
|
||||
else:
|
||||
self.set_position(0)
|
||||
self.current_frame = 0
|
||||
self.skip_frames(item)
|
||||
return self.decode_frame()
|
||||
|
||||
def __len__(self):
|
||||
if self.len_available:
|
||||
return len(self._cache)
|
||||
raise UnknownLenError
|
||||
|
||||
def times_of_indices(self, indices):
|
||||
return [self[i].time for i in indices]
|
||||
|
||||
def __hash__(self):
|
||||
return merge_hashes(hash(self.filename), hash_anything(self._cache))
|
||||
|
||||
|
||||
class TRRHeader:
|
||||
__slots__ = ['ir_size', 'e_size', 'box_size', 'vir_size', 'pres_size', 'top_size', 'sym_size', 'x_size', 'v_size',
|
||||
'f_size', 'n_atoms', 'step', 'nre', 't', 'λ', 'is_double']
|
||||
|
||||
@property
|
||||
def frame_size(self):
|
||||
return self.ir_size + \
|
||||
self.e_size + \
|
||||
self.box_size + \
|
||||
self.vir_size + \
|
||||
self.top_size + \
|
||||
self.sym_size + \
|
||||
self.x_size + \
|
||||
self.v_size + \
|
||||
self.f_size
|
||||
|
||||
|
||||
class TRRFrame:
|
||||
__slots__ = ['header', 'box', 'x', 'v', 'f']
|
||||
|
||||
|
||||
class TRRReader(BaseReader):
|
||||
|
||||
def __iter__(self):
|
||||
return TRRIterator(self.filename)
|
||||
|
||||
def _unpack_float(self, is_double):
|
||||
if is_double:
|
||||
return self.xdr_reader.unpack_double()
|
||||
else:
|
||||
return self.xdr_reader.unpack_float()
|
||||
|
||||
def _unpack_np_float(self, is_double, *dim):
|
||||
total = np.product(dim)
|
||||
|
||||
if is_double:
|
||||
box = self.xdr_reader.unpack_double_array(total)
|
||||
else:
|
||||
box = self.xdr_reader.unpack_float_array(total)
|
||||
|
||||
box = box.reshape(*dim)
|
||||
return box
|
||||
|
||||
def _read_header(self):
|
||||
|
||||
data = self.fd.read(8) # 4 Magic + 4 VersionStringLen
|
||||
unpacker = NumpyUnpacker(data)
|
||||
|
||||
magic = unpacker.unpack_int()
|
||||
|
||||
if magic != TRR_MAGIC:
|
||||
raise InvalidMagicException
|
||||
|
||||
version_string_len = unpacker.unpack_int()
|
||||
data = self.fd.read(version_string_len + 55) # 4 Magic + 4 VersionStringLen
|
||||
unpacker.reset(data)
|
||||
|
||||
version = unpacker.unpack_string()
|
||||
assert version == b'GMX_trn_file'
|
||||
|
||||
header = TRRHeader()
|
||||
|
||||
header.ir_size = unpacker.unpack_int()
|
||||
header.e_size = unpacker.unpack_int()
|
||||
header.box_size = unpacker.unpack_int()
|
||||
header.vir_size = unpacker.unpack_int()
|
||||
header.pres_size = unpacker.unpack_int()
|
||||
header.top_size = unpacker.unpack_int()
|
||||
header.sym_size = unpacker.unpack_int()
|
||||
header.x_size = unpacker.unpack_int()
|
||||
header.v_size = unpacker.unpack_int()
|
||||
header.f_size = unpacker.unpack_int()
|
||||
header.n_atoms = unpacker.unpack_int()
|
||||
header.step = unpacker.unpack_int()
|
||||
header.nre = unpacker.unpack_int()
|
||||
|
||||
if header.x_size is not None:
|
||||
header.is_double = (header.x_size // header.n_atoms) == 8
|
||||
elif header.v_size is not None:
|
||||
header.is_double = (header.v_size // header.n_atoms) == 8
|
||||
elif header.f_size is not None:
|
||||
header.is_double = (header.f_size // header.n_atoms) == 8
|
||||
|
||||
if header.is_double:
|
||||
data = self.fd.read(16)
|
||||
else:
|
||||
data = self.fd.read(8)
|
||||
unpacker.reset(data)
|
||||
self.xdr_reader = unpacker
|
||||
|
||||
header.t = self._unpack_float(header.is_double)
|
||||
header.λ = self._unpack_float(header.is_double)
|
||||
|
||||
return header
|
||||
|
||||
def _read_frame(self, header):
|
||||
frame = TRRFrame()
|
||||
frame.header = header
|
||||
|
||||
data = self.fd.read(header.frame_size)
|
||||
self.xdr_reader = NumpyUnpacker(data)
|
||||
|
||||
frame.box = None
|
||||
if header.box_size:
|
||||
frame.box = self._unpack_np_float(header.is_double, 3, 3)
|
||||
|
||||
if header.vir_size:
|
||||
for i in range(9):
|
||||
self._unpack_float(header.is_double)
|
||||
|
||||
if header.pres_size:
|
||||
for i in range(9):
|
||||
self._unpack_float(header.is_double)
|
||||
|
||||
frame.x = None
|
||||
if header.x_size:
|
||||
frame.x = self._unpack_np_float(header.is_double, header.n_atoms, 3)
|
||||
|
||||
frame.v = None
|
||||
if header.v_size:
|
||||
frame.v = self._unpack_np_float(header.is_double, header.n_atoms, 3)
|
||||
|
||||
frame.f = None
|
||||
if header.f_size:
|
||||
frame.f = self._unpack_np_float(header.is_double, header.n_atoms, 3)
|
||||
|
||||
return frame
|
||||
|
||||
def decode_frame(self):
|
||||
h = self._read_header()
|
||||
return self._read_frame(h)
|
||||
|
||||
def get_position(self):
|
||||
return self.fd.tell()
|
||||
|
||||
def set_position(self, pos):
|
||||
self.fd.seek(pos)
|
||||
|
||||
|
||||
class TRRIterator():
|
||||
|
||||
def __init__(self, file):
|
||||
self.__reader = TRRReader(file)
|
||||
self.__reader.__enter__()
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
return self.__reader.decode_frame()
|
||||
except EOFError as err:
|
||||
self.__reader.__exit__()
|
||||
raise StopIteration
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
Reference in New Issue
Block a user