Added multi_selector to shifted_correlation.

This commit is contained in:
Sebastian Kloth 2023-07-22 11:57:17 +02:00
parent bc078f890b
commit 69715fd2a7
2 changed files with 48 additions and 13 deletions

View File

@ -162,7 +162,8 @@ class CoordinateFrame(np.ndarray):
if self.mode != "nojump": if self.mode != "nojump":
if self.mode is not None: if self.mode is not None:
logger.warn( logger.warn(
"Combining Nojump with other Coordinate modes is not supported and may cause unexpected results." "Combining Nojump with other Coordinate modes is not supported and "
"may cause unexpected results."
) )
frame = nojump(self) frame = nojump(self)
frame.mode = "nojump" frame.mode = "nojump"

View File

@ -99,12 +99,13 @@ def shifted_correlation(
function, function,
frames, frames,
selector=None, selector=None,
multi_selector=None,
segments=10, segments=10,
skip=0.1, skip=0.1,
window=0.5, window=0.5,
average=True, average=True,
points=100, points=100,
nodes=8, nodes=1,
): ):
""" """
Calculate the time series for a correlation function. Calculate the time series for a correlation function.
@ -118,7 +119,13 @@ def shifted_correlation(
selector (opt.): selector (opt.):
A function that returns the indices depending on A function that returns the indices depending on
the staring frame for which particles the the staring frame for which particles the
correlation should be calculated correlation should be calculated.
Can not be used with multi_selector.
multi_selector (opt.):
A function that returns multiple lists of indices depending on
the staring frame for which particles the
correlation should be calculated.
Can not be used with selector.
segments (int, opt.): segments (int, opt.):
The number of segments the time window will be The number of segments the time window will be
shifted shifted
@ -150,14 +157,9 @@ def shifted_correlation(
>>> time, data = shifted_correlation(msd, coords) >>> time, data = shifted_correlation(msd, coords)
""" """
def get_correlation(start_frame, idx, selector=None): def get_correlation(frames, start_frame, index, shifted_idx):
shifted_idx = idx + start_frame
if selector:
index = selector(frames[start_frame])
else:
index = np.arange(len(frames[start_frame]))
if len(index) == 0: if len(index) == 0:
return np.zeros(len(idx)) correlation = np.zeros(len(shifted_idx))
else: else:
start = frames[start_frame][index] start = frames[start_frame][index]
correlation = np.array( correlation = np.array(
@ -165,6 +167,32 @@ def shifted_correlation(
) )
return correlation return correlation
def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None):
shifted_idx = idx + start_frame
if selector is None and multi_selector is None:
index = np.arange(len(frames[start_frame]))
return get_correlation(frames, start_frame, index, shifted_idx)
elif selector is not None and multi_selector is not None:
raise ValueError(
"selector and multi_selector can not be used at the same time"
)
elif selector is not None:
index = selector(frames[start_frame])
return get_correlation(frames, start_frame, index, shifted_idx)
elif multi_selector is not None:
indices = multi_selector(frames[start_frame])
correlation = []
for index in indices:
correlation.append(
get_correlation(
frames, start_frame, index, shifted_idx
)
)
return correlation
if 1 - skip < window: if 1 - skip < window:
window = 1 - skip window = 1 - skip
@ -186,7 +214,9 @@ def shifted_correlation(
if nodes == 1: if nodes == 1:
result = np.array( result = np.array(
[ [
get_correlation(start_frame, idx=idx, selector=selector) apply_selector(start_frame, frames=frames, idx=idx,
selector=selector,
multi_selector=multi_selector)
for start_frame in start_frames for start_frame in start_frames
] ]
) )
@ -198,7 +228,11 @@ def shifted_correlation(
try: try:
result = np.array( result = np.array(
pool.map( pool.map(
partial(get_correlation, idx=idx, selector=selector), start_frames partial(apply_selector,
frames=frames,
idx=idx,
selector=selector,
multi_selector=multi_selector), start_frames
) )
) )
finally: finally: