Some updates

This commit is contained in:
hannahbaumann
2026-06-01 15:40:09 +02:00
parent e1c101ee73
commit 14ce7fd3ea

View File

@@ -21,6 +21,7 @@ from typing import Any, Literal, Optional
import gufe
import matplotlib.pyplot as plt
import MDAnalysis as mda
import numpy as np
import numpy.typing as npt
import openmm
@@ -1516,6 +1517,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
protein_selection: str = "protein and name CA",
) -> dict[str, Any]:
"""
Run structural analysis for the complex phase.
@@ -1536,6 +1538,9 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
rdmol_B : Chem.Mol
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
protein_selection : str
MDAnalysis selection string for the protein atoms used for
alignment and RMSD calculations. Default: "protein and name CA".
Returns
-------
@@ -1562,13 +1567,13 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
"protein_2D_RMSD": [],
"time_ps": None,
}
u_top = mda.Universe(pdb_file)
for state_idx in range(n_lambda):
u = create_universe_single_state(pdb_file, ds, state=state_idx)
prot = u.select_atoms("protein and name CA")
lig_A = u.atoms[ligand_A_indices]
lig_B = u.atoms[ligand_B_indices]
apply_complex_alignment_transformations(u, protein=prot, ligands=[lig_A, lig_B])
universe = create_universe_single_state(u_top._topology, ds, state=state_idx)
prot = universe.select_atoms(protein_selection)
lig_A = universe.atoms[ligand_A_indices]
lig_B = universe.atoms[ligand_B_indices]
apply_complex_alignment_transformations(universe, protein=prot, ligands=[lig_A, lig_B])
if prot:
prot_rmsd2d = Protein2DRMSD(prot).run(step=skip)
@@ -1585,7 +1590,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
data[f"{label}_COM_drift"].append(lig_drift.results.com_drift)
if data["time_ps"] is None:
data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt
data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt
return data
@@ -1638,21 +1643,21 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
"ligand_B_RMSD": [],
"time_ps": None,
}
u_top = mda.Universe(pdb_file)
for state_idx in range(n_lambda):
for label, indices, rdmol in [
("ligand_A", ligand_A_indices, rdmol_A),
("ligand_B", ligand_B_indices, rdmol_B),
]:
u = create_universe_single_state(pdb_file, ds, state=state_idx)
lig = u.atoms[indices]
apply_ligand_alignment_transformations(u, ligand=lig)
universe = create_universe_single_state(u_top._topology, ds, state=state_idx)
lig = universe.atoms[indices]
apply_ligand_alignment_transformations(universe, ligand=lig)
lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip)
data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd)
if data["time_ps"] is None:
data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt
data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt
return data
@@ -1668,6 +1673,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
protein_selection: str = "protein and name CA",
) -> dict[str, str | pathlib.Path]:
"""
Run structural analysis using ``openfe-analysis``.
@@ -1675,11 +1681,12 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
Parameters
----------
pdb_file : pathlib.Path
Path to the subsampled PDB file.
Path to the PDB file.
trj_file : pathlib.Path
Path to the trajectory NetCDF file.
output_directory : pathlib.Path
Directory where plots and the NPZ file will be written.
The output directory where plots and the data NPZ file
will be stored.
dry : bool
Whether or not we are running a dry run.
simtype : str
@@ -1693,6 +1700,10 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
RDKit molecule for ligand A, used for symmetry-corrected RMSD.
rdmol_B : Chem.Mol
RDKit molecule for ligand B, used for symmetry-corrected RMSD.
protein_selection : str
MDAnalysis selection string for the protein atoms used for
alignment and RMSD calculations in the complex phase.
Ignored for the solvent phase. Default "protein and name CA".
Returns
-------
@@ -1705,7 +1716,6 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
try:
with nc.Dataset(trj_file) as ds:
# Frame stride: aim for ~500 frames per state
if hasattr(ds, "PositionInterval"):
n_frames = len(range(0, ds.dimensions["iteration"].size, ds.PositionInterval))
else:
@@ -1721,6 +1731,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
ligand_B_indices,
rdmol_A,
rdmol_B,
protein_selection=protein_selection,
)
else:
data = cls._run_solvent_analysis(
@@ -1736,6 +1747,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
except Exception as e:
return {"structural_analysis_error": str(e)}
# Generate relevant plots if not a dry run
if not dry:
time_ps = data.get("time_ps", [])
@@ -1755,6 +1767,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
fig.savefig(output_directory / f"{label}_COM_drift.png")
plt.close(fig)
# Write out an NPZ with all the relevant analysis data
npz_file = output_directory / "structural_analysis.npz"
npz_data = {k: np.asarray(v, dtype=np.float32) for k, v in data.items() if v is not None}
np.savez_compressed(npz_file, **npz_data)
@@ -1770,6 +1783,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
ligand_B_indices: list[int],
rdmol_A: Chem.Mol,
rdmol_B: Chem.Mol,
protein_selection: str = "protein and name CA",
dry: bool = False,
verbose: bool = True,
scratch_basepath: pathlib.Path | None = None,
@@ -1793,6 +1807,10 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
RDKit molecule for ligand A.
rdmol_B : Chem.Mol
RDKit molecule for ligand B.
protein_selection : str
MDAnalysis selection string for the protein atoms used for
alignment and RMSD calculations in the complex phase.
Ignored for the solvent phase. Default "protein and name CA".
dry : bool
Do a dry run of the calculation, creating all necessary hybrid
system components (topology, system, sampler, etc...) but without
@@ -1821,7 +1839,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
settings = self._get_settings()
# Energies analysis
if verbose:
if self.verbose:
self.logger.info("Analyzing energies")
energy_analysis = self._analyze_multistate_energies(
@@ -1832,7 +1850,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
dry=dry,
)
if verbose:
if self.verbose:
self.logger.info("Analyzing structural outputs")
structural_analysis = self._structural_analysis(
@@ -1845,6 +1863,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin):
ligand_B_indices=ligand_B_indices,
rdmol_A=rdmol_A,
rdmol_B=rdmol_B,
protein_selection=protein_selection,
)
return energy_analysis | structural_analysis