diff --git a/src/mdevaluate/extra/chill.py b/src/mdevaluate/extra/chill.py index f626915..11ec8b2 100644 --- a/src/mdevaluate/extra/chill.py +++ b/src/mdevaluate/extra/chill.py @@ -11,7 +11,7 @@ from mdevaluate.coordinates import CoordinateFrame, Coordinates from mdevaluate.pbc import pbc_points -def a_ij(atoms: ArrayLike, N: int = 4, l: int = 3) -> tuple[NDArray, NDArray]: +def calc_aij(atoms: ArrayLike, N: int = 4, l: int = 3) -> tuple[NDArray, NDArray]: tree = KDTree(atoms) dist, indices = tree.query(atoms, N + 1) @@ -84,18 +84,18 @@ def count_ice_types(iceTypes: NDArray) -> NDArray: def selector_ice( - start_frame: CoordinateFrame, - traj: Coordinates, + oxygen_atoms_water: CoordinateFrame, chosen_ice_types: ArrayLike, combined: bool = True, + next_neighbor_distance: float = 0.35, ) -> NDArray: - oxygen = traj.subset(atom_name="OW") - atoms = oxygen[start_frame.step] - atoms = atoms % np.diag(atoms.box) - atoms_PBC = pbc_points(atoms, thickness=1) - aij, indices = a_ij(atoms_PBC) + atoms = oxygen_atoms_water + atoms_PBC = pbc_points(atoms, thickness=next_neighbor_distance * 2.2) + aij, indices = calc_aij(atoms_PBC) tree = KDTree(atoms_PBC) - neighbors = tree.query_ball_point(atoms_PBC, 0.35, return_length=True) + neighbors = tree.query_ball_point( + atoms_PBC, next_neighbor_distance, return_length=True + ) - 1 index_SOL = atoms_PBC.tolist().index(atoms[0].tolist()) index_SOL = np.arange(index_SOL, index_SOL + len(atoms)) ice_Types = classify_ice(aij, indices, neighbors, index_SOL) @@ -117,9 +117,9 @@ def selector_ice( def ice_types(trajectory: Coordinates, segments: int = 10000) -> pd.DataFrame: def ice_types_distribution(frame: CoordinateFrame, selector: Callable) -> NDArray: atoms_PBC = pbc_points(frame, thickness=1) - aij, indices = a_ij(atoms_PBC) + aij, indices = calc_aij(atoms_PBC) tree = KDTree(atoms_PBC) - neighbors = tree.query_ball_point(atoms_PBC, 0.35, return_length=True) + neighbors = tree.query_ball_point(atoms_PBC, 0.35, return_length=True) - 1 index = selector(frame, atoms_PBC) ice_types_data = classify_ice(aij, indices, neighbors, index) ice_parts_data = count_ice_types(ice_types_data)