Combined selector and multi_selector in shifted_correlation and added type hints

This commit is contained in:
Sebastian Kloth 2023-12-26 16:38:55 +01:00
parent 476f7167b4
commit e91de71787

View File

@ -1,9 +1,8 @@
from typing import Callable
from typing import Callable, Optional
import numpy as np
from numpy.typing import ArrayLike
from scipy.special import legendre
from itertools import chain
import dask.array as darray
from functools import partial
from scipy.spatial import KDTree
@ -14,29 +13,22 @@ from .pbc import pbc_diff, pbc_points
from .coordinates import Coordinates, CoordinateFrame, displacements_without_drift
def log_indices(first, last, num=100):
def log_indices(first: int, last: int, num: int = 100) -> np.ndarray:
ls = np.logspace(0, np.log10(last - first + 1), num=num)
return np.unique(np.int_(ls) - 1 + first)
def correlation(function, frames):
iterator = iter(frames)
start_frame = next(iterator)
return map(lambda f: function(start_frame, f), chain([start_frame], iterator))
@autosave_data(2)
def shifted_correlation(
function: Callable,
frames: Coordinates,
selector: ArrayLike = None,
multi_selector: ArrayLike = None,
segments: int = 10,
skip: float = 0.1,
window: float = 0.5,
average: bool = True,
points: int = 100,
):
) -> (np.ndarray, np.ndarray):
"""
Calculate the time series for a correlation function.
@ -50,12 +42,6 @@ def shifted_correlation(
A function that returns the indices depending on
the staring frame for which particles the
correlation should be calculated.
Can not be used with multi_selector.
multi_selector (opt.):
A function that returns multiple lists of indices depending on
the staring frame for which particles the
correlation should be calculated.
Can not be used with selector.
segments (int, opt.):
The number of segments the time window will be
shifted
@ -79,12 +65,18 @@ def shifted_correlation(
that holds the (non-avaraged) correlation data
Example:
Calculating the mean square displacement of a coordinates object named ``coords``:
Calculating the mean square displacement of a coordinate object
named ``coords``:
>>> time, data = shifted_correlation(msd, coords)
"""
def get_correlation(frames, start_frame, index, shifted_idx):
def get_correlation(
frames: CoordinateFrame,
start_frame: CoordinateFrame,
index: np.ndarray,
shifted_idx: np.ndarray,
) -> np.ndarray:
if len(index) == 0:
correlation = np.zeros(len(shifted_idx))
else:
@ -94,29 +86,33 @@ def shifted_correlation(
)
return correlation
def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None):
def apply_selector(
start_frame: CoordinateFrame,
frames: CoordinateFrame,
idx: np.ndarray,
selector: Optional[Callable] = None,
):
shifted_idx = idx + start_frame
if selector is None and multi_selector is None:
if selector is None:
index = np.arange(len(frames[start_frame]))
return get_correlation(frames, start_frame, index, shifted_idx)
elif selector is not None and multi_selector is not None:
raise ValueError(
"selector and multi_selector can not be used at the same time"
)
elif selector is not None:
else:
index = selector(frames[start_frame])
if len(index.shape) == 1:
return get_correlation(frames, start_frame, index, shifted_idx)
elif multi_selector is not None:
indices = multi_selector(frames[start_frame])
correlation = []
for index in indices:
correlation.append(
get_correlation(frames, start_frame, index, shifted_idx)
elif len(index.shape) == 2:
correlations = []
for ind in index:
correlations.append(
get_correlation(frames, start_frame, ind, shifted_idx)
)
return correlations
else:
raise ValueError(
f"Index list of selector has {len(index.shape)} dimensions, "
"but should have 1 or 2"
)
return correlation
if 1 - skip < window:
window = 1 - skip
@ -138,13 +134,7 @@ def shifted_correlation(
result = np.array(
[
apply_selector(
start_frame,
frames=frames,
idx=idx,
selector=selector,
multi_selector=multi_selector,
)
apply_selector(start_frame, frames=frames, idx=idx, selector=selector)
for start_frame in start_frames
]
)
@ -196,8 +186,6 @@ def isf(
"""
Incoherent intermediate scattering function. To specify q, use
water_isf = functools.partial(isf, q=22.77) # q has the value 22.77 nm^-1
:param q: length of scattering vector
"""
if trajectory is None:
displacements = start_frame - end_frame
@ -219,19 +207,21 @@ def isf(
raise ValueError('Parameter axis has to be ether "all", "x", "y", or "z"!')
def rotational_autocorrelation(onset, frame, order=2):
def rotational_autocorrelation(
start_frame: CoordinateFrame, end_frame: CoordinateFrame, order: int = 2
) -> float:
"""
Compute the rotational autocorrelation of the legendre polynomial for the
given vectors.
Args:
onset, frame: CoordinateFrames of vectors
start_frame, end_frame: CoordinateFrames of vectors
order (opt.): Order of the legendre polynomial.
Returns:
Scalar value of the correlation function.
"""
scalar_prod = (onset * frame).sum(axis=-1)
scalar_prod = (start_frame * end_frame).sum(axis=-1)
poly = legendre(order)
return poly(scalar_prod).mean()
@ -242,9 +232,9 @@ def van_hove_self(
bins: ArrayLike,
trajectory: Coordinates = None,
axis: str = "all",
):
) -> np.ndarray:
r"""
Compute the self part of the Van Hove autocorrelation function.
Compute the self-part of the Van Hove autocorrelation function.
..math::
G(r, t) = \sum_i \delta(|\vec r_i(0) - \vec r_i(t)| - r)
@ -269,8 +259,13 @@ def van_hove_self(
def van_hove_distinct(
onset, frame, bins, box=None, use_dask=True, comp=False, bincount=True
):
start_frame: CoordinateFrame,
end_frame: CoordinateFrame,
bins: ArrayLike,
box: ArrayLike = None,
use_dask: bool = True,
comp: bool = False,
) -> np.ndarray:
r"""
Compute the distinct part of the Van Hove autocorrelation function.
@ -278,17 +273,19 @@ def van_hove_distinct(
G(r, t) = \sum_{i, j} \delta(|\vec r_i(0) - \vec r_j(t)| - r)
"""
if box is None:
box = onset.box.diagonal()
box = start_frame.box.diagonal()
dimension = len(box)
N = len(onset)
N = len(start_frame)
if use_dask:
onset = darray.from_array(onset, chunks=(500, dimension)).reshape(
start_frame = darray.from_array(start_frame, chunks=(500, dimension)).reshape(
1, N, dimension
)
frame = darray.from_array(frame, chunks=(500, dimension)).reshape(
end_frame = darray.from_array(end_frame, chunks=(500, dimension)).reshape(
N, 1, dimension
)
dist = ((pbc_diff(onset, frame, box) ** 2).sum(axis=-1) ** 0.5).ravel()
dist = (
(pbc_diff(start_frame, end_frame, box) ** 2).sum(axis=-1) ** 0.5
).ravel()
if np.diff(bins).std() < 1e6:
dx = bins[0] - bins[1]
hist = darray.bincount((dist // dx).astype(int), minlength=(len(bins) - 1))
@ -301,15 +298,18 @@ def van_hove_distinct(
minlength = len(bins) - 1
def f(x):
d = (pbc_diff(x, frame, box) ** 2).sum(axis=-1) ** 0.5
d = (pbc_diff(x, end_frame, box) ** 2).sum(axis=-1) ** 0.5
return np.bincount((d // dx).astype(int), minlength=minlength)[
:minlength
]
hist = sum(f(x) for x in onset)
hist = sum(f(x) for x in start_frame)
else:
dist = (
pbc_diff(onset.reshape(1, -1, 3), frame.reshape(-1, 1, 3), box) ** 2
pbc_diff(
start_frame.reshape(1, -1, 3), end_frame.reshape(-1, 1, 3), box
)
** 2
).sum(axis=-1) ** 0.5
hist = histogram(dist, bins=bins)[0]
return hist / N
@ -320,17 +320,18 @@ def overlap(
end_frame: CoordinateFrame,
radius: float = 0.1,
mode: str = "self",
):
) -> float:
"""
Compute the overlap with a reference configuration defined in a CoordinatesTree.
Args:
onset: Initial frame, this is only used to get the frame index
frame: The current configuration
start_frame: Initial frame, this is only used to get the frame index
end_frame: The current configuration
radius: The cutoff radius for the overlap
mode: Select between "indifferent", "self" or "distict" part of the overlap
This function is intended to be used with :func:`shifted_correlation`.
As usual the first two arguments are used internally and the remaining ones
As usual, the first two arguments are used internally, and the remaining ones
should be defined with :func:`functools.partial`.
If the overlap of a subset of the system should be calculated, this has to be
@ -359,11 +360,13 @@ def overlap(
return np.sum(index != index_dist) / len(start_frame)
def coherent_scattering_function(onset, frame, q):
def coherent_scattering_function(
start_frame: CoordinateFrame, end_frame: CoordinateFrame, q: float
) -> np.ndarray:
"""
Calculate the coherent scattering function.
"""
box = onset.box.diagonal()
box = start_frame.box.diagonal()
dimension = len(box)
def scfunc(x, y):
@ -381,7 +384,7 @@ def coherent_scattering_function(onset, frame, q):
else:
return np.sin(x) / x
return coherent_sum(scfunc, onset.pbc, frame.pbc) / len(onset)
return coherent_sum(scfunc, start_frame.pbc, end_frame.pbc) / len(start_frame)
def non_gaussian_parameter(
@ -391,7 +394,7 @@ def non_gaussian_parameter(
axis: str = "all",
) -> float:
"""
Calculate the Non-Gaussian parameter :
Calculate the non-Gaussian parameter.
..math:
\alpha_2 (t) =
\frac{3}{5}\frac{\langle r_i^4(t)\rangle}{\langle r_i^2(t)\rangle^2} - 1