Adjusted type hints
This commit is contained in:
parent
95b46c43be
commit
5fdc9c8698
@ -1,10 +1,10 @@
|
||||
from functools import partial, wraps
|
||||
from copy import copy
|
||||
from .logging import logger
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, Callable, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
from numpy.typing import ArrayLike, NDArray
|
||||
from scipy.spatial import KDTree
|
||||
|
||||
from .atoms import AtomSubset
|
||||
@ -17,7 +17,7 @@ class UnknownCoordinatesMode(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CoordinateFrame(np.ndarray):
|
||||
class CoordinateFrame(NDArray):
|
||||
_known_modes = ("pbc", "whole", "nojump")
|
||||
|
||||
@property
|
||||
@ -99,7 +99,7 @@ class CoordinateFrame(np.ndarray):
|
||||
box=None,
|
||||
mode=None,
|
||||
):
|
||||
obj = np.ndarray.__new__(subtype, shape, dtype, buffer, offset, strides)
|
||||
obj = NDArray.__new__(subtype, shape, dtype, buffer, offset, strides)
|
||||
|
||||
obj.coordinates = coordinates
|
||||
obj.step = step
|
||||
@ -319,7 +319,7 @@ class CoordinatesMap:
|
||||
return CoordinatesMap(self.coordinates.pbc, self.function)
|
||||
|
||||
|
||||
def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray:
|
||||
def rotate_axis(coords: ArrayLike, axis: ArrayLike) -> NDArray:
|
||||
"""
|
||||
Rotate a set of coordinates to a given axis.
|
||||
"""
|
||||
@ -352,8 +352,8 @@ def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray:
|
||||
|
||||
|
||||
def spherical_radius(
|
||||
frame: CoordinateFrame, origin: Optional[npt.ArrayLike] = None
|
||||
) -> np.ndarray:
|
||||
frame: CoordinateFrame, origin: Optional[ArrayLike] = None
|
||||
) -> NDArray:
|
||||
"""
|
||||
Transform a frame of cartesian coordinates into the spherical radius.
|
||||
If origin=None, the center of the box is taken as the coordinates' origin.
|
||||
@ -363,7 +363,7 @@ def spherical_radius(
|
||||
return ((frame - origin) ** 2).sum(axis=-1) ** 0.5
|
||||
|
||||
|
||||
def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.ndarray):
|
||||
def polar_coordinates(x: ArrayLike, y: ArrayLike) -> (NDArray, NDArray):
|
||||
"""Convert cartesian to polar coordinates."""
|
||||
radius = (x**2 + y**2) ** 0.5
|
||||
phi = np.arctan2(y, x)
|
||||
@ -371,8 +371,8 @@ def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.nda
|
||||
|
||||
|
||||
def spherical_coordinates(
|
||||
x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike
|
||||
) -> (np.ndarray, np.ndarray, np.ndarray):
|
||||
x: ArrayLike, y: ArrayLike, z: ArrayLike
|
||||
) -> (NDArray, NDArray, NDArray):
|
||||
"""Convert cartesian to spherical coordinates."""
|
||||
xy, phi = polar_coordinates(x, y)
|
||||
radius = (x**2 + y**2 + z**2) ** 0.5
|
||||
@ -384,8 +384,8 @@ def selector_radial_cylindrical(
|
||||
atoms: CoordinateFrame,
|
||||
r_min: float,
|
||||
r_max: float,
|
||||
origin: Optional[npt.ArrayLike] = None,
|
||||
) -> np.ndarray:
|
||||
origin: Optional[ArrayLike] = None,
|
||||
) -> NDArray:
|
||||
box = atoms.box
|
||||
atoms = atoms % np.diag(box)
|
||||
if origin is None:
|
||||
@ -397,7 +397,7 @@ def selector_radial_cylindrical(
|
||||
|
||||
|
||||
def map_coordinates(
|
||||
func: Callable[[CoordinateFrame, ...], np.ndarray]
|
||||
func: Callable[[CoordinateFrame, ...], NDArray]
|
||||
) -> Callable[..., CoordinatesMap]:
|
||||
@wraps(func)
|
||||
def wrapped(coordinates: Coordinates, **kwargs) -> CoordinatesMap:
|
||||
@ -408,14 +408,14 @@ def map_coordinates(
|
||||
|
||||
@map_coordinates
|
||||
def center_of_masses(
|
||||
frame: CoordinateFrame, atoms=None, shear: bool = False
|
||||
) -> np.ndarray:
|
||||
if atoms is None:
|
||||
atoms = list(range(len(frame)))
|
||||
res_ids = frame.residue_ids[atoms]
|
||||
masses = frame.masses[atoms]
|
||||
frame: CoordinateFrame, atom_indices=None, shear: bool = False
|
||||
) -> NDArray:
|
||||
if atom_indices is None:
|
||||
atom_indices = list(range(len(frame)))
|
||||
res_ids = frame.residue_ids[atom_indices]
|
||||
masses = frame.masses[atom_indices]
|
||||
if shear:
|
||||
coords = frame[atoms]
|
||||
coords = frame[atom_indices]
|
||||
box = frame.box
|
||||
sort_ind = res_ids.argsort(kind="stable")
|
||||
i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1])
|
||||
@ -423,7 +423,7 @@ def center_of_masses(
|
||||
cor = pbc_diff(coords, coms, box)
|
||||
coords = coms + cor
|
||||
else:
|
||||
coords = frame.whole[atoms]
|
||||
coords = frame.whole[atom_indices]
|
||||
mask = np.bincount(res_ids)[1:] != 0
|
||||
positions = np.array(
|
||||
[
|
||||
@ -437,8 +437,8 @@ def center_of_masses(
|
||||
|
||||
@map_coordinates
|
||||
def pore_coordinates(
|
||||
frame: CoordinateFrame, origin: npt.ArrayLike, sym_axis: str = "z"
|
||||
) -> np.ndarray:
|
||||
frame: CoordinateFrame, origin: ArrayLike, sym_axis: str = "z"
|
||||
) -> NDArray:
|
||||
"""
|
||||
Map coordinates of a pore simulation so the pore has cylindrical symmetry.
|
||||
|
||||
@ -459,17 +459,17 @@ def pore_coordinates(
|
||||
@map_coordinates
|
||||
def vectors(
|
||||
frame: CoordinateFrame,
|
||||
atoms_indices_a: npt.ArrayLike,
|
||||
atoms_indices_b: npt.ArrayLike,
|
||||
atom_indices_a: ArrayLike,
|
||||
atom_indices_b: ArrayLike,
|
||||
normed: bool = False,
|
||||
) -> np.ndarray:
|
||||
) -> NDArray:
|
||||
"""
|
||||
Compute the vectors between the atoms of two subsets.
|
||||
|
||||
Args:
|
||||
frame: The Coordinates object the atoms will be taken from
|
||||
atoms_indices_a: Mask or indices of the first atom subset
|
||||
atoms_indices_b: Mask or indices of the second atom subset
|
||||
atom_indices_a: Mask or indices of the first atom subset
|
||||
atom_indices_b: Mask or indices of the second atom subset
|
||||
normed (opt.): If the vectors should be normed
|
||||
|
||||
The definition of atoms_a/b can be any possible subript of a numpy array.
|
||||
@ -492,10 +492,10 @@ def vectors(
|
||||
])
|
||||
"""
|
||||
box = frame.box
|
||||
coords_a = frame[atoms_indices_a]
|
||||
coords_a = frame[atom_indices_a]
|
||||
if len(coords_a.shape) > 2:
|
||||
coords_a = coords_a.mean(axis=0)
|
||||
coords_b = frame[atoms_indices_b]
|
||||
coords_b = frame[atom_indices_b]
|
||||
if len(coords_b.shape) > 2:
|
||||
coords_b = coords_b.mean(axis=0)
|
||||
vec = pbc_diff(coords_a, coords_b, box=box)
|
||||
@ -507,8 +507,8 @@ def vectors(
|
||||
|
||||
@map_coordinates
|
||||
def dipole_vector(
|
||||
frame: CoordinateFrame, atom_indices: npt.ArrayLike, normed: bool = None
|
||||
) -> np.ndarray:
|
||||
frame: CoordinateFrame, atom_indices: ArrayLike, normed: bool = None
|
||||
) -> NDArray:
|
||||
coords = frame.whole[atom_indices]
|
||||
res_ids = frame.residue_ids[atom_indices]
|
||||
charges = frame.charges[atom_indices]
|
||||
@ -525,9 +525,9 @@ def dipole_vector(
|
||||
@map_coordinates
|
||||
def sum_dipole_vector(
|
||||
coordinates: CoordinateFrame,
|
||||
atom_indices: npt.ArrayLike,
|
||||
atom_indices: ArrayLike,
|
||||
normed: bool = True,
|
||||
) -> np.ndarray:
|
||||
) -> NDArray:
|
||||
coords = coordinates.whole[atom_indices]
|
||||
charges = coordinates.charges[atom_indices]
|
||||
dipole = np.array([c * charges for c in coords.T]).T
|
||||
@ -539,11 +539,11 @@ def sum_dipole_vector(
|
||||
@map_coordinates
|
||||
def normal_vectors(
|
||||
frame: CoordinateFrame,
|
||||
atom_indices_a: npt.ArrayLike,
|
||||
atom_indices_b: npt.ArrayLike,
|
||||
atom_indices_c: npt.ArrayLike,
|
||||
atom_indices_a: ArrayLike,
|
||||
atom_indices_b: ArrayLike,
|
||||
atom_indices_c: ArrayLike,
|
||||
normed: bool = True,
|
||||
) -> np.ndarray:
|
||||
) -> NDArray:
|
||||
coords_a = frame[atom_indices_a]
|
||||
coords_b = frame[atom_indices_b]
|
||||
coords_c = frame[atom_indices_c]
|
||||
@ -571,8 +571,8 @@ def displacements_without_drift(
|
||||
|
||||
@map_coordinates
|
||||
def cylindrical_coordinates(
|
||||
frame: CoordinateFrame, origin: npt.ArrayLike = None
|
||||
) -> np.ndarray:
|
||||
frame: CoordinateFrame, origin: ArrayLike = None
|
||||
) -> NDArray:
|
||||
if origin is None:
|
||||
origin = np.diag(frame.box) / 2
|
||||
x = frame[:, 0] - origin[0]
|
||||
@ -586,8 +586,8 @@ def cylindrical_coordinates(
|
||||
def layer_of_atoms(
|
||||
atoms: CoordinateFrame,
|
||||
thickness: float,
|
||||
plane_normal: npt.ArrayLike,
|
||||
plane_offset: Optional[npt.ArrayLike] = np.array([0, 0, 0]),
|
||||
plane_normal: ArrayLike,
|
||||
plane_offset: Optional[ArrayLike] = np.array([0, 0, 0]),
|
||||
) -> np.array:
|
||||
if plane_offset is None:
|
||||
np.array([0, 0, 0])
|
||||
@ -603,7 +603,7 @@ def next_neighbors(
|
||||
distance_upper_bound: float = np.inf,
|
||||
distinct: bool = False,
|
||||
**kwargs
|
||||
) -> (np.ndarray, np.ndarray):
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
Find the N next neighbors of a set of atoms.
|
||||
|
||||
@ -635,6 +635,14 @@ def next_neighbors(
|
||||
number_of_neighbors + dnn,
|
||||
distance_upper_bound=distance_upper_bound,
|
||||
)
|
||||
distances = distances[:, dnn:]
|
||||
indices = indices[:, dnn:]
|
||||
distances_new = []
|
||||
indices_new = []
|
||||
for dist, ind in zip(distances, indices):
|
||||
distances_new.append(dist[dist <= distance_upper_bound])
|
||||
indices_new.append(ind[dist <= distance_upper_bound])
|
||||
return distances_new, indices_new
|
||||
else:
|
||||
atoms_pbc, atoms_pbc_index = pbc_points(
|
||||
query_atoms, box, thickness=distance_upper_bound + 0.1, index=True, **kwargs
|
||||
@ -645,6 +653,51 @@ def next_neighbors(
|
||||
number_of_neighbors + dnn,
|
||||
distance_upper_bound=distance_upper_bound,
|
||||
)
|
||||
indices = atoms_pbc_index[indices]
|
||||
distances = distances[:, dnn:]
|
||||
indices = indices[:, dnn:]
|
||||
distances_new = []
|
||||
indices_new = []
|
||||
for dist, ind in zip(distances, indices):
|
||||
distances_new.append(dist[dist <= distance_upper_bound])
|
||||
indices_new.append(atoms_pbc_index[ind[dist <= distance_upper_bound]])
|
||||
return distances_new, indices_new
|
||||
|
||||
return distances[:, dnn:], indices[:, dnn:]
|
||||
|
||||
def number_of_neighbors(
|
||||
atoms: CoordinateFrame,
|
||||
query_atoms: Optional[CoordinateFrame] = None,
|
||||
r_max: float = 1,
|
||||
distinct: bool = False,
|
||||
**kwargs
|
||||
) -> Tuple[List, List]:
|
||||
"""
|
||||
Find the N next neighbors of a set of atoms.
|
||||
|
||||
Args:
|
||||
atoms:
|
||||
The reference atoms and also the atoms which are queried if `query_atoms`
|
||||
is net provided
|
||||
query_atoms (opt.): If this is not None, these atoms will be queried
|
||||
r_max (float, opt.):
|
||||
Upper bound of the distance between neighbors
|
||||
distinct (bool, opt.):
|
||||
If this is true, the atoms and query atoms are taken as distinct sets of
|
||||
atoms
|
||||
"""
|
||||
dnn = 0
|
||||
if query_atoms is None:
|
||||
query_atoms = atoms
|
||||
dnn = 1
|
||||
elif not distinct:
|
||||
dnn = 1
|
||||
|
||||
box = atoms.box
|
||||
if np.all(np.diag(np.diag(box)) == box):
|
||||
atoms = atoms % np.diag(box)
|
||||
tree = KDTree(atoms, boxsize=np.diag(box))
|
||||
else:
|
||||
atoms_pbc = pbc_points(query_atoms, box, thickness=r_max + 0.1, **kwargs)
|
||||
tree = KDTree(atoms_pbc)
|
||||
|
||||
num_of_neighbors = tree.query_ball_point(query_atoms, r_max, return_length=True)
|
||||
return num_of_neighbors - dnn
|
||||
|
Loading…
Reference in New Issue
Block a user