diff --git a/news/validate-septop.rst b/news/validate-septop.rst new file mode 100644 index 00000000..45745f38 --- /dev/null +++ b/news/validate-septop.rst @@ -0,0 +1,26 @@ +**Added:** + +* The `validate` method for the SepTopProtocol has been implemented. + This means that settings and system validation can mostly be done prior + to Protocol execuation by calling + `SepTopProtocol.validate(stateA, stateB, mapping=None)`. + +**Changed:** + +* + +**Deprecated:** + +* + +**Removed:** + +* + +**Fixed:** + +* + +**Security:** + +* diff --git a/src/openfe/protocols/openmm_septop/equil_septop_method.py b/src/openfe/protocols/openmm_septop/equil_septop_method.py index bf75076c..828fd9cd 100644 --- a/src/openfe/protocols/openmm_septop/equil_septop_method.py +++ b/src/openfe/protocols/openmm_septop/equil_septop_method.py @@ -33,6 +33,7 @@ from __future__ import annotations import logging import uuid +import warnings from collections import defaultdict from typing import Any, Iterable, Optional, Union @@ -307,14 +308,14 @@ class SepTopProtocol(gufe.Protocol): return protocol_settings @staticmethod - def _validate_complex_endstates( + def _validate_endstates( stateA: ChemicalSystem, stateB: ChemicalSystem, ) -> None: """ - A complex transformation is defined (in terms of gufe components) + A complex relative transformation is defined (in terms of gufe components) as starting from one or more ligands and a protein in solvent and - ending up in a state with one less ligand. + ending up in a state with one ligand that is different. Parameters ---------- @@ -328,76 +329,55 @@ class SepTopProtocol(gufe.Protocol): ValueError If there is no SolventComponent and no ProteinComponent in either stateA or stateB. + If there are no or more than one alchemical components in state A. + If there are no or more than one alchemical components in state B. + If there are any alchemical components that are not SmallMoleculeComponents. + If a change in net charge between the alchemical components is detected. """ # check that there is a protein component - if not any(isinstance(comp, ProteinComponent) for comp in stateA.values()): + if not stateA.contains(ProteinComponent): errmsg = "No ProteinComponent found in stateA" raise ValueError(errmsg) - if not any(isinstance(comp, ProteinComponent) for comp in stateB.values()): + if not stateB.contains(ProteinComponent): errmsg = "No ProteinComponent found in stateB" raise ValueError(errmsg) + # check that there is only one protein component + system_validation.validate_protein(stateA) + system_validation.validate_protein(stateB) + # check that there is a SolventComponent - if not any(isinstance(comp, SolventComponent) for comp in stateA.values()): + if not stateA.contains(SolventComponent): errmsg = "No SolventComponent found in stateA" raise ValueError(errmsg) - if not any(isinstance(comp, SolventComponent) for comp in stateB.values()): + if not stateB.contains(SolventComponent): errmsg = "No SolventComponent found in stateB" raise ValueError(errmsg) - @staticmethod - def _validate_alchemical_components(alchemical_components: dict[str, list[Component]]) -> None: - """ - Checks that the ChemicalSystem alchemical components are correct. + # Check the difference between the endstates + diff = stateA.component_diff(stateB) - Parameters - ---------- - alchemical_components : Dict[str, list[Component]] - Dictionary containing the alchemical components for - stateA and stateB. - - Raises - ------ - ValueError - * If there are no or more than one alchemical components in state A. - * If there are no or more than one alchemical components in state B. - * If there are any alchemical components that are not - SmallMoleculeComponents - * If a change in net charge between the alchemical components is detected. - - Notes - ----- - * Currently doesn't support alchemical components which are not - SmallMoleculeComponents. - * Currently doesn't support more than one alchemical component - being desolvated. - """ - - # Crash out if there are less or more than one alchemical components - # in state A and B - for state in ["stateA", "stateB"]: - n = len(alchemical_components[state]) - if n != 1: - raise ValueError( - "Exactly one alchemical component must be present in " - f"{state}. Found {n} alchemical components." + for i, state in enumerate(["stateA", "stateB"]): + # Error if there isn't exactly one alchemical component + if len(diff[i]) != 1: + errmsg = ( + "Only one alchemical species is supported. " + f"Number of unique components found in {state}: {len(diff[i])}." ) + raise ValueError(errmsg) - # Crash out if any of the alchemical components are not - # SmallMoleculeComponent - for state in ["stateA", "stateB"]: - for comp in alchemical_components[state]: - if not isinstance(comp, SmallMoleculeComponent): - raise ValueError( - "Only SmallMoleculeComponent alchemical species are supported." - ) + # Error if the component isn't an SMC + if not isinstance(diff[i][0], SmallMoleculeComponent): + errmsg = ( + "Only transforming SmallMoleculeComponents are supported " + f"by this Protocol. Found a {type(diff[i][0])}." + ) + raise ValueError(errmsg) - # Raise an error if there is a change in netcharge - _check_alchemical_charge_difference( - alchemical_components["stateA"][0], alchemical_components["stateB"][0] - ) + # Raise an error if there is a change in net charge + _check_alchemical_charge_difference(diff[0][0], diff[1][0]) @staticmethod def _validate_lambda_schedule( @@ -419,8 +399,10 @@ class SepTopProtocol(gufe.Protocol): ValueError If the number of lambda windows differs for electrostatics and sterics. If the number of replicas does not match the number of lambda windows. - Warnings - If there are non-zero values for restraints (lambda_restraints). + + TODO + ---- + Add a warning if all the lambda restraints are zero? Issue #1945. """ lambda_elec_A = lambda_settings.lambda_elec_A @@ -474,32 +456,36 @@ class SepTopProtocol(gufe.Protocol): f"State {state}: lambda {idx}: elec {e} vdW {v}" ) - def _create( + def _validate( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # TODO: extensions - if extends: # pragma: no-cover - raise NotImplementedError("Can't extend simulations yet") + mapping: gufe.ComponentMapping | list[gufe.ComponentMappping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend + if extends: + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped + # around a try/except for that error type + raise ValueError("Can't extend simulations yet") - # Validate components and get alchemical components + # Check the mappping + if mapping is not None: + wmsg = "A mapping was passed but is not used by this Protocol" + warnings.warn(wmsg) + + # Validate end states system_validation.validate_chemical_system(stateA) system_validation.validate_chemical_system(stateB) - self._validate_complex_endstates(stateA, stateB) - alchem_comps = system_validation.get_alchemical_components( - stateA, - stateB, - ) - self._validate_alchemical_components(alchem_comps) + self._validate_endstates(stateA, stateB) # Validate the lambda schedule self._validate_lambda_schedule( self.settings.solvent_lambda_settings, self.settings.solvent_simulation_settings, ) + self._validate_lambda_schedule( self.settings.complex_lambda_settings, self.settings.complex_simulation_settings, @@ -514,15 +500,30 @@ class SepTopProtocol(gufe.Protocol): settings_validation.validate_openmm_solvation_settings( self.settings.solvent_solvation_settings ) - - # Validate protein component - system_validation.validate_protein(stateA) + settings_validation.validate_openmm_solvation_settings( + self.settings.complex_solvation_settings + ) # Validate the barostat used in combination with the protein component system_validation.validate_barostat( stateA, self.settings.complex_integrator_settings.barostat ) + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> list[gufe.ProtocolUnit]: + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # Get the alchemical components + alchem_comps = system_validation.get_alchemical_components( + stateA, + stateB, + ) + # Create list units for complex and solvent transforms def create_setup_units(unit_cls, leg): return [ diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py index 42aaf132..2f559b97 100644 --- a/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_protocol.py @@ -40,11 +40,6 @@ from openfe.protocols.openmm_septop import ( SepTopSolventRunUnit, SepTopSolventSetupUnit, ) -from openfe.protocols.openmm_septop.equil_septop_method import ( - _check_alchemical_charge_difference, -) -from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings -from openfe.protocols.openmm_utils import system_validation from openfe.protocols.openmm_utils.serialization import deserialize from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry from openfe.tests.protocols.conftest import compute_energy @@ -79,191 +74,6 @@ def default_settings(): return s -def test_create_default_settings(): - settings = SepTopProtocol.default_settings() - assert settings - - -@pytest.mark.parametrize( - "val", - [ - {"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]}, - {"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]}, - {"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]}, - ], -) -def test_incorrect_window_settings(val, default_settings): - errmsg = "Lambda windows must be between 0 and 1." - lambda_settings = default_settings.complex_lambda_settings - with pytest.raises(ValueError, match=errmsg): - lambda_settings.lambda_elec_A = val["elec"] - lambda_settings.lambda_vdw_A = val["vdw"] - lambda_settings.lambda_restraints_A = val["restraints"] - - -@pytest.mark.parametrize( - "val", - [ - { - "elec": [0.0, 0.1, 0.0], - "vdw": [0.0, 1.0, 1.0], - "restraints": [0.0, 1.0, 1.0], - }, - { - "elec": [0.0, 0.0, 0.0], - "vdw": [0.0, 1.0, 0.0], - "restraints": [0.0, 1.0, 1.0], - }, - { - "elec": [0.0, 0.0, 0.0], - "vdw": [0.0, 1.0, 1.0], - "restraints": [0.0, 1.0, 0.0], - }, - ], -) -def test_monotonic_lambda_windows_A(val, default_settings): - errmsg = "The lambda schedule for ligand A" - lambda_settings = default_settings.complex_lambda_settings - - with pytest.raises(ValueError, match=errmsg): - lambda_settings.lambda_elec_A = val["elec"] - lambda_settings.lambda_vdw_A = val["vdw"] - lambda_settings.lambda_restraints_A = val["restraints"] - - -@pytest.mark.parametrize( - "val", - [ - { - "elec": [1.0, 0.1, 1.0], - "vdw": [1.0, 1.0, 1.0], - "restraints": [1.0, 1.0, 1.0], - }, - { - "elec": [1.0, 1.0, 1.0], - "vdw": [1.0, 0.0, 1.0], - "restraints": [1.0, 1.0, 1.0], - }, - { - "elec": [1.0, 1.0, 1.0], - "vdw": [1.0, 1.0, 1.0], - "restraints": [1.0, 0.0, 1.0], - }, - ], -) -def test_monotonic_lambda_windows_B(val, default_settings): - errmsg = "The lambda schedule for ligand B" - lambda_settings = default_settings.complex_lambda_settings - - with pytest.raises(ValueError, match=errmsg): - lambda_settings.lambda_elec_B = val["elec"] - lambda_settings.lambda_vdw_B = val["vdw"] - lambda_settings.lambda_restraints_B = val["restraints"] - - -def test_output_induces_not_all(default_settings): - errmsg = "Equilibration simulations need to output the full system" - - with pytest.raises(ValueError, match=errmsg): - default_settings.complex_equil_output_settings.output_indices = "no water" - - -@pytest.mark.parametrize( - "val", - [ - { - "elec_A": [1.0, 1.0], - "vdw_A": [0.0, 1.0], - "restraints_A": [0.0, 0.0], - "elec_B": [1.0, 1.0], - "vdw_B": [1.0, 1.0], - "restraints_B": [0.0, 0.0], - }, - ], -) -def test_validate_lambda_schedule_nreplicas(val, default_settings): - default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"] - default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"] - default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"] - default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"] - default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"] - default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"] - n_replicas = 3 - default_settings.complex_simulation_settings.n_replicas = n_replicas - errmsg = ( - f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(val['vdw_A'])}" - ) - with pytest.raises(ValueError, match=errmsg): - SepTopProtocol._validate_lambda_schedule( - default_settings.complex_lambda_settings, - default_settings.complex_simulation_settings, - ) - - -@pytest.mark.parametrize( - "val", - [ - {"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]}, - ], -) -def test_validate_lambda_schedule_nwindows(val, default_settings): - default_settings.complex_lambda_settings.lambda_elec_A = val["elec"] - default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"] - default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"] - n_replicas = 3 - default_settings.complex_simulation_settings.n_replicas = n_replicas - errmsg = ( - "Components elec, vdw, and restraints must have equal amount of lambda " - "windows. Got 3 and 19 elec lambda windows" - ) - with pytest.raises(ValueError, match=errmsg): - SepTopProtocol._validate_lambda_schedule( - default_settings.complex_lambda_settings, - default_settings.complex_simulation_settings, - ) - - -@pytest.mark.parametrize( - "val", - [ - { - "elec_A": [0.0, 1.0], - "vdw_A": [1.0, 1.0], - "restraints_A": [0.0, 0.0], - "elec_B": [1.0, 1.0], - "vdw_B": [1.0, 1.0], - "restraints_B": [0.0, 0.0], - }, - ], -) -def test_validate_lambda_schedule_nakedcharge(val, default_settings): - default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"] - default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"] - default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"] - default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"] - default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"] - default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"] - n_replicas = 2 - default_settings.complex_simulation_settings.n_replicas = n_replicas - default_settings.solvent_simulation_settings.n_replicas = n_replicas - errmsg = ( - "There are states along this lambda schedule " - "where there are atoms with charges but no LJ " - "interactions: State A: l" - ) - with pytest.raises(ValueError, match=errmsg): - SepTopProtocol._validate_lambda_schedule( - default_settings.complex_lambda_settings, - default_settings.complex_simulation_settings, - ) - with pytest.raises(ValueError, match=errmsg): - SepTopProtocol._validate_lambda_schedule( - default_settings.complex_lambda_settings, - default_settings.solvent_simulation_settings, - ) - - def test_create_default_protocol(default_settings): # this is roughly how it should be created protocol = SepTopProtocol( @@ -315,172 +125,6 @@ def test_create_independent_repeat_ids( assert len(repeat_ids) == 24 -def test_check_alchem_charge_diff(charged_benzene_modifications): - errmsg = "A charge difference of 1" - with pytest.raises(ValueError, match=errmsg): - _check_alchemical_charge_difference( - charged_benzene_modifications["benzene"], - charged_benzene_modifications["benzoic_acid"], - ) - - -def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings): - protocol = SepTopProtocol( - settings=default_settings, - ) - stateA = ChemicalSystem( - { - "benzene": charged_benzene_modifications["benzene"], - "protein": T4_protein_component, - "solvent": SolventComponent(), - } - ) - - stateB = ChemicalSystem( - { - "benzoic": charged_benzene_modifications["benzoic_acid"], - "protein": T4_protein_component, - "solvent": SolventComponent(), - } - ) - errmsg = "A charge difference of 1" - with pytest.raises(ValueError, match=errmsg): - protocol.create( - stateA=stateA, - stateB=stateB, - mapping=None, - ) - - -@pytest.mark.parametrize( - "fail_endstate, system_A, system_B", - [ - ("stateA", "benzene_system", "benzene_complex_system"), - ("stateB", "benzene_complex_system", "benzene_system"), - ], -) -def test_validate_complex_endstates_protcomp(request, system_A, system_B, fail_endstate): - with pytest.raises(ValueError, match="No ProteinComponent found"): - SepTopProtocol._validate_complex_endstates( - request.getfixturevalue(system_A), - request.getfixturevalue(system_B), - ) - - -@pytest.fixture -def T4L_benzene_vacuum(benzene_modifications, T4_protein_component): - return openfe.ChemicalSystem( - { - "benzene": benzene_modifications["benzene"], - "protein": T4_protein_component, - } - ) - - -@pytest.mark.parametrize( - "fail_endstate, system_A, system_B", - [ - ("stateA", "T4L_benzene_vacuum", "benzene_complex_system"), - ("stateB", "benzene_complex_system", "T4L_benzene_vacuum"), - ], -) -def test_validate_complex_endstates_nosolvcomp( - request, - system_A, - system_B, - fail_endstate, -): - with pytest.raises(ValueError, match="No SolventComponent found"): - SepTopProtocol._validate_complex_endstates( - request.getfixturevalue(system_A), - request.getfixturevalue(system_B), - ) - - -@pytest.fixture -def T4L_system(T4_protein_component): - return openfe.ChemicalSystem( - { - "solvent": openfe.SolventComponent(), - "protein": T4_protein_component, - } - ) - - -@pytest.mark.parametrize( - "fail_endstate, system_A, system_B", - [ - ("stateA", "T4L_system", "benzene_complex_system"), - ("stateB", "benzene_complex_system", "T4L_system"), - ], -) -def test_validate_alchem_comps_missing( - request, - system_A, - system_B, - fail_endstate, -): - alchem_comps = system_validation.get_alchemical_components( - request.getfixturevalue(system_A), - request.getfixturevalue(system_B), - ) - - with pytest.raises( - ValueError, - match=f"one alchemical component must be present in {fail_endstate}.", - ): - SepTopProtocol._validate_alchemical_components(alchem_comps) - - -def test_validate_alchem_comps_toomanyA( - benzene_modifications, - T4_protein_component, -): - stateA = ChemicalSystem( - { - "benzene": benzene_modifications["benzene"], - "toluene": benzene_modifications["toluene"], - "protein": T4_protein_component, - "solvent": SolventComponent(), - } - ) - - stateB = ChemicalSystem( - { - "phenol": benzene_modifications["phenol"], - "protein": T4_protein_component, - "solvent": SolventComponent(), - } - ) - - alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - - assert len(alchem_comps["stateA"]) == 2 - - assert len(alchem_comps["stateB"]) == 1 - - with pytest.raises(ValueError, match="present in stateA. Found 2 alchemical components."): - SepTopProtocol._validate_alchemical_components(alchem_comps) - - -def test_validate_alchem_nonsmc( - benzene_modifications, - T4_protein_component, -): - stateA = ChemicalSystem( - {"benzene": benzene_modifications["benzene"], "solvent": SolventComponent()} - ) - - stateB = ChemicalSystem( - {"benzene": benzene_modifications["benzene"], "protein": T4_protein_component} - ) - - alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - - with pytest.raises(ValueError, match="Only SmallMoleculeComponent alchemical"): - SepTopProtocol._validate_alchemical_components(alchem_comps) - - # Tests for the alchemical systems. This tests were modified from # femto (https://github.com/Psivant/femto/tree/main) def compute_interaction_energy( @@ -1816,28 +1460,3 @@ class TestA2AMembraneDryRun: # Check the PDB pdb = md.load_pdb("alchemical_system.pdb") assert pdb.n_atoms == (self.num_ligand_atoms_A + self.num_ligand_atoms_B) - - -def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings): - settings = SepTopProtocol._adaptive_settings( - toluene_complex_system, toluene_complex_system, default_settings - ) - - assert isinstance(settings, SepTopSettings) - # Should use default barostat since no ProteinMembraneComponent - assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat" - - -def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands): - stateA = ChemicalSystem( - { - "ligandA": a2a_ligands[0], - "protein": a2a_protein_membrane_component, - "solvent": SolventComponent(), - } - ) - - settings = SepTopProtocol._adaptive_settings(stateA, stateA) - assert isinstance(settings, SepTopSettings) - # Barostat should have been updated - assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat" diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_settings.py b/src/openfe/tests/protocols/openmm_septop/test_septop_settings.py new file mode 100644 index 00000000..8acd6b76 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_settings.py @@ -0,0 +1,139 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest + +from openfe import ChemicalSystem, SolventComponent +from openfe.protocols.openmm_septop import ( + SepTopProtocol, +) +from openfe.protocols.openmm_septop.equil_septop_settings import SepTopSettings + + +@pytest.fixture() +def protocol_dry_settings(): + # a set of settings for dry run tests + s = SepTopProtocol.default_settings() + s.engine_settings.compute_platform = None + s.protocol_repeats = 1 + return s + + +@pytest.fixture() +def default_settings(): + s = SepTopProtocol.default_settings() + return s + + +def test_create_default_settings(): + settings = SepTopProtocol.default_settings() + assert settings + + +@pytest.mark.parametrize( + "val", + [ + {"elec": [0.0, -1], "vdw": [0.0, 1.0], "restraints": [0.0, 1.0]}, + {"elec": [0.0, 1], "vdw": [0.0, 1.5], "restraints": [0.0, 1.0]}, + {"elec": [0.0, 1], "vdw": [0.0, 1], "restraints": [-0.1, 1.0]}, + ], +) +def test_incorrect_window_settings(val, default_settings): + errmsg = "Lambda windows must be between 0 and 1." + lambda_settings = default_settings.complex_lambda_settings + with pytest.raises(ValueError, match=errmsg): + lambda_settings.lambda_elec_A = val["elec"] + lambda_settings.lambda_vdw_A = val["vdw"] + lambda_settings.lambda_restraints_A = val["restraints"] + + +@pytest.mark.parametrize( + "val", + [ + { + "elec": [0.0, 0.1, 0.0], + "vdw": [0.0, 1.0, 1.0], + "restraints": [0.0, 1.0, 1.0], + }, + { + "elec": [0.0, 0.0, 0.0], + "vdw": [0.0, 1.0, 0.0], + "restraints": [0.0, 1.0, 1.0], + }, + { + "elec": [0.0, 0.0, 0.0], + "vdw": [0.0, 1.0, 1.0], + "restraints": [0.0, 1.0, 0.0], + }, + ], +) +def test_monotonic_lambda_windows_A(val, default_settings): + errmsg = "The lambda schedule for ligand A" + lambda_settings = default_settings.complex_lambda_settings + + with pytest.raises(ValueError, match=errmsg): + lambda_settings.lambda_elec_A = val["elec"] + lambda_settings.lambda_vdw_A = val["vdw"] + lambda_settings.lambda_restraints_A = val["restraints"] + + +@pytest.mark.parametrize( + "val", + [ + { + "elec": [1.0, 0.1, 1.0], + "vdw": [1.0, 1.0, 1.0], + "restraints": [1.0, 1.0, 1.0], + }, + { + "elec": [1.0, 1.0, 1.0], + "vdw": [1.0, 0.0, 1.0], + "restraints": [1.0, 1.0, 1.0], + }, + { + "elec": [1.0, 1.0, 1.0], + "vdw": [1.0, 1.0, 1.0], + "restraints": [1.0, 0.0, 1.0], + }, + ], +) +def test_monotonic_lambda_windows_B(val, default_settings): + errmsg = "The lambda schedule for ligand B" + lambda_settings = default_settings.complex_lambda_settings + + with pytest.raises(ValueError, match=errmsg): + lambda_settings.lambda_elec_B = val["elec"] + lambda_settings.lambda_vdw_B = val["vdw"] + lambda_settings.lambda_restraints_B = val["restraints"] + + +def test_output_induces_not_all(default_settings): + errmsg = "Equilibration simulations need to output the full system" + + with pytest.raises(ValueError, match=errmsg): + default_settings.complex_equil_output_settings.output_indices = "no water" + + +def test_adaptive_settings_no_protein_membrane(toluene_complex_system, default_settings): + settings = SepTopProtocol._adaptive_settings( + toluene_complex_system, toluene_complex_system, default_settings + ) + + assert isinstance(settings, SepTopSettings) + # Should use default barostat since no ProteinMembraneComponent + assert settings.complex_integrator_settings.barostat == "MonteCarloBarostat" + + +def test_adaptive_settings_with_protein_membrane(a2a_protein_membrane_component, a2a_ligands): + stateA = ChemicalSystem( + { + "ligandA": a2a_ligands[0], + "protein": a2a_protein_membrane_component, + "solvent": SolventComponent(), + } + ) + + settings = SepTopProtocol._adaptive_settings(stateA, stateA) + assert isinstance(settings, SepTopSettings) + # Barostat should have been updated + assert settings.complex_integrator_settings.barostat == "MonteCarloMembraneBarostat" diff --git a/src/openfe/tests/protocols/openmm_septop/test_septop_validation.py b/src/openfe/tests/protocols/openmm_septop/test_septop_validation.py new file mode 100644 index 00000000..c6e45ad7 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_septop/test_septop_validation.py @@ -0,0 +1,295 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import pytest + +import openfe +from openfe import ChemicalSystem, SolventComponent +from openfe.protocols.openmm_septop import ( + SepTopProtocol, +) +from openfe.protocols.openmm_septop.equil_septop_method import ( + _check_alchemical_charge_difference, +) +from openfe.protocols.openmm_utils import system_validation + + +@pytest.fixture() +def default_settings(): + s = SepTopProtocol.default_settings() + return s + + +@pytest.mark.parametrize( + "val", + [ + { + "elec_A": [1.0, 1.0], + "vdw_A": [0.0, 1.0], + "restraints_A": [0.0, 0.0], + "elec_B": [1.0, 1.0], + "vdw_B": [1.0, 1.0], + "restraints_B": [0.0, 0.0], + }, + ], +) +def test_validate_lambda_schedule_nreplicas(val, default_settings): + default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"] + default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"] + default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"] + default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"] + default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"] + default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"] + n_replicas = 3 + default_settings.complex_simulation_settings.n_replicas = n_replicas + errmsg = ( + f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(val['vdw_A'])}" + ) + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_lambda_schedule( + default_settings.complex_lambda_settings, + default_settings.complex_simulation_settings, + ) + + +@pytest.mark.parametrize( + "val", + [ + {"elec": [1.0, 1.0, 1.0], "vdw": [0.0, 1.0], "restraints": [0.0, 0.0]}, + ], +) +def test_validate_lambda_schedule_nwindows(val, default_settings): + default_settings.complex_lambda_settings.lambda_elec_A = val["elec"] + default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw"] + default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints"] + n_replicas = 3 + default_settings.complex_simulation_settings.n_replicas = n_replicas + errmsg = ( + "Components elec, vdw, and restraints must have equal amount of lambda " + "windows. Got 3 and 19 elec lambda windows" + ) + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_lambda_schedule( + default_settings.complex_lambda_settings, + default_settings.complex_simulation_settings, + ) + + +@pytest.mark.parametrize( + "val", + [ + { + "elec_A": [0.0, 1.0], + "vdw_A": [1.0, 1.0], + "restraints_A": [0.0, 0.0], + "elec_B": [1.0, 1.0], + "vdw_B": [1.0, 1.0], + "restraints_B": [0.0, 0.0], + }, + ], +) +def test_validate_lambda_schedule_nakedcharge(val, default_settings): + default_settings.complex_lambda_settings.lambda_elec_A = val["elec_A"] + default_settings.complex_lambda_settings.lambda_vdw_A = val["vdw_A"] + default_settings.complex_lambda_settings.lambda_restraints_A = val["restraints_A"] + default_settings.complex_lambda_settings.lambda_elec_B = val["elec_B"] + default_settings.complex_lambda_settings.lambda_vdw_B = val["vdw_B"] + default_settings.complex_lambda_settings.lambda_restraints_B = val["restraints_B"] + n_replicas = 2 + default_settings.complex_simulation_settings.n_replicas = n_replicas + default_settings.solvent_simulation_settings.n_replicas = n_replicas + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + "interactions: State A: l" + ) + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_lambda_schedule( + default_settings.complex_lambda_settings, + default_settings.complex_simulation_settings, + ) + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_lambda_schedule( + default_settings.complex_lambda_settings, + default_settings.solvent_simulation_settings, + ) + + +def test_check_alchem_charge_diff(charged_benzene_modifications): + errmsg = "A charge difference of 1" + with pytest.raises(ValueError, match=errmsg): + _check_alchemical_charge_difference( + charged_benzene_modifications["benzene"], + charged_benzene_modifications["benzoic_acid"], + ) + + +def test_charge_error_create(charged_benzene_modifications, T4_protein_component, default_settings): + protocol = SepTopProtocol( + settings=default_settings, + ) + stateA = ChemicalSystem( + { + "benzene": charged_benzene_modifications["benzene"], + "protein": T4_protein_component, + "solvent": SolventComponent(), + } + ) + + stateB = ChemicalSystem( + { + "benzoic": charged_benzene_modifications["benzoic_acid"], + "protein": T4_protein_component, + "solvent": SolventComponent(), + } + ) + errmsg = "A charge difference of 1" + with pytest.raises(ValueError, match=errmsg): + protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + +@pytest.mark.parametrize( + "fail_endstate, system_A, system_B", + [ + ("stateA", "benzene_system", "benzene_complex_system"), + ("stateB", "benzene_complex_system", "benzene_system"), + ], +) +def test_validate_endstates_protcomp(request, system_A, system_B, fail_endstate): + with pytest.raises(ValueError, match="No ProteinComponent found"): + SepTopProtocol._validate_endstates( + request.getfixturevalue(system_A), + request.getfixturevalue(system_B), + ) + + +@pytest.fixture +def T4L_benzene_vacuum(benzene_modifications, T4_protein_component): + return openfe.ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "protein": T4_protein_component, + } + ) + + +@pytest.mark.parametrize( + "fail_endstate, system_A, system_B", + [ + ("stateA", "T4L_benzene_vacuum", "benzene_complex_system"), + ("stateB", "benzene_complex_system", "T4L_benzene_vacuum"), + ], +) +def test_validate_endstates_nosolvcomp( + request, + system_A, + system_B, + fail_endstate, +): + with pytest.raises(ValueError, match="No SolventComponent found"): + SepTopProtocol._validate_endstates( + request.getfixturevalue(system_A), + request.getfixturevalue(system_B), + ) + + +@pytest.fixture +def T4L_system(T4_protein_component): + return openfe.ChemicalSystem( + { + "solvent": openfe.SolventComponent(), + "protein": T4_protein_component, + } + ) + + +@pytest.mark.parametrize( + "fail_endstate, system_A, system_B", + [ + ("stateA", "T4L_system", "benzene_complex_system"), + ("stateB", "benzene_complex_system", "T4L_system"), + ], +) +def test_validate_alchem_comps_missing( + request, + system_A, + system_B, + fail_endstate, +): + errmsg = ( + "Only one alchemical species is supported. " + f"Number of unique components found in {fail_endstate}" + ) + + with pytest.raises( + ValueError, + match=errmsg, + ): + SepTopProtocol._validate_endstates( + request.getfixturevalue(system_A), + request.getfixturevalue(system_B), + ) + + +def test_validate_alchem_comps_toomanyA( + benzene_modifications, + T4_protein_component, +): + stateA = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "toluene": benzene_modifications["toluene"], + "protein": T4_protein_component, + "solvent": SolventComponent(), + } + ) + + stateB = ChemicalSystem( + { + "phenol": benzene_modifications["phenol"], + "protein": T4_protein_component, + "solvent": SolventComponent(), + } + ) + + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + + assert len(alchem_comps["stateA"]) == 2 + + assert len(alchem_comps["stateB"]) == 1 + + errmsg = ( + "Only one alchemical species is supported. Number of unique components found in stateA: 2." + ) + + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_endstates(stateA, stateB) + + +def test_validate_alchem_nonsmc( + benzene_modifications, + T4_protein_component, +): + stateA = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "protein": T4_protein_component, + "solvent": SolventComponent(neutralize=False), + } + ) + + stateB = ChemicalSystem( + { + "benzene": benzene_modifications["benzene"], + "protein": T4_protein_component, + "solvent": SolventComponent(), + } + ) + + errmsg = "Only transforming SmallMoleculeComponents are supported by this Protocol." + with pytest.raises(ValueError, match=errmsg): + SepTopProtocol._validate_endstates(stateA, stateB)