From 5fdc9c8698b2b009a0ec7522fd28c9f65c6241b3 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Tue, 16 Jan 2024 13:38:44 +0100 Subject: [PATCH] Adjusted type hints --- src/mdevaluate/coordinates.py | 163 ++++++++++++++++++++++------------ 1 file changed, 108 insertions(+), 55 deletions(-) diff --git a/src/mdevaluate/coordinates.py b/src/mdevaluate/coordinates.py index bfeac00..fc50995 100755 --- a/src/mdevaluate/coordinates.py +++ b/src/mdevaluate/coordinates.py @@ -1,10 +1,10 @@ from functools import partial, wraps from copy import copy from .logging import logger -from typing import Optional, Callable +from typing import Optional, Callable, List, Tuple import numpy as np -import numpy.typing as npt +from numpy.typing import ArrayLike, NDArray from scipy.spatial import KDTree from .atoms import AtomSubset @@ -17,7 +17,7 @@ class UnknownCoordinatesMode(Exception): pass -class CoordinateFrame(np.ndarray): +class CoordinateFrame(NDArray): _known_modes = ("pbc", "whole", "nojump") @property @@ -99,7 +99,7 @@ class CoordinateFrame(np.ndarray): box=None, mode=None, ): - obj = np.ndarray.__new__(subtype, shape, dtype, buffer, offset, strides) + obj = NDArray.__new__(subtype, shape, dtype, buffer, offset, strides) obj.coordinates = coordinates obj.step = step @@ -319,7 +319,7 @@ class CoordinatesMap: return CoordinatesMap(self.coordinates.pbc, self.function) -def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray: +def rotate_axis(coords: ArrayLike, axis: ArrayLike) -> NDArray: """ Rotate a set of coordinates to a given axis. """ @@ -352,8 +352,8 @@ def rotate_axis(coords: npt.ArrayLike, axis: npt.ArrayLike) -> np.ndarray: def spherical_radius( - frame: CoordinateFrame, origin: Optional[npt.ArrayLike] = None -) -> np.ndarray: + frame: CoordinateFrame, origin: Optional[ArrayLike] = None +) -> NDArray: """ Transform a frame of cartesian coordinates into the spherical radius. If origin=None, the center of the box is taken as the coordinates' origin. @@ -363,7 +363,7 @@ def spherical_radius( return ((frame - origin) ** 2).sum(axis=-1) ** 0.5 -def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.ndarray): +def polar_coordinates(x: ArrayLike, y: ArrayLike) -> (NDArray, NDArray): """Convert cartesian to polar coordinates.""" radius = (x**2 + y**2) ** 0.5 phi = np.arctan2(y, x) @@ -371,8 +371,8 @@ def polar_coordinates(x: npt.ArrayLike, y: npt.ArrayLike) -> (np.ndarray, np.nda def spherical_coordinates( - x: npt.ArrayLike, y: npt.ArrayLike, z: npt.ArrayLike -) -> (np.ndarray, np.ndarray, np.ndarray): + x: ArrayLike, y: ArrayLike, z: ArrayLike +) -> (NDArray, NDArray, NDArray): """Convert cartesian to spherical coordinates.""" xy, phi = polar_coordinates(x, y) radius = (x**2 + y**2 + z**2) ** 0.5 @@ -384,8 +384,8 @@ def selector_radial_cylindrical( atoms: CoordinateFrame, r_min: float, r_max: float, - origin: Optional[npt.ArrayLike] = None, -) -> np.ndarray: + origin: Optional[ArrayLike] = None, +) -> NDArray: box = atoms.box atoms = atoms % np.diag(box) if origin is None: @@ -397,7 +397,7 @@ def selector_radial_cylindrical( def map_coordinates( - func: Callable[[CoordinateFrame, ...], np.ndarray] + func: Callable[[CoordinateFrame, ...], NDArray] ) -> Callable[..., CoordinatesMap]: @wraps(func) def wrapped(coordinates: Coordinates, **kwargs) -> CoordinatesMap: @@ -408,14 +408,14 @@ def map_coordinates( @map_coordinates def center_of_masses( - frame: CoordinateFrame, atoms=None, shear: bool = False -) -> np.ndarray: - if atoms is None: - atoms = list(range(len(frame))) - res_ids = frame.residue_ids[atoms] - masses = frame.masses[atoms] + frame: CoordinateFrame, atom_indices=None, shear: bool = False +) -> NDArray: + if atom_indices is None: + atom_indices = list(range(len(frame))) + res_ids = frame.residue_ids[atom_indices] + masses = frame.masses[atom_indices] if shear: - coords = frame[atoms] + coords = frame[atom_indices] box = frame.box sort_ind = res_ids.argsort(kind="stable") i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1]) @@ -423,7 +423,7 @@ def center_of_masses( cor = pbc_diff(coords, coms, box) coords = coms + cor else: - coords = frame.whole[atoms] + coords = frame.whole[atom_indices] mask = np.bincount(res_ids)[1:] != 0 positions = np.array( [ @@ -437,8 +437,8 @@ def center_of_masses( @map_coordinates def pore_coordinates( - frame: CoordinateFrame, origin: npt.ArrayLike, sym_axis: str = "z" -) -> np.ndarray: + frame: CoordinateFrame, origin: ArrayLike, sym_axis: str = "z" +) -> NDArray: """ Map coordinates of a pore simulation so the pore has cylindrical symmetry. @@ -459,17 +459,17 @@ def pore_coordinates( @map_coordinates def vectors( frame: CoordinateFrame, - atoms_indices_a: npt.ArrayLike, - atoms_indices_b: npt.ArrayLike, + atom_indices_a: ArrayLike, + atom_indices_b: ArrayLike, normed: bool = False, -) -> np.ndarray: +) -> NDArray: """ Compute the vectors between the atoms of two subsets. Args: frame: The Coordinates object the atoms will be taken from - atoms_indices_a: Mask or indices of the first atom subset - atoms_indices_b: Mask or indices of the second atom subset + atom_indices_a: Mask or indices of the first atom subset + atom_indices_b: Mask or indices of the second atom subset normed (opt.): If the vectors should be normed The definition of atoms_a/b can be any possible subript of a numpy array. @@ -492,10 +492,10 @@ def vectors( ]) """ box = frame.box - coords_a = frame[atoms_indices_a] + coords_a = frame[atom_indices_a] if len(coords_a.shape) > 2: coords_a = coords_a.mean(axis=0) - coords_b = frame[atoms_indices_b] + coords_b = frame[atom_indices_b] if len(coords_b.shape) > 2: coords_b = coords_b.mean(axis=0) vec = pbc_diff(coords_a, coords_b, box=box) @@ -507,8 +507,8 @@ def vectors( @map_coordinates def dipole_vector( - frame: CoordinateFrame, atom_indices: npt.ArrayLike, normed: bool = None -) -> np.ndarray: + frame: CoordinateFrame, atom_indices: ArrayLike, normed: bool = None +) -> NDArray: coords = frame.whole[atom_indices] res_ids = frame.residue_ids[atom_indices] charges = frame.charges[atom_indices] @@ -525,9 +525,9 @@ def dipole_vector( @map_coordinates def sum_dipole_vector( coordinates: CoordinateFrame, - atom_indices: npt.ArrayLike, + atom_indices: ArrayLike, normed: bool = True, -) -> np.ndarray: +) -> NDArray: coords = coordinates.whole[atom_indices] charges = coordinates.charges[atom_indices] dipole = np.array([c * charges for c in coords.T]).T @@ -539,11 +539,11 @@ def sum_dipole_vector( @map_coordinates def normal_vectors( frame: CoordinateFrame, - atom_indices_a: npt.ArrayLike, - atom_indices_b: npt.ArrayLike, - atom_indices_c: npt.ArrayLike, + atom_indices_a: ArrayLike, + atom_indices_b: ArrayLike, + atom_indices_c: ArrayLike, normed: bool = True, -) -> np.ndarray: +) -> NDArray: coords_a = frame[atom_indices_a] coords_b = frame[atom_indices_b] coords_c = frame[atom_indices_c] @@ -571,8 +571,8 @@ def displacements_without_drift( @map_coordinates def cylindrical_coordinates( - frame: CoordinateFrame, origin: npt.ArrayLike = None -) -> np.ndarray: + frame: CoordinateFrame, origin: ArrayLike = None +) -> NDArray: if origin is None: origin = np.diag(frame.box) / 2 x = frame[:, 0] - origin[0] @@ -584,10 +584,10 @@ def cylindrical_coordinates( def layer_of_atoms( - atoms: CoordinateFrame, - thickness: float, - plane_normal: npt.ArrayLike, - plane_offset: Optional[npt.ArrayLike] = np.array([0, 0, 0]), + atoms: CoordinateFrame, + thickness: float, + plane_normal: ArrayLike, + plane_offset: Optional[ArrayLike] = np.array([0, 0, 0]), ) -> np.array: if plane_offset is None: np.array([0, 0, 0]) @@ -597,13 +597,13 @@ def layer_of_atoms( def next_neighbors( - atoms: CoordinateFrame, - query_atoms: Optional[CoordinateFrame] = None, - number_of_neighbors: int = 1, - distance_upper_bound: float = np.inf, - distinct: bool = False, - **kwargs -) -> (np.ndarray, np.ndarray): + atoms: CoordinateFrame, + query_atoms: Optional[CoordinateFrame] = None, + number_of_neighbors: int = 1, + distance_upper_bound: float = np.inf, + distinct: bool = False, + **kwargs +) -> Tuple[List, List]: """ Find the N next neighbors of a set of atoms. @@ -634,7 +634,15 @@ def next_neighbors( query_atoms, number_of_neighbors + dnn, distance_upper_bound=distance_upper_bound, - ) + ) + distances = distances[:, dnn:] + indices = indices[:, dnn:] + distances_new = [] + indices_new = [] + for dist, ind in zip(distances, indices): + distances_new.append(dist[dist <= distance_upper_bound]) + indices_new.append(ind[dist <= distance_upper_bound]) + return distances_new, indices_new else: atoms_pbc, atoms_pbc_index = pbc_points( query_atoms, box, thickness=distance_upper_bound + 0.1, index=True, **kwargs @@ -644,7 +652,52 @@ def next_neighbors( query_atoms, number_of_neighbors + dnn, distance_upper_bound=distance_upper_bound, - ) - indices = atoms_pbc_index[indices] + ) + distances = distances[:, dnn:] + indices = indices[:, dnn:] + distances_new = [] + indices_new = [] + for dist, ind in zip(distances, indices): + distances_new.append(dist[dist <= distance_upper_bound]) + indices_new.append(atoms_pbc_index[ind[dist <= distance_upper_bound]]) + return distances_new, indices_new - return distances[:, dnn:], indices[:, dnn:] + +def number_of_neighbors( + atoms: CoordinateFrame, + query_atoms: Optional[CoordinateFrame] = None, + r_max: float = 1, + distinct: bool = False, + **kwargs +) -> Tuple[List, List]: + """ + Find the N next neighbors of a set of atoms. + + Args: + atoms: + The reference atoms and also the atoms which are queried if `query_atoms` + is net provided + query_atoms (opt.): If this is not None, these atoms will be queried + r_max (float, opt.): + Upper bound of the distance between neighbors + distinct (bool, opt.): + If this is true, the atoms and query atoms are taken as distinct sets of + atoms + """ + dnn = 0 + if query_atoms is None: + query_atoms = atoms + dnn = 1 + elif not distinct: + dnn = 1 + + box = atoms.box + if np.all(np.diag(np.diag(box)) == box): + atoms = atoms % np.diag(box) + tree = KDTree(atoms, boxsize=np.diag(box)) + else: + atoms_pbc = pbc_points(query_atoms, box, thickness=r_max + 0.1, **kwargs) + tree = KDTree(atoms_pbc) + + num_of_neighbors = tree.query_ball_point(query_atoms, r_max, return_length=True) + return num_of_neighbors - dnn