diff --git a/src/mdevaluate/pbc.py b/src/mdevaluate/pbc.py index f57a60d..88457cb 100644 --- a/src/mdevaluate/pbc.py +++ b/src/mdevaluate/pbc.py @@ -1,54 +1,55 @@ from collections import OrderedDict +from typing import Optional, Union import numpy as np +from numpy.typing import ArrayLike, NDArray from scipy.spatial import cKDTree from itertools import product from .logging import logger +from .coordinates import CoordinateFrame -def pbc_diff(v1, v2=None, box=None): +def pbc_diff( + coords_a: NDArray, coords_b: NDArray, box: Optional[NDArray] = None +) -> NDArray: if box is None: - out = v1 - v2 + out = coords_a - coords_b elif len(getattr(box, "shape", [])) == 1: - out = pbc_diff_rect(v1, v2, box) + out = pbc_diff_rect(coords_a, coords_b, box) elif len(getattr(box, "shape", [])) == 2: - out = pbc_diff_tric(v1, v2, box) + out = pbc_diff_tric(coords_a, coords_b, box) else: raise NotImplementedError("cannot handle box") return out -def pbc_diff_rect(v1, v2, box): +def pbc_diff_rect(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray: """ Calculate the difference of two vectors, considering periodic boundary conditions. """ - if v2 is None: - v = v1 - else: - v = v1 - v2 - + v = coords_a - coords_b s = v / box - v = box * (s - s.round()) + v = box * (s - np.round(s)) return v -def pbc_diff_tric(v1, v2=None, box=None): +def pbc_diff_tric(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray: """ - difference vector for arbitrary pbc + Difference vector for arbitrary pbc Args: box_matrix: CoordinateFrame.box """ if len(box.shape) == 1: box = np.diag(box) - if v1.shape == (3,): - v1 = v1.reshape((1, 3)) # quick 'n dirty - if v2.shape == (3,): - v2 = v2.reshape((1, 3)) + if coords_a.shape == (3,): + coords_a = coords_a.reshape((1, 3)) # quick 'n dirty + if coords_b.shape == (3,): + coords_b = coords_b.reshape((1, 3)) if box is not None: - r3 = np.subtract(v1, v2) + r3 = np.subtract(coords_a, coords_b) r2 = np.subtract( r3, (np.rint(np.divide(r3[:, 2], box[2][2])))[:, np.newaxis] @@ -65,68 +66,17 @@ def pbc_diff_tric(v1, v2=None, box=None): * box[0][np.newaxis, :], ) else: - v = v1 - v2 + v = coords_a - coords_b return v -def pbc_dist(a1, a2, box=None): - return ((pbc_diff(a1, a2, box) ** 2).sum(axis=1)) ** 0.5 +def pbc_dist( + atoms_a: CoordinateFrame, atoms_b: CoordinateFrame, box: Optional[NDArray] = None +) -> ArrayLike: + return ((pbc_diff(atoms_a, atoms_b, box) ** 2).sum(axis=1)) ** 0.5 -def pbc_extend(c, box): - """ - in: c is frame, box is frame.box - out: all atoms in frame and their perio. image (shape => array(len(c)*27,3)) - """ - c = np.asarray(c) - if c.shape == (3,): - c = c.reshape((1, 3)) # quick 'n dirty - comb = np.array( - [np.asarray(i) for i in product([0, -1, 1], [0, -1, 1], [0, -1, 1])] - ) - b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :] - b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :] - return c[:, np.newaxis, :] + b_vectors - - -def pbc_kdtree(v1, box, leafsize=32, compact_nodes=False, balanced_tree=False): - """ - kd_tree with periodic images - box - whole matrix - rest: optional optimization - """ - r0 = cKDTree( - pbc_extend(v1, box).reshape((-1, 3)), leafsize, compact_nodes, balanced_tree - ) - return r0 - - -def pbc_kdtree_query(v1, v2, box, n=1): - """ - kd_tree query with periodic images - """ - r0, r1 = pbc_kdtree(v1, box).query(v2, n) - r1 = r1 // 27 - return r0, r1 - - -def pbc_backfold_rect(act_frame, box_matrix): - """ - mimics "trjconv ... -pbc atom -ur rect" - - folds coords of act_frame in cuboid - - """ - af = np.asarray(act_frame) - if af.shape == (3,): - act_frame = act_frame.reshape((1, 3)) # quick 'n dirty - b = box_matrix - c = np.diag(b) / 2 - af = pbc_diff(np.zeros((1, 3)), af - c, b) - return af + c - - -def pbc_backfold_compact(act_frame, box_matrix): +def pbc_backfold_compact(act_frame: NDArray, box_matrix: NDArray) -> NDArray: """ mimics "trjconv ... -pbc atom -ur compact" @@ -146,11 +96,11 @@ def pbc_backfold_compact(act_frame, box_matrix): b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :] b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :] sc = c[:, np.newaxis, :] + b_vectors - w = np.argsort((((sc) - ctr) ** 2).sum(2), 1)[:, 0] + w = np.argsort(((sc - ctr) ** 2).sum(2), 1)[:, 0] return sc[range(shape[0]), w] -def whole(frame): +def whole(frame: CoordinateFrame) -> CoordinateFrame: """ Apply ``-pbc whole`` to a CoordinateFrame. """ @@ -177,7 +127,7 @@ def whole(frame): NOJUMP_CACHESIZE = 128 -def nojump(frame, usecache=True): +def nojump(frame: CoordinateFrame, usecache: bool = True) -> CoordinateFrame: """ Return the nojump coordinates of a frame, based on a jump matrix. """ @@ -226,15 +176,21 @@ def nojump(frame, usecache=True): return frame - delta -def pbc_points(coordinates, box, thickness=0, index=False, shear=False): +def pbc_points( + coordinates: CoordinateFrame, + thickness: Optional[float] = None, + index: bool = False, + shear: bool = False, +) -> Union[NDArray, tuple[NDArray, NDArray]]: """ Returns the points their first periodic images. Does not fold them back into the box. Thickness 0 means all 27 boxes. Positive means the box+thickness. Negative values mean that less than the box is returned. index=True also returns the indices with indices of images being their - originals values. + original values. """ + box = coordinates.box if shear: box[2, 0] = box[2, 0] % box[0, 0] # Shifts the box images in the other directions if they moved more than @@ -249,7 +205,7 @@ def pbc_points(coordinates, box, thickness=0, index=False, shear=False): coordinates_pbc = np.concatenate([coordinates + v @ box for v in grid], axis=0) size = np.diag(box) - if thickness != 0: + if thickness is not None: mask = np.all(coordinates_pbc > -thickness, axis=1) coordinates_pbc = coordinates_pbc[mask] indices = indices[mask]