Added type hints

This commit is contained in:
Sebastian Kloth 2023-12-28 12:42:43 +01:00
parent cc08f5ae50
commit 94d67496ba

View File

@ -1,54 +1,55 @@
from collections import OrderedDict
from typing import Optional, Union
import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy.spatial import cKDTree
from itertools import product
from .logging import logger
from .coordinates import CoordinateFrame
def pbc_diff(v1, v2=None, box=None):
def pbc_diff(
coords_a: NDArray, coords_b: NDArray, box: Optional[NDArray] = None
) -> NDArray:
if box is None:
out = v1 - v2
out = coords_a - coords_b
elif len(getattr(box, "shape", [])) == 1:
out = pbc_diff_rect(v1, v2, box)
out = pbc_diff_rect(coords_a, coords_b, box)
elif len(getattr(box, "shape", [])) == 2:
out = pbc_diff_tric(v1, v2, box)
out = pbc_diff_tric(coords_a, coords_b, box)
else:
raise NotImplementedError("cannot handle box")
return out
def pbc_diff_rect(v1, v2, box):
def pbc_diff_rect(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
"""
Calculate the difference of two vectors, considering periodic boundary conditions.
"""
if v2 is None:
v = v1
else:
v = v1 - v2
v = coords_a - coords_b
s = v / box
v = box * (s - s.round())
v = box * (s - np.round(s))
return v
def pbc_diff_tric(v1, v2=None, box=None):
def pbc_diff_tric(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
"""
difference vector for arbitrary pbc
Difference vector for arbitrary pbc
Args:
box_matrix: CoordinateFrame.box
"""
if len(box.shape) == 1:
box = np.diag(box)
if v1.shape == (3,):
v1 = v1.reshape((1, 3)) # quick 'n dirty
if v2.shape == (3,):
v2 = v2.reshape((1, 3))
if coords_a.shape == (3,):
coords_a = coords_a.reshape((1, 3)) # quick 'n dirty
if coords_b.shape == (3,):
coords_b = coords_b.reshape((1, 3))
if box is not None:
r3 = np.subtract(v1, v2)
r3 = np.subtract(coords_a, coords_b)
r2 = np.subtract(
r3,
(np.rint(np.divide(r3[:, 2], box[2][2])))[:, np.newaxis]
@ -65,68 +66,17 @@ def pbc_diff_tric(v1, v2=None, box=None):
* box[0][np.newaxis, :],
)
else:
v = v1 - v2
v = coords_a - coords_b
return v
def pbc_dist(a1, a2, box=None):
return ((pbc_diff(a1, a2, box) ** 2).sum(axis=1)) ** 0.5
def pbc_dist(
atoms_a: CoordinateFrame, atoms_b: CoordinateFrame, box: Optional[NDArray] = None
) -> ArrayLike:
return ((pbc_diff(atoms_a, atoms_b, box) ** 2).sum(axis=1)) ** 0.5
def pbc_extend(c, box):
"""
in: c is frame, box is frame.box
out: all atoms in frame and their perio. image (shape => array(len(c)*27,3))
"""
c = np.asarray(c)
if c.shape == (3,):
c = c.reshape((1, 3)) # quick 'n dirty
comb = np.array(
[np.asarray(i) for i in product([0, -1, 1], [0, -1, 1], [0, -1, 1])]
)
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
return c[:, np.newaxis, :] + b_vectors
def pbc_kdtree(v1, box, leafsize=32, compact_nodes=False, balanced_tree=False):
"""
kd_tree with periodic images
box - whole matrix
rest: optional optimization
"""
r0 = cKDTree(
pbc_extend(v1, box).reshape((-1, 3)), leafsize, compact_nodes, balanced_tree
)
return r0
def pbc_kdtree_query(v1, v2, box, n=1):
"""
kd_tree query with periodic images
"""
r0, r1 = pbc_kdtree(v1, box).query(v2, n)
r1 = r1 // 27
return r0, r1
def pbc_backfold_rect(act_frame, box_matrix):
"""
mimics "trjconv ... -pbc atom -ur rect"
folds coords of act_frame in cuboid
"""
af = np.asarray(act_frame)
if af.shape == (3,):
act_frame = act_frame.reshape((1, 3)) # quick 'n dirty
b = box_matrix
c = np.diag(b) / 2
af = pbc_diff(np.zeros((1, 3)), af - c, b)
return af + c
def pbc_backfold_compact(act_frame, box_matrix):
def pbc_backfold_compact(act_frame: NDArray, box_matrix: NDArray) -> NDArray:
"""
mimics "trjconv ... -pbc atom -ur compact"
@ -146,11 +96,11 @@ def pbc_backfold_compact(act_frame, box_matrix):
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
sc = c[:, np.newaxis, :] + b_vectors
w = np.argsort((((sc) - ctr) ** 2).sum(2), 1)[:, 0]
w = np.argsort(((sc - ctr) ** 2).sum(2), 1)[:, 0]
return sc[range(shape[0]), w]
def whole(frame):
def whole(frame: CoordinateFrame) -> CoordinateFrame:
"""
Apply ``-pbc whole`` to a CoordinateFrame.
"""
@ -177,7 +127,7 @@ def whole(frame):
NOJUMP_CACHESIZE = 128
def nojump(frame, usecache=True):
def nojump(frame: CoordinateFrame, usecache: bool = True) -> CoordinateFrame:
"""
Return the nojump coordinates of a frame, based on a jump matrix.
"""
@ -226,15 +176,21 @@ def nojump(frame, usecache=True):
return frame - delta
def pbc_points(coordinates, box, thickness=0, index=False, shear=False):
def pbc_points(
coordinates: CoordinateFrame,
thickness: Optional[float] = None,
index: bool = False,
shear: bool = False,
) -> Union[NDArray, tuple[NDArray, NDArray]]:
"""
Returns the points their first periodic images. Does not fold
them back into the box.
Thickness 0 means all 27 boxes. Positive means the box+thickness.
Negative values mean that less than the box is returned.
index=True also returns the indices with indices of images being their
originals values.
original values.
"""
box = coordinates.box
if shear:
box[2, 0] = box[2, 0] % box[0, 0]
# Shifts the box images in the other directions if they moved more than
@ -249,7 +205,7 @@ def pbc_points(coordinates, box, thickness=0, index=False, shear=False):
coordinates_pbc = np.concatenate([coordinates + v @ box for v in grid], axis=0)
size = np.diag(box)
if thickness != 0:
if thickness is not None:
mask = np.all(coordinates_pbc > -thickness, axis=1)
coordinates_pbc = coordinates_pbc[mask]
indices = indices[mask]