Added multi_selector to shifted_correlation.
This commit is contained in:
		| @@ -162,7 +162,8 @@ class CoordinateFrame(np.ndarray): | |||||||
|         if self.mode != "nojump": |         if self.mode != "nojump": | ||||||
|             if self.mode is not None: |             if self.mode is not None: | ||||||
|                 logger.warn( |                 logger.warn( | ||||||
|                     "Combining Nojump with other Coordinate modes is not supported and may cause unexpected results." |                     "Combining Nojump with other Coordinate modes is not supported and " | ||||||
|  |                     "may cause unexpected results." | ||||||
|                 ) |                 ) | ||||||
|             frame = nojump(self) |             frame = nojump(self) | ||||||
|             frame.mode = "nojump" |             frame.mode = "nojump" | ||||||
|   | |||||||
| @@ -99,12 +99,13 @@ def shifted_correlation( | |||||||
|     function, |     function, | ||||||
|     frames, |     frames, | ||||||
|     selector=None, |     selector=None, | ||||||
|  |     multi_selector=None, | ||||||
|     segments=10, |     segments=10, | ||||||
|     skip=0.1, |     skip=0.1, | ||||||
|     window=0.5, |     window=0.5, | ||||||
|     average=True, |     average=True, | ||||||
|     points=100, |     points=100, | ||||||
|     nodes=8, |     nodes=1, | ||||||
| ): | ): | ||||||
|     """ |     """ | ||||||
|     Calculate the time series for a correlation function. |     Calculate the time series for a correlation function. | ||||||
| @@ -118,7 +119,13 @@ def shifted_correlation( | |||||||
|         selector (opt.): |         selector (opt.): | ||||||
|                     A function that returns the indices depending on |                     A function that returns the indices depending on | ||||||
|                     the staring frame for which particles the |                     the staring frame for which particles the | ||||||
|                     correlation should be calculated |                     correlation should be calculated. | ||||||
|  |                     Can not be used with multi_selector. | ||||||
|  |         multi_selector (opt.): | ||||||
|  |                     A function that returns multiple lists of indices depending on | ||||||
|  |                     the staring frame for which particles the | ||||||
|  |                     correlation should be calculated. | ||||||
|  |                     Can not be used with selector. | ||||||
|         segments (int, opt.): |         segments (int, opt.): | ||||||
|                     The number of segments the time window will be |                     The number of segments the time window will be | ||||||
|                     shifted |                     shifted | ||||||
| @@ -150,14 +157,9 @@ def shifted_correlation( | |||||||
|         >>> time, data = shifted_correlation(msd, coords) |         >>> time, data = shifted_correlation(msd, coords) | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def get_correlation(start_frame, idx, selector=None): |     def get_correlation(frames, start_frame, index, shifted_idx): | ||||||
|         shifted_idx = idx + start_frame |  | ||||||
|         if selector: |  | ||||||
|             index = selector(frames[start_frame]) |  | ||||||
|         else: |  | ||||||
|             index = np.arange(len(frames[start_frame])) |  | ||||||
|         if len(index) == 0: |         if len(index) == 0: | ||||||
|             return np.zeros(len(idx)) |             correlation = np.zeros(len(shifted_idx)) | ||||||
|         else: |         else: | ||||||
|             start = frames[start_frame][index] |             start = frames[start_frame][index] | ||||||
|             correlation = np.array( |             correlation = np.array( | ||||||
| @@ -165,6 +167,32 @@ def shifted_correlation( | |||||||
|             ) |             ) | ||||||
|         return correlation |         return correlation | ||||||
|  |  | ||||||
|  |     def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None): | ||||||
|  |         shifted_idx = idx + start_frame | ||||||
|  |         if selector is None and multi_selector is None: | ||||||
|  |             index = np.arange(len(frames[start_frame])) | ||||||
|  |             return get_correlation(frames, start_frame, index, shifted_idx) | ||||||
|  |  | ||||||
|  |         elif selector is not None and multi_selector is not None: | ||||||
|  |             raise ValueError( | ||||||
|  |                 "selector and multi_selector can not be used at the same time" | ||||||
|  |             ) | ||||||
|  |  | ||||||
|  |         elif selector is not None: | ||||||
|  |             index = selector(frames[start_frame]) | ||||||
|  |             return get_correlation(frames, start_frame, index, shifted_idx) | ||||||
|  |  | ||||||
|  |         elif multi_selector is not None: | ||||||
|  |             indices = multi_selector(frames[start_frame]) | ||||||
|  |             correlation = [] | ||||||
|  |             for index in indices: | ||||||
|  |                 correlation.append( | ||||||
|  |                     get_correlation( | ||||||
|  |                         frames, start_frame, index, shifted_idx | ||||||
|  |                     ) | ||||||
|  |                 ) | ||||||
|  |             return correlation | ||||||
|  |  | ||||||
|     if 1 - skip < window: |     if 1 - skip < window: | ||||||
|         window = 1 - skip |         window = 1 - skip | ||||||
|  |  | ||||||
| @@ -175,7 +203,7 @@ def shifted_correlation( | |||||||
|             num=segments, |             num=segments, | ||||||
|             endpoint=False, |             endpoint=False, | ||||||
|             dtype=int, |             dtype=int, | ||||||
|         ) |             ) | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     num_frames = int(len(frames) * window) |     num_frames = int(len(frames) * window) | ||||||
| @@ -186,7 +214,9 @@ def shifted_correlation( | |||||||
|     if nodes == 1: |     if nodes == 1: | ||||||
|         result = np.array( |         result = np.array( | ||||||
|             [ |             [ | ||||||
|                 get_correlation(start_frame, idx=idx, selector=selector) |                 apply_selector(start_frame, frames=frames, idx=idx, | ||||||
|  |                                selector=selector, | ||||||
|  |                                multi_selector=multi_selector) | ||||||
|                 for start_frame in start_frames |                 for start_frame in start_frames | ||||||
|             ] |             ] | ||||||
|         ) |         ) | ||||||
| @@ -198,7 +228,11 @@ def shifted_correlation( | |||||||
|         try: |         try: | ||||||
|             result = np.array( |             result = np.array( | ||||||
|                 pool.map( |                 pool.map( | ||||||
|                     partial(get_correlation, idx=idx, selector=selector), start_frames |                     partial(apply_selector, | ||||||
|  |                             frames=frames, | ||||||
|  |                             idx=idx, | ||||||
|  |                             selector=selector, | ||||||
|  |                             multi_selector=multi_selector), start_frames | ||||||
|                 ) |                 ) | ||||||
|             ) |             ) | ||||||
|         finally: |         finally: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user