Added type hints

This commit is contained in:
Sebastian Kloth 2023-12-28 12:42:43 +01:00
parent cc08f5ae50
commit 94d67496ba

View File

@ -1,54 +1,55 @@
from collections import OrderedDict from collections import OrderedDict
from typing import Optional, Union
import numpy as np import numpy as np
from numpy.typing import ArrayLike, NDArray
from scipy.spatial import cKDTree from scipy.spatial import cKDTree
from itertools import product from itertools import product
from .logging import logger from .logging import logger
from .coordinates import CoordinateFrame
def pbc_diff(v1, v2=None, box=None): def pbc_diff(
coords_a: NDArray, coords_b: NDArray, box: Optional[NDArray] = None
) -> NDArray:
if box is None: if box is None:
out = v1 - v2 out = coords_a - coords_b
elif len(getattr(box, "shape", [])) == 1: elif len(getattr(box, "shape", [])) == 1:
out = pbc_diff_rect(v1, v2, box) out = pbc_diff_rect(coords_a, coords_b, box)
elif len(getattr(box, "shape", [])) == 2: elif len(getattr(box, "shape", [])) == 2:
out = pbc_diff_tric(v1, v2, box) out = pbc_diff_tric(coords_a, coords_b, box)
else: else:
raise NotImplementedError("cannot handle box") raise NotImplementedError("cannot handle box")
return out return out
def pbc_diff_rect(v1, v2, box): def pbc_diff_rect(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
""" """
Calculate the difference of two vectors, considering periodic boundary conditions. Calculate the difference of two vectors, considering periodic boundary conditions.
""" """
if v2 is None: v = coords_a - coords_b
v = v1
else:
v = v1 - v2
s = v / box s = v / box
v = box * (s - s.round()) v = box * (s - np.round(s))
return v return v
def pbc_diff_tric(v1, v2=None, box=None): def pbc_diff_tric(coords_a: NDArray, coords_b: NDArray, box: NDArray) -> NDArray:
""" """
difference vector for arbitrary pbc Difference vector for arbitrary pbc
Args: Args:
box_matrix: CoordinateFrame.box box_matrix: CoordinateFrame.box
""" """
if len(box.shape) == 1: if len(box.shape) == 1:
box = np.diag(box) box = np.diag(box)
if v1.shape == (3,): if coords_a.shape == (3,):
v1 = v1.reshape((1, 3)) # quick 'n dirty coords_a = coords_a.reshape((1, 3)) # quick 'n dirty
if v2.shape == (3,): if coords_b.shape == (3,):
v2 = v2.reshape((1, 3)) coords_b = coords_b.reshape((1, 3))
if box is not None: if box is not None:
r3 = np.subtract(v1, v2) r3 = np.subtract(coords_a, coords_b)
r2 = np.subtract( r2 = np.subtract(
r3, r3,
(np.rint(np.divide(r3[:, 2], box[2][2])))[:, np.newaxis] (np.rint(np.divide(r3[:, 2], box[2][2])))[:, np.newaxis]
@ -65,68 +66,17 @@ def pbc_diff_tric(v1, v2=None, box=None):
* box[0][np.newaxis, :], * box[0][np.newaxis, :],
) )
else: else:
v = v1 - v2 v = coords_a - coords_b
return v return v
def pbc_dist(a1, a2, box=None): def pbc_dist(
return ((pbc_diff(a1, a2, box) ** 2).sum(axis=1)) ** 0.5 atoms_a: CoordinateFrame, atoms_b: CoordinateFrame, box: Optional[NDArray] = None
) -> ArrayLike:
return ((pbc_diff(atoms_a, atoms_b, box) ** 2).sum(axis=1)) ** 0.5
def pbc_extend(c, box): def pbc_backfold_compact(act_frame: NDArray, box_matrix: NDArray) -> NDArray:
"""
in: c is frame, box is frame.box
out: all atoms in frame and their perio. image (shape => array(len(c)*27,3))
"""
c = np.asarray(c)
if c.shape == (3,):
c = c.reshape((1, 3)) # quick 'n dirty
comb = np.array(
[np.asarray(i) for i in product([0, -1, 1], [0, -1, 1], [0, -1, 1])]
)
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
return c[:, np.newaxis, :] + b_vectors
def pbc_kdtree(v1, box, leafsize=32, compact_nodes=False, balanced_tree=False):
"""
kd_tree with periodic images
box - whole matrix
rest: optional optimization
"""
r0 = cKDTree(
pbc_extend(v1, box).reshape((-1, 3)), leafsize, compact_nodes, balanced_tree
)
return r0
def pbc_kdtree_query(v1, v2, box, n=1):
"""
kd_tree query with periodic images
"""
r0, r1 = pbc_kdtree(v1, box).query(v2, n)
r1 = r1 // 27
return r0, r1
def pbc_backfold_rect(act_frame, box_matrix):
"""
mimics "trjconv ... -pbc atom -ur rect"
folds coords of act_frame in cuboid
"""
af = np.asarray(act_frame)
if af.shape == (3,):
act_frame = act_frame.reshape((1, 3)) # quick 'n dirty
b = box_matrix
c = np.diag(b) / 2
af = pbc_diff(np.zeros((1, 3)), af - c, b)
return af + c
def pbc_backfold_compact(act_frame, box_matrix):
""" """
mimics "trjconv ... -pbc atom -ur compact" mimics "trjconv ... -pbc atom -ur compact"
@ -146,11 +96,11 @@ def pbc_backfold_compact(act_frame, box_matrix):
b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :] b_matrices = comb[:, :, np.newaxis] * box[np.newaxis, :, :]
b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :] b_vectors = b_matrices.sum(axis=1)[np.newaxis, :, :]
sc = c[:, np.newaxis, :] + b_vectors sc = c[:, np.newaxis, :] + b_vectors
w = np.argsort((((sc) - ctr) ** 2).sum(2), 1)[:, 0] w = np.argsort(((sc - ctr) ** 2).sum(2), 1)[:, 0]
return sc[range(shape[0]), w] return sc[range(shape[0]), w]
def whole(frame): def whole(frame: CoordinateFrame) -> CoordinateFrame:
""" """
Apply ``-pbc whole`` to a CoordinateFrame. Apply ``-pbc whole`` to a CoordinateFrame.
""" """
@ -177,7 +127,7 @@ def whole(frame):
NOJUMP_CACHESIZE = 128 NOJUMP_CACHESIZE = 128
def nojump(frame, usecache=True): def nojump(frame: CoordinateFrame, usecache: bool = True) -> CoordinateFrame:
""" """
Return the nojump coordinates of a frame, based on a jump matrix. Return the nojump coordinates of a frame, based on a jump matrix.
""" """
@ -226,15 +176,21 @@ def nojump(frame, usecache=True):
return frame - delta return frame - delta
def pbc_points(coordinates, box, thickness=0, index=False, shear=False): def pbc_points(
coordinates: CoordinateFrame,
thickness: Optional[float] = None,
index: bool = False,
shear: bool = False,
) -> Union[NDArray, tuple[NDArray, NDArray]]:
""" """
Returns the points their first periodic images. Does not fold Returns the points their first periodic images. Does not fold
them back into the box. them back into the box.
Thickness 0 means all 27 boxes. Positive means the box+thickness. Thickness 0 means all 27 boxes. Positive means the box+thickness.
Negative values mean that less than the box is returned. Negative values mean that less than the box is returned.
index=True also returns the indices with indices of images being their index=True also returns the indices with indices of images being their
originals values. original values.
""" """
box = coordinates.box
if shear: if shear:
box[2, 0] = box[2, 0] % box[0, 0] box[2, 0] = box[2, 0] % box[0, 0]
# Shifts the box images in the other directions if they moved more than # Shifts the box images in the other directions if they moved more than
@ -249,7 +205,7 @@ def pbc_points(coordinates, box, thickness=0, index=False, shear=False):
coordinates_pbc = np.concatenate([coordinates + v @ box for v in grid], axis=0) coordinates_pbc = np.concatenate([coordinates + v @ box for v in grid], axis=0)
size = np.diag(box) size = np.diag(box)
if thickness != 0: if thickness is not None:
mask = np.all(coordinates_pbc > -thickness, axis=1) mask = np.all(coordinates_pbc > -thickness, axis=1)
coordinates_pbc = coordinates_pbc[mask] coordinates_pbc = coordinates_pbc[mask]
indices = indices[mask] indices = indices[mask]