mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
First pass at adding structural analysis to the SepTop protocol
This commit is contained in:
@@ -20,6 +20,8 @@ import pathlib
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import gufe
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import openmm
|
||||
import openmmtools
|
||||
@@ -43,6 +45,7 @@ from openmmtools.states import (
|
||||
ThermodynamicState,
|
||||
create_thermodynamic_state_protocol,
|
||||
)
|
||||
from rdkit import Chem
|
||||
|
||||
import openfe
|
||||
from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
@@ -1503,11 +1506,282 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
|
||||
reporter.close()
|
||||
return analyzer.unit_results_dict
|
||||
|
||||
@classmethod
|
||||
def _run_complex_analysis(
|
||||
cls,
|
||||
ds,
|
||||
pdb_file: pathlib.Path,
|
||||
skip: int,
|
||||
ligand_A_indices: list[int],
|
||||
ligand_B_indices: list[int],
|
||||
rdmol_A: Chem.Mol,
|
||||
rdmol_B: Chem.Mol,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run structural analysis for the complex phase.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ds : netCDF4.Dataset
|
||||
Open NetCDF dataset for the multistate trajectory.
|
||||
pdb_file : pathlib.Path
|
||||
Path to the subsampled PDB file.
|
||||
skip : int
|
||||
Frame stride for analysis.
|
||||
ligand_A_indices : list[int]
|
||||
Atom indices of ligand A in the subsampled system.
|
||||
ligand_B_indices : list[int]
|
||||
Atom indices of ligand B in the subsampled system.
|
||||
rdmol_A : Chem.Mol
|
||||
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
|
||||
rdmol_B : Chem.Mol
|
||||
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
Per-state lists for ligand RMSD, COM drift, protein 2D RMSD,
|
||||
and a single time_ps array.
|
||||
"""
|
||||
from openfe_analysis.rmsd import (
|
||||
LigandCOMDrift,
|
||||
Protein2DRMSD,
|
||||
SymmetryCorrectedLigandRMSD,
|
||||
)
|
||||
from openfe_analysis.utils.apply_transformations import (
|
||||
apply_alignment_transformations,
|
||||
)
|
||||
from openfe_analysis.utils.universe_utils import \
|
||||
create_universe_single_state
|
||||
|
||||
n_lambda = ds.dimensions["state"].size
|
||||
data: dict[str, Any] = {
|
||||
"ligand_A_RMSD": [],
|
||||
"ligand_B_RMSD": [],
|
||||
"ligand_A_COM_drift": [],
|
||||
"ligand_B_COM_drift": [],
|
||||
"protein_2D_RMSD": [],
|
||||
"time_ps": None,
|
||||
}
|
||||
|
||||
for state_idx in range(n_lambda):
|
||||
u = create_universe_single_state(pdb_file, ds, state=state_idx)
|
||||
prot = u.select_atoms("protein and name CA")
|
||||
lig_A = u.atoms[ligand_A_indices]
|
||||
lig_B = u.atoms[ligand_B_indices]
|
||||
apply_alignment_transformations(u, protein=prot,
|
||||
ligand=lig_A + lig_B)
|
||||
|
||||
if prot:
|
||||
prot_rmsd2d = Protein2DRMSD(prot).run(step=skip)
|
||||
data["protein_2D_RMSD"].append(prot_rmsd2d.results.rmsd2d)
|
||||
|
||||
for label, lig, rdmol in [
|
||||
("ligand_A", lig_A, rdmol_A),
|
||||
("ligand_B", lig_B, rdmol_B),
|
||||
]:
|
||||
lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(
|
||||
step=skip)
|
||||
data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd)
|
||||
|
||||
lig_drift = LigandCOMDrift(lig).run(step=skip)
|
||||
data[f"{label}_COM_drift"].append(lig_drift.results.com_drift)
|
||||
|
||||
if data["time_ps"] is None:
|
||||
data["time_ps"] = (
|
||||
np.arange(len(u.trajectory))[::skip] * u.trajectory.dt
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _run_solvent_analysis(
|
||||
cls,
|
||||
ds,
|
||||
pdb_file: pathlib.Path,
|
||||
skip: int,
|
||||
ligand_A_indices: list[int],
|
||||
ligand_B_indices: list[int],
|
||||
rdmol_A: Chem.Mol,
|
||||
rdmol_B: Chem.Mol,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Run structural analysis for the solvent phase.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ds : netCDF4.Dataset
|
||||
Open NetCDF dataset for the multistate trajectory.
|
||||
pdb_file : pathlib.Path
|
||||
Path to the subsampled PDB file.
|
||||
skip : int
|
||||
Frame stride for analysis.
|
||||
ligand_A_indices : list[int]
|
||||
Atom indices of ligand A in the subsampled system.
|
||||
ligand_B_indices : list[int]
|
||||
Atom indices of ligand B in the subsampled system.
|
||||
rdmol_A : Chem.Mol
|
||||
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
|
||||
rdmol_B : Chem.Mol
|
||||
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, Any]
|
||||
Per-state lists for ligand RMSD and a single
|
||||
time_ps array.
|
||||
"""
|
||||
from openfe_analysis.rmsd import SymmetryCorrectedLigandRMSD
|
||||
from openfe_analysis.utils.apply_transformations import (
|
||||
apply_alignment_transformations,
|
||||
)
|
||||
from openfe_analysis.utils.universe_utils import \
|
||||
create_universe_single_state
|
||||
|
||||
n_lambda = ds.dimensions["state"].size
|
||||
data: dict[str, Any] = {
|
||||
"ligand_A_RMSD": [],
|
||||
"ligand_B_RMSD": [],
|
||||
"time_ps": None,
|
||||
}
|
||||
|
||||
for state_idx in range(n_lambda):
|
||||
for label, indices, rdmol in [
|
||||
("ligand_A", ligand_A_indices, rdmol_A),
|
||||
("ligand_B", ligand_B_indices, rdmol_B),
|
||||
]:
|
||||
u = create_universe_single_state(pdb_file, ds, state=state_idx)
|
||||
lig = u.atoms[indices]
|
||||
apply_alignment_transformations(u, ligand=lig)
|
||||
|
||||
lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(
|
||||
step=skip)
|
||||
data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd)
|
||||
|
||||
if data["time_ps"] is None:
|
||||
data["time_ps"] = (
|
||||
np.arange(len(u.trajectory))[
|
||||
::skip] * u.trajectory.dt
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def _structural_analysis(
|
||||
cls,
|
||||
pdb_file: pathlib.Path,
|
||||
trj_file: pathlib.Path,
|
||||
output_directory: pathlib.Path,
|
||||
dry: bool,
|
||||
simtype: str,
|
||||
ligand_A_indices: list[int],
|
||||
ligand_B_indices: list[int],
|
||||
rdmol_A: Chem.Mol,
|
||||
rdmol_B: Chem.Mol,
|
||||
) -> dict[str, str | pathlib.Path]:
|
||||
"""
|
||||
Run structural analysis using ``openfe-analysis``.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pdb_file : pathlib.Path
|
||||
Path to the subsampled PDB file.
|
||||
trj_file : pathlib.Path
|
||||
Path to the trajectory NetCDF file.
|
||||
output_directory : pathlib.Path
|
||||
Directory where plots and the NPZ file will be written.
|
||||
dry : bool
|
||||
Whether or not we are running a dry run.
|
||||
simtype : str
|
||||
Either ``"complex"`` or ``"solvent"``. Controls whether protein
|
||||
analyses are run and how alignment is applied.
|
||||
ligand_A_indices : list[int]
|
||||
Atom indices of ligand A in the subsampled system.
|
||||
ligand_B_indices : list[int]
|
||||
Atom indices of ligand B in the subsampled system.
|
||||
rdmol_A : Chem.Mol
|
||||
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
|
||||
rdmol_B : Chem.Mol
|
||||
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict[str, str | pathlib.Path]
|
||||
Dictionary containing either the path to the NPZ file with the
|
||||
structural data, or the analysis error.
|
||||
"""
|
||||
import netCDF4 as nc
|
||||
from openfe_analysis import plotting
|
||||
|
||||
try:
|
||||
with nc.Dataset(trj_file) as ds:
|
||||
# Frame stride: aim for ~500 frames per state
|
||||
if hasattr(ds, "PositionInterval"):
|
||||
n_frames = len(
|
||||
range(0, ds.dimensions["iteration"].size,
|
||||
ds.PositionInterval)
|
||||
)
|
||||
else:
|
||||
n_frames = ds.dimensions["iteration"].size
|
||||
skip = max(n_frames // 500, 1)
|
||||
|
||||
if simtype == "complex":
|
||||
data = cls._run_complex_analysis(
|
||||
ds, pdb_file, skip,
|
||||
ligand_A_indices, ligand_B_indices,
|
||||
rdmol_A, rdmol_B,
|
||||
)
|
||||
else:
|
||||
data = cls._run_solvent_analysis(
|
||||
ds, pdb_file, skip,
|
||||
ligand_A_indices, ligand_B_indices,
|
||||
rdmol_A, rdmol_B,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return {"structural_analysis_error": str(e)}
|
||||
|
||||
if not dry:
|
||||
time_ps = data.get("time_ps", [])
|
||||
|
||||
if data.get("protein_2D_RMSD"):
|
||||
fig = plotting.plot_2D_rmsd(data["protein_2D_RMSD"])
|
||||
fig.savefig(output_directory / "protein_2D_RMSD.png")
|
||||
plt.close(fig)
|
||||
|
||||
for label in ["ligand_A", "ligand_B"]:
|
||||
if data.get(f"{label}_RMSD"):
|
||||
fig = plotting.plot_ligand_RMSD(time_ps,
|
||||
data[f"{label}_RMSD"])
|
||||
fig.savefig(output_directory / f"{label}_RMSD.png")
|
||||
plt.close(fig)
|
||||
|
||||
if data.get(f"{label}_COM_drift"):
|
||||
fig = plotting.plot_ligand_COM_drift(
|
||||
time_ps, data[f"{label}_COM_drift"]
|
||||
)
|
||||
fig.savefig(output_directory / f"{label}_COM_drift.png")
|
||||
plt.close(fig)
|
||||
|
||||
npz_file = output_directory / "structural_analysis.npz"
|
||||
npz_data = {
|
||||
k: np.asarray(v, dtype=np.float32)
|
||||
for k, v in data.items()
|
||||
if v is not None
|
||||
}
|
||||
np.savez_compressed(npz_file, **npz_data)
|
||||
return {"structural_analysis": npz_file}
|
||||
|
||||
def run(
|
||||
self,
|
||||
*,
|
||||
trajectory: pathlib.Path,
|
||||
checkpoint: pathlib.Path,
|
||||
pdb_file: pathlib.Path,
|
||||
ligand_A_indices: list[int],
|
||||
ligand_B_indices: list[int],
|
||||
rdmol_A: Chem.Mol,
|
||||
rdmol_B: Chem.Mol,
|
||||
dry: bool = False,
|
||||
verbose: bool = True,
|
||||
scratch_basepath: pathlib.Path | None = None,
|
||||
@@ -1521,6 +1795,16 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
|
||||
Path to the MultiStateReporter generated NetCDF file.
|
||||
checkpoint : pathlib.Path
|
||||
Path to the checkpoint file generated by MultiStateReporter.
|
||||
pdb_file : pathlib.Path
|
||||
Path to the subsampled PDB file.
|
||||
ligand_A_indices : list[int]
|
||||
Atom indices of ligand A in the subsampled system.
|
||||
ligand_B_indices : list[int]
|
||||
Atom indices of ligand B in the subsampled system.
|
||||
rdmol_A : Chem.Mol
|
||||
RDKit molecule for ligand A.
|
||||
rdmol_B : Chem.Mol
|
||||
RDKit molecule for ligand B.
|
||||
dry : bool
|
||||
Do a dry run of the calculation, creating all necessary hybrid
|
||||
system components (topology, system, sampler, etc...) but without
|
||||
@@ -1560,7 +1844,22 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
|
||||
dry=dry,
|
||||
)
|
||||
|
||||
return energy_analysis
|
||||
if verbose:
|
||||
self.logger.info("Analyzing structural outputs")
|
||||
|
||||
structural_analysis = self._structural_analysis(
|
||||
pdb_file=pdb_file,
|
||||
trj_file=trajectory,
|
||||
output_directory=self.shared_basepath,
|
||||
dry=dry,
|
||||
simtype=self.simtype,
|
||||
ligand_A_indices=ligand_A_indices,
|
||||
ligand_B_indices=ligand_B_indices,
|
||||
rdmol_A=rdmol_A,
|
||||
rdmol_B=rdmol_B,
|
||||
)
|
||||
|
||||
return energy_analysis | structural_analysis
|
||||
|
||||
def _execute(
|
||||
self,
|
||||
@@ -1579,14 +1878,22 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
|
||||
trajectory = simulation.outputs["trajectory"]
|
||||
checkpoint = simulation.outputs["checkpoint"]
|
||||
|
||||
alchem_comps = self._inputs["alchemical_components"]
|
||||
rdmol_A = alchem_comps["stateA"][0].to_rdkit()
|
||||
rdmol_B = alchem_comps["stateB"][0].to_rdkit()
|
||||
|
||||
outputs = self.run(
|
||||
trajectory=trajectory,
|
||||
checkpoint=checkpoint,
|
||||
pdb_file=setup.outputs["subsampled_pdb_structure"],
|
||||
ligand_A_indices=setup.outputs["ligand_A_indices"],
|
||||
ligand_B_indices=setup.outputs["ligand_B_indices"],
|
||||
rdmol_A=rdmol_A,
|
||||
rdmol_B=rdmol_B,
|
||||
scratch_basepath=ctx.scratch,
|
||||
shared_basepath=ctx.shared,
|
||||
)
|
||||
|
||||
# We re-include things here to make life easier when gathering results
|
||||
if self.simtype == "complex":
|
||||
previous_outputs = {
|
||||
"standard_state_correction_A": setup.outputs["standard_state_correction_A"],
|
||||
|
||||
@@ -855,6 +855,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
"restraint_geometry_B": restraint_geom_B.model_dump(),
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
"ligand_A_indices": atom_indices_AB_A,
|
||||
"ligand_B_indices": atom_indices_AB_B,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
@@ -870,6 +872,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
"positions": equil_positions_AB,
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
"ligand_A_indices": atom_indices_AB_A,
|
||||
"ligand_B_indices": atom_indices_AB_B,
|
||||
}
|
||||
|
||||
|
||||
@@ -1133,6 +1137,8 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
"standard_state_correction": corr.to("kilocalorie_per_mole"),
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
"ligand_A_indices": atom_indices_AB_A,
|
||||
"ligand_B_indices": atom_indices_AB_B,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
@@ -1146,6 +1152,8 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
"positions": positions_AB,
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
"ligand_A_indices": atom_indices_AB_A,
|
||||
"ligand_B_indices": atom_indices_AB_B,
|
||||
}
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user