mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
Turn AFE protocols into multiple units (#1776)
* Split the AFE protocol units into setup, simulation, and analysis.
This commit is contained in:
@@ -16,8 +16,12 @@ Protocol API specification
|
||||
:toctree: generated/
|
||||
|
||||
AbsoluteBindingProtocol
|
||||
AbsoluteBindingComplexUnit
|
||||
AbsoluteBindingSolventUnit
|
||||
ABFEComplexAnalysisUnit
|
||||
ABFEComplexSetupUnit
|
||||
ABFEComplexSimUnit
|
||||
ABFESolventAnalysisUnit
|
||||
ABFESolventSetupUnit
|
||||
ABFESolventSimUnit
|
||||
AbsoluteBindingProtocolResult
|
||||
|
||||
Protocol Settings
|
||||
|
||||
@@ -16,8 +16,12 @@ Protocol API specification
|
||||
:toctree: generated/
|
||||
|
||||
AbsoluteSolvationProtocol
|
||||
AbsoluteSolvationVacuumUnit
|
||||
AbsoluteSolvationSolventUnit
|
||||
AHFESolventAnalysisUnit
|
||||
AHFESolventSetupUnit
|
||||
AHFESolventSimUnit
|
||||
AHFEVacuumAnalysisUnit
|
||||
AHFEVacuumSetupUnit
|
||||
AHFEVacuumSimUnit
|
||||
AbsoluteSolvationProtocolResult
|
||||
|
||||
Protocol Settings
|
||||
|
||||
26
news/multi-unit-afe.rst
Normal file
26
news/multi-unit-afe.rst
Normal file
@@ -0,0 +1,26 @@
|
||||
**Added:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Changed:**
|
||||
|
||||
* The absolute free energy protocols have been broken into multiple
|
||||
protocol units, allowing for setup, run, and analysis to happen
|
||||
separately in the future when relevant changes to protocol execution are
|
||||
made (PR #1776).
|
||||
|
||||
**Deprecated:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Removed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Fixed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Security:**
|
||||
|
||||
* <news item>
|
||||
@@ -6,16 +6,24 @@ Run absolute free energy calculations using OpenMM and OpenMMTools.
|
||||
"""
|
||||
|
||||
from .abfe_units import (
|
||||
AbsoluteBindingComplexUnit,
|
||||
AbsoluteBindingSolventUnit,
|
||||
ABFEComplexAnalysisUnit,
|
||||
ABFEComplexSetupUnit,
|
||||
ABFEComplexSimUnit,
|
||||
ABFESolventAnalysisUnit,
|
||||
ABFESolventSetupUnit,
|
||||
ABFESolventSimUnit,
|
||||
)
|
||||
from .afe_protocol_results import (
|
||||
AbsoluteBindingProtocolResult,
|
||||
AbsoluteSolvationProtocolResult,
|
||||
)
|
||||
from .ahfe_units import (
|
||||
AbsoluteSolvationSolventUnit,
|
||||
AbsoluteSolvationVacuumUnit,
|
||||
AHFESolventAnalysisUnit,
|
||||
AHFESolventSetupUnit,
|
||||
AHFESolventSimUnit,
|
||||
AHFEVacuumAnalysisUnit,
|
||||
AHFEVacuumSetupUnit,
|
||||
AHFEVacuumSimUnit,
|
||||
)
|
||||
from .equil_binding_afe_method import (
|
||||
AbsoluteBindingProtocol,
|
||||
@@ -30,11 +38,19 @@ __all__ = [
|
||||
"AbsoluteSolvationProtocol",
|
||||
"AbsoluteSolvationSettings",
|
||||
"AbsoluteSolvationProtocolResult",
|
||||
"AbsoluteVacuumUnit",
|
||||
"AbsoluteSolventUnit",
|
||||
"AHFESolventSetupUnit",
|
||||
"AHFESolventSimUnit",
|
||||
"AHFESolventAnalysisUnit",
|
||||
"AHFEVacuumSetupUnit",
|
||||
"AHFEVacuumSimUnit",
|
||||
"AHFEVacuumAnalysisUnit",
|
||||
"AbsoluteBindingProtocol",
|
||||
"AbsoluteBindingSettings",
|
||||
"AbsoluteBindingProtocolResult",
|
||||
"AbsoluteBindingComplexUnit",
|
||||
"AbsoluteBindingSolventUnit",
|
||||
"ABFEComplexSetupUnit",
|
||||
"ABFEComplexSimUnit",
|
||||
"ABFEComplexAnalysisUnit",
|
||||
"ABFESolventSetupUnit",
|
||||
"ABFESolventSimUnit",
|
||||
"ABFESolventAnalysisUnit",
|
||||
]
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
"""ABFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.abfe_units`
|
||||
========================================================================
|
||||
This module defines the ProtocolUnits for the
|
||||
@@ -23,7 +21,7 @@ from openff.units.openmm import to_openmm
|
||||
from openmm import System
|
||||
from openmm import unit as ommunit
|
||||
from openmm.app import Topology as omm_topology
|
||||
from openmmtools.states import GlobalParameterState, ThermodynamicState
|
||||
from openmmtools.states import ThermodynamicState
|
||||
from rdkit import Chem
|
||||
|
||||
from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
@@ -36,18 +34,16 @@ from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGe
|
||||
from openfe.protocols.restraint_utils.openmm import omm_restraints
|
||||
from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint
|
||||
|
||||
from .base_afe_units import BaseAbsoluteUnit
|
||||
from .base_afe_units import (
|
||||
BaseAbsoluteMultiStateAnalysisUnit,
|
||||
BaseAbsoluteMultiStateSimulationUnit,
|
||||
BaseAbsoluteSetupUnit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
"""
|
||||
Protocol Unit for the complex phase of an absolute binding free energy
|
||||
"""
|
||||
|
||||
simtype = "complex"
|
||||
|
||||
class ComplexComponentsMixin:
|
||||
def _get_components(self):
|
||||
"""
|
||||
Get the relevant components for a complex transformation.
|
||||
@@ -75,7 +71,9 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
# Similarly we don't need to check prot_comp
|
||||
return alchem_comps, solv_comp, prot_comp, off_comps
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
|
||||
class ComplexSettingsMixin:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a complex transformation.
|
||||
|
||||
@@ -97,7 +95,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
* output_settings: MultiStateOutputSettings
|
||||
* restraint_settings: BaseRestraintSettings
|
||||
"""
|
||||
prot_settings = self._inputs["protocol"].settings
|
||||
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
|
||||
|
||||
settings = {}
|
||||
settings["forcefield_settings"] = prot_settings.forcefield_settings
|
||||
@@ -116,6 +114,15 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class ABFEComplexSetupUnit(ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteSetupUnit):
|
||||
"""
|
||||
Setup unit for the complex phase of absolute binding free energy
|
||||
transformations.
|
||||
"""
|
||||
|
||||
simtype = "complex"
|
||||
|
||||
@staticmethod
|
||||
def _get_mda_universe(
|
||||
topology: omm_topology,
|
||||
@@ -261,7 +268,6 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
comp_resids: dict[Component, npt.NDArray],
|
||||
settings: dict[str, SettingsBaseModel],
|
||||
) -> tuple[
|
||||
GlobalParameterState,
|
||||
Quantity,
|
||||
System,
|
||||
geometry.HostGuestRestraintGeometry,
|
||||
@@ -295,9 +301,6 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
|
||||
Returns
|
||||
-------
|
||||
restraint_parameter_state : RestraintParameterState
|
||||
A RestraintParameterState object that defines the control
|
||||
parameter for the restraint.
|
||||
correction : openff.units.Quantity
|
||||
The standard state correction for the restraint.
|
||||
system : openmm.System
|
||||
@@ -380,10 +383,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
rest_geom,
|
||||
)
|
||||
|
||||
# Get the GlobalParameterState for the restraint
|
||||
restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0)
|
||||
return (
|
||||
restraint_parameter_state,
|
||||
correction,
|
||||
# Remove the thermostat, otherwise you'll get an
|
||||
# Andersen thermostat by default!
|
||||
@@ -392,13 +392,28 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit):
|
||||
)
|
||||
|
||||
|
||||
class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
|
||||
class ABFEComplexSimUnit(
|
||||
ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
|
||||
):
|
||||
"""
|
||||
Protocol Unit for the solvent phase of an absolute binding free energy
|
||||
Multi-state simulation (e.g. multi replica methods like Hamiltonian
|
||||
replica exchange) unit for the complex phase of absolute binding
|
||||
free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
simtype = "complex"
|
||||
|
||||
|
||||
class ABFEComplexAnalysisUnit(ComplexSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
|
||||
"""
|
||||
Analysis unit for multi-state simulations with the complex phase
|
||||
of absolute binding free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "complex"
|
||||
|
||||
|
||||
class SolventComponentsMixin:
|
||||
def _get_components(self):
|
||||
"""
|
||||
Get the relevant components for a solvent transformation.
|
||||
@@ -426,7 +441,9 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
|
||||
# Similarly we don't need to check prot_comp just return None
|
||||
return alchem_comps, solv_comp, None, off_comps
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
|
||||
class SolventSettingsMixin:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a solvent transformation.
|
||||
|
||||
@@ -447,7 +464,7 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
|
||||
* simulation_settings : MultiStateSimulationSettings
|
||||
* output_settings: MultiStateOutputSettings
|
||||
"""
|
||||
prot_settings = self._inputs["protocol"].settings
|
||||
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
|
||||
|
||||
settings = {}
|
||||
settings["forcefield_settings"] = prot_settings.forcefield_settings
|
||||
@@ -464,3 +481,33 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit):
|
||||
settings["output_settings"] = prot_settings.solvent_output_settings
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class ABFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
|
||||
"""
|
||||
Setup unit for the solvent phase of absolute binding free energy
|
||||
transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
|
||||
class ABFESolventSimUnit(
|
||||
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
|
||||
):
|
||||
"""
|
||||
Multi-state simulation (e.g. multi replica methods like Hamiltonian
|
||||
replica exchange) unit for the solvent phase of absolute binding
|
||||
free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
|
||||
class ABFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
|
||||
"""
|
||||
Analysis unit for multi-state simulations with the solvent phase
|
||||
of absolute binding free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
@@ -214,8 +214,8 @@ class AbsoluteProtocolResultMixin:
|
||||
for key in [self.bound_state, self.unbound_state]:
|
||||
for pus in self.data[key].values(): # type: ignore[attr-defined]
|
||||
states = get_replica_state(
|
||||
pus[0].outputs["nc"],
|
||||
pus[0].outputs["last_checkpoint"],
|
||||
pus[0].outputs["trajectory"],
|
||||
pus[0].outputs["checkpoint"],
|
||||
)
|
||||
replica_states[key].append(states)
|
||||
|
||||
@@ -295,7 +295,9 @@ class AbsoluteProtocolResultMixin:
|
||||
|
||||
|
||||
class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
|
||||
"""Dict-like container for the output of a AbsoluteSolvationProtocol"""
|
||||
"""
|
||||
Protocol results with the output of a AbsoluteSolvationProtocol
|
||||
"""
|
||||
|
||||
bound_state = "solvent"
|
||||
unbound_state = "vacuum"
|
||||
@@ -375,7 +377,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResul
|
||||
|
||||
|
||||
class AbsoluteBindingProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
|
||||
"""Dict-like container for the output of a AbsoluteBindingProtocol"""
|
||||
"""
|
||||
Protocol results with the output of a AbsoluteBindingProtocol.
|
||||
"""
|
||||
|
||||
bound_state = "complex"
|
||||
unbound_state = "solvent"
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
"""AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units`
|
||||
========================================================================
|
||||
"""
|
||||
AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units`
|
||||
=====================================================================
|
||||
|
||||
This module defines the ProtocolUnits for the
|
||||
:class:`AbsoluteSolvationProtocol`.
|
||||
"""
|
||||
@@ -15,18 +15,16 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
)
|
||||
|
||||
from ..openmm_utils import system_validation
|
||||
from .base_afe_units import BaseAbsoluteUnit
|
||||
from .base_afe_units import (
|
||||
BaseAbsoluteMultiStateAnalysisUnit,
|
||||
BaseAbsoluteMultiStateSimulationUnit,
|
||||
BaseAbsoluteSetupUnit,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
|
||||
"""
|
||||
Protocol Unit for the vacuum phase of an absolute solvation free energy
|
||||
"""
|
||||
|
||||
simtype = "vacuum"
|
||||
|
||||
class VacuumComponentsMixin:
|
||||
def _get_components(self):
|
||||
"""
|
||||
Get the relevant components for a vacuum transformation.
|
||||
@@ -59,7 +57,9 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
|
||||
# (of stateA since we enforce only one disappearing ligand)
|
||||
return alchem_comps, None, prot_comp, off_comps
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
|
||||
class VacuumSettingsMixin:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a vacuum transformation.
|
||||
|
||||
@@ -80,7 +80,7 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
|
||||
* simulation_settings : SimulationSettings
|
||||
* output_settings: MultiStateOutputSettings
|
||||
"""
|
||||
prot_settings = self._inputs["protocol"].settings
|
||||
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
|
||||
|
||||
settings = {}
|
||||
settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings
|
||||
@@ -99,13 +99,37 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
|
||||
return settings
|
||||
|
||||
|
||||
class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
|
||||
class AHFEVacuumSetupUnit(VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteSetupUnit):
|
||||
"""
|
||||
Protocol Unit for the solvent phase of an absolute solvation free energy
|
||||
Setup unit for the vacuum phase of absolute hydration free energy
|
||||
transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
simtype = "vacuum"
|
||||
|
||||
|
||||
class AHFEVacuumSimUnit(
|
||||
VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
|
||||
):
|
||||
"""
|
||||
Multi-state simulation (e.g. multi replica methods like Hamiltonian
|
||||
replica exchange) unit for the vacuum phase of absolute hydration
|
||||
free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "vacuum"
|
||||
|
||||
|
||||
class AHFEVacuumAnalysisUnit(VacuumSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
|
||||
"""
|
||||
Analysis unit for multi-state simulations with the vacuum phase
|
||||
of absolute hydration free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "vacuum"
|
||||
|
||||
|
||||
class SolventComponentsMixin:
|
||||
def _get_components(self):
|
||||
"""
|
||||
Get the relevant components for a solvent transformation.
|
||||
@@ -134,7 +158,9 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
|
||||
# disallowed on create
|
||||
return alchem_comps, solv_comp, prot_comp, off_comps
|
||||
|
||||
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
|
||||
class SolventSettingsMixin:
|
||||
def _get_settings(self) -> dict[str, SettingsBaseModel]:
|
||||
"""
|
||||
Extract the relevant settings for a solvent transformation.
|
||||
|
||||
@@ -155,7 +181,7 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
|
||||
* simulation_settings : MultiStateSimulationSettings
|
||||
* output_settings: MultiStateOutputSettings
|
||||
"""
|
||||
prot_settings = self._inputs["protocol"].settings
|
||||
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
|
||||
|
||||
settings = {}
|
||||
settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings
|
||||
@@ -172,3 +198,33 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
|
||||
settings["output_settings"] = prot_settings.solvent_output_settings
|
||||
|
||||
return settings
|
||||
|
||||
|
||||
class AHFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
|
||||
"""
|
||||
Setup unit for the solvent phase of absolute hydration free energy
|
||||
transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
|
||||
class AHFESolventSimUnit(
|
||||
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
|
||||
):
|
||||
"""
|
||||
Multi-state simulation (e.g. multi replica methods like Hamiltonian
|
||||
replica exchange) unit for the solvent phase of absolute hydration
|
||||
free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
|
||||
class AHFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
|
||||
"""
|
||||
Analysis unit for multi-state simulations with the solvent phase
|
||||
of absolute hydration free energy transformations.
|
||||
"""
|
||||
|
||||
simtype = "solvent"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -55,9 +55,19 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
OpenMMEngineSettings,
|
||||
OpenMMSolvationSettings,
|
||||
)
|
||||
from openfe.protocols.openmm_utils import settings_validation, system_validation
|
||||
from openfe.protocols.openmm_utils import (
|
||||
settings_validation,
|
||||
system_validation,
|
||||
)
|
||||
|
||||
from .abfe_units import AbsoluteBindingComplexUnit, AbsoluteBindingSolventUnit
|
||||
from .abfe_units import (
|
||||
ABFEComplexAnalysisUnit,
|
||||
ABFEComplexSetupUnit,
|
||||
ABFEComplexSimUnit,
|
||||
ABFESolventAnalysisUnit,
|
||||
ABFESolventSetupUnit,
|
||||
ABFESolventSimUnit,
|
||||
)
|
||||
from .afe_protocol_results import AbsoluteBindingProtocolResult
|
||||
|
||||
due.cite(
|
||||
@@ -422,36 +432,58 @@ class AbsoluteBindingProtocol(gufe.Protocol):
|
||||
|
||||
# Get the name of the alchemical species
|
||||
alchname = alchem_comps["stateA"][0].name
|
||||
unit_classes = {
|
||||
"solvent": {
|
||||
"setup": ABFESolventSetupUnit,
|
||||
"simulation": ABFESolventSimUnit,
|
||||
"analysis": ABFESolventAnalysisUnit,
|
||||
},
|
||||
"complex": {
|
||||
"setup": ABFEComplexSetupUnit,
|
||||
"simulation": ABFEComplexSimUnit,
|
||||
"analysis": ABFEComplexAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
# Create list units for complex and solvent transforms
|
||||
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []}
|
||||
|
||||
solvent_units = [
|
||||
AbsoluteBindingSolventUnit(
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(f"Absolute Binding, {alchname} solvent leg: repeat {i} generation 0"),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
for phase in ["solvent", "complex"]:
|
||||
for i in range(self.settings.protocol_repeats):
|
||||
repeat_id = int(uuid.uuid4())
|
||||
|
||||
complex_units = [
|
||||
AbsoluteBindingComplexUnit(
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(f"Absolute Binding, {alchname} complex leg: repeat {i} generation 0"),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
setup = unit_classes[phase]["setup"](
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"ABFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
|
||||
)
|
||||
|
||||
return solvent_units + complex_units
|
||||
simulation = unit_classes[phase]["simulation"](
|
||||
protocol=self,
|
||||
# only need state A & alchem comps
|
||||
stateA=stateA,
|
||||
alchemical_components=alchem_comps,
|
||||
setup_results=setup,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"ABFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
|
||||
)
|
||||
|
||||
analysis = unit_classes[phase]["analysis"](
|
||||
protocol=self,
|
||||
setup_results=setup,
|
||||
simulation_results=simulation,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"ABFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
|
||||
)
|
||||
|
||||
protocol_units[phase] += [setup, simulation, analysis]
|
||||
|
||||
return protocol_units["solvent"] + protocol_units["complex"]
|
||||
|
||||
def _gather(
|
||||
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
|
||||
@@ -463,7 +495,7 @@ class AbsoluteBindingProtocol(gufe.Protocol):
|
||||
for d in protocol_dag_results:
|
||||
pu: gufe.ProtocolUnitResult
|
||||
for pu in d.protocol_unit_results:
|
||||
if not pu.ok():
|
||||
if ("Analysis" not in pu.name) or (not pu.ok()):
|
||||
continue
|
||||
if pu.outputs["simtype"] == "solvent":
|
||||
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
|
||||
|
||||
@@ -63,8 +63,12 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
from ..openmm_utils import settings_validation, system_validation
|
||||
from .afe_protocol_results import AbsoluteSolvationProtocolResult
|
||||
from .ahfe_units import (
|
||||
AbsoluteSolvationSolventUnit,
|
||||
AbsoluteSolvationVacuumUnit,
|
||||
AHFESolventAnalysisUnit,
|
||||
AHFESolventSetupUnit,
|
||||
AHFESolventSimUnit,
|
||||
AHFEVacuumAnalysisUnit,
|
||||
AHFEVacuumSetupUnit,
|
||||
AHFEVacuumSimUnit,
|
||||
)
|
||||
|
||||
due.cite(
|
||||
@@ -445,36 +449,58 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
|
||||
# Get the name of the alchemical species
|
||||
alchname = alchem_comps["stateA"][0].name
|
||||
|
||||
# Create list units for vacuum and solvent transforms
|
||||
solvent_units = [
|
||||
AbsoluteSolvationSolventUnit(
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(f"Absolute Solvation, {alchname} solvent leg: repeat {i} generation 0"),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
unit_classes = {
|
||||
"solvent": {
|
||||
"setup": AHFESolventSetupUnit,
|
||||
"simulation": AHFESolventSimUnit,
|
||||
"analysis": AHFESolventAnalysisUnit,
|
||||
},
|
||||
"vacuum": {
|
||||
"setup": AHFEVacuumSetupUnit,
|
||||
"simulation": AHFEVacuumSimUnit,
|
||||
"analysis": AHFEVacuumAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
vacuum_units = [
|
||||
AbsoluteSolvationVacuumUnit(
|
||||
# These don't really reflect the actual transform
|
||||
# Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ?
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=int(uuid.uuid4()),
|
||||
name=(f"Absolute Solvation, {alchname} vacuum leg: repeat {i} generation 0"),
|
||||
)
|
||||
for i in range(self.settings.protocol_repeats)
|
||||
]
|
||||
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "vacuum": []}
|
||||
|
||||
return solvent_units + vacuum_units
|
||||
for phase in ["solvent", "vacuum"]:
|
||||
for i in range(self.settings.protocol_repeats):
|
||||
repeat_id = int(uuid.uuid4())
|
||||
|
||||
setup = unit_classes[phase]["setup"](
|
||||
protocol=self,
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
alchemical_components=alchem_comps,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"AHFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
|
||||
)
|
||||
|
||||
simulation = unit_classes[phase]["simulation"](
|
||||
protocol=self,
|
||||
# only need state A & alchem comps
|
||||
stateA=stateA,
|
||||
alchemical_components=alchem_comps,
|
||||
setup_results=setup,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"AHFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
|
||||
)
|
||||
|
||||
analysis = unit_classes[phase]["analysis"](
|
||||
protocol=self,
|
||||
setup_results=setup,
|
||||
simulation_results=simulation,
|
||||
generation=0,
|
||||
repeat_id=repeat_id,
|
||||
name=f"AHFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
|
||||
)
|
||||
|
||||
protocol_units[phase] += [setup, simulation, analysis]
|
||||
|
||||
return protocol_units["solvent"] + protocol_units["vacuum"]
|
||||
|
||||
def _gather(
|
||||
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
|
||||
@@ -486,7 +512,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
|
||||
for d in protocol_dag_results:
|
||||
pu: gufe.ProtocolUnitResult
|
||||
for pu in d.protocol_unit_results:
|
||||
if not pu.ok():
|
||||
if ("Analysis" not in pu.name) or (not pu.ok()):
|
||||
continue
|
||||
if pu.outputs["simtype"] == "solvent":
|
||||
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
import os
|
||||
import pathlib
|
||||
|
||||
from gufe.settings.typing import NanometerArrayQuantity
|
||||
from openff.units import Quantity
|
||||
from openmm import Vec3
|
||||
from openmm import unit as ommunit
|
||||
|
||||
|
||||
def serialize(item, filename: pathlib.Path):
|
||||
"""
|
||||
@@ -62,3 +67,23 @@ def deserialize(filename: pathlib.Path):
|
||||
item = XmlSerializer.deserialize(serialized_thing)
|
||||
|
||||
return item
|
||||
|
||||
|
||||
def make_vec3_box(dimensions: NanometerArrayQuantity) -> Vec3:
|
||||
"""
|
||||
Convert an OpenFF box dimensions Quantity back into Vec3 format.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dimensions : gufe.settings.typing.NanometerArrayQuantity
|
||||
United array to turn to Vec3 format.
|
||||
|
||||
Returns
|
||||
-------
|
||||
openmm.Vec3
|
||||
The input array in Vec3 format.
|
||||
"""
|
||||
return [
|
||||
Vec3(float(row[0]), float(row[1]), float(row[2])) * ommunit.nanometer
|
||||
for row in dimensions.m_as("nanometer")
|
||||
]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
import openmm
|
||||
import pooch
|
||||
import pytest
|
||||
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
|
||||
from openff.units import Quantity, unit
|
||||
from openff.units.openmm import from_openmm
|
||||
from openmm import Platform
|
||||
@@ -326,6 +327,20 @@ def get_available_openmm_platforms() -> set[str]:
|
||||
return working_platforms
|
||||
|
||||
|
||||
class ModGufeTokenizableTestsMixin(GufeTokenizableTestsMixin):
|
||||
"""
|
||||
A modified gufe tokenizable tests mixin which allows
|
||||
for repr to be lazily evaluated.
|
||||
"""
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
def compute_energy(
|
||||
system: openmm.System,
|
||||
positions: openmm.unit.Quantity,
|
||||
|
||||
@@ -17,8 +17,8 @@ from openmmtools.alchemy import (
|
||||
)
|
||||
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteBindingComplexUnit,
|
||||
from openfe.protocols.openmm_afe.abfe_units import (
|
||||
ABFEComplexSetupUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings
|
||||
from openfe.protocols.openmm_utils.serialization import deserialize
|
||||
@@ -147,11 +147,11 @@ class TestT4EnergiesRegression:
|
||||
|
||||
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
|
||||
complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
|
||||
complex_units = [u for u in dag.protocol_units if isinstance(u, ABFEComplexSetupUnit)]
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
data = complex_units[0].run(dry=True)["debug"]
|
||||
return data
|
||||
results = complex_units[0].run(dry=True)
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def get_energy_components(
|
||||
@@ -182,7 +182,7 @@ class TestT4EnergiesRegression:
|
||||
energies_ref = self.get_energy_components(
|
||||
t4_reference_system,
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_val,
|
||||
lambda_val,
|
||||
lambda_val,
|
||||
@@ -191,7 +191,7 @@ class TestT4EnergiesRegression:
|
||||
energies_val = self.get_energy_components(
|
||||
t4_validation_data["alchem_system"],
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_val,
|
||||
lambda_val,
|
||||
lambda_val,
|
||||
@@ -223,7 +223,7 @@ class TestT4EnergiesRegression:
|
||||
energies = self.get_energy_components(
|
||||
t4_validation_data["alchem_system"],
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_sterics=1.0,
|
||||
lambda_electrostatics=1.0,
|
||||
lambda_restraints=1.0,
|
||||
@@ -236,7 +236,7 @@ class TestT4EnergiesRegression:
|
||||
energies = self.get_energy_components(
|
||||
t4_validation_data["alchem_system"],
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_sterics=1.0,
|
||||
lambda_electrostatics=1.0,
|
||||
lambda_restraints=0.0,
|
||||
@@ -249,7 +249,7 @@ class TestT4EnergiesRegression:
|
||||
energies = self.get_energy_components(
|
||||
t4_validation_data["alchem_system"],
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_sterics=1.0,
|
||||
lambda_electrostatics=0.0,
|
||||
lambda_restraints=0.0,
|
||||
@@ -270,7 +270,7 @@ class TestT4EnergiesRegression:
|
||||
energies = self.get_energy_components(
|
||||
t4_validation_data["alchem_system"],
|
||||
t4_validation_data["alchem_indices"],
|
||||
t4_validation_data["positions"],
|
||||
t4_validation_data["debug_positions"],
|
||||
lambda_sterics=0.0,
|
||||
lambda_electrostatics=0.0,
|
||||
lambda_restraints=0.0,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
from importlib import resources
|
||||
from math import sqrt
|
||||
from unittest import mock
|
||||
|
||||
@@ -23,9 +22,7 @@ from openmm import (
|
||||
)
|
||||
from openmm import unit as ommunit
|
||||
from openmmtools.alchemy import (
|
||||
AbsoluteAlchemicalFactory,
|
||||
AlchemicalRegion,
|
||||
AlchemicalState,
|
||||
)
|
||||
from openmmtools.multistate.multistatesampler import MultiStateSampler
|
||||
from openmmtools.tests.test_alchemy import (
|
||||
@@ -38,11 +35,16 @@ import openfe
|
||||
from openfe import ChemicalSystem, SmallMoleculeComponent, SolventComponent
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteBindingComplexUnit,
|
||||
AbsoluteBindingProtocol,
|
||||
AbsoluteBindingSolventUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings
|
||||
from openfe.protocols.openmm_afe.abfe_units import (
|
||||
ABFEComplexSetupUnit,
|
||||
ABFEComplexSimUnit,
|
||||
ABFESolventSetupUnit,
|
||||
ABFESolventSimUnit,
|
||||
)
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
@@ -75,37 +77,53 @@ def test_serialize_protocol(default_settings):
|
||||
assert protocol == ret
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_complex_dag, tmpdir):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
def test_repeat_units(benzene_modifications, T4_protein_component):
|
||||
protocol = openmm_afe.AbsoluteBindingProtocol(
|
||||
settings=openmm_afe.AbsoluteBindingProtocol.default_settings()
|
||||
)
|
||||
|
||||
dag_units = benzene_complex_dag.protocol_units
|
||||
stateA = gufe.ChemicalSystem(
|
||||
{
|
||||
"protein": T4_protein_component,
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"solvent": gufe.SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
):
|
||||
results = []
|
||||
for u in dag_units:
|
||||
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
|
||||
results.append(ret)
|
||||
stateB = gufe.ChemicalSystem(
|
||||
{
|
||||
"protein": T4_protein_component,
|
||||
"solvent": gufe.SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
solv_repeats = set()
|
||||
complex_repeats = set()
|
||||
for ret in results:
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
if ret.outputs["simtype"] == "complex":
|
||||
complex_repeats.add(ret.outputs["repeat_id"])
|
||||
else:
|
||||
solv_repeats.add(ret.outputs["repeat_id"])
|
||||
# Repeat ids are random ints so just check their lengths
|
||||
assert len(complex_repeats) == len(solv_repeats) == 3
|
||||
dag = protocol.create(
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
# 6 protocol unit, 3 per repeat
|
||||
pus = list(dag.protocol_units)
|
||||
assert len(pus) == 18
|
||||
|
||||
# Check info for each repeat
|
||||
for phase in ["solvent", "complex"]:
|
||||
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
|
||||
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
|
||||
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
# Should be 3 of each set
|
||||
assert len(setup) == len(sim) == len(analysis) == 3
|
||||
|
||||
# Check that the dag chain is correct
|
||||
for analysis_pu in analysis:
|
||||
repeat_id = analysis_pu.inputs["repeat_id"]
|
||||
setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0]
|
||||
sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0]
|
||||
assert analysis_pu.inputs["setup_results"] == setup_pu
|
||||
assert analysis_pu.inputs["simulation_results"] == sim_pu
|
||||
assert sim_pu.inputs["setup_results"] == setup_pu
|
||||
|
||||
|
||||
def test_create_independent_repeat_ids(benzene_modifications, T4_protein_component):
|
||||
@@ -137,9 +155,12 @@ def test_create_independent_repeat_ids(benzene_modifications, T4_protein_compone
|
||||
repeat_ids = set()
|
||||
|
||||
for dag in dags:
|
||||
# 3 sets of 6 units
|
||||
assert len(list(dag.protocol_units)) == 18
|
||||
for u in dag.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
|
||||
# squashed by repeat_id, that's 2 sets of 6
|
||||
assert len(repeat_ids) == 12
|
||||
|
||||
|
||||
@@ -149,7 +170,7 @@ def test_mda_universe_error():
|
||||
when calling the mda Universe getter.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="No positions to create"):
|
||||
_ = openmm_afe.AbsoluteBindingComplexUnit._get_mda_universe(
|
||||
_ = openmm_afe.ABFEComplexSetupUnit._get_mda_universe(
|
||||
topology="foo", positions=None, trajectory=None
|
||||
)
|
||||
|
||||
@@ -201,17 +222,32 @@ class TestT4LysozymeDryRun:
|
||||
)
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def complex_units(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
|
||||
def complex_setup_units(self, dag):
|
||||
return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"])
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def solvent_units(self, dag):
|
||||
return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingSolventUnit)]
|
||||
def complex_sim_units(self, dag):
|
||||
return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["sim"])
|
||||
|
||||
def test_number_of_units(self, dag, complex_units, solvent_units):
|
||||
assert len(list(dag.protocol_units)) == 2
|
||||
assert len(complex_units) == 1
|
||||
assert len(solvent_units) == 1
|
||||
@pytest.fixture(scope="class")
|
||||
def solvent_setup_units(self, dag):
|
||||
return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["setup"])
|
||||
|
||||
@pytest.fixture(scope="class")
|
||||
def solvent_sim_units(self, dag):
|
||||
return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["sim"])
|
||||
|
||||
def test_number_of_units(
|
||||
self,
|
||||
dag,
|
||||
complex_setup_units,
|
||||
complex_sim_units,
|
||||
solvent_setup_units,
|
||||
solvent_sim_units,
|
||||
):
|
||||
assert len(list(dag.protocol_units)) == 6
|
||||
assert len(complex_setup_units) == len(complex_sim_units) == 1
|
||||
assert len(solvent_setup_units) == len(solvent_sim_units) == 1
|
||||
|
||||
def _assert_force_num(self, system, forcetype, number):
|
||||
forces = [f for f in system.getForces() if isinstance(f, forcetype)]
|
||||
@@ -356,83 +392,99 @@ class TestT4LysozymeDryRun:
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
def test_complex_dry_run(self, complex_units, settings, tmpdir):
|
||||
def test_complex_dry_run(self, complex_setup_units, complex_sim_units, settings, tmpdir):
|
||||
with tmpdir.as_cwd():
|
||||
data = complex_units[0].run(dry=True, verbose=True)["debug"]
|
||||
setup_results = complex_setup_units[0].run(dry=True, verbose=True)
|
||||
sim_results = complex_sim_units[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=True,
|
||||
dry=True,
|
||||
)
|
||||
|
||||
# Check the sampler
|
||||
self._verify_sampler(data["sampler"], complexed=True, settings=settings)
|
||||
self._verify_sampler(sim_results["sampler"], complexed=True, settings=settings)
|
||||
|
||||
# Check the alchemical system
|
||||
self._assert_expected_alchemical_forces(
|
||||
data["alchem_system"], complexed=True, settings=settings
|
||||
setup_results["alchem_system"], complexed=True, settings=settings
|
||||
)
|
||||
self._test_dodecahedron_vectors(data["alchem_system"])
|
||||
self._test_dodecahedron_vectors(setup_results["alchem_system"])
|
||||
|
||||
# Check the alchemical indices
|
||||
expected_indices = [i + self.num_complex_atoms for i in range(self.num_solvent_atoms)]
|
||||
assert expected_indices == data["alchem_indices"]
|
||||
assert expected_indices == setup_results["alchem_indices"]
|
||||
|
||||
# Check the non-alchemical system
|
||||
self._assert_expected_nonalchemical_forces(data["system"], settings)
|
||||
self._test_dodecahedron_vectors(data["system"])
|
||||
self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings)
|
||||
self._test_dodecahedron_vectors(setup_results["standard_system"])
|
||||
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
|
||||
assert_allclose(
|
||||
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(data["system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()),
|
||||
)
|
||||
|
||||
# Check the PDB
|
||||
pdb = mdt.load_pdb("alchemical_system.pdb")
|
||||
pdb = mdt.load_pdb(setup_results["pdb_structure"])
|
||||
assert pdb.n_atoms == self.num_all_not_water
|
||||
|
||||
# Check energies
|
||||
alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"])
|
||||
alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"])
|
||||
self._test_energies(
|
||||
reference_system=data["system"],
|
||||
alchemical_system=data["alchem_system"],
|
||||
reference_system=setup_results["standard_system"],
|
||||
alchemical_system=setup_results["alchem_system"],
|
||||
alchemical_regions=alchem_region,
|
||||
positions=data["positions"],
|
||||
positions=setup_results["debug_positions"],
|
||||
)
|
||||
|
||||
def test_solvent_dry_run(self, solvent_units, settings, tmpdir):
|
||||
def test_solvent_dry_run(self, solvent_setup_units, solvent_sim_units, settings, tmpdir):
|
||||
with tmpdir.as_cwd():
|
||||
data = solvent_units[0].run(dry=True, verbose=True)["debug"]
|
||||
setup_results = solvent_setup_units[0].run(dry=True, verbose=True)
|
||||
sim_results = solvent_sim_units[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
|
||||
# Check the sampler
|
||||
self._verify_sampler(data["sampler"], complexed=False, settings=settings)
|
||||
self._verify_sampler(sim_results["sampler"], complexed=False, settings=settings)
|
||||
|
||||
# Check the alchemical system
|
||||
self._assert_expected_alchemical_forces(
|
||||
data["alchem_system"], complexed=False, settings=settings
|
||||
setup_results["alchem_system"], complexed=False, settings=settings
|
||||
)
|
||||
self._test_cubic_vectors(data["alchem_system"])
|
||||
self._test_cubic_vectors(setup_results["alchem_system"])
|
||||
|
||||
# Check the alchemical indices
|
||||
expected_indices = [i for i in range(self.num_solvent_atoms)]
|
||||
assert expected_indices == data["alchem_indices"]
|
||||
assert expected_indices == setup_results["alchem_indices"]
|
||||
|
||||
# Check the non-alchemical system
|
||||
self._assert_expected_nonalchemical_forces(data["system"], settings)
|
||||
self._test_cubic_vectors(data["system"])
|
||||
self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings)
|
||||
self._test_cubic_vectors(setup_results["standard_system"])
|
||||
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
|
||||
assert_allclose(
|
||||
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(data["system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()),
|
||||
from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()),
|
||||
)
|
||||
|
||||
# Check the PDB
|
||||
pdb = mdt.load_pdb("alchemical_system.pdb")
|
||||
pdb = mdt.load_pdb(setup_results["pdb_structure"])
|
||||
assert pdb.n_atoms == self.num_solvent_atoms
|
||||
|
||||
# Check energies
|
||||
alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"])
|
||||
alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"])
|
||||
|
||||
self._test_energies(
|
||||
reference_system=data["system"],
|
||||
alchemical_system=data["alchem_system"],
|
||||
reference_system=setup_results["standard_system"],
|
||||
alchemical_system=setup_results["alchem_system"],
|
||||
alchemical_regions=alchem_region,
|
||||
positions=data["positions"],
|
||||
positions=setup_results["debug_positions"],
|
||||
)
|
||||
|
||||
|
||||
@@ -510,15 +562,17 @@ def test_user_charges(benzene_modifications, T4_protein_component, tmpdir):
|
||||
|
||||
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
|
||||
complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)]
|
||||
complex_setup_units = _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
data = complex_units[0].run(dry=True)["debug"]
|
||||
results = complex_setup_units[0].run(dry=True)
|
||||
|
||||
system_nbf = [f for f in data["system"].getForces() if isinstance(f, NonbondedForce)][0]
|
||||
system_nbf = [
|
||||
f for f in results["standard_system"].getForces() if isinstance(f, NonbondedForce)
|
||||
][0]
|
||||
alchem_system_nbf = [
|
||||
f
|
||||
for f in data["alchem_system"].getForces()
|
||||
for f in results["alchem_system"].getForces()
|
||||
if isinstance(f, NonbondedForce)
|
||||
][0] # fmt: skip
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import gzip
|
||||
import itertools
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import gufe
|
||||
@@ -14,25 +15,78 @@ import openfe
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
def test_gather(benzene_complex_dag, tmpdir):
|
||||
# check that .gather behaves as expected
|
||||
|
||||
@pytest.fixture()
|
||||
def patcher():
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFESolventSetupUnit.run",
|
||||
return_value={
|
||||
"system": Path("system.xml.bz2"),
|
||||
"positions": Path("positions.npy"),
|
||||
"pdb_structure": Path("hybrid_system.pdb"),
|
||||
"selection_indices": np.zeros(100),
|
||||
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
|
||||
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
|
||||
"restraint_geometry": None,
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexSetupUnit.run",
|
||||
return_value={
|
||||
"system": Path("system.xml.bz2"),
|
||||
"positions": Path("positions.npy"),
|
||||
"pdb_structure": Path("hybrid_system.pdb"),
|
||||
"selection_indices": np.zeros(100),
|
||||
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
|
||||
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
|
||||
"restraint_geometry": True,
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.base_afe_units.np.load",
|
||||
return_value=np.zeros(100),
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.base_afe_units.deserialize",
|
||||
return_value="foo",
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexSimUnit.run",
|
||||
return_value={
|
||||
"trajectory": Path("file.nc"),
|
||||
"checkpoint": Path("chk.chk"),
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFESolventSimUnit.run",
|
||||
return_value={
|
||||
"trajectory": Path("file.nc"),
|
||||
"checkpoint": Path("chk.chk"),
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFEComplexAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.abfe_units.ABFESolventAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
):
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_complex_dag,
|
||||
shared_basedir=tmpdir,
|
||||
scratch_basedir=tmpdir,
|
||||
keep_shared=True,
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
def test_gather(benzene_complex_dag, patcher, tmpdir):
|
||||
# check that .gather behaves as expected
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_complex_dag,
|
||||
shared_basedir=tmpdir,
|
||||
scratch_basedir=tmpdir,
|
||||
keep_shared=True,
|
||||
)
|
||||
|
||||
protocol = openmm_afe.AbsoluteBindingProtocol(
|
||||
settings=openmm_afe.AbsoluteBindingProtocol.default_settings(),
|
||||
@@ -43,6 +97,47 @@ def test_gather(benzene_complex_dag, tmpdir):
|
||||
assert isinstance(res, openmm_afe.AbsoluteBindingProtocolResult)
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_complex_dag, patcher, tmpdir):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
|
||||
dag_units = benzene_complex_dag.protocol_units
|
||||
|
||||
for phase in ["solvent", "complex"]:
|
||||
setup_results = {}
|
||||
sim_results = {}
|
||||
analysis_results = {}
|
||||
|
||||
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
|
||||
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
|
||||
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
for u in setup_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir))
|
||||
|
||||
for u in sim_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
sim_results[rid] = u.execute(
|
||||
context=gufe.Context(tmpdir, tmpdir),
|
||||
setup_results=setup_results[rid],
|
||||
)
|
||||
|
||||
for u in a_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
analysis_results[rid] = u.execute(
|
||||
context=gufe.Context(tmpdir, tmpdir),
|
||||
setup_results=setup_results[rid],
|
||||
simulation_results=sim_results[rid],
|
||||
)
|
||||
|
||||
for results in [setup_results, sim_results, analysis_results]:
|
||||
for ret in results.values():
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
|
||||
assert len(setup_results) == len(sim_results) == len(analysis_results) == 3
|
||||
|
||||
|
||||
class TestProtocolResult:
|
||||
@pytest.fixture()
|
||||
def protocolresult(self, abfe_transformation_json_path):
|
||||
@@ -62,7 +157,7 @@ class TestProtocolResult:
|
||||
est = protocolresult.get_estimate()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(-21.71, abs=0.01)
|
||||
assert est.m == pytest.approx(-21.35, abs=0.01)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
@@ -70,7 +165,7 @@ class TestProtocolResult:
|
||||
est = protocolresult.get_uncertainty()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(0.73, abs=0.01)
|
||||
assert est.m == pytest.approx(1.04, abs=0.01)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
@@ -176,12 +271,12 @@ class TestProtocolResult:
|
||||
assert isinstance(geom[0], BoreschRestraintGeometry)
|
||||
assert geom[0].guest_atoms == [1779, 1778, 1777]
|
||||
assert geom[0].host_atoms == [880, 865, 864]
|
||||
assert pytest.approx(geom[0].r_aA0) == 1.083558 * offunit.nanometer
|
||||
assert pytest.approx(geom[0].theta_A0) == 0.6786444 * offunit.radian
|
||||
assert pytest.approx(geom[0].theta_B0) == 1.649905 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_A0) == -0.3640583 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_B0) == 1.892376 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_C0) == -0.6106747 * offunit.radian
|
||||
assert pytest.approx(geom[0].r_aA0, rel=1e-2) == 1.083558 * offunit.nanometer
|
||||
assert pytest.approx(geom[0].theta_A0, rel=1e-2) == 0.711876 * offunit.radian
|
||||
assert pytest.approx(geom[0].theta_B0, rel=1e-2) == 1.687366 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_A0, rel=1e-2) == -0.2164231 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_B0, rel=1e-2) == 1.892376 * offunit.radian
|
||||
assert pytest.approx(geom[0].phi_C0, rel=1e-2) == -0.522031870 * offunit.radian
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"key, expected_size",
|
||||
|
||||
@@ -96,16 +96,36 @@ def test_openmm_run_engine(
|
||||
r = openfe.execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True)
|
||||
|
||||
assert r.ok()
|
||||
for pur in r.protocol_unit_results:
|
||||
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
checkpoint = pur.outputs["last_checkpoint"]
|
||||
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
assert (unit_shared / checkpoint).exists()
|
||||
nc = pur.outputs["nc"]
|
||||
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
|
||||
assert nc.exists()
|
||||
|
||||
# Check outputs of solvent & complex results
|
||||
for phase in ["solvent", "complex"]:
|
||||
purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase]
|
||||
|
||||
# get the path to the simulation unit shared dict
|
||||
for pur in purs:
|
||||
if "Simulation" in pur.name:
|
||||
sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert sim_shared.exists()
|
||||
assert pathlib.Path(sim_shared).is_dir()
|
||||
|
||||
# check the analysis outputs
|
||||
for pur in purs:
|
||||
if "Analysis" not in pur.name:
|
||||
continue
|
||||
|
||||
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
|
||||
# Does the checkpoint file exist?
|
||||
checkpoint = pur.outputs["checkpoint"]
|
||||
assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
assert checkpoint.exists()
|
||||
|
||||
# Does the trajectory file exist?
|
||||
nc = pur.outputs["trajectory"]
|
||||
assert nc == sim_shared / f"{pur.outputs['simtype']}.nc"
|
||||
assert nc.exists()
|
||||
|
||||
# Test results methods that need files present
|
||||
results = protocol.gather([r])
|
||||
|
||||
@@ -3,16 +3,21 @@
|
||||
import gzip
|
||||
|
||||
import pytest
|
||||
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
|
||||
|
||||
import openfe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteBindingComplexUnit,
|
||||
ABFEComplexAnalysisUnit,
|
||||
ABFEComplexSetupUnit,
|
||||
ABFEComplexSimUnit,
|
||||
ABFESolventAnalysisUnit,
|
||||
ABFESolventSetupUnit,
|
||||
ABFESolventSimUnit,
|
||||
AbsoluteBindingProtocol,
|
||||
AbsoluteBindingProtocolResult,
|
||||
AbsoluteBindingSolventUnit,
|
||||
)
|
||||
|
||||
from ..conftest import ModGufeTokenizableTestsMixin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol():
|
||||
@@ -33,18 +38,40 @@ def protocol_units(protocol, benzene_complex_system, T4_protein_component):
|
||||
return list(pus.protocol_units)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, AbsoluteBindingSolventUnit):
|
||||
def _filter_units(pus, classtype):
|
||||
for pu in pus:
|
||||
if isinstance(pu, classtype):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, AbsoluteBindingComplexUnit):
|
||||
return pu
|
||||
def complex_protocol_setup_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFEComplexSetupUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_protocol_sim_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFEComplexSimUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def complex_protocol_analysis_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFEComplexAnalysisUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_setup_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFESolventSetupUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_sim_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFESolventSimUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_analysis_unit(protocol_units):
|
||||
return _filter_units(protocol_units, ABFESolventAnalysisUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -54,7 +81,7 @@ def protocol_result(abfe_transformation_json_path):
|
||||
return pr
|
||||
|
||||
|
||||
class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin):
|
||||
class TestAbsoluteBindingProtocol(ModGufeTokenizableTestsMixin):
|
||||
cls = AbsoluteBindingProtocol
|
||||
key = None
|
||||
repr = "AbsoluteBindingProtocol-"
|
||||
@@ -63,49 +90,68 @@ class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin):
|
||||
def instance(self, protocol):
|
||||
return protocol
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestAbsoluteBindingSolventUnit(GufeTokenizableTestsMixin):
|
||||
cls = AbsoluteBindingSolventUnit
|
||||
repr = "AbsoluteBindingSolventUnit(Absolute Binding, benzene solvent leg"
|
||||
class TestABFESolventSetupUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFESolventSetupUnit
|
||||
repr = "ABFESolventSetupUnit(ABFE Setup: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, solvent_protocol_unit):
|
||||
return solvent_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
def instance(self, solvent_protocol_setup_unit):
|
||||
return solvent_protocol_setup_unit
|
||||
|
||||
|
||||
class TestAbsoluteBindingComplexUnit(GufeTokenizableTestsMixin):
|
||||
cls = AbsoluteBindingComplexUnit
|
||||
repr = "AbsoluteBindingComplexUnit(Absolute Binding, benzene complex leg"
|
||||
class TestABFESolventSimUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFESolventSimUnit
|
||||
repr = "ABFESolventSimUnit(ABFE Simulation: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, complex_protocol_unit):
|
||||
return complex_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
def instance(self, solvent_protocol_sim_unit):
|
||||
return solvent_protocol_sim_unit
|
||||
|
||||
|
||||
class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin):
|
||||
class TestABFESolventAnalysisUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFESolventAnalysisUnit
|
||||
repr = "ABFESolventAnalysisUnit(ABFE Analysis: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, solvent_protocol_analysis_unit):
|
||||
return solvent_protocol_analysis_unit
|
||||
|
||||
|
||||
class TestABFEComplexSetupUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFEComplexSetupUnit
|
||||
repr = "ABFEComplexSetupUnit(ABFE Setup: benzene complex leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, complex_protocol_setup_unit):
|
||||
return complex_protocol_setup_unit
|
||||
|
||||
|
||||
class TestABFEComplexSimUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFEComplexSimUnit
|
||||
repr = "ABFEComplexSimUnit(ABFE Simulation: benzene complex leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, complex_protocol_sim_unit):
|
||||
return complex_protocol_sim_unit
|
||||
|
||||
|
||||
class TestABFEComplexAnalysisUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = ABFEComplexAnalysisUnit
|
||||
repr = "ABFEComplexAnalysisUnit(ABFE Analysis: benzene complex leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, complex_protocol_analysis_unit):
|
||||
return complex_protocol_analysis_unit
|
||||
|
||||
|
||||
class TestAbsoluteBindingProtocolResult(ModGufeTokenizableTestsMixin):
|
||||
cls = AbsoluteBindingProtocolResult
|
||||
key = None
|
||||
repr = "AbsoluteBindingProtocolResult-"
|
||||
@@ -113,10 +159,3 @@ class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin):
|
||||
@pytest.fixture()
|
||||
def instance(self, protocol_result):
|
||||
return protocol_result
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
30
openfe/tests/protocols/openmm_abfe/utils.py
Normal file
30
openfe/tests/protocols/openmm_abfe/utils.py
Normal file
@@ -0,0 +1,30 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
from openfe.protocols.openmm_afe.abfe_units import (
|
||||
ABFEComplexAnalysisUnit,
|
||||
ABFEComplexSetupUnit,
|
||||
ABFEComplexSimUnit,
|
||||
ABFESolventAnalysisUnit,
|
||||
ABFESolventSetupUnit,
|
||||
ABFESolventSimUnit,
|
||||
)
|
||||
|
||||
UNIT_TYPES = {
|
||||
"solvent": {
|
||||
"setup": ABFESolventSetupUnit,
|
||||
"sim": ABFESolventSimUnit,
|
||||
"analysis": ABFESolventAnalysisUnit,
|
||||
},
|
||||
"complex": {
|
||||
"setup": ABFEComplexSetupUnit,
|
||||
"sim": ABFEComplexSimUnit,
|
||||
"analysis": ABFEComplexAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_units(protocol_units, unit_type):
|
||||
"""
|
||||
Helper method to extract setup units.
|
||||
"""
|
||||
return [pu for pu in protocol_units if isinstance(pu, unit_type)]
|
||||
@@ -1,12 +1,9 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import itertools
|
||||
import json
|
||||
import sys
|
||||
from math import sqrt
|
||||
from unittest import mock
|
||||
|
||||
import gufe
|
||||
import mdtraj as mdt
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -18,7 +15,6 @@ from openmm import (
|
||||
CustomNonbondedForce,
|
||||
HarmonicAngleForce,
|
||||
HarmonicBondForce,
|
||||
MonteCarloBarostat,
|
||||
NonbondedForce,
|
||||
PeriodicTorsionForce,
|
||||
)
|
||||
@@ -29,16 +25,15 @@ from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteSolvationProtocol,
|
||||
AbsoluteSolvationSolventUnit,
|
||||
AbsoluteSolvationVacuumUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils import system_validation
|
||||
from openfe.protocols.openmm_utils.charge_generation import (
|
||||
HAS_ESPALOMA_CHARGE,
|
||||
HAS_NAGL,
|
||||
HAS_OPENEYE,
|
||||
)
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def protocol_dry_settings():
|
||||
@@ -68,6 +63,40 @@ def test_serialize_protocol(default_settings):
|
||||
assert protocol == ret
|
||||
|
||||
|
||||
def test_repeat_units(benzene_system):
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(
|
||||
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings()
|
||||
)
|
||||
|
||||
dag = protocol.create(
|
||||
stateA=benzene_system,
|
||||
stateB=ChemicalSystem({"solvent": SolventComponent()}),
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
# 6 protocol unit, 3 per repeat
|
||||
pus = list(dag.protocol_units)
|
||||
assert len(pus) == 18
|
||||
|
||||
# Check info for each repeat
|
||||
for phase in ["solvent", "vacuum"]:
|
||||
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
|
||||
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
|
||||
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
# Should be 3 of each set
|
||||
assert len(setup) == len(sim) == len(analysis) == 3
|
||||
|
||||
# Check that the dag chain is correct
|
||||
for analysis_pu in analysis:
|
||||
repeat_id = analysis_pu.inputs["repeat_id"]
|
||||
setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0]
|
||||
sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0]
|
||||
assert analysis_pu.inputs["setup_results"] == setup_pu
|
||||
assert analysis_pu.inputs["simulation_results"] == sim_pu
|
||||
assert sim_pu.inputs["setup_results"] == setup_pu
|
||||
|
||||
|
||||
def test_create_independent_repeat_ids(benzene_system):
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(
|
||||
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings()
|
||||
@@ -88,9 +117,12 @@ def test_create_independent_repeat_ids(benzene_system):
|
||||
repeat_ids = set()
|
||||
|
||||
for dag in dags:
|
||||
# 3 sets of 6 units
|
||||
assert len(list(dag.protocol_units)) == 18
|
||||
for u in dag.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
|
||||
# squashed by repeat_id, that's 2 sets of 6
|
||||
assert len(repeat_ids) == 12
|
||||
|
||||
|
||||
@@ -143,7 +175,7 @@ def _verify_alchemical_sterics_force_parameters(
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"])
|
||||
def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir):
|
||||
def test_setup_dry_sim_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir):
|
||||
protocol_dry_settings.vacuum_simulation_settings.sampler_method = method
|
||||
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
|
||||
@@ -161,21 +193,32 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 2
|
||||
assert len(prot_units) == 6
|
||||
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
|
||||
vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"])
|
||||
|
||||
assert len(vac_unit) == 1
|
||||
assert len(sol_unit) == 1
|
||||
assert len(vac_setup_unit) == 1
|
||||
assert len(vac_sim_unit) == 1
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
debug = vac_unit[0].run(dry=True)["debug"]
|
||||
vac_sampler = debug["sampler"]
|
||||
assert not vac_sampler.is_periodic
|
||||
setup_results = vac_setup_unit[0].run(dry=True)
|
||||
sim_results = vac_sim_unit[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
|
||||
sampler = sim_results["sampler"]
|
||||
assert isinstance(sampler, MultiStateSampler)
|
||||
assert not sampler.is_periodic
|
||||
assert sampler._thermodynamic_states[0].barostat is None
|
||||
|
||||
# standard system
|
||||
system = debug["system"]
|
||||
system = setup_results["standard_system"]
|
||||
assert system.getNumParticles() == 12
|
||||
assert len(system.getForces()) == 4
|
||||
_assert_num_forces(system, NonbondedForce, 1)
|
||||
@@ -184,7 +227,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
|
||||
_assert_num_forces(system, PeriodicTorsionForce, 1)
|
||||
|
||||
# alchemical system
|
||||
alchem_system = debug["alchem_system"]
|
||||
alchem_system = setup_results["alchem_system"]
|
||||
assert alchem_system.getNumParticles() == 12
|
||||
assert len(alchem_system.getForces()) == 12
|
||||
_assert_num_forces(alchem_system, NonbondedForce, 1)
|
||||
@@ -212,7 +255,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd
|
||||
[0.35, 2.2, 1.5, 0, False],
|
||||
],
|
||||
)
|
||||
def test_alchemical_settings_dry_run_vacuum(
|
||||
def test_alchemical_settings_setup_vacuum(
|
||||
alpha, a, b, c, correction, benzene_system, protocol_dry_settings, tmpdir
|
||||
):
|
||||
"""
|
||||
@@ -238,18 +281,18 @@ def test_alchemical_settings_dry_run_vacuum(
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 2
|
||||
assert len(prot_units) == 6
|
||||
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
|
||||
vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"])
|
||||
|
||||
assert len(vac_unit) == 1
|
||||
assert len(sol_unit) == 1
|
||||
assert len(vac_setup_unit) == 1
|
||||
assert len(vac_sim_unit) == 1
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
debug = vac_unit[0].run(dry=True)["debug"]
|
||||
results = vac_setup_unit[0].run(dry=True)
|
||||
|
||||
alchem_system = debug["alchem_system"]
|
||||
alchem_system = results["alchem_system"]
|
||||
_assert_num_forces(alchem_system, NonbondedForce, 1)
|
||||
_assert_num_forces(alchem_system, CustomNonbondedForce, 4)
|
||||
_assert_num_forces(alchem_system, CustomBondForce, 4)
|
||||
@@ -291,16 +334,16 @@ def test_confgen_fail_AFE(benzene_system, protocol_dry_settings, tmpdir):
|
||||
mapping=None,
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
|
||||
vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
with mock.patch("rdkit.Chem.AllChem.EmbedMultipleConfs", return_value=0):
|
||||
vac_sampler = vac_unit[0].run(dry=True)["debug"]["sampler"]
|
||||
|
||||
assert vac_sampler
|
||||
# If this worked, the system will have been built
|
||||
system = vac_setup_unit[0].run(dry=True)["alchem_system"]
|
||||
assert system
|
||||
|
||||
|
||||
def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
|
||||
def test_setup_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
|
||||
protocol_dry_settings.solvent_output_settings.output_indices = "resname UNK"
|
||||
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
|
||||
@@ -318,19 +361,25 @@ def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir):
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 2
|
||||
sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
|
||||
sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
|
||||
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)]
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
|
||||
assert len(vac_unit) == 1
|
||||
assert len(sol_unit) == 1
|
||||
assert len(sol_setup_unit) == len(sol_sim_unit) == 1
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
|
||||
setup_results = sol_setup_unit[0].run(dry=True)
|
||||
sim_results = sol_sim_unit[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
sol_sampler = sim_results["sampler"]
|
||||
assert sol_sampler.is_periodic
|
||||
|
||||
pdb = mdt.load_pdb("hybrid_system.pdb")
|
||||
pdb = mdt.load_pdb(setup_results["pdb_structure"])
|
||||
assert pdb.n_atoms == 12
|
||||
|
||||
|
||||
@@ -363,14 +412,23 @@ def test_dry_run_vsite_fail(benzene_system, tmpdir, protocol_dry_settings):
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
|
||||
sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
setup_results = sol_setup_unit[0].run(dry=True)
|
||||
with pytest.raises(ValueError, match="are unstable"):
|
||||
_ = sol_unit[0].run(dry=True)
|
||||
sim_results = sol_sim_unit[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
|
||||
|
||||
def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir):
|
||||
def test_setup_dry_sim_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir):
|
||||
protocol_dry_settings.vacuum_forcefield_settings.forcefields = [
|
||||
"amber/ff14SB.xml", # ff14SB protein force field
|
||||
"amber/tip4pew_standard.xml", # FF we are testsing with the fun VS
|
||||
@@ -399,10 +457,20 @@ def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdi
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
|
||||
sol_sim_units = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
|
||||
setup_results = sol_setup_units[0].run(dry=True)
|
||||
sim_results = sol_sim_units[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
sol_sampler = sim_results["sampler"]
|
||||
assert sol_sampler.is_periodic
|
||||
|
||||
|
||||
@@ -425,11 +493,11 @@ def test_dry_run_solv_benzene_noncubic(benzene_system, protocol_dry_settings, tm
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)]
|
||||
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
sampler = sol_unit[0].run(dry=True)["debug"]["sampler"]
|
||||
system = sampler._thermodynamic_states[0].system
|
||||
results = sol_setup_units[0].run(dry=True)
|
||||
system = results["alchem_system"]
|
||||
|
||||
vectors = system.getDefaultPeriodicBoxVectors()
|
||||
width = float(from_openmm(vectors)[0][0].to("nanometer").m)
|
||||
@@ -486,13 +554,13 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s
|
||||
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0]
|
||||
sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)][0]
|
||||
vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
|
||||
sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"])
|
||||
|
||||
# check sol_unit charges
|
||||
with tmpdir.as_cwd():
|
||||
sampler = sol_unit.run(dry=True)["debug"]["sampler"]
|
||||
system = sampler._thermodynamic_states[0].system
|
||||
results = sol_setup_units[0].run(dry=True)
|
||||
system = results["alchem_system"]
|
||||
nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)]
|
||||
|
||||
assert len(nonbond) == 1
|
||||
@@ -506,8 +574,8 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s
|
||||
|
||||
# check vac_unit charges
|
||||
with tmpdir.as_cwd():
|
||||
sampler = vac_unit.run(dry=True)["debug"]["sampler"]
|
||||
system = sampler._thermodynamic_states[0].system
|
||||
results = vac_setup_units[0].run(dry=True)
|
||||
system = results["alchem_system"]
|
||||
nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)]
|
||||
assert len(nonbond) == 4
|
||||
|
||||
@@ -572,12 +640,12 @@ def test_dry_run_charge_backends(
|
||||
dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0]
|
||||
vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"])
|
||||
|
||||
# check vac_unit charges
|
||||
with tmpdir.as_cwd():
|
||||
sampler = vac_unit.run(dry=True)["debug"]["sampler"]
|
||||
system = sampler._thermodynamic_states[0].system
|
||||
results = vac_setup_units[0].run(dry=True)
|
||||
system = results["alchem_system"]
|
||||
nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)]
|
||||
assert len(nonbond) == 4
|
||||
|
||||
@@ -609,187 +677,6 @@ def benzene_solvation_dag(benzene_system, protocol_dry_settings):
|
||||
return protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_solvation_dag, tmpdir):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
|
||||
dag_units = benzene_solvation_dag.protocol_units
|
||||
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
):
|
||||
results = []
|
||||
for u in dag_units:
|
||||
ret = u.execute(context=gufe.Context(tmpdir, tmpdir))
|
||||
results.append(ret)
|
||||
|
||||
solv_repeats = set()
|
||||
vac_repeats = set()
|
||||
for ret in results:
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
if ret.outputs["simtype"] == "vacuum":
|
||||
vac_repeats.add(ret.outputs["repeat_id"])
|
||||
else:
|
||||
solv_repeats.add(ret.outputs["repeat_id"])
|
||||
# Repeat ids are random ints so just check their lengths
|
||||
assert len(vac_repeats) == len(solv_repeats) == 3
|
||||
|
||||
|
||||
def test_gather(benzene_solvation_dag, tmpdir):
|
||||
# check that .gather behaves as expected
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run",
|
||||
return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"},
|
||||
),
|
||||
):
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_solvation_dag,
|
||||
shared_basedir=tmpdir,
|
||||
scratch_basedir=tmpdir,
|
||||
keep_shared=True,
|
||||
)
|
||||
|
||||
protocol = AbsoluteSolvationProtocol(
|
||||
settings=AbsoluteSolvationProtocol.default_settings(),
|
||||
)
|
||||
|
||||
res = protocol.gather([dagres])
|
||||
|
||||
assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult)
|
||||
|
||||
|
||||
class TestProtocolResult:
|
||||
@pytest.fixture()
|
||||
def protocolresult(self, afe_solv_transformation_json):
|
||||
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
return pr
|
||||
|
||||
def test_reload_protocol_result(self, afe_solv_transformation_json):
|
||||
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
assert pr
|
||||
|
||||
def test_get_estimate(self, protocolresult):
|
||||
est = protocolresult.get_estimate()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(-2.47, abs=0.5)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_uncertainty(self, protocolresult):
|
||||
est = protocolresult.get_uncertainty()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(0.2, abs=0.2)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_individual(self, protocolresult):
|
||||
inds = protocolresult.get_individual_estimates()
|
||||
|
||||
assert isinstance(inds, dict)
|
||||
assert isinstance(inds["solvent"], list)
|
||||
assert isinstance(inds["vacuum"], list)
|
||||
assert len(inds["solvent"]) == len(inds["vacuum"]) == 3
|
||||
for e, u in itertools.chain(inds["solvent"], inds["vacuum"]):
|
||||
assert e.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
assert u.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_forwards_etc(self, key, protocolresult):
|
||||
far = protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
assert isinstance(far, dict)
|
||||
assert isinstance(far[key], list)
|
||||
far1 = far[key][0]
|
||||
assert isinstance(far1, dict)
|
||||
|
||||
for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
|
||||
assert k in far1
|
||||
|
||||
if k == "fractions":
|
||||
assert isinstance(far1[k], np.ndarray)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_frwd_reverse_none_return(self, key, protocolresult):
|
||||
# fetch the first result of type key
|
||||
data = [i for i in protocolresult.data[key].values()][0][0]
|
||||
# set the output to None
|
||||
data.outputs["forward_and_reverse_energies"] = None
|
||||
|
||||
# now fetch the analysis results and expect a warning
|
||||
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
|
||||
with pytest.warns(UserWarning, match=wmsg):
|
||||
protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_overlap_matrices(self, key, protocolresult):
|
||||
ovp = protocolresult.get_overlap_matrices()
|
||||
|
||||
assert isinstance(ovp, dict)
|
||||
assert isinstance(ovp[key], list)
|
||||
assert len(ovp[key]) == 3
|
||||
|
||||
ovp1 = ovp[key][0]
|
||||
assert isinstance(ovp1["matrix"], np.ndarray)
|
||||
assert ovp1["matrix"].shape == (14, 14)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_replica_transition_statistics(self, key, protocolresult):
|
||||
rpx = protocolresult.get_replica_transition_statistics()
|
||||
|
||||
assert isinstance(rpx, dict)
|
||||
assert isinstance(rpx[key], list)
|
||||
assert len(rpx[key]) == 3
|
||||
rpx1 = rpx[key][0]
|
||||
assert "eigenvalues" in rpx1
|
||||
assert "matrix" in rpx1
|
||||
assert rpx1["eigenvalues"].shape == (14,)
|
||||
assert rpx1["matrix"].shape == (14, 14)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_equilibration_iterations(self, key, protocolresult):
|
||||
eq = protocolresult.equilibration_iterations()
|
||||
|
||||
assert isinstance(eq, dict)
|
||||
assert isinstance(eq[key], list)
|
||||
assert len(eq[key]) == 3
|
||||
assert all(isinstance(v, float) for v in eq[key])
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_production_iterations(self, key, protocolresult):
|
||||
prod = protocolresult.production_iterations()
|
||||
|
||||
assert isinstance(prod, dict)
|
||||
assert isinstance(prod[key], list)
|
||||
assert len(prod[key]) == 3
|
||||
assert all(isinstance(v, float) for v in prod[key])
|
||||
|
||||
def test_filenotfound_replica_states(self, protocolresult):
|
||||
errmsg = "File could not be found"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocolresult.get_replica_states()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"positions_write_frequency,velocities_write_frequency",
|
||||
[
|
||||
@@ -821,7 +708,6 @@ def test_dry_run_vacuum_write_frequency(
|
||||
stateB = ChemicalSystem({"solvent": SolventComponent()})
|
||||
|
||||
# Create DAG from protocol, get the vacuum and solvent units
|
||||
# and eventually dry run the first solvent unit
|
||||
dag = protocol.create(
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
@@ -829,11 +715,23 @@ def test_dry_run_vacuum_write_frequency(
|
||||
)
|
||||
prot_units = list(dag.protocol_units)
|
||||
|
||||
assert len(prot_units) == 2
|
||||
assert len(prot_units) == 6
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
for u in prot_units:
|
||||
sampler = u.run(dry=True)["debug"]["sampler"]
|
||||
for phase in ["solvent", "vacuum"]:
|
||||
setup_units = _get_units(prot_units, UNIT_TYPES[phase]["setup"])
|
||||
sim_units = _get_units(prot_units, UNIT_TYPES[phase]["sim"])
|
||||
|
||||
with tmpdir.as_cwd():
|
||||
setup_results = setup_units[0].run(dry=True)
|
||||
sim_results = sim_units[0].run(
|
||||
system=setup_results["alchem_system"],
|
||||
positions=setup_results["debug_positions"],
|
||||
selection_indices=setup_results["selection_indices"],
|
||||
box_vectors=setup_results["box_vectors"],
|
||||
alchemical_restraints=False,
|
||||
dry=True,
|
||||
)
|
||||
sampler = sim_results["sampler"]
|
||||
reporter = sampler._reporter
|
||||
if positions_write_frequency:
|
||||
assert reporter.position_interval == positions_write_frequency.m
|
||||
|
||||
278
openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py
Normal file
278
openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import itertools
|
||||
import json
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import gufe
|
||||
import numpy as np
|
||||
import pytest
|
||||
from openff.units import unit as offunit
|
||||
|
||||
import openfe
|
||||
from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols import openmm_afe
|
||||
|
||||
from .utils import UNIT_TYPES, _get_units
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def protocol_dry_settings():
|
||||
settings = openmm_afe.AbsoluteSolvationProtocol.default_settings()
|
||||
settings.vacuum_engine_settings.compute_platform = None
|
||||
settings.solvent_engine_settings.compute_platform = None
|
||||
settings.protocol_repeats = 1
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def benzene_solvation_dag(benzene_system, protocol_dry_settings):
|
||||
protocol_dry_settings.protocol_repeats = 3
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings)
|
||||
|
||||
stateA = benzene_system
|
||||
|
||||
stateB = ChemicalSystem({"solvent": SolventComponent()})
|
||||
|
||||
return protocol.create(stateA=stateA, stateB=stateB, mapping=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patcher():
|
||||
with (
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventSetupUnit.run",
|
||||
return_value={
|
||||
"system": Path("system.xml.bz2"),
|
||||
"positions": Path("positions.npy"),
|
||||
"pdb_structure": Path("hybrid_system.pdb"),
|
||||
"selection_indices": np.zeros(100),
|
||||
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
|
||||
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
|
||||
"restraint_geometry": None,
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSetupUnit.run",
|
||||
return_value={
|
||||
"system": Path("system.xml.bz2"),
|
||||
"positions": Path("positions.npy"),
|
||||
"pdb_structure": Path("hybrid_system.pdb"),
|
||||
"selection_indices": np.zeros(100),
|
||||
"box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm,
|
||||
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
|
||||
"restraint_geometry": None,
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.base_afe_units.np.load",
|
||||
return_value=np.zeros(100),
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.base_afe_units.deserialize",
|
||||
return_value="foo",
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventSimUnit.run",
|
||||
return_value={
|
||||
"trajectory": Path("file.nc"),
|
||||
"checkpoint": Path("chk.chk"),
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSimUnit.run",
|
||||
return_value={
|
||||
"trajectory": Path("file.nc"),
|
||||
"checkpoint": Path("chk.chk"),
|
||||
},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFESolventAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
mock.patch(
|
||||
"openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumAnalysisUnit.run",
|
||||
return_value={"foo": "bar"},
|
||||
),
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def test_gather(benzene_solvation_dag, patcher, tmpdir):
|
||||
# check that .gather behaves as expected
|
||||
dagres = gufe.protocols.execute_DAG(
|
||||
benzene_solvation_dag,
|
||||
shared_basedir=tmpdir,
|
||||
scratch_basedir=tmpdir,
|
||||
keep_shared=True,
|
||||
)
|
||||
|
||||
protocol = openmm_afe.AbsoluteSolvationProtocol(
|
||||
settings=openmm_afe.AbsoluteSolvationProtocol.default_settings(),
|
||||
)
|
||||
|
||||
res = protocol.gather([dagres])
|
||||
|
||||
assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult)
|
||||
|
||||
|
||||
def test_unit_tagging(benzene_solvation_dag, patcher, tmpdir):
|
||||
# test that executing the units includes correct gen and repeat info
|
||||
|
||||
dag_units = benzene_solvation_dag.protocol_units
|
||||
|
||||
for phase in ["solvent", "vacuum"]:
|
||||
setup_results = {}
|
||||
sim_results = {}
|
||||
analysis_results = {}
|
||||
|
||||
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
|
||||
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
|
||||
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
|
||||
|
||||
for u in setup_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir))
|
||||
|
||||
for u in sim_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
sim_results[rid] = u.execute(
|
||||
context=gufe.Context(tmpdir, tmpdir),
|
||||
setup_results=setup_results[rid],
|
||||
)
|
||||
|
||||
for u in a_units:
|
||||
rid = u.inputs["repeat_id"]
|
||||
analysis_results[rid] = u.execute(
|
||||
context=gufe.Context(tmpdir, tmpdir),
|
||||
setup_results=setup_results[rid],
|
||||
simulation_results=sim_results[rid],
|
||||
)
|
||||
|
||||
for results in [setup_results, sim_results, analysis_results]:
|
||||
for ret in results.values():
|
||||
assert isinstance(ret, gufe.ProtocolUnitResult)
|
||||
assert ret.outputs["generation"] == 0
|
||||
|
||||
assert len(setup_results) == len(sim_results) == len(analysis_results) == 3
|
||||
|
||||
|
||||
class TestProtocolResult:
|
||||
@pytest.fixture()
|
||||
def protocolresult(self, afe_solv_transformation_json):
|
||||
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
return pr
|
||||
|
||||
def test_reload_protocol_result(self, afe_solv_transformation_json):
|
||||
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
|
||||
|
||||
pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"])
|
||||
|
||||
assert pr
|
||||
|
||||
def test_get_estimate(self, protocolresult):
|
||||
est = protocolresult.get_estimate()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(-2.47, abs=0.5)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_uncertainty(self, protocolresult):
|
||||
est = protocolresult.get_uncertainty()
|
||||
|
||||
assert est
|
||||
assert est.m == pytest.approx(0.2, abs=0.2)
|
||||
assert isinstance(est, offunit.Quantity)
|
||||
assert est.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
def test_get_individual(self, protocolresult):
|
||||
inds = protocolresult.get_individual_estimates()
|
||||
|
||||
assert isinstance(inds, dict)
|
||||
assert isinstance(inds["solvent"], list)
|
||||
assert isinstance(inds["vacuum"], list)
|
||||
assert len(inds["solvent"]) == len(inds["vacuum"]) == 3
|
||||
for e, u in itertools.chain(inds["solvent"], inds["vacuum"]):
|
||||
assert e.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
assert u.is_compatible_with(offunit.kilojoule_per_mole)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_forwards_etc(self, key, protocolresult):
|
||||
far = protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
assert isinstance(far, dict)
|
||||
assert isinstance(far[key], list)
|
||||
far1 = far[key][0]
|
||||
assert isinstance(far1, dict)
|
||||
|
||||
for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
|
||||
assert k in far1
|
||||
|
||||
if k == "fractions":
|
||||
assert isinstance(far1[k], np.ndarray)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_frwd_reverse_none_return(self, key, protocolresult):
|
||||
# fetch the first result of type key
|
||||
data = [i for i in protocolresult.data[key].values()][0][0]
|
||||
# set the output to None
|
||||
data.outputs["forward_and_reverse_energies"] = None
|
||||
|
||||
# now fetch the analysis results and expect a warning
|
||||
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
|
||||
with pytest.warns(UserWarning, match=wmsg):
|
||||
protocolresult.get_forward_and_reverse_energy_analysis()
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_overlap_matrices(self, key, protocolresult):
|
||||
ovp = protocolresult.get_overlap_matrices()
|
||||
|
||||
assert isinstance(ovp, dict)
|
||||
assert isinstance(ovp[key], list)
|
||||
assert len(ovp[key]) == 3
|
||||
|
||||
ovp1 = ovp[key][0]
|
||||
assert isinstance(ovp1["matrix"], np.ndarray)
|
||||
assert ovp1["matrix"].shape == (14, 14)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_get_replica_transition_statistics(self, key, protocolresult):
|
||||
rpx = protocolresult.get_replica_transition_statistics()
|
||||
|
||||
assert isinstance(rpx, dict)
|
||||
assert isinstance(rpx[key], list)
|
||||
assert len(rpx[key]) == 3
|
||||
rpx1 = rpx[key][0]
|
||||
assert "eigenvalues" in rpx1
|
||||
assert "matrix" in rpx1
|
||||
assert rpx1["eigenvalues"].shape == (14,)
|
||||
assert rpx1["matrix"].shape == (14, 14)
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_equilibration_iterations(self, key, protocolresult):
|
||||
eq = protocolresult.equilibration_iterations()
|
||||
|
||||
assert isinstance(eq, dict)
|
||||
assert isinstance(eq[key], list)
|
||||
assert len(eq[key]) == 3
|
||||
assert all(isinstance(v, float) for v in eq[key])
|
||||
|
||||
@pytest.mark.parametrize("key", ["solvent", "vacuum"])
|
||||
def test_production_iterations(self, key, protocolresult):
|
||||
prod = protocolresult.production_iterations()
|
||||
|
||||
assert isinstance(prod, dict)
|
||||
assert isinstance(prod[key], list)
|
||||
assert len(prod[key]) == 3
|
||||
assert all(isinstance(v, float) for v in prod[key])
|
||||
|
||||
def test_filenotfound_replica_states(self, protocolresult):
|
||||
errmsg = "File could not be found"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocolresult.get_replica_states()
|
||||
@@ -79,16 +79,36 @@ def test_openmm_run_engine(
|
||||
r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True)
|
||||
|
||||
assert r.ok()
|
||||
for pur in r.protocol_unit_results:
|
||||
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
checkpoint = pur.outputs["last_checkpoint"]
|
||||
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
assert (unit_shared / checkpoint).exists()
|
||||
nc = pur.outputs["nc"]
|
||||
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
|
||||
assert nc.exists()
|
||||
|
||||
# Check outputs of solvent & vacuum results
|
||||
for phase in ["solvent", "vacuum"]:
|
||||
purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase]
|
||||
|
||||
# get the path to the simulation unit shared dict
|
||||
for pur in purs:
|
||||
if "Simulation" in pur.name:
|
||||
sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert sim_shared.exists()
|
||||
assert pathlib.Path(sim_shared).is_dir()
|
||||
|
||||
# check the analysis outputs
|
||||
for pur in purs:
|
||||
if "Analysis" not in pur.name:
|
||||
continue
|
||||
|
||||
unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
|
||||
# Does the checkpoint file exist?
|
||||
checkpoint = pur.outputs["checkpoint"]
|
||||
assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
|
||||
assert checkpoint.exists()
|
||||
|
||||
# Does the trajectory file exist?
|
||||
nc = pur.outputs["trajectory"]
|
||||
assert nc == sim_shared / f"{pur.outputs['simtype']}.nc"
|
||||
assert nc.exists()
|
||||
|
||||
# Test results methods that need files present
|
||||
results = protocol.gather([r])
|
||||
|
||||
@@ -4,10 +4,19 @@ import json
|
||||
|
||||
import gufe
|
||||
import pytest
|
||||
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
|
||||
|
||||
import openfe
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AHFESolventAnalysisUnit,
|
||||
AHFESolventSetupUnit,
|
||||
AHFESolventSimUnit,
|
||||
AHFEVacuumAnalysisUnit,
|
||||
AHFEVacuumSetupUnit,
|
||||
AHFEVacuumSimUnit,
|
||||
)
|
||||
|
||||
from ..conftest import ModGufeTokenizableTestsMixin
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -27,18 +36,40 @@ def protocol_units(protocol, benzene_system):
|
||||
return list(pus.protocol_units)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_afe.AbsoluteSolvationSolventUnit):
|
||||
def _filter_units(pus, classtype):
|
||||
for pu in pus:
|
||||
if isinstance(pu, classtype):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vacuum_protocol_unit(protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_afe.AbsoluteSolvationVacuumUnit):
|
||||
return pu
|
||||
def solvent_protocol_setup_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFESolventSetupUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_sim_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFESolventSimUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_analysis_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFESolventAnalysisUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vacuum_protocol_setup_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFEVacuumSetupUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vacuum_protocol_sim_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFEVacuumSimUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vacuum_protocol_analysis_unit(protocol_units):
|
||||
return _filter_units(protocol_units, AHFEVacuumAnalysisUnit)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,7 +79,7 @@ def protocol_result(afe_solv_transformation_json):
|
||||
return pr
|
||||
|
||||
|
||||
class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
|
||||
class TestAbsoluteSolvationProtocol(ModGufeTokenizableTestsMixin):
|
||||
cls = openmm_afe.AbsoluteSolvationProtocol
|
||||
key = None
|
||||
repr = "AbsoluteSolvationProtocol-"
|
||||
@@ -57,49 +88,68 @@ class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
|
||||
def instance(self, protocol):
|
||||
return protocol
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_afe.AbsoluteSolvationSolventUnit
|
||||
repr = "AbsoluteSolvationSolventUnit(Absolute Solvation, benzene solvent leg"
|
||||
class TestAHFESolventSetupUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFESolventSetupUnit
|
||||
repr = "AHFESolventSetupUnit(AHFE Setup: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, solvent_protocol_unit):
|
||||
return solvent_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
def instance(self, solvent_protocol_setup_unit):
|
||||
return solvent_protocol_setup_unit
|
||||
|
||||
|
||||
class TestAbsoluteSolvationVacuumUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_afe.AbsoluteSolvationVacuumUnit
|
||||
repr = "AbsoluteSolvationVacuumUnit(Absolute Solvation, benzene vacuum leg"
|
||||
class TestAHFESolventSimUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFESolventSimUnit
|
||||
repr = "AHFESolventSimUnit(AHFE Simulation: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, vacuum_protocol_unit):
|
||||
return vacuum_protocol_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
def instance(self, solvent_protocol_sim_unit):
|
||||
return solvent_protocol_sim_unit
|
||||
|
||||
|
||||
class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
|
||||
class TestAHFESolventAnalysisUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFESolventAnalysisUnit
|
||||
repr = "AHFESolventAnalysisUnit(AHFE Analysis: benzene solvent leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, solvent_protocol_analysis_unit):
|
||||
return solvent_protocol_analysis_unit
|
||||
|
||||
|
||||
class TestAHFEVacuumSetupUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFEVacuumSetupUnit
|
||||
repr = "AHFEVacuumSetupUnit(AHFE Setup: benzene vacuum leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, vacuum_protocol_setup_unit):
|
||||
return vacuum_protocol_setup_unit
|
||||
|
||||
|
||||
class TestAHFEVacuumSimUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFEVacuumSimUnit
|
||||
repr = "AHFEVacuumSimUnit(AHFE Simulation: benzene vacuum leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, vacuum_protocol_sim_unit):
|
||||
return vacuum_protocol_sim_unit
|
||||
|
||||
|
||||
class TestAHFEVacuumAnalysisUnit(ModGufeTokenizableTestsMixin):
|
||||
cls = AHFEVacuumAnalysisUnit
|
||||
repr = "AHFEVacuumAnalysisUnit(AHFE Analysis: benzene vacuum leg"
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, vacuum_protocol_analysis_unit):
|
||||
return vacuum_protocol_analysis_unit
|
||||
|
||||
|
||||
class TestAbsoluteSolvationProtocolResult(ModGufeTokenizableTestsMixin):
|
||||
cls = openmm_afe.AbsoluteSolvationProtocolResult
|
||||
key = None
|
||||
repr = "AbsoluteSolvationProtocolResult-"
|
||||
@@ -107,10 +157,3 @@ class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
|
||||
@pytest.fixture()
|
||||
def instance(self, protocol_result):
|
||||
return protocol_result
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
Overwrites the base `test_repr` call.
|
||||
"""
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
@@ -8,15 +8,8 @@ from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols import openmm_afe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteSolvationProtocol,
|
||||
AbsoluteSolvationSolventUnit,
|
||||
AbsoluteSolvationVacuumUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils import system_validation
|
||||
from openfe.protocols.openmm_utils.charge_generation import (
|
||||
HAS_ESPALOMA_CHARGE,
|
||||
HAS_NAGL,
|
||||
HAS_OPENEYE,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
||||
31
openfe/tests/protocols/openmm_ahfe/utils.py
Normal file
31
openfe/tests/protocols/openmm_ahfe/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
from openfe.protocols.openmm_afe import (
|
||||
AbsoluteSolvationProtocol,
|
||||
AHFESolventAnalysisUnit,
|
||||
AHFESolventSetupUnit,
|
||||
AHFESolventSimUnit,
|
||||
AHFEVacuumAnalysisUnit,
|
||||
AHFEVacuumSetupUnit,
|
||||
AHFEVacuumSimUnit,
|
||||
)
|
||||
|
||||
UNIT_TYPES = {
|
||||
"solvent": {
|
||||
"setup": AHFESolventSetupUnit,
|
||||
"sim": AHFESolventSimUnit,
|
||||
"analysis": AHFESolventAnalysisUnit,
|
||||
},
|
||||
"vacuum": {
|
||||
"setup": AHFEVacuumSetupUnit,
|
||||
"sim": AHFEVacuumSimUnit,
|
||||
"analysis": AHFEVacuumAnalysisUnit,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_units(protocol_units, unit_type):
|
||||
"""
|
||||
Helper method to extract setup units.
|
||||
"""
|
||||
return [pu for pu in protocol_units if isinstance(pu, unit_type)]
|
||||
@@ -33,7 +33,12 @@ def _get_name(result: dict) -> str:
|
||||
"""
|
||||
|
||||
solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
|
||||
name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
|
||||
try:
|
||||
name = solvent_data["inputs"]["setup_results"]["inputs"]["alchemical_components"]["stateA"][
|
||||
0
|
||||
]["molprops"]["ofe-name"]
|
||||
except KeyError:
|
||||
name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
|
||||
|
||||
return str(name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user