From 14ce7fd3eab778f0197bee53267da3e180a5119c Mon Sep 17 00:00:00 2001 From: hannahbaumann Date: Mon, 1 Jun 2026 15:40:09 +0200 Subject: [PATCH] Some updates --- .../protocols/openmm_septop/base_units.py | 53 +++++++++++++------ 1 file changed, 36 insertions(+), 17 deletions(-) diff --git a/src/openfe/protocols/openmm_septop/base_units.py b/src/openfe/protocols/openmm_septop/base_units.py index 40b8e2d5..8564ff27 100644 --- a/src/openfe/protocols/openmm_septop/base_units.py +++ b/src/openfe/protocols/openmm_septop/base_units.py @@ -21,6 +21,7 @@ from typing import Any, Literal, Optional import gufe import matplotlib.pyplot as plt +import MDAnalysis as mda import numpy as np import numpy.typing as npt import openmm @@ -1516,6 +1517,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", ) -> dict[str, Any]: """ Run structural analysis for the complex phase. @@ -1536,6 +1538,9 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): RDKit molecule for ligand A, used for symmetry-corrected RMSD. rdmol_B : Chem.Mol RDKit molecule for ligand B, used for symmetry-corrected RMSD. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations. Default: "protein and name CA". Returns ------- @@ -1562,13 +1567,13 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): "protein_2D_RMSD": [], "time_ps": None, } - + u_top = mda.Universe(pdb_file) for state_idx in range(n_lambda): - u = create_universe_single_state(pdb_file, ds, state=state_idx) - prot = u.select_atoms("protein and name CA") - lig_A = u.atoms[ligand_A_indices] - lig_B = u.atoms[ligand_B_indices] - apply_complex_alignment_transformations(u, protein=prot, ligands=[lig_A, lig_B]) + universe = create_universe_single_state(u_top._topology, ds, state=state_idx) + prot = universe.select_atoms(protein_selection) + lig_A = universe.atoms[ligand_A_indices] + lig_B = universe.atoms[ligand_B_indices] + apply_complex_alignment_transformations(universe, protein=prot, ligands=[lig_A, lig_B]) if prot: prot_rmsd2d = Protein2DRMSD(prot).run(step=skip) @@ -1585,7 +1590,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): data[f"{label}_COM_drift"].append(lig_drift.results.com_drift) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt + data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt return data @@ -1638,21 +1643,21 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): "ligand_B_RMSD": [], "time_ps": None, } - + u_top = mda.Universe(pdb_file) for state_idx in range(n_lambda): for label, indices, rdmol in [ ("ligand_A", ligand_A_indices, rdmol_A), ("ligand_B", ligand_B_indices, rdmol_B), ]: - u = create_universe_single_state(pdb_file, ds, state=state_idx) - lig = u.atoms[indices] - apply_ligand_alignment_transformations(u, ligand=lig) + universe = create_universe_single_state(u_top._topology, ds, state=state_idx) + lig = universe.atoms[indices] + apply_ligand_alignment_transformations(universe, ligand=lig) lig_rmsd = SymmetryCorrectedLigandRMSD(lig, rdmol=rdmol).run(step=skip) data[f"{label}_RMSD"].append(lig_rmsd.results.rmsd) if data["time_ps"] is None: - data["time_ps"] = np.arange(len(u.trajectory))[::skip] * u.trajectory.dt + data["time_ps"] = np.arange(len(universe.trajectory))[::skip] * universe.trajectory.dt return data @@ -1668,6 +1673,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", ) -> dict[str, str | pathlib.Path]: """ Run structural analysis using ``openfe-analysis``. @@ -1675,11 +1681,12 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): Parameters ---------- pdb_file : pathlib.Path - Path to the subsampled PDB file. + Path to the PDB file. trj_file : pathlib.Path Path to the trajectory NetCDF file. output_directory : pathlib.Path - Directory where plots and the NPZ file will be written. + The output directory where plots and the data NPZ file + will be stored. dry : bool Whether or not we are running a dry run. simtype : str @@ -1693,6 +1700,10 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): RDKit molecule for ligand A, used for symmetry-corrected RMSD. rdmol_B : Chem.Mol RDKit molecule for ligand B, used for symmetry-corrected RMSD. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations in the complex phase. + Ignored for the solvent phase. Default "protein and name CA". Returns ------- @@ -1705,7 +1716,6 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): try: with nc.Dataset(trj_file) as ds: - # Frame stride: aim for ~500 frames per state if hasattr(ds, "PositionInterval"): n_frames = len(range(0, ds.dimensions["iteration"].size, ds.PositionInterval)) else: @@ -1721,6 +1731,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): ligand_B_indices, rdmol_A, rdmol_B, + protein_selection=protein_selection, ) else: data = cls._run_solvent_analysis( @@ -1736,6 +1747,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): except Exception as e: return {"structural_analysis_error": str(e)} + # Generate relevant plots if not a dry run if not dry: time_ps = data.get("time_ps", []) @@ -1755,6 +1767,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): fig.savefig(output_directory / f"{label}_COM_drift.png") plt.close(fig) + # Write out an NPZ with all the relevant analysis data npz_file = output_directory / "structural_analysis.npz" npz_data = {k: np.asarray(v, dtype=np.float32) for k, v in data.items() if v is not None} np.savez_compressed(npz_file, **npz_data) @@ -1770,6 +1783,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): ligand_B_indices: list[int], rdmol_A: Chem.Mol, rdmol_B: Chem.Mol, + protein_selection: str = "protein and name CA", dry: bool = False, verbose: bool = True, scratch_basepath: pathlib.Path | None = None, @@ -1793,6 +1807,10 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): RDKit molecule for ligand A. rdmol_B : Chem.Mol RDKit molecule for ligand B. + protein_selection : str + MDAnalysis selection string for the protein atoms used for + alignment and RMSD calculations in the complex phase. + Ignored for the solvent phase. Default "protein and name CA". dry : bool Do a dry run of the calculation, creating all necessary hybrid system components (topology, system, sampler, etc...) but without @@ -1821,7 +1839,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): settings = self._get_settings() # Energies analysis - if verbose: + if self.verbose: self.logger.info("Analyzing energies") energy_analysis = self._analyze_multistate_energies( @@ -1832,7 +1850,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): dry=dry, ) - if verbose: + if self.verbose: self.logger.info("Analyzing structural outputs") structural_analysis = self._structural_analysis( @@ -1845,6 +1863,7 @@ class BaseSepTopAnalysisUnit(gufe.ProtocolUnit, SepTopUnitMixin): ligand_B_indices=ligand_B_indices, rdmol_A=rdmol_A, rdmol_B=rdmol_B, + protein_selection=protein_selection, ) return energy_analysis | structural_analysis