From 69715fd2a736359b6eba0e5203f9016afd4bbd65 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Sat, 22 Jul 2023 11:57:17 +0200 Subject: [PATCH] Added multi_selector to shifted_correlation. --- mdevaluate/coordinates.py | 3 +- mdevaluate/correlation.py | 58 +++++++++++++++++++++++++++++++-------- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/mdevaluate/coordinates.py b/mdevaluate/coordinates.py index a5b9f21..fb99b93 100755 --- a/mdevaluate/coordinates.py +++ b/mdevaluate/coordinates.py @@ -162,7 +162,8 @@ class CoordinateFrame(np.ndarray): if self.mode != "nojump": if self.mode is not None: logger.warn( - "Combining Nojump with other Coordinate modes is not supported and may cause unexpected results." + "Combining Nojump with other Coordinate modes is not supported and " + "may cause unexpected results." ) frame = nojump(self) frame.mode = "nojump" diff --git a/mdevaluate/correlation.py b/mdevaluate/correlation.py index 34b273f..bce734c 100644 --- a/mdevaluate/correlation.py +++ b/mdevaluate/correlation.py @@ -99,12 +99,13 @@ def shifted_correlation( function, frames, selector=None, + multi_selector=None, segments=10, skip=0.1, window=0.5, average=True, points=100, - nodes=8, + nodes=1, ): """ Calculate the time series for a correlation function. @@ -118,7 +119,13 @@ def shifted_correlation( selector (opt.): A function that returns the indices depending on the staring frame for which particles the - correlation should be calculated + correlation should be calculated. + Can not be used with multi_selector. + multi_selector (opt.): + A function that returns multiple lists of indices depending on + the staring frame for which particles the + correlation should be calculated. + Can not be used with selector. segments (int, opt.): The number of segments the time window will be shifted @@ -150,14 +157,9 @@ def shifted_correlation( >>> time, data = shifted_correlation(msd, coords) """ - def get_correlation(start_frame, idx, selector=None): - shifted_idx = idx + start_frame - if selector: - index = selector(frames[start_frame]) - else: - index = np.arange(len(frames[start_frame])) + def get_correlation(frames, start_frame, index, shifted_idx): if len(index) == 0: - return np.zeros(len(idx)) + correlation = np.zeros(len(shifted_idx)) else: start = frames[start_frame][index] correlation = np.array( @@ -165,6 +167,32 @@ def shifted_correlation( ) return correlation + def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None): + shifted_idx = idx + start_frame + if selector is None and multi_selector is None: + index = np.arange(len(frames[start_frame])) + return get_correlation(frames, start_frame, index, shifted_idx) + + elif selector is not None and multi_selector is not None: + raise ValueError( + "selector and multi_selector can not be used at the same time" + ) + + elif selector is not None: + index = selector(frames[start_frame]) + return get_correlation(frames, start_frame, index, shifted_idx) + + elif multi_selector is not None: + indices = multi_selector(frames[start_frame]) + correlation = [] + for index in indices: + correlation.append( + get_correlation( + frames, start_frame, index, shifted_idx + ) + ) + return correlation + if 1 - skip < window: window = 1 - skip @@ -175,7 +203,7 @@ def shifted_correlation( num=segments, endpoint=False, dtype=int, - ) + ) ) num_frames = int(len(frames) * window) @@ -186,7 +214,9 @@ def shifted_correlation( if nodes == 1: result = np.array( [ - get_correlation(start_frame, idx=idx, selector=selector) + apply_selector(start_frame, frames=frames, idx=idx, + selector=selector, + multi_selector=multi_selector) for start_frame in start_frames ] ) @@ -198,7 +228,11 @@ def shifted_correlation( try: result = np.array( pool.map( - partial(get_correlation, idx=idx, selector=selector), start_frames + partial(apply_selector, + frames=frames, + idx=idx, + selector=selector, + multi_selector=multi_selector), start_frames ) ) finally: