Plain MD restart support (#1884)

* split md protocol to setup and simulate, add restart support

* allow for resume in any stage

* add a single run dynamics function

* update restart to only look for checkpoints, split out remaining step logic, update tests

---------

Co-authored-by: Irfan Alibay <IAlibay@users.noreply.github.com>
This commit is contained in:
Josh Horton
2026-04-27 16:04:28 +01:00
committed by GitHub
parent be6cd88414
commit 909d29ad25
13 changed files with 1064 additions and 389 deletions

View File

@@ -16,7 +16,8 @@ Protocol API Specification
:toctree: generated/
PlainMDProtocol
PlainMDProtocolUnit
PlainMDSetupUnit
PlainMDSimulationUnit
PlainMDProtocolResult

24
news/resume-plainmd.rst Normal file
View File

@@ -0,0 +1,24 @@
**Added:**
* * Added API support for resuming the PlainMDProtocol.
PR #1884.
**Changed:**
* <news item>
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View File

@@ -22,10 +22,16 @@ zenodo_resume_data = dict(
fname="multistate_checkpoints.zip",
known_hash="md5:a6bdceff0c4a2f200538edb17c21d443",
)
zenodo_md_resume_data = dict(
base_url="doi:10.5281/zenodo.19694944",
fname="checkpoint.xml",
known_hash="md5:0f3957c263b5def8de727c5c419b31b5",
)
zenodo_data_registry = [
zenodo_rfe_simulation_nc,
zenodo_t4_lysozyme_traj,
zenodo_industry_benchmark_systems,
zenodo_resume_data,
zenodo_md_resume_data,
]

View File

@@ -65,7 +65,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
OpenFFPartialChargeSettings,
ThermoSettings,
)
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
from openfe.protocols.openmm_md.plain_md_methods import PlainMDSimulationUnit
from openfe.protocols.openmm_utils import (
charge_generation,
multistate_analysis,
@@ -320,9 +320,9 @@ class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin):
box = system.getDefaultPeriodicBoxVectors()
return positions, to_openmm(from_openmm(box))
# Use the _run_MD method from the PlainMDProtocolUnit
# Use the _run_MD method from the PlainMDSimulationUnit
# Should in-place modify the simulation
PlainMDProtocolUnit._run_MD(
PlainMDSimulationUnit._run_MD(
simulation=simulation,
positions=positions,
simulation_settings=settings["equil_simulation_settings"],

View File

@@ -9,12 +9,14 @@ from .plain_md_methods import (
PlainMDProtocol,
PlainMDProtocolResult,
PlainMDProtocolSettings,
PlainMDProtocolUnit,
PlainMDSetupUnit,
PlainMDSimulationUnit,
)
__all__ = [
"PlainMDProtocol",
"PlainMDProtocolSettings",
"PlainMDProtocolResult",
"PlainMDProtocolUnit",
"PlainMDSetupUnit",
"PlainMDSimulationUnit",
]

File diff suppressed because it is too large Load Diff

View File

@@ -56,7 +56,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
OpenMMSystemGeneratorFFSettings,
ThermoSettings,
)
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
from openfe.protocols.openmm_md.plain_md_methods import PlainMDSimulationUnit
from openfe.protocols.openmm_utils import omm_compute
from openfe.protocols.openmm_utils.omm_settings import SettingsBaseModel
from openfe.protocols.openmm_utils.serialization import deserialize
@@ -202,9 +202,9 @@ def _pre_equilibrate(
errmsg = f"Only 'A', 'B', and 'AB' are accepted as endstates. Got {endstate}"
raise ValueError(errmsg)
# Use the _run_MD method from the PlainMDProtocolUnit
# Use the _run_MD method from the PlainMDSimulationUnit
# Should in-place modify the simulation
PlainMDProtocolUnit._run_MD(
PlainMDSimulationUnit._run_MD(
simulation=simulation,
positions=positions,
simulation_settings=settings["equil_simulation_settings"],

View File

@@ -237,4 +237,4 @@ styrene
10 11 4 0 0 0 0
11 16 1 0 0 0 0
M END
$$$$
$$$$

View File

@@ -20,6 +20,7 @@ import openfe
from openfe.data._registry import (
POOCH_CACHE,
zenodo_industry_benchmark_systems,
zenodo_md_resume_data,
zenodo_resume_data,
zenodo_rfe_simulation_nc,
zenodo_t4_lysozyme_traj,
@@ -406,6 +407,19 @@ def septop_solv_checkpoint_path():
return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}")
pooch_md_resume_data = pooch.create(
path=POOCH_CACHE,
base_url=zenodo_md_resume_data["base_url"],
registry={zenodo_md_resume_data["fname"]: zenodo_md_resume_data["known_hash"]},
)
@pytest.fixture(scope="module")
def plain_md_checkpoint_path():
pooch_md_resume_data.fetch("checkpoint.xml")
return pathlib.Path(pooch.os_cache("openfe") / "checkpoint.xml")
@pytest.fixture(scope="session")
def available_platforms() -> set[str]:
return {

View File

@@ -7,8 +7,10 @@ import sys
from unittest import mock
import gufe
import numpy as np
import openmm
import pytest
from gufe import ChemicalSystem, SmallMoleculeComponent
from gufe import ChemicalSystem, LigandAtomMapping, SmallMoleculeComponent
from numpy.testing import assert_allclose
from openff.units import unit
from openff.units.openmm import from_openmm, to_openmm
@@ -22,8 +24,10 @@ from openfe.protocols import openmm_md
from openfe.protocols.openmm_md.plain_md_methods import (
PlainMDProtocol,
PlainMDProtocolResult,
PlainMDProtocolUnit,
PlainMDSetupUnit,
PlainMDSimulationUnit,
)
from openfe.protocols.openmm_utils import serialization
from openfe.protocols.openmm_utils.charge_generation import (
HAS_ESPALOMA_CHARGE,
HAS_NAGL,
@@ -40,6 +44,25 @@ def vac_settings():
return settings
@pytest.mark.parametrize(
"inputs, expected",
[
# inputs are current step count, nvt steps, npt steps and prod steps
# outputs are steps to run for nvt, npt, prod and if the production phase has started
pytest.param([50, 100, 100, 100], [50, 100, 100, False], id="nvt resuming"),
pytest.param([101, 100, 100, 100], [0, 99, 100, False], id="npt resuming"),
pytest.param([220, 100, 100, 100], [0, 0, 80, True], id="prod resuming"),
pytest.param([200, 100, 100, 100], [0, 0, 100, False], id="prod resuming not started"),
],
)
def test_get_remaining_steps(inputs, expected):
nvt, npt, prod, is_prod = PlainMDSimulationUnit._get_remaining_steps(*inputs)
assert nvt == expected[0]
assert npt == expected[1]
assert prod == expected[2]
assert is_prod == expected[3]
def test_create_default_settings():
settings = PlainMDProtocol.default_settings()
@@ -92,7 +115,7 @@ def test_create_independent_repeat_ids(benzene_system):
)
repeat_ids = set()
u: PlainMDProtocolUnit
u: PlainMDSetupUnit | PlainMDSimulationUnit
for u in dag1.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
for u in dag2.protocol_units:
@@ -104,14 +127,14 @@ def test_create_independent_repeat_ids(benzene_system):
def test_dry_run_default_vacuum(benzene_vacuum_system, vac_settings, tmp_path):
protocol = PlainMDProtocol(settings=vac_settings)
# create DAG from protocol and take first (and only) work unit from within
# create DAG from protocol and take the setup unit
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=None,
)
dag_unit = list(dag.protocol_units)[0]
result = dag_unit.run(
setup_unit = list(dag.protocol_units)[0]
result = setup_unit.run(
dry=True, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)
system = result["debug"]["system"]
@@ -138,22 +161,46 @@ def test_dry_run_logger_output(benzene_vacuum_system, vac_settings, tmp_path, ca
settings=vac_settings,
)
# create DAG from protocol and take first (and only) work unit from within
# create DAG from protocol
dag = protocol.create(
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=None,
)
dag_unit = list(dag.protocol_units)[0]
setup_unit = list(dag.protocol_units)[0]
caplog.set_level(logging.INFO)
dag_unit.run(dry=False, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
setup_results = setup_unit.run(
dry=False, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)
messages = [r.message for r in caplog.records]
assert "minimizing systems" in messages
assert "Running NVT equilibration" in messages
assert "Running NPT equilibration" in messages
assert "running production phase" in messages
assert "Creating system" in messages
# now run the production unit after extracting outputs from the setup unit
system = serialization.deserialize(setup_results["system"])
positions = np.load(setup_results["positions"]) * omm_unit.nanometers
topology = openmm.app.PDBFile(str(setup_results["system_pdb"])).getTopology()
equil_steps_nvt = setup_results["equil_steps_nvt"]
equil_steps_npt = setup_results["equil_steps_npt"]
prod_steps = setup_results["prod_steps"]
prod_unit = list(dag.protocol_units)[1]
prod_unit.run(
system=system,
positions=positions,
topology=topology,
equil_steps_nvt=equil_steps_nvt,
equil_steps_npt=equil_steps_npt,
prod_steps=prod_steps,
dry=False,
verbose=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)
messages = [r.message for r in caplog.records]
assert "Minimizing systems" in messages
assert "Running NVT equilibration for 250 steps" in messages
assert "Running NPT equilibration for 250 steps" in messages
assert "Running production phase for 250 steps" in messages
def test_dry_run_ffcache_none_vacuum(benzene_vacuum_system, vac_settings, tmp_path):
@@ -441,17 +488,14 @@ def test_hightimestep(benzene_vacuum_system, tmp_path):
settings.forcefield_settings.nonbonded_method = "nocutoff"
p = PlainMDProtocol(settings=settings)
dag = p.create(
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=None,
)
dag_unit = list(dag.protocol_units)[0]
errmsg = "too large for hydrogen mass"
# make sure this is triggered in validate
with pytest.raises(ValueError, match=errmsg):
dag_unit.run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
_ = p.create(
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=None,
)
def test_vaccuum_PME_error(benzene_vacuum_system):
@@ -459,6 +503,7 @@ def test_vaccuum_PME_error(benzene_vacuum_system):
settings=PlainMDProtocol.default_settings(),
)
errmsg = "PME cannot be used for vacuum transform"
# make sure this is triggered in validate
with pytest.raises(ValueError, match=errmsg):
_ = p.create(
stateA=benzene_vacuum_system,
@@ -486,6 +531,35 @@ def test_multiple_basesolvents_error(a2a_protein_membrane_component):
)
def test_states_not_matching_error(benzene_vacuum_system, toluene_vacuum_system):
p = PlainMDProtocol(settings=PlainMDProtocol.default_settings())
errmsg = "The two end states do not match."
with pytest.raises(ValueError, match=errmsg):
_ = p.create(
stateA=benzene_vacuum_system,
stateB=toluene_vacuum_system,
mapping=None,
)
def test_mapping_warning(benzene_vacuum_system, tmp_path):
settings = PlainMDProtocol.default_settings()
settings.forcefield_settings.nonbonded_method = "nocutoff"
p = PlainMDProtocol(settings=settings)
warnmsg = "A mapping was passed but is not used by this Protocol."
benzene = benzene_vacuum_system.components["ligand"]
with pytest.warns(match=warnmsg):
_ = p.create(
stateA=benzene_vacuum_system,
stateB=benzene_vacuum_system,
mapping=LigandAtomMapping(
componentA=benzene,
componentB=benzene,
componentA_to_componentB=dict((i, i) for i in range(12)),
),
)
@pytest.fixture
def solvent_protocol_dag(benzene_system):
settings = PlainMDProtocol.default_settings()
@@ -499,11 +573,20 @@ def solvent_protocol_dag(benzene_system):
)
def test_unit_tagging(solvent_protocol_dag, tmp_path):
def test_unit_tagging(benzene_system, tmp_path):
# test that executing the Units includes correct generation and repeat info
dag_units = solvent_protocol_dag.protocol_units
settings = PlainMDProtocol.default_settings()
settings.protocol_repeats = 3
protocol = PlainMDProtocol(settings=settings)
dag = protocol.create(
stateA=benzene_system,
stateB=benzene_system,
mapping=None,
)
dag_units = dag.protocol_units
with mock.patch(
"openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run",
"openfe.protocols.openmm_md.plain_md_methods.PlainMDSimulationUnit.run",
return_value={
"nc": "simulation.xtc",
"last_checkpoint": "checkpoint.chk",
@@ -511,8 +594,10 @@ def test_unit_tagging(solvent_protocol_dag, tmp_path):
):
results = []
for u in dag_units:
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
results.append(ret)
# just execute the setup unit so we don't have to pass the results though to the simulation unit
if isinstance(u, PlainMDSetupUnit):
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
results.append(ret)
repeats = set()
for ret in results:
@@ -526,7 +611,7 @@ def test_unit_tagging(solvent_protocol_dag, tmp_path):
def test_gather(solvent_protocol_dag, tmp_path):
# check .gather behaves as expected
with mock.patch(
"openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run",
"openfe.protocols.openmm_md.plain_md_methods.PlainMDSimulationUnit.run",
return_value={
"nc": "simulation.xtc",
"last_checkpoint": "checkpoint.chk",

View File

@@ -0,0 +1,167 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import os
import pathlib
import shutil
import gufe
import openmm
import openmm.unit as openmm_unit
import pytest
from gufe import ChemicalSystem, SmallMoleculeComponent
from gufe.protocols.errors import ProtocolUnitExecutionError
from openff.units import unit
import openfe
from openfe.data._registry import POOCH_CACHE
from openfe.protocols.openmm_md.plain_md_methods import (
PlainMDProtocol,
PlainMDSetupUnit,
PlainMDSimulationUnit,
)
from ...conftest import HAS_INTERNET
@pytest.fixture()
def vacuum_protocol_settings():
# setup a cheap vacuum md protocol
settings = PlainMDProtocol.default_settings()
settings.protocol_repeats = 1
settings.forcefield_settings.nonbonded_method = "nocutoff"
settings.engine_settings.compute_platform = None
settings.simulation_settings.equilibration_length_nvt = 1 * unit.picoseconds
settings.simulation_settings.equilibration_length = 1 * unit.picoseconds
settings.simulation_settings.production_length = 1 * unit.picoseconds
settings.output_settings.checkpoint_interval = 0.5 * unit.picoseconds
settings.output_settings.trajectory_write_interval = 0.5 * unit.picoseconds
return settings
def test_verify_execution_environment():
# verify using the current versions of the software
PlainMDSimulationUnit._verify_execution_environment(
setup_outputs={
"gufe_version": gufe.__version__,
"openfe_version": openfe.__version__,
"openmm_version": openmm.__version__,
}
)
def test_verify_execution_environment_fail():
# pass in different versions to force failure
with pytest.raises(ProtocolUnitExecutionError, match="Python environment"):
PlainMDSimulationUnit._verify_execution_environment(
setup_outputs={
"gufe_version": 0.1,
"openfe_version": openmm.__version__,
"openmm_version": openmm.__version__,
}
)
def test_verify_execution_env_missing_key():
errmsg = "Missing environment information from setup outputs."
with pytest.raises(ProtocolUnitExecutionError, match=errmsg):
PlainMDSimulationUnit._verify_execution_environment(
setup_outputs={
"foo_version": 0.1,
"openfe_version": openfe.__version__,
"openmm_version": openmm.__version__,
},
)
@pytest.mark.skipif(
not os.path.exists(POOCH_CACHE) and not HAS_INTERNET,
reason="Internet unavailable and test data is not cached locally",
)
def test_check_restart(vacuum_protocol_settings, plain_md_checkpoint_path):
# test we can correctly detect when we should be restarting
assert PlainMDSimulationUnit._check_restart(
output_settings=vacuum_protocol_settings.output_settings,
shared_path=plain_md_checkpoint_path.parent,
)
# make sure it does not try and restart if inputs are missing
assert not PlainMDSimulationUnit._check_restart(
output_settings=vacuum_protocol_settings.output_settings,
shared_path=pathlib.Path("."),
)
@pytest.mark.skipif(
not os.path.exists(POOCH_CACHE) and not HAS_INTERNET,
reason="Internet unavailable and test data is not cached locally",
)
class TestPlainMDResume:
@pytest.fixture
def protocol_dag(self, vacuum_protocol_settings, benzene_vacuum_system):
protocol = PlainMDProtocol(
settings=vacuum_protocol_settings,
)
return protocol.create(
stateA=benzene_vacuum_system, stateB=benzene_vacuum_system, mapping=None
)
def test_resume(
self, protocol_dag, tmp_path, caplog, vacuum_protocol_settings, plain_md_checkpoint_path
):
# test that we can resume a simulation from a checkpoint
protocol_units = list(protocol_dag.protocol_units)
setup_unit: PlainMDSetupUnit = protocol_units[0]
simulation_unit: PlainMDSimulationUnit = protocol_units[1]
# copy the files over
shutil.copyfile(plain_md_checkpoint_path, tmp_path / "checkpoint.xml")
# dry run the setup unit
setup_results = setup_unit.run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)
# make sure the protocol thinks it can restart
assert PlainMDSimulationUnit._check_restart(
output_settings=vacuum_protocol_settings.output_settings,
shared_path=tmp_path,
)
# now run the simulation unit in resume mode this should be 0.5 ps of equilibration and 1 ps of production
sim_results = simulation_unit.run(
system=setup_results["debug"]["system"],
positions=setup_results["debug"]["positions"],
topology=setup_results["debug"]["topology"],
equil_steps_nvt=setup_results["equil_steps_nvt"],
equil_steps_npt=setup_results["equil_steps_npt"],
prod_steps=setup_results["prod_steps"],
verbose=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)
# make sure it prints that its restarting
assert "Restarting simulation from checkpoint state" in caplog.text
# check the number of npt steps to run is correct, this should be 0.5 ps at 4fs timestep
assert "Running NPT equilibration for 125 steps" in caplog.text
# make sure the production phase steps are correct, this should be the full 1ps at 4fs timestep
assert "Running production phase for 250 steps" in caplog.text
# check the outputs of the simulation unit
assert sim_results["system_pdb"].exists()
assert sim_results["nc"].exists()
assert sim_results["last_checkpoint"]
# load the final checkpoint and check the simulation time is correct, this should be 3 ps
# also check the total step count
simulation = openmm.app.Simulation(
setup_results["debug"]["topology"],
setup_results["debug"]["system"],
openmm.LangevinMiddleIntegrator(
298.15 * openmm_unit.kelvin,
1.0 / openmm_unit.picosecond,
4 * openmm_unit.femtoseconds,
),
)
simulation.context.setPositions(setup_results["debug"]["positions"])
simulation.loadState(str(sim_results["last_checkpoint"]))
total_sim_time = simulation.context.getTime()
# check the time is 3 ps
assert total_sim_time.value_in_unit(openmm_unit.picoseconds) == pytest.approx(3)
# check the step count has been extended
assert simulation.context.getStepCount() == 750

View File

@@ -25,7 +25,7 @@ def test_vacuum_sim(
settings.simulation_settings.equilibration_length_nvt = None
settings.simulation_settings.equilibration_length = 10 * unit.picosecond
settings.simulation_settings.production_length = 20 * unit.picosecond
settings.output_settings.checkpoint_interval = 40 * unit.picosecond
settings.output_settings.checkpoint_interval = 5 * unit.picosecond
settings.forcefield_settings.nonbonded_method = "nocutoff"
settings.engine_settings.compute_platform = platform
@@ -46,9 +46,9 @@ def test_vacuum_sim(
assert r.ok()
assert len(r.protocol_unit_results) == 1
assert len(r.protocol_unit_results) == 2
pur = r.protocol_unit_results[0]
pur = r.protocol_unit_results[1]
unit_shared = tmp_path / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
@@ -59,20 +59,19 @@ def test_vacuum_sim(
"minimized.pdb",
"simulation.xtc",
"simulation.log",
"system.pdb",
"checkpoint.xml",
]
for file in files:
assert (unit_shared / file).exists()
# NVT PDB should not exist
assert not (unit_shared / "equil_nvt.pdb").exists()
assert not (unit_shared / "checkpoint.chk").exists()
# check that the output file paths are correct
assert pur.outputs["system_pdb"] == unit_shared / "system.pdb"
assert pur.outputs["minimized_pdb"] == unit_shared / "minimized.pdb"
assert pur.outputs["nc"] == unit_shared / "simulation.xtc"
assert pur.outputs["last_checkpoint"] is None
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.xml"
assert pur.outputs["npt_equil_pdb"] == unit_shared / "equil_npt.pdb"
assert pur.outputs["nvt_equil_pdb"] is None
@@ -113,22 +112,21 @@ def test_complex_solvent_sim_gpu(
assert r.ok()
assert len(r.protocol_unit_results) == 1
assert len(r.protocol_unit_results) == 2
pur = r.protocol_unit_results[0]
pur = r.protocol_unit_results[1]
unit_shared = tmp_path / f"shared_{pur.source_key}_attempt_0"
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
# check the files
files = [
"checkpoint.chk",
"checkpoint.xml",
"equil_nvt.pdb",
"equil_npt.pdb",
"minimized.pdb",
"simulation.xtc",
"simulation.log",
"system.pdb",
]
for file in files:
assert (unit_shared / file).exists()
@@ -137,6 +135,6 @@ def test_complex_solvent_sim_gpu(
assert pur.outputs["system_pdb"] == unit_shared / "system.pdb"
assert pur.outputs["minimized_pdb"] == unit_shared / "minimized.pdb"
assert pur.outputs["nc"] == unit_shared / "simulation.xtc"
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.chk"
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.xml"
assert pur.outputs["nvt_equil_pdb"] == unit_shared / "equil_nvt.pdb"
assert pur.outputs["npt_equil_pdb"] == unit_shared / "equil_npt.pdb"

View File

@@ -15,13 +15,27 @@ def protocol():
@pytest.fixture
def protocol_unit(protocol, benzene_system):
def protocol_units(protocol, benzene_system):
pus = protocol.create(
stateA=benzene_system,
stateB=benzene_system,
mapping=None,
)
return list(pus.protocol_units)[0]
return list(pus.protocol_units)
@pytest.fixture
def protocol_setup_unit(protocol, protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_md.PlainMDSetupUnit):
return pu
@pytest.fixture
def protocol_simulation_unit(protocol, protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_md.PlainMDSimulationUnit):
return pu
@pytest.fixture
@@ -48,14 +62,14 @@ class TestPlainMDProtocol(GufeTokenizableTestsMixin):
assert self.repr in repr(instance)
class TestPlainMDProtocolUnit(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDProtocolUnit
repr = "PlainMDProtocolUnit("
class TestPlainMDSetupUnit(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDSetupUnit
repr = "PlainMDSetupUnit("
key = None
@pytest.fixture
def instance(self, protocol_unit):
return protocol_unit
def instance(self, protocol_setup_unit):
return protocol_setup_unit
def test_repr(self, instance):
"""
@@ -65,6 +79,20 @@ class TestPlainMDProtocolUnit(GufeTokenizableTestsMixin):
assert self.repr in repr(instance)
class TestPlainMDSimulationUnit(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDSimulationUnit
repr = "PlainMDSimulationUnit("
key = None
@pytest.fixture()
def instance(self, protocol_simulation_unit):
return protocol_simulation_unit
def test_repr(self, instance):
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestPlainMDProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_md.PlainMDProtocolResult
key = None