Adjusted type hints

This commit is contained in:
Sebastian Kloth 2024-01-16 13:38:44 +01:00
parent 95b46c43be
commit 5fdc9c8698

View File

@ -1,10 +1,10 @@
from functools import partial, wraps
from copy import copy
from .logging import logger
from typing import Optional, Callable
from typing import Optional, Callable, List, Tuple
import numpy as np
import numpy.typing as npt
from numpy.typing import ArrayLike, NDArray
from scipy.spatial import KDTree
from .atoms import AtomSubset
@ -17,7 +17,7 @@ class UnknownCoordinatesMode(Exception):
pass
class CoordinateFrame(np.ndarray):
class CoordinateFrame(NDArray):
_known_modes = ("pbc", "whole", "nojump")
@property
@ -99,7 +99,7 @@ class CoordinateFrame(np.ndarray):
box=None,
mode=None,
):
obj = np.ndarray.__new__(subtype, shape, dtype, buffer, offset, strides)
obj = NDArray.__new__(subtype, shape, dtype, buffer, offset, strides)
obj.coordinates = coordinates
obj.step = step
@ -319,7 +319,7 @@ class CoordinatesMap:
return CoordinatesMap(self.coordinates.pbc, self.function)
def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray:
def rotate_axis(coords: ArrayLike, axis: ArrayLike) -> NDArray:
"""
Rotate a set of coordinates to a given axis.
"""
@ -352,8 +352,8 @@ def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray:
def spherical_radius(
frame: CoordinateFrame, origin: Optional[npt.ArrayLike] = None
) -> np.ndarray:
frame: CoordinateFrame, origin: Optional[ArrayLike] = None
) -> NDArray:
"""
Transform a frame of cartesian coordinates into the spherical radius.
If origin=None, the center of the box is taken as the coordinates' origin.
@ -363,7 +363,7 @@ def spherical_radius(
return ((frame - origin) ** 2).sum(axis=-1) ** 0.5
def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.ndarray):
def polar_coordinates(x: ArrayLike, y: ArrayLike) -> (NDArray, NDArray):
"""Convert cartesian to polar coordinates."""
radius = (x**2 + y**2) ** 0.5
phi = np.arctan2(y, x)
@ -371,8 +371,8 @@ def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.nda
def spherical_coordinates(
x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike
) -> (np.ndarray, np.ndarray, np.ndarray):
x: ArrayLike, y: ArrayLike, z: ArrayLike
) -> (NDArray, NDArray, NDArray):
"""Convert cartesian to spherical coordinates."""
xy, phi = polar_coordinates(x, y)
radius = (x**2 + y**2 + z**2) ** 0.5
@ -384,8 +384,8 @@ def selector_radial_cylindrical(
atoms: CoordinateFrame,
r_min: float,
r_max: float,
origin: Optional[npt.ArrayLike] = None,
) -> np.ndarray:
origin: Optional[ArrayLike] = None,
) -> NDArray:
box = atoms.box
atoms = atoms % np.diag(box)
if origin is None:
@ -397,7 +397,7 @@ def selector_radial_cylindrical(
def map_coordinates(
func: Callable[[CoordinateFrame, ...], np.ndarray]
func: Callable[[CoordinateFrame, ...], NDArray]
) -> Callable[..., CoordinatesMap]:
@wraps(func)
def wrapped(coordinates: Coordinates, **kwargs) -> CoordinatesMap:
@ -408,14 +408,14 @@ def map_coordinates(
@map_coordinates
def center_of_masses(
frame: CoordinateFrame, atoms=None, shear: bool = False
) -> np.ndarray:
if atoms is None:
atoms = list(range(len(frame)))
res_ids = frame.residue_ids[atoms]
masses = frame.masses[atoms]
frame: CoordinateFrame, atom_indices=None, shear: bool = False
) -> NDArray:
if atom_indices is None:
atom_indices = list(range(len(frame)))
res_ids = frame.residue_ids[atom_indices]
masses = frame.masses[atom_indices]
if shear:
coords = frame[atoms]
coords = frame[atom_indices]
box = frame.box
sort_ind = res_ids.argsort(kind="stable")
i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1])
@ -423,7 +423,7 @@ def center_of_masses(
cor = pbc_diff(coords, coms, box)
coords = coms + cor
else:
coords = frame.whole[atoms]
coords = frame.whole[atom_indices]
mask = np.bincount(res_ids)[1:] != 0
positions = np.array(
[
@ -437,8 +437,8 @@ def center_of_masses(
@map_coordinates
def pore_coordinates(
frame: CoordinateFrame, origin: npt.ArrayLike, sym_axis: str = "z"
) -> np.ndarray:
frame: CoordinateFrame, origin: ArrayLike, sym_axis: str = "z"
) -> NDArray:
"""
Map coordinates of a pore simulation so the pore has cylindrical symmetry.
@ -459,17 +459,17 @@ def pore_coordinates(
@map_coordinates
def vectors(
frame: CoordinateFrame,
atoms_indices_a: npt.ArrayLike,
atoms_indices_b: npt.ArrayLike,
atom_indices_a: ArrayLike,
atom_indices_b: ArrayLike,
normed: bool = False,
) -> np.ndarray:
) -> NDArray:
"""
Compute the vectors between the atoms of two subsets.
Args:
frame: The Coordinates object the atoms will be taken from
atoms_indices_a: Mask or indices of the first atom subset
atoms_indices_b: Mask or indices of the second atom subset
atom_indices_a: Mask or indices of the first atom subset
atom_indices_b: Mask or indices of the second atom subset
normed (opt.): If the vectors should be normed
The definition of atoms_a/b can be any possible subript of a numpy array.
@ -492,10 +492,10 @@ def vectors(
])
"""
box = frame.box
coords_a = frame[atoms_indices_a]
coords_a = frame[atom_indices_a]
if len(coords_a.shape) > 2:
coords_a = coords_a.mean(axis=0)
coords_b = frame[atoms_indices_b]
coords_b = frame[atom_indices_b]
if len(coords_b.shape) > 2:
coords_b = coords_b.mean(axis=0)
vec = pbc_diff(coords_a, coords_b, box=box)
@ -507,8 +507,8 @@ def vectors(
@map_coordinates
def dipole_vector(
frame: CoordinateFrame, atom_indices: npt.ArrayLike, normed: bool = None
) -> np.ndarray:
frame: CoordinateFrame, atom_indices: ArrayLike, normed: bool = None
) -> NDArray:
coords = frame.whole[atom_indices]
res_ids = frame.residue_ids[atom_indices]
charges = frame.charges[atom_indices]
@ -525,9 +525,9 @@ def dipole_vector(
@map_coordinates
def sum_dipole_vector(
coordinates: CoordinateFrame,
atom_indices: npt.ArrayLike,
atom_indices: ArrayLike,
normed: bool = True,
) -> np.ndarray:
) -> NDArray:
coords = coordinates.whole[atom_indices]
charges = coordinates.charges[atom_indices]
dipole = np.array([c * charges for c in coords.T]).T
@ -539,11 +539,11 @@ def sum_dipole_vector(
@map_coordinates
def normal_vectors(
frame: CoordinateFrame,
atom_indices_a: npt.ArrayLike,
atom_indices_b: npt.ArrayLike,
atom_indices_c: npt.ArrayLike,
atom_indices_a: ArrayLike,
atom_indices_b: ArrayLike,
atom_indices_c: ArrayLike,
normed: bool = True,
) -> np.ndarray:
) -> NDArray:
coords_a = frame[atom_indices_a]
coords_b = frame[atom_indices_b]
coords_c = frame[atom_indices_c]
@ -571,8 +571,8 @@ def displacements_without_drift(
@map_coordinates
def cylindrical_coordinates(
frame: CoordinateFrame, origin: npt.ArrayLike = None
) -> np.ndarray:
frame: CoordinateFrame, origin: ArrayLike = None
) -> NDArray:
if origin is None:
origin = np.diag(frame.box) / 2
x = frame[:, 0] - origin[0]
@ -586,8 +586,8 @@ def cylindrical_coordinates(
def layer_of_atoms(
atoms: CoordinateFrame,
thickness: float,
plane_normal: npt.ArrayLike,
plane_offset: Optional[npt.ArrayLike] = np.array([0, 0, 0]),
plane_normal: ArrayLike,
plane_offset: Optional[ArrayLike] = np.array([0, 0, 0]),
) -> np.array:
if plane_offset is None:
np.array([0, 0, 0])
@ -603,7 +603,7 @@ def next_neighbors(
distance_upper_bound: float = np.inf,
distinct: bool = False,
**kwargs
) -> (np.ndarray, np.ndarray):
) -> Tuple[List, List]:
"""
Find the N next neighbors of a set of atoms.
@ -635,6 +635,14 @@ def next_neighbors(
number_of_neighbors + dnn,
distance_upper_bound=distance_upper_bound,
)
distances = distances[:, dnn:]
indices = indices[:, dnn:]
distances_new = []
indices_new = []
for dist, ind in zip(distances, indices):
distances_new.append(dist[dist <= distance_upper_bound])
indices_new.append(ind[dist <= distance_upper_bound])
return distances_new, indices_new
else:
atoms_pbc, atoms_pbc_index = pbc_points(
query_atoms, box, thickness=distance_upper_bound + 0.1, index=True, **kwargs
@ -645,6 +653,51 @@ def next_neighbors(
number_of_neighbors + dnn,
distance_upper_bound=distance_upper_bound,
)
indices = atoms_pbc_index[indices]
distances = distances[:, dnn:]
indices = indices[:, dnn:]
distances_new = []
indices_new = []
for dist, ind in zip(distances, indices):
distances_new.append(dist[dist <= distance_upper_bound])
indices_new.append(atoms_pbc_index[ind[dist <= distance_upper_bound]])
return distances_new, indices_new
return distances[:, dnn:], indices[:, dnn:]
def number_of_neighbors(
atoms: CoordinateFrame,
query_atoms: Optional[CoordinateFrame] = None,
r_max: float = 1,
distinct: bool = False,
**kwargs
) -> Tuple[List, List]:
"""
Find the N next neighbors of a set of atoms.
Args:
atoms:
The reference atoms and also the atoms which are queried if `query_atoms`
is net provided
query_atoms (opt.): If this is not None, these atoms will be queried
r_max (float, opt.):
Upper bound of the distance between neighbors
distinct (bool, opt.):
If this is true, the atoms and query atoms are taken as distinct sets of
atoms
"""
dnn = 0
if query_atoms is None:
query_atoms = atoms
dnn = 1
elif not distinct:
dnn = 1
box = atoms.box
if np.all(np.diag(np.diag(box)) == box):
atoms = atoms % np.diag(box)
tree = KDTree(atoms, boxsize=np.diag(box))
else:
atoms_pbc = pbc_points(query_atoms, box, thickness=r_max + 0.1, **kwargs)
tree = KDTree(atoms_pbc)
num_of_neighbors = tree.query_ball_point(query_atoms, r_max, return_length=True)
return num_of_neighbors - dnn