58 lines
1.7 KiB
Python
58 lines
1.7 KiB
Python
import os
|
|
import pytest
|
|
|
|
import mdevaluate
|
|
from mdevaluate import correlation
|
|
import numpy as np
|
|
|
|
|
|
@pytest.fixture
|
|
def trajectory(request):
|
|
return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/water"))
|
|
|
|
|
|
def test_shifted_correlation(trajectory):
|
|
test_array = np.array([100, 82, 65, 49, 39, 29, 20, 13, 7])
|
|
OW = trajectory.subset(atom_name="OW")
|
|
t, result = correlation.shifted_correlation(
|
|
correlation.isf, OW, segments=10, skip=0.1, points=10
|
|
)
|
|
assert (np.array(result * 100, dtype=int) == test_array).all()
|
|
|
|
|
|
def test_shifted_correlation_no_average(trajectory):
|
|
t, result = correlation.shifted_correlation(
|
|
correlation.isf, trajectory, segments=10, skip=0.1, points=5, average=False
|
|
)
|
|
assert result.shape == (10, 5)
|
|
|
|
|
|
def test_shifted_correlation_selector(trajectory):
|
|
test_array = np.array([100, 82, 64, 48, 37, 28, 19, 11, 5])
|
|
|
|
def selector(frame):
|
|
index = np.argwhere((frame[:, 0] >= 0) * (frame[:, 0] < 1))
|
|
return index.flatten()
|
|
|
|
OW = trajectory.subset(atom_name="OW")
|
|
t, result = correlation.shifted_correlation(
|
|
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
|
)
|
|
assert (np.array(result * 100, dtype=int) == test_array).all()
|
|
|
|
|
|
def test_shifted_correlation_multi_selector(trajectory):
|
|
def selector(frame):
|
|
indices = []
|
|
for i in range(3):
|
|
x = frame[:, 0].flatten()
|
|
index = np.argwhere((x >= i) * (x < i + 1))
|
|
indices.append(index.flatten())
|
|
return indices
|
|
|
|
OW = trajectory.subset(atom_name="OW")
|
|
t, result = correlation.shifted_correlation(
|
|
correlation.isf, OW, segments=10, skip=0.1, points=10, selector=selector
|
|
)
|
|
assert result.shape == (3, 9)
|