199 lines
6.4 KiB
Python
199 lines
6.4 KiB
Python
import os
|
|
import functools
|
|
import inspect
|
|
from typing import Optional, Callable, Iterable
|
|
|
|
import numpy as np
|
|
from .checksum import checksum
|
|
from .logging_util import logger
|
|
|
|
autosave_directory: Optional[str] = None
|
|
load_autosave_data = False
|
|
verbose_print = True
|
|
user_autosave_directory = os.path.join(os.environ["HOME"], ".mdevaluate/autosave")
|
|
|
|
|
|
def notify(msg: str):
|
|
if verbose_print:
|
|
logger.info(msg)
|
|
else:
|
|
logger.debug(msg)
|
|
|
|
|
|
def enable(dir: str, load_data: bool = True, verbose: bool = True):
|
|
"""
|
|
Enable auto saving results of functions decorated with: func: `autosave_data`.
|
|
|
|
Args:
|
|
dir: Directory where the data should be saved.
|
|
load_data (opt., bool): If data should also be loaded.
|
|
verbose (opt., bool): If autosave should be verbose.
|
|
"""
|
|
global autosave_directory, load_autosave_data, verbose_print
|
|
verbose_print = verbose
|
|
# absolute = os.path.abspath(dir)
|
|
# os.makedirs(absolute, exist_ok=True)
|
|
autosave_directory = dir
|
|
load_autosave_data = load_data
|
|
notify("Enabled autosave in directory: {}".format(autosave_directory))
|
|
|
|
|
|
def disable():
|
|
"""Disable autosave."""
|
|
global autosave_directory, load_autosave_data
|
|
autosave_directory = None
|
|
load_autosave_data = False
|
|
|
|
|
|
class disabled:
|
|
"""
|
|
A context manager that disbales the autosave module within its context.
|
|
|
|
Example:
|
|
import mdevaluate as md
|
|
md.autosave.enable('data')
|
|
with md.autosave.disabled():
|
|
# Autosave functionality is disabled within this context.
|
|
md.correlation.shifted_correlation(
|
|
...
|
|
)
|
|
|
|
# After the context is exited, autosave will work as before.
|
|
"""
|
|
|
|
def __enter__(self):
|
|
self._autosave_directory = autosave_directory
|
|
disable()
|
|
|
|
def __exit__(self, *args):
|
|
enable(self._autosave_directory)
|
|
|
|
|
|
def get_directory(reader):
|
|
"""Get the autosave directory for a trajectory reader."""
|
|
outdir = os.path.dirname(reader.filename)
|
|
savedir = os.path.join(outdir, autosave_directory)
|
|
if not os.path.exists(savedir):
|
|
try:
|
|
os.makedirs(savedir)
|
|
except PermissionError:
|
|
pass
|
|
if not os.access(savedir, os.W_OK):
|
|
savedir = os.path.join(user_autosave_directory, savedir.lstrip("/"))
|
|
logger.info(
|
|
"Switched autosave directory to {}, "
|
|
"since original location is not writeable.".format(savedir)
|
|
)
|
|
os.makedirs(savedir, exist_ok=True)
|
|
return savedir
|
|
|
|
|
|
def get_filename(function, checksum, description, *args):
|
|
"""Get the autosave filename for a specific function call."""
|
|
func_desc = function.__name__
|
|
for arg in args:
|
|
if hasattr(arg, "__name__"):
|
|
func_desc += "_{}".format(arg.__name__)
|
|
elif isinstance(arg, functools.partial):
|
|
func_desc += "_{}".format(arg.func.__name__)
|
|
|
|
if hasattr(arg, "frames"):
|
|
savedir = get_directory(arg.frames)
|
|
|
|
if hasattr(arg, "description") and arg.description != "":
|
|
description += "_{}".format(arg.description)
|
|
filename = "{}_{}.npz".format(func_desc.strip("_"), description.strip("_"))
|
|
return os.path.join(savedir, filename)
|
|
|
|
|
|
def verify_file(filename, checksum):
|
|
"""Verify if the file matches the function call."""
|
|
file_checksum = 0
|
|
if os.path.exists(filename):
|
|
data = np.load(filename, allow_pickle=True)
|
|
if "checksum" in data:
|
|
file_checksum = data["checksum"]
|
|
return file_checksum == checksum
|
|
|
|
|
|
def save_data(filename, checksum, data):
|
|
"""Save data and checksum to a file."""
|
|
notify("Saving result to file: {}".format(filename))
|
|
try:
|
|
data = np.array(data)
|
|
except ValueError:
|
|
arr = np.empty((len(data),), dtype=object)
|
|
arr[:] = data
|
|
data = arr
|
|
|
|
np.savez(filename, checksum=checksum, data=data)
|
|
|
|
|
|
def load_data(filename):
|
|
"""Load data from a npz file."""
|
|
notify("Loading result from file: {}".format(filename))
|
|
fdata = np.load(filename, allow_pickle=True)
|
|
if "data" in fdata:
|
|
return fdata["data"]
|
|
else:
|
|
data = tuple(fdata[k] for k in sorted(fdata) if ("arr" in k))
|
|
save_data(filename, fdata["checksum"], data)
|
|
return data
|
|
|
|
|
|
def autosave_data(
|
|
nargs: int, kwargs_keys: Optional[Iterable[str]] = None, version: Optional[str] = None
|
|
) -> Callable:
|
|
"""
|
|
Enable autosaving of results for a function.
|
|
|
|
Args:
|
|
nargs: Number of args which are relevant for the calculation.
|
|
kwargs_keys (opt.):
|
|
List of keyword arguments which are relevant for the calculation.
|
|
version (opt.):
|
|
An optional version number of the decorated function, which replaces the
|
|
checksum of the function code, hence the checksum does not depend on the
|
|
function code.
|
|
"""
|
|
|
|
def decorator_function(function):
|
|
# make sure to include names of positional arguments in kwargs_keys,
|
|
# sice otherwise they will be ignored if passed via keyword.
|
|
# nonlocal kwargs_keys
|
|
posargs_keys = list(inspect.signature(function).parameters)[:nargs]
|
|
|
|
@functools.wraps(function)
|
|
def autosave(*args, **kwargs):
|
|
description = kwargs.pop("description", "")
|
|
autoload = kwargs.pop("autoload", True) and load_autosave_data
|
|
if autosave_directory is not None:
|
|
relevant_args = list(args[:nargs])
|
|
if kwargs_keys is not None:
|
|
for key in [*posargs_keys, *kwargs_keys]:
|
|
if key in kwargs:
|
|
relevant_args.append(kwargs[key])
|
|
|
|
if version is None:
|
|
csum = legacy_csum = checksum(function, *relevant_args)
|
|
else:
|
|
csum = checksum(version, *relevant_args)
|
|
legacy_csum = checksum(function, *relevant_args)
|
|
|
|
filename = get_filename(function, csum, description, *relevant_args)
|
|
if autoload and (
|
|
verify_file(filename, csum) or verify_file(filename, legacy_csum)
|
|
):
|
|
result = load_data(filename)
|
|
else:
|
|
result = function(*args, **kwargs)
|
|
save_data(filename, csum, result)
|
|
else:
|
|
result = function(*args, **kwargs)
|
|
|
|
return result
|
|
|
|
return autosave
|
|
|
|
return decorator_function
|