diff --git a/src/mdevaluate/correlation.py b/src/mdevaluate/correlation.py index 7e89b3d..b601570 100644 --- a/src/mdevaluate/correlation.py +++ b/src/mdevaluate/correlation.py @@ -13,12 +13,97 @@ from .pbc import pbc_diff, pbc_points from .coordinates import Coordinates, CoordinateFrame, displacements_without_drift -def log_indices(first: int, last: int, num: int = 100) -> np.ndarray: - ls = np.logspace(0, np.log10(last - first + 1), num=num) - return np.unique(np.int_(ls) - 1 + first) +def _is_multi_selector(selection): + if len(selection) == 0: + return False + elif ( + isinstance(selection[0], int) + or isinstance(selection[0], bool) + or isinstance(selection[0], np.integer) + or isinstance(selection[0], np.bool_) + ): + return False + else: + for indices in selection: + if len(indices) == 0: + continue + elif ( + isinstance(indices[0], int) + or isinstance(indices[0], bool) + or isinstance(indices[0], np.integer) + or isinstance(indices[0], np.bool_) + ): + return True + else: + raise ValueError( + "selector has more than two dimensions or does not " + "contain int or bool types" + ) + + +def _calc_correlation( + frames: Coordinates, + start_frame: CoordinateFrame, + function: Callable, + selection: np.ndarray, + shifted_idx: np.ndarray, +) -> np.ndarray: + if len(selection) == 0: + correlation = np.zeros(len(shifted_idx)) + else: + start = start_frame[selection] + correlation = np.array( + [ + function(start, frames[frame_index][selection]) + for frame_index in shifted_idx + ] + ) + return correlation + + +def _calc_correlation_multi( + frames: Coordinates, + start_frame: CoordinateFrame, + function: Callable, + selection: np.ndarray, + shifted_idx: np.ndarray, +) -> np.ndarray: + correlations = np.zeros((len(selection), len(shifted_idx))) + for i, frame_index in enumerate(shifted_idx): + frame = frames[frame_index] + for j, current_selection in enumerate(selection): + if len(selection) == 0: + correlations[j, i] = 0 + else: + correlations[j, i] = function( + start_frame[current_selection], frame[current_selection] + ) + return correlations + + +def _average_correlation(result): + averaged_result = [] + for n in range(result.shape[1]): + clean_result = [] + for entry in result[:, n]: + if np.all(entry == 0): + continue + else: + clean_result.append(entry) + averaged_result.append(np.average(np.array(clean_result), axis=0)) + return np.array(averaged_result) + + +def _average_correlation_multi(result): + clean_result = [] + for entry in result: + if np.all(entry == 0): + continue + else: + clean_result.append(entry) + return np.average(np.array(clean_result), axis=0) -@autosave_data(2) def shifted_correlation( function: Callable, frames: Coordinates, @@ -28,113 +113,82 @@ def shifted_correlation( window: float = 0.5, average: bool = True, points: int = 100, -) -> (np.ndarray, np.ndarray): +) -> tuple[np.ndarray, np.ndarray]: + """Compute a time-dependent correlation function for a given trajectory. + + To improve statistics, multiple (possibly overlapping) windows will be + layed over the whole trajectory and the correlation is computed for them separately. + The start frames of the windows are spaced linearly over the valid region of + the trajectory (skipping frames in the beginning given by skip parameter). + + The points within each window are spaced logarithmically. + + Only a certain subset of the given atoms may be selected for each window + individually using a selector function. + + Note that this function is specifically optimized for multi selectors, which select + multiple selection sets per window, for which the correlation is to be computed + separately. + + + Arguments + --------- + function: + The (correlation) function to evaluate. + Should be of the form (CoordinateFrame, CoordinateFrame) -> float + + frames: + Trajectory to evaluate on + + selector: (optional) + Selection function so select only certain selection sets for each start frame. + Should be of the form + (CoordinateFrame) -> list[A] + where A is something you can index an ndarray with. + For example a list of indices or a bool array. + Must return the same number of selection sets for every frame. + + segments: + Number of start frames + + skip: + Percentage of trajectory to skip from the start + + window: + Length of each segment given as percentage of trajectory + + average: + Whether to return averaged results. + See below for details on the returned ndarray. + + points: + Number of points per segment + + + Returns + ------- + times: ndarray + 1d array of time differences to start frame + result: ndarray + 2d ndarray of averaged (or non-averaged) correlations. + + When average==True (default) the returned array will be of the shape (S, P) + where S is the number of selection sets and P the number of points per window. + For selection sets that where empty for all start frames all data points will be + zero. + + When average==False the returned array will be of shape (W, S) with + dtype=object. The elements are either ndarrays of shape (P,) containing the + correlation data for the specific window and selection set or None if the + corresponding selection set was empty. + W is the number of segments (windows). + S and P are the same as for average==True. + """ - Calculate the time series for a correlation function. - - The times at which the correlation is calculated are determined by - a logarithmic distribution. - - Args: - function: The function that should be correlated - frames: The coordinates of the simulation data - selector (opt.): - A function that returns the indices depending on - the staring frame for which particles the - correlation should be calculated. - segments (int, opt.): - The number of segments the time window will be - shifted - skip (float, opt.): - The fraction of the trajectory that will be skipped - at the beginning, if this is None the start index - of the frames slice will be used, which defaults - to 0.1. - window (float, opt.): - The fraction of the simulation the time series will - cover - average (bool, opt.): - If True, returns averaged correlation function - points (int, opt.): - The number of timeshifts for which the correlation - should be calculated - Returns: - tuple: - A list of length N that contains the timeshiftes of the frames at which - the time series was calculated and a numpy array of shape (segments, N) - that holds the (non-avaraged) correlation data - - Example: - Calculating the mean square displacement of a coordinate object - named ``coords``: - - >>> time, data = shifted_correlation(msd, coords) - """ - - def get_correlation( - frames: CoordinateFrame, - start_frame: CoordinateFrame, - index: np.ndarray, - shifted_idx: np.ndarray, - ) -> np.ndarray: - if len(index) == 0: - correlation = np.zeros(len(shifted_idx)) - else: - start = frames[start_frame][index] - correlation = np.array( - [function(start, frames[frame][index]) for frame in shifted_idx] - ) - return correlation - - def apply_selector( - start_frame: CoordinateFrame, - frames: CoordinateFrame, - idx: np.ndarray, - selector: Optional[Callable] = None, - ): - shifted_idx = idx + start_frame - - if selector is None: - index = np.arange(len(frames[start_frame])) - return get_correlation(frames, start_frame, index, shifted_idx) - else: - index = selector(frames[start_frame]) - if len(index) == 0: - return np.zeros(len(shifted_idx)) - - elif ( - isinstance(index[0], int) - or isinstance(index[0], bool) - or isinstance(index[0], np.integer) - or isinstance(index[0], np.bool_) - ): - return get_correlation(frames, start_frame, index, shifted_idx) - else: - correlations = [] - for ind in index: - if len(ind) == 0: - correlations.append(np.zeros(len(shifted_idx))) - - elif ( - isinstance(ind[0], int) - or isinstance(ind[0], bool) - or isinstance(ind[0], np.integer) - or isinstance(ind[0], np.bool_) - ): - correlations.append( - get_correlation(frames, start_frame, ind, shifted_idx) - ) - else: - raise ValueError( - "selector has more than two dimensions or does not " - "contain int or bool types" - ) - return correlations - if 1 - skip < window: window = 1 - skip - start_frames = np.unique( + start_frame_indices = np.unique( np.linspace( len(frames) * skip, len(frames) * (1 - window), @@ -144,28 +198,44 @@ def shifted_correlation( ) ) - num_frames = int(len(frames) * window) - ls = np.logspace(0, np.log10(num_frames + 1), num=points) - idx = np.unique(np.int_(ls) - 1) - t = np.array([frames[i].time for i in idx]) - frames[0].time - - result = np.array( - [ - apply_selector(start_frame, frames=frames, idx=idx, selector=selector) - for start_frame in start_frames - ] + num_frames_per_window = int(len(frames) * window) + logspaced_indices = np.logspace(0, np.log10(num_frames_per_window + 1), num=points) + logspaced_indices = np.unique(np.int_(logspaced_indices) - 1) + logspaced_time = ( + np.array([frames[i].time for i in logspaced_indices]) - frames[0].time ) + if selector is None: + multi_selector = False + else: + selection = selector(frames[0]) + multi_selector = _is_multi_selector(selection) + + result = [] + for start_frame_index in start_frame_indices: + shifted_idx = logspaced_indices + start_frame_index + start_frame = frames[start_frame_index] + if selector is None: + selection = np.arange(len(start_frame)) + else: + selection = selector(start_frame) + if multi_selector: + result_segment = _calc_correlation_multi( + frames, start_frame, function, selection, shifted_idx + ) + else: + result_segment = _calc_correlation( + frames, start_frame, function, selection, shifted_idx + ) + result.append(result_segment) + result = np.array(result) + if average: - clean_result = [] - for entry in result: - if np.all(entry == 0): - continue - else: - clean_result.append(entry) - result = np.array(clean_result) - result = np.average(result, axis=0) - return t, result + if multi_selector: + result = _average_correlation_multi(result) + else: + result = _average_correlation(result) + return logspaced_time, result def msd( @@ -184,11 +254,11 @@ def msd( if axis == "all": return (displacements**2).sum(axis=1).mean() elif axis == "xy" or axis == "yx": - return (displacements[:, [0, 1]]**2).sum(axis=1).mean() + return (displacements[:, [0, 1]] ** 2).sum(axis=1).mean() elif axis == "xz" or axis == "zx": - return (displacements[:, [0, 2]]**2).sum(axis=1).mean() + return (displacements[:, [0, 2]] ** 2).sum(axis=1).mean() elif axis == "yz" or axis == "zy": - return (displacements[:, [1, 2]]**2).sum(axis=1).mean() + return (displacements[:, [1, 2]] ** 2).sum(axis=1).mean() elif axis == "x": return (displacements[:, 0] ** 2).mean() elif axis == "y": @@ -218,13 +288,13 @@ def isf( distance = (displacements**2).sum(axis=1) ** 0.5 return np.sinc(distance * q / np.pi).mean() elif axis == "xy" or axis == "yx": - distance = (displacements[:, [0, 1]]**2).sum(axis=1) ** 0.5 + distance = (displacements[:, [0, 1]] ** 2).sum(axis=1) ** 0.5 return np.real(jn(0, distance * q)).mean() elif axis == "xz" or axis == "zx": - distance = (displacements[:, [0, 2]]**2).sum(axis=1) ** 0.5 + distance = (displacements[:, [0, 2]] ** 2).sum(axis=1) ** 0.5 return np.real(jn(0, distance * q)).mean() elif axis == "yz" or axis == "zy": - distance = (displacements[:, [1, 2]]**2).sum(axis=1) ** 0.5 + distance = (displacements[:, [1, 2]] ** 2).sum(axis=1) ** 0.5 return np.real(jn(0, distance * q)).mean() elif axis == "x": distance = np.abs(displacements[:, 0]) @@ -278,11 +348,11 @@ def van_hove_self( if axis == "all": delta_r = (vectors**2).sum(axis=1) ** 0.5 elif axis == "xy" or axis == "yx": - delta_r = (vectors[:, [0, 1]]**2).sum(axis=1) ** 0.5 + delta_r = (vectors[:, [0, 1]] ** 2).sum(axis=1) ** 0.5 elif axis == "xz" or axis == "zx": - delta_r = (vectors[:, [0, 2]]**2).sum(axis=1) ** 0.5 + delta_r = (vectors[:, [0, 2]] ** 2).sum(axis=1) ** 0.5 elif axis == "yz" or axis == "zy": - delta_r = (vectors[:, [1, 2]]**2).sum(axis=1) ** 0.5 + delta_r = (vectors[:, [1, 2]] ** 2).sum(axis=1) ** 0.5 elif axis == "x": delta_r = np.abs(vectors[:, 0]) elif axis == "y": @@ -445,13 +515,13 @@ def non_gaussian_parameter( r = (vectors**2).sum(axis=1) dimensions = 3 elif axis == "xy" or axis == "yx": - r = (vectors[:, [0, 1]]**2).sum(axis=1) + r = (vectors[:, [0, 1]] ** 2).sum(axis=1) dimensions = 2 elif axis == "xz" or axis == "zx": - r = (vectors[:, [0, 2]]**2).sum(axis=1) + r = (vectors[:, [0, 2]] ** 2).sum(axis=1) dimensions = 2 elif axis == "yz" or axis == "zy": - r = (vectors[:, [1, 2]]**2).sum(axis=1) + r = (vectors[:, [1, 2]] ** 2).sum(axis=1) dimensions = 2 elif axis == "x": r = vectors[:, 0] ** 2