Added multi_selector to shifted_correlation.

This commit is contained in:
Sebastian Kloth 2023-07-22 11:57:17 +02:00
parent bc078f890b
commit 69715fd2a7
2 changed files with 48 additions and 13 deletions

View File

@ -162,7 +162,8 @@ class CoordinateFrame(np.ndarray):
if self.mode != "nojump":
if self.mode is not None:
logger.warn(
"Combining Nojump with other Coordinate modes is not supported and may cause unexpected results."
"Combining Nojump with other Coordinate modes is not supported and "
"may cause unexpected results."
)
frame = nojump(self)
frame.mode = "nojump"

View File

@ -99,12 +99,13 @@ def shifted_correlation(
function,
frames,
selector=None,
multi_selector=None,
segments=10,
skip=0.1,
window=0.5,
average=True,
points=100,
nodes=8,
nodes=1,
):
"""
Calculate the time series for a correlation function.
@ -118,7 +119,13 @@ def shifted_correlation(
selector (opt.):
A function that returns the indices depending on
the staring frame for which particles the
correlation should be calculated
correlation should be calculated.
Can not be used with multi_selector.
multi_selector (opt.):
A function that returns multiple lists of indices depending on
the staring frame for which particles the
correlation should be calculated.
Can not be used with selector.
segments (int, opt.):
The number of segments the time window will be
shifted
@ -150,14 +157,9 @@ def shifted_correlation(
>>> time, data = shifted_correlation(msd, coords)
"""
def get_correlation(start_frame, idx, selector=None):
shifted_idx = idx + start_frame
if selector:
index = selector(frames[start_frame])
else:
index = np.arange(len(frames[start_frame]))
def get_correlation(frames, start_frame, index, shifted_idx):
if len(index) == 0:
return np.zeros(len(idx))
correlation = np.zeros(len(shifted_idx))
else:
start = frames[start_frame][index]
correlation = np.array(
@ -165,6 +167,32 @@ def shifted_correlation(
)
return correlation
def apply_selector(start_frame, frames, idx, selector=None, multi_selector=None):
shifted_idx = idx + start_frame
if selector is None and multi_selector is None:
index = np.arange(len(frames[start_frame]))
return get_correlation(frames, start_frame, index, shifted_idx)
elif selector is not None and multi_selector is not None:
raise ValueError(
"selector and multi_selector can not be used at the same time"
)
elif selector is not None:
index = selector(frames[start_frame])
return get_correlation(frames, start_frame, index, shifted_idx)
elif multi_selector is not None:
indices = multi_selector(frames[start_frame])
correlation = []
for index in indices:
correlation.append(
get_correlation(
frames, start_frame, index, shifted_idx
)
)
return correlation
if 1 - skip < window:
window = 1 - skip
@ -175,7 +203,7 @@ def shifted_correlation(
num=segments,
endpoint=False,
dtype=int,
)
)
)
num_frames = int(len(frames) * window)
@ -186,7 +214,9 @@ def shifted_correlation(
if nodes == 1:
result = np.array(
[
get_correlation(start_frame, idx=idx, selector=selector)
apply_selector(start_frame, frames=frames, idx=idx,
selector=selector,
multi_selector=multi_selector)
for start_frame in start_frames
]
)
@ -198,7 +228,11 @@ def shifted_correlation(
try:
result = np.array(
pool.map(
partial(get_correlation, idx=idx, selector=selector), start_frames
partial(apply_selector,
frames=frames,
idx=idx,
selector=selector,
multi_selector=multi_selector), start_frames
)
)
finally: