Files
python-store/store/utils.py
2022-04-11 15:38:09 +02:00

225 lines
7.0 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
from numbers import Number
from collections.abc import Iterable
from functools import wraps
import inspect
from glob import glob
import warnings
import pandas as pd
import numpy as np
import traceback
try:
import mdevaluate as md
except ImportError:
pass
from . import config
def nice_systems(df):
return df.replace(['.*bulk_3_1', '.*bulk_1_3', '.*bulk_1_9', '.*bulk'], ['3:1', '1:3', '1:9', '1:1'], regex=True)
def numlike(x):
return isinstance(x, Number)
def number_shorthand(x):
if numlike(x):
return True
elif isinstance(x, str):
limits = x.split('-')
if len(limits) == 2:
try:
limits = [float(a) for a in limits]
except:
return False
return True
else:
return False
elif isinstance(x, Iterable):
return all(numlike(v) for v in x)
else:
False
def lazy_eq(a, b):
try:
return a == b
except (ValueError, TypeError):
return False
def data_frame(**kwargs):
try:
df = pd.DataFrame(kwargs)
except ValueError:
df = pd.DataFrame(kwargs, index=[0])
return df
def traj_slice(N, skip, nr_averages):
step = int(N * (1 - skip) // nr_averages) or 1
return slice(int(skip * N), N, step)
def set_correlation_defaults(kwargs):
"""Set some sensefull defaults for shifted correlation parameters."""
for k in config['correlation'].keys():
kwargs.setdefault(k, config['correlation'].getfloat(k))
return kwargs
def merge_timeframes(left, right, on='time'):
"""Merge two dataframes with overlapping time scales."""
merged = pd.merge(left, right, on=on, how='outer', indicator=True, suffixes=('_x', ''))
res = pd.concat(
(left, merged[left.columns][merged._merge == 'right_only']), ignore_index=True
).sort_values(by=on).reset_index(drop=True)
return res
def open_sim(directory, maxcache=None):
return md.open(
directory, topology=config['eval']['topology'], trajectory=config['eval']['trajectory'],
reindex=True, nojump=config['eval']['nojump']
)
def collect_short(func):
"""
Decorator to run an analysis function for the given trajectory and associated short simulations.
Args:
short_subdir (opt.): Directory of short simulations, relative to the main trajectory file.
Decorated functions will be evaluate for the given trajecotory. Subsequently, this decorator will
look for multiple simulations in subdirectories '../short/*' of the trajecotory file. The results
for these simulations are then averaged and merged with the result of the original trajectory.
The simulations should be organized as follows, where the main trajectory, for which the function
is called, is basedir/out/long.xtc. Results for short times are then obtained from the two trajectories
located in basedir/short/100 and basedir/short/500.
basedir/
topol.tpr
out/
long.xtc
short/
100/
topol.tpr
out/
short.xtc
500/
topol.tpr
out/
short.xtc
"""
@wraps(func)
def wrapped(trajectory, *args, **kwargs):
res = func(trajectory, *args, **kwargs)
indices = trajectory.atom_subset.indices
description = trajectory.description
directory = os.path.dirname(trajectory.frames.filename)
short_dir = os.path.abspath(os.path.join(directory, config['eval']['short_dir']))
timestep = trajectory[1].time - trajectory[0].time
params = inspect.signature(func).parameters
has_other = 'other' in params
if has_other:
args = list(args)
other = args.pop(list(params).index('other') - 1)
other_indices = other.atom_subset.indices
other_description = other.description
res_short = {}
N = 0
for sd in glob(short_dir):
md.logger.debug(sd)
try:
traj = open_sim(sd)
except FileNotFoundError:
md.logger.info('Unale to load short simulation: %s', sd)
continue
if traj is None:
print('sim=None')
continue
N += 1
sim = traj.subset(indices=indices)
sim.description = description
if isinstance(trajectory, md.coordinates.CoordinatesMap):
mode = trajectory.coordinates.mode
if mode is not None:
sim = getattr(sim, mode)
sim = md.coordinates.CoordinatesMap(sim, trajectory.function)
else:
mode = trajectory.mode
if mode is not None:
sim = getattr(sim, mode)
if has_other:
other_sim = traj.subset(indices=other_indices)
other_sim.description = other_description
if isinstance(other, md.coordinates.CoordinatesMap):
mode = other.coordinates.mode
if mode is not None:
other_sim = getattr(other_sim, mode)
other_sim = md.coordinates.CoordinatesMap(other_sim, other.function)
else:
mode = other.mode
if mode is not None:
other_sim = getattr(other_sim, mode)
short_kwargs = kwargs
win = min(round(timestep / (sim[-1].time * 0.9), 3), 0.5)
short_kwargs['window'] = win
short_kwargs['skip'] = 0.01
short_kwargs['segments'] = min(int((0.9 - win) * (len(sim) - 1)), 20)
if has_other:
sr = func(sim, other_sim, *args, **short_kwargs)
else:
sr = func(sim, *args, **short_kwargs)
for key, val in sr.items():
if isinstance(val, pd.DataFrame):
if key in res_short:
res_short[key] += val
else:
res_short[key] = val
for key in res_short:
res_short[key] /= N
if len(res_short) != 0:
res.update({key: merge_timeframes(sr, res[key]) for key, sr in res_short.items()})
return res
return wrapped
def enhanced_bins(x, y, k=3):
"""
Determine enhanced bins for some x and y data.
The binssize will be reduced where the derivative of y is steep. The k parameter controlls,
how strong the binsize is influenced by the steepness of y.
"""
dy = np.absolute(np.diff(y))
dx_max = np.diff(x).mean()
xnew = [min(x)]
while xnew[-1] < max(x):
i = np.where(x > xnew[-1])[0][0]
if i >= len(dy):
break
dx = dx_max / (1 + k * dy[i])
xnew.append(xnew[-1] + dx)
return np.array(xnew)
def excess_entropy(r, gr, ρ=1):
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
y = (gr * np.log(gr) - (gr - 1)) * r**2
y = np.nan_to_num(y)
return -2 * np.pi * ρ * np.trapz(y, r)