Added tests for shifted_correlation
This commit is contained in:
parent
c09549902a
commit
7d8c5d849d
57
test/test_correlation.py
Normal file
57
test/test_correlation.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
import os
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mdevaluate
|
||||||
|
from mdevaluate import correlation
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def trajectory(request):
|
||||||
|
return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/water"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_shifted_correlation(trajectory):
|
||||||
|
test_array = np.array([100, 82, 65, 49, 39, 29, 20, 13, 7])
|
||||||
|
OW = trajectory.subset(atom_name="OW")
|
||||||
|
t, result = correlation.shifted_correlation(
|
||||||
|
correlation.isf, OW, segments=10, skip=0.1, points=10
|
||||||
|
)
|
||||||
|
assert (np.array(result * 100, dtype=int) == test_array).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_shifted_correlation_no_average(trajectory):
|
||||||
|
t, result = correlation.shifted_correlation(
|
||||||
|
correlation.isf, trajectory, segments=10, skip=0.1, points=5, average=False
|
||||||
|
)
|
||||||
|
assert result.shape == (10, 5)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shifted_correlation_selector(trajectory):
|
||||||
|
test_array = np.array([100, 82, 64, 48, 37, 28, 19, 11, 5])
|
||||||
|
|
||||||
|
def selector(frame):
|
||||||
|
index = np.argwhere((frame[:, 0] >= 0) * (frame[:, 0] < 1))
|
||||||
|
return index.flatten()
|
||||||
|
|
||||||
|
OW = trajectory.subset(atom_name="OW")
|
||||||
|
t, result = correlation.shifted_correlation(
|
||||||
|
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
||||||
|
)
|
||||||
|
assert (np.array(result * 100, dtype=int) == test_array).all()
|
||||||
|
|
||||||
|
|
||||||
|
def test_shifted_correlation_multi_selector(trajectory):
|
||||||
|
def selector(frame):
|
||||||
|
indices = []
|
||||||
|
for i in range(3):
|
||||||
|
x = frame[:, 0].flatten()
|
||||||
|
index = np.argwhere((x >= i) * (x < i + 1))
|
||||||
|
indices.append(index.flatten())
|
||||||
|
return indices
|
||||||
|
|
||||||
|
OW = trajectory.subset(atom_name="OW")
|
||||||
|
t, result = correlation.shifted_correlation(
|
||||||
|
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
||||||
|
)
|
||||||
|
assert result.shape == (3, 9)
|
Loading…
Reference in New Issue
Block a user