Merge branch 'main' into remove_perses

This commit is contained in:
Alyssa Travitz
2026-04-17 09:52:54 -05:00
committed by GitHub
5 changed files with 534 additions and 454 deletions

26
news/validate-septop.rst Normal file
View File

@@ -0,0 +1,26 @@
**Added:**
* The `validate` method for the SepTopProtocol has been implemented.
This means that settings and system validation can mostly be done prior
to Protocol execuation by calling
`SepTopProtocol.validate(stateA, stateB, mapping=None)`.
**Changed:**
* <news item>
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View File

@@ -33,6 +33,7 @@ from __future__ import annotations
import logging
import uuid
import warnings
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
@@ -307,14 +308,14 @@ class SepTopProtocol(gufe.Protocol):
return protocol_settings
@staticmethod
def _validate_complex_endstates(
def _validate_endstates(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
A complex transformation is defined (in terms of gufe components)
A complex relative transformation is defined (in terms of gufe components)
as starting from one or more ligands and a protein in solvent and
ending up in a state with one less ligand.
ending up in a state with one ligand that is different.
Parameters
----------
@@ -328,76 +329,55 @@ class SepTopProtocol(gufe.Protocol):
ValueError
If there is no SolventComponent and no ProteinComponent
in either stateA or stateB.
If there are no or more than one alchemical components in state A.
If there are no or more than one alchemical components in state B.
If there are any alchemical components that are not SmallMoleculeComponents.
If a change in net charge between the alchemical components is detected.
"""
# check that there is a protein component
if not any(isinstance(comp, ProteinComponent) for comp in stateA.values()):
if not stateA.contains(ProteinComponent):
errmsg = "No ProteinComponent found in stateA"
raise ValueError(errmsg)
if not any(isinstance(comp, ProteinComponent) for comp in stateB.values()):
if not stateB.contains(ProteinComponent):
errmsg = "No ProteinComponent found in stateB"
raise ValueError(errmsg)
# check that there is only one protein component
system_validation.validate_protein(stateA)
system_validation.validate_protein(stateB)
# check that there is a SolventComponent
if not any(isinstance(comp, SolventComponent) for comp in stateA.values()):
if not stateA.contains(SolventComponent):
errmsg = "No SolventComponent found in stateA"
raise ValueError(errmsg)
if not any(isinstance(comp, SolventComponent) for comp in stateB.values()):
if not stateB.contains(SolventComponent):
errmsg = "No SolventComponent found in stateB"
raise ValueError(errmsg)
@staticmethod
def _validate_alchemical_components(alchemical_components: dict[str, list[Component]]) -> None:
"""
Checks that the ChemicalSystem alchemical components are correct.
# Check the difference between the endstates
diff = stateA.component_diff(stateB)
Parameters
----------
alchemical_components : Dict[str, list[Component]]
Dictionary containing the alchemical components for
stateA and stateB.
Raises
------
ValueError
* If there are no or more than one alchemical components in state A.
* If there are no or more than one alchemical components in state B.
* If there are any alchemical components that are not
SmallMoleculeComponents
* If a change in net charge between the alchemical components is detected.
Notes
-----
* Currently doesn't support alchemical components which are not
SmallMoleculeComponents.
* Currently doesn't support more than one alchemical component
being desolvated.
"""
# Crash out if there are less or more than one alchemical components
# in state A and B
for state in ["stateA", "stateB"]:
n = len(alchemical_components[state])
if n != 1:
raise ValueError(
"Exactly one alchemical component must be present in "
f"{state}. Found {n} alchemical components."
for i, state in enumerate(["stateA", "stateB"]):
# Error if there isn't exactly one alchemical component
if len(diff[i]) != 1:
errmsg = (
"Only one alchemical species is supported. "
f"Number of unique components found in {state}: {len(diff[i])}."
)
raise ValueError(errmsg)
# Crash out if any of the alchemical components are not
# SmallMoleculeComponent
for state in ["stateA", "stateB"]:
for comp in alchemical_components[state]:
if not isinstance(comp, SmallMoleculeComponent):
raise ValueError(
"Only SmallMoleculeComponent alchemical species are supported."
)
# Error if the component isn't an SMC
if not isinstance(diff[i][0], SmallMoleculeComponent):
errmsg = (
"Only transforming SmallMoleculeComponents are supported "
f"by this Protocol. Found a {type(diff[i][0])}."
)
raise ValueError(errmsg)
# Raise an error if there is a change in netcharge
_check_alchemical_charge_difference(
alchemical_components["stateA"][0], alchemical_components["stateB"][0]
)
# Raise an error if there is a change in net charge
_check_alchemical_charge_difference(diff[0][0], diff[1][0])
@staticmethod
def _validate_lambda_schedule(
@@ -419,8 +399,10 @@ class SepTopProtocol(gufe.Protocol):
ValueError
If the number of lambda windows differs for electrostatics and sterics.
If the number of replicas does not match the number of lambda windows.
Warnings
If there are non-zero values for restraints (lambda_restraints).
TODO
----
Add a warning if all the lambda restraints are zero? Issue #1945.
"""
lambda_elec_A = lambda_settings.lambda_elec_A
@@ -474,32 +456,36 @@ class SepTopProtocol(gufe.Protocol):
f"State {state}: lambda {idx}: elec {e} vdW {v}"
)
def _create(
def _validate(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# TODO: extensions
if extends: # pragma: no-cover
raise NotImplementedError("Can't extend simulations yet")
mapping: gufe.ComponentMapping | list[gufe.ComponentMappping] | None,
extends: gufe.ProtocolDAGResult | None = None,
) -> None:
# Check we're not trying to extend
if extends:
# This technically should be NotImplementedError
# but gufe.Protocol.validate calls `_validate` wrapped
# around a try/except for that error type
raise ValueError("Can't extend simulations yet")
# Validate components and get alchemical components
# Check the mappping
if mapping is not None:
wmsg = "A mapping was passed but is not used by this Protocol"
warnings.warn(wmsg)
# Validate end states
system_validation.validate_chemical_system(stateA)
system_validation.validate_chemical_system(stateB)
self._validate_complex_endstates(stateA, stateB)
alchem_comps = system_validation.get_alchemical_components(
stateA,
stateB,
)
self._validate_alchemical_components(alchem_comps)
self._validate_endstates(stateA, stateB)
# Validate the lambda schedule
self._validate_lambda_schedule(
self.settings.solvent_lambda_settings,
self.settings.solvent_simulation_settings,
)
self._validate_lambda_schedule(
self.settings.complex_lambda_settings,
self.settings.complex_simulation_settings,
@@ -514,15 +500,30 @@ class SepTopProtocol(gufe.Protocol):
settings_validation.validate_openmm_solvation_settings(
self.settings.solvent_solvation_settings
)
# Validate protein component
system_validation.validate_protein(stateA)
settings_validation.validate_openmm_solvation_settings(
self.settings.complex_solvation_settings
)
# Validate the barostat used in combination with the protein component
system_validation.validate_barostat(
stateA, self.settings.complex_integrator_settings.barostat
)
def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None,
extends: gufe.ProtocolDAGResult | None = None,
) -> list[gufe.ProtocolUnit]:
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
# Get the alchemical components
alchem_comps = system_validation.get_alchemical_components(
stateA,
stateB,
)
# Create list units for complex and solvent transforms
def create_setup_units(unit_cls, leg):
return [

View File

@@ -40,11 +40,6 @@ from openfe.protocols.openmm_septop import (
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
from openfe.protocols.openmm_septop.equil_septop_method import (
_check_alchemical_charge_difference,
)
from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings
from openfe.protocols.openmm_utils import system_validation
from openfe.protocols.openmm_utils.serialization import deserialize
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
from openfe.tests.protocols.conftest import compute_energy
@@ -79,191 +74,6 @@ def default_settings():
return s
def test_create_default_settings():
settings = SepTopProtocol.default_settings()
assert settings
@pytest.mark.parametrize(
"val",
[
{"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]},
{"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]},
{"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]},
],
)
def test_incorrect_window_settings(val, default_settings):
errmsg = "Lambda windows must be between 0 and 1."
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_A = val["elec"]
lambda_settings.lambda_vdw_A = val["vdw"]
lambda_settings.lambda_restraints_A = val["restraints"]
@pytest.mark.parametrize(
"val",
[
{
"elec": [0.0, 0.1, 0.0],
"vdw": [0.0, 1.0, 1.0],
"restraints": [0.0, 1.0, 1.0],
},
{
"elec": [0.0, 0.0, 0.0],
"vdw": [0.0, 1.0, 0.0],
"restraints": [0.0, 1.0, 1.0],
},
{
"elec": [0.0, 0.0, 0.0],
"vdw": [0.0, 1.0, 1.0],
"restraints": [0.0, 1.0, 0.0],
},
],
)
def test_monotonic_lambda_windows_A(val, default_settings):
errmsg = "The lambda schedule for ligand A"
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_A = val["elec"]
lambda_settings.lambda_vdw_A = val["vdw"]
lambda_settings.lambda_restraints_A = val["restraints"]
@pytest.mark.parametrize(
"val",
[
{
"elec": [1.0, 0.1, 1.0],
"vdw": [1.0, 1.0, 1.0],
"restraints": [1.0, 1.0, 1.0],
},
{
"elec": [1.0, 1.0, 1.0],
"vdw": [1.0, 0.0, 1.0],
"restraints": [1.0, 1.0, 1.0],
},
{
"elec": [1.0, 1.0, 1.0],
"vdw": [1.0, 1.0, 1.0],
"restraints": [1.0, 0.0, 1.0],
},
],
)
def test_monotonic_lambda_windows_B(val, default_settings):
errmsg = "The lambda schedule for ligand B"
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_B = val["elec"]
lambda_settings.lambda_vdw_B = val["vdw"]
lambda_settings.lambda_restraints_B = val["restraints"]
def test_output_induces_not_all(default_settings):
errmsg = "Equilibration simulations need to output the full system"
with pytest.raises(ValueError, match=errmsg):
default_settings.complex_equil_output_settings.output_indices = "no water"
@pytest.mark.parametrize(
"val",
[
{
"elec_A": [1.0, 1.0],
"vdw_A": [0.0, 1.0],
"restraints_A": [0.0, 0.0],
"elec_B": [1.0, 1.0],
"vdw_B": [1.0, 1.0],
"restraints_B": [0.0, 0.0],
},
],
)
def test_validate_lambda_schedule_nreplicas(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
n_replicas = 3
default_settings.complex_simulation_settings.n_replicas = n_replicas
errmsg = (
f"Number of replicas {n_replicas} does not equal the"
f" number of lambda windows {len(val['vdw_A'])}"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
@pytest.mark.parametrize(
"val",
[
{"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]},
],
)
def test_validate_lambda_schedule_nwindows(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"]
n_replicas = 3
default_settings.complex_simulation_settings.n_replicas = n_replicas
errmsg = (
"Components elec, vdw, and restraints must have equal amount of lambda "
"windows. Got 3 and 19 elec lambda windows"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
@pytest.mark.parametrize(
"val",
[
{
"elec_A": [0.0, 1.0],
"vdw_A": [1.0, 1.0],
"restraints_A": [0.0, 0.0],
"elec_B": [1.0, 1.0],
"vdw_B": [1.0, 1.0],
"restraints_B": [0.0, 0.0],
},
],
)
def test_validate_lambda_schedule_nakedcharge(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
n_replicas = 2
default_settings.complex_simulation_settings.n_replicas = n_replicas
default_settings.solvent_simulation_settings.n_replicas = n_replicas
errmsg = (
"There are states along this lambda schedule "
"where there are atoms with charges but no LJ "
"interactions: State A: l"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.solvent_simulation_settings,
)
def test_create_default_protocol(default_settings):
# this is roughly how it should be created
protocol = SepTopProtocol(
@@ -315,172 +125,6 @@ def test_create_independent_repeat_ids(
assert len(repeat_ids) == 24
def test_check_alchem_charge_diff(charged_benzene_modifications):
errmsg = "A charge difference of 1"
with pytest.raises(ValueError, match=errmsg):
_check_alchemical_charge_difference(
charged_benzene_modifications["benzene"],
charged_benzene_modifications["benzoic_acid"],
)
def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings):
protocol = SepTopProtocol(
settings=default_settings,
)
stateA = ChemicalSystem(
{
"benzene": charged_benzene_modifications["benzene"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
stateB = ChemicalSystem(
{
"benzoic": charged_benzene_modifications["benzoic_acid"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
errmsg = "A charge difference of 1"
with pytest.raises(ValueError, match=errmsg):
protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "benzene_system", "benzene_complex_system"),
("stateB", "benzene_complex_system", "benzene_system"),
],
)
def test_validate_complex_endstates_protcomp(request, system_A, system_B, fail_endstate):
with pytest.raises(ValueError, match="No ProteinComponent found"):
SepTopProtocol._validate_complex_endstates(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
@pytest.fixture
def T4L_benzene_vacuum(benzene_modifications, T4_protein_component):
return openfe.ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"protein": T4_protein_component,
}
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "T4L_benzene_vacuum", "benzene_complex_system"),
("stateB", "benzene_complex_system", "T4L_benzene_vacuum"),
],
)
def test_validate_complex_endstates_nosolvcomp(
request,
system_A,
system_B,
fail_endstate,
):
with pytest.raises(ValueError, match="No SolventComponent found"):
SepTopProtocol._validate_complex_endstates(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
@pytest.fixture
def T4L_system(T4_protein_component):
return openfe.ChemicalSystem(
{
"solvent": openfe.SolventComponent(),
"protein": T4_protein_component,
}
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "T4L_system", "benzene_complex_system"),
("stateB", "benzene_complex_system", "T4L_system"),
],
)
def test_validate_alchem_comps_missing(
request,
system_A,
system_B,
fail_endstate,
):
alchem_comps = system_validation.get_alchemical_components(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
with pytest.raises(
ValueError,
match=f"one alchemical component must be present in {fail_endstate}.",
):
SepTopProtocol._validate_alchemical_components(alchem_comps)
def test_validate_alchem_comps_toomanyA(
benzene_modifications,
T4_protein_component,
):
stateA = ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"toluene": benzene_modifications["toluene"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
stateB = ChemicalSystem(
{
"phenol": benzene_modifications["phenol"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
assert len(alchem_comps["stateA"]) == 2
assert len(alchem_comps["stateB"]) == 1
with pytest.raises(ValueError, match="present in stateA. Found 2 alchemical components."):
SepTopProtocol._validate_alchemical_components(alchem_comps)
def test_validate_alchem_nonsmc(
benzene_modifications,
T4_protein_component,
):
stateA = ChemicalSystem(
{"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()}
)
stateB = ChemicalSystem(
{"benzene": benzene_modifications["benzene"], "protein": T4_protein_component}
)
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
with pytest.raises(ValueError, match="Only SmallMoleculeComponent alchemical"):
SepTopProtocol._validate_alchemical_components(alchem_comps)
# Tests for the alchemical systems. This tests were modified from
# femto (https://github.com/Psivant/femto/tree/main)
def compute_interaction_energy(
@@ -1816,28 +1460,3 @@ class TestA2AMembraneDryRun:
# Check the PDB
pdb = md.load_pdb("alchemical_system.pdb")
assert pdb.n_atoms == (self.num_ligand_atoms_A + self.num_ligand_atoms_B)
def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings):
settings = SepTopProtocol._adaptive_settings(
toluene_complex_system, toluene_complex_system, default_settings
)
assert isinstance(settings, SepTopSettings)
# Should use default barostat since no ProteinMembraneComponent
assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat"
def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands):
stateA = ChemicalSystem(
{
"ligandA": a2a_ligands[0],
"protein": a2a_protein_membrane_component,
"solvent": SolventComponent(),
}
)
settings = SepTopProtocol._adaptive_settings(stateA, stateA)
assert isinstance(settings, SepTopSettings)
# Barostat should have been updated
assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat"

View File

@@ -0,0 +1,139 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import pytest
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols.openmm_septop import (
SepTopProtocol,
)
from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings
@pytest.fixture()
def protocol_dry_settings():
# a set of settings for dry run tests
s = SepTopProtocol.default_settings()
s.engine_settings.compute_platform = None
s.protocol_repeats = 1
return s
@pytest.fixture()
def default_settings():
s = SepTopProtocol.default_settings()
return s
def test_create_default_settings():
settings = SepTopProtocol.default_settings()
assert settings
@pytest.mark.parametrize(
"val",
[
{"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]},
{"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]},
{"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]},
],
)
def test_incorrect_window_settings(val, default_settings):
errmsg = "Lambda windows must be between 0 and 1."
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_A = val["elec"]
lambda_settings.lambda_vdw_A = val["vdw"]
lambda_settings.lambda_restraints_A = val["restraints"]
@pytest.mark.parametrize(
"val",
[
{
"elec": [0.0, 0.1, 0.0],
"vdw": [0.0, 1.0, 1.0],
"restraints": [0.0, 1.0, 1.0],
},
{
"elec": [0.0, 0.0, 0.0],
"vdw": [0.0, 1.0, 0.0],
"restraints": [0.0, 1.0, 1.0],
},
{
"elec": [0.0, 0.0, 0.0],
"vdw": [0.0, 1.0, 1.0],
"restraints": [0.0, 1.0, 0.0],
},
],
)
def test_monotonic_lambda_windows_A(val, default_settings):
errmsg = "The lambda schedule for ligand A"
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_A = val["elec"]
lambda_settings.lambda_vdw_A = val["vdw"]
lambda_settings.lambda_restraints_A = val["restraints"]
@pytest.mark.parametrize(
"val",
[
{
"elec": [1.0, 0.1, 1.0],
"vdw": [1.0, 1.0, 1.0],
"restraints": [1.0, 1.0, 1.0],
},
{
"elec": [1.0, 1.0, 1.0],
"vdw": [1.0, 0.0, 1.0],
"restraints": [1.0, 1.0, 1.0],
},
{
"elec": [1.0, 1.0, 1.0],
"vdw": [1.0, 1.0, 1.0],
"restraints": [1.0, 0.0, 1.0],
},
],
)
def test_monotonic_lambda_windows_B(val, default_settings):
errmsg = "The lambda schedule for ligand B"
lambda_settings = default_settings.complex_lambda_settings
with pytest.raises(ValueError, match=errmsg):
lambda_settings.lambda_elec_B = val["elec"]
lambda_settings.lambda_vdw_B = val["vdw"]
lambda_settings.lambda_restraints_B = val["restraints"]
def test_output_induces_not_all(default_settings):
errmsg = "Equilibration simulations need to output the full system"
with pytest.raises(ValueError, match=errmsg):
default_settings.complex_equil_output_settings.output_indices = "no water"
def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings):
settings = SepTopProtocol._adaptive_settings(
toluene_complex_system, toluene_complex_system, default_settings
)
assert isinstance(settings, SepTopSettings)
# Should use default barostat since no ProteinMembraneComponent
assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat"
def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands):
stateA = ChemicalSystem(
{
"ligandA": a2a_ligands[0],
"protein": a2a_protein_membrane_component,
"solvent": SolventComponent(),
}
)
settings = SepTopProtocol._adaptive_settings(stateA, stateA)
assert isinstance(settings, SepTopSettings)
# Barostat should have been updated
assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat"

View File

@@ -0,0 +1,295 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import pytest
import openfe
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols.openmm_septop import (
SepTopProtocol,
)
from openfe.protocols.openmm_septop.equil_septop_method import (
_check_alchemical_charge_difference,
)
from openfe.protocols.openmm_utils import system_validation
@pytest.fixture()
def default_settings():
s = SepTopProtocol.default_settings()
return s
@pytest.mark.parametrize(
"val",
[
{
"elec_A": [1.0, 1.0],
"vdw_A": [0.0, 1.0],
"restraints_A": [0.0, 0.0],
"elec_B": [1.0, 1.0],
"vdw_B": [1.0, 1.0],
"restraints_B": [0.0, 0.0],
},
],
)
def test_validate_lambda_schedule_nreplicas(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
n_replicas = 3
default_settings.complex_simulation_settings.n_replicas = n_replicas
errmsg = (
f"Number of replicas {n_replicas} does not equal the"
f" number of lambda windows {len(val['vdw_A'])}"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
@pytest.mark.parametrize(
"val",
[
{"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]},
],
)
def test_validate_lambda_schedule_nwindows(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"]
n_replicas = 3
default_settings.complex_simulation_settings.n_replicas = n_replicas
errmsg = (
"Components elec, vdw, and restraints must have equal amount of lambda "
"windows. Got 3 and 19 elec lambda windows"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
@pytest.mark.parametrize(
"val",
[
{
"elec_A": [0.0, 1.0],
"vdw_A": [1.0, 1.0],
"restraints_A": [0.0, 0.0],
"elec_B": [1.0, 1.0],
"vdw_B": [1.0, 1.0],
"restraints_B": [0.0, 0.0],
},
],
)
def test_validate_lambda_schedule_nakedcharge(val, default_settings):
default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"]
default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"]
default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"]
default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"]
default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"]
default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"]
n_replicas = 2
default_settings.complex_simulation_settings.n_replicas = n_replicas
default_settings.solvent_simulation_settings.n_replicas = n_replicas
errmsg = (
"There are states along this lambda schedule "
"where there are atoms with charges but no LJ "
"interactions: State A: l"
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.complex_simulation_settings,
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_lambda_schedule(
default_settings.complex_lambda_settings,
default_settings.solvent_simulation_settings,
)
def test_check_alchem_charge_diff(charged_benzene_modifications):
errmsg = "A charge difference of 1"
with pytest.raises(ValueError, match=errmsg):
_check_alchemical_charge_difference(
charged_benzene_modifications["benzene"],
charged_benzene_modifications["benzoic_acid"],
)
def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings):
protocol = SepTopProtocol(
settings=default_settings,
)
stateA = ChemicalSystem(
{
"benzene": charged_benzene_modifications["benzene"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
stateB = ChemicalSystem(
{
"benzoic": charged_benzene_modifications["benzoic_acid"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
errmsg = "A charge difference of 1"
with pytest.raises(ValueError, match=errmsg):
protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "benzene_system", "benzene_complex_system"),
("stateB", "benzene_complex_system", "benzene_system"),
],
)
def test_validate_endstates_protcomp(request, system_A, system_B, fail_endstate):
with pytest.raises(ValueError, match="No ProteinComponent found"):
SepTopProtocol._validate_endstates(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
@pytest.fixture
def T4L_benzene_vacuum(benzene_modifications, T4_protein_component):
return openfe.ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"protein": T4_protein_component,
}
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "T4L_benzene_vacuum", "benzene_complex_system"),
("stateB", "benzene_complex_system", "T4L_benzene_vacuum"),
],
)
def test_validate_endstates_nosolvcomp(
request,
system_A,
system_B,
fail_endstate,
):
with pytest.raises(ValueError, match="No SolventComponent found"):
SepTopProtocol._validate_endstates(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
@pytest.fixture
def T4L_system(T4_protein_component):
return openfe.ChemicalSystem(
{
"solvent": openfe.SolventComponent(),
"protein": T4_protein_component,
}
)
@pytest.mark.parametrize(
"fail_endstate, system_A, system_B",
[
("stateA", "T4L_system", "benzene_complex_system"),
("stateB", "benzene_complex_system", "T4L_system"),
],
)
def test_validate_alchem_comps_missing(
request,
system_A,
system_B,
fail_endstate,
):
errmsg = (
"Only one alchemical species is supported. "
f"Number of unique components found in {fail_endstate}"
)
with pytest.raises(
ValueError,
match=errmsg,
):
SepTopProtocol._validate_endstates(
request.getfixturevalue(system_A),
request.getfixturevalue(system_B),
)
def test_validate_alchem_comps_toomanyA(
benzene_modifications,
T4_protein_component,
):
stateA = ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"toluene": benzene_modifications["toluene"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
stateB = ChemicalSystem(
{
"phenol": benzene_modifications["phenol"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
assert len(alchem_comps["stateA"]) == 2
assert len(alchem_comps["stateB"]) == 1
errmsg = (
"Only one alchemical species is supported. Number of unique components found in stateA: 2."
)
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_endstates(stateA, stateB)
def test_validate_alchem_nonsmc(
benzene_modifications,
T4_protein_component,
):
stateA = ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"protein": T4_protein_component,
"solvent": SolventComponent(neutralize=False),
}
)
stateB = ChemicalSystem(
{
"benzene": benzene_modifications["benzene"],
"protein": T4_protein_component,
"solvent": SolventComponent(),
}
)
errmsg = "Only transforming SmallMoleculeComponents are supported by this Protocol."
with pytest.raises(ValueError, match=errmsg):
SepTopProtocol._validate_endstates(stateA, stateB)