Turn AFE protocols into multiple units (#1776)

* Split the AFE protocol units into setup, simulation, and analysis.
This commit is contained in:
Irfan Alibay
2026-01-22 13:38:57 +00:00
committed by GitHub
parent 999613b658
commit 27c0d79cd1
27 changed files with 2158 additions and 1150 deletions

View File

@@ -16,8 +16,12 @@ Protocol API specification
:toctree: generated/
AbsoluteBindingProtocol
AbsoluteBindingComplexUnit
AbsoluteBindingSolventUnit
ABFEComplexAnalysisUnit
ABFEComplexSetupUnit
ABFEComplexSimUnit
ABFESolventAnalysisUnit
ABFESolventSetupUnit
ABFESolventSimUnit
AbsoluteBindingProtocolResult
Protocol Settings

View File

@@ -16,8 +16,12 @@ Protocol API specification
:toctree: generated/
AbsoluteSolvationProtocol
AbsoluteSolvationVacuumUnit
AbsoluteSolvationSolventUnit
AHFESolventAnalysisUnit
AHFESolventSetupUnit
AHFESolventSimUnit
AHFEVacuumAnalysisUnit
AHFEVacuumSetupUnit
AHFEVacuumSimUnit
AbsoluteSolvationProtocolResult
Protocol Settings

26
news/multi-unit-afe.rst Normal file
View File

@@ -0,0 +1,26 @@
**Added:**
* <news item>
**Changed:**
* The absolute free energy protocols have been broken into multiple
protocol units, allowing for setup, run, and analysis to happen
separately in the future when relevant changes to protocol execution are
made (PR #1776).
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View File

@@ -6,16 +6,24 @@ Run absolute free energy calculations using OpenMM and OpenMMTools.
"""
from .abfe_units import (
AbsoluteBindingComplexUnit,
AbsoluteBindingSolventUnit,
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
from .afe_protocol_results import (
AbsoluteBindingProtocolResult,
AbsoluteSolvationProtocolResult,
)
from .ahfe_units import (
AbsoluteSolvationSolventUnit,
AbsoluteSolvationVacuumUnit,
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
from .equil_binding_afe_method import (
AbsoluteBindingProtocol,
@@ -30,11 +38,19 @@ __all__ = [
"AbsoluteSolvationProtocol",
"AbsoluteSolvationSettings",
"AbsoluteSolvationProtocolResult",
"AbsoluteVacuumUnit",
"AbsoluteSolventUnit",
"AHFESolventSetupUnit",
"AHFESolventSimUnit",
"AHFESolventAnalysisUnit",
"AHFEVacuumSetupUnit",
"AHFEVacuumSimUnit",
"AHFEVacuumAnalysisUnit",
"AbsoluteBindingProtocol",
"AbsoluteBindingSettings",
"AbsoluteBindingProtocolResult",
"AbsoluteBindingComplexUnit",
"AbsoluteBindingSolventUnit",
"ABFEComplexSetupUnit",
"ABFEComplexSimUnit",
"ABFEComplexAnalysisUnit",
"ABFESolventSetupUnit",
"ABFESolventSimUnit",
"ABFESolventAnalysisUnit",
]

View File

@@ -1,7 +1,5 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""ABFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.abfe_units`
========================================================================
This module defines the ProtocolUnits for the
@@ -23,7 +21,7 @@ from openff.units.openmm import to_openmm
from openmm import System
from openmm import unit as ommunit
from openmm.app import Topology as omm_topology
from openmmtools.states import GlobalParameterState, ThermodynamicState
from openmmtools.states import ThermodynamicState
from rdkit import Chem
from openfe.protocols.openmm_afe.equil_afe_settings import (
@@ -36,18 +34,16 @@ from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGe
from openfe.protocols.restraint_utils.openmm import omm_restraints
from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint
from .base_afe_units import BaseAbsoluteUnit
from .base_afe_units import (
BaseAbsoluteMultiStateAnalysisUnit,
BaseAbsoluteMultiStateSimulationUnit,
BaseAbsoluteSetupUnit,
)
logger = logging.getLogger(__name__)
class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
"""
Protocol Unit for the complex phase of an absolute binding free energy
"""
simtype = "complex"
class ComplexComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a complex transformation.
@@ -75,7 +71,9 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
# Similarly we don't need to check prot_comp
return alchem_comps, solv_comp, prot_comp, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
class ComplexSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a complex transformation.
@@ -97,7 +95,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
* output_settings: MultiStateOutputSettings
* restraint_settings: BaseRestraintSettings
"""
prot_settings = self._inputs["protocol"].settings
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.forcefield_settings
@@ -116,6 +114,15 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
return settings
class ABFEComplexSetupUnit(ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the complex phase of absolute binding free energy
transformations.
"""
simtype = "complex"
@staticmethod
def _get_mda_universe(
topology: omm_topology,
@@ -261,7 +268,6 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
comp_resids: dict[Component, npt.NDArray],
settings: dict[str, SettingsBaseModel],
) -> tuple[
GlobalParameterState,
Quantity,
System,
geometry.HostGuestRestraintGeometry,
@@ -295,9 +301,6 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
Returns
-------
restraint_parameter_state : RestraintParameterState
A RestraintParameterState object that defines the control
parameter for the restraint.
correction : openff.units.Quantity
The standard state correction for the restraint.
system : openmm.System
@@ -380,10 +383,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
rest_geom,
)
# Get the GlobalParameterState for the restraint
restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0)
return (
restraint_parameter_state,
correction,
# Remove the thermostat, otherwise you'll get an
# Andersen thermostat by default!
@@ -392,13 +392,28 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
)
class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
class ABFEComplexSimUnit(
ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Protocol Unit for the solvent phase of an absolute binding free energy
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the complex phase of absolute binding
free energy transformations.
"""
simtype = "solvent"
simtype = "complex"
class ABFEComplexAnalysisUnit(ComplexSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the complex phase
of absolute binding free energy transformations.
"""
simtype = "complex"
class SolventComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a solvent transformation.
@@ -426,7 +441,9 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
# Similarly we don't need to check prot_comp just return None
return alchem_comps, solv_comp, None, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
class SolventSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a solvent transformation.
@@ -447,7 +464,7 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
* simulation_settings : MultiStateSimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.forcefield_settings
@@ -464,3 +481,33 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
settings["output_settings"] = prot_settings.solvent_output_settings
return settings
class ABFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the solvent phase of absolute binding free energy
transformations.
"""
simtype = "solvent"
class ABFESolventSimUnit(
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the solvent phase of absolute binding
free energy transformations.
"""
simtype = "solvent"
class ABFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the solvent phase
of absolute binding free energy transformations.
"""
simtype = "solvent"

View File

@@ -214,8 +214,8 @@ class AbsoluteProtocolResultMixin:
for key in [self.bound_state, self.unbound_state]:
for pus in self.data[key].values(): # type: ignore[attr-defined]
states = get_replica_state(
pus[0].outputs["nc"],
pus[0].outputs["last_checkpoint"],
pus[0].outputs["trajectory"],
pus[0].outputs["checkpoint"],
)
replica_states[key].append(states)
@@ -295,7 +295,9 @@ class AbsoluteProtocolResultMixin:
class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
"""Dict-like container for the output of a AbsoluteSolvationProtocol"""
"""
Protocol results with the output of a AbsoluteSolvationProtocol
"""
bound_state = "solvent"
unbound_state = "vacuum"
@@ -375,7 +377,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResul
class AbsoluteBindingProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
"""Dict-like container for the output of a AbsoluteBindingProtocol"""
"""
Protocol results with the output of a AbsoluteBindingProtocol.
"""
bound_state = "complex"
unbound_state = "solvent"

View File

@@ -1,9 +1,9 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units`
========================================================================
"""
AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units`
=====================================================================
This module defines the ProtocolUnits for the
:class:`AbsoluteSolvationProtocol`.
"""
@@ -15,18 +15,16 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
)
from ..openmm_utils import system_validation
from .base_afe_units import BaseAbsoluteUnit
from .base_afe_units import (
BaseAbsoluteMultiStateAnalysisUnit,
BaseAbsoluteMultiStateSimulationUnit,
BaseAbsoluteSetupUnit,
)
logger = logging.getLogger(__name__)
class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
"""
Protocol Unit for the vacuum phase of an absolute solvation free energy
"""
simtype = "vacuum"
class VacuumComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a vacuum transformation.
@@ -59,7 +57,9 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
# (of stateA since we enforce only one disappearing ligand)
return alchem_comps, None, prot_comp, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
class VacuumSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a vacuum transformation.
@@ -80,7 +80,7 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
* simulation_settings : SimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings
@@ -99,13 +99,37 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
return settings
class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
class AHFEVacuumSetupUnit(VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteSetupUnit):
"""
Protocol Unit for the solvent phase of an absolute solvation free energy
Setup unit for the vacuum phase of absolute hydration free energy
transformations.
"""
simtype = "solvent"
simtype = "vacuum"
class AHFEVacuumSimUnit(
VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the vacuum phase of absolute hydration
free energy transformations.
"""
simtype = "vacuum"
class AHFEVacuumAnalysisUnit(VacuumSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the vacuum phase
of absolute hydration free energy transformations.
"""
simtype = "vacuum"
class SolventComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a solvent transformation.
@@ -134,7 +158,9 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
# disallowed on create
return alchem_comps, solv_comp, prot_comp, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
class SolventSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a solvent transformation.
@@ -155,7 +181,7 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
* simulation_settings : MultiStateSimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings
@@ -172,3 +198,33 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
settings["output_settings"] = prot_settings.solvent_output_settings
return settings
class AHFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the solvent phase of absolute hydration free energy
transformations.
"""
simtype = "solvent"
class AHFESolventSimUnit(
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the solvent phase of absolute hydration
free energy transformations.
"""
simtype = "solvent"
class AHFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the solvent phase
of absolute hydration free energy transformations.
"""
simtype = "solvent"

File diff suppressed because it is too large Load Diff

View File

@@ -55,9 +55,19 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
OpenMMEngineSettings,
OpenMMSolvationSettings,
)
from openfe.protocols.openmm_utils import settings_validation, system_validation
from openfe.protocols.openmm_utils import (
settings_validation,
system_validation,
)
from .abfe_units import AbsoluteBindingComplexUnit, AbsoluteBindingSolventUnit
from .abfe_units import (
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
from .afe_protocol_results import AbsoluteBindingProtocolResult
due.cite(
@@ -422,36 +432,58 @@ class AbsoluteBindingProtocol(gufe.Protocol):
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
unit_classes = {
"solvent": {
"setup": ABFESolventSetupUnit,
"simulation": ABFESolventSimUnit,
"analysis": ABFESolventAnalysisUnit,
},
"complex": {
"setup": ABFEComplexSetupUnit,
"simulation": ABFEComplexSimUnit,
"analysis": ABFEComplexAnalysisUnit,
},
}
# Create list units for complex and solvent transforms
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []}
solvent_units = [
AbsoluteBindingSolventUnit(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Binding, {alchname} solvent leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
for phase in ["solvent", "complex"]:
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
complex_units = [
AbsoluteBindingComplexUnit(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Binding, {alchname} complex leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
setup = unit_classes[phase]["setup"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
)
return solvent_units + complex_units
simulation = unit_classes[phase]["simulation"](
protocol=self,
# only need state A & alchem comps
stateA=stateA,
alchemical_components=alchem_comps,
setup_results=setup,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
)
analysis = unit_classes[phase]["analysis"](
protocol=self,
setup_results=setup,
simulation_results=simulation,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
)
protocol_units[phase] += [setup, simulation, analysis]
return protocol_units["solvent"] + protocol_units["complex"]
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
@@ -463,7 +495,7 @@ class AbsoluteBindingProtocol(gufe.Protocol):
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if not pu.ok():
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
if pu.outputs["simtype"] == "solvent":
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)

View File

@@ -63,8 +63,12 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
from ..openmm_utils import settings_validation, system_validation
from .afe_protocol_results import AbsoluteSolvationProtocolResult
from .ahfe_units import (
AbsoluteSolvationSolventUnit,
AbsoluteSolvationVacuumUnit,
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
due.cite(
@@ -445,36 +449,58 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
# Create list units for vacuum and solvent transforms
solvent_units = [
AbsoluteSolvationSolventUnit(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} solvent leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
unit_classes = {
"solvent": {
"setup": AHFESolventSetupUnit,
"simulation": AHFESolventSimUnit,
"analysis": AHFESolventAnalysisUnit,
},
"vacuum": {
"setup": AHFEVacuumSetupUnit,
"simulation": AHFEVacuumSimUnit,
"analysis": AHFEVacuumAnalysisUnit,
},
}
vacuum_units = [
AbsoluteSolvationVacuumUnit(
# These don't really reflect the actual transform
# Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ?
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} vacuum leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "vacuum": []}
return solvent_units + vacuum_units
for phase in ["solvent", "vacuum"]:
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
setup = unit_classes[phase]["setup"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
)
simulation = unit_classes[phase]["simulation"](
protocol=self,
# only need state A & alchem comps
stateA=stateA,
alchemical_components=alchem_comps,
setup_results=setup,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
)
analysis = unit_classes[phase]["analysis"](
protocol=self,
setup_results=setup,
simulation_results=simulation,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
)
protocol_units[phase] += [setup, simulation, analysis]
return protocol_units["solvent"] + protocol_units["vacuum"]
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
@@ -486,7 +512,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if not pu.ok():
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
if pu.outputs["simtype"] == "solvent":
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)

View File

@@ -1,6 +1,11 @@
import os
import pathlib
from gufe.settings.typing import NanometerArrayQuantity
from openff.units import Quantity
from openmm import Vec3
from openmm import unit as ommunit
def serialize(item, filename: pathlib.Path):
"""
@@ -62,3 +67,23 @@ def deserialize(filename: pathlib.Path):
item = XmlSerializer.deserialize(serialized_thing)
return item
def make_vec3_box(dimensions: NanometerArrayQuantity) -> Vec3:
"""
Convert an OpenFF box dimensions Quantity back into Vec3 format.
Parameters
----------
dimensions : gufe.settings.typing.NanometerArrayQuantity
United array to turn to Vec3 format.
Returns
-------
openmm.Vec3
The input array in Vec3 format.
"""
return [
Vec3(float(row[0]), float(row[1]), float(row[2])) * ommunit.nanometer
for row in dimensions.m_as("nanometer")
]

View File

@@ -7,6 +7,7 @@ from typing import Optional
import openmm
import pooch
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
from openff.units import Quantity, unit
from openff.units.openmm import from_openmm
from openmm import Platform
@@ -326,6 +327,20 @@ def get_available_openmm_platforms() -> set[str]:
return working_platforms
class ModGufeTokenizableTestsMixin(GufeTokenizableTestsMixin):
"""
A modified gufe tokenizable tests mixin which allows
for repr to be lazily evaluated.
"""
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
def compute_energy(
system: openmm.System,
positions: openmm.unit.Quantity,

View File

@@ -17,8 +17,8 @@ from openmmtools.alchemy import (
)
from openfe.protocols import openmm_afe
from openfe.protocols.openmm_afe import (
AbsoluteBindingComplexUnit,
from openfe.protocols.openmm_afe.abfe_units import (
ABFEComplexSetupUnit,
)
from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings
from openfe.protocols.openmm_utils.serialization import deserialize
@@ -147,11 +147,11 @@ class TestT4EnergiesRegression:
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
complex_units = [u for u in dag.protocol_units if isinstance(u, ABFEComplexSetupUnit)]
with tmpdir.as_cwd():
data = complex_units[0].run(dry=True)["debug"]
return data
results = complex_units[0].run(dry=True)
return results
@staticmethod
def get_energy_components(
@@ -182,7 +182,7 @@ class TestT4EnergiesRegression:
energies_ref = self.get_energy_components(
t4_reference_system,
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_val,
lambda_val,
lambda_val,
@@ -191,7 +191,7 @@ class TestT4EnergiesRegression:
energies_val = self.get_energy_components(
t4_validation_data["alchem_system"],
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_val,
lambda_val,
lambda_val,
@@ -223,7 +223,7 @@ class TestT4EnergiesRegression:
energies = self.get_energy_components(
t4_validation_data["alchem_system"],
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_sterics=1.0,
lambda_electrostatics=1.0,
lambda_restraints=1.0,
@@ -236,7 +236,7 @@ class TestT4EnergiesRegression:
energies = self.get_energy_components(
t4_validation_data["alchem_system"],
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_sterics=1.0,
lambda_electrostatics=1.0,
lambda_restraints=0.0,
@@ -249,7 +249,7 @@ class TestT4EnergiesRegression:
energies = self.get_energy_components(
t4_validation_data["alchem_system"],
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_sterics=1.0,
lambda_electrostatics=0.0,
lambda_restraints=0.0,
@@ -270,7 +270,7 @@ class TestT4EnergiesRegression:
energies = self.get_energy_components(
t4_validation_data["alchem_system"],
t4_validation_data["alchem_indices"],
t4_validation_data["positions"],
t4_validation_data["debug_positions"],
lambda_sterics=0.0,
lambda_electrostatics=0.0,
lambda_restraints=0.0,

View File

@@ -1,6 +1,5 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from importlib import resources
from math import sqrt
from unittest import mock
@@ -23,9 +22,7 @@ from openmm import (
)
from openmm import unit as ommunit
from openmmtools.alchemy import (
AbsoluteAlchemicalFactory,
AlchemicalRegion,
AlchemicalState,
)
from openmmtools.multistate.multistatesampler import MultiStateSampler
from openmmtools.tests.test_alchemy import (
@@ -38,11 +35,16 @@ import openfe
from openfe import ChemicalSystem, SmallMoleculeComponent, SolventComponent
from openfe.protocols import openmm_afe
from openfe.protocols.openmm_afe import (
AbsoluteBindingComplexUnit,
AbsoluteBindingProtocol,
AbsoluteBindingSolventUnit,
)
from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings
from openfe.protocols.openmm_afe.abfe_units import (
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
from .utils import UNIT_TYPES, _get_units
@pytest.fixture()
@@ -75,37 +77,53 @@ def test_serialize_protocol(default_settings):
assert protocol == ret
def test_unit_tagging(benzene_complex_dag, tmpdir):
# test that executing the units includes correct gen and repeat info
def test_repeat_units(benzene_modifications, T4_protein_component):
protocol = openmm_afe.AbsoluteBindingProtocol(
settings=openmm_afe.AbsoluteBindingProtocol.default_settings()
)
dag_units = benzene_complex_dag.protocol_units
stateA = gufe.ChemicalSystem(
{
"protein": T4_protein_component,
"benzene": benzene_modifications["benzene"],
"solvent": gufe.SolventComponent(),
}
)
with (
mock.patch(
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
mock.patch(
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
):
results = []
for u in dag_units:
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
results.append(ret)
stateB = gufe.ChemicalSystem(
{
"protein": T4_protein_component,
"solvent": gufe.SolventComponent(),
}
)
solv_repeats = set()
complex_repeats = set()
for ret in results:
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
if ret.outputs["simtype"] == "complex":
complex_repeats.add(ret.outputs["repeat_id"])
else:
solv_repeats.add(ret.outputs["repeat_id"])
# Repeat ids are random ints so just check their lengths
assert len(complex_repeats) == len(solv_repeats) == 3
dag = protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)
# 6 protocol unit, 3 per repeat
pus = list(dag.protocol_units)
assert len(pus) == 18
# Check info for each repeat
for phase in ["solvent", "complex"]:
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
# Should be 3 of each set
assert len(setup) == len(sim) == len(analysis) == 3
# Check that the dag chain is correct
for analysis_pu in analysis:
repeat_id = analysis_pu.inputs["repeat_id"]
setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0]
sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0]
assert analysis_pu.inputs["setup_results"] == setup_pu
assert analysis_pu.inputs["simulation_results"] == sim_pu
assert sim_pu.inputs["setup_results"] == setup_pu
def test_create_independent_repeat_ids(benzene_modifications, T4_protein_component):
@@ -137,9 +155,12 @@ def test_create_independent_repeat_ids(benzene_modifications, T4_protein_compone
repeat_ids = set()
for dag in dags:
# 3 sets of 6 units
assert len(list(dag.protocol_units)) == 18
for u in dag.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
# squashed by repeat_id, that's 2 sets of 6
assert len(repeat_ids) == 12
@@ -149,7 +170,7 @@ def test_mda_universe_error():
when calling the mda Universe getter.
"""
with pytest.raises(ValueError, match="No positions to create"):
_ = openmm_afe.AbsoluteBindingComplexUnit._get_mda_universe(
_ = openmm_afe.ABFEComplexSetupUnit._get_mda_universe(
topology="foo", positions=None, trajectory=None
)
@@ -201,17 +222,32 @@ class TestT4LysozymeDryRun:
)
@pytest.fixture(scope="class")
def complex_units(self, dag):
return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
def complex_setup_units(self, dag):
return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"])
@pytest.fixture(scope="class")
def solvent_units(self, dag):
return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingSolventUnit)]
def complex_sim_units(self, dag):
return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["sim"])
def test_number_of_units(self, dag, complex_units, solvent_units):
assert len(list(dag.protocol_units)) == 2
assert len(complex_units) == 1
assert len(solvent_units) == 1
@pytest.fixture(scope="class")
def solvent_setup_units(self, dag):
return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["setup"])
@pytest.fixture(scope="class")
def solvent_sim_units(self, dag):
return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["sim"])
def test_number_of_units(
self,
dag,
complex_setup_units,
complex_sim_units,
solvent_setup_units,
solvent_sim_units,
):
assert len(list(dag.protocol_units)) == 6
assert len(complex_setup_units) == len(complex_sim_units) == 1
assert len(solvent_setup_units) == len(solvent_sim_units) == 1
def _assert_force_num(self, system, forcetype, number):
forces = [f for f in system.getForces() if isinstance(f, forcetype)]
@@ -356,83 +392,99 @@ class TestT4LysozymeDryRun:
positions=positions,
)
def test_complex_dry_run(self, complex_units, settings, tmpdir):
def test_complex_dry_run(self, complex_setup_units, complex_sim_units, settings, tmpdir):
with tmpdir.as_cwd():
data = complex_units[0].run(dry=True, verbose=True)["debug"]
setup_results = complex_setup_units[0].run(dry=True, verbose=True)
sim_results = complex_sim_units[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=True,
dry=True,
)
# Check the sampler
self._verify_sampler(data["sampler"], complexed=True, settings=settings)
self._verify_sampler(sim_results["sampler"], complexed=True, settings=settings)
# Check the alchemical system
self._assert_expected_alchemical_forces(
data["alchem_system"], complexed=True, settings=settings
setup_results["alchem_system"], complexed=True, settings=settings
)
self._test_dodecahedron_vectors(data["alchem_system"])
self._test_dodecahedron_vectors(setup_results["alchem_system"])
# Check the alchemical indices
expected_indices = [i + self.num_complex_atoms for i in range(self.num_solvent_atoms)]
assert expected_indices == data["alchem_indices"]
assert expected_indices == setup_results["alchem_indices"]
# Check the non-alchemical system
self._assert_expected_nonalchemical_forces(data["system"], settings)
self._test_dodecahedron_vectors(data["system"])
self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings)
self._test_dodecahedron_vectors(setup_results["standard_system"])
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
assert_allclose(
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
from_openmm(data["system"].getDefaultPeriodicBoxVectors()),
from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()),
from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()),
)
# Check the PDB
pdb = mdt.load_pdb("alchemical_system.pdb")
pdb = mdt.load_pdb(setup_results["pdb_structure"])
assert pdb.n_atoms == self.num_all_not_water
# Check energies
alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"])
alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"])
self._test_energies(
reference_system=data["system"],
alchemical_system=data["alchem_system"],
reference_system=setup_results["standard_system"],
alchemical_system=setup_results["alchem_system"],
alchemical_regions=alchem_region,
positions=data["positions"],
positions=setup_results["debug_positions"],
)
def test_solvent_dry_run(self, solvent_units, settings, tmpdir):
def test_solvent_dry_run(self, solvent_setup_units, solvent_sim_units, settings, tmpdir):
with tmpdir.as_cwd():
data = solvent_units[0].run(dry=True, verbose=True)["debug"]
setup_results = solvent_setup_units[0].run(dry=True, verbose=True)
sim_results = solvent_sim_units[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
# Check the sampler
self._verify_sampler(data["sampler"], complexed=False, settings=settings)
self._verify_sampler(sim_results["sampler"], complexed=False, settings=settings)
# Check the alchemical system
self._assert_expected_alchemical_forces(
data["alchem_system"], complexed=False, settings=settings
setup_results["alchem_system"], complexed=False, settings=settings
)
self._test_cubic_vectors(data["alchem_system"])
self._test_cubic_vectors(setup_results["alchem_system"])
# Check the alchemical indices
expected_indices = [i for i in range(self.num_solvent_atoms)]
assert expected_indices == data["alchem_indices"]
assert expected_indices == setup_results["alchem_indices"]
# Check the non-alchemical system
self._assert_expected_nonalchemical_forces(data["system"], settings)
self._test_cubic_vectors(data["system"])
self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings)
self._test_cubic_vectors(setup_results["standard_system"])
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
assert_allclose(
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
from_openmm(data["system"].getDefaultPeriodicBoxVectors()),
from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()),
from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()),
)
# Check the PDB
pdb = mdt.load_pdb("alchemical_system.pdb")
pdb = mdt.load_pdb(setup_results["pdb_structure"])
assert pdb.n_atoms == self.num_solvent_atoms
# Check energies
alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"])
alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"])
self._test_energies(
reference_system=data["system"],
alchemical_system=data["alchem_system"],
reference_system=setup_results["standard_system"],
alchemical_system=setup_results["alchem_system"],
alchemical_regions=alchem_region,
positions=data["positions"],
positions=setup_results["debug_positions"],
)
@@ -510,15 +562,17 @@ def test_user_charges(benzene_modifications, T4_protein_component, tmpdir):
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
complex_setup_units = _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"])
with tmpdir.as_cwd():
data = complex_units[0].run(dry=True)["debug"]
results = complex_setup_units[0].run(dry=True)
system_nbf = [f for f in data["system"].getForces() if isinstance(f, NonbondedForce)][0]
system_nbf = [
f for f in results["standard_system"].getForces() if isinstance(f, NonbondedForce)
][0]
alchem_system_nbf = [
f
for f in data["alchem_system"].getForces()
for f in results["alchem_system"].getForces()
if isinstance(f, NonbondedForce)
][0] # fmt: skip

View File

@@ -3,6 +3,7 @@
import gzip
import itertools
import json
from pathlib import Path
from unittest import mock
import gufe
@@ -14,25 +15,78 @@ import openfe
from openfe.protocols import openmm_afe
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
from .utils import UNIT_TYPES, _get_units
def test_gather(benzene_complex_dag, tmpdir):
# check that .gather behaves as expected
@pytest.fixture()
def patcher():
with (
mock.patch(
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
"openfe.protocols.openmm_afe.abfe_units.ABFESolventSetupUnit.run",
return_value={
"system": Path("system.xml.bz2"),
"positions": Path("positions.npy"),
"pdb_structure": Path("hybrid_system.pdb"),
"selection_indices": np.zeros(100),
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
"restraint_geometry": None,
},
),
mock.patch(
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexSetupUnit.run",
return_value={
"system": Path("system.xml.bz2"),
"positions": Path("positions.npy"),
"pdb_structure": Path("hybrid_system.pdb"),
"selection_indices": np.zeros(100),
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
"restraint_geometry": True,
},
),
mock.patch(
"openfe.protocols.openmm_afe.base_afe_units.np.load",
return_value=np.zeros(100),
),
mock.patch(
"openfe.protocols.openmm_afe.base_afe_units.deserialize",
return_value="foo",
),
mock.patch(
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexSimUnit.run",
return_value={
"trajectory": Path("file.nc"),
"checkpoint": Path("chk.chk"),
},
),
mock.patch(
"openfe.protocols.openmm_afe.abfe_units.ABFESolventSimUnit.run",
return_value={
"trajectory": Path("file.nc"),
"checkpoint": Path("chk.chk"),
},
),
mock.patch(
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexAnalysisUnit.run",
return_value={"foo": "bar"},
),
mock.patch(
"openfe.protocols.openmm_afe.abfe_units.ABFESolventAnalysisUnit.run",
return_value={"foo": "bar"},
),
):
dagres = gufe.protocols.execute_DAG(
benzene_complex_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
keep_shared=True,
)
yield
def test_gather(benzene_complex_dag, patcher, tmpdir):
# check that .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
benzene_complex_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
keep_shared=True,
)
protocol = openmm_afe.AbsoluteBindingProtocol(
settings=openmm_afe.AbsoluteBindingProtocol.default_settings(),
@@ -43,6 +97,47 @@ def test_gather(benzene_complex_dag, tmpdir):
assert isinstance(res, openmm_afe.AbsoluteBindingProtocolResult)
def test_unit_tagging(benzene_complex_dag, patcher, tmpdir):
# test that executing the units includes correct gen and repeat info
dag_units = benzene_complex_dag.protocol_units
for phase in ["solvent", "complex"]:
setup_results = {}
sim_results = {}
analysis_results = {}
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
for u in setup_units:
rid = u.inputs["repeat_id"]
setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir))
for u in sim_units:
rid = u.inputs["repeat_id"]
sim_results[rid] = u.execute(
context=gufe.Context(tmpdir, tmpdir),
setup_results=setup_results[rid],
)
for u in a_units:
rid = u.inputs["repeat_id"]
analysis_results[rid] = u.execute(
context=gufe.Context(tmpdir, tmpdir),
setup_results=setup_results[rid],
simulation_results=sim_results[rid],
)
for results in [setup_results, sim_results, analysis_results]:
for ret in results.values():
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
assert len(setup_results) == len(sim_results) == len(analysis_results) == 3
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, abfe_transformation_json_path):
@@ -62,7 +157,7 @@ class TestProtocolResult:
est = protocolresult.get_estimate()
assert est
assert est.m == pytest.approx(-21.71, abs=0.01)
assert est.m == pytest.approx(-21.35, abs=0.01)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
@@ -70,7 +165,7 @@ class TestProtocolResult:
est = protocolresult.get_uncertainty()
assert est
assert est.m == pytest.approx(0.73, abs=0.01)
assert est.m == pytest.approx(1.04, abs=0.01)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
@@ -176,12 +271,12 @@ class TestProtocolResult:
assert isinstance(geom[0], BoreschRestraintGeometry)
assert geom[0].guest_atoms == [1779, 1778, 1777]
assert geom[0].host_atoms == [880, 865, 864]
assert pytest.approx(geom[0].r_aA0) == 1.083558 * offunit.nanometer
assert pytest.approx(geom[0].theta_A0) == 0.6786444 * offunit.radian
assert pytest.approx(geom[0].theta_B0) == 1.649905 * offunit.radian
assert pytest.approx(geom[0].phi_A0) == -0.3640583 * offunit.radian
assert pytest.approx(geom[0].phi_B0) == 1.892376 * offunit.radian
assert pytest.approx(geom[0].phi_C0) == -0.6106747 * offunit.radian
assert pytest.approx(geom[0].r_aA0, rel=1e-2) == 1.083558 * offunit.nanometer
assert pytest.approx(geom[0].theta_A0, rel=1e-2) == 0.711876 * offunit.radian
assert pytest.approx(geom[0].theta_B0, rel=1e-2) == 1.687366 * offunit.radian
assert pytest.approx(geom[0].phi_A0, rel=1e-2) == -0.2164231 * offunit.radian
assert pytest.approx(geom[0].phi_B0, rel=1e-2) == 1.892376 * offunit.radian
assert pytest.approx(geom[0].phi_C0, rel=1e-2) == -0.522031870 * offunit.radian
@pytest.mark.parametrize(
"key, expected_size",

View File

@@ -96,16 +96,36 @@ def test_openmm_run_engine(
r = openfe.execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True)
assert r.ok()
for pur in r.protocol_unit_results:
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
checkpoint = pur.outputs["last_checkpoint"]
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
assert (unit_shared / checkpoint).exists()
nc = pur.outputs["nc"]
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()
# Check outputs of solvent & complex results
for phase in ["solvent", "complex"]:
purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase]
# get the path to the simulation unit shared dict
for pur in purs:
if "Simulation" in pur.name:
sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert sim_shared.exists()
assert pathlib.Path(sim_shared).is_dir()
# check the analysis outputs
for pur in purs:
if "Analysis" not in pur.name:
continue
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
# Does the checkpoint file exist?
checkpoint = pur.outputs["checkpoint"]
assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
assert checkpoint.exists()
# Does the trajectory file exist?
nc = pur.outputs["trajectory"]
assert nc == sim_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()
# Test results methods that need files present
results = protocol.gather([r])

View File

@@ -3,16 +3,21 @@
import gzip
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import openfe
from openfe.protocols.openmm_afe import (
AbsoluteBindingComplexUnit,
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
AbsoluteBindingProtocol,
AbsoluteBindingProtocolResult,
AbsoluteBindingSolventUnit,
)
from ..conftest import ModGufeTokenizableTestsMixin
@pytest.fixture
def protocol():
@@ -33,18 +38,40 @@ def protocol_units(protocol, benzene_complex_system, T4_protein_component):
return list(pus.protocol_units)
@pytest.fixture
def solvent_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, AbsoluteBindingSolventUnit):
def _filter_units(pus, classtype):
for pu in pus:
if isinstance(pu, classtype):
return pu
@pytest.fixture
def complex_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, AbsoluteBindingComplexUnit):
return pu
def complex_protocol_setup_unit(protocol_units):
return _filter_units(protocol_units, ABFEComplexSetupUnit)
@pytest.fixture
def complex_protocol_sim_unit(protocol_units):
return _filter_units(protocol_units, ABFEComplexSimUnit)
@pytest.fixture
def complex_protocol_analysis_unit(protocol_units):
return _filter_units(protocol_units, ABFEComplexAnalysisUnit)
@pytest.fixture
def solvent_protocol_setup_unit(protocol_units):
return _filter_units(protocol_units, ABFESolventSetupUnit)
@pytest.fixture
def solvent_protocol_sim_unit(protocol_units):
return _filter_units(protocol_units, ABFESolventSimUnit)
@pytest.fixture
def solvent_protocol_analysis_unit(protocol_units):
return _filter_units(protocol_units, ABFESolventAnalysisUnit)
@pytest.fixture
@@ -54,7 +81,7 @@ def protocol_result(abfe_transformation_json_path):
return pr
class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin):
class TestAbsoluteBindingProtocol(ModGufeTokenizableTestsMixin):
cls = AbsoluteBindingProtocol
key = None
repr = "AbsoluteBindingProtocol-"
@@ -63,49 +90,68 @@ class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin):
def instance(self, protocol):
return protocol
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteBindingSolventUnit(GufeTokenizableTestsMixin):
cls = AbsoluteBindingSolventUnit
repr = "AbsoluteBindingSolventUnit(Absolute Binding, benzene solvent leg"
class TestABFESolventSetupUnit(ModGufeTokenizableTestsMixin):
cls = ABFESolventSetupUnit
repr = "ABFESolventSetupUnit(ABFE Setup: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_unit):
return solvent_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
def instance(self, solvent_protocol_setup_unit):
return solvent_protocol_setup_unit
class TestAbsoluteBindingComplexUnit(GufeTokenizableTestsMixin):
cls = AbsoluteBindingComplexUnit
repr = "AbsoluteBindingComplexUnit(Absolute Binding, benzene complex leg"
class TestABFESolventSimUnit(ModGufeTokenizableTestsMixin):
cls = ABFESolventSimUnit
repr = "ABFESolventSimUnit(ABFE Simulation: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, complex_protocol_unit):
return complex_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
def instance(self, solvent_protocol_sim_unit):
return solvent_protocol_sim_unit
class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin):
class TestABFESolventAnalysisUnit(ModGufeTokenizableTestsMixin):
cls = ABFESolventAnalysisUnit
repr = "ABFESolventAnalysisUnit(ABFE Analysis: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_analysis_unit):
return solvent_protocol_analysis_unit
class TestABFEComplexSetupUnit(ModGufeTokenizableTestsMixin):
cls = ABFEComplexSetupUnit
repr = "ABFEComplexSetupUnit(ABFE Setup: benzene complex leg"
key = None
@pytest.fixture()
def instance(self, complex_protocol_setup_unit):
return complex_protocol_setup_unit
class TestABFEComplexSimUnit(ModGufeTokenizableTestsMixin):
cls = ABFEComplexSimUnit
repr = "ABFEComplexSimUnit(ABFE Simulation: benzene complex leg"
key = None
@pytest.fixture()
def instance(self, complex_protocol_sim_unit):
return complex_protocol_sim_unit
class TestABFEComplexAnalysisUnit(ModGufeTokenizableTestsMixin):
cls = ABFEComplexAnalysisUnit
repr = "ABFEComplexAnalysisUnit(ABFE Analysis: benzene complex leg"
key = None
@pytest.fixture()
def instance(self, complex_protocol_analysis_unit):
return complex_protocol_analysis_unit
class TestAbsoluteBindingProtocolResult(ModGufeTokenizableTestsMixin):
cls = AbsoluteBindingProtocolResult
key = None
repr = "AbsoluteBindingProtocolResult-"
@@ -113,10 +159,3 @@ class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin):
@pytest.fixture()
def instance(self, protocol_result):
return protocol_result
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)

View File

@@ -0,0 +1,30 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from openfe.protocols.openmm_afe.abfe_units import (
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
UNIT_TYPES = {
"solvent": {
"setup": ABFESolventSetupUnit,
"sim": ABFESolventSimUnit,
"analysis": ABFESolventAnalysisUnit,
},
"complex": {
"setup": ABFEComplexSetupUnit,
"sim": ABFEComplexSimUnit,
"analysis": ABFEComplexAnalysisUnit,
},
}
def _get_units(protocol_units, unit_type):
"""
Helper method to extract setup units.
"""
return [pu for pu in protocol_units if isinstance(pu, unit_type)]

View File

@@ -1,12 +1,9 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import itertools
import json
import sys
from math import sqrt
from unittest import mock
import gufe
import mdtraj as mdt
import numpy as np
import pytest
@@ -18,7 +15,6 @@ from openmm import (
CustomNonbondedForce,
HarmonicAngleForce,
HarmonicBondForce,
MonteCarloBarostat,
NonbondedForce,
PeriodicTorsionForce,
)
@@ -29,16 +25,15 @@ from openfe import ChemicalSystem, SolventComponent
from openfe.protocols import openmm_afe
from openfe.protocols.openmm_afe import (
AbsoluteSolvationProtocol,
AbsoluteSolvationSolventUnit,
AbsoluteSolvationVacuumUnit,
)
from openfe.protocols.openmm_utils import system_validation
from openfe.protocols.openmm_utils.charge_generation import (
HAS_ESPALOMA_CHARGE,
HAS_NAGL,
HAS_OPENEYE,
)
from .utils import UNIT_TYPES, _get_units
@pytest.fixture()
def protocol_dry_settings():
@@ -68,6 +63,40 @@ def test_serialize_protocol(default_settings):
assert protocol == ret
def test_repeat_units(benzene_system):
protocol = openmm_afe.AbsoluteSolvationProtocol(
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings()
)
dag = protocol.create(
stateA=benzene_system,
stateB=ChemicalSystem({"solvent": SolventComponent()}),
mapping=None,
)
# 6 protocol unit, 3 per repeat
pus = list(dag.protocol_units)
assert len(pus) == 18
# Check info for each repeat
for phase in ["solvent", "vacuum"]:
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
# Should be 3 of each set
assert len(setup) == len(sim) == len(analysis) == 3
# Check that the dag chain is correct
for analysis_pu in analysis:
repeat_id = analysis_pu.inputs["repeat_id"]
setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0]
sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0]
assert analysis_pu.inputs["setup_results"] == setup_pu
assert analysis_pu.inputs["simulation_results"] == sim_pu
assert sim_pu.inputs["setup_results"] == setup_pu
def test_create_independent_repeat_ids(benzene_system):
protocol = openmm_afe.AbsoluteSolvationProtocol(
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings()
@@ -88,9 +117,12 @@ def test_create_independent_repeat_ids(benzene_system):
repeat_ids = set()
for dag in dags:
# 3 sets of 6 units
assert len(list(dag.protocol_units)) == 18
for u in dag.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
# squashed by repeat_id, that's 2 sets of 6
assert len(repeat_ids) == 12
@@ -143,7 +175,7 @@ def _verify_alchemical_sterics_force_parameters(
@pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"])
def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir):
def test_setup_dry_sim_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir):
protocol_dry_settings.vacuum_simulation_settings.sampler_method = method
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
@@ -161,21 +193,32 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
)
prot_units = list(dag.protocol_units)
assert len(prot_units) == 2
assert len(prot_units) == 6
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"])
assert len(vac_unit) == 1
assert len(sol_unit) == 1
assert len(vac_setup_unit) == 1
assert len(vac_sim_unit) == 1
with tmpdir.as_cwd():
debug = vac_unit[0].run(dry=True)["debug"]
vac_sampler = debug["sampler"]
assert not vac_sampler.is_periodic
setup_results = vac_setup_unit[0].run(dry=True)
sim_results = vac_sim_unit[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
sampler = sim_results["sampler"]
assert isinstance(sampler, MultiStateSampler)
assert not sampler.is_periodic
assert sampler._thermodynamic_states[0].barostat is None
# standard system
system = debug["system"]
system = setup_results["standard_system"]
assert system.getNumParticles() == 12
assert len(system.getForces()) == 4
_assert_num_forces(system, NonbondedForce, 1)
@@ -184,7 +227,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
_assert_num_forces(system, PeriodicTorsionForce, 1)
# alchemical system
alchem_system = debug["alchem_system"]
alchem_system = setup_results["alchem_system"]
assert alchem_system.getNumParticles() == 12
assert len(alchem_system.getForces()) == 12
_assert_num_forces(alchem_system, NonbondedForce, 1)
@@ -212,7 +255,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
[0.35, 2.2, 1.5, 0, False],
],
)
def test_alchemical_settings_dry_run_vacuum(
def test_alchemical_settings_setup_vacuum(
alpha, a, b, c, correction, benzene_system, protocol_dry_settings, tmpdir
):
"""
@@ -238,18 +281,18 @@ def test_alchemical_settings_dry_run_vacuum(
)
prot_units = list(dag.protocol_units)
assert len(prot_units) == 2
assert len(prot_units) == 6
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"])
assert len(vac_unit) == 1
assert len(sol_unit) == 1
assert len(vac_setup_unit) == 1
assert len(vac_sim_unit) == 1
with tmpdir.as_cwd():
debug = vac_unit[0].run(dry=True)["debug"]
results = vac_setup_unit[0].run(dry=True)
alchem_system = debug["alchem_system"]
alchem_system = results["alchem_system"]
_assert_num_forces(alchem_system, NonbondedForce, 1)
_assert_num_forces(alchem_system, CustomNonbondedForce, 4)
_assert_num_forces(alchem_system, CustomBondForce, 4)
@@ -291,16 +334,16 @@ def test_confgen_fail_AFE(benzene_system, protocol_dry_settings, tmpdir):
mapping=None,
)
prot_units = list(dag.protocol_units)
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
with tmpdir.as_cwd():
with mock.patch("rdkit.Chem.AllChem.EmbedMultipleConfs", return_value=0):
vac_sampler = vac_unit[0].run(dry=True)["debug"]["sampler"]
assert vac_sampler
# If this worked, the system will have been built
system = vac_setup_unit[0].run(dry=True)["alchem_system"]
assert system
def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
def test_setup_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
protocol_dry_settings.solvent_output_settings.output_indices = "resname UNK"
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
@@ -318,19 +361,25 @@ def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
)
prot_units = list(dag.protocol_units)
assert len(prot_units) == 2
sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
assert len(vac_unit) == 1
assert len(sol_unit) == 1
assert len(sol_setup_unit) == len(sol_sim_unit) == 1
with tmpdir.as_cwd():
sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
setup_results = sol_setup_unit[0].run(dry=True)
sim_results = sol_sim_unit[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
sol_sampler = sim_results["sampler"]
assert sol_sampler.is_periodic
pdb = mdt.load_pdb("hybrid_system.pdb")
pdb = mdt.load_pdb(setup_results["pdb_structure"])
assert pdb.n_atoms == 12
@@ -363,14 +412,23 @@ def test_dry_run_vsite_fail(benzene_system, tmpdir, protocol_dry_settings):
)
prot_units = list(dag.protocol_units)
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
with tmpdir.as_cwd():
setup_results = sol_setup_unit[0].run(dry=True)
with pytest.raises(ValueError, match="are unstable"):
_ = sol_unit[0].run(dry=True)
sim_results = sol_sim_unit[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir):
def test_setup_dry_sim_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir):
protocol_dry_settings.vacuum_forcefield_settings.forcefields = [
"amber/ff14SB.xml", # ff14SB protein force field
"amber/tip4pew_standard.xml", # FF we are testsing with the fun VS
@@ -399,10 +457,20 @@ def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdi
)
prot_units = list(dag.protocol_units)
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
sol_sim_units = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
with tmpdir.as_cwd():
sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
setup_results = sol_setup_units[0].run(dry=True)
sim_results = sol_sim_units[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
sol_sampler = sim_results["sampler"]
assert sol_sampler.is_periodic
@@ -425,11 +493,11 @@ def test_dry_run_solv_benzene_noncubic(benzene_system, protocol_dry_settings, tm
)
prot_units = list(dag.protocol_units)
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
with tmpdir.as_cwd():
sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
system = sampler._thermodynamic_states[0].system
results = sol_setup_units[0].run(dry=True)
system = results["alchem_system"]
vectors = system.getDefaultPeriodicBoxVectors()
width = float(from_openmm(vectors)[0][0].to("nanometer").m)
@@ -486,13 +554,13 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
prot_units = list(dag.protocol_units)
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0]
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)][0]
vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
# check sol_unit charges
with tmpdir.as_cwd():
sampler = sol_unit.run(dry=True)["debug"]["sampler"]
system = sampler._thermodynamic_states[0].system
results = sol_setup_units[0].run(dry=True)
system = results["alchem_system"]
nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)]
assert len(nonbond) == 1
@@ -506,8 +574,8 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s
# check vac_unit charges
with tmpdir.as_cwd():
sampler = vac_unit.run(dry=True)["debug"]["sampler"]
system = sampler._thermodynamic_states[0].system
results = vac_setup_units[0].run(dry=True)
system = results["alchem_system"]
nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)]
assert len(nonbond) == 4
@@ -572,12 +640,12 @@ def test_dry_run_charge_backends(
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
prot_units = list(dag.protocol_units)
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0]
vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
# check vac_unit charges
with tmpdir.as_cwd():
sampler = vac_unit.run(dry=True)["debug"]["sampler"]
system = sampler._thermodynamic_states[0].system
results = vac_setup_units[0].run(dry=True)
system = results["alchem_system"]
nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)]
assert len(nonbond) == 4
@@ -609,187 +677,6 @@ def benzene_solvation_dag(benzene_system, protocol_dry_settings):
return protocol.create(stateA=stateA, stateB=stateB, mapping=None)
def test_unit_tagging(benzene_solvation_dag, tmpdir):
# test that executing the units includes correct gen and repeat info
dag_units = benzene_solvation_dag.protocol_units
with (
mock.patch(
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
mock.patch(
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
):
results = []
for u in dag_units:
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
results.append(ret)
solv_repeats = set()
vac_repeats = set()
for ret in results:
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
if ret.outputs["simtype"] == "vacuum":
vac_repeats.add(ret.outputs["repeat_id"])
else:
solv_repeats.add(ret.outputs["repeat_id"])
# Repeat ids are random ints so just check their lengths
assert len(vac_repeats) == len(solv_repeats) == 3
def test_gather(benzene_solvation_dag, tmpdir):
# check that .gather behaves as expected
with (
mock.patch(
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
mock.patch(
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run",
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
),
):
dagres = gufe.protocols.execute_DAG(
benzene_solvation_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
keep_shared=True,
)
protocol = AbsoluteSolvationProtocol(
settings=AbsoluteSolvationProtocol.default_settings(),
)
res = protocol.gather([dagres])
assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult)
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
return pr
def test_reload_protocol_result(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"])
assert pr
def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()
assert est
assert est.m == pytest.approx(-2.47, abs=0.5)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()
assert est
assert est.m == pytest.approx(0.2, abs=0.2)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()
assert isinstance(inds, dict)
assert isinstance(inds["solvent"], list)
assert isinstance(inds["vacuum"], list)
assert len(inds["solvent"]) == len(inds["vacuum"]) == 3
for e, u in itertools.chain(inds["solvent"], inds["vacuum"]):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_forwards_etc(self, key, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()
assert isinstance(far, dict)
assert isinstance(far[key], list)
far1 = far[key][0]
assert isinstance(far1, dict)
for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
assert k in far1
if k == "fractions":
assert isinstance(far1[k], np.ndarray)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_frwd_reverse_none_return(self, key, protocolresult):
# fetch the first result of type key
data = [i for i in protocolresult.data[key].values()][0][0]
# set the output to None
data.outputs["forward_and_reverse_energies"] = None
# now fetch the analysis results and expect a warning
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
with pytest.warns(UserWarning, match=wmsg):
protocolresult.get_forward_and_reverse_energy_analysis()
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()
assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 3
ovp1 = ovp[key][0]
assert isinstance(ovp1["matrix"], np.ndarray)
assert ovp1["matrix"].shape == (14, 14)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()
assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 3
rpx1 = rpx[key][0]
assert "eigenvalues" in rpx1
assert "matrix" in rpx1
assert rpx1["eigenvalues"].shape == (14,)
assert rpx1["matrix"].shape == (14, 14)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()
assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 3
assert all(isinstance(v, float) for v in eq[key])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()
assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 3
assert all(isinstance(v, float) for v in prod[key])
def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"
with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()
@pytest.mark.parametrize(
"positions_write_frequency,velocities_write_frequency",
[
@@ -821,7 +708,6 @@ def test_dry_run_vacuum_write_frequency(
stateB = ChemicalSystem({"solvent": SolventComponent()})
# Create DAG from protocol, get the vacuum and solvent units
# and eventually dry run the first solvent unit
dag = protocol.create(
stateA=stateA,
stateB=stateB,
@@ -829,11 +715,23 @@ def test_dry_run_vacuum_write_frequency(
)
prot_units = list(dag.protocol_units)
assert len(prot_units) == 2
assert len(prot_units) == 6
with tmpdir.as_cwd():
for u in prot_units:
sampler = u.run(dry=True)["debug"]["sampler"]
for phase in ["solvent", "vacuum"]:
setup_units = _get_units(prot_units, UNIT_TYPES[phase]["setup"])
sim_units = _get_units(prot_units, UNIT_TYPES[phase]["sim"])
with tmpdir.as_cwd():
setup_results = setup_units[0].run(dry=True)
sim_results = sim_units[0].run(
system=setup_results["alchem_system"],
positions=setup_results["debug_positions"],
selection_indices=setup_results["selection_indices"],
box_vectors=setup_results["box_vectors"],
alchemical_restraints=False,
dry=True,
)
sampler = sim_results["sampler"]
reporter = sampler._reporter
if positions_write_frequency:
assert reporter.position_interval == positions_write_frequency.m

View File

@@ -0,0 +1,278 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import itertools
import json
from pathlib import Path
from unittest import mock
import gufe
import numpy as np
import pytest
from openff.units import unit as offunit
import openfe
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols import openmm_afe
from .utils import UNIT_TYPES, _get_units
@pytest.fixture()
def protocol_dry_settings():
settings = openmm_afe.AbsoluteSolvationProtocol.default_settings()
settings.vacuum_engine_settings.compute_platform = None
settings.solvent_engine_settings.compute_platform = None
settings.protocol_repeats = 1
return settings
@pytest.fixture
def benzene_solvation_dag(benzene_system, protocol_dry_settings):
protocol_dry_settings.protocol_repeats = 3
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
stateA = benzene_system
stateB = ChemicalSystem({"solvent": SolventComponent()})
return protocol.create(stateA=stateA, stateB=stateB, mapping=None)
@pytest.fixture
def patcher():
with (
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventSetupUnit.run",
return_value={
"system": Path("system.xml.bz2"),
"positions": Path("positions.npy"),
"pdb_structure": Path("hybrid_system.pdb"),
"selection_indices": np.zeros(100),
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
"restraint_geometry": None,
},
),
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSetupUnit.run",
return_value={
"system": Path("system.xml.bz2"),
"positions": Path("positions.npy"),
"pdb_structure": Path("hybrid_system.pdb"),
"selection_indices": np.zeros(100),
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
"restraint_geometry": None,
},
),
mock.patch(
"openfe.protocols.openmm_afe.base_afe_units.np.load",
return_value=np.zeros(100),
),
mock.patch(
"openfe.protocols.openmm_afe.base_afe_units.deserialize",
return_value="foo",
),
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventSimUnit.run",
return_value={
"trajectory": Path("file.nc"),
"checkpoint": Path("chk.chk"),
},
),
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSimUnit.run",
return_value={
"trajectory": Path("file.nc"),
"checkpoint": Path("chk.chk"),
},
),
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventAnalysisUnit.run",
return_value={"foo": "bar"},
),
mock.patch(
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumAnalysisUnit.run",
return_value={"foo": "bar"},
),
):
yield
def test_gather(benzene_solvation_dag, patcher, tmpdir):
# check that .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
benzene_solvation_dag,
shared_basedir=tmpdir,
scratch_basedir=tmpdir,
keep_shared=True,
)
protocol = openmm_afe.AbsoluteSolvationProtocol(
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings(),
)
res = protocol.gather([dagres])
assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult)
def test_unit_tagging(benzene_solvation_dag, patcher, tmpdir):
# test that executing the units includes correct gen and repeat info
dag_units = benzene_solvation_dag.protocol_units
for phase in ["solvent", "vacuum"]:
setup_results = {}
sim_results = {}
analysis_results = {}
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
for u in setup_units:
rid = u.inputs["repeat_id"]
setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir))
for u in sim_units:
rid = u.inputs["repeat_id"]
sim_results[rid] = u.execute(
context=gufe.Context(tmpdir, tmpdir),
setup_results=setup_results[rid],
)
for u in a_units:
rid = u.inputs["repeat_id"]
analysis_results[rid] = u.execute(
context=gufe.Context(tmpdir, tmpdir),
setup_results=setup_results[rid],
simulation_results=sim_results[rid],
)
for results in [setup_results, sim_results, analysis_results]:
for ret in results.values():
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
assert len(setup_results) == len(sim_results) == len(analysis_results) == 3
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
return pr
def test_reload_protocol_result(self, afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"])
assert pr
def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()
assert est
assert est.m == pytest.approx(-2.47, abs=0.5)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()
assert est
assert est.m == pytest.approx(0.2, abs=0.2)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()
assert isinstance(inds, dict)
assert isinstance(inds["solvent"], list)
assert isinstance(inds["vacuum"], list)
assert len(inds["solvent"]) == len(inds["vacuum"]) == 3
for e, u in itertools.chain(inds["solvent"], inds["vacuum"]):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_forwards_etc(self, key, protocolresult):
far = protocolresult.get_forward_and_reverse_energy_analysis()
assert isinstance(far, dict)
assert isinstance(far[key], list)
far1 = far[key][0]
assert isinstance(far1, dict)
for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
assert k in far1
if k == "fractions":
assert isinstance(far1[k], np.ndarray)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_frwd_reverse_none_return(self, key, protocolresult):
# fetch the first result of type key
data = [i for i in protocolresult.data[key].values()][0][0]
# set the output to None
data.outputs["forward_and_reverse_energies"] = None
# now fetch the analysis results and expect a warning
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
with pytest.warns(UserWarning, match=wmsg):
protocolresult.get_forward_and_reverse_energy_analysis()
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()
assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 3
ovp1 = ovp[key][0]
assert isinstance(ovp1["matrix"], np.ndarray)
assert ovp1["matrix"].shape == (14, 14)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()
assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 3
rpx1 = rpx[key][0]
assert "eigenvalues" in rpx1
assert "matrix" in rpx1
assert rpx1["eigenvalues"].shape == (14,)
assert rpx1["matrix"].shape == (14, 14)
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()
assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 3
assert all(isinstance(v, float) for v in eq[key])
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()
assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 3
assert all(isinstance(v, float) for v in prod[key])
def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"
with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()

View File

@@ -79,16 +79,36 @@ def test_openmm_run_engine(
r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True)
assert r.ok()
for pur in r.protocol_unit_results:
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
checkpoint = pur.outputs["last_checkpoint"]
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
assert (unit_shared / checkpoint).exists()
nc = pur.outputs["nc"]
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()
# Check outputs of solvent & vacuum results
for phase in ["solvent", "vacuum"]:
purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase]
# get the path to the simulation unit shared dict
for pur in purs:
if "Simulation" in pur.name:
sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert sim_shared.exists()
assert pathlib.Path(sim_shared).is_dir()
# check the analysis outputs
for pur in purs:
if "Analysis" not in pur.name:
continue
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
# Does the checkpoint file exist?
checkpoint = pur.outputs["checkpoint"]
assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
assert checkpoint.exists()
# Does the trajectory file exist?
nc = pur.outputs["trajectory"]
assert nc == sim_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()
# Test results methods that need files present
results = protocol.gather([r])

View File

@@ -4,10 +4,19 @@ import json
import gufe
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import openfe
from openfe.protocols import openmm_afe
from openfe.protocols.openmm_afe import (
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
from ..conftest import ModGufeTokenizableTestsMixin
@pytest.fixture
@@ -27,18 +36,40 @@ def protocol_units(protocol, benzene_system):
return list(pus.protocol_units)
@pytest.fixture
def solvent_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_afe.AbsoluteSolvationSolventUnit):
def _filter_units(pus, classtype):
for pu in pus:
if isinstance(pu, classtype):
return pu
@pytest.fixture
def vacuum_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_afe.AbsoluteSolvationVacuumUnit):
return pu
def solvent_protocol_setup_unit(protocol_units):
return _filter_units(protocol_units, AHFESolventSetupUnit)
@pytest.fixture
def solvent_protocol_sim_unit(protocol_units):
return _filter_units(protocol_units, AHFESolventSimUnit)
@pytest.fixture
def solvent_protocol_analysis_unit(protocol_units):
return _filter_units(protocol_units, AHFESolventAnalysisUnit)
@pytest.fixture
def vacuum_protocol_setup_unit(protocol_units):
return _filter_units(protocol_units, AHFEVacuumSetupUnit)
@pytest.fixture
def vacuum_protocol_sim_unit(protocol_units):
return _filter_units(protocol_units, AHFEVacuumSimUnit)
@pytest.fixture
def vacuum_protocol_analysis_unit(protocol_units):
return _filter_units(protocol_units, AHFEVacuumAnalysisUnit)
@pytest.fixture
@@ -48,7 +79,7 @@ def protocol_result(afe_solv_transformation_json):
return pr
class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
class TestAbsoluteSolvationProtocol(ModGufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocol
key = None
repr = "AbsoluteSolvationProtocol-"
@@ -57,49 +88,68 @@ class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
def instance(self, protocol):
return protocol
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationSolventUnit
repr = "AbsoluteSolvationSolventUnit(Absolute Solvation, benzene solvent leg"
class TestAHFESolventSetupUnit(ModGufeTokenizableTestsMixin):
cls = AHFESolventSetupUnit
repr = "AHFESolventSetupUnit(AHFE Setup: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_unit):
return solvent_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
def instance(self, solvent_protocol_setup_unit):
return solvent_protocol_setup_unit
class TestAbsoluteSolvationVacuumUnit(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationVacuumUnit
repr = "AbsoluteSolvationVacuumUnit(Absolute Solvation, benzene vacuum leg"
class TestAHFESolventSimUnit(ModGufeTokenizableTestsMixin):
cls = AHFESolventSimUnit
repr = "AHFESolventSimUnit(AHFE Simulation: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, vacuum_protocol_unit):
return vacuum_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
def instance(self, solvent_protocol_sim_unit):
return solvent_protocol_sim_unit
class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
class TestAHFESolventAnalysisUnit(ModGufeTokenizableTestsMixin):
cls = AHFESolventAnalysisUnit
repr = "AHFESolventAnalysisUnit(AHFE Analysis: benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_analysis_unit):
return solvent_protocol_analysis_unit
class TestAHFEVacuumSetupUnit(ModGufeTokenizableTestsMixin):
cls = AHFEVacuumSetupUnit
repr = "AHFEVacuumSetupUnit(AHFE Setup: benzene vacuum leg"
key = None
@pytest.fixture()
def instance(self, vacuum_protocol_setup_unit):
return vacuum_protocol_setup_unit
class TestAHFEVacuumSimUnit(ModGufeTokenizableTestsMixin):
cls = AHFEVacuumSimUnit
repr = "AHFEVacuumSimUnit(AHFE Simulation: benzene vacuum leg"
key = None
@pytest.fixture()
def instance(self, vacuum_protocol_sim_unit):
return vacuum_protocol_sim_unit
class TestAHFEVacuumAnalysisUnit(ModGufeTokenizableTestsMixin):
cls = AHFEVacuumAnalysisUnit
repr = "AHFEVacuumAnalysisUnit(AHFE Analysis: benzene vacuum leg"
key = None
@pytest.fixture()
def instance(self, vacuum_protocol_analysis_unit):
return vacuum_protocol_analysis_unit
class TestAbsoluteSolvationProtocolResult(ModGufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocolResult
key = None
repr = "AbsoluteSolvationProtocolResult-"
@@ -107,10 +157,3 @@ class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
@pytest.fixture()
def instance(self, protocol_result):
return protocol_result
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)

View File

@@ -8,15 +8,8 @@ from openfe import ChemicalSystem, SolventComponent
from openfe.protocols import openmm_afe
from openfe.protocols.openmm_afe import (
AbsoluteSolvationProtocol,
AbsoluteSolvationSolventUnit,
AbsoluteSolvationVacuumUnit,
)
from openfe.protocols.openmm_utils import system_validation
from openfe.protocols.openmm_utils.charge_generation import (
HAS_ESPALOMA_CHARGE,
HAS_NAGL,
HAS_OPENEYE,
)
@pytest.fixture()

View File

@@ -0,0 +1,31 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from openfe.protocols.openmm_afe import (
AbsoluteSolvationProtocol,
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
UNIT_TYPES = {
"solvent": {
"setup": AHFESolventSetupUnit,
"sim": AHFESolventSimUnit,
"analysis": AHFESolventAnalysisUnit,
},
"vacuum": {
"setup": AHFEVacuumSetupUnit,
"sim": AHFEVacuumSimUnit,
"analysis": AHFEVacuumAnalysisUnit,
},
}
def _get_units(protocol_units, unit_type):
"""
Helper method to extract setup units.
"""
return [pu for pu in protocol_units if isinstance(pu, unit_type)]

View File

@@ -33,7 +33,12 @@ def _get_name(result: dict) -> str:
"""
solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
try:
name = solvent_data["inputs"]["setup_results"]["inputs"]["alchemical_components"]["stateA"][
0
]["molprops"]["ofe-name"]
except KeyError:
name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
return str(name)