From c40ea052b85035f3f689dc06d02d0f4d3028f2e6 Mon Sep 17 00:00:00 2001 From: Sebastian Kloth Date: Fri, 8 Dec 2023 17:20:06 +0100 Subject: [PATCH] Added new function for calculating the occupation matrix --- src/mdevaluate/free_energy_landscape.py | 53 ++++++++++++++++++++++++- test/test_free_energy_landscape.py | 44 ++++++++++---------- 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/src/mdevaluate/free_energy_landscape.py b/src/mdevaluate/free_energy_landscape.py index fc1753a..429f951 100644 --- a/src/mdevaluate/free_energy_landscape.py +++ b/src/mdevaluate/free_energy_landscape.py @@ -9,10 +9,61 @@ import cmath import pandas as pd import multiprocessing as mp - VALID_GEOMETRY = {"cylindrical", "slab"} +def occupation_matrix(trajectory, edge_length=0.05, segments=1000, skip=0.1, nodes=8): + frame_indices = np.unique( + np.int_(np.linspace(len(trajectory) * skip, len(trajectory) - 1, num=segments)) + ) + + box = trajectory[0].box + x_bins = np.arange(0, box[0][0] + edge_length, edge_length) + y_bins = np.arange(0, box[1][1] + edge_length, edge_length) + z_bins = np.arange(0, box[2][2] + edge_length, edge_length) + bins = [x_bins, y_bins, z_bins] + # Trajectory is split for parallel computing + size = math.ceil(len(frame_indices) / nodes) + indices = [ + np.arange(len(frame_indices))[i : i + size] + for i in range(0, len(frame_indices), size) + ] + pool = mp.Pool(nodes) + results = pool.map( + partial(_calc_histogram, trajectory=trajectory, bins=bins), indices + ) + pool.close() + matbin = np.sum(results, axis=0) + x = (x_bins[:-1] + x_bins[1:]) / 2 + y = (y_bins[:-1] + y_bins[1:]) / 2 + z = (z_bins[:-1] + z_bins[1:]) / 2 + + coords = np.array(np.meshgrid(x, y, z, indexing="ij")) + coords = np.array([x.flatten() for x in coords]) + matbin_new = matbin.flatten() + occupation_df = pd.DataFrame( + {"x": coords[0], "y": coords[1], "z": coords[2], "occupation": matbin_new} + ) + occupation_df = occupation_df.query("occupation != 0") + return occupation_df + + +def _calc_histogram(numberlist, trajectory, bins): + matbin = None + for index in range(0, len(numberlist), 1000): + try: + indices = numberlist[index : index + 1000] + except IndexError: + indices = numberlist[index:] + frames = np.concatenate(np.array([trajectory.pbc[i] for i in indices])) + hist, _ = np.histogramdd(frames, bins=bins) + if matbin is None: + matbin = hist + else: + matbin += hist + return matbin + + def get_fel( traj, path, diff --git a/test/test_free_energy_landscape.py b/test/test_free_energy_landscape.py index 843bcb5..44b131f 100644 --- a/test/test_free_energy_landscape.py +++ b/test/test_free_energy_landscape.py @@ -9,39 +9,39 @@ from mdevaluate import free_energy_landscape as fel @pytest.fixture def trajectory(request): - return mdevaluate.open(os.path.join(os.path.dirname(__file__), 'data/pore')) + return mdevaluate.open(os.path.join(os.path.dirname(__file__), "data/pore")) def test_get_fel(trajectory): test_array = np.array( [ 0.0, - 13.162354034697204, - 5.327100985208421, - 9.558746399158396, - 4.116475238453127, - 6.305715728953043, - 3.231102391108276, - 5.896478799115712, - 8.381981206446293, - 5.1191684352849816, - 5.361112857237105, - 8.053932845998895, - 6.895396051256847, - 7.588888886900885, - 11.223429636542576, - 3.779149304024221, - 40.64319010769286, - 93.1120609754045, - 136.99287780099627, - 171.4403749377496, + 12.87438176, + 4.95868203, + 11.02055197, + 5.44195534, + 6.73933442, + 3.30971789, + 6.10424055, + 8.56153733, + 5.45777331, + 5.64545817, + 8.42100423, + 6.28132121, + 7.4777172, + 11.64839354, + 4.52566354, + 40.84730838, + 93.86241602, + 140.3039937, + 173.55970021, ] ) oxygens_water = trajectory.subset(atom_name="OW", residue_name="SOL") r, energy_differences = fel.get_fel( oxygens_water, - os.path.join(os.path.dirname(__file__), 'data/pore'), + os.path.join(os.path.dirname(__file__), "data/pore"), "cylindrical", 225, edge=0.05, @@ -51,4 +51,4 @@ def test_get_fel(trajectory): overwrite=True, ) - assert (energy_differences == test_array).all() + assert (np.round(energy_differences) == np.round(test_array)).all()