mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
Merge branch 'main' into remove_perses
This commit is contained in:
26
news/validate-septop.rst
Normal file
26
news/validate-septop.rst
Normal file
@@ -0,0 +1,26 @@
|
||||
**Added:**
|
||||
|
||||
* The `validate` method for the SepTopProtocol has been implemented.
|
||||
This means that settings and system validation can mostly be done prior
|
||||
to Protocol execuation by calling
|
||||
`SepTopProtocol.validate(stateA, stateB, mapping=None)`.
|
||||
|
||||
**Changed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Deprecated:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Removed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Fixed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Security:**
|
||||
|
||||
* <news item>
|
||||
@@ -33,6 +33,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
@@ -307,14 +308,14 @@ class SepTopProtocol(gufe.Protocol):
|
||||
return protocol_settings
|
||||
|
||||
@staticmethod
|
||||
def _validate_complex_endstates(
|
||||
def _validate_endstates(
|
||||
stateA: ChemicalSystem,
|
||||
stateB: ChemicalSystem,
|
||||
) -> None:
|
||||
"""
|
||||
A complex transformation is defined (in terms of gufe components)
|
||||
A complex relative transformation is defined (in terms of gufe components)
|
||||
as starting from one or more ligands and a protein in solvent and
|
||||
ending up in a state with one less ligand.
|
||||
ending up in a state with one ligand that is different.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -328,76 +329,55 @@ class SepTopProtocol(gufe.Protocol):
|
||||
ValueError
|
||||
If there is no SolventComponent and no ProteinComponent
|
||||
in either stateA or stateB.
|
||||
If there are no or more than one alchemical components in state A.
|
||||
If there are no or more than one alchemical components in state B.
|
||||
If there are any alchemical components that are not SmallMoleculeComponents.
|
||||
If a change in net charge between the alchemical components is detected.
|
||||
"""
|
||||
# check that there is a protein component
|
||||
if not any(isinstance(comp, ProteinComponent) for comp in stateA.values()):
|
||||
if not stateA.contains(ProteinComponent):
|
||||
errmsg = "No ProteinComponent found in stateA"
|
||||
raise ValueError(errmsg)
|
||||
|
||||
if not any(isinstance(comp, ProteinComponent) for comp in stateB.values()):
|
||||
if not stateB.contains(ProteinComponent):
|
||||
errmsg = "No ProteinComponent found in stateB"
|
||||
raise ValueError(errmsg)
|
||||
|
||||
# check that there is only one protein component
|
||||
system_validation.validate_protein(stateA)
|
||||
system_validation.validate_protein(stateB)
|
||||
|
||||
# check that there is a SolventComponent
|
||||
if not any(isinstance(comp, SolventComponent) for comp in stateA.values()):
|
||||
if not stateA.contains(SolventComponent):
|
||||
errmsg = "No SolventComponent found in stateA"
|
||||
raise ValueError(errmsg)
|
||||
|
||||
if not any(isinstance(comp, SolventComponent) for comp in stateB.values()):
|
||||
if not stateB.contains(SolventComponent):
|
||||
errmsg = "No SolventComponent found in stateB"
|
||||
raise ValueError(errmsg)
|
||||
|
||||
@staticmethod
|
||||
def _validate_alchemical_components(alchemical_components: dict[str, list[Component]]) -> None:
|
||||
"""
|
||||
Checks that the ChemicalSystem alchemical components are correct.
|
||||
# Check the difference between the endstates
|
||||
diff = stateA.component_diff(stateB)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
alchemical_components : Dict[str, list[Component]]
|
||||
Dictionary containing the alchemical components for
|
||||
stateA and stateB.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
* If there are no or more than one alchemical components in state A.
|
||||
* If there are no or more than one alchemical components in state B.
|
||||
* If there are any alchemical components that are not
|
||||
SmallMoleculeComponents
|
||||
* If a change in net charge between the alchemical components is detected.
|
||||
|
||||
Notes
|
||||
-----
|
||||
* Currently doesn't support alchemical components which are not
|
||||
SmallMoleculeComponents.
|
||||
* Currently doesn't support more than one alchemical component
|
||||
being desolvated.
|
||||
"""
|
||||
|
||||
# Crash out if there are less or more than one alchemical components
|
||||
# in state A and B
|
||||
for state in ["stateA", "stateB"]:
|
||||
n = len(alchemical_components[state])
|
||||
if n != 1:
|
||||
raise ValueError(
|
||||
"Exactly one alchemical component must be present in "
|
||||
f"{state}. Found {n} alchemical components."
|
||||
for i, state in enumerate(["stateA", "stateB"]):
|
||||
# Error if there isn't exactly one alchemical component
|
||||
if len(diff[i]) != 1:
|
||||
errmsg = (
|
||||
"Only one alchemical species is supported. "
|
||||
f"Number of unique components found in {state}: {len(diff[i])}."
|
||||
)
|
||||
raise ValueError(errmsg)
|
||||
|
||||
# Crash out if any of the alchemical components are not
|
||||
# SmallMoleculeComponent
|
||||
for state in ["stateA", "stateB"]:
|
||||
for comp in alchemical_components[state]:
|
||||
if not isinstance(comp, SmallMoleculeComponent):
|
||||
raise ValueError(
|
||||
"Only SmallMoleculeComponent alchemical species are supported."
|
||||
)
|
||||
# Error if the component isn't an SMC
|
||||
if not isinstance(diff[i][0], SmallMoleculeComponent):
|
||||
errmsg = (
|
||||
"Only transforming SmallMoleculeComponents are supported "
|
||||
f"by this Protocol. Found a {type(diff[i][0])}."
|
||||
)
|
||||
raise ValueError(errmsg)
|
||||
|
||||
# Raise an error if there is a change in netcharge
|
||||
_check_alchemical_charge_difference(
|
||||
alchemical_components["stateA"][0], alchemical_components["stateB"][0]
|
||||
)
|
||||
# Raise an error if there is a change in net charge
|
||||
_check_alchemical_charge_difference(diff[0][0], diff[1][0])
|
||||
|
||||
@staticmethod
|
||||
def _validate_lambda_schedule(
|
||||
@@ -419,8 +399,10 @@ class SepTopProtocol(gufe.Protocol):
|
||||
ValueError
|
||||
If the number of lambda windows differs for electrostatics and sterics.
|
||||
If the number of replicas does not match the number of lambda windows.
|
||||
Warnings
|
||||
If there are non-zero values for restraints (lambda_restraints).
|
||||
|
||||
TODO
|
||||
----
|
||||
Add a warning if all the lambda restraints are zero? Issue #1945.
|
||||
"""
|
||||
|
||||
lambda_elec_A = lambda_settings.lambda_elec_A
|
||||
@@ -474,32 +456,36 @@ class SepTopProtocol(gufe.Protocol):
|
||||
f"State {state}: lambda {idx}: elec {e} vdW {v}"
|
||||
)
|
||||
|
||||
def _create(
|
||||
def _validate(
|
||||
self,
|
||||
stateA: ChemicalSystem,
|
||||
stateB: ChemicalSystem,
|
||||
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
|
||||
extends: Optional[gufe.ProtocolDAGResult] = None,
|
||||
) -> list[gufe.ProtocolUnit]:
|
||||
# TODO: extensions
|
||||
if extends: # pragma: no-cover
|
||||
raise NotImplementedError("Can't extend simulations yet")
|
||||
mapping: gufe.ComponentMapping | list[gufe.ComponentMappping] | None,
|
||||
extends: gufe.ProtocolDAGResult | None = None,
|
||||
) -> None:
|
||||
# Check we're not trying to extend
|
||||
if extends:
|
||||
# This technically should be NotImplementedError
|
||||
# but gufe.Protocol.validate calls `_validate` wrapped
|
||||
# around a try/except for that error type
|
||||
raise ValueError("Can't extend simulations yet")
|
||||
|
||||
# Validate components and get alchemical components
|
||||
# Check the mappping
|
||||
if mapping is not None:
|
||||
wmsg = "A mapping was passed but is not used by this Protocol"
|
||||
warnings.warn(wmsg)
|
||||
|
||||
# Validate end states
|
||||
system_validation.validate_chemical_system(stateA)
|
||||
system_validation.validate_chemical_system(stateB)
|
||||
self._validate_complex_endstates(stateA, stateB)
|
||||
alchem_comps = system_validation.get_alchemical_components(
|
||||
stateA,
|
||||
stateB,
|
||||
)
|
||||
self._validate_alchemical_components(alchem_comps)
|
||||
self._validate_endstates(stateA, stateB)
|
||||
|
||||
# Validate the lambda schedule
|
||||
self._validate_lambda_schedule(
|
||||
self.settings.solvent_lambda_settings,
|
||||
self.settings.solvent_simulation_settings,
|
||||
)
|
||||
|
||||
self._validate_lambda_schedule(
|
||||
self.settings.complex_lambda_settings,
|
||||
self.settings.complex_simulation_settings,
|
||||
@@ -514,15 +500,30 @@ class SepTopProtocol(gufe.Protocol):
|
||||
settings_validation.validate_openmm_solvation_settings(
|
||||
self.settings.solvent_solvation_settings
|
||||
)
|
||||
|
||||
# Validate protein component
|
||||
system_validation.validate_protein(stateA)
|
||||
settings_validation.validate_openmm_solvation_settings(
|
||||
self.settings.complex_solvation_settings
|
||||
)
|
||||
|
||||
# Validate the barostat used in combination with the protein component
|
||||
system_validation.validate_barostat(
|
||||
stateA, self.settings.complex_integrator_settings.barostat
|
||||
)
|
||||
|
||||
def _create(
|
||||
self,
|
||||
stateA: ChemicalSystem,
|
||||
stateB: ChemicalSystem,
|
||||
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None,
|
||||
extends: gufe.ProtocolDAGResult | None = None,
|
||||
) -> list[gufe.ProtocolUnit]:
|
||||
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
|
||||
|
||||
# Get the alchemical components
|
||||
alchem_comps = system_validation.get_alchemical_components(
|
||||
stateA,
|
||||
stateB,
|
||||
)
|
||||
|
||||
# Create list units for complex and solvent transforms
|
||||
def create_setup_units(unit_cls, leg):
|
||||
return [
|
||||
|
||||
@@ -40,11 +40,6 @@ from openfe.protocols.openmm_septop import (
|
||||
SepTopSolventRunUnit,
|
||||
SepTopSolventSetupUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_septop.equil_septop_method import (
|
||||
_check_alchemical_charge_difference,
|
||||
)
|
||||
from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings
|
||||
from openfe.protocols.openmm_utils import system_validation
|
||||
from openfe.protocols.openmm_utils.serialization import deserialize
|
||||
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
|
||||
from openfe.tests.protocols.conftest import compute_energy
|
||||
@@ -79,191 +74,6 @@ def default_settings():
|
||||
return s
|
||||
|
||||
|
||||
def test_create_default_settings():
|
||||
settings = SepTopProtocol.default_settings()
|
||||
assert settings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]},
|
||||
{"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]},
|
||||
{"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]},
|
||||
],
|
||||
)
|
||||
def test_incorrect_window_settings(val, default_settings):
|
||||
errmsg = "Lambda windows must be between 0 and 1."
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_A = val["elec"]
|
||||
lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec": [0.0, 0.1, 0.0],
|
||||
"vdw": [0.0, 1.0, 1.0],
|
||||
"restraints": [0.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [0.0, 0.0, 0.0],
|
||||
"vdw": [0.0, 1.0, 0.0],
|
||||
"restraints": [0.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [0.0, 0.0, 0.0],
|
||||
"vdw": [0.0, 1.0, 1.0],
|
||||
"restraints": [0.0, 1.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_monotonic_lambda_windows_A(val, default_settings):
|
||||
errmsg = "The lambda schedule for ligand A"
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_A = val["elec"]
|
||||
lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec": [1.0, 0.1, 1.0],
|
||||
"vdw": [1.0, 1.0, 1.0],
|
||||
"restraints": [1.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [1.0, 1.0, 1.0],
|
||||
"vdw": [1.0, 0.0, 1.0],
|
||||
"restraints": [1.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [1.0, 1.0, 1.0],
|
||||
"vdw": [1.0, 1.0, 1.0],
|
||||
"restraints": [1.0, 0.0, 1.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_monotonic_lambda_windows_B(val, default_settings):
|
||||
errmsg = "The lambda schedule for ligand B"
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_B = val["elec"]
|
||||
lambda_settings.lambda_vdw_B = val["vdw"]
|
||||
lambda_settings.lambda_restraints_B = val["restraints"]
|
||||
|
||||
|
||||
def test_output_induces_not_all(default_settings):
|
||||
errmsg = "Equilibration simulations need to output the full system"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
default_settings.complex_equil_output_settings.output_indices = "no water"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec_A": [1.0, 1.0],
|
||||
"vdw_A": [0.0, 1.0],
|
||||
"restraints_A": [0.0, 0.0],
|
||||
"elec_B": [1.0, 1.0],
|
||||
"vdw_B": [1.0, 1.0],
|
||||
"restraints_B": [0.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nreplicas(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
|
||||
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
|
||||
n_replicas = 3
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
f"Number of replicas {n_replicas} does not equal the"
|
||||
f" number of lambda windows {len(val['vdw_A'])}"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nwindows(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
n_replicas = 3
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
"Components elec, vdw, and restraints must have equal amount of lambda "
|
||||
"windows. Got 3 and 19 elec lambda windows"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec_A": [0.0, 1.0],
|
||||
"vdw_A": [1.0, 1.0],
|
||||
"restraints_A": [0.0, 0.0],
|
||||
"elec_B": [1.0, 1.0],
|
||||
"vdw_B": [1.0, 1.0],
|
||||
"restraints_B": [0.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nakedcharge(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
|
||||
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
|
||||
n_replicas = 2
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
default_settings.solvent_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
"There are states along this lambda schedule "
|
||||
"where there are atoms with charges but no LJ "
|
||||
"interactions: State A: l"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.solvent_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
def test_create_default_protocol(default_settings):
|
||||
# this is roughly how it should be created
|
||||
protocol = SepTopProtocol(
|
||||
@@ -315,172 +125,6 @@ def test_create_independent_repeat_ids(
|
||||
assert len(repeat_ids) == 24
|
||||
|
||||
|
||||
def test_check_alchem_charge_diff(charged_benzene_modifications):
|
||||
errmsg = "A charge difference of 1"
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
_check_alchemical_charge_difference(
|
||||
charged_benzene_modifications["benzene"],
|
||||
charged_benzene_modifications["benzoic_acid"],
|
||||
)
|
||||
|
||||
|
||||
def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings):
|
||||
protocol = SepTopProtocol(
|
||||
settings=default_settings,
|
||||
)
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"benzene": charged_benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{
|
||||
"benzoic": charged_benzene_modifications["benzoic_acid"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
errmsg = "A charge difference of 1"
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocol.create(
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "benzene_system", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "benzene_system"),
|
||||
],
|
||||
)
|
||||
def test_validate_complex_endstates_protcomp(request, system_A, system_B, fail_endstate):
|
||||
with pytest.raises(ValueError, match="No ProteinComponent found"):
|
||||
SepTopProtocol._validate_complex_endstates(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def T4L_benzene_vacuum(benzene_modifications, T4_protein_component):
|
||||
return openfe.ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "T4L_benzene_vacuum", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "T4L_benzene_vacuum"),
|
||||
],
|
||||
)
|
||||
def test_validate_complex_endstates_nosolvcomp(
|
||||
request,
|
||||
system_A,
|
||||
system_B,
|
||||
fail_endstate,
|
||||
):
|
||||
with pytest.raises(ValueError, match="No SolventComponent found"):
|
||||
SepTopProtocol._validate_complex_endstates(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def T4L_system(T4_protein_component):
|
||||
return openfe.ChemicalSystem(
|
||||
{
|
||||
"solvent": openfe.SolventComponent(),
|
||||
"protein": T4_protein_component,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "T4L_system", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "T4L_system"),
|
||||
],
|
||||
)
|
||||
def test_validate_alchem_comps_missing(
|
||||
request,
|
||||
system_A,
|
||||
system_B,
|
||||
fail_endstate,
|
||||
):
|
||||
alchem_comps = system_validation.get_alchemical_components(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=f"one alchemical component must be present in {fail_endstate}.",
|
||||
):
|
||||
SepTopProtocol._validate_alchemical_components(alchem_comps)
|
||||
|
||||
|
||||
def test_validate_alchem_comps_toomanyA(
|
||||
benzene_modifications,
|
||||
T4_protein_component,
|
||||
):
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"toluene": benzene_modifications["toluene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{
|
||||
"phenol": benzene_modifications["phenol"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
|
||||
|
||||
assert len(alchem_comps["stateA"]) == 2
|
||||
|
||||
assert len(alchem_comps["stateB"]) == 1
|
||||
|
||||
with pytest.raises(ValueError, match="present in stateA. Found 2 alchemical components."):
|
||||
SepTopProtocol._validate_alchemical_components(alchem_comps)
|
||||
|
||||
|
||||
def test_validate_alchem_nonsmc(
|
||||
benzene_modifications,
|
||||
T4_protein_component,
|
||||
):
|
||||
stateA = ChemicalSystem(
|
||||
{"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{"benzene": benzene_modifications["benzene"], "protein": T4_protein_component}
|
||||
)
|
||||
|
||||
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
|
||||
|
||||
with pytest.raises(ValueError, match="Only SmallMoleculeComponent alchemical"):
|
||||
SepTopProtocol._validate_alchemical_components(alchem_comps)
|
||||
|
||||
|
||||
# Tests for the alchemical systems. This tests were modified from
|
||||
# femto (https://github.com/Psivant/femto/tree/main)
|
||||
def compute_interaction_energy(
|
||||
@@ -1816,28 +1460,3 @@ class TestA2AMembraneDryRun:
|
||||
# Check the PDB
|
||||
pdb = md.load_pdb("alchemical_system.pdb")
|
||||
assert pdb.n_atoms == (self.num_ligand_atoms_A + self.num_ligand_atoms_B)
|
||||
|
||||
|
||||
def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings):
|
||||
settings = SepTopProtocol._adaptive_settings(
|
||||
toluene_complex_system, toluene_complex_system, default_settings
|
||||
)
|
||||
|
||||
assert isinstance(settings, SepTopSettings)
|
||||
# Should use default barostat since no ProteinMembraneComponent
|
||||
assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat"
|
||||
|
||||
|
||||
def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands):
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"ligandA": a2a_ligands[0],
|
||||
"protein": a2a_protein_membrane_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
settings = SepTopProtocol._adaptive_settings(stateA, stateA)
|
||||
assert isinstance(settings, SepTopSettings)
|
||||
# Barostat should have been updated
|
||||
assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat"
|
||||
|
||||
139
src/openfe/tests/protocols/openmm_septop/test_septop_settings.py
Normal file
139
src/openfe/tests/protocols/openmm_septop/test_septop_settings.py
Normal file
@@ -0,0 +1,139 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
|
||||
import pytest
|
||||
|
||||
from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols.openmm_septop import (
|
||||
SepTopProtocol,
|
||||
)
|
||||
from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def protocol_dry_settings():
|
||||
# a set of settings for dry run tests
|
||||
s = SepTopProtocol.default_settings()
|
||||
s.engine_settings.compute_platform = None
|
||||
s.protocol_repeats = 1
|
||||
return s
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def default_settings():
|
||||
s = SepTopProtocol.default_settings()
|
||||
return s
|
||||
|
||||
|
||||
def test_create_default_settings():
|
||||
settings = SepTopProtocol.default_settings()
|
||||
assert settings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]},
|
||||
{"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]},
|
||||
{"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]},
|
||||
],
|
||||
)
|
||||
def test_incorrect_window_settings(val, default_settings):
|
||||
errmsg = "Lambda windows must be between 0 and 1."
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_A = val["elec"]
|
||||
lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec": [0.0, 0.1, 0.0],
|
||||
"vdw": [0.0, 1.0, 1.0],
|
||||
"restraints": [0.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [0.0, 0.0, 0.0],
|
||||
"vdw": [0.0, 1.0, 0.0],
|
||||
"restraints": [0.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [0.0, 0.0, 0.0],
|
||||
"vdw": [0.0, 1.0, 1.0],
|
||||
"restraints": [0.0, 1.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_monotonic_lambda_windows_A(val, default_settings):
|
||||
errmsg = "The lambda schedule for ligand A"
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_A = val["elec"]
|
||||
lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec": [1.0, 0.1, 1.0],
|
||||
"vdw": [1.0, 1.0, 1.0],
|
||||
"restraints": [1.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [1.0, 1.0, 1.0],
|
||||
"vdw": [1.0, 0.0, 1.0],
|
||||
"restraints": [1.0, 1.0, 1.0],
|
||||
},
|
||||
{
|
||||
"elec": [1.0, 1.0, 1.0],
|
||||
"vdw": [1.0, 1.0, 1.0],
|
||||
"restraints": [1.0, 0.0, 1.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_monotonic_lambda_windows_B(val, default_settings):
|
||||
errmsg = "The lambda schedule for ligand B"
|
||||
lambda_settings = default_settings.complex_lambda_settings
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
lambda_settings.lambda_elec_B = val["elec"]
|
||||
lambda_settings.lambda_vdw_B = val["vdw"]
|
||||
lambda_settings.lambda_restraints_B = val["restraints"]
|
||||
|
||||
|
||||
def test_output_induces_not_all(default_settings):
|
||||
errmsg = "Equilibration simulations need to output the full system"
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
default_settings.complex_equil_output_settings.output_indices = "no water"
|
||||
|
||||
|
||||
def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings):
|
||||
settings = SepTopProtocol._adaptive_settings(
|
||||
toluene_complex_system, toluene_complex_system, default_settings
|
||||
)
|
||||
|
||||
assert isinstance(settings, SepTopSettings)
|
||||
# Should use default barostat since no ProteinMembraneComponent
|
||||
assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat"
|
||||
|
||||
|
||||
def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands):
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"ligandA": a2a_ligands[0],
|
||||
"protein": a2a_protein_membrane_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
settings = SepTopProtocol._adaptive_settings(stateA, stateA)
|
||||
assert isinstance(settings, SepTopSettings)
|
||||
# Barostat should have been updated
|
||||
assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat"
|
||||
@@ -0,0 +1,295 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import pytest
|
||||
|
||||
import openfe
|
||||
from openfe import ChemicalSystem, SolventComponent
|
||||
from openfe.protocols.openmm_septop import (
|
||||
SepTopProtocol,
|
||||
)
|
||||
from openfe.protocols.openmm_septop.equil_septop_method import (
|
||||
_check_alchemical_charge_difference,
|
||||
)
|
||||
from openfe.protocols.openmm_utils import system_validation
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def default_settings():
|
||||
s = SepTopProtocol.default_settings()
|
||||
return s
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec_A": [1.0, 1.0],
|
||||
"vdw_A": [0.0, 1.0],
|
||||
"restraints_A": [0.0, 0.0],
|
||||
"elec_B": [1.0, 1.0],
|
||||
"vdw_B": [1.0, 1.0],
|
||||
"restraints_B": [0.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nreplicas(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
|
||||
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
|
||||
n_replicas = 3
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
f"Number of replicas {n_replicas} does not equal the"
|
||||
f" number of lambda windows {len(val['vdw_A'])}"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nwindows(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"]
|
||||
n_replicas = 3
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
"Components elec, vdw, and restraints must have equal amount of lambda "
|
||||
"windows. Got 3 and 19 elec lambda windows"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"val",
|
||||
[
|
||||
{
|
||||
"elec_A": [0.0, 1.0],
|
||||
"vdw_A": [1.0, 1.0],
|
||||
"restraints_A": [0.0, 0.0],
|
||||
"elec_B": [1.0, 1.0],
|
||||
"vdw_B": [1.0, 1.0],
|
||||
"restraints_B": [0.0, 0.0],
|
||||
},
|
||||
],
|
||||
)
|
||||
def test_validate_lambda_schedule_nakedcharge(val, default_settings):
|
||||
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
|
||||
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
|
||||
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
|
||||
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
|
||||
n_replicas = 2
|
||||
default_settings.complex_simulation_settings.n_replicas = n_replicas
|
||||
default_settings.solvent_simulation_settings.n_replicas = n_replicas
|
||||
errmsg = (
|
||||
"There are states along this lambda schedule "
|
||||
"where there are atoms with charges but no LJ "
|
||||
"interactions: State A: l"
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.complex_simulation_settings,
|
||||
)
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_lambda_schedule(
|
||||
default_settings.complex_lambda_settings,
|
||||
default_settings.solvent_simulation_settings,
|
||||
)
|
||||
|
||||
|
||||
def test_check_alchem_charge_diff(charged_benzene_modifications):
|
||||
errmsg = "A charge difference of 1"
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
_check_alchemical_charge_difference(
|
||||
charged_benzene_modifications["benzene"],
|
||||
charged_benzene_modifications["benzoic_acid"],
|
||||
)
|
||||
|
||||
|
||||
def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings):
|
||||
protocol = SepTopProtocol(
|
||||
settings=default_settings,
|
||||
)
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"benzene": charged_benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{
|
||||
"benzoic": charged_benzene_modifications["benzoic_acid"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
errmsg = "A charge difference of 1"
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
protocol.create(
|
||||
stateA=stateA,
|
||||
stateB=stateB,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "benzene_system", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "benzene_system"),
|
||||
],
|
||||
)
|
||||
def test_validate_endstates_protcomp(request, system_A, system_B, fail_endstate):
|
||||
with pytest.raises(ValueError, match="No ProteinComponent found"):
|
||||
SepTopProtocol._validate_endstates(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def T4L_benzene_vacuum(benzene_modifications, T4_protein_component):
|
||||
return openfe.ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "T4L_benzene_vacuum", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "T4L_benzene_vacuum"),
|
||||
],
|
||||
)
|
||||
def test_validate_endstates_nosolvcomp(
|
||||
request,
|
||||
system_A,
|
||||
system_B,
|
||||
fail_endstate,
|
||||
):
|
||||
with pytest.raises(ValueError, match="No SolventComponent found"):
|
||||
SepTopProtocol._validate_endstates(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def T4L_system(T4_protein_component):
|
||||
return openfe.ChemicalSystem(
|
||||
{
|
||||
"solvent": openfe.SolventComponent(),
|
||||
"protein": T4_protein_component,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"fail_endstate, system_A, system_B",
|
||||
[
|
||||
("stateA", "T4L_system", "benzene_complex_system"),
|
||||
("stateB", "benzene_complex_system", "T4L_system"),
|
||||
],
|
||||
)
|
||||
def test_validate_alchem_comps_missing(
|
||||
request,
|
||||
system_A,
|
||||
system_B,
|
||||
fail_endstate,
|
||||
):
|
||||
errmsg = (
|
||||
"Only one alchemical species is supported. "
|
||||
f"Number of unique components found in {fail_endstate}"
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=errmsg,
|
||||
):
|
||||
SepTopProtocol._validate_endstates(
|
||||
request.getfixturevalue(system_A),
|
||||
request.getfixturevalue(system_B),
|
||||
)
|
||||
|
||||
|
||||
def test_validate_alchem_comps_toomanyA(
|
||||
benzene_modifications,
|
||||
T4_protein_component,
|
||||
):
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"toluene": benzene_modifications["toluene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{
|
||||
"phenol": benzene_modifications["phenol"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
|
||||
|
||||
assert len(alchem_comps["stateA"]) == 2
|
||||
|
||||
assert len(alchem_comps["stateB"]) == 1
|
||||
|
||||
errmsg = (
|
||||
"Only one alchemical species is supported. Number of unique components found in stateA: 2."
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_endstates(stateA, stateB)
|
||||
|
||||
|
||||
def test_validate_alchem_nonsmc(
|
||||
benzene_modifications,
|
||||
T4_protein_component,
|
||||
):
|
||||
stateA = ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(neutralize=False),
|
||||
}
|
||||
)
|
||||
|
||||
stateB = ChemicalSystem(
|
||||
{
|
||||
"benzene": benzene_modifications["benzene"],
|
||||
"protein": T4_protein_component,
|
||||
"solvent": SolventComponent(),
|
||||
}
|
||||
)
|
||||
|
||||
errmsg = "Only transforming SmallMoleculeComponents are supported by this Protocol."
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
SepTopProtocol._validate_endstates(stateA, stateB)
|
||||
Reference in New Issue
Block a user