From 87ffa1e67ed11f1c9a2574bd58748ea54ab456c8 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Tue, 16 Jan 2024 13:39:48 +0100 Subject: [PATCH] Fixed to find dimension of selector output --- src/mdevaluate/correlation.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/mdevaluate/correlation.py b/src/mdevaluate/correlation.py index 786f57b..a8b3265 100644 --- a/src/mdevaluate/correlation.py +++ b/src/mdevaluate/correlation.py @@ -22,7 +22,7 @@ def log_indices(first: int, last: int, num: int = 100) -> np.ndarray: def shifted_correlation( function: Callable, frames: Coordinates, - selector: ArrayLike = None, + selector: Optional[Callable] = None, segments: int = 10, skip: float = 0.1, window: float = 0.5, @@ -102,7 +102,12 @@ def shifted_correlation( if len(index) == 0: return np.zeros(len(shifted_idx)) - elif isinstance(index[0], int) or isinstance(index[0], bool): + elif ( + isinstance(index[0], int) + or isinstance(index[0], bool) + or isinstance(index[0], np.integer) + or isinstance(index[0], np.bool_) + ): return get_correlation(frames, start_frame, index, shifted_idx) else: correlations = [] @@ -110,7 +115,12 @@ def shifted_correlation( if len(ind) == 0: correlations.append(np.zeros(len(shifted_idx))) - elif isinstance(ind[0], int) or isinstance(ind[0], bool): + elif ( + isinstance(ind[0], int) + or isinstance(ind[0], bool) + or isinstance(ind[0], np.integer) + or isinstance(ind[0], np.bool_) + ): correlations.append( get_correlation(frames, start_frame, ind, shifted_idx) )