Removed try except block from shifted_correlation.

This commit is contained in:
sebastiankloth 2023-06-16 08:41:19 +02:00
parent f83b0c0cfb
commit a4b1105c54

View File

@ -10,10 +10,12 @@ from .utils import filon_fourier_transformation, coherent_sum, histogram
from .pbc import pbc_diff from .pbc import pbc_diff
from .logging import logger from .logging import logger
def set_has_counter(func): def set_has_counter(func):
func.has_counter = True func.has_counter = True
return func return func
def log_indices(first, last, num=100): def log_indices(first, last, num=100):
ls = np.logspace(0, np.log10(last - first + 1), num=num) ls = np.logspace(0, np.log10(last - first + 1), num=num)
return np.unique(np.int_(ls) - 1 + first) return np.unique(np.int_(ls) - 1 + first)
@ -26,7 +28,6 @@ def correlation(function, frames):
def subensemble_correlation(selector_function, correlation_function=correlation): def subensemble_correlation(selector_function, correlation_function=correlation):
def c(function, frames): def c(function, frames):
iterator = iter(frames) iterator = iter(frames)
start_frame = next(iterator) start_frame = next(iterator)
@ -128,7 +129,7 @@ def shifted_correlation(function, frames, selector=None, segments=10,
Calculating the mean square displacement of a coordinates object named ``coords``: Calculating the mean square displacement of a coordinates object named ``coords``:
>>> time, data = shifted_correlation(msd, coords) >>> time, data = shifted_correlation(msd, coords)
""" """
def get_correlation(start_frame, idx, selector=None): def get_correlation(start_frame, idx, selector=None):
shifted_idx = idx + start_frame shifted_idx = idx + start_frame
if selector: if selector:
@ -156,26 +157,22 @@ def shifted_correlation(function, frames, selector=None, segments=10,
idx = np.unique(np.int_(ls) - 1) idx = np.unique(np.int_(ls) - 1)
t = np.array([frames[i].time for i in idx]) - frames[0].time t = np.array([frames[i].time for i in idx]) - frames[0].time
if nodes==1: if nodes == 1:
result = np.array([get_correlation(start_frame, idx=idx, result = np.array([get_correlation(start_frame, idx=idx,
selector=selector) selector=selector)
for start_frame in start_frames]) for start_frame in start_frames])
else: else:
pool = ProcessPool(nodes=nodes) pool = ProcessPool(nodes=nodes)
try: result = np.array(pool.map(partial(get_correlation, idx=idx,
result = np.array(pool.map(partial(get_correlation, idx=idx, selector=selector),
selector=selector), start_frames))
start_frames)) pool.terminate()
except Exception: pool.restart()
logger.warning("Something went wrong while calculating the shifted correlation!")
finally:
pool.terminate()
pool.restart()
if average == True: if average == True:
clean_result = [] clean_result = []
for entry in result: for entry in result:
if np.all(entry==0): if np.all(entry == 0):
continue continue
else: else:
clean_result.append(entry) clean_result.append(entry)
@ -206,14 +203,15 @@ def isf(start, frame, q, box=None):
def rotational_autocorrelation(onset, frame, order=2): def rotational_autocorrelation(onset, frame, order=2):
""" """
Compute the rotaional autocorrelation of the legendre polynamial for the given vectors. Compute the rotational autocorrelation of the legendre polynomial for the
given vectors.
Args: Args:
onset, frame: CoordinateFrames of vectors onset, frame: CoordinateFrames of vectors
order (opt.): Order of the legendre polynomial. order (opt.): Order of the legendre polynomial.
Returns: Returns:
Skalar value of the correltaion function. Scalar value of the correlation function.
""" """
scalar_prod = (onset * frame).sum(axis=-1) scalar_prod = (onset * frame).sum(axis=-1)
poly = legendre(order) poly = legendre(order)