Skip to content
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ channels:
dependencies:
- click
- MDAnalysis
- MDAnalysisTests
- netCDF4
- openff-units
- pip
Expand Down
257 changes: 166 additions & 91 deletions src/openfe_analysis/rmsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import MDAnalysis as mda
import netCDF4 as nc
import numpy as np
import tqdm
from MDAnalysis.analysis import rms
from MDAnalysis.analysis import diffusionmap, rms
from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.transformations import unwrap
from numpy import typing as npt

from .reader import FEReader
from .transformations import Aligner, ClosestImageShift, NoJump
Expand Down Expand Up @@ -100,6 +99,130 @@ def make_Universe(top: pathlib.Path, trj: nc.Dataset, state: int) -> mda.Univers
return u


class Protein2DRMSD(AnalysisBase):
"""
Flattened 2D RMSD matrix

For all unique frame pairs ``(i, j)`` with ``i < j``, this function
computes the RMSD between atomic coordinates after optimal alignment.
"""

def __init__(self, atomgroup, weights=None, **kwargs):
"""
Parameters
----------
atomgroup: AtomGroup
Protein atoms (e.g. CA selection)
weights: np.ndarray, optional
Per-atom weights to use in the RMSD calculation. If ``None``,
all atoms are weighted equally.
"""
super(Protein2DRMSD, self).__init__(atomgroup.universe.trajectory, **kwargs)

self._weights = weights
self._ag = atomgroup

def _prepare(self):
self._coords = []
self.results.rmsd2d = []

def _single_frame(self):
self._coords.append(self._ag.positions)

def _conclude(self):
positions = np.asarray(self._coords)
nframes, _, _ = positions.shape

output = []
for i, j in itertools.combinations(range(nframes), 2):
posi, posj = positions[i], positions[j]
rmsd = rms.rmsd(
posi,
posj,
self._weights,
center=True,
superposition=True,
)
output.append(rmsd)

self.results.rmsd2d = np.asarray(output)


class RMSDAnalysis(AnalysisBase):
"""
1D RMSD time series for an AtomGroup.

Parameters
----------
atomgroup : MDAnalysis.AtomGroup
Atoms to compute RMSD for.
reference: Optional[MDAnalysis.AtomGroup]
Reference AtomGroup. If None, the first frame of the trajectory will be used.
mass_weighted : bool, optional
If True, compute mass-weighted RMSD.
"""

def __init__(
self, atomgroup, reference=None, mass_weighted=False, superposition=False, **kwargs
):
super(RMSDAnalysis, self).__init__(atomgroup.universe.trajectory, **kwargs)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_analysis_algorithm_is_parallelizable

self._ag = atomgroup
self._reference = reference if reference is not None else self._ag
self._mass_weighted = mass_weighted
self._superposition = superposition

def _prepare(self):
self.results.rmsd = []

self._reference_pos = self._reference.positions

if self._mass_weighted:
self._weights = self._ag.masses / np.mean(self._ag.masses)
else:
self._weights = None

def _single_frame(self):
rmsd = rms.rmsd(
self._ag.positions,
self._reference_pos,
self._weights,
center=False,
superposition=self._superposition,
)
self.results.rmsd.append(rmsd)

def _conclude(self):
self.results.rmsd = np.asarray(self.results.rmsd)


class LigandCOMDrift(AnalysisBase):
"""
Ligand center-of-mass displacement from initial position.
"""

def __init__(self, atomgroup, **kwargs):
super(LigandCOMDrift, self).__init__(atomgroup.universe.trajectory, **kwargs)

self._ag = atomgroup

def _prepare(self):
self.results.com_drift = []
self._initial_com = self._ag.center_of_mass()

def _single_frame(self):
# distance between start and current ligand position
# ignores PBC, but we've already centered the traj
drift = mda.lib.distances.calc_bonds(
self._ag.center_of_mass(),
self._initial_com,
)
self.results.com_drift.append(drift)

def _conclude(self):
self.results.com_drift = np.asarray(self.results.com_drift)


def gather_rms_data(
pdb_topology: pathlib.Path, dataset: pathlib.Path, skip: Optional[int] = None
) -> dict[str, list[float]]:
Expand Down Expand Up @@ -161,8 +284,6 @@ def gather_rms_data(
# max against 1 to avoid skip=0 case
skip = max(n_frames // 500, 1)

pb = tqdm.tqdm(total=int(n_frames / skip) * n_lambda)

u_top = mda.Universe(pdb_topology)

for i in range(n_lambda):
Expand All @@ -173,93 +294,47 @@ def gather_rms_data(
prot = u.select_atoms("protein and name CA")
ligand = u.select_atoms("resname UNK")

# save coordinates for 2D RMSD matrix
# TODO: Some smart guard to avoid allocating a silly amount of memory?
prot2d = np.empty((len(u.trajectory[::skip]), len(prot), 3), dtype=np.float32)

prot_start = prot.positions
ligand_start = ligand.positions
ligand_initial_com = ligand.center_of_mass()
ligand_weights = ligand.masses / np.mean(ligand.masses)

this_protein_rmsd = []
this_ligand_rmsd = []
this_ligand_wander = []

for ts_i, ts in enumerate(u.trajectory[::skip]):
pb.update()

if prot:
prot2d[ts_i, :, :] = prot.positions
this_protein_rmsd.append(
rms.rmsd(
prot.positions,
prot_start,
None, # prot_weights,
center=False,
superposition=False,
)
)
if ligand:
this_ligand_rmsd.append(
rms.rmsd(
ligand.positions,
ligand_start,
ligand_weights,
center=False,
superposition=False,
)
)
this_ligand_wander.append(
# distance between start and current ligand position
# ignores PBC, but we've already centered the traj
mda.lib.distances.calc_bonds(ligand.center_of_mass(), ligand_initial_com)
)

if prot:
# can ignore weights here as it's all Ca
rmsd2d = twoD_RMSD(prot2d, w=None) # prot_weights)
output["protein_RMSD"].append(this_protein_rmsd)
output["protein_2D_RMSD"].append(rmsd2d)
if ligand:
output["ligand_RMSD"].append(this_ligand_rmsd)
output["ligand_wander"].append(this_ligand_wander)
prot_rmsd = RMSDAnalysis(prot).run(step=skip)
output["protein_RMSD"].append(prot_rmsd.results.rmsd)
# # Using the MDAnalysis RMSD class instead
# gs = ["protein and name CA"]
# prot_rmsd = rms.RMSD(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two RMSD classes are approximately equal in timing (on the test data)

# u, select="protein and name CA", groupselections=gs, weights="mass")
# prot_rmsd.run(step=skip)
# # The results contain:
# # - frame number
# # - time
# # - RMSD based on select (after superimposing)
# # - RMSD based on groupselections, one array per selection
# output["protein_RMSD"].append(prot_rmsd.results.rmsd.T[3])

prot_rmsd2d = Protein2DRMSD(prot).run(step=skip)
output["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d)
# # Using the MDAnalysis DistanceMatrix class
# prot_rmsd2d = diffusionmap.DistanceMatrix(u, select="protein and name CA")
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This MDA code is much slower, on the test data 10s vs. 0.4s.

# prot_rmsd2d.run(step=skip)
# dist_mat = prot_rmsd2d.results.dist_matrix
# i, j = np.triu_indices_from(dist_mat, k=1)
# flattened = dist_mat[i, j]
# output["protein_2D_RMSD"].append(flattened)

output["time(ps)"] = list(np.arange(len(u.trajectory))[::skip] * u.trajectory.dt)

return output


def twoD_RMSD(positions, w: Optional[npt.NDArray]) -> list[float]:
"""
Compute a flattened 2D RMSD matrix from a trajectory.

For all unique frame pairs ``(i, j)`` with ``i < j``, this function
computes the RMSD between atomic coordinates after optimal alignment.

Parameters
----------
positions : np.ndarray
Atomic coordinates for all frames in the trajectory.
w : np.ndarray, optional
Per-atom weights to use in the RMSD calculation. If ``None``,
all atoms are weighted equally.

Returns
-------
list of float
Flattened list of RMSD values corresponding to all frame pairs
``(i, j)`` with ``i < j``.
"""
nframes, _, _ = positions.shape

output = []

for i, j in itertools.combinations(range(nframes), 2):
posi, posj = positions[i], positions[j]

rmsd = rms.rmsd(posi, posj, w, center=True, superposition=True)

output.append(rmsd)
if ligand:
lig_rmsd = RMSDAnalysis(ligand, mass_weighted=True).run(step=skip)
output["ligand_RMSD"].append(lig_rmsd.results.rmsd)
# # Using the MDAnalysis RMSD class instead
# groupselections = ["resname UNK"]
# lig_rmsd = rms.RMSD(
# u,
# select="protein and name CA",
# groupselections=groupselections,
# weights="mass",
# )
# lig_rmsd.run(step=skip)
# output["ligand_RMSD"].append(lig_rmsd.results.rmsd.T[3])
lig_com_drift = LigandCOMDrift(ligand).run(step=skip)
output["ligand_wander"].append(lig_com_drift.results.com_drift)

output["time(ps)"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt

return output
80 changes: 80 additions & 0 deletions src/openfe_analysis/tests/test_rmsd_mda_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import MDAnalysis as mda
import pytest
from MDAnalysisTests.datafiles import DCD, PSF
from numpy.testing import assert_allclose, assert_almost_equal

from openfe_analysis.rmsd import RMSDAnalysis


@pytest.fixture
def mda_universe():
return mda.Universe(PSF, DCD)


@pytest.fixture()
def correct_values():
return [0, 4.68953]


@pytest.fixture()
def correct_values_mass():
return [0, 4.74920]


def test_rmsd(mda_universe, correct_values):
prot = mda_universe.select_atoms("name CA")
prot_rmsd = RMSDAnalysis(prot, superposition=True).run(step=49)
assert_almost_equal(
prot_rmsd.results.rmsd,
correct_values,
4,
err_msg="error: rmsd profile should match" + "test values",
)


def test_rmsd_frames(mda_universe, correct_values):
prot = mda_universe.select_atoms("name CA")
prot_rmsd = RMSDAnalysis(prot, superposition=True).run(frames=[0, 49])
assert_almost_equal(
prot_rmsd.results.rmsd,
correct_values,
4,
err_msg="error: rmsd profile should match" + "test values",
)


def test_rmsd_single_frame(mda_universe):
prot = mda_universe.select_atoms("name CA")
prot_rmsd = RMSDAnalysis(prot, superposition=True).run(start=5, stop=6)
single_frame = [0.91544906]
assert_almost_equal(
prot_rmsd.results.rmsd,
single_frame,
4,
err_msg="error: rmsd profile should match" + "test values",
)


def test_mass_weighted(mda_universe, correct_values):
# mass weighting the CA should give the same answer as weighing
# equally because all CA have the same mass
prot = mda_universe.select_atoms("name CA")
prot_rmsd = RMSDAnalysis(prot, superposition=True, mass_weighted=True).run(step=49)

assert_almost_equal(
prot_rmsd.results.rmsd,
correct_values,
4,
err_msg="error: rmsd profile should matchtest values",
)


def test_custom_weighted(mda_universe, correct_values_mass):
prot = mda_universe.select_atoms("all")
prot_rmsd = RMSDAnalysis(prot, superposition=True, mass_weighted=True).run(step=49)
assert_almost_equal(
prot_rmsd.results.rmsd,
correct_values_mass,
4,
err_msg="error: rmsd profile should matchtest values",
)
Loading