diff --git a/src/mdevaluate/distribution.py b/src/mdevaluate/distribution.py index d3d03ca..7f0cd90 100644 --- a/src/mdevaluate/distribution.py +++ b/src/mdevaluate/distribution.py @@ -1,3 +1,5 @@ +from typing import Callable, Optional + import numpy as np from numpy.typing import ArrayLike from scipy import spatial @@ -17,8 +19,14 @@ from .pbc import pbc_diff, pbc_points from .logging import logger -@autosave_data(nargs=2, kwargs_keys=("coordinates_b",), version="time_average-1") -def time_average(function, coordinates, coordinates_b=None, pool=None): +@autosave_data(nargs=2, kwargs_keys=("coordinates_b",)) +def time_average( + function: Callable, + coordinates: Coordinates, + coordinates_b: Optional[Coordinates] = None, + skip: float = 0.1, + segments: int = 100, +) -> np.ndarray: """ Compute the time average of a function. @@ -27,35 +35,23 @@ def time_average(function, coordinates, coordinates_b=None, pool=None): The function that will be averaged, it has to accept exactly one argument which is the current atom set coordinates: The coordinates object of the simulation - pool (multiprocessing.Pool, opt.): - A multiprocessing pool which will be used for cocurrent calculation of the - averaged function - + coordinates_b: Additional coordinates object of the simulation + skip: + segments: """ - if pool is not None: - _map = pool.imap + frame_indices = np.unique( + np.int_( + np.linspace(len(coordinates) * skip, len(coordinates) - 1, num=segments) + ) + ) + if coordinates_b is None: + result = [function(coordinates[frame_index]) for frame_index in frame_indices] else: - _map = map - - number_of_averages = 0 - result = 0 - - if coordinates_b is not None: - if coordinates._slice != coordinates_b._slice: - logger.warning("Different slice for coordinates and coordinates_b.") - coordinate_iter = (iter(coordinates), iter(coordinates_b)) - else: - coordinate_iter = (iter(coordinates),) - - evaluated = _map(function, *coordinate_iter) - - for ev in evaluated: - number_of_averages += 1 - result += ev - if number_of_averages % 100 == 0: - logger.debug("time_average: %d", number_of_averages) - - return result / number_of_averages + result = [ + function(coordinates[frame_index], coordinates_b[frame_index]) + for frame_index in frame_indices + ] + return np.mean(result, axis=0) def time_histogram(function, coordinates, bins, hist_range, pool=None):