mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
Plain MD restart support (#1884)
* split md protocol to setup and simulate, add restart support * allow for resume in any stage * add a single run dynamics function * update restart to only look for checkpoints, split out remaining step logic, update tests --------- Co-authored-by: Irfan Alibay <IAlibay@users.noreply.github.com>
This commit is contained in:
@@ -16,7 +16,8 @@ Protocol API Specification
|
||||
:toctree: generated/
|
||||
|
||||
PlainMDProtocol
|
||||
PlainMDProtocolUnit
|
||||
PlainMDSetupUnit
|
||||
PlainMDSimulationUnit
|
||||
PlainMDProtocolResult
|
||||
|
||||
|
||||
|
||||
24
news/resume-plainmd.rst
Normal file
24
news/resume-plainmd.rst
Normal file
@@ -0,0 +1,24 @@
|
||||
**Added:**
|
||||
|
||||
* * Added API support for resuming the PlainMDProtocol.
|
||||
PR #1884.
|
||||
|
||||
**Changed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Deprecated:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Removed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Fixed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
**Security:**
|
||||
|
||||
* <news item>
|
||||
@@ -22,10 +22,16 @@ zenodo_resume_data = dict(
|
||||
fname="multistate_checkpoints.zip",
|
||||
known_hash="md5:a6bdceff0c4a2f200538edb17c21d443",
|
||||
)
|
||||
zenodo_md_resume_data = dict(
|
||||
base_url="doi:10.5281/zenodo.19694944",
|
||||
fname="checkpoint.xml",
|
||||
known_hash="md5:0f3957c263b5def8de727c5c419b31b5",
|
||||
)
|
||||
|
||||
zenodo_data_registry = [
|
||||
zenodo_rfe_simulation_nc,
|
||||
zenodo_t4_lysozyme_traj,
|
||||
zenodo_industry_benchmark_systems,
|
||||
zenodo_resume_data,
|
||||
zenodo_md_resume_data,
|
||||
]
|
||||
|
||||
@@ -65,7 +65,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
OpenFFPartialChargeSettings,
|
||||
ThermoSettings,
|
||||
)
|
||||
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
|
||||
from openfe.protocols.openmm_md.plain_md_methods import PlainMDSimulationUnit
|
||||
from openfe.protocols.openmm_utils import (
|
||||
charge_generation,
|
||||
multistate_analysis,
|
||||
@@ -320,9 +320,9 @@ class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin):
|
||||
box = system.getDefaultPeriodicBoxVectors()
|
||||
return positions, to_openmm(from_openmm(box))
|
||||
|
||||
# Use the _run_MD method from the PlainMDProtocolUnit
|
||||
# Use the _run_MD method from the PlainMDSimulationUnit
|
||||
# Should in-place modify the simulation
|
||||
PlainMDProtocolUnit._run_MD(
|
||||
PlainMDSimulationUnit._run_MD(
|
||||
simulation=simulation,
|
||||
positions=positions,
|
||||
simulation_settings=settings["equil_simulation_settings"],
|
||||
|
||||
@@ -9,12 +9,14 @@ from .plain_md_methods import (
|
||||
PlainMDProtocol,
|
||||
PlainMDProtocolResult,
|
||||
PlainMDProtocolSettings,
|
||||
PlainMDProtocolUnit,
|
||||
PlainMDSetupUnit,
|
||||
PlainMDSimulationUnit,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PlainMDProtocol",
|
||||
"PlainMDProtocolSettings",
|
||||
"PlainMDProtocolResult",
|
||||
"PlainMDProtocolUnit",
|
||||
"PlainMDSetupUnit",
|
||||
"PlainMDSimulationUnit",
|
||||
]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -56,7 +56,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
|
||||
OpenMMSystemGeneratorFFSettings,
|
||||
ThermoSettings,
|
||||
)
|
||||
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
|
||||
from openfe.protocols.openmm_md.plain_md_methods import PlainMDSimulationUnit
|
||||
from openfe.protocols.openmm_utils import omm_compute
|
||||
from openfe.protocols.openmm_utils.omm_settings import SettingsBaseModel
|
||||
from openfe.protocols.openmm_utils.serialization import deserialize
|
||||
@@ -202,9 +202,9 @@ def _pre_equilibrate(
|
||||
errmsg = f"Only 'A', 'B', and 'AB' are accepted as endstates. Got {endstate}"
|
||||
raise ValueError(errmsg)
|
||||
|
||||
# Use the _run_MD method from the PlainMDProtocolUnit
|
||||
# Use the _run_MD method from the PlainMDSimulationUnit
|
||||
# Should in-place modify the simulation
|
||||
PlainMDProtocolUnit._run_MD(
|
||||
PlainMDSimulationUnit._run_MD(
|
||||
simulation=simulation,
|
||||
positions=positions,
|
||||
simulation_settings=settings["equil_simulation_settings"],
|
||||
|
||||
@@ -237,4 +237,4 @@ styrene
|
||||
10 11 4 0 0 0 0
|
||||
11 16 1 0 0 0 0
|
||||
M END
|
||||
$$$$
|
||||
$$$$
|
||||
@@ -20,6 +20,7 @@ import openfe
|
||||
from openfe.data._registry import (
|
||||
POOCH_CACHE,
|
||||
zenodo_industry_benchmark_systems,
|
||||
zenodo_md_resume_data,
|
||||
zenodo_resume_data,
|
||||
zenodo_rfe_simulation_nc,
|
||||
zenodo_t4_lysozyme_traj,
|
||||
@@ -406,6 +407,19 @@ def septop_solv_checkpoint_path():
|
||||
return pathlib.Path(pooch.os_cache("openfe") / f"{topdir}/{subdir}/{filename}")
|
||||
|
||||
|
||||
pooch_md_resume_data = pooch.create(
|
||||
path=POOCH_CACHE,
|
||||
base_url=zenodo_md_resume_data["base_url"],
|
||||
registry={zenodo_md_resume_data["fname"]: zenodo_md_resume_data["known_hash"]},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def plain_md_checkpoint_path():
|
||||
pooch_md_resume_data.fetch("checkpoint.xml")
|
||||
return pathlib.Path(pooch.os_cache("openfe") / "checkpoint.xml")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def available_platforms() -> set[str]:
|
||||
return {
|
||||
|
||||
@@ -7,8 +7,10 @@ import sys
|
||||
from unittest import mock
|
||||
|
||||
import gufe
|
||||
import numpy as np
|
||||
import openmm
|
||||
import pytest
|
||||
from gufe import ChemicalSystem, SmallMoleculeComponent
|
||||
from gufe import ChemicalSystem, LigandAtomMapping, SmallMoleculeComponent
|
||||
from numpy.testing import assert_allclose
|
||||
from openff.units import unit
|
||||
from openff.units.openmm import from_openmm, to_openmm
|
||||
@@ -22,8 +24,10 @@ from openfe.protocols import openmm_md
|
||||
from openfe.protocols.openmm_md.plain_md_methods import (
|
||||
PlainMDProtocol,
|
||||
PlainMDProtocolResult,
|
||||
PlainMDProtocolUnit,
|
||||
PlainMDSetupUnit,
|
||||
PlainMDSimulationUnit,
|
||||
)
|
||||
from openfe.protocols.openmm_utils import serialization
|
||||
from openfe.protocols.openmm_utils.charge_generation import (
|
||||
HAS_ESPALOMA_CHARGE,
|
||||
HAS_NAGL,
|
||||
@@ -40,6 +44,25 @@ def vac_settings():
|
||||
return settings
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"inputs, expected",
|
||||
[
|
||||
# inputs are current step count, nvt steps, npt steps and prod steps
|
||||
# outputs are steps to run for nvt, npt, prod and if the production phase has started
|
||||
pytest.param([50, 100, 100, 100], [50, 100, 100, False], id="nvt resuming"),
|
||||
pytest.param([101, 100, 100, 100], [0, 99, 100, False], id="npt resuming"),
|
||||
pytest.param([220, 100, 100, 100], [0, 0, 80, True], id="prod resuming"),
|
||||
pytest.param([200, 100, 100, 100], [0, 0, 100, False], id="prod resuming not started"),
|
||||
],
|
||||
)
|
||||
def test_get_remaining_steps(inputs, expected):
|
||||
nvt, npt, prod, is_prod = PlainMDSimulationUnit._get_remaining_steps(*inputs)
|
||||
assert nvt == expected[0]
|
||||
assert npt == expected[1]
|
||||
assert prod == expected[2]
|
||||
assert is_prod == expected[3]
|
||||
|
||||
|
||||
def test_create_default_settings():
|
||||
settings = PlainMDProtocol.default_settings()
|
||||
|
||||
@@ -92,7 +115,7 @@ def test_create_independent_repeat_ids(benzene_system):
|
||||
)
|
||||
|
||||
repeat_ids = set()
|
||||
u: PlainMDProtocolUnit
|
||||
u: PlainMDSetupUnit | PlainMDSimulationUnit
|
||||
for u in dag1.protocol_units:
|
||||
repeat_ids.add(u.inputs["repeat_id"])
|
||||
for u in dag2.protocol_units:
|
||||
@@ -104,14 +127,14 @@ def test_create_independent_repeat_ids(benzene_system):
|
||||
def test_dry_run_default_vacuum(benzene_vacuum_system, vac_settings, tmp_path):
|
||||
protocol = PlainMDProtocol(settings=vac_settings)
|
||||
|
||||
# create DAG from protocol and take first (and only) work unit from within
|
||||
# create DAG from protocol and take the setup unit
|
||||
dag = protocol.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=benzene_vacuum_system,
|
||||
mapping=None,
|
||||
)
|
||||
dag_unit = list(dag.protocol_units)[0]
|
||||
result = dag_unit.run(
|
||||
setup_unit = list(dag.protocol_units)[0]
|
||||
result = setup_unit.run(
|
||||
dry=True, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)
|
||||
system = result["debug"]["system"]
|
||||
@@ -138,22 +161,46 @@ def test_dry_run_logger_output(benzene_vacuum_system, vac_settings, tmp_path, ca
|
||||
settings=vac_settings,
|
||||
)
|
||||
|
||||
# create DAG from protocol and take first (and only) work unit from within
|
||||
# create DAG from protocol
|
||||
dag = protocol.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=benzene_vacuum_system,
|
||||
mapping=None,
|
||||
)
|
||||
dag_unit = list(dag.protocol_units)[0]
|
||||
setup_unit = list(dag.protocol_units)[0]
|
||||
|
||||
caplog.set_level(logging.INFO)
|
||||
dag_unit.run(dry=False, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
|
||||
setup_results = setup_unit.run(
|
||||
dry=False, verbose=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)
|
||||
|
||||
messages = [r.message for r in caplog.records]
|
||||
assert "minimizing systems" in messages
|
||||
assert "Running NVT equilibration" in messages
|
||||
assert "Running NPT equilibration" in messages
|
||||
assert "running production phase" in messages
|
||||
assert "Creating system" in messages
|
||||
# now run the production unit after extracting outputs from the setup unit
|
||||
system = serialization.deserialize(setup_results["system"])
|
||||
positions = np.load(setup_results["positions"]) * omm_unit.nanometers
|
||||
topology = openmm.app.PDBFile(str(setup_results["system_pdb"])).getTopology()
|
||||
equil_steps_nvt = setup_results["equil_steps_nvt"]
|
||||
equil_steps_npt = setup_results["equil_steps_npt"]
|
||||
prod_steps = setup_results["prod_steps"]
|
||||
prod_unit = list(dag.protocol_units)[1]
|
||||
prod_unit.run(
|
||||
system=system,
|
||||
positions=positions,
|
||||
topology=topology,
|
||||
equil_steps_nvt=equil_steps_nvt,
|
||||
equil_steps_npt=equil_steps_npt,
|
||||
prod_steps=prod_steps,
|
||||
dry=False,
|
||||
verbose=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)
|
||||
messages = [r.message for r in caplog.records]
|
||||
assert "Minimizing systems" in messages
|
||||
assert "Running NVT equilibration for 250 steps" in messages
|
||||
assert "Running NPT equilibration for 250 steps" in messages
|
||||
assert "Running production phase for 250 steps" in messages
|
||||
|
||||
|
||||
def test_dry_run_ffcache_none_vacuum(benzene_vacuum_system, vac_settings, tmp_path):
|
||||
@@ -441,17 +488,14 @@ def test_hightimestep(benzene_vacuum_system, tmp_path):
|
||||
settings.forcefield_settings.nonbonded_method = "nocutoff"
|
||||
|
||||
p = PlainMDProtocol(settings=settings)
|
||||
|
||||
dag = p.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=benzene_vacuum_system,
|
||||
mapping=None,
|
||||
)
|
||||
dag_unit = list(dag.protocol_units)[0]
|
||||
|
||||
errmsg = "too large for hydrogen mass"
|
||||
# make sure this is triggered in validate
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
dag_unit.run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
|
||||
_ = p.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=benzene_vacuum_system,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
|
||||
def test_vaccuum_PME_error(benzene_vacuum_system):
|
||||
@@ -459,6 +503,7 @@ def test_vaccuum_PME_error(benzene_vacuum_system):
|
||||
settings=PlainMDProtocol.default_settings(),
|
||||
)
|
||||
errmsg = "PME cannot be used for vacuum transform"
|
||||
# make sure this is triggered in validate
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
_ = p.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
@@ -486,6 +531,35 @@ def test_multiple_basesolvents_error(a2a_protein_membrane_component):
|
||||
)
|
||||
|
||||
|
||||
def test_states_not_matching_error(benzene_vacuum_system, toluene_vacuum_system):
|
||||
p = PlainMDProtocol(settings=PlainMDProtocol.default_settings())
|
||||
errmsg = "The two end states do not match."
|
||||
with pytest.raises(ValueError, match=errmsg):
|
||||
_ = p.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=toluene_vacuum_system,
|
||||
mapping=None,
|
||||
)
|
||||
|
||||
|
||||
def test_mapping_warning(benzene_vacuum_system, tmp_path):
|
||||
settings = PlainMDProtocol.default_settings()
|
||||
settings.forcefield_settings.nonbonded_method = "nocutoff"
|
||||
p = PlainMDProtocol(settings=settings)
|
||||
warnmsg = "A mapping was passed but is not used by this Protocol."
|
||||
benzene = benzene_vacuum_system.components["ligand"]
|
||||
with pytest.warns(match=warnmsg):
|
||||
_ = p.create(
|
||||
stateA=benzene_vacuum_system,
|
||||
stateB=benzene_vacuum_system,
|
||||
mapping=LigandAtomMapping(
|
||||
componentA=benzene,
|
||||
componentB=benzene,
|
||||
componentA_to_componentB=dict((i, i) for i in range(12)),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def solvent_protocol_dag(benzene_system):
|
||||
settings = PlainMDProtocol.default_settings()
|
||||
@@ -499,11 +573,20 @@ def solvent_protocol_dag(benzene_system):
|
||||
)
|
||||
|
||||
|
||||
def test_unit_tagging(solvent_protocol_dag, tmp_path):
|
||||
def test_unit_tagging(benzene_system, tmp_path):
|
||||
# test that executing the Units includes correct generation and repeat info
|
||||
dag_units = solvent_protocol_dag.protocol_units
|
||||
settings = PlainMDProtocol.default_settings()
|
||||
settings.protocol_repeats = 3
|
||||
protocol = PlainMDProtocol(settings=settings)
|
||||
dag = protocol.create(
|
||||
stateA=benzene_system,
|
||||
stateB=benzene_system,
|
||||
mapping=None,
|
||||
)
|
||||
dag_units = dag.protocol_units
|
||||
|
||||
with mock.patch(
|
||||
"openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run",
|
||||
"openfe.protocols.openmm_md.plain_md_methods.PlainMDSimulationUnit.run",
|
||||
return_value={
|
||||
"nc": "simulation.xtc",
|
||||
"last_checkpoint": "checkpoint.chk",
|
||||
@@ -511,8 +594,10 @@ def test_unit_tagging(solvent_protocol_dag, tmp_path):
|
||||
):
|
||||
results = []
|
||||
for u in dag_units:
|
||||
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
|
||||
results.append(ret)
|
||||
# just execute the setup unit so we don't have to pass the results though to the simulation unit
|
||||
if isinstance(u, PlainMDSetupUnit):
|
||||
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
|
||||
results.append(ret)
|
||||
|
||||
repeats = set()
|
||||
for ret in results:
|
||||
@@ -526,7 +611,7 @@ def test_unit_tagging(solvent_protocol_dag, tmp_path):
|
||||
def test_gather(solvent_protocol_dag, tmp_path):
|
||||
# check .gather behaves as expected
|
||||
with mock.patch(
|
||||
"openfe.protocols.openmm_md.plain_md_methods.PlainMDProtocolUnit.run",
|
||||
"openfe.protocols.openmm_md.plain_md_methods.PlainMDSimulationUnit.run",
|
||||
return_value={
|
||||
"nc": "simulation.xtc",
|
||||
"last_checkpoint": "checkpoint.chk",
|
||||
|
||||
167
src/openfe/tests/protocols/openmm_md/test_plain_md_resume.py
Normal file
167
src/openfe/tests/protocols/openmm_md/test_plain_md_resume.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# This code is part of OpenFE and is licensed under the MIT license.
|
||||
# For details, see https://github.com/OpenFreeEnergy/openfe
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
|
||||
import gufe
|
||||
import openmm
|
||||
import openmm.unit as openmm_unit
|
||||
import pytest
|
||||
from gufe import ChemicalSystem, SmallMoleculeComponent
|
||||
from gufe.protocols.errors import ProtocolUnitExecutionError
|
||||
from openff.units import unit
|
||||
|
||||
import openfe
|
||||
from openfe.data._registry import POOCH_CACHE
|
||||
from openfe.protocols.openmm_md.plain_md_methods import (
|
||||
PlainMDProtocol,
|
||||
PlainMDSetupUnit,
|
||||
PlainMDSimulationUnit,
|
||||
)
|
||||
|
||||
from ...conftest import HAS_INTERNET
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def vacuum_protocol_settings():
|
||||
# setup a cheap vacuum md protocol
|
||||
settings = PlainMDProtocol.default_settings()
|
||||
settings.protocol_repeats = 1
|
||||
settings.forcefield_settings.nonbonded_method = "nocutoff"
|
||||
settings.engine_settings.compute_platform = None
|
||||
settings.simulation_settings.equilibration_length_nvt = 1 * unit.picoseconds
|
||||
settings.simulation_settings.equilibration_length = 1 * unit.picoseconds
|
||||
settings.simulation_settings.production_length = 1 * unit.picoseconds
|
||||
settings.output_settings.checkpoint_interval = 0.5 * unit.picoseconds
|
||||
settings.output_settings.trajectory_write_interval = 0.5 * unit.picoseconds
|
||||
return settings
|
||||
|
||||
|
||||
def test_verify_execution_environment():
|
||||
# verify using the current versions of the software
|
||||
PlainMDSimulationUnit._verify_execution_environment(
|
||||
setup_outputs={
|
||||
"gufe_version": gufe.__version__,
|
||||
"openfe_version": openfe.__version__,
|
||||
"openmm_version": openmm.__version__,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_verify_execution_environment_fail():
|
||||
# pass in different versions to force failure
|
||||
with pytest.raises(ProtocolUnitExecutionError, match="Python environment"):
|
||||
PlainMDSimulationUnit._verify_execution_environment(
|
||||
setup_outputs={
|
||||
"gufe_version": 0.1,
|
||||
"openfe_version": openmm.__version__,
|
||||
"openmm_version": openmm.__version__,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_verify_execution_env_missing_key():
|
||||
errmsg = "Missing environment information from setup outputs."
|
||||
with pytest.raises(ProtocolUnitExecutionError, match=errmsg):
|
||||
PlainMDSimulationUnit._verify_execution_environment(
|
||||
setup_outputs={
|
||||
"foo_version": 0.1,
|
||||
"openfe_version": openfe.__version__,
|
||||
"openmm_version": openmm.__version__,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.path.exists(POOCH_CACHE) and not HAS_INTERNET,
|
||||
reason="Internet unavailable and test data is not cached locally",
|
||||
)
|
||||
def test_check_restart(vacuum_protocol_settings, plain_md_checkpoint_path):
|
||||
# test we can correctly detect when we should be restarting
|
||||
assert PlainMDSimulationUnit._check_restart(
|
||||
output_settings=vacuum_protocol_settings.output_settings,
|
||||
shared_path=plain_md_checkpoint_path.parent,
|
||||
)
|
||||
|
||||
# make sure it does not try and restart if inputs are missing
|
||||
assert not PlainMDSimulationUnit._check_restart(
|
||||
output_settings=vacuum_protocol_settings.output_settings,
|
||||
shared_path=pathlib.Path("."),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.path.exists(POOCH_CACHE) and not HAS_INTERNET,
|
||||
reason="Internet unavailable and test data is not cached locally",
|
||||
)
|
||||
class TestPlainMDResume:
|
||||
@pytest.fixture
|
||||
def protocol_dag(self, vacuum_protocol_settings, benzene_vacuum_system):
|
||||
protocol = PlainMDProtocol(
|
||||
settings=vacuum_protocol_settings,
|
||||
)
|
||||
return protocol.create(
|
||||
stateA=benzene_vacuum_system, stateB=benzene_vacuum_system, mapping=None
|
||||
)
|
||||
|
||||
def test_resume(
|
||||
self, protocol_dag, tmp_path, caplog, vacuum_protocol_settings, plain_md_checkpoint_path
|
||||
):
|
||||
# test that we can resume a simulation from a checkpoint
|
||||
protocol_units = list(protocol_dag.protocol_units)
|
||||
setup_unit: PlainMDSetupUnit = protocol_units[0]
|
||||
simulation_unit: PlainMDSimulationUnit = protocol_units[1]
|
||||
# copy the files over
|
||||
shutil.copyfile(plain_md_checkpoint_path, tmp_path / "checkpoint.xml")
|
||||
# dry run the setup unit
|
||||
setup_results = setup_unit.run(
|
||||
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
|
||||
)
|
||||
# make sure the protocol thinks it can restart
|
||||
assert PlainMDSimulationUnit._check_restart(
|
||||
output_settings=vacuum_protocol_settings.output_settings,
|
||||
shared_path=tmp_path,
|
||||
)
|
||||
# now run the simulation unit in resume mode this should be 0.5 ps of equilibration and 1 ps of production
|
||||
sim_results = simulation_unit.run(
|
||||
system=setup_results["debug"]["system"],
|
||||
positions=setup_results["debug"]["positions"],
|
||||
topology=setup_results["debug"]["topology"],
|
||||
equil_steps_nvt=setup_results["equil_steps_nvt"],
|
||||
equil_steps_npt=setup_results["equil_steps_npt"],
|
||||
prod_steps=setup_results["prod_steps"],
|
||||
verbose=True,
|
||||
scratch_basepath=tmp_path,
|
||||
shared_basepath=tmp_path,
|
||||
)
|
||||
# make sure it prints that its restarting
|
||||
assert "Restarting simulation from checkpoint state" in caplog.text
|
||||
# check the number of npt steps to run is correct, this should be 0.5 ps at 4fs timestep
|
||||
assert "Running NPT equilibration for 125 steps" in caplog.text
|
||||
# make sure the production phase steps are correct, this should be the full 1ps at 4fs timestep
|
||||
assert "Running production phase for 250 steps" in caplog.text
|
||||
|
||||
# check the outputs of the simulation unit
|
||||
assert sim_results["system_pdb"].exists()
|
||||
assert sim_results["nc"].exists()
|
||||
assert sim_results["last_checkpoint"]
|
||||
|
||||
# load the final checkpoint and check the simulation time is correct, this should be 3 ps
|
||||
# also check the total step count
|
||||
simulation = openmm.app.Simulation(
|
||||
setup_results["debug"]["topology"],
|
||||
setup_results["debug"]["system"],
|
||||
openmm.LangevinMiddleIntegrator(
|
||||
298.15 * openmm_unit.kelvin,
|
||||
1.0 / openmm_unit.picosecond,
|
||||
4 * openmm_unit.femtoseconds,
|
||||
),
|
||||
)
|
||||
simulation.context.setPositions(setup_results["debug"]["positions"])
|
||||
simulation.loadState(str(sim_results["last_checkpoint"]))
|
||||
total_sim_time = simulation.context.getTime()
|
||||
# check the time is 3 ps
|
||||
assert total_sim_time.value_in_unit(openmm_unit.picoseconds) == pytest.approx(3)
|
||||
# check the step count has been extended
|
||||
assert simulation.context.getStepCount() == 750
|
||||
@@ -25,7 +25,7 @@ def test_vacuum_sim(
|
||||
settings.simulation_settings.equilibration_length_nvt = None
|
||||
settings.simulation_settings.equilibration_length = 10 * unit.picosecond
|
||||
settings.simulation_settings.production_length = 20 * unit.picosecond
|
||||
settings.output_settings.checkpoint_interval = 40 * unit.picosecond
|
||||
settings.output_settings.checkpoint_interval = 5 * unit.picosecond
|
||||
settings.forcefield_settings.nonbonded_method = "nocutoff"
|
||||
settings.engine_settings.compute_platform = platform
|
||||
|
||||
@@ -46,9 +46,9 @@ def test_vacuum_sim(
|
||||
|
||||
assert r.ok()
|
||||
|
||||
assert len(r.protocol_unit_results) == 1
|
||||
assert len(r.protocol_unit_results) == 2
|
||||
|
||||
pur = r.protocol_unit_results[0]
|
||||
pur = r.protocol_unit_results[1]
|
||||
unit_shared = tmp_path / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
@@ -59,20 +59,19 @@ def test_vacuum_sim(
|
||||
"minimized.pdb",
|
||||
"simulation.xtc",
|
||||
"simulation.log",
|
||||
"system.pdb",
|
||||
"checkpoint.xml",
|
||||
]
|
||||
for file in files:
|
||||
assert (unit_shared / file).exists()
|
||||
|
||||
# NVT PDB should not exist
|
||||
assert not (unit_shared / "equil_nvt.pdb").exists()
|
||||
assert not (unit_shared / "checkpoint.chk").exists()
|
||||
|
||||
# check that the output file paths are correct
|
||||
assert pur.outputs["system_pdb"] == unit_shared / "system.pdb"
|
||||
assert pur.outputs["minimized_pdb"] == unit_shared / "minimized.pdb"
|
||||
assert pur.outputs["nc"] == unit_shared / "simulation.xtc"
|
||||
assert pur.outputs["last_checkpoint"] is None
|
||||
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.xml"
|
||||
assert pur.outputs["npt_equil_pdb"] == unit_shared / "equil_npt.pdb"
|
||||
assert pur.outputs["nvt_equil_pdb"] is None
|
||||
|
||||
@@ -113,22 +112,21 @@ def test_complex_solvent_sim_gpu(
|
||||
|
||||
assert r.ok()
|
||||
|
||||
assert len(r.protocol_unit_results) == 1
|
||||
assert len(r.protocol_unit_results) == 2
|
||||
|
||||
pur = r.protocol_unit_results[0]
|
||||
pur = r.protocol_unit_results[1]
|
||||
unit_shared = tmp_path / f"shared_{pur.source_key}_attempt_0"
|
||||
assert unit_shared.exists()
|
||||
assert pathlib.Path(unit_shared).is_dir()
|
||||
|
||||
# check the files
|
||||
files = [
|
||||
"checkpoint.chk",
|
||||
"checkpoint.xml",
|
||||
"equil_nvt.pdb",
|
||||
"equil_npt.pdb",
|
||||
"minimized.pdb",
|
||||
"simulation.xtc",
|
||||
"simulation.log",
|
||||
"system.pdb",
|
||||
]
|
||||
for file in files:
|
||||
assert (unit_shared / file).exists()
|
||||
@@ -137,6 +135,6 @@ def test_complex_solvent_sim_gpu(
|
||||
assert pur.outputs["system_pdb"] == unit_shared / "system.pdb"
|
||||
assert pur.outputs["minimized_pdb"] == unit_shared / "minimized.pdb"
|
||||
assert pur.outputs["nc"] == unit_shared / "simulation.xtc"
|
||||
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.chk"
|
||||
assert pur.outputs["last_checkpoint"] == unit_shared / "checkpoint.xml"
|
||||
assert pur.outputs["nvt_equil_pdb"] == unit_shared / "equil_nvt.pdb"
|
||||
assert pur.outputs["npt_equil_pdb"] == unit_shared / "equil_npt.pdb"
|
||||
|
||||
@@ -15,13 +15,27 @@ def protocol():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol_unit(protocol, benzene_system):
|
||||
def protocol_units(protocol, benzene_system):
|
||||
pus = protocol.create(
|
||||
stateA=benzene_system,
|
||||
stateB=benzene_system,
|
||||
mapping=None,
|
||||
)
|
||||
return list(pus.protocol_units)[0]
|
||||
return list(pus.protocol_units)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol_setup_unit(protocol, protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_md.PlainMDSetupUnit):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def protocol_simulation_unit(protocol, protocol_units):
|
||||
for pu in protocol_units:
|
||||
if isinstance(pu, openmm_md.PlainMDSimulationUnit):
|
||||
return pu
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -48,14 +62,14 @@ class TestPlainMDProtocol(GufeTokenizableTestsMixin):
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestPlainMDProtocolUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_md.PlainMDProtocolUnit
|
||||
repr = "PlainMDProtocolUnit("
|
||||
class TestPlainMDSetupUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_md.PlainMDSetupUnit
|
||||
repr = "PlainMDSetupUnit("
|
||||
key = None
|
||||
|
||||
@pytest.fixture
|
||||
def instance(self, protocol_unit):
|
||||
return protocol_unit
|
||||
def instance(self, protocol_setup_unit):
|
||||
return protocol_setup_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
"""
|
||||
@@ -65,6 +79,20 @@ class TestPlainMDProtocolUnit(GufeTokenizableTestsMixin):
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestPlainMDSimulationUnit(GufeTokenizableTestsMixin):
|
||||
cls = openmm_md.PlainMDSimulationUnit
|
||||
repr = "PlainMDSimulationUnit("
|
||||
key = None
|
||||
|
||||
@pytest.fixture()
|
||||
def instance(self, protocol_simulation_unit):
|
||||
return protocol_simulation_unit
|
||||
|
||||
def test_repr(self, instance):
|
||||
assert isinstance(repr(instance), str)
|
||||
assert self.repr in repr(instance)
|
||||
|
||||
|
||||
class TestPlainMDProtocolResult(GufeTokenizableTestsMixin):
|
||||
cls = openmm_md.PlainMDProtocolResult
|
||||
key = None
|
||||
|
||||
Reference in New Issue
Block a user