Combined selector and multi_selector in shifted_correlation and added type hints

This commit is contained in:
Sebastian Kloth 2023-12-26 16:38:55 +01:00
parent 476f7167b4
commit e91de71787

View File

@ -1,9 +1,8 @@
from typing import Callable from typing import Callable, Optional
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike
from scipy.special import legendre from scipy.special import legendre
from itertools import chain
import dask.array as darray import dask.array as darray
from functools import partial from functools import partial
from scipy.spatial import KDTree from scipy.spatial import KDTree
@ -14,29 +13,22 @@ from .pbc import pbc_diff, pbc_points
from .coordinates import Coordinates, CoordinateFrame, displacements_without_drift from .coordinates import Coordinates, CoordinateFrame, displacements_without_drift
def log_indices(first, last, num=100): def log_indices(first: int, last: int, num: int = 100) -> np.ndarray:
ls = np.logspace(0, np.log10(last - first + 1), num=num) ls = np.logspace(0, np.log10(last - first + 1), num=num)
return np.unique(np.int_(ls) - 1 + first) return np.unique(np.int_(ls) - 1 + first)
def correlation(function, frames):
iterator = iter(frames)
start_frame = next(iterator)
return map(lambda f: function(start_frame, f), chain([start_frame], iterator))
@autosave_data(2) @autosave_data(2)
def shifted_correlation( def shifted_correlation(
function: Callable, function: Callable,
frames: Coordinates, frames: Coordinates,
selector: ArrayLike = None, selector: ArrayLike = None,
multi_selector: ArrayLike = None,
segments: int = 10, segments: int = 10,
skip: float = 0.1, skip: float = 0.1,
window: float = 0.5, window: float = 0.5,
average: bool = True, average: bool = True,
points: int = 100, points: int = 100,
): ) -> (np.ndarray, np.ndarray):
""" """
Calculate the time series for a correlation function. Calculate the time series for a correlation function.
@ -44,18 +36,12 @@ def shifted_correlation(
a logarithmic distribution. a logarithmic distribution.
Args: Args:
function: The function that should be correlated function: The function that should be correlated
frames: The coordinates of the simulation data frames: The coordinates of the simulation data
selector (opt.): selector (opt.):
A function that returns the indices depending on A function that returns the indices depending on
the staring frame for which particles the the staring frame for which particles the
correlation should be calculated. correlation should be calculated.
Can not be used with multi_selector.
multi_selector (opt.):
A function that returns multiple lists of indices depending on
the staring frame for which particles the
correlation should be calculated.
Can not be used with selector.
segments (int, opt.): segments (int, opt.):
The number of segments the time window will be The number of segments the time window will be
shifted shifted
@ -79,12 +65,18 @@ def shifted_correlation(
that holds the (non-avaraged) correlation data that holds the (non-avaraged) correlation data
Example: Example:
Calculating the mean square displacement of a coordinates object named ``coords``: Calculating the mean square displacement of a coordinate object
named ``coords``:
>>> time, data = shifted_correlation(msd, coords) >>> time, data = shifted_correlation(msd, coords)
""" """
def get_correlation(frames, start_frame, index, shifted_idx): def get_correlation(
frames: CoordinateFrame,
start_frame: CoordinateFrame,
index: np.ndarray,
shifted_idx: np.ndarray,
) -> np.ndarray:
if len(index) == 0: if len(index) == 0:
correlation = np.zeros(len(shifted_idx)) correlation = np.zeros(len(shifted_idx))
else: else:
@ -94,29 +86,33 @@ def shifted_correlation(
) )
return correlation return correlation
def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None): def apply_selector(
start_frame: CoordinateFrame,
frames: CoordinateFrame,
idx: np.ndarray,
selector: Optional[Callable] = None,
):
shifted_idx = idx + start_frame shifted_idx = idx + start_frame
if selector is None and multi_selector is None:
if selector is None:
index = np.arange(len(frames[start_frame])) index = np.arange(len(frames[start_frame]))
return get_correlation(frames, start_frame, index, shifted_idx) return get_correlation(frames, start_frame, index, shifted_idx)
else:
elif selector is not None and multi_selector is not None:
raise ValueError(
"selector and multi_selector can not be used at the same time"
)
elif selector is not None:
index = selector(frames[start_frame]) index = selector(frames[start_frame])
return get_correlation(frames, start_frame, index, shifted_idx) if len(index.shape) == 1:
return get_correlation(frames, start_frame, index, shifted_idx)
elif multi_selector is not None: elif len(index.shape) == 2:
indices = multi_selector(frames[start_frame]) correlations = []
correlation = [] for ind in index:
for index in indices: correlations.append(
correlation.append( get_correlation(frames, start_frame, ind, shifted_idx)
get_correlation(frames, start_frame, index, shifted_idx) )
return correlations
else:
raise ValueError(
f"Index list of selector has {len(index.shape)} dimensions, "
"but should have 1 or 2"
) )
return correlation
if 1 - skip < window: if 1 - skip < window:
window = 1 - skip window = 1 - skip
@ -138,13 +134,7 @@ def shifted_correlation(
result = np.array( result = np.array(
[ [
apply_selector( apply_selector(start_frame, frames=frames, idx=idx, selector=selector)
start_frame,
frames=frames,
idx=idx,
selector=selector,
multi_selector=multi_selector,
)
for start_frame in start_frames for start_frame in start_frames
] ]
) )
@ -196,8 +186,6 @@ def isf(
""" """
Incoherent intermediate scattering function. To specify q, use Incoherent intermediate scattering function. To specify q, use
water_isf = functools.partial(isf, q=22.77) # q has the value 22.77 nm^-1 water_isf = functools.partial(isf, q=22.77) # q has the value 22.77 nm^-1
:param q: length of scattering vector
""" """
if trajectory is None: if trajectory is None:
displacements = start_frame - end_frame displacements = start_frame - end_frame
@ -219,19 +207,21 @@ def isf(
raise ValueError('Parameter axis has to be ether "all", "x", "y", or "z"!') raise ValueError('Parameter axis has to be ether "all", "x", "y", or "z"!')
def rotational_autocorrelation(onset, frame, order=2): def rotational_autocorrelation(
start_frame: CoordinateFrame, end_frame: CoordinateFrame, order: int = 2
) -> float:
""" """
Compute the rotational autocorrelation of the legendre polynomial for the Compute the rotational autocorrelation of the legendre polynomial for the
given vectors. given vectors.
Args: Args:
onset, frame: CoordinateFrames of vectors start_frame, end_frame: CoordinateFrames of vectors
order (opt.): Order of the legendre polynomial. order (opt.): Order of the legendre polynomial.
Returns: Returns:
Scalar value of the correlation function. Scalar value of the correlation function.
""" """
scalar_prod = (onset * frame).sum(axis=-1) scalar_prod = (start_frame * end_frame).sum(axis=-1)
poly = legendre(order) poly = legendre(order)
return poly(scalar_prod).mean() return poly(scalar_prod).mean()
@ -242,9 +232,9 @@ def van_hove_self(
bins: ArrayLike, bins: ArrayLike,
trajectory: Coordinates = None, trajectory: Coordinates = None,
axis: str = "all", axis: str = "all",
): ) -> np.ndarray:
r""" r"""
Compute the self part of the Van Hove autocorrelation function. Compute the self-part of the Van Hove autocorrelation function.
..math:: ..math::
G(r, t) = \sum_i \delta(|\vec r_i(0) - \vec r_i(t)| - r) G(r, t) = \sum_i \delta(|\vec r_i(0) - \vec r_i(t)| - r)
@ -269,8 +259,13 @@ def van_hove_self(
def van_hove_distinct( def van_hove_distinct(
onset, frame, bins, box=None, use_dask=True, comp=False, bincount=True start_frame: CoordinateFrame,
): end_frame: CoordinateFrame,
bins: ArrayLike,
box: ArrayLike = None,
use_dask: bool = True,
comp: bool = False,
) -> np.ndarray:
r""" r"""
Compute the distinct part of the Van Hove autocorrelation function. Compute the distinct part of the Van Hove autocorrelation function.
@ -278,17 +273,19 @@ def van_hove_distinct(
G(r, t) = \sum_{i, j} \delta(|\vec r_i(0) - \vec r_j(t)| - r) G(r, t) = \sum_{i, j} \delta(|\vec r_i(0) - \vec r_j(t)| - r)
""" """
if box is None: if box is None:
box = onset.box.diagonal() box = start_frame.box.diagonal()
dimension = len(box) dimension = len(box)
N = len(onset) N = len(start_frame)
if use_dask: if use_dask:
onset = darray.from_array(onset, chunks=(500, dimension)).reshape( start_frame = darray.from_array(start_frame, chunks=(500, dimension)).reshape(
1, N, dimension 1, N, dimension
) )
frame = darray.from_array(frame, chunks=(500, dimension)).reshape( end_frame = darray.from_array(end_frame, chunks=(500, dimension)).reshape(
N, 1, dimension N, 1, dimension
) )
dist = ((pbc_diff(onset, frame, box) ** 2).sum(axis=-1) ** 0.5).ravel() dist = (
(pbc_diff(start_frame, end_frame, box) ** 2).sum(axis=-1) ** 0.5
).ravel()
if np.diff(bins).std() < 1e6: if np.diff(bins).std() < 1e6:
dx = bins[0] - bins[1] dx = bins[0] - bins[1]
hist = darray.bincount((dist // dx).astype(int), minlength=(len(bins) - 1)) hist = darray.bincount((dist // dx).astype(int), minlength=(len(bins) - 1))
@ -301,15 +298,18 @@ def van_hove_distinct(
minlength = len(bins) - 1 minlength = len(bins) - 1
def f(x): def f(x):
d = (pbc_diff(x, frame, box) ** 2).sum(axis=-1) ** 0.5 d = (pbc_diff(x, end_frame, box) ** 2).sum(axis=-1) ** 0.5
return np.bincount((d // dx).astype(int), minlength=minlength)[ return np.bincount((d // dx).astype(int), minlength=minlength)[
:minlength :minlength
] ]
hist = sum(f(x) for x in onset) hist = sum(f(x) for x in start_frame)
else: else:
dist = ( dist = (
pbc_diff(onset.reshape(1, -1, 3), frame.reshape(-1, 1, 3), box) ** 2 pbc_diff(
start_frame.reshape(1, -1, 3), end_frame.reshape(-1, 1, 3), box
)
** 2
).sum(axis=-1) ** 0.5 ).sum(axis=-1) ** 0.5
hist = histogram(dist, bins=bins)[0] hist = histogram(dist, bins=bins)[0]
return hist / N return hist / N
@ -320,17 +320,18 @@ def overlap(
end_frame: CoordinateFrame, end_frame: CoordinateFrame,
radius: float = 0.1, radius: float = 0.1,
mode: str = "self", mode: str = "self",
): ) -> float:
""" """
Compute the overlap with a reference configuration defined in a CoordinatesTree. Compute the overlap with a reference configuration defined in a CoordinatesTree.
Args: Args:
onset: Initial frame, this is only used to get the frame index start_frame: Initial frame, this is only used to get the frame index
frame: The current configuration end_frame: The current configuration
radius: The cutoff radius for the overlap radius: The cutoff radius for the overlap
mode: Select between "indifferent", "self" or "distict" part of the overlap
This function is intended to be used with :func:`shifted_correlation`. This function is intended to be used with :func:`shifted_correlation`.
As usual the first two arguments are used internally and the remaining ones As usual, the first two arguments are used internally, and the remaining ones
should be defined with :func:`functools.partial`. should be defined with :func:`functools.partial`.
If the overlap of a subset of the system should be calculated, this has to be If the overlap of a subset of the system should be calculated, this has to be
@ -359,11 +360,13 @@ def overlap(
return np.sum(index != index_dist) / len(start_frame) return np.sum(index != index_dist) / len(start_frame)
def coherent_scattering_function(onset, frame, q): def coherent_scattering_function(
start_frame: CoordinateFrame, end_frame: CoordinateFrame, q: float
) -> np.ndarray:
""" """
Calculate the coherent scattering function. Calculate the coherent scattering function.
""" """
box = onset.box.diagonal() box = start_frame.box.diagonal()
dimension = len(box) dimension = len(box)
def scfunc(x, y): def scfunc(x, y):
@ -381,7 +384,7 @@ def coherent_scattering_function(onset, frame, q):
else: else:
return np.sin(x) / x return np.sin(x) / x
return coherent_sum(scfunc, onset.pbc, frame.pbc) / len(onset) return coherent_sum(scfunc, start_frame.pbc, end_frame.pbc) / len(start_frame)
def non_gaussian_parameter( def non_gaussian_parameter(
@ -391,7 +394,7 @@ def non_gaussian_parameter(
axis: str = "all", axis: str = "all",
) -> float: ) -> float:
""" """
Calculate the Non-Gaussian parameter : Calculate the non-Gaussian parameter.
..math: ..math:
\alpha_2 (t) = \alpha_2 (t) =
\frac{3}{5}\frac{\langle r_i^4(t)\rangle}{\langle r_i^2(t)\rangle^2} - 1 \frac{3}{5}\frac{\langle r_i^4(t)\rangle}{\langle r_i^2(t)\rangle^2} - 1