Added missing functions from mdeval_skloth

This commit is contained in:
Sebastian Kloth 2023-11-06 15:53:56 +01:00
parent 9c253a2f58
commit 118ed8e6af
2 changed files with 29 additions and 2 deletions

View File

@ -82,6 +82,17 @@ def radial_selector(frame, coordinates, rmin, rmax):
return mask2indices(selector) return mask2indices(selector)
def selector_radial_cylindrical(atoms, r_min, r_max, origin=None):
box = atoms.box
atoms = atoms % np.diag(box)
if origin is None:
origin = [box[0, 0] / 2, box[1, 1] / 2, box[2, 2] / 2]
r_vec = (atoms - origin)[:, :2]
r = np.linalg.norm(r_vec, axis=1)
index = np.argwhere((r >= r_min) * (r < r_max))
return index.flatten()
def spatial_selector(frame, transform, rmin, rmax): def spatial_selector(frame, transform, rmin, rmax):
""" """
Select a subset of atoms which have a radius between rmin and rmax. Select a subset of atoms which have a radius between rmin and rmax.
@ -411,7 +422,7 @@ def map_coordinates(func):
@map_coordinates @map_coordinates
def center_of_masses(coordinates, atoms, shear: bool = False): def center_of_masses(coordinates, atoms=None, shear: bool = False):
""" """
Example: Example:
rd = XTCReader('t.xtc') rd = XTCReader('t.xtc')
@ -419,6 +430,8 @@ def center_of_masses(coordinates, atoms, shear: bool = False):
com = centers_of_mass(coordinates, (1.0, 2.0, 1.0, 3.0)) com = centers_of_mass(coordinates, (1.0, 2.0, 1.0, 3.0))
""" """
if atoms is None:
atoms = list(range(len(coordinates)))
res_ids = coordinates.residue_ids[atoms] res_ids = coordinates.residue_ids[atoms]
masses = coordinates.masses[atoms] masses = coordinates.masses[atoms]
if shear: if shear:
@ -427,7 +440,7 @@ def center_of_masses(coordinates, atoms, shear: bool = False):
sort_ind = res_ids.argsort(kind="stable") sort_ind = res_ids.argsort(kind="stable")
i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1]) i = np.concatenate([[0], np.where(np.diff(res_ids[sort_ind]) > 0)[0] + 1])
coms = coords[sort_ind[i]][res_ids - min(res_ids)] coms = coords[sort_ind[i]][res_ids - min(res_ids)]
cor = md.pbc.pbc_diff(coords, coms, box) cor = pbc_diff(coords, coms, box)
coords = coms + cor coords = coms + cor
else: else:
coords = coordinates.whole[atoms] coords = coordinates.whole[atoms]

View File

@ -2,6 +2,7 @@
Collection of utility functions. Collection of utility functions.
""" """
import functools import functools
from time import time
from types import FunctionType from types import FunctionType
import numpy as np import numpy as np
@ -486,3 +487,16 @@ def fibonacci_sphere(samples=1000):
points.append((x, y, z)) points.append((x, y, z))
return np.array(points) return np.array(points)
def timing(function):
@functools.wraps(function)
def wrap(*args, **kw):
start_time = time()
result = function(*args, **kw)
end_time = time()
time_needed = end_time - start_time
print(f"Finished in {int(time_needed // 60)} min " f"{int(time_needed % 60)} s")
return result
return wrap