Added type hints
This commit is contained in:
parent
cc08f5ae50
commit
94d67496ba
@ -1,54 +1,55 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import ArrayLike, NDArray
|
||||
|
||||
from scipy.spatial import cKDTree
|
||||
from itertools import product
|
||||
|
||||
from .logging import logger
|
||||
from .coordinates import CoordinateFrame
|
||||
|
||||
|
||||
def pbc_diff(v1, v2=None, box=None):
|
||||
def pbc_diff(
|
||||
coords_a: NDArray, coords_b: NDArray, box: Optional[NDArray] = None
|
||||
) -> NDArray:
|
||||
if box is None:
|
||||
out = v1 - v2
|
||||
out = coords_a - coords_b
|
||||
elif len(getattr(box, "shape", [])) == 1:
|
||||
out = pbc_diff_rect(v1, v2, box)
|
||||
out = pbc_diff_rect(coords_a, coords_b, box)
|
||||
elif len(getattr(box, "shape", [])) == 2:
|
||||
out = pbc_diff_tric(v1, v2, box)
|
||||
out = pbc_diff_tric(coords_a, coords_b, box)
|
||||
else:
|
||||
raise NotImplementedError("cannot handle box")
|
||||
return out
|
||||
|
||||
|
||||
def pbc_diff_rect(v1, v2, box):
|
||||
def pbc_diff_rect(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
|
||||
"""
|
||||
Calculate the difference of two vectors, considering periodic boundary conditions.
|
||||
"""
|
||||
if v2 is None:
|
||||
v = v1
|
||||
else:
|
||||
v = v1 - v2
|
||||
|
||||
v = coords_a - coords_b
|
||||
s = v / box
|
||||
v = box * (s - s.round())
|
||||
v = box * (s - np.round(s))
|
||||
return v
|
||||
|
||||
|
||||
def pbc_diff_tric(v1, v2=None, box=None):
|
||||
def pbc_diff_tric(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
|
||||
"""
|
||||
difference vector for arbitrary pbc
|
||||
Difference vector for arbitrary pbc
|
||||
|
||||
Args:
|
||||
box_matrix: CoordinateFrame.box
|
||||
"""
|
||||
if len(box.shape) == 1:
|
||||
box = np.diag(box)
|
||||
if v1.shape == (3,):
|
||||
v1 = v1.reshape((1, 3)) # quick 'n dirty
|
||||
if v2.shape == (3,):
|
||||
v2 = v2.reshape((1, 3))
|
||||
if coords_a.shape == (3,):
|
||||
coords_a = coords_a.reshape((1, 3)) # quick 'n dirty
|
||||
if coords_b.shape == (3,):
|
||||
coords_b = coords_b.reshape((1, 3))
|
||||
if box is not None:
|
||||
r3 = np.subtract(v1, v2)
|
||||
r3 = np.subtract(coords_a, coords_b)
|
||||
r2 = np.subtract(
|
||||
r3,
|
||||
(np.rint(np.divide(r3[:, 2], box[2][2])))[:, np.newaxis]
|
||||
@ -65,68 +66,17 @@ def pbc_diff_tric(v1, v2=None, box=None):
|
||||
* box[0][np.newaxis, :],
|
||||
)
|
||||
else:
|
||||
v = v1 - v2
|
||||
v = coords_a - coords_b
|
||||
return v
|
||||
|
||||
|
||||
def pbc_dist(a1, a2, box=None):
|
||||
return ((pbc_diff(a1, a2, box) ** 2).sum(axis=1)) ** 0.5
|
||||
def pbc_dist(
|
||||
atoms_a: CoordinateFrame, atoms_b: CoordinateFrame, box: Optional[NDArray] = None
|
||||
) -> ArrayLike:
|
||||
return ((pbc_diff(atoms_a, atoms_b, box) ** 2).sum(axis=1)) ** 0.5
|
||||
|
||||
|
||||
def pbc_extend(c, box):
|
||||
"""
|
||||
in: c is frame, box is frame.box
|
||||
out: all atoms in frame and their perio. image (shape => array(len(c)*27,3))
|
||||
"""
|
||||
c = np.asarray(c)
|
||||
if c.shape == (3,):
|
||||
c = c.reshape((1, 3)) # quick 'n dirty
|
||||
comb = np.array(
|
||||
[np.asarray(i) for i in product([0, -1, 1], [0, -1, 1], [0, -1, 1])]
|
||||
)
|
||||
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
|
||||
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
|
||||
return c[:, np.newaxis, :] + b_vectors
|
||||
|
||||
|
||||
def pbc_kdtree(v1, box, leafsize=32, compact_nodes=False, balanced_tree=False):
|
||||
"""
|
||||
kd_tree with periodic images
|
||||
box - whole matrix
|
||||
rest: optional optimization
|
||||
"""
|
||||
r0 = cKDTree(
|
||||
pbc_extend(v1, box).reshape((-1, 3)), leafsize, compact_nodes, balanced_tree
|
||||
)
|
||||
return r0
|
||||
|
||||
|
||||
def pbc_kdtree_query(v1, v2, box, n=1):
|
||||
"""
|
||||
kd_tree query with periodic images
|
||||
"""
|
||||
r0, r1 = pbc_kdtree(v1, box).query(v2, n)
|
||||
r1 = r1 // 27
|
||||
return r0, r1
|
||||
|
||||
|
||||
def pbc_backfold_rect(act_frame, box_matrix):
|
||||
"""
|
||||
mimics "trjconv ... -pbc atom -ur rect"
|
||||
|
||||
folds coords of act_frame in cuboid
|
||||
|
||||
"""
|
||||
af = np.asarray(act_frame)
|
||||
if af.shape == (3,):
|
||||
act_frame = act_frame.reshape((1, 3)) # quick 'n dirty
|
||||
b = box_matrix
|
||||
c = np.diag(b) / 2
|
||||
af = pbc_diff(np.zeros((1, 3)), af - c, b)
|
||||
return af + c
|
||||
|
||||
|
||||
def pbc_backfold_compact(act_frame, box_matrix):
|
||||
def pbc_backfold_compact(act_frame: NDArray, box_matrix: NDArray) -> NDArray:
|
||||
"""
|
||||
mimics "trjconv ... -pbc atom -ur compact"
|
||||
|
||||
@ -146,11 +96,11 @@ def pbc_backfold_compact(act_frame, box_matrix):
|
||||
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
|
||||
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
|
||||
sc = c[:, np.newaxis, :] + b_vectors
|
||||
w = np.argsort((((sc) - ctr) ** 2).sum(2), 1)[:, 0]
|
||||
w = np.argsort(((sc - ctr) ** 2).sum(2), 1)[:, 0]
|
||||
return sc[range(shape[0]), w]
|
||||
|
||||
|
||||
def whole(frame):
|
||||
def whole(frame: CoordinateFrame) -> CoordinateFrame:
|
||||
"""
|
||||
Apply ``-pbc whole`` to a CoordinateFrame.
|
||||
"""
|
||||
@ -177,7 +127,7 @@ def whole(frame):
|
||||
NOJUMP_CACHESIZE = 128
|
||||
|
||||
|
||||
def nojump(frame, usecache=True):
|
||||
def nojump(frame: CoordinateFrame, usecache: bool = True) -> CoordinateFrame:
|
||||
"""
|
||||
Return the nojump coordinates of a frame, based on a jump matrix.
|
||||
"""
|
||||
@ -226,15 +176,21 @@ def nojump(frame, usecache=True):
|
||||
return frame - delta
|
||||
|
||||
|
||||
def pbc_points(coordinates, box, thickness=0, index=False, shear=False):
|
||||
def pbc_points(
|
||||
coordinates: CoordinateFrame,
|
||||
thickness: Optional[float] = None,
|
||||
index: bool = False,
|
||||
shear: bool = False,
|
||||
) -> Union[NDArray, tuple[NDArray, NDArray]]:
|
||||
"""
|
||||
Returns the points their first periodic images. Does not fold
|
||||
them back into the box.
|
||||
Thickness 0 means all 27 boxes. Positive means the box+thickness.
|
||||
Negative values mean that less than the box is returned.
|
||||
index=True also returns the indices with indices of images being their
|
||||
originals values.
|
||||
original values.
|
||||
"""
|
||||
box = coordinates.box
|
||||
if shear:
|
||||
box[2, 0] = box[2, 0] % box[0, 0]
|
||||
# Shifts the box images in the other directions if they moved more than
|
||||
@ -249,7 +205,7 @@ def pbc_points(coordinates, box, thickness=0, index=False, shear=False):
|
||||
coordinates_pbc = np.concatenate([coordinates + v @ box for v in grid], axis=0)
|
||||
size = np.diag(box)
|
||||
|
||||
if thickness != 0:
|
||||
if thickness is not None:
|
||||
mask = np.all(coordinates_pbc > -thickness, axis=1)
|
||||
coordinates_pbc = coordinates_pbc[mask]
|
||||
indices = indices[mask]
|
||||
|
Loading…
Reference in New Issue
Block a user