From 7d8c5d849dccaf7a4892b109ef2cf2353f5a7505 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Fri, 14 Jun 2024 10:10:37 +0200 Subject: [PATCH] Added tests for shifted_correlation --- test/test_correlation.py | 57 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/test_correlation.py diff --git a/test/test_correlation.py b/test/test_correlation.py new file mode 100644 index 0000000..d8dc345 --- /dev/null +++ b/test/test_correlation.py @@ -0,0 +1,57 @@ +import os +import pytest + +import mdevaluate +from mdevaluate import correlation +import numpy as np + + +@pytest.fixture +def trajectory(request): + return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/water")) + + +def test_shifted_correlation(trajectory): + test_array = np.array([100, 82, 65, 49, 39, 29, 20, 13, 7]) + OW = trajectory.subset(atom_name="OW") + t, result = correlation.shifted_correlation( + correlation.isf, OW, segments=10, skip=0.1, points=10 + ) + assert (np.array(result * 100, dtype=int) == test_array).all() + + +def test_shifted_correlation_no_average(trajectory): + t, result = correlation.shifted_correlation( + correlation.isf, trajectory, segments=10, skip=0.1, points=5, average=False + ) + assert result.shape == (10, 5) + + +def test_shifted_correlation_selector(trajectory): + test_array = np.array([100, 82, 64, 48, 37, 28, 19, 11, 5]) + + def selector(frame): + index = np.argwhere((frame[:, 0] >= 0) * (frame[:, 0] < 1)) + return index.flatten() + + OW = trajectory.subset(atom_name="OW") + t, result = correlation.shifted_correlation( + correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector + ) + assert (np.array(result * 100, dtype=int) == test_array).all() + + +def test_shifted_correlation_multi_selector(trajectory): + def selector(frame): + indices = [] + for i in range(3): + x = frame[:, 0].flatten() + index = np.argwhere((x >= i) * (x < i + 1)) + indices.append(index.flatten()) + return indices + + OW = trajectory.subset(atom_name="OW") + t, result = correlation.shifted_correlation( + correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector + ) + assert result.shape == (3, 9)