mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
Add Analysis unit for Septop (#1937)
* Add an Analysis unit to SepTopProtocol. * Update PDB writing via MDTraj in AFE and SepTop Protocols to account for box dimensions with mdtraj_from_openmm utility. * Split up some SepTop tests into different files. * Updated the CLI gathering to account for post openfe v1.11 Analysis units --------- Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com>
This commit is contained in:
@@ -18,8 +18,10 @@ Protocol API specification
|
||||
SepTopProtocol
|
||||
SepTopComplexSetupUnit
|
||||
SepTopComplexRunUnit
|
||||
SepTopComplexAnalysisUnit
|
||||
SepTopSolventSetupUnit
|
||||
SepTopSolventRunUnit
|
||||
SepTopSolventAnalysisUnit
|
||||
SepTopProtocolResult
|
||||
|
||||
Protocol Settings
|
||||
|
||||
27
news/analysis-septop.rst
Normal file
27
news/analysis-septop.rst
Normal file
@@ -0,0 +1,27 @@
|
||||
**Added:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Changed:**
|
||||
|
||||
* The SepTopProtocol now has a dedicated Analysis unit.
|
||||
At the top level API, this does not change behaviour, but
|
||||
if you are directly interfacing with the ProtocolUnits, you
|
||||
will have to account for this change. The SepTopProtocolResult now
|
||||
solely uses the Analysis units. PR #1937
|
||||
|
||||
**Deprecated:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Removed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Fixed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Security:**
|
||||
|
||||
* <news item>
|
||||
@@ -23,7 +23,6 @@ import pathlib
|
||||
from typing import Any
|
||||
|
||||
import gufe
|
||||
import mdtraj as mdt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import openmm
|
||||
@@ -75,6 +74,9 @@ from openfe.protocols.openmm_utils import (
|
||||
system_creation,
|
||||
system_validation,
|
||||
)
|
||||
from openfe.protocols.openmm_utils.mdtraj_utils import (
|
||||
mdtraj_from_openmm,
|
||||
)
|
||||
from openfe.protocols.openmm_utils.omm_settings import (
|
||||
SettingsBaseModel,
|
||||
)
|
||||
@@ -643,16 +645,14 @@ class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin):
|
||||
selection_indices : npt.NDArray
|
||||
The indices of the subselected system.
|
||||
"""
|
||||
mdt_top = mdt.Topology.from_openmm(topology)
|
||||
selection_indices = mdt_top.select(output_selection)
|
||||
traj = mdtraj_from_openmm(topology, positions)
|
||||
|
||||
selection_indices = traj.topology.select(output_selection)
|
||||
|
||||
# Write out the subselected structure to PDB if not empty
|
||||
if len(selection_indices) > 0:
|
||||
traj = mdt.Trajectory(
|
||||
positions[selection_indices, :],
|
||||
mdt_top.subset(selection_indices),
|
||||
)
|
||||
traj.save_pdb(output_file)
|
||||
sub_traj = traj.atom_slice(selection_indices)
|
||||
sub_traj.save_pdb(output_file)
|
||||
|
||||
return selection_indices
|
||||
|
||||
@@ -1346,7 +1346,7 @@ class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin)
|
||||
self.shared_basepath / settings["output_settings"].checkpoint_storage_filename,
|
||||
]
|
||||
for fn in fns:
|
||||
os.remove(fn)
|
||||
fn.unlink()
|
||||
|
||||
def run(
|
||||
self,
|
||||
|
||||
@@ -487,7 +487,7 @@ class AbsoluteBindingProtocol(gufe.Protocol):
|
||||
|
||||
# Get the name of the alchemical species
|
||||
alchname = alchem_comps["stateA"][0].name
|
||||
unit_classes = {
|
||||
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
|
||||
"solvent": {
|
||||
"setup": ABFESolventSetupUnit,
|
||||
"simulation": ABFESolventSimUnit,
|
||||
|
||||
@@ -451,7 +451,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
|
||||
# Get the name of the alchemical species
|
||||
alchname = alchem_comps["stateA"][0].name
|
||||
|
||||
unit_classes = {
|
||||
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
|
||||
"solvent": {
|
||||
"setup": AHFESolventSetupUnit,
|
||||
"simulation": AHFESolventSimUnit,
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
|
||||
import gufe
|
||||
import matplotlib.pyplot as plt
|
||||
import mdtraj
|
||||
import mdtraj as mdt
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import openmm
|
||||
@@ -643,7 +643,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
|
||||
|
||||
def _subsample_topology(
|
||||
self,
|
||||
hybrid_topology: openmm.app.Topology,
|
||||
hybrid_topology: mdt.Topology,
|
||||
hybrid_positions: openmm.unit.Quantity,
|
||||
output_selection: str,
|
||||
output_filename: str,
|
||||
@@ -655,7 +655,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
hybrid_topology : openmm.app.Topology
|
||||
hybrid_topology : mdtraj.Topology
|
||||
The hybrid system topology to subsample.
|
||||
hybrid_positions : openmm.unit.Quantity
|
||||
The hybrid system positions.
|
||||
@@ -674,7 +674,8 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
|
||||
|
||||
TODO
|
||||
----
|
||||
Modify this to also store the full system.
|
||||
* Modify this to also store the full system.
|
||||
* Use the mdtraj_from_openmm utility.
|
||||
"""
|
||||
selection_indices = hybrid_topology.select(output_selection)
|
||||
|
||||
@@ -690,7 +691,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
|
||||
bfactors[np.isin(selection_indices, list(atom_classes["unique_new_atoms"]))] = 0.75
|
||||
|
||||
if len(selection_indices) > 0:
|
||||
traj = mdtraj.Trajectory(
|
||||
traj = mdt.Trajectory(
|
||||
hybrid_positions[selection_indices, :],
|
||||
hybrid_topology.subset(selection_indices),
|
||||
).save_pdb(
|
||||
|
||||
@@ -6,10 +6,12 @@ Run SepTop free energy calculations using OpenMM and OpenMMTools.
|
||||
"""
|
||||
|
||||
from .equil_septop_method import (
|
||||
SepTopComplexAnalysisUnit,
|
||||
SepTopComplexRunUnit,
|
||||
SepTopComplexSetupUnit,
|
||||
SepTopProtocol,
|
||||
SepTopProtocolResult,
|
||||
SepTopSolventAnalysisUnit,
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
@@ -25,4 +27,6 @@ __all__ = [
|
||||
"SepTopSolventSetupUnit",
|
||||
"SepTopSolventRunUnit",
|
||||
"SepTopComplexRunUnit",
|
||||
"SepTopSolventAnalysisUnit",
|
||||
"SeptopComplexAnalysisUnit",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -74,8 +74,10 @@ from ..restraint_utils.settings import (
|
||||
)
|
||||
from .septop_protocol_results import SepTopProtocolResult
|
||||
from .septop_units import (
|
||||
SepTopComplexAnalysisUnit,
|
||||
SepTopComplexRunUnit,
|
||||
SepTopComplexSetupUnit,
|
||||
SepTopSolventAnalysisUnit,
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
@@ -148,7 +150,7 @@ class SepTopProtocol(gufe.Protocol):
|
||||
:class:`openfe.protocols.openmm_septop.SepTopProtocolResult`
|
||||
:class:`openfe.protocols.openmm_septop.SepTopComplexSetupUnit`
|
||||
:class:`openfe.protocols.openmm_septop.SepTopComplexRunUnit`
|
||||
:class:`openfe.protocols.openmm_septop.SepTopSolventSetupUnit
|
||||
:class:`openfe.protocols.openmm_septop.SepTopSolventSetupUnit`
|
||||
:class:`openfe.protocols.openmm_septop.SepTopSolventRunUnit`
|
||||
"""
|
||||
|
||||
@@ -460,7 +462,7 @@ class SepTopProtocol(gufe.Protocol):
|
||||
self,
|
||||
stateA: ChemicalSystem,
|
||||
stateB: ChemicalSystem,
|
||||
mapping: gufe.ComponentMapping | list[gufe.ComponentMappping] | None,
|
||||
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None,
|
||||
extends: gufe.ProtocolDAGResult | None = None,
|
||||
) -> None:
|
||||
# Check we're not trying to extend
|
||||
@@ -525,89 +527,96 @@ class SepTopProtocol(gufe.Protocol):
|
||||
)
|
||||
|
||||
# Create list units for complex and solvent transforms
|
||||
def create_setup_units(unit_cls, leg):
|
||||
return [
|
||||
unit_cls(
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(
|
||||
f"SepTop RBFE Setup, transformation {alchname_A} to "
|
||||
f"{alchname_B}, {leg} leg: repeat {i} generation 0"
|
||||
),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
|
||||
def create_run_units(unit_cls, leg, setup):
|
||||
return [
|
||||
unit_cls(
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
setup=setup[i],
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(
|
||||
f"SepTop RBFE Run, transformation {alchname_A} to "
|
||||
f"{alchname_B}, {leg} leg: repeat {i} generation 0"
|
||||
),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
|
||||
alchname_A = alchem_comps["stateA"][0].name
|
||||
alchname_B = alchem_comps["stateB"][0].name
|
||||
|
||||
solvent_setup = create_setup_units(SepTopSolventSetupUnit, "solvent")
|
||||
solvent_run = create_run_units(SepTopSolventRunUnit, "solvent", setup=solvent_setup)
|
||||
complex_setup = create_setup_units(SepTopComplexSetupUnit, "complex")
|
||||
complex_run = create_run_units(SepTopComplexRunUnit, "complex", setup=complex_setup)
|
||||
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
|
||||
"solvent": {
|
||||
"setup": SepTopSolventSetupUnit,
|
||||
"simulation": SepTopSolventRunUnit,
|
||||
"analysis": SepTopSolventAnalysisUnit,
|
||||
},
|
||||
"complex": {
|
||||
"setup": SepTopComplexSetupUnit,
|
||||
"simulation": SepTopComplexRunUnit,
|
||||
"analysis": SepTopComplexAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
return solvent_setup + solvent_run + complex_setup + complex_run
|
||||
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []}
|
||||
|
||||
for i in range(self.settings.protocol_repeats):
|
||||
repeat_id = int(uuid.uuid4())
|
||||
for phase in ["solvent", "complex"]:
|
||||
setup = unit_classes[phase]["setup"](
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=(
|
||||
f"SepTop RBFE Setup, transformation {alchname_A} to "
|
||||
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
|
||||
),
|
||||
)
|
||||
|
||||
simulation = unit_classes[phase]["simulation"](
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
setup=setup,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=(
|
||||
f"SepTop RBFE Run, transformation {alchname_A} to "
|
||||
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
|
||||
),
|
||||
)
|
||||
|
||||
analysis = unit_classes[phase]["analysis"](
|
||||
protocol=self,
|
||||
setup=setup,
|
||||
simulation=simulation,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=(
|
||||
f"SepTop RBFE Analysis, transformation {alchname_A} to "
|
||||
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
|
||||
),
|
||||
)
|
||||
|
||||
protocol_units[phase] += [setup, simulation, analysis]
|
||||
|
||||
return protocol_units["solvent"] + protocol_units["complex"]
|
||||
|
||||
def _gather(
|
||||
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
# result units will have a repeat_id and generation
|
||||
# first group according to repeat_id
|
||||
unsorted_solvent_repeats_setup = defaultdict(list)
|
||||
unsorted_solvent_repeats_run = defaultdict(list)
|
||||
unsorted_complex_repeats_setup = defaultdict(list)
|
||||
unsorted_complex_repeats_run = defaultdict(list)
|
||||
unsorted_solvent_repeats = defaultdict(list)
|
||||
unsorted_complex_repeats = defaultdict(list)
|
||||
|
||||
for d in protocol_dag_results:
|
||||
pu: gufe.ProtocolUnitResult
|
||||
for pu in d.protocol_unit_results:
|
||||
if not pu.ok():
|
||||
if ("Analysis" not in pu.name) or (not pu.ok()):
|
||||
continue
|
||||
if pu.outputs["simtype"] == "solvent":
|
||||
if "Run" in pu.name:
|
||||
unsorted_solvent_repeats_run[pu.outputs["repeat_id"]].append(pu)
|
||||
elif "Setup" in pu.name:
|
||||
unsorted_solvent_repeats_setup[pu.outputs["repeat_id"]].append(pu)
|
||||
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
|
||||
else:
|
||||
if "Run" in pu.name:
|
||||
unsorted_complex_repeats_run[pu.outputs["repeat_id"]].append(pu)
|
||||
elif "Setup" in pu.name:
|
||||
unsorted_complex_repeats_setup[pu.outputs["repeat_id"]].append(pu)
|
||||
unsorted_complex_repeats[pu.outputs["repeat_id"]].append(pu)
|
||||
|
||||
repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = {
|
||||
"solvent_setup": {},
|
||||
"solvent": {},
|
||||
"complex_setup": {},
|
||||
"complex": {},
|
||||
}
|
||||
for k, v in unsorted_solvent_repeats_setup.items():
|
||||
repeats["solvent_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
|
||||
for k, v in unsorted_solvent_repeats_run.items():
|
||||
|
||||
for k, v in unsorted_solvent_repeats.items():
|
||||
repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
|
||||
|
||||
for k, v in unsorted_complex_repeats_setup.items():
|
||||
repeats["complex_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
|
||||
for k, v in unsorted_complex_repeats_run.items():
|
||||
for k, v in unsorted_complex_repeats.items():
|
||||
repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
|
||||
return repeats
|
||||
|
||||
@@ -65,8 +65,6 @@ class SepTopProtocolResult(gufe.ProtocolResult):
|
||||
complex_dGs.append(
|
||||
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
|
||||
)
|
||||
|
||||
for pus in self.data["complex_setup"].values():
|
||||
complex_correction_dGs_A.append(
|
||||
(
|
||||
pus[0].outputs["standard_state_correction_A"],
|
||||
@@ -84,8 +82,6 @@ class SepTopProtocolResult(gufe.ProtocolResult):
|
||||
solv_dGs.append(
|
||||
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
|
||||
)
|
||||
|
||||
for pus in self.data["solvent_setup"].values():
|
||||
solv_correction_dGs.append(
|
||||
(
|
||||
pus[0].outputs["standard_state_correction"],
|
||||
@@ -417,8 +413,8 @@ class SepTopProtocolResult(gufe.ProtocolResult):
|
||||
for key in ["complex", "solvent"]:
|
||||
for pus in self.data[key].values():
|
||||
states = get_replica_state(
|
||||
pus[0].outputs["nc"],
|
||||
pus[0].outputs["last_checkpoint"],
|
||||
pus[0].outputs["trajectory"],
|
||||
pus[0].outputs["checkpoint"],
|
||||
)
|
||||
replica_states[key].append(states)
|
||||
|
||||
@@ -487,11 +483,11 @@ class SepTopProtocolResult(gufe.ProtocolResult):
|
||||
"""
|
||||
geometry_A = [
|
||||
BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_A"])
|
||||
for pus in self.data["complex_setup"].values()
|
||||
for pus in self.data["complex"].values()
|
||||
]
|
||||
geometry_B = [
|
||||
BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_B"])
|
||||
for pus in self.data["complex_setup"].values()
|
||||
for pus in self.data["complex"].values()
|
||||
]
|
||||
|
||||
return geometry_A, geometry_B
|
||||
|
||||
@@ -34,6 +34,7 @@ from openff.units.openmm import from_openmm, to_openmm
|
||||
from openmmtools.states import ThermodynamicState
|
||||
from rdkit import Chem
|
||||
|
||||
from openfe.protocols.openmm_utils import omm_compute
|
||||
from openfe.protocols.openmm_utils.serialization import serialize
|
||||
from openfe.protocols.restraint_utils import geometry
|
||||
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
|
||||
@@ -43,59 +44,25 @@ from openfe.protocols.restraint_utils.openmm.omm_restraints import (
|
||||
add_force_in_separate_group,
|
||||
)
|
||||
|
||||
from ..openmm_utils import settings_validation, system_validation
|
||||
from ..openmm_utils import (
|
||||
settings_validation,
|
||||
system_validation,
|
||||
)
|
||||
from ..openmm_utils.mdtraj_utils import mdtraj_from_openmm
|
||||
from ..restraint_utils.settings import (
|
||||
BoreschRestraintSettings,
|
||||
DistanceRestraintSettings,
|
||||
)
|
||||
from .base_units import BaseSepTopRunUnit, BaseSepTopSetupUnit, _pre_equilibrate
|
||||
from .base_units import (
|
||||
BaseSepTopAnalysisUnit,
|
||||
BaseSepTopRunUnit,
|
||||
BaseSepTopSetupUnit,
|
||||
_pre_equilibrate,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_mdtraj_from_openmm(
|
||||
omm_topology: openmm.app.Topology,
|
||||
omm_positions: openmm.unit.Quantity,
|
||||
):
|
||||
"""
|
||||
Get an mdtraj object from an OpenMM topology and positions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
omm_topology: openmm.app.Topology
|
||||
The OpenMM topology
|
||||
omm_positions: openmm.unit.Quantity
|
||||
The OpenMM positions
|
||||
|
||||
Returns
|
||||
-------
|
||||
mdtraj_system: md.Trajectory
|
||||
"""
|
||||
mdtraj_topology = md.Topology.from_openmm(omm_topology)
|
||||
positions_in_mdtraj_format = omm_positions.value_in_unit(omm_units.nanometers)
|
||||
|
||||
box = omm_topology.getPeriodicBoxVectors()
|
||||
x, y, z = [np.array(b._value) for b in box]
|
||||
lx = np.linalg.norm(x)
|
||||
ly = np.linalg.norm(y)
|
||||
lz = np.linalg.norm(z)
|
||||
# angle between y and z
|
||||
alpha = np.arccos(np.dot(y, z) / (ly * lz))
|
||||
# angle between x and z
|
||||
beta = np.arccos(np.dot(x, z) / (lx * lz))
|
||||
# angle between x and y
|
||||
gamma = np.arccos(np.dot(x, y) / (lx * ly))
|
||||
|
||||
mdtraj_system = md.Trajectory(
|
||||
positions_in_mdtraj_format,
|
||||
mdtraj_topology,
|
||||
unitcell_lengths=np.array([lx, ly, lz]),
|
||||
unitcell_angles=np.array([np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)]),
|
||||
)
|
||||
|
||||
return mdtraj_system
|
||||
|
||||
|
||||
class SepTopComplexMixin:
|
||||
"""
|
||||
A mixin to get the components and the settings for the Complex Units.
|
||||
@@ -132,7 +99,7 @@ class SepTopComplexMixin:
|
||||
|
||||
return alchem_comps, solv_comp, prot_comp, small_mols
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a complex transformation.
|
||||
|
||||
@@ -218,9 +185,9 @@ class SepTopSolventMixin:
|
||||
|
||||
return alchem_comps, solv_comp, None, small_mols
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a complex transformation.
|
||||
Extract the relevant settings for a solvent transformation.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -391,8 +358,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
updated_positions_B: openmm.unit.Quantity
|
||||
Updated positions of the complex B
|
||||
"""
|
||||
mdtraj_complex_A = _get_mdtraj_from_openmm(omm_topology_A, positions_A)
|
||||
mdtraj_complex_B = _get_mdtraj_from_openmm(omm_topology_B, positions_B)
|
||||
mdtraj_complex_A = mdtraj_from_openmm(omm_topology_A, positions_A)
|
||||
mdtraj_complex_B = mdtraj_from_openmm(omm_topology_B, positions_B)
|
||||
alignment_indices = SepTopComplexSetupUnit._get_selection_atom_indices(mdtraj_complex_A)
|
||||
imaged_complex_B = mdtraj_complex_B.image_molecules()
|
||||
imaged_complex_B.superpose(
|
||||
@@ -701,13 +668,15 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
# 0. General preparation tasks
|
||||
self._prepare(verbose, scratch_basepath, shared_basepath)
|
||||
|
||||
self.logger.info("Setting up SepTop complex system.")
|
||||
|
||||
# 1. Get components
|
||||
self.logger.info("Creating and setting up the OpenMM systems")
|
||||
alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components()
|
||||
smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps)
|
||||
|
||||
# 3. Get settings
|
||||
settings = self._handle_settings()
|
||||
settings = self._get_settings()
|
||||
|
||||
# 4. Assign partial charges
|
||||
self._assign_partial_charges(settings["charge_settings"], smc_comps_AB)
|
||||
@@ -740,11 +709,6 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
smc_comp_B_unique,
|
||||
settings,
|
||||
)
|
||||
# Virtual sites sanity check - ensure we restart velocities when
|
||||
# there are virtual sites in the system
|
||||
self.check_assign_velocities_with_virtual_site(
|
||||
omm_system_AB, settings["integrator_settings"]
|
||||
)
|
||||
|
||||
# Get the comp_resids of the AB system
|
||||
resids_A = list(itertools.chain(*comp_resids_A.values()))
|
||||
@@ -753,28 +717,38 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
comp_resids_AB = comp_resids_A | {alchem_comps["stateB"][0]: np.array(diff_resids)}
|
||||
|
||||
# 6. Pre-equilbrate System (for restraint selection)
|
||||
self.logger.info("Pre-equilibrating the systems")
|
||||
equil_positions_A, box_A = _pre_equilibrate(
|
||||
omm_system_A,
|
||||
omm_topology_A,
|
||||
positions_A,
|
||||
settings,
|
||||
"A",
|
||||
dry,
|
||||
self.shared_basepath,
|
||||
self.verbose,
|
||||
self.logger,
|
||||
platform = omm_compute.get_openmm_platform(
|
||||
platform_name=settings["engine_settings"].compute_platform,
|
||||
gpu_device_index=settings["engine_settings"].gpu_device_index,
|
||||
restrict_cpu_count=False,
|
||||
)
|
||||
|
||||
self.logger.info("Pre-equilibrating the systems")
|
||||
|
||||
equil_positions_A, box_A = _pre_equilibrate(
|
||||
system=omm_system_A,
|
||||
topology=omm_topology_A,
|
||||
positions=positions_A,
|
||||
settings=settings,
|
||||
endstate="A",
|
||||
dry=dry,
|
||||
shared_basepath=self.shared_basepath,
|
||||
platform=platform,
|
||||
verbose=self.verbose,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
equil_positions_B, box_B = _pre_equilibrate(
|
||||
omm_system_B,
|
||||
omm_topology_B,
|
||||
positions_B,
|
||||
settings,
|
||||
"B",
|
||||
dry,
|
||||
self.shared_basepath,
|
||||
self.verbose,
|
||||
self.logger,
|
||||
system=omm_system_B,
|
||||
topology=omm_topology_B,
|
||||
positions=positions_B,
|
||||
settings=settings,
|
||||
endstate="B",
|
||||
dry=dry,
|
||||
shared_basepath=self.shared_basepath,
|
||||
platform=platform,
|
||||
verbose=self.verbose,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
# 7. Get all the right atom indices for alignments
|
||||
@@ -832,20 +806,35 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
)
|
||||
|
||||
equil_positions_AB, box_AB = _pre_equilibrate(
|
||||
system,
|
||||
omm_topology_AB,
|
||||
positions_AB,
|
||||
settings,
|
||||
"AB",
|
||||
dry,
|
||||
self.shared_basepath,
|
||||
self.verbose,
|
||||
self.logger,
|
||||
system=system,
|
||||
topology=omm_topology_AB,
|
||||
positions=positions_AB,
|
||||
settings=settings,
|
||||
endstate="AB",
|
||||
dry=dry,
|
||||
platform=platform,
|
||||
shared_basepath=self.shared_basepath,
|
||||
verbose=self.verbose,
|
||||
logger=self.logger,
|
||||
)
|
||||
|
||||
# Update box vectors
|
||||
omm_topology_AB.setPeriodicBoxVectors(box_AB)
|
||||
|
||||
# Serialize system, state and integrator
|
||||
# Subselect system based on user inputs & write initial subsampled PDB
|
||||
sub_pdb_structure = self.shared_basepath / settings["output_settings"].output_structure
|
||||
selection_indices = self._subsample_topology(
|
||||
topology=omm_topology_AB,
|
||||
positions=positions_AB,
|
||||
output_selection=settings["output_settings"].output_indices,
|
||||
output_file=self.shared_basepath / settings["output_settings"].output_structure,
|
||||
)
|
||||
# The subsampled PDB may not have been written if selection_indices == 0
|
||||
# Issue #1942 - maybe move this to the method?
|
||||
if len(selection_indices) == 0:
|
||||
sub_pdb_structure = None
|
||||
|
||||
# Serialize the system and PDB topology
|
||||
system_outfile = self.shared_basepath / "system.xml.bz2"
|
||||
serialize(system, system_outfile)
|
||||
|
||||
@@ -864,21 +853,23 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
|
||||
"standard_state_correction_B": corr_B.to("kilocalorie_per_mole"),
|
||||
"restraint_geometry_A": restraint_geom_A.model_dump(),
|
||||
"restraint_geometry_B": restraint_geom_B.model_dump(),
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
# Add in various objects we can use to test the system
|
||||
"debug": {
|
||||
"system": system_outfile,
|
||||
"topology": topology_file,
|
||||
"system_A": omm_system_A,
|
||||
"system_B": omm_system_B,
|
||||
"system_AB": omm_system_AB,
|
||||
"restrained_system": system,
|
||||
"alchem_system": alchemical_system,
|
||||
"alchem_factory": alchemical_factory,
|
||||
"positions": equil_positions_AB,
|
||||
}
|
||||
"system": system_outfile,
|
||||
"topology": topology_file,
|
||||
"system_A": omm_system_A,
|
||||
"system_B": omm_system_B,
|
||||
"system_AB": omm_system_AB,
|
||||
"alchem_restrained_system": system,
|
||||
"alchem_system": alchemical_system,
|
||||
"alchem_factory": alchemical_factory,
|
||||
"positions": equil_positions_AB,
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
}
|
||||
|
||||
|
||||
@@ -1049,13 +1040,15 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
# 0. General preparation tasks
|
||||
self._prepare(verbose, scratch_basepath, shared_basepath)
|
||||
|
||||
self.logger.info("Setting up SepTop solvent system.")
|
||||
|
||||
# 1. Get components
|
||||
self.logger.info("Creating and setting up the OpenMM systems")
|
||||
alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components()
|
||||
smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps)
|
||||
|
||||
# 2. Get settings
|
||||
settings = self._handle_settings()
|
||||
settings = self._get_settings()
|
||||
|
||||
# 3. Assign partial charges
|
||||
self._assign_partial_charges(settings["charge_settings"], smc_comps_AB)
|
||||
@@ -1078,12 +1071,6 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
)
|
||||
) # fmt: skip
|
||||
|
||||
# Virtual sites sanity check - ensure we restart velocities when
|
||||
# there are virtual sites in the system
|
||||
self.check_assign_velocities_with_virtual_site(
|
||||
omm_system_AB, settings["integrator_settings"]
|
||||
)
|
||||
|
||||
# 6. Get atom indices for ligand A and ligand B and the solvent in the
|
||||
# system AB
|
||||
comp_atomids_AB = self._get_atom_indices(omm_topology_AB, comp_resids_AB)
|
||||
@@ -1116,16 +1103,27 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
positions_AB,
|
||||
)
|
||||
|
||||
# Write the full system PDB
|
||||
topology_file = self.shared_basepath / "topology.pdb"
|
||||
openmm.app.pdbfile.PDBFile.writeFile(
|
||||
omm_topology_AB, positions_AB, open(topology_file, "w")
|
||||
)
|
||||
|
||||
# ToDo: also apply REST
|
||||
# Subselect system based on user inputs & write initial subsampled PDB
|
||||
sub_pdb_structure = self.shared_basepath / settings["output_settings"].output_structure
|
||||
selection_indices = self._subsample_topology(
|
||||
topology=omm_topology_AB,
|
||||
positions=positions_AB,
|
||||
output_selection=settings["output_settings"].output_indices,
|
||||
output_file=self.shared_basepath / settings["output_settings"].output_structure,
|
||||
)
|
||||
# The subsampled PDB may not have been written if selection_indices == 0
|
||||
# Issue #1942 - maybe move this to the method?
|
||||
if len(selection_indices) == 0:
|
||||
sub_pdb_structure = None
|
||||
|
||||
# Serialize the system
|
||||
system_outfile = self.shared_basepath / "system.xml.bz2"
|
||||
|
||||
# Serialize system, state and integrator
|
||||
serialize(system, system_outfile)
|
||||
|
||||
if not dry:
|
||||
@@ -1133,19 +1131,21 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
|
||||
"system": system_outfile,
|
||||
"topology": topology_file,
|
||||
"standard_state_correction": corr.to("kilocalorie_per_mole"),
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
# Add in various objects we can used to test the system
|
||||
"debug": {
|
||||
"system": system_outfile,
|
||||
"topology": topology_file,
|
||||
"system_AB": omm_system_AB,
|
||||
"restrained_system": system,
|
||||
"alchem_system": alchemical_system,
|
||||
"alchem_factory": alchemical_factory,
|
||||
"positions": positions_AB,
|
||||
}
|
||||
"system": system_outfile,
|
||||
"topology": topology_file,
|
||||
"system_AB": omm_system_AB,
|
||||
"alchem_restrained_system": system,
|
||||
"alchem_system": alchemical_system,
|
||||
"alchem_factory": alchemical_factory,
|
||||
"positions": positions_AB,
|
||||
"selection_indices": selection_indices,
|
||||
"subsampled_pdb_structure": sub_pdb_structure,
|
||||
}
|
||||
|
||||
|
||||
@@ -1218,3 +1218,19 @@ class SepTopComplexRunUnit(SepTopComplexMixin, BaseSepTopRunUnit):
|
||||
lambdas["lambda_restraints_B"] = lambda_restraints_B
|
||||
|
||||
return lambdas
|
||||
|
||||
|
||||
class SepTopSolventAnalysisUnit(SepTopSolventMixin, BaseSepTopAnalysisUnit):
|
||||
"""
|
||||
Protocol Unit for the analysis of the solvent phase of a relative SepTop free energy
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
|
||||
class SepTopComplexAnalysisUnit(SepTopComplexMixin, BaseSepTopAnalysisUnit):
|
||||
"""
|
||||
Protocol Unit for the analysis of the complex phase of a relative SepTop free energy
|
||||
"""
|
||||
|
||||
simtype = "complex"
|
||||
|
||||
57
src/openfe/protocols/openmm_utils/mdtraj_utils.py
Normal file
57
src/openfe/protocols/openmm_utils/mdtraj_utils.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
|
||||
import mdtraj as mdt
|
||||
import numpy as np
|
||||
import openmm
|
||||
from openmm import unit as omm_unit
|
||||
|
||||
|
||||
def mdtraj_from_openmm(
|
||||
omm_topology: openmm.app.Topology,
|
||||
omm_positions: openmm.unit.Quantity,
|
||||
):
|
||||
"""
|
||||
Get an mdtraj object from an OpenMM topology and positions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
omm_topology : openmm.app.Topology
|
||||
The OpenMM topology
|
||||
omm_positions : openmm.unit.Quantity
|
||||
The OpenMM positions
|
||||
|
||||
Returns
|
||||
-------
|
||||
mdtraj_trajectory : md.Trajectory
|
||||
"""
|
||||
mdtraj_topology = mdt.Topology.from_openmm(omm_topology)
|
||||
positions_in_mdtraj_format = omm_positions.value_in_unit(omm_unit.nanometers)
|
||||
|
||||
box = omm_topology.getPeriodicBoxVectors()
|
||||
if box is not None:
|
||||
x, y, z = [np.array(b._value) for b in box]
|
||||
lx = np.linalg.norm(x)
|
||||
ly = np.linalg.norm(y)
|
||||
lz = np.linalg.norm(z)
|
||||
# angle between y and z
|
||||
alpha = np.arccos(np.dot(y, z) / (ly * lz))
|
||||
# angle between x and z
|
||||
beta = np.arccos(np.dot(x, z) / (lx * lz))
|
||||
# angle between x and y
|
||||
gamma = np.arccos(np.dot(x, y) / (lx * ly))
|
||||
|
||||
unitcell_lengths = np.array([lx, ly, lz])
|
||||
unitcell_angles = np.array([np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)])
|
||||
else:
|
||||
unitcell_lengths = None
|
||||
unitcell_angles = None
|
||||
|
||||
mdtraj_trajectory = mdt.Trajectory(
|
||||
positions_in_mdtraj_format,
|
||||
mdtraj_topology,
|
||||
unitcell_lengths=unitcell_lengths,
|
||||
unitcell_angles=unitcell_angles,
|
||||
)
|
||||
|
||||
return mdtraj_trajectory
|
||||
@@ -1,3 +1,6 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
|
||||
Binary file not shown.
30
src/openfe/tests/protocols/openmm_septop/conftest.py
Normal file
30
src/openfe/tests/protocols/openmm_septop/conftest.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import gufe
|
||||
import pytest
|
||||
|
||||
from openfe.protocols.openmm_septop import SepTopProtocol
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def protocol_dry_settings():
|
||||
# a set of settings for dry run tests
|
||||
s = SepTopProtocol.default_settings()
|
||||
s.engine_settings.compute_platform = None
|
||||
s.protocol_repeats = 1
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def benzene_toluene_dag(
|
||||
benzene_complex_system,
|
||||
toluene_complex_system,
|
||||
protocol_dry_settings,
|
||||
):
|
||||
protocol = SepTopProtocol(settings=protocol_dry_settings)
|
||||
|
||||
return protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
@@ -33,10 +33,12 @@ from openmmtools.multistate.multistatesampler import MultiStateSampler
|
||||
import openfe.protocols.openmm_septop
|
||||
from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols.openmm_septop import (
|
||||
SepTopComplexAnalysisUnit,
|
||||
SepTopComplexRunUnit,
|
||||
SepTopComplexSetupUnit,
|
||||
SepTopProtocol,
|
||||
SepTopProtocolResult,
|
||||
SepTopSolventAnalysisUnit,
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
@@ -48,6 +50,8 @@ from openfe.tests.protocols.openmm_ahfe.test_ahfe_protocol import (
|
||||
_verify_alchemical_sterics_force_parameters,
|
||||
)
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
E_CHARGE = 1.602176634e-19 * openmm.unit.coulomb
|
||||
EPSILON0 = (
|
||||
1e-6
|
||||
@@ -59,15 +63,6 @@ EPSILON0 = (
|
||||
ONE_4PI_EPS0 = 1 / (4 * np.pi * EPSILON0) * EPSILON0.unit * 10.0 # nm -> angstrom
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def protocol_dry_settings():
|
||||
# a set of settings for dry run tests
|
||||
s = SepTopProtocol.default_settings()
|
||||
s.engine_settings.compute_platform = None
|
||||
s.protocol_repeats = 1
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def default_settings():
|
||||
s = SepTopProtocol.default_settings()
|
||||
@@ -93,6 +88,47 @@ def test_serialize_protocol(default_settings):
|
||||
assert protocol == ret
|
||||
|
||||
|
||||
def test_repeat_units(benzene_complex_system, toluene_complex_system, default_settings):
|
||||
default_settings.protocol_repeats = 3
|
||||
protocol = SepTopProtocol(
|
||||
settings=default_settings,
|
||||
)
|
||||
|
||||
dag = protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
# 6 protocol unit, 3 per repeat
|
||||
pus = list(dag.protocol_units)
|
||||
assert len(pus) == 18
|
||||
|
||||
# Check info for each repeat
|
||||
for phase in ["solvent", "complex"]:
|
||||
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
|
||||
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
|
||||
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
# Should be 3 of each set
|
||||
assert len(setup) == 3
|
||||
assert len(sim) == 3
|
||||
assert len(analysis) == 3
|
||||
|
||||
# check that the dag chain is correct
|
||||
for analysis_pu in analysis:
|
||||
repeat_id = analysis_pu.inputs["repeat_id"]
|
||||
setup_pu = [
|
||||
s for s in setup if (s.inputs["repeat_id"] == repeat_id) and (s.simtype == phase)
|
||||
][0]
|
||||
sim_pu = [
|
||||
s for s in sim if (s.inputs["repeat_id"] == repeat_id) and (s.simtype == phase)
|
||||
][0]
|
||||
assert analysis_pu.inputs["setup"] == setup_pu
|
||||
assert analysis_pu.inputs["simulation"] == sim_pu
|
||||
assert sim_pu.inputs["setup"] == setup_pu
|
||||
|
||||
|
||||
def test_create_independent_repeat_ids(
|
||||
benzene_complex_system, toluene_complex_system, default_settings
|
||||
):
|
||||
@@ -104,25 +140,26 @@ def test_create_independent_repeat_ids(
|
||||
settings=default_settings,
|
||||
)
|
||||
|
||||
dag1 = protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
dag2 = protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
# print([u for u in dag1.protocol_units])
|
||||
repeat_ids = set()
|
||||
for u in dag1.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
for u in dag2.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
dags = []
|
||||
for i in range(2):
|
||||
dags.append(
|
||||
protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
)
|
||||
|
||||
# There are 4 units per repeat per DAG: 4 * 3 * 2 = 24
|
||||
assert len(repeat_ids) == 24
|
||||
repeat_ids = set()
|
||||
|
||||
for dag in dags:
|
||||
# 3 repeats of 6 units
|
||||
assert len(list(dag.protocol_units)) == 18
|
||||
for u in dag.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
|
||||
# one uuid per repeat, so should equal 6
|
||||
assert len(repeat_ids) == 6
|
||||
|
||||
|
||||
# Tests for the alchemical systems. This tests were modified from
|
||||
@@ -325,25 +362,10 @@ class TestNonbondedInteractions:
|
||||
assert_allclose(energy, from_openmm(expected_energy), rtol=1e-05)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def benzene_toluene_dag(
|
||||
benzene_complex_system,
|
||||
toluene_complex_system,
|
||||
protocol_dry_settings,
|
||||
):
|
||||
protocol = SepTopProtocol(settings=protocol_dry_settings)
|
||||
|
||||
return protocol.create(
|
||||
stateA=benzene_complex_system,
|
||||
stateB=toluene_complex_system,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
|
||||
def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
|
||||
prot_units = list(benzene_toluene_dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 4
|
||||
assert len(prot_units) == 6
|
||||
|
||||
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
|
||||
sol_run_unit = [u for u in prot_units if isinstance(u, SepTopSolventRunUnit)]
|
||||
@@ -356,7 +378,7 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
|
||||
|
||||
solv_setup_output = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
pdb = md.load_pdb(tmp_path / "topology.pdb")
|
||||
assert pdb.n_atoms == 1762
|
||||
central_atoms = np.array([[2, 19]], dtype=np.int32)
|
||||
@@ -367,10 +389,11 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
|
||||
solv_sampler = sol_run_unit[0].run(
|
||||
alchem_system,
|
||||
pdb_file,
|
||||
solv_setup_output["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)["debug"]["sampler"] # fmt: skip
|
||||
)["sampler"] # fmt: skip
|
||||
|
||||
assert solv_sampler.is_periodic
|
||||
assert isinstance(solv_sampler, MultiStateSampler)
|
||||
@@ -397,16 +420,17 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
|
||||
|
||||
complex_setup_output = complex_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(complex_setup_output["topology"]))
|
||||
alchem_system = deserialize(complex_setup_output["system"])
|
||||
complex_sampler = complex_run_unit[0].run(
|
||||
alchem_system,
|
||||
pdb_file,
|
||||
complex_setup_output["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)["debug"]["sampler"] # fmt: skip
|
||||
)["sampler"] # fmt: skip
|
||||
|
||||
assert complex_sampler.is_periodic
|
||||
assert isinstance(complex_sampler, MultiStateSampler)
|
||||
@@ -461,16 +485,17 @@ def test_dry_run_methods(
|
||||
|
||||
solv_setup_output = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
|
||||
alchem_system = deserialize(solv_setup_output["system"])
|
||||
solv_sampler = sol_run_unit[0].run(
|
||||
alchem_system,
|
||||
pdb_file,
|
||||
solv_setup_output["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)["debug"]["sampler"] # fmt: skip
|
||||
)["sampler"] # fmt: skip
|
||||
|
||||
assert isinstance(solv_sampler, MultiStateSampler)
|
||||
assert solv_sampler.is_periodic
|
||||
@@ -518,16 +543,17 @@ def test_dry_run_ligand_system_pressure(
|
||||
|
||||
solv_setup_output = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
|
||||
alchem_system = deserialize(solv_setup_output["system"])
|
||||
solv_sampler = sol_run_unit[0].run(
|
||||
alchem_system,
|
||||
pdb_file,
|
||||
solv_setup_output["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)["debug"]["sampler"] # fmt: skip
|
||||
)["sampler"] # fmt: skip
|
||||
|
||||
# at this point, the units will be in openmm units
|
||||
assert solv_sampler._thermodynamic_states[1].pressure == pressure * openmm.unit.bar
|
||||
@@ -561,9 +587,25 @@ def test_virtual_sites_no_reassign(
|
||||
dag_units = list(dag.protocol_units)
|
||||
# Only check the Solvent Unit
|
||||
solv_setup_unit = [u for u in dag_units if isinstance(u, SepTopSolventSetupUnit)]
|
||||
errmsg = "Simulations with virtual sites without velocity"
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
_ = solv_setup_unit[0].run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
|
||||
solv_run_unit = [u for u in dag_units if isinstance(u, SepTopSolventRunUnit)]
|
||||
|
||||
setup_results = solv_setup_unit[0].run(
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)
|
||||
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(setup_results["topology"]))
|
||||
|
||||
with pytest.raises(ValueError, match="are unstable"):
|
||||
_ = solv_run_unit[0].run(
|
||||
setup_results["alchem_system"],
|
||||
pdb_file,
|
||||
setup_results["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
) # fmt: skip
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -597,7 +639,7 @@ def test_dry_run_ligand_system_cutoff(
|
||||
|
||||
serialized_system = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]["system"]
|
||||
)["system"]
|
||||
system = deserialize(serialized_system)
|
||||
nbfs = [
|
||||
f
|
||||
@@ -636,7 +678,7 @@ def test_dry_run_benzene_toluene_tip4p(
|
||||
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 4
|
||||
assert len(prot_units) == 6
|
||||
|
||||
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
|
||||
sol_run_unit = [u for u in prot_units if isinstance(u, SepTopSolventRunUnit)]
|
||||
@@ -646,16 +688,17 @@ def test_dry_run_benzene_toluene_tip4p(
|
||||
|
||||
solv_setup_output = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
|
||||
alchem_system = deserialize(solv_setup_output["system"])
|
||||
solv_run = sol_run_unit[0].run(
|
||||
alchem_system,
|
||||
pdb_file,
|
||||
solv_setup_output["selection_indices"],
|
||||
dry=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)["debug"]["sampler"] # fmt: skip
|
||||
)["sampler"] # fmt: skip
|
||||
|
||||
assert solv_run.is_periodic
|
||||
|
||||
@@ -681,7 +724,7 @@ def test_dry_run_benzene_toluene_noncubic(
|
||||
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 4
|
||||
assert len(prot_units) == 6
|
||||
|
||||
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
|
||||
|
||||
@@ -689,7 +732,7 @@ def test_dry_run_benzene_toluene_noncubic(
|
||||
|
||||
solv_setup_output = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]
|
||||
)
|
||||
serialized_system = solv_setup_output["system"]
|
||||
system = deserialize(serialized_system)
|
||||
vectors = system.getDefaultPeriodicBoxVectors()
|
||||
@@ -779,7 +822,7 @@ def test_dry_run_solv_user_charges_benzene_toluene(
|
||||
# check sol_unit charges
|
||||
serialized_system = solv_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]["system"]
|
||||
)["system"]
|
||||
system = deserialize(serialized_system)
|
||||
nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)]
|
||||
assert len(nonbond) == 1
|
||||
@@ -799,7 +842,7 @@ def test_dry_run_solv_user_charges_benzene_toluene(
|
||||
# check complex_unit charges
|
||||
serialized_system = complex_setup_unit[0].run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)["debug"]["system"]
|
||||
)["system"]
|
||||
system = deserialize(serialized_system)
|
||||
nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)]
|
||||
assert len(nonbond) == 1
|
||||
@@ -840,6 +883,23 @@ def test_high_timestep(
|
||||
prot_units[0].run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
|
||||
|
||||
|
||||
def test_bad_sampler():
|
||||
class FakeSimSettings(gufe.settings.SettingsBaseModel):
|
||||
sampler_method: str = "foo bar"
|
||||
|
||||
errmsg = "Unknown sampler foo bar"
|
||||
with pytest.raises(AttributeError, match=errmsg):
|
||||
SepTopSolventRunUnit._get_sampler(
|
||||
integrator=None,
|
||||
reporter=None,
|
||||
simulation_settings=FakeSimSettings(),
|
||||
thermodynamic_settings=None,
|
||||
compound_states=None,
|
||||
sampler_states=None,
|
||||
platform=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def T4L_xml(
|
||||
benzene_complex_system,
|
||||
@@ -865,7 +925,7 @@ def T4L_xml(
|
||||
|
||||
tmp = tmp_path_factory.mktemp("xml_reg")
|
||||
|
||||
dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp)["debug"]
|
||||
dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp)
|
||||
|
||||
system = dryrun["system"]
|
||||
return deserialize(system)
|
||||
@@ -908,270 +968,6 @@ class TestT4LXmlRegression:
|
||||
assert a[2] == b[2]
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_toluene_dag, tmp_path):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
dag_units = benzene_toluene_dag.protocol_units
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopComplexSetupUnit.run", # fmt: skip
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopComplexRunUnit._execute", # fmt: skip
|
||||
return_value={
|
||||
"repeat_id": 0,
|
||||
"generation": 0,
|
||||
"simtype": "complex",
|
||||
"nc": "file.nc",
|
||||
"last_checkpoint": "chck.nc",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopSolventSetupUnit.run", # fmt: skip
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopSolventRunUnit._execute", # fmt: skip
|
||||
return_value={
|
||||
"repeat_id": 0,
|
||||
"generation": 0,
|
||||
"simtype": "solvent",
|
||||
"nc": "file.nc",
|
||||
"last_checkpoint": "chck.nc",
|
||||
},
|
||||
),
|
||||
):
|
||||
results = []
|
||||
for u in dag_units:
|
||||
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
|
||||
results.append(ret)
|
||||
solv_repeats = set()
|
||||
complex_repeats = set()
|
||||
for ret in results:
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
if ret.outputs["simtype"] == "complex":
|
||||
complex_repeats.add(ret.outputs["repeat_id"])
|
||||
else:
|
||||
solv_repeats.add(ret.outputs["repeat_id"])
|
||||
# Repeat ids are random ints so just check their lengths
|
||||
assert len(complex_repeats) == len(solv_repeats) == 2
|
||||
|
||||
|
||||
def test_gather(benzene_toluene_dag, tmp_path):
|
||||
# check that .gather behaves as expected
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopComplexSetupUnit.run", # fmt: skip
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopComplexRunUnit._execute", # fmt: skip
|
||||
return_value={
|
||||
"repeat_id": 0,
|
||||
"generation": 0,
|
||||
"simtype": "complex",
|
||||
"nc": "file.nc",
|
||||
"last_checkpoint": "chck.nc",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopSolventSetupUnit.run", # fmt: skip
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_septop.equil_septop_method"
|
||||
".SepTopSolventRunUnit._execute", # fmt: skip
|
||||
return_value={
|
||||
"repeat_id": 0,
|
||||
"generation": 0,
|
||||
"simtype": "solvent",
|
||||
"nc": "file.nc",
|
||||
"last_checkpoint": "chck.nc",
|
||||
},
|
||||
),
|
||||
):
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_toluene_dag,
|
||||
shared_basedir=tmp_path,
|
||||
scratch_basedir=tmp_path,
|
||||
keep_shared=True,
|
||||
)
|
||||
|
||||
protocol = SepTopProtocol(
|
||||
settings=SepTopProtocol.default_settings(),
|
||||
)
|
||||
|
||||
res = protocol.gather([dagres])
|
||||
|
||||
assert isinstance(res, openfe.protocols.openmm_septop.SepTopProtocolResult)
|
||||
|
||||
|
||||
class TestProtocolResult:
|
||||
@pytest.fixture()
|
||||
def protocolresult(self, septop_json):
|
||||
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
return pr
|
||||
|
||||
def test_reload_protocol_result(self, septop_json):
|
||||
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = SepTopProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
assert pr
|
||||
|
||||
def test_get_estimate(self, protocolresult):
|
||||
est = protocolresult.get_estimate()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(3.82, abs=0.1)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_uncertainty(self, protocolresult):
|
||||
est = protocolresult.get_uncertainty()
|
||||
|
||||
assert est.m == pytest.approx(0.0, abs=0.1)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_individual(self, protocolresult):
|
||||
inds = protocolresult.get_individual_estimates()
|
||||
|
||||
assert isinstance(inds, dict)
|
||||
assert isinstance(inds["solvent"], list)
|
||||
assert isinstance(inds["complex"], list)
|
||||
assert len(inds["solvent"]) == len(inds["complex"]) == 1
|
||||
for e, u in itertools.chain(inds["solvent"], inds["complex"]):
|
||||
assert e.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
assert u.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_get_forwards_etc(self, key, protocolresult):
|
||||
"""
|
||||
Due to the short simulation times, we expect the frwd/reverse
|
||||
analysis to be None.
|
||||
"""
|
||||
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
|
||||
with pytest.warns(UserWarning, match=wmsg):
|
||||
far = protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
assert isinstance(far, dict)
|
||||
assert isinstance(far[key], list)
|
||||
for entry in far[key]:
|
||||
assert entry is None
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_get_overlap_matrices(self, key, protocolresult):
|
||||
ovp = protocolresult.get_overlap_matrices()
|
||||
|
||||
assert isinstance(ovp, dict)
|
||||
assert isinstance(ovp[key], list)
|
||||
assert len(ovp[key]) == 1
|
||||
|
||||
ovp1 = ovp[key][0]
|
||||
assert isinstance(ovp1["matrix"], np.ndarray)
|
||||
if key == "solvent":
|
||||
lambda_nr = 27
|
||||
else:
|
||||
lambda_nr = 19
|
||||
assert ovp1["matrix"].shape == (lambda_nr, lambda_nr)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_get_replica_transition_statistics(self, key, protocolresult):
|
||||
rpx = protocolresult.get_replica_transition_statistics()
|
||||
if key == "solvent":
|
||||
lambda_nr = 27
|
||||
else:
|
||||
lambda_nr = 19
|
||||
assert isinstance(rpx, dict)
|
||||
assert isinstance(rpx[key], list)
|
||||
assert len(rpx[key]) == 1
|
||||
rpx1 = rpx[key][0]
|
||||
assert "eigenvalues" in rpx1
|
||||
assert "matrix" in rpx1
|
||||
|
||||
assert rpx1["eigenvalues"].shape == (lambda_nr,)
|
||||
assert rpx1["matrix"].shape == (lambda_nr, lambda_nr)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_equilibration_iterations(self, key, protocolresult):
|
||||
eq = protocolresult.equilibration_iterations()
|
||||
|
||||
assert isinstance(eq, dict)
|
||||
assert isinstance(eq[key], list)
|
||||
assert len(eq[key]) == 1
|
||||
assert all(isinstance(v, float) for v in eq[key])
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_production_iterations(self, key, protocolresult):
|
||||
prod = protocolresult.production_iterations()
|
||||
|
||||
assert isinstance(prod, dict)
|
||||
assert isinstance(prod[key], list)
|
||||
assert len(prod[key]) == 1
|
||||
assert all(isinstance(v, float) for v in prod[key])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key, expected_size",
|
||||
[
|
||||
["solvent", 87],
|
||||
["complex", 1868],
|
||||
],
|
||||
)
|
||||
def test_selection_indices(self, key, protocolresult, expected_size):
|
||||
indices = protocolresult.selection_indices()
|
||||
|
||||
assert isinstance(indices, dict)
|
||||
assert isinstance(indices[key], list)
|
||||
for inds in indices[key]:
|
||||
assert isinstance(inds, np.ndarray)
|
||||
assert len(inds) == expected_size
|
||||
|
||||
def test_filenotfound_replica_states(self, protocolresult):
|
||||
errmsg = "File could not be found"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocolresult.get_replica_states()
|
||||
|
||||
def test_restraint_geometry(self, protocolresult):
|
||||
geom = protocolresult.restraint_geometries()
|
||||
assert isinstance(geom, tuple)
|
||||
assert len(geom) == 2
|
||||
assert isinstance(geom[0], list)
|
||||
assert isinstance(geom[0][0], BoreschRestraintGeometry)
|
||||
assert geom[0][0].guest_atoms == [1779, 1778, 1777]
|
||||
assert geom[0][0].host_atoms == [802, 801, 800]
|
||||
assert pytest.approx(geom[0][0].r_aA0) == 0.798936 * offunit.nanometer
|
||||
assert pytest.approx(geom[0][0].theta_A0) == 2.049091 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].theta_B0) == 1.221973 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_A0) == 0.956774 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_B0) == -1.217188 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_C0) == -1.068226 * offunit.radian
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
class TestA2AMembraneDryRun:
|
||||
solvent = SolventComponent(ion_concentration=0 * offunit.molar)
|
||||
@@ -1231,6 +1027,10 @@ class TestA2AMembraneDryRun:
|
||||
def complex_run_units(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, SepTopComplexRunUnit)]
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def complex_analysis_unit(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, SepTopComplexAnalysisUnit)]
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def solvent_setup_units(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventSetupUnit)]
|
||||
@@ -1239,10 +1039,14 @@ class TestA2AMembraneDryRun:
|
||||
def solvent_run_units(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventRunUnit)]
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def solvent_analysis_unit(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventAnalysisUnit)]
|
||||
|
||||
def test_number_of_units(
|
||||
self, dag, complex_setup_units, complex_run_units, solvent_setup_units, solvent_run_units
|
||||
):
|
||||
assert len(list(dag.protocol_units)) == 4
|
||||
assert len(list(dag.protocol_units)) == 6
|
||||
assert len(complex_setup_units) == 1
|
||||
assert len(complex_run_units) == 1
|
||||
assert len(solvent_setup_units) == 1
|
||||
@@ -1400,18 +1204,21 @@ class TestA2AMembraneDryRun:
|
||||
adaptive_settings.complex_integrator_settings.barostat
|
||||
== "MonteCarloMembraneBarostat"
|
||||
)
|
||||
complex_setup_output = complex_setup_units[0].run(dry=True)["debug"]
|
||||
complex_setup_output = complex_setup_units[0].run(dry=True)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(complex_setup_output["topology"]))
|
||||
system = deserialize(complex_setup_output["system"])
|
||||
data = complex_run_units[0].run(system, pdb_file, dry=True)["debug"] # fmt: skip
|
||||
indices = complex_setup_output["selection_indices"]
|
||||
data = complex_run_units[0].run(system, pdb_file, indices, dry=True) # fmt: skip
|
||||
# Check the sampler
|
||||
self._verify_sampler(data["sampler"], complexed=True, settings=adaptive_settings)
|
||||
|
||||
# Check the alchemical system
|
||||
self._assert_expected_alchemical_forces(
|
||||
data["alchem_system"], complexed=True, settings=adaptive_settings
|
||||
complex_setup_output["alchem_restrained_system"],
|
||||
complexed=True,
|
||||
settings=adaptive_settings,
|
||||
)
|
||||
self._test_orthogonal_vectors(data["alchem_system"])
|
||||
self._test_orthogonal_vectors(complex_setup_output["alchem_restrained_system"])
|
||||
|
||||
# Check the non-alchemical system
|
||||
self._assert_expected_nonalchemical_forces(
|
||||
@@ -1420,7 +1227,9 @@ class TestA2AMembraneDryRun:
|
||||
self._test_orthogonal_vectors(complex_setup_output["system_AB"])
|
||||
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
|
||||
assert_allclose(
|
||||
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(
|
||||
complex_setup_output["alchem_restrained_system"].getDefaultPeriodicBoxVectors()
|
||||
),
|
||||
from_openmm(complex_setup_output["system_AB"].getDefaultPeriodicBoxVectors()),
|
||||
)
|
||||
|
||||
@@ -1433,23 +1242,24 @@ class TestA2AMembraneDryRun:
|
||||
|
||||
def test_solvent_dry_run(self, solvent_setup_units, solvent_run_units, settings, tmpdir):
|
||||
with tmpdir.as_cwd():
|
||||
solv_setup_output = solvent_setup_units[0].run(dry=True)["debug"]
|
||||
solv_setup_output = solvent_setup_units[0].run(dry=True)
|
||||
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
|
||||
system = deserialize(solv_setup_output["system"])
|
||||
data = solvent_run_units[0].run(system, pdb_file, dry=True)["debug"] # fmt: skip
|
||||
indices = solv_setup_output["selection_indices"]
|
||||
data = solvent_run_units[0].run(system, pdb_file, indices, dry=True) # fmt: skip
|
||||
|
||||
# Check the sampler
|
||||
self._verify_sampler(data["sampler"], complexed=False, settings=settings)
|
||||
|
||||
# Check the alchemical system
|
||||
self._assert_expected_alchemical_forces(
|
||||
data["alchem_system"], complexed=False, settings=settings
|
||||
solv_setup_output["alchem_restrained_system"], complexed=False, settings=settings
|
||||
)
|
||||
self._test_cubic_vectors(data["alchem_system"])
|
||||
self._test_cubic_vectors(solv_setup_output["alchem_restrained_system"])
|
||||
|
||||
# Check the alchemical indices
|
||||
expected_indices = [i for i in range(self.num_ligand_atoms_A + self.num_ligand_atoms_B)]
|
||||
assert expected_indices == data["selection_indices"].tolist()
|
||||
assert expected_indices == solv_setup_output["selection_indices"].tolist()
|
||||
|
||||
# Check the non-alchemical system
|
||||
self._assert_expected_nonalchemical_forces(
|
||||
|
||||
@@ -0,0 +1,338 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import itertools
|
||||
import json
|
||||
import math
|
||||
import pathlib
|
||||
from unittest import mock
|
||||
|
||||
import gufe
|
||||
import mdtraj as md
|
||||
import numpy as np
|
||||
import openmm
|
||||
import openmm.app
|
||||
import openmm.unit
|
||||
import pytest
|
||||
from numpy.testing import assert_allclose
|
||||
from openff.units import unit as offunit
|
||||
from openff.units.openmm import ensure_quantity, from_openmm, to_openmm
|
||||
from openmm import (
|
||||
CustomBondForce,
|
||||
CustomCompoundBondForce,
|
||||
CustomNonbondedForce,
|
||||
HarmonicAngleForce,
|
||||
HarmonicBondForce,
|
||||
MonteCarloBarostat,
|
||||
MonteCarloMembraneBarostat,
|
||||
NonbondedForce,
|
||||
PeriodicTorsionForce,
|
||||
)
|
||||
from openmmtools.alchemy import AbsoluteAlchemicalFactory, AlchemicalRegion
|
||||
from openmmtools.multistate.multistatesampler import MultiStateSampler
|
||||
|
||||
import openfe.protocols.openmm_septop
|
||||
from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols.openmm_septop import (
|
||||
SepTopComplexRunUnit,
|
||||
SepTopComplexSetupUnit,
|
||||
SepTopProtocol,
|
||||
SepTopProtocolResult,
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_septop.base_units import (
|
||||
BaseSepTopAnalysisUnit,
|
||||
BaseSepTopRunUnit,
|
||||
BaseSepTopSetupUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils.serialization import deserialize
|
||||
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
|
||||
from openfe.tests.protocols.conftest import compute_energy
|
||||
from openfe.tests.protocols.openmm_ahfe.test_ahfe_protocol import (
|
||||
_assert_num_forces,
|
||||
_verify_alchemical_sterics_force_parameters,
|
||||
)
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patcher():
|
||||
base_path = "openfe.protocols.openmm_septop.base_units"
|
||||
protocol_path = "openfe.protocols.openmm_septop.equil_septop_method"
|
||||
with (
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopComplexSetupUnit.run",
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
"standard_state_correction_A": 0 * offunit.kilocalorie_per_mole,
|
||||
"standard_state_correction_B": 0 * offunit.kilocalorie_per_mole,
|
||||
"restraint_geometry_A": None,
|
||||
"restraint_geometry_B": None,
|
||||
"selection_indices": np.array(
|
||||
[
|
||||
0,
|
||||
]
|
||||
),
|
||||
"subsampled_pdb_structure": "subsampled.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopSolventSetupUnit.run",
|
||||
return_value={
|
||||
"system": pathlib.Path("system.xml.bz2"),
|
||||
"topology": "topology.pdb",
|
||||
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
|
||||
"selection_indices": np.array(
|
||||
[
|
||||
0,
|
||||
]
|
||||
),
|
||||
"subsampled_pdb_structure": "subsampled.pdb",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopComplexRunUnit.run",
|
||||
return_value={
|
||||
"trajectory": "foo.nc",
|
||||
"checkpoint": "bar.nc",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopSolventRunUnit.run",
|
||||
return_value={
|
||||
"trajectory": "foo.nc",
|
||||
"checkpoint": "bar.nc",
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopComplexAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
mock.patch(
|
||||
f"{protocol_path}.SepTopSolventAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
mock.patch(
|
||||
f"{base_path}.deserialize",
|
||||
return_value="foo",
|
||||
),
|
||||
mock.patch(
|
||||
f"{base_path}.openmm.app.pdbfile.PDBFile",
|
||||
return_value="foo",
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_toluene_dag, patcher, tmp_path):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
dag_units = benzene_toluene_dag.protocol_units
|
||||
|
||||
for phase in ["solvent", "complex"]:
|
||||
setup_results = {}
|
||||
sim_results = {}
|
||||
analysis_results = {}
|
||||
|
||||
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
|
||||
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
|
||||
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
for u in setup_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
setup_results[rid] = u.execute(context=gufe.Context(tmp_path, tmp_path))
|
||||
|
||||
for u in sim_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
sim_results[rid] = u.execute(
|
||||
context=gufe.Context(tmp_path, tmp_path),
|
||||
setup=setup_results[rid],
|
||||
)
|
||||
|
||||
for u in a_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
analysis_results[rid] = u.execute(
|
||||
context=gufe.Context(tmp_path, tmp_path),
|
||||
setup=setup_results[rid],
|
||||
simulation=sim_results[rid],
|
||||
)
|
||||
|
||||
for results in [setup_results, sim_results, analysis_results]:
|
||||
for ret in results.values():
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
|
||||
assert len(setup_results) == 1
|
||||
assert len(sim_results) == 1
|
||||
assert len(analysis_results) == 1
|
||||
|
||||
|
||||
def test_gather(benzene_toluene_dag, patcher, tmp_path):
|
||||
# check that .gather behaves as expected
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_toluene_dag,
|
||||
shared_basedir=tmp_path,
|
||||
scratch_basedir=tmp_path,
|
||||
keep_shared=True,
|
||||
)
|
||||
|
||||
protocol = SepTopProtocol(
|
||||
settings=SepTopProtocol.default_settings(),
|
||||
)
|
||||
|
||||
res = protocol.gather([dagres])
|
||||
|
||||
assert isinstance(res, openfe.protocols.openmm_septop.SepTopProtocolResult)
|
||||
|
||||
|
||||
class TestProtocolResult:
|
||||
@pytest.fixture()
|
||||
def protocolresult(self, septop_json):
|
||||
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
return pr
|
||||
|
||||
def test_reload_protocol_result(self, septop_json):
|
||||
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = SepTopProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
assert pr
|
||||
|
||||
def test_get_estimate(self, protocolresult):
|
||||
est = protocolresult.get_estimate()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(1.6, abs=0.1)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_uncertainty(self, protocolresult):
|
||||
est = protocolresult.get_uncertainty()
|
||||
|
||||
assert est.m == pytest.approx(0.0, abs=0.1)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_individual(self, protocolresult):
|
||||
inds = protocolresult.get_individual_estimates()
|
||||
|
||||
assert isinstance(inds, dict)
|
||||
assert isinstance(inds["solvent"], list)
|
||||
assert isinstance(inds["complex"], list)
|
||||
assert len(inds["solvent"]) == len(inds["complex"]) == 1
|
||||
for e, u in itertools.chain(inds["solvent"], inds["complex"]):
|
||||
assert e.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
assert u.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_forwards_etc(self, protocolresult):
|
||||
"""
|
||||
Due to the short simulation times, we expect the frwd/reverse
|
||||
analysis of the solvent to be None.
|
||||
"""
|
||||
wmsg = "were found in the forward and reverse dictionaries of the repeats of the solvent"
|
||||
with pytest.warns(UserWarning, match=wmsg):
|
||||
far = protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
assert isinstance(far, dict)
|
||||
for key in ["solvent", "complex"]:
|
||||
assert isinstance(far[key], list)
|
||||
|
||||
assert far["solvent"][0] is None
|
||||
|
||||
complex_keys = list(far["complex"][0].keys())
|
||||
|
||||
for key in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
|
||||
assert key in complex_keys
|
||||
assert len(far["complex"][0][key]) == 10
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_get_overlap_matrices(self, key, protocolresult):
|
||||
ovp = protocolresult.get_overlap_matrices()
|
||||
|
||||
assert isinstance(ovp, dict)
|
||||
assert isinstance(ovp[key], list)
|
||||
assert len(ovp[key]) == 1
|
||||
|
||||
ovp1 = ovp[key][0]
|
||||
assert isinstance(ovp1["matrix"], np.ndarray)
|
||||
if key == "solvent":
|
||||
lambda_nr = 27
|
||||
else:
|
||||
lambda_nr = 19
|
||||
assert ovp1["matrix"].shape == (lambda_nr, lambda_nr)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_get_replica_transition_statistics(self, key, protocolresult):
|
||||
rpx = protocolresult.get_replica_transition_statistics()
|
||||
if key == "solvent":
|
||||
lambda_nr = 27
|
||||
else:
|
||||
lambda_nr = 19
|
||||
assert isinstance(rpx, dict)
|
||||
assert isinstance(rpx[key], list)
|
||||
assert len(rpx[key]) == 1
|
||||
rpx1 = rpx[key][0]
|
||||
assert "eigenvalues" in rpx1
|
||||
assert "matrix" in rpx1
|
||||
|
||||
assert rpx1["eigenvalues"].shape == (lambda_nr,)
|
||||
assert rpx1["matrix"].shape == (lambda_nr, lambda_nr)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_equilibration_iterations(self, key, protocolresult):
|
||||
eq = protocolresult.equilibration_iterations()
|
||||
|
||||
assert isinstance(eq, dict)
|
||||
assert isinstance(eq[key], list)
|
||||
assert len(eq[key]) == 1
|
||||
assert all(isinstance(v, float) for v in eq[key])
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "complex"])
|
||||
def test_production_iterations(self, key, protocolresult):
|
||||
prod = protocolresult.production_iterations()
|
||||
|
||||
assert isinstance(prod, dict)
|
||||
assert isinstance(prod[key], list)
|
||||
assert len(prod[key]) == 1
|
||||
assert all(isinstance(v, float) for v in prod[key])
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key, expected_size",
|
||||
[
|
||||
["solvent", 87],
|
||||
["complex", 1868],
|
||||
],
|
||||
)
|
||||
def test_selection_indices(self, key, protocolresult, expected_size):
|
||||
indices = protocolresult.selection_indices()
|
||||
|
||||
assert isinstance(indices, dict)
|
||||
assert isinstance(indices[key], list)
|
||||
for inds in indices[key]:
|
||||
assert isinstance(inds, np.ndarray)
|
||||
assert len(inds) == expected_size
|
||||
|
||||
def test_filenotfound_replica_states(self, protocolresult):
|
||||
errmsg = "File could not be found"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocolresult.get_replica_states()
|
||||
|
||||
def test_restraint_geometry(self, protocolresult):
|
||||
geom = protocolresult.restraint_geometries()
|
||||
assert isinstance(geom, tuple)
|
||||
assert len(geom) == 2
|
||||
assert isinstance(geom[0], list)
|
||||
assert isinstance(geom[0][0], BoreschRestraintGeometry)
|
||||
assert geom[0][0].guest_atoms == [1779, 1778, 1777]
|
||||
assert geom[0][0].host_atoms == [802, 801, 800]
|
||||
assert pytest.approx(geom[0][0].r_aA0, abs=0.01) == 0.75 * offunit.nanometer
|
||||
assert pytest.approx(geom[0][0].theta_A0, abs=0.01) == 1.95 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].theta_B0, abs=0.01) == 1.33 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_A0, abs=0.01) == 1.01 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_B0, abs=0.01) == -1.24 * offunit.radian
|
||||
assert pytest.approx(geom[0][0].phi_C0, abs=0.01) == -1.08 * offunit.radian
|
||||
@@ -277,10 +277,10 @@ def test_openmm_run_engine(
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
if "SepTopComplexRunUnit" in pur.source_key.split("-") or "SepTopSolventRunUnit" in pur.source_key.split("-"): # fmt: skip
|
||||
checkpoint = pur.outputs["last_checkpoint"]
|
||||
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
checkpoint = pur.outputs["checkpoint"]
|
||||
assert checkpoint == unit_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
assert (unit_shared / checkpoint).exists()
|
||||
nc = pur.outputs["nc"]
|
||||
nc = pur.outputs["trajectory"]
|
||||
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
|
||||
assert nc.exists()
|
||||
|
||||
|
||||
@@ -38,6 +38,13 @@ def solvent_run_protocol_unit(protocol_units):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_analysis_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_septop.SepTopSolventAnalysisUnit):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_setup_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
@@ -52,6 +59,13 @@ def complex_run_protocol_unit(protocol_units):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_analysis_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_septop.SepTopComplexAnalysisUnit):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol_result(septop_json):
|
||||
d = json.loads(
|
||||
@@ -115,6 +129,23 @@ class TestSepTopSolventRunUnit(GufeTokenizableTestsMixin):
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestSepTopSolventAnalysisUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_septop.SepTopSolventAnalysisUnit
|
||||
repr = "SepTopSolventAnalysisUnit(SepTop RBFE Analysis, transformation benzene to toluene, solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, solvent_analysis_protocol_unit):
|
||||
return solvent_analysis_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestSepTopComplexSetupUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_septop.SepTopComplexSetupUnit
|
||||
repr = (
|
||||
@@ -151,6 +182,23 @@ class TestSepTopComplexRunUnit(GufeTokenizableTestsMixin):
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestSepTopComplexAnalysisUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_septop.SepTopComplexAnalysisUnit
|
||||
repr = "SepTopComplexAnalysisUnit(SepTop RBFE Analysis, transformation benzene to toluene, complex leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, complex_analysis_protocol_unit):
|
||||
return complex_analysis_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestSepTopProtocolResult(GufeTokenizableTestsMixin):
|
||||
cls = openmm_septop.SepTopProtocolResult
|
||||
key = None
|
||||
|
||||
30
src/openfe/tests/protocols/openmm_septop/utils.py
Normal file
30
src/openfe/tests/protocols/openmm_septop/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
from openfe.protocols.openmm_septop import (
|
||||
SepTopComplexAnalysisUnit,
|
||||
SepTopComplexRunUnit,
|
||||
SepTopComplexSetupUnit,
|
||||
SepTopSolventAnalysisUnit,
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
|
||||
UNIT_TYPES = {
|
||||
"solvent": {
|
||||
"setup": SepTopSolventSetupUnit,
|
||||
"sim": SepTopSolventRunUnit,
|
||||
"analysis": SepTopSolventAnalysisUnit,
|
||||
},
|
||||
"complex": {
|
||||
"setup": SepTopComplexSetupUnit,
|
||||
"sim": SepTopComplexRunUnit,
|
||||
"analysis": SepTopComplexAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_units(protocol_units, unit_type):
|
||||
"""
|
||||
Helper method to extract setup units.
|
||||
"""
|
||||
return [pu for pu in protocol_units if isinstance(pu, unit_type)]
|
||||
@@ -97,7 +97,23 @@ def _get_legs_from_result_jsons(
|
||||
for k in result["unit_results"].keys()
|
||||
if k.startswith("ProtocolUnitResult")
|
||||
] # fmt: skip
|
||||
|
||||
# In openfe v1.11+, we only want to pick up results from
|
||||
# the Analysis Unit. To ensure backwards compatibility,
|
||||
# we check if there are any analysis units. If so,
|
||||
# we set a flag and later exclude Setup and Run.
|
||||
has_analysis_units = any(
|
||||
["Analysis" in result["unit_results"][p]["source_key"] for p in proto_key]
|
||||
)
|
||||
|
||||
for p in proto_key:
|
||||
# Skip non-analysis units if we have any
|
||||
if has_analysis_units and (
|
||||
"Setup" in result["unit_results"][p]["source_key"]
|
||||
or "Run" in result["unit_results"][p]["source_key"]
|
||||
):
|
||||
continue
|
||||
|
||||
if "unit_estimate" in result["unit_results"][p]["outputs"]:
|
||||
simtype = result["unit_results"][p]["outputs"]["simtype"]
|
||||
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
|
||||
@@ -135,8 +151,13 @@ def _get_names(result: dict) -> tuple[str, str]:
|
||||
|
||||
solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
|
||||
|
||||
name_A = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
|
||||
name_B = solvent_data["inputs"]["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"]
|
||||
try:
|
||||
setup_data = solvent_data["inputs"]["setup"]["inputs"]
|
||||
except KeyError:
|
||||
setup_data = solvent_data["inputs"]
|
||||
|
||||
name_A = setup_data["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
|
||||
name_B = setup_data["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"]
|
||||
|
||||
return str(name_A), str(name_B)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user