Added tests for shifted_correlation
This commit is contained in:
parent
c09549902a
commit
7d8c5d849d
57
test/test_correlation.py
Normal file
57
test/test_correlation.py
Normal file
@ -0,0 +1,57 @@
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import mdevaluate
|
||||
from mdevaluate import correlation
|
||||
import numpy as np
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def trajectory(request):
|
||||
return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/water"))
|
||||
|
||||
|
||||
def test_shifted_correlation(trajectory):
|
||||
test_array = np.array([100, 82, 65, 49, 39, 29, 20, 13, 7])
|
||||
OW = trajectory.subset(atom_name="OW")
|
||||
t, result = correlation.shifted_correlation(
|
||||
correlation.isf, OW, segments=10, skip=0.1, points=10
|
||||
)
|
||||
assert (np.array(result * 100, dtype=int) == test_array).all()
|
||||
|
||||
|
||||
def test_shifted_correlation_no_average(trajectory):
|
||||
t, result = correlation.shifted_correlation(
|
||||
correlation.isf, trajectory, segments=10, skip=0.1, points=5, average=False
|
||||
)
|
||||
assert result.shape == (10, 5)
|
||||
|
||||
|
||||
def test_shifted_correlation_selector(trajectory):
|
||||
test_array = np.array([100, 82, 64, 48, 37, 28, 19, 11, 5])
|
||||
|
||||
def selector(frame):
|
||||
index = np.argwhere((frame[:, 0] >= 0) * (frame[:, 0] < 1))
|
||||
return index.flatten()
|
||||
|
||||
OW = trajectory.subset(atom_name="OW")
|
||||
t, result = correlation.shifted_correlation(
|
||||
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
||||
)
|
||||
assert (np.array(result * 100, dtype=int) == test_array).all()
|
||||
|
||||
|
||||
def test_shifted_correlation_multi_selector(trajectory):
|
||||
def selector(frame):
|
||||
indices = []
|
||||
for i in range(3):
|
||||
x = frame[:, 0].flatten()
|
||||
index = np.argwhere((x >= i) * (x < i + 1))
|
||||
indices.append(index.flatten())
|
||||
return indices
|
||||
|
||||
OW = trajectory.subset(atom_name="OW")
|
||||
t, result = correlation.shifted_correlation(
|
||||
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
||||
)
|
||||
assert result.shape == (3, 9)
|
Loading…
Reference in New Issue
Block a user