Updated type hints

This commit is contained in:
Sebastian Kloth 2023-12-28 12:00:40 +01:00
parent 3cfdf79777
commit b66e920758

View File

@ -1,7 +1,7 @@
from typing import Callable, Optional, Union from typing import Callable, Optional, Union
import numpy as np import numpy as np
from numpy.typing import ArrayLike from numpy.typing import ArrayLike, NDArray
from scipy import spatial from scipy import spatial
from scipy.spatial import KDTree from scipy.spatial import KDTree
from scipy.sparse.csgraph import connected_components from scipy.sparse.csgraph import connected_components
@ -24,7 +24,7 @@ def time_average(
coordinates_b: Optional[Coordinates] = None, coordinates_b: Optional[Coordinates] = None,
skip: float = 0.1, skip: float = 0.1,
segments: int = 100, segments: int = 100,
) -> np.ndarray: ) -> NDArray:
""" """
Compute the time average of a function. Compute the time average of a function.
@ -58,7 +58,7 @@ def gr(
bins: Optional[ArrayLike] = None, bins: Optional[ArrayLike] = None,
remove_intra: bool = False, remove_intra: bool = False,
**kwargs **kwargs
) -> np.ndarray: ) -> NDArray:
r""" r"""
Compute the radial pair distribution of one or two sets of atoms. Compute the radial pair distribution of one or two sets of atoms.
@ -121,7 +121,7 @@ def gr(
def distance_distribution( def distance_distribution(
atoms: CoordinateFrame, bins: Optional[int, ArrayLike] atoms: CoordinateFrame, bins: Optional[int, ArrayLike]
) -> np.ndarray: ) -> NDArray:
connection_vectors = atoms[:-1, :] - atoms[1:, :] connection_vectors = atoms[:-1, :] - atoms[1:, :]
connection_lengths = (connection_vectors**2).sum(axis=1) ** 0.5 connection_lengths = (connection_vectors**2).sum(axis=1) ** 0.5
return np.histogram(connection_lengths, bins)[0] return np.histogram(connection_lengths, bins)[0]
@ -129,7 +129,7 @@ def distance_distribution(
def tetrahedral_order( def tetrahedral_order(
atoms: CoordinateFrame, reference_atoms: CoordinateFrame = None atoms: CoordinateFrame, reference_atoms: CoordinateFrame = None
) -> np.ndarray: ) -> NDArray:
if reference_atoms is None: if reference_atoms is None:
reference_atoms = atoms reference_atoms = atoms
indices = next_neighbors( indices = next_neighbors(
@ -175,7 +175,7 @@ def tetrahedral_order_distribution(
atoms: CoordinateFrame, atoms: CoordinateFrame,
reference_atoms: Optional[CoordinateFrame] = None, reference_atoms: Optional[CoordinateFrame] = None,
bins: Optional[ArrayLike] = None, bins: Optional[ArrayLike] = None,
) -> np.ndarray: ) -> NDArray:
assert bins is not None, "Bin edges of the distribution have to be specified." assert bins is not None, "Bin edges of the distribution have to be specified."
Q = tetrahedral_order(atoms, reference_atoms=reference_atoms) Q = tetrahedral_order(atoms, reference_atoms=reference_atoms)
return np.histogram(Q, bins=bins)[0] return np.histogram(Q, bins=bins)[0]
@ -186,7 +186,7 @@ def radial_density(
bins: Optional[ArrayLike] = None, bins: Optional[ArrayLike] = None,
symmetry_axis: ArrayLike = (0, 0, 1), symmetry_axis: ArrayLike = (0, 0, 1),
origin: Optional[ArrayLike] = None, origin: Optional[ArrayLike] = None,
) -> np.ndarray: ) -> NDArray:
""" """
Calculate the radial density distribution. Calculate the radial density distribution.
@ -223,7 +223,7 @@ def shell_density(
shell_thickness: float = 0.5, shell_thickness: float = 0.5,
symmetry_axis: ArrayLike = (0, 0, 1), symmetry_axis: ArrayLike = (0, 0, 1),
origin: Optional[ArrayLike] = None, origin: Optional[ArrayLike] = None,
) -> np.ndarray: ) -> NDArray:
""" """
Compute the density distribution on a cylindrical shell. Compute the density distribution on a cylindrical shell.
@ -258,7 +258,7 @@ def next_neighbor_distribution(
number_of_neighbors: int = 4, number_of_neighbors: int = 4,
bins: Optional[ArrayLike] = None, bins: Optional[ArrayLike] = None,
normed: bool = True, normed: bool = True,
) -> np.ndarray: ) -> NDArray:
""" """
Compute the distribution of next neighbors with the same residue name. Compute the distribution of next neighbors with the same residue name.
""" """
@ -281,7 +281,7 @@ def hbonds(
HA_lim: float = 0.35, HA_lim: float = 0.35,
min_cos: float = np.cos(30 * np.pi / 180), min_cos: float = np.cos(30 * np.pi / 180),
full_output: bool = False, full_output: bool = False,
) -> Union[np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray]]: ) -> Union[NDArray, tuple[NDArray, NDArray, NDArray]]:
""" """
Compute h-bond pairs Compute h-bond pairs
@ -303,7 +303,7 @@ def hbonds(
def dist_DltA( def dist_DltA(
D: CoordinateFrame, A: CoordinateFrame, max_dist: float = 0.35 D: CoordinateFrame, A: CoordinateFrame, max_dist: float = 0.35
) -> np.ndarray: ) -> NDArray:
ppoints, pind = pbc_points(D, thickness=max_dist + 0.1, index=True) ppoints, pind = pbc_points(D, thickness=max_dist + 0.1, index=True)
Dtree = spatial.cKDTree(ppoints) Dtree = spatial.cKDTree(ppoints)
Atree = spatial.cKDTree(A) Atree = spatial.cKDTree(A)
@ -315,7 +315,7 @@ def hbonds(
def dist_AltD( def dist_AltD(
D: CoordinateFrame, A: CoordinateFrame, max_dist: float = 0.35 D: CoordinateFrame, A: CoordinateFrame, max_dist: float = 0.35
) -> np.ndarray: ) -> NDArray:
ppoints, pind = pbc_points(A, thickness=max_dist + 0.1, index=True) ppoints, pind = pbc_points(A, thickness=max_dist + 0.1, index=True)
Atree = spatial.cKDTree(ppoints) Atree = spatial.cKDTree(ppoints)
Dtree = spatial.cKDTree(D) Dtree = spatial.cKDTree(D)
@ -357,7 +357,7 @@ def hbonds(
return pairs[is_bond] return pairs[is_bond]
def calc_cluster_sizes(atoms: CoordinateFrame, r_max: float = 0.35) -> np.ndarray: def calc_cluster_sizes(atoms: CoordinateFrame, r_max: float = 0.35) -> NDArray:
frame_PBC, indices_PBC = pbc_points(atoms, thickness=r_max + 0.1, index=True) frame_PBC, indices_PBC = pbc_points(atoms, thickness=r_max + 0.1, index=True)
tree = KDTree(frame_PBC) tree = KDTree(frame_PBC)
matrix = tree.sparse_distance_matrix(tree, r_max, output_type="ndarray") matrix = tree.sparse_distance_matrix(tree, r_max, output_type="ndarray")
@ -372,7 +372,7 @@ def calc_cluster_sizes(atoms: CoordinateFrame, r_max: float = 0.35) -> np.ndarra
return np.array(cluster_sizes).flatten() return np.array(cluster_sizes).flatten()
def gyration_radius(position: CoordinateFrame) -> np.ndarray: def gyration_radius(position: CoordinateFrame) -> NDArray:
r""" r"""
Calculates a list of all radii of gyration of all molecules given in the coordinate Calculates a list of all radii of gyration of all molecules given in the coordinate
frame, weighted with the masses of the individual atoms. frame, weighted with the masses of the individual atoms.