Changed to load own nojump matrix first if not owner of trajectory

This commit is contained in:
sebastiankloth 2023-04-20 11:50:42 +02:00
parent eaacdbf4ad
commit 78f0b3d727

View File

@ -104,9 +104,17 @@ def is_writeable(fname):
return False return False
def nojump_filename(reader): def nojump_load_filename(reader):
directory, fname = path.split(reader.filename) directory, fname = path.split(reader.filename)
fname = path.join(directory, '.{}.nojump.npz'.format(fname)) fname = path.join(directory, '.{}.nojump.npz'.format(fname))
if not is_writeable(directory):
fname_fallback = os.path.join(
os.path.join(os.environ['HOME'], '.mdevaluate/nojump'),
directory.lstrip('/'),
'.{}.nojump.npz'.format(fname)
)
if os.path.exists(fname_fallback):
return fname_fallback
if os.path.exists(fname) or is_writeable(directory): if os.path.exists(fname) or is_writeable(directory):
return fname return fname
else: else:
@ -119,6 +127,22 @@ def nojump_filename(reader):
os.makedirs(os.path.dirname(fname), exist_ok=True) os.makedirs(os.path.dirname(fname), exist_ok=True)
return fname return fname
def nojump_save_filename(reader):
directory, fname = path.split(reader.filename)
fname = path.join(directory, '.{}.nojump.npz'.format(fname))
if is_writeable(directory):
return fname
else:
fname = os.path.join(
os.path.join(os.environ['HOME'], '.mdevaluate/nojump'),
directory.lstrip('/'),
'.{}.nojump.npz'.format(fname)
)
logger.info('Saving nojump to {}, since original location is not writeable.'.format(fname))
os.makedirs(os.path.dirname(fname), exist_ok=True)
return fname
CSR_ATTRS = ('data', 'indices', 'indptr') CSR_ATTRS = ('data', 'indices', 'indptr')
NOJUMP_MAGIC = 2016 NOJUMP_MAGIC = 2016
@ -173,17 +197,17 @@ def save_nojump_matrixes(reader, matrixes=None):
for attr in CSR_ATTRS: for attr in CSR_ATTRS:
data['{}_{}'.format(attr, d)] = getattr(mat, attr) data['{}_{}'.format(attr, d)] = getattr(mat, attr)
np.savez(nojump_filename(reader), **data) np.savez(nojump_save_filename(reader), **data)
def load_nojump_matrixes(reader): def load_nojump_matrixes(reader):
zipname = nojump_filename(reader) zipname = nojump_load_filename(reader)
try: try:
data = np.load(zipname, allow_pickle=True) data = np.load(zipname, allow_pickle=True)
except (AttributeError, BadZipFile, OSError): except (AttributeError, BadZipFile, OSError):
# npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed? # npz-files can be corrupted, propably a bug for big arrays saved with savez_compressed?
logger.info('Removing zip-File: %s', zipname) logger.info('Removing zip-File: %s', zipname)
os.remove(nojump_filename(reader)) os.remove(nojump_load_filename(reader))
return return
try: try:
if data['checksum'] == checksum(NOJUMP_MAGIC, checksum(reader)): if data['checksum'] == checksum(NOJUMP_MAGIC, checksum(reader)):
@ -194,12 +218,12 @@ def load_nojump_matrixes(reader):
) )
for d in range(3) for d in range(3)
) )
logger.info('Loaded Nojump Matrixes: {}'.format(nojump_filename(reader))) logger.info('Loaded Nojump Matrixes: {}'.format(nojump_load_filename(reader)))
else: else:
logger.info('Invlaid Nojump Data: {}'.format(nojump_filename(reader))) logger.info('Invlaid Nojump Data: {}'.format(nojump_load_filename(reader)))
except KeyError: except KeyError:
logger.info('Removing zip-File: %s', zipname) logger.info('Removing zip-File: %s', zipname)
os.remove(nojump_filename(reader)) os.remove(nojump_load_filename(reader))
return return
@ -277,7 +301,7 @@ class BaseReader:
""" """
self.rd = rd self.rd = rd
self._nojump_matrixes = None self._nojump_matrixes = None
if path.exists(nojump_filename(self)): if path.exists(nojump_load_filename(self)):
load_nojump_matrixes(self) load_nojump_matrixes(self)
def __getitem__(self, item): def __getitem__(self, item):