Adjusted type hints

This commit is contained in:
Sebastian Kloth 2024-01-16 13:38:44 +01:00
parent 95b46c43be
commit 5fdc9c8698

View File

@ -1,10 +1,10 @@
from functools import partial, wraps from functools import partial, wraps
from copy import copy from copy import copy
from .logging import logger from .logging import logger
from typing import Optional, Callable from typing import Optional, Callable, List, Tuple
import numpy as np import numpy as np
import numpy.typing as npt from numpy.typing import ArrayLike, NDArray
from scipy.spatial import KDTree from scipy.spatial import KDTree
from .atoms import AtomSubset from .atoms import AtomSubset
@ -17,7 +17,7 @@ class UnknownCoordinatesMode(Exception):
pass pass
class CoordinateFrame(np.ndarray): class CoordinateFrame(NDArray):
_known_modes = ("pbc", "whole", "nojump") _known_modes = ("pbc", "whole", "nojump")
@property @property
@ -99,7 +99,7 @@ class CoordinateFrame(np.ndarray):
box=None, box=None,
mode=None, mode=None,
): ):
obj = np.ndarray.__new__(subtype, shape, dtype, buffer, offset, strides) obj = NDArray.__new__(subtype, shape, dtype, buffer, offset, strides)
obj.coordinates = coordinates obj.coordinates = coordinates
obj.step = step obj.step = step
@ -319,7 +319,7 @@ class CoordinatesMap:
return CoordinatesMap(self.coordinates.pbc, self.function) return CoordinatesMap(self.coordinates.pbc, self.function)
def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray: def rotate_axis(coords: ArrayLike, axis: ArrayLike) -> NDArray:
""" """
Rotate a set of coordinates to a given axis. Rotate a set of coordinates to a given axis.
""" """
@ -352,8 +352,8 @@ def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray:
def spherical_radius( def spherical_radius(
frame: CoordinateFrame, origin: Optional[npt.ArrayLike] = None frame: CoordinateFrame, origin: Optional[ArrayLike] = None
) -> np.ndarray: ) -> NDArray:
""" """
Transform a frame of cartesian coordinates into the spherical radius. Transform a frame of cartesian coordinates into the spherical radius.
If origin=None, the center of the box is taken as the coordinates' origin. If origin=None, the center of the box is taken as the coordinates' origin.
@ -363,7 +363,7 @@ def spherical_radius(
return ((frame - origin) ** 2).sum(axis=-1) ** 0.5 return ((frame - origin) ** 2).sum(axis=-1) ** 0.5
def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.ndarray): def polar_coordinates(x: ArrayLike, y: ArrayLike) -> (NDArray, NDArray):
"""Convert cartesian to polar coordinates.""" """Convert cartesian to polar coordinates."""
radius = (x**2 + y**2) ** 0.5 radius = (x**2 + y**2) ** 0.5
phi = np.arctan2(y, x) phi = np.arctan2(y, x)
@ -371,8 +371,8 @@ def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.nda
def spherical_coordinates( def spherical_coordinates(
x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike x: ArrayLike, y: ArrayLike, z: ArrayLike
) -> (np.ndarray, np.ndarray, np.ndarray): ) -> (NDArray, NDArray, NDArray):
"""Convert cartesian to spherical coordinates.""" """Convert cartesian to spherical coordinates."""
xy, phi = polar_coordinates(x, y) xy, phi = polar_coordinates(x, y)
radius = (x**2 + y**2 + z**2) ** 0.5 radius = (x**2 + y**2 + z**2) ** 0.5
@ -384,8 +384,8 @@ def selector_radial_cylindrical(
atoms: CoordinateFrame, atoms: CoordinateFrame,
r_min: float, r_min: float,
r_max: float, r_max: float,
origin: Optional[npt.ArrayLike] = None, origin: Optional[ArrayLike] = None,
) -> np.ndarray: ) -> NDArray:
box = atoms.box box = atoms.box
atoms = atoms % np.diag(box) atoms = atoms % np.diag(box)
if origin is None: if origin is None:
@ -397,7 +397,7 @@ def selector_radial_cylindrical(
def map_coordinates( def map_coordinates(
func: Callable[[CoordinateFrame, ...], np.ndarray] func: Callable[[CoordinateFrame, ...], NDArray]
) -> Callable[..., CoordinatesMap]: ) -> Callable[..., CoordinatesMap]:
@wraps(func) @wraps(func)
def wrapped(coordinates: Coordinates, **kwargs) -> CoordinatesMap: def wrapped(coordinates: Coordinates, **kwargs) -> CoordinatesMap:
@ -408,14 +408,14 @@ def map_coordinates(
@map_coordinates @map_coordinates
def center_of_masses( def center_of_masses(
frame: CoordinateFrame, atoms=None, shear: bool = False frame: CoordinateFrame, atom_indices=None, shear: bool = False
) -> np.ndarray: ) -> NDArray:
if atoms is None: if atom_indices is None:
atoms = list(range(len(frame))) atom_indices = list(range(len(frame)))
res_ids = frame.residue_ids[atoms] res_ids = frame.residue_ids[atom_indices]
masses = frame.masses[atoms] masses = frame.masses[atom_indices]
if shear: if shear:
coords = frame[atoms] coords = frame[atom_indices]
box = frame.box box = frame.box
sort_ind = res_ids.argsort(kind="stable") sort_ind = res_ids.argsort(kind="stable")
i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1]) i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1])
@ -423,7 +423,7 @@ def center_of_masses(
cor = pbc_diff(coords, coms, box) cor = pbc_diff(coords, coms, box)
coords = coms + cor coords = coms + cor
else: else:
coords = frame.whole[atoms] coords = frame.whole[atom_indices]
mask = np.bincount(res_ids)[1:] != 0 mask = np.bincount(res_ids)[1:] != 0
positions = np.array( positions = np.array(
[ [
@ -437,8 +437,8 @@ def center_of_masses(
@map_coordinates @map_coordinates
def pore_coordinates( def pore_coordinates(
frame: CoordinateFrame, origin: npt.ArrayLike, sym_axis: str = "z" frame: CoordinateFrame, origin: ArrayLike, sym_axis: str = "z"
) -> np.ndarray: ) -> NDArray:
""" """
Map coordinates of a pore simulation so the pore has cylindrical symmetry. Map coordinates of a pore simulation so the pore has cylindrical symmetry.
@ -459,17 +459,17 @@ def pore_coordinates(
@map_coordinates @map_coordinates
def vectors( def vectors(
frame: CoordinateFrame, frame: CoordinateFrame,
atoms_indices_a: npt.ArrayLike, atom_indices_a: ArrayLike,
atoms_indices_b: npt.ArrayLike, atom_indices_b: ArrayLike,
normed: bool = False, normed: bool = False,
) -> np.ndarray: ) -> NDArray:
""" """
Compute the vectors between the atoms of two subsets. Compute the vectors between the atoms of two subsets.
Args: Args:
frame: The Coordinates object the atoms will be taken from frame: The Coordinates object the atoms will be taken from
atoms_indices_a: Mask or indices of the first atom subset atom_indices_a: Mask or indices of the first atom subset
atoms_indices_b: Mask or indices of the second atom subset atom_indices_b: Mask or indices of the second atom subset
normed (opt.): If the vectors should be normed normed (opt.): If the vectors should be normed
The definition of atoms_a/b can be any possible subript of a numpy array. The definition of atoms_a/b can be any possible subript of a numpy array.
@ -492,10 +492,10 @@ def vectors(
]) ])
""" """
box = frame.box box = frame.box
coords_a = frame[atoms_indices_a] coords_a = frame[atom_indices_a]
if len(coords_a.shape) > 2: if len(coords_a.shape) > 2:
coords_a = coords_a.mean(axis=0) coords_a = coords_a.mean(axis=0)
coords_b = frame[atoms_indices_b] coords_b = frame[atom_indices_b]
if len(coords_b.shape) > 2: if len(coords_b.shape) > 2:
coords_b = coords_b.mean(axis=0) coords_b = coords_b.mean(axis=0)
vec = pbc_diff(coords_a, coords_b, box=box) vec = pbc_diff(coords_a, coords_b, box=box)
@ -507,8 +507,8 @@ def vectors(
@map_coordinates @map_coordinates
def dipole_vector( def dipole_vector(
frame: CoordinateFrame, atom_indices: npt.ArrayLike, normed: bool = None frame: CoordinateFrame, atom_indices: ArrayLike, normed: bool = None
) -> np.ndarray: ) -> NDArray:
coords = frame.whole[atom_indices] coords = frame.whole[atom_indices]
res_ids = frame.residue_ids[atom_indices] res_ids = frame.residue_ids[atom_indices]
charges = frame.charges[atom_indices] charges = frame.charges[atom_indices]
@ -525,9 +525,9 @@ def dipole_vector(
@map_coordinates @map_coordinates
def sum_dipole_vector( def sum_dipole_vector(
coordinates: CoordinateFrame, coordinates: CoordinateFrame,
atom_indices: npt.ArrayLike, atom_indices: ArrayLike,
normed: bool = True, normed: bool = True,
) -> np.ndarray: ) -> NDArray:
coords = coordinates.whole[atom_indices] coords = coordinates.whole[atom_indices]
charges = coordinates.charges[atom_indices] charges = coordinates.charges[atom_indices]
dipole = np.array([c * charges for c in coords.T]).T dipole = np.array([c * charges for c in coords.T]).T
@ -539,11 +539,11 @@ def sum_dipole_vector(
@map_coordinates @map_coordinates
def normal_vectors( def normal_vectors(
frame: CoordinateFrame, frame: CoordinateFrame,
atom_indices_a: npt.ArrayLike, atom_indices_a: ArrayLike,
atom_indices_b: npt.ArrayLike, atom_indices_b: ArrayLike,
atom_indices_c: npt.ArrayLike, atom_indices_c: ArrayLike,
normed: bool = True, normed: bool = True,
) -> np.ndarray: ) -> NDArray:
coords_a = frame[atom_indices_a] coords_a = frame[atom_indices_a]
coords_b = frame[atom_indices_b] coords_b = frame[atom_indices_b]
coords_c = frame[atom_indices_c] coords_c = frame[atom_indices_c]
@ -571,8 +571,8 @@ def displacements_without_drift(
@map_coordinates @map_coordinates
def cylindrical_coordinates( def cylindrical_coordinates(
frame: CoordinateFrame, origin: npt.ArrayLike = None frame: CoordinateFrame, origin: ArrayLike = None
) -> np.ndarray: ) -> NDArray:
if origin is None: if origin is None:
origin = np.diag(frame.box) / 2 origin = np.diag(frame.box) / 2
x = frame[:, 0] - origin[0] x = frame[:, 0] - origin[0]
@ -584,10 +584,10 @@ def cylindrical_coordinates(
def layer_of_atoms( def layer_of_atoms(
atoms: CoordinateFrame, atoms: CoordinateFrame,
thickness: float, thickness: float,
plane_normal: npt.ArrayLike, plane_normal: ArrayLike,
plane_offset: Optional[npt.ArrayLike] = np.array([0, 0, 0]), plane_offset: Optional[ArrayLike] = np.array([0, 0, 0]),
) -> np.array: ) -> np.array:
if plane_offset is None: if plane_offset is None:
np.array([0, 0, 0]) np.array([0, 0, 0])
@ -597,13 +597,13 @@ def layer_of_atoms(
def next_neighbors( def next_neighbors(
atoms: CoordinateFrame, atoms: CoordinateFrame,
query_atoms: Optional[CoordinateFrame] = None, query_atoms: Optional[CoordinateFrame] = None,
number_of_neighbors: int = 1, number_of_neighbors: int = 1,
distance_upper_bound: float = np.inf, distance_upper_bound: float = np.inf,
distinct: bool = False, distinct: bool = False,
**kwargs **kwargs
) -> (np.ndarray, np.ndarray): ) -> Tuple[List, List]:
""" """
Find the N next neighbors of a set of atoms. Find the N next neighbors of a set of atoms.
@ -634,7 +634,15 @@ def next_neighbors(
query_atoms, query_atoms,
number_of_neighbors + dnn, number_of_neighbors + dnn,
distance_upper_bound=distance_upper_bound, distance_upper_bound=distance_upper_bound,
) )
distances = distances[:, dnn:]
indices = indices[:, dnn:]
distances_new = []
indices_new = []
for dist, ind in zip(distances, indices):
distances_new.append(dist[dist <= distance_upper_bound])
indices_new.append(ind[dist <= distance_upper_bound])
return distances_new, indices_new
else: else:
atoms_pbc, atoms_pbc_index = pbc_points( atoms_pbc, atoms_pbc_index = pbc_points(
query_atoms, box, thickness=distance_upper_bound + 0.1, index=True, **kwargs query_atoms, box, thickness=distance_upper_bound + 0.1, index=True, **kwargs
@ -644,7 +652,52 @@ def next_neighbors(
query_atoms, query_atoms,
number_of_neighbors + dnn, number_of_neighbors + dnn,
distance_upper_bound=distance_upper_bound, distance_upper_bound=distance_upper_bound,
) )
indices = atoms_pbc_index[indices] distances = distances[:, dnn:]
indices = indices[:, dnn:]
distances_new = []
indices_new = []
for dist, ind in zip(distances, indices):
distances_new.append(dist[dist <= distance_upper_bound])
indices_new.append(atoms_pbc_index[ind[dist <= distance_upper_bound]])
return distances_new, indices_new
return distances[:, dnn:], indices[:, dnn:]
def number_of_neighbors(
atoms: CoordinateFrame,
query_atoms: Optional[CoordinateFrame] = None,
r_max: float = 1,
distinct: bool = False,
**kwargs
) -> Tuple[List, List]:
"""
Find the N next neighbors of a set of atoms.
Args:
atoms:
The reference atoms and also the atoms which are queried if `query_atoms`
is net provided
query_atoms (opt.): If this is not None, these atoms will be queried
r_max (float, opt.):
Upper bound of the distance between neighbors
distinct (bool, opt.):
If this is true, the atoms and query atoms are taken as distinct sets of
atoms
"""
dnn = 0
if query_atoms is None:
query_atoms = atoms
dnn = 1
elif not distinct:
dnn = 1
box = atoms.box
if np.all(np.diag(np.diag(box)) == box):
atoms = atoms % np.diag(box)
tree = KDTree(atoms, boxsize=np.diag(box))
else:
atoms_pbc = pbc_points(query_atoms, box, thickness=r_max + 0.1, **kwargs)
tree = KDTree(atoms_pbc)
num_of_neighbors = tree.query_ball_point(query_atoms, r_max, return_length=True)
return num_of_neighbors - dnn