diff --git a/src/mdevaluate/correlation.py b/src/mdevaluate/correlation.py index 7d2eee6..d78c43a 100644 --- a/src/mdevaluate/correlation.py +++ b/src/mdevaluate/correlation.py @@ -99,20 +99,18 @@ def shifted_correlation( return get_correlation(frames, start_frame, index, shifted_idx) else: index = selector(frames[start_frame]) - if index.ndim == 1: + if len(index) == 0: + return np.zeros(len(idx)) + + if isinstance(index[0], int) or isinstance(index[0], bool): return get_correlation(frames, start_frame, index, shifted_idx) - elif index.ndim == 2: + else: correlations = [] for ind in index: correlations.append( get_correlation(frames, start_frame, ind, shifted_idx) ) return correlations - else: - raise ValueError( - f"Index list of selector has {index.ndim} dimensions, " - "but should have 1 or 2" - ) if 1 - skip < window: window = 1 - skip