First pass at adding structural analysis to the SepTop protocol

This commit is contained in:
hannahbaumann
2026-05-28 12:04:36 +02:00
parent 53470827a4
commit ea6e58a6ae
2 changed files with 317 additions and 2 deletions

View File

@@ -20,6 +20,8 @@ import pathlib
from typing import Any, Literal, Optional
import gufe
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import openmm
import openmmtools
@@ -43,6 +45,7 @@ from openmmtools.states import (
ThermodynamicState,
create_thermodynamic_state_protocol,
)
from rdkit import Chem
import openfe
from openfe.protocols.openmm_afe.equil_afe_settings import (
@@ -1503,11 +1506,282 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
reporter.close()
return analyzer.unit_results_dict
@classmethod
def _run_complex_analysis(
cls,
ds,
pdb_file: pathlib.Path,
skip: int,
ligand_A_indices: list[int],
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
) -> dict[str, Any]:
"""
Run structural analysis for the complex phase.
Parameters
----------
ds : netCDF4.Dataset
Open NetCDF dataset for the multistate trajectory.
pdb_file : pathlib.Path
Path to the subsampled PDB file.
skip : int
Frame stride for analysis.
ligand_A_indices : list[int]
Atom indices of ligand A in the subsampled system.
ligand_B_indices : list[int]
Atom indices of ligand B in the subsampled system.
rdmol_A : Chem.Mol
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
rdmol_B : Chem.Mol
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
Returns
-------
dict[str, Any]
Per-state lists for ligand RMSD, COM drift, protein 2D RMSD,
and a single time_ps array.
"""
from openfe_analysis.rmsd import (
LigandCOMDrift,
Protein2DRMSD,
SymmetryCorrectedLigandRMSD,
)
from openfe_analysis.utils.apply_transformations import (
apply_alignment_transformations,
)
from openfe_analysis.utils.universe_utils import \
create_universe_single_state
n_lambda = ds.dimensions["state"].size
data: dict[str, Any] = {
"ligand_A_RMSD": [],
"ligand_B_RMSD": [],
"ligand_A_COM_drift": [],
"ligand_B_COM_drift": [],
"protein_2D_RMSD": [],
"time_ps": None,
}
for state_idx in range(n_lambda):
u = create_universe_single_state(pdb_file, ds, state=state_idx)
prot = u.select_atoms("protein and name CA")
lig_A = u.atoms[ligand_A_indices]
lig_B = u.atoms[ligand_B_indices]
apply_alignment_transformations(u, protein=prot,
ligand=lig_A + lig_B)
if prot:
prot_rmsd2d = Protein2DRMSD(prot).run(step=skip)
data["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d)
for label, lig, rdmol in [
("ligand_A", lig_A, rdmol_A),
("ligand_B", lig_B, rdmol_B),
]:
lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(
step=skip)
data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd)
lig_drift = LigandCOMDrift(lig).run(step=skip)
data[f"{label}_COM_drift"].append(lig_drift.results.com_drift)
if data["time_ps"] is None:
data["time_ps"] = (
np.arange(len(u.trajectory))[::skip] * u.trajectory.dt
)
return data
@classmethod
def _run_solvent_analysis(
cls,
ds,
pdb_file: pathlib.Path,
skip: int,
ligand_A_indices: list[int],
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
) -> dict[str, Any]:
"""
Run structural analysis for the solvent phase.
Parameters
----------
ds : netCDF4.Dataset
Open NetCDF dataset for the multistate trajectory.
pdb_file : pathlib.Path
Path to the subsampled PDB file.
skip : int
Frame stride for analysis.
ligand_A_indices : list[int]
Atom indices of ligand A in the subsampled system.
ligand_B_indices : list[int]
Atom indices of ligand B in the subsampled system.
rdmol_A : Chem.Mol
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
rdmol_B : Chem.Mol
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
Returns
-------
dict[str, Any]
Per-state lists for ligand RMSD and a single
time_ps array.
"""
from openfe_analysis.rmsd import SymmetryCorrectedLigandRMSD
from openfe_analysis.utils.apply_transformations import (
apply_alignment_transformations,
)
from openfe_analysis.utils.universe_utils import \
create_universe_single_state
n_lambda = ds.dimensions["state"].size
data: dict[str, Any] = {
"ligand_A_RMSD": [],
"ligand_B_RMSD": [],
"time_ps": None,
}
for state_idx in range(n_lambda):
for label, indices, rdmol in [
("ligand_A", ligand_A_indices, rdmol_A),
("ligand_B", ligand_B_indices, rdmol_B),
]:
u = create_universe_single_state(pdb_file, ds, state=state_idx)
lig = u.atoms[indices]
apply_alignment_transformations(u, ligand=lig)
lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(
step=skip)
data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd)
if data["time_ps"] is None:
data["time_ps"] = (
np.arange(len(u.trajectory))[
::skip] * u.trajectory.dt
)
return data
@classmethod
def _structural_analysis(
cls,
pdb_file: pathlib.Path,
trj_file: pathlib.Path,
output_directory: pathlib.Path,
dry: bool,
simtype: str,
ligand_A_indices: list[int],
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
) -> dict[str, str | pathlib.Path]:
"""
Run structural analysis using ``openfe-analysis``.
Parameters
----------
pdb_file : pathlib.Path
Path to the subsampled PDB file.
trj_file : pathlib.Path
Path to the trajectory NetCDF file.
output_directory : pathlib.Path
Directory where plots and the NPZ file will be written.
dry : bool
Whether or not we are running a dry run.
simtype : str
Either ``"complex"`` or ``"solvent"``. Controls whether protein
analyses are run and how alignment is applied.
ligand_A_indices : list[int]
Atom indices of ligand A in the subsampled system.
ligand_B_indices : list[int]
Atom indices of ligand B in the subsampled system.
rdmol_A : Chem.Mol
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
rdmol_B : Chem.Mol
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
Returns
-------
dict[str, str | pathlib.Path]
Dictionary containing either the path to the NPZ file with the
structural data, or the analysis error.
"""
import netCDF4 as nc
from openfe_analysis import plotting
try:
with nc.Dataset(trj_file) as ds:
# Frame stride: aim for ~500 frames per state
if hasattr(ds, "PositionInterval"):
n_frames = len(
range(0, ds.dimensions["iteration"].size,
ds.PositionInterval)
)
else:
n_frames = ds.dimensions["iteration"].size
skip = max(n_frames // 500, 1)
if simtype == "complex":
data = cls._run_complex_analysis(
ds, pdb_file, skip,
ligand_A_indices, ligand_B_indices,
rdmol_A, rdmol_B,
)
else:
data = cls._run_solvent_analysis(
ds, pdb_file, skip,
ligand_A_indices, ligand_B_indices,
rdmol_A, rdmol_B,
)
except Exception as e:
return {"structural_analysis_error": str(e)}
if not dry:
time_ps = data.get("time_ps", [])
if data.get("protein_2D_RMSD"):
fig = plotting.plot_2D_rmsd(data["protein_2D_RMSD"])
fig.savefig(output_directory / "protein_2D_RMSD.png")
plt.close(fig)
for label in ["ligand_A", "ligand_B"]:
if data.get(f"{label}_RMSD"):
fig = plotting.plot_ligand_RMSD(time_ps,
data[f"{label}_RMSD"])
fig.savefig(output_directory / f"{label}_RMSD.png")
plt.close(fig)
if data.get(f"{label}_COM_drift"):
fig = plotting.plot_ligand_COM_drift(
time_ps, data[f"{label}_COM_drift"]
)
fig.savefig(output_directory / f"{label}_COM_drift.png")
plt.close(fig)
npz_file = output_directory / "structural_analysis.npz"
npz_data = {
k: np.asarray(v, dtype=np.float32)
for k, v in data.items()
if v is not None
}
np.savez_compressed(npz_file, **npz_data)
return {"structural_analysis": npz_file}
def run(
self,
*,
trajectory: pathlib.Path,
checkpoint: pathlib.Path,
pdb_file: pathlib.Path,
ligand_A_indices: list[int],
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
dry: bool = False,
verbose: bool = True,
scratch_basepath: pathlib.Path | None = None,
@@ -1521,6 +1795,16 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
Path to the MultiStateReporter generated NetCDF file.
checkpoint : pathlib.Path
Path to the checkpoint file generated by MultiStateReporter.
pdb_file : pathlib.Path
Path to the subsampled PDB file.
ligand_A_indices : list[int]
Atom indices of ligand A in the subsampled system.
ligand_B_indices : list[int]
Atom indices of ligand B in the subsampled system.
rdmol_A : Chem.Mol
RDKit molecule for ligand A.
rdmol_B : Chem.Mol
RDKit molecule for ligand B.
dry : bool
Do a dry run of the calculation, creating all necessary hybrid
system components (topology, system, sampler, etc...) but without
@@ -1560,7 +1844,22 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
dry=dry,
)
return energy_analysis
if verbose:
self.logger.info("Analyzing structural outputs")
structural_analysis = self._structural_analysis(
pdb_file=pdb_file,
trj_file=trajectory,
output_directory=self.shared_basepath,
dry=dry,
simtype=self.simtype,
ligand_A_indices=ligand_A_indices,
ligand_B_indices=ligand_B_indices,
rdmol_A=rdmol_A,
rdmol_B=rdmol_B,
)
return energy_analysis | structural_analysis
def _execute(
self,
@@ -1579,14 +1878,22 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
trajectory = simulation.outputs["trajectory"]
checkpoint = simulation.outputs["checkpoint"]
alchem_comps = self._inputs["alchemical_components"]
rdmol_A = alchem_comps["stateA"][0].to_rdkit()
rdmol_B = alchem_comps["stateB"][0].to_rdkit()
outputs = self.run(
trajectory=trajectory,
checkpoint=checkpoint,
pdb_file=setup.outputs["subsampled_pdb_structure"],
ligand_A_indices=setup.outputs["ligand_A_indices"],
ligand_B_indices=setup.outputs["ligand_B_indices"],
rdmol_A=rdmol_A,
rdmol_B=rdmol_B,
scratch_basepath=ctx.scratch,
shared_basepath=ctx.shared,
)
# We re-include things here to make life easier when gathering results
if self.simtype == "complex":
previous_outputs = {
"standard_state_correction_A": setup.outputs["standard_state_correction_A"],

View File

@@ -855,6 +855,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
"restraint_geometry_B": restraint_geom_B.model_dump(),
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
"ligand_A_indices": atom_indices_AB_A,
"ligand_B_indices": atom_indices_AB_B,
}
else:
return {
@@ -870,6 +872,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
"positions": equil_positions_AB,
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
"ligand_A_indices": atom_indices_AB_A,
"ligand_B_indices": atom_indices_AB_B,
}
@@ -1133,6 +1137,8 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
"standard_state_correction": corr.to("kilocalorie_per_mole"),
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
"ligand_A_indices": atom_indices_AB_A,
"ligand_B_indices": atom_indices_AB_B,
}
else:
return {
@@ -1146,6 +1152,8 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
"positions": positions_AB,
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
"ligand_A_indices": atom_indices_AB_A,
"ligand_B_indices": atom_indices_AB_B,
}