diff --git a/src/mdevaluate/correlation.py b/src/mdevaluate/correlation.py index c6a2cbd..edcd942 100644 --- a/src/mdevaluate/correlation.py +++ b/src/mdevaluate/correlation.py @@ -1,9 +1,8 @@ -from typing import Callable +from typing import Callable, Optional import numpy as np from numpy.typing import ArrayLike from scipy.special import legendre -from itertools import chain import dask.array as darray from functools import partial from scipy.spatial import KDTree @@ -14,29 +13,22 @@ from .pbc import pbc_diff, pbc_points from .coordinates import Coordinates, CoordinateFrame, displacements_without_drift -def log_indices(first, last, num=100): +def log_indices(first: int, last: int, num: int = 100) -> np.ndarray: ls = np.logspace(0, np.log10(last - first + 1), num=num) return np.unique(np.int_(ls) - 1 + first) -def correlation(function, frames): - iterator = iter(frames) - start_frame = next(iterator) - return map(lambda f: function(start_frame, f), chain([start_frame], iterator)) - - @autosave_data(2) def shifted_correlation( function: Callable, frames: Coordinates, selector: ArrayLike = None, - multi_selector: ArrayLike = None, segments: int = 10, skip: float = 0.1, window: float = 0.5, average: bool = True, points: int = 100, -): +) -> (np.ndarray, np.ndarray): """ Calculate the time series for a correlation function. @@ -44,18 +36,12 @@ def shifted_correlation( a logarithmic distribution. Args: - function: The function that should be correlated - frames: The coordinates of the simulation data + function: The function that should be correlated + frames: The coordinates of the simulation data selector (opt.): A function that returns the indices depending on the staring frame for which particles the correlation should be calculated. - Can not be used with multi_selector. - multi_selector (opt.): - A function that returns multiple lists of indices depending on - the staring frame for which particles the - correlation should be calculated. - Can not be used with selector. segments (int, opt.): The number of segments the time window will be shifted @@ -79,12 +65,18 @@ def shifted_correlation( that holds the (non-avaraged) correlation data Example: - Calculating the mean square displacement of a coordinates object named ``coords``: + Calculating the mean square displacement of a coordinate object + named ``coords``: >>> time, data = shifted_correlation(msd, coords) """ - def get_correlation(frames, start_frame, index, shifted_idx): + def get_correlation( + frames: CoordinateFrame, + start_frame: CoordinateFrame, + index: np.ndarray, + shifted_idx: np.ndarray, + ) -> np.ndarray: if len(index) == 0: correlation = np.zeros(len(shifted_idx)) else: @@ -94,29 +86,33 @@ def shifted_correlation( ) return correlation - def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None): + def apply_selector( + start_frame: CoordinateFrame, + frames: CoordinateFrame, + idx: np.ndarray, + selector: Optional[Callable] = None, + ): shifted_idx = idx + start_frame - if selector is None and multi_selector is None: + + if selector is None: index = np.arange(len(frames[start_frame])) return get_correlation(frames, start_frame, index, shifted_idx) - - elif selector is not None and multi_selector is not None: - raise ValueError( - "selector and multi_selector can not be used at the same time" - ) - - elif selector is not None: + else: index = selector(frames[start_frame]) - return get_correlation(frames, start_frame, index, shifted_idx) - - elif multi_selector is not None: - indices = multi_selector(frames[start_frame]) - correlation = [] - for index in indices: - correlation.append( - get_correlation(frames, start_frame, index, shifted_idx) + if len(index.shape) == 1: + return get_correlation(frames, start_frame, index, shifted_idx) + elif len(index.shape) == 2: + correlations = [] + for ind in index: + correlations.append( + get_correlation(frames, start_frame, ind, shifted_idx) + ) + return correlations + else: + raise ValueError( + f"Index list of selector has {len(index.shape)} dimensions, " + "but should have 1 or 2" ) - return correlation if 1 - skip < window: window = 1 - skip @@ -138,13 +134,7 @@ def shifted_correlation( result = np.array( [ - apply_selector( - start_frame, - frames=frames, - idx=idx, - selector=selector, - multi_selector=multi_selector, - ) + apply_selector(start_frame, frames=frames, idx=idx, selector=selector) for start_frame in start_frames ] ) @@ -196,8 +186,6 @@ def isf( """ Incoherent intermediate scattering function. To specify q, use water_isf = functools.partial(isf, q=22.77) # q has the value 22.77 nm^-1 - - :param q: length of scattering vector """ if trajectory is None: displacements = start_frame - end_frame @@ -219,19 +207,21 @@ def isf( raise ValueError('Parameter axis has to be ether "all", "x", "y", or "z"!') -def rotational_autocorrelation(onset, frame, order=2): +def rotational_autocorrelation( + start_frame: CoordinateFrame, end_frame: CoordinateFrame, order: int = 2 +) -> float: """ Compute the rotational autocorrelation of the legendre polynomial for the given vectors. Args: - onset, frame: CoordinateFrames of vectors + start_frame, end_frame: CoordinateFrames of vectors order (opt.): Order of the legendre polynomial. Returns: Scalar value of the correlation function. """ - scalar_prod = (onset * frame).sum(axis=-1) + scalar_prod = (start_frame * end_frame).sum(axis=-1) poly = legendre(order) return poly(scalar_prod).mean() @@ -242,9 +232,9 @@ def van_hove_self( bins: ArrayLike, trajectory: Coordinates = None, axis: str = "all", -): +) -> np.ndarray: r""" - Compute the self part of the Van Hove autocorrelation function. + Compute the self-part of the Van Hove autocorrelation function. ..math:: G(r, t) = \sum_i \delta(|\vec r_i(0) - \vec r_i(t)| - r) @@ -269,8 +259,13 @@ def van_hove_self( def van_hove_distinct( - onset, frame, bins, box=None, use_dask=True, comp=False, bincount=True -): + start_frame: CoordinateFrame, + end_frame: CoordinateFrame, + bins: ArrayLike, + box: ArrayLike = None, + use_dask: bool = True, + comp: bool = False, +) -> np.ndarray: r""" Compute the distinct part of the Van Hove autocorrelation function. @@ -278,17 +273,19 @@ def van_hove_distinct( G(r, t) = \sum_{i, j} \delta(|\vec r_i(0) - \vec r_j(t)| - r) """ if box is None: - box = onset.box.diagonal() + box = start_frame.box.diagonal() dimension = len(box) - N = len(onset) + N = len(start_frame) if use_dask: - onset = darray.from_array(onset, chunks=(500, dimension)).reshape( + start_frame = darray.from_array(start_frame, chunks=(500, dimension)).reshape( 1, N, dimension ) - frame = darray.from_array(frame, chunks=(500, dimension)).reshape( + end_frame = darray.from_array(end_frame, chunks=(500, dimension)).reshape( N, 1, dimension ) - dist = ((pbc_diff(onset, frame, box) ** 2).sum(axis=-1) ** 0.5).ravel() + dist = ( + (pbc_diff(start_frame, end_frame, box) ** 2).sum(axis=-1) ** 0.5 + ).ravel() if np.diff(bins).std() < 1e6: dx = bins[0] - bins[1] hist = darray.bincount((dist // dx).astype(int), minlength=(len(bins) - 1)) @@ -301,15 +298,18 @@ def van_hove_distinct( minlength = len(bins) - 1 def f(x): - d = (pbc_diff(x, frame, box) ** 2).sum(axis=-1) ** 0.5 + d = (pbc_diff(x, end_frame, box) ** 2).sum(axis=-1) ** 0.5 return np.bincount((d // dx).astype(int), minlength=minlength)[ :minlength ] - hist = sum(f(x) for x in onset) + hist = sum(f(x) for x in start_frame) else: dist = ( - pbc_diff(onset.reshape(1, -1, 3), frame.reshape(-1, 1, 3), box) ** 2 + pbc_diff( + start_frame.reshape(1, -1, 3), end_frame.reshape(-1, 1, 3), box + ) + ** 2 ).sum(axis=-1) ** 0.5 hist = histogram(dist, bins=bins)[0] return hist / N @@ -320,17 +320,18 @@ def overlap( end_frame: CoordinateFrame, radius: float = 0.1, mode: str = "self", -): +) -> float: """ Compute the overlap with a reference configuration defined in a CoordinatesTree. Args: - onset: Initial frame, this is only used to get the frame index - frame: The current configuration + start_frame: Initial frame, this is only used to get the frame index + end_frame: The current configuration radius: The cutoff radius for the overlap + mode: Select between "indifferent", "self" or "distict" part of the overlap This function is intended to be used with :func:`shifted_correlation`. - As usual the first two arguments are used internally and the remaining ones + As usual, the first two arguments are used internally, and the remaining ones should be defined with :func:`functools.partial`. If the overlap of a subset of the system should be calculated, this has to be @@ -359,11 +360,13 @@ def overlap( return np.sum(index != index_dist) / len(start_frame) -def coherent_scattering_function(onset, frame, q): +def coherent_scattering_function( + start_frame: CoordinateFrame, end_frame: CoordinateFrame, q: float +) -> np.ndarray: """ Calculate the coherent scattering function. """ - box = onset.box.diagonal() + box = start_frame.box.diagonal() dimension = len(box) def scfunc(x, y): @@ -381,7 +384,7 @@ def coherent_scattering_function(onset, frame, q): else: return np.sin(x) / x - return coherent_sum(scfunc, onset.pbc, frame.pbc) / len(onset) + return coherent_sum(scfunc, start_frame.pbc, end_frame.pbc) / len(start_frame) def non_gaussian_parameter( @@ -391,7 +394,7 @@ def non_gaussian_parameter( axis: str = "all", ) -> float: """ - Calculate the Non-Gaussian parameter : + Calculate the non-Gaussian parameter. ..math: \alpha_2 (t) = \frac{3}{5}\frac{\langle r_i^4(t)\rangle}{\langle r_i^2(t)\rangle^2} - 1