Added tests for shifted_correlation

This commit is contained in:
Sebastian Kloth 2024-06-14 10:10:37 +02:00
parent c09549902a
commit 7d8c5d849d

57
test/test_correlation.py Normal file
View File

@ -0,0 +1,57 @@
import os
import pytest
import mdevaluate
from mdevaluate import correlation
import numpy as np
@pytest.fixture
def trajectory(request):
return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/water"))
def test_shifted_correlation(trajectory):
test_array = np.array([100, 82, 65, 49, 39, 29, 20, 13, 7])
OW = trajectory.subset(atom_name="OW")
t, result = correlation.shifted_correlation(
correlation.isf, OW, segments=10, skip=0.1, points=10
)
assert (np.array(result * 100, dtype=int) == test_array).all()
def test_shifted_correlation_no_average(trajectory):
t, result = correlation.shifted_correlation(
correlation.isf, trajectory, segments=10, skip=0.1, points=5, average=False
)
assert result.shape == (10, 5)
def test_shifted_correlation_selector(trajectory):
test_array = np.array([100, 82, 64, 48, 37, 28, 19, 11, 5])
def selector(frame):
index = np.argwhere((frame[:, 0] >= 0) * (frame[:, 0] < 1))
return index.flatten()
OW = trajectory.subset(atom_name="OW")
t, result = correlation.shifted_correlation(
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
)
assert (np.array(result * 100, dtype=int) == test_array).all()
def test_shifted_correlation_multi_selector(trajectory):
def selector(frame):
indices = []
for i in range(3):
x = frame[:, 0].flatten()
index = np.argwhere((x >= i) * (x < i + 1))
indices.append(index.flatten())
return indices
OW = trajectory.subset(atom_name="OW")
t, result = correlation.shifted_correlation(
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
)
assert result.shape == (3, 9)