diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 37742b46..83beb542 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -14,7 +14,7 @@ see https://regro.github.io/rever-docs/news.html for details on how to add news Checklist * [ ] All new code is appropriately documented (user-facing code _must_ have complete docstrings). * [ ] Added a ``news`` entry, or the changes are not user-facing. -* [ ] Ran pre-commit by making a comment with `pre-commit.ci autofix` before requesting review. +* [ ] Ran pre-commit: you can run [pre-commit](https://pre-commit.com) locally or comment on this PR with `pre-commit.ci autofix`. Manual Tests: these are slow so don't need to be run every commit, only before merging and when relevant changes are made (generally at reviewer-discretion). * [ ] [GPU integration tests](https://github.com/OpenFreeEnergy/openfe/actions/workflows/gpu-integration-tests.yaml) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2d47c378..627910d2 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -35,7 +35,7 @@ jobs: strategy: fail-fast: false matrix: - os: ["ubuntu-latest", "macos-latest"] + os: ["ubuntu-latest"] openeye: ["no"] python-version: - "3.11" @@ -45,6 +45,9 @@ jobs: - os: "ubuntu-latest" python-version: "3.13" openeye: "yes" + - os: "macos-latest" + python-version: "3.12" + openeye: "no" env: OE_LICENSE: ${{ github.workspace }}/oe_license.txt diff --git a/.github/workflows/conda_cron.yaml b/.github/workflows/conda_cron.yaml index c96f00a9..c31b3fb1 100644 --- a/.github/workflows/conda_cron.yaml +++ b/.github/workflows/conda_cron.yaml @@ -20,11 +20,15 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest', 'macos-latest'] + os: ['ubuntu-latest'] python-version: - "3.11" - "3.12" - "3.13" + include: + - os: "macos-latest" + python-version: "3.12" + openeye: "no" steps: - name: Checkout Code uses: actions/checkout@v4 diff --git a/.github/workflows/cpu-long-tests.yaml b/.github/workflows/cpu-long-tests.yaml index 346e3a93..8cc1e480 100644 --- a/.github/workflows/cpu-long-tests.yaml +++ b/.github/workflows/cpu-long-tests.yaml @@ -92,7 +92,7 @@ jobs: DUECREDIT_ENABLE: 'yes' OFE_INTEGRATION_TESTS: FALSE run: | - pytest -n logical -vv --durations=10 --runslow openfecli/tests/ openfe/tests/ + pytest -n logical -vv --durations=10 --runslow src/openfecli/tests/ src/openfe/tests/ stop-aws-runner: runs-on: ubuntu-latest diff --git a/.github/workflows/gpu-integration-tests.yaml b/.github/workflows/gpu-integration-tests.yaml index 22db4116..535e526f 100644 --- a/.github/workflows/gpu-integration-tests.yaml +++ b/.github/workflows/gpu-integration-tests.yaml @@ -96,7 +96,7 @@ jobs: OFE_INTEGRATION_TESTS: TRUE run: | # The -m flag will only run tests with @pytest.mark.integration - pytest -n logical -vv --durations=10 -m integration openfecli/tests/ openfe/tests/ + pytest -n logical -vv --durations=10 -m integration src/openfecli/tests/ src/openfe/tests/ stop-aws-runner: runs-on: ubuntu-latest diff --git a/.github/workflows/test-example-notebooks.yaml b/.github/workflows/test-example-notebooks.yaml index fce82d42..df27c1fb 100644 --- a/.github/workflows/test-example-notebooks.yaml +++ b/.github/workflows/test-example-notebooks.yaml @@ -12,7 +12,7 @@ defaults: shell: bash -leo pipefail {0} jobs: - test-conda-build: + test-example-notebooks: runs-on: ubuntu-latest steps: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a18de974..d548167b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,6 +10,7 @@ repos: rev: v6.0.0 hooks: - id: check-added-large-files + args: ["--maxkb=900"] - id: check-case-conflict - id: check-executables-have-shebangs - id: check-symlinks @@ -19,12 +20,12 @@ repos: - id: debug-statements - repo: https://github.com/tox-dev/pyproject-fmt - rev: "v2.8.0" + rev: "v2.11.1" hooks: - id: pyproject-fmt - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.13.3 + rev: v0.14.10 hooks: # Run the linter. - id: ruff diff --git a/MANIFEST.in b/MANIFEST.in index 6e18d4c7..d2479fce 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,18 +1,18 @@ -recursive-include openfe/tests/data/ *.sdf -recursive-include openfe/tests/data/ *.bz2 -recursive-include openfe/tests/data/ *.csv -recursive-include openfe/tests/data/ *.pdb -recursive-include openfe/tests/data/ *.mol2 -recursive-include openfe/tests/data/ *.xml -recursive-include openfe/tests/data/ *.graphml -recursive-include openfe/tests/data/ *.edge -recursive-include openfe/tests/data/ *.dat -recursive-include openfe/tests/data/ *.txt -recursive-include openfe/tests/data/ *.gz -recursive-include openfe/tests/data/ *json_results.gz -include openfecli/tests/data/*.json -include openfecli/tests/data/*.tar.gz -include openfecli/tests/commands/test_gather/*.tsv -recursive-include openfecli/tests/ *.sdf -recursive-include openfecli/tests/ *.pdb -include openfe/tests/data/openmm_rfe/vacuum_nocoord.nc +recursive-include src/openfe/tests/data/ *.sdf +recursive-include src/openfe/tests/data/ *.bz2 +recursive-include src/openfe/tests/data/ *.csv +recursive-include src/openfe/tests/data/ *.pdb +recursive-include src/openfe/tests/data/ *.mol2 +recursive-include src/openfe/tests/data/ *.xml +recursive-include src/openfe/tests/data/ *.graphml +recursive-include src/openfe/tests/data/ *.edge +recursive-include src/openfe/tests/data/ *.dat +recursive-include src/openfe/tests/data/ *.txt +recursive-include src/openfe/tests/data/ *.gz +recursive-include src/openfe/tests/data/ *json_results.gz +include src/openfecli/tests/data/*.json +include src/openfecli/tests/data/*.tar.gz +include src/openfecli/tests/commands/test_gather/*.tsv +recursive-include src/openfecli/tests/ *.sdf +recursive-include src/openfecli/tests/ *.pdb +include src/openfe/tests/data/openmm_rfe/vacuum_nocoord.nc diff --git a/docs/CHANGELOG.rst b/docs/CHANGELOG.rst index 7b262e8d..acce0fc3 100644 --- a/docs/CHANGELOG.rst +++ b/docs/CHANGELOG.rst @@ -4,6 +4,55 @@ Changelog .. current developments + + + +v1.9.0 +==================== + +**Added:** + +* The ``validate`` method for the RelativeHybridTopologyProtocol has been implemented. + This means that settings and system validation can mostly be done prior to Protocol execution by calling ``RelativeHybridTopologyProtocol.validate(stateA, stateB, mapping)`` (`PR 1740 `_). + +* Added ``openfe test --download-only`` flag, which downloads all test data stored remotely to the local cache (`PR 1814 `_). + +**Changed:** + +* The absolute free energy protocols (AbsoluteBindingProtocol and AbsoluteSolvationProtocol) have been broken into multiple + protocol units, allowing for setup, run, and analysis to happen + separately in the future when relevant changes to protocol execution are + made (`PR 1776 `_). +* The relative free energy protocol (RelativeHybridTopologyProtocol) has been + broken into multiple protocol units, allowing for the setup, run, analysis to happen + separately (`PR 1773 `_). + +**Fixed:** + +* Fixed bug in ligand network visualization (such as with ``openfe view-ligand-network``) so that ligand names are no longer cut off by the plot border (`PR 1822 `_). +* Endstates in the RelativeHybridTopologyProtocol are now being created + in a manner that allows for isomorphic molecules that differ between + endstates to have different parameters (`PR 1772 `_). + + + +v1.8.1 +==================== + +**Added:** + +* Added a progress bar for ``openfe gather`` JSON loading (`PR #1786 `_). + +**Fixed:** + +* Due to issues with OpenFF's handling of toolkit registries + with NAGL, the use of NAGL models (e.g. AshGC) when OpenEye + is installed but not requested as the charge backend has been + disabled (Issue #1760, `PR #1762 `_). +* Fixed bug in ligand network visualization (such as with ``openfe view-ligand-network``) so that ligand names are no longer cut off by the plot border (`PR #1822 `_). + + + v1.8.0 ==================== @@ -14,6 +63,7 @@ v1.8.0 * Added experimental features ``openfe gather-septop`` and ``openfe gather-abfe``, which are analogous to ``openfe gather`` and allow for gathering results generated by the Separated Topologies and Absolute Binding Free Energy protocols, respectively. These commands are experimental and are liable to be changed in a future release. * Emit a clarifying log message when a user gets a warning from JAX (`PR #1585 `_, fixes `Issue #1499 `_). * Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (`PR #1694 `_). +* New options have been added to the ``AlchemicalSettings`` of the ``SepTopProtocol``, ``AbsoluteSolvationProtocol`` and ``AbsoluteBindingProtocol``. Notably, these options allow users to control the softcore parameters as well as the use of long range dispersion corrections (`PR #1742 `_). **Changed:** diff --git a/docs/conf.py b/docs/conf.py index 1b4b4723..f54790c2 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -103,6 +103,7 @@ exclude_patterns = [ autodoc_mock_imports = [ "cinnabar", + "dill", "MDAnalysis", "matplotlib", "mdtraj", @@ -192,7 +193,7 @@ try: else: repo = git.Repo.clone_from( "https://github.com/OpenFreeEnergy/ExampleNotebooks.git", - branch="2025.12.04", + branch="2026.01.26", to_path=example_notebooks_path, ) except Exception as e: diff --git a/docs/guide/cli/cli_yaml.rst b/docs/guide/cli/cli_yaml.rst index 84825ef7..7ca6304f 100644 --- a/docs/guide/cli/cli_yaml.rst +++ b/docs/guide/cli/cli_yaml.rst @@ -10,7 +10,7 @@ This settings file has a series of sections for customising the different algori For example, the settings file which re-specifies the default behaviour would look like :: network: - method: plan_minimal_spanning_tree + method: generate_minimal_spanning_network mapper: method: LomapAtomMapper settings: @@ -29,7 +29,7 @@ All sections of the file ``network:``, ``mapper:`` and ``partial_charge:`` are The settings YAML file is then provided to the ``-s`` option of ``openfe plan-rbfe-network``: :: - openfe plan-rbfe-network -M molecules.sdf -P protein.pdb -s settings.yaml + openfe plan-rbfe-network -M molecules.sdf -p protein.pdb -s settings.yaml Customising the atom mapper --------------------------- diff --git a/docs/index.rst b/docs/index.rst index 600d8431..4a0cb6f6 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -25,12 +25,12 @@ Using this toolkit you can plan, execute, and analyze free energy calculations u Follow our installation guide to get **openfe** running on your machine! - .. grid-item-card:: :fas:`laptop-code` CLI + .. grid-item-card:: :fas:`laptop-code` CLI Quickstart :text-align: center - :link: reference/cli/index + :link: tutorials/rbfe_cli_tutorial :link-type: doc - Documentation for **openfe**\'s simple command line interface. + Get started with **openfe**\'s command line interface. .. grid-item-card:: :fas:`person-chalkboard` Tutorials :text-align: center @@ -53,12 +53,12 @@ Using this toolkit you can plan, execute, and analyze free energy calculations u How-to guides for common tasks. - .. grid-item-card:: :fas:`code` Python API + .. grid-item-card:: :fas:`code` API Reference :text-align: center - :link: reference/api/index + :link: reference/index :link-type: doc - Comprehensive details of the **openfe** Python API. + Comprehensive details of the **openfe** Python and CLI APIs. .. grid-item-card:: :fas:`gears` Protocols :text-align: center diff --git a/docs/reference/api/index.rst b/docs/reference/api/index.rst index 5b93b873..d74b545d 100644 --- a/docs/reference/api/index.rst +++ b/docs/reference/api/index.rst @@ -4,7 +4,7 @@ We have reproduced API documentation from the `gufe`_ package here for convenience. `gufe`_ serves as a foundation layer for openfe, providing abstract base classes and object models, and so might be more useful for developers. -OpenFE API Reference +Python API Reference ==================== .. toctree:: diff --git a/docs/reference/api/openmm_binding_afe.rst b/docs/reference/api/openmm_binding_afe.rst index d5548fee..4c7a89ad 100644 --- a/docs/reference/api/openmm_binding_afe.rst +++ b/docs/reference/api/openmm_binding_afe.rst @@ -16,8 +16,12 @@ Protocol API specification :toctree: generated/ AbsoluteBindingProtocol - AbsoluteBindingComplexUnit - AbsoluteBindingSolventUnit + ABFEComplexAnalysisUnit + ABFEComplexSetupUnit + ABFEComplexSimUnit + ABFESolventAnalysisUnit + ABFESolventSetupUnit + ABFESolventSimUnit AbsoluteBindingProtocolResult Protocol Settings diff --git a/docs/reference/api/openmm_rfe.rst b/docs/reference/api/openmm_rfe.rst index 8e03d1d3..18a3a032 100644 --- a/docs/reference/api/openmm_rfe.rst +++ b/docs/reference/api/openmm_rfe.rst @@ -16,7 +16,9 @@ Protocol API specification :toctree: generated/ RelativeHybridTopologyProtocol - RelativeHybridTopologyProtocolUnit + HybridTopologySetupUnit + HybridTopologyMultiStateSimulationUnit + HybridTopologyMultiStateAnalysisUnit RelativeHybridTopologyProtocolResult Protocol Settings diff --git a/docs/reference/api/openmm_solvation_afe.rst b/docs/reference/api/openmm_solvation_afe.rst index de8ef118..c4a4f3de 100644 --- a/docs/reference/api/openmm_solvation_afe.rst +++ b/docs/reference/api/openmm_solvation_afe.rst @@ -16,8 +16,12 @@ Protocol API specification :toctree: generated/ AbsoluteSolvationProtocol - AbsoluteSolvationVacuumUnit - AbsoluteSolvationSolventUnit + AHFESolventAnalysisUnit + AHFESolventSetupUnit + AHFESolventSimUnit + AHFEVacuumAnalysisUnit + AHFEVacuumSetupUnit + AHFEVacuumSimUnit AbsoluteSolvationProtocolResult Protocol Settings diff --git a/environment.yml b/environment.yml index 21ad66a6..a08341b2 100644 --- a/environment.yml +++ b/environment.yml @@ -12,7 +12,7 @@ dependencies: - lomap2>=3.2.1 - networkx - numba - - numpy<2.3 + - numpy - openfe-analysis>=0.3.1 - openff-interchange-base - openff-nagl-base >=0.3.3 @@ -21,14 +21,15 @@ dependencies: - openff-units==0.3.1 # https://github.com/OpenFreeEnergy/openfe/pull/1374 - openmm ~=8.2.0 # omit 8.3.0 and 8.3.1 due to https://github.com/openmm/openmm/pull/5069, unpin once we've qualified 8.3.2 - openmmforcefields >=0.15.0 # min needed for https://github.com/OpenFreeEnergy/openfe/pull/1695 - - openmmtools >=0.25.0 + - openmmtools >=0.25.3 # fix to support numpy >=2.3: https://github.com/choderalab/openmmtools/pull/793 - packaging - pandas + - parmed >=4.3.1 # fix to support numpy >=2.3: https://github.com/ParmEd/ParmEd/pull/1387 - perses>=0.10.3 - plugcli - pint>=0.24.0 - pip - - pooch + - pooch >= 1.9.0 # min needed for https://github.com/fatiando/pooch/issues/502 - py3dmol - pydantic >= 2.0.0, <2.12.0 # https://github.com/openforcefield/openff-interchange/issues/1346 - pygraphviz @@ -55,3 +56,6 @@ dependencies: - pip: - git+https://github.com/OpenFreeEnergy/gufe@main - git+https://github.com/choderalab/pymbar.git@ed40ec3bbef03bb08938ad1a74d459b0d1ab81f7 + - run_constrained: + # drop this pin when handled upstream in espaloma-feedstock + - smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109 diff --git a/news/absolute_settings.rst b/news/absolute_settings.rst deleted file mode 100644 index dc4d715c..00000000 --- a/news/absolute_settings.rst +++ /dev/null @@ -1,27 +0,0 @@ -**Added:** - -* New options have been added to the ``AlchemicalSettings`` - of the ``SepTopProtocol``, ``AbsoluteSolvationProtocol`` - and ``AbsoluteBindingProtocol``. Notably, these options allow users to - control the softcore parameters as well as the use of - long range dispersion corrections. - -**Changed:** - -* - -**Deprecated:** - -* - -**Removed:** - -* - -**Fixed:** - -* - -**Security:** - -* diff --git a/openfe/protocols/openmm_afe/__init__.py b/openfe/protocols/openmm_afe/__init__.py deleted file mode 100644 index 48919cd0..00000000 --- a/openfe/protocols/openmm_afe/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -""" -Run absolute free energy calculations using OpenMM and OpenMMTools. - -""" - -from .equil_binding_afe_method import ( - AbsoluteBindingComplexUnit, - AbsoluteBindingProtocol, - AbsoluteBindingProtocolResult, - AbsoluteBindingSettings, - AbsoluteBindingSolventUnit, -) -from .equil_solvation_afe_method import ( - AbsoluteSolvationProtocol, - AbsoluteSolvationProtocolResult, - AbsoluteSolvationSettings, - AbsoluteSolvationSolventUnit, - AbsoluteSolvationVacuumUnit, -) - -__all__ = [ - "AbsoluteSolvationProtocol", - "AbsoluteSolvationSettings", - "AbsoluteSolvationProtocolResult", - "AbsoluteVacuumUnit", - "AbsoluteSolventUnit", - "AbsoluteBindingProtocol", - "AbsoluteBindingSettings", - "AbsoluteBindingProtocolResult", - "AbsoluteBindingComplexUnit", - "AbsoluteBindingSolventUnit", -] diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py deleted file mode 100644 index dc12f1a8..00000000 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ /dev/null @@ -1,1334 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -"""OpenMM Equilibrium Binding AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_binding_afe_method` -========================================================================================================== - -This module implements the necessary methodology tooling to calculate an -absolute binding free energy using OpenMM tools and one of the following -alchemical sampling methods: - -* Hamiltonian Replica Exchange -* Self-adjusted mixture sampling -* Independent window sampling - -Current limitations -------------------- -* Alchemical species with a net charge are not currently supported. -* Disapearing molecules are only allowed in state A. -* Only small molecules are allowed to act as alchemical molecules. - -Acknowledgements ----------------- -* This Protocol re-implements components from - `Yank `_. - -""" - -import itertools -import logging -import pathlib -import uuid -import warnings -from collections import defaultdict -from typing import Any, Iterable, Optional, Union - -import gufe -import MDAnalysis as mda -import numpy as np -import numpy.typing as npt -from gufe import ( - ChemicalSystem, - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, - settings, -) -from gufe.components import Component -from openff.units import Quantity -from openff.units import unit as offunit -from openff.units.openmm import to_openmm -from openmm import System -from openmm import unit as ommunit -from openmm.app import Topology as omm_topology -from openmmtools import multistate -from openmmtools.states import GlobalParameterState, ThermodynamicState -from rdkit import Chem - -from openfe.due import Doi, due -from openfe.protocols.openmm_afe.equil_afe_settings import ( - ABFEPreEquilOutputSettings, - AbsoluteBindingSettings, - AlchemicalSettings, - BoreschRestraintSettings, - IntegratorSettings, - LambdaSettings, - MDSimulationSettings, - MultiStateOutputSettings, - MultiStateSimulationSettings, - OpenFFPartialChargeSettings, - OpenMMEngineSettings, - OpenMMSolvationSettings, - SettingsBaseModel, -) -from openfe.protocols.openmm_utils import settings_validation, system_validation -from openfe.protocols.restraint_utils import geometry -from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry -from openfe.protocols.restraint_utils.openmm import omm_restraints -from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint - -from .base import BaseAbsoluteUnit - -due.cite( - Doi("10.5281/zenodo.596504"), - description="Yank", - path="openfe.protocols.openmm_afe.equil_binding_afe_method", - cite_module=True, -) - -due.cite( - Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_afe.equil_binding_afe_method", - cite_module=True, -) - -due.cite( - Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_afe.equil_binding_afe_method", - cite_module=True, -) - - -logger = logging.getLogger(__name__) - - -class AbsoluteBindingProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteBindingProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # TODO: Detect when we have extensions and stitch these together? - if any( - len(pur_list) > 2 - for pur_list in itertools.chain( - self.data["solvent"].values(), self.data["complex"].values() - ) - ): - raise NotImplementedError("Can't stitch together results yet") - - def get_individual_estimates( - self, - ) -> dict[str, list[tuple[Quantity, Quantity]]]: - """ - Get the individual estimate of the free energies. - - Returns - ------- - dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] - A dictionary, keyed `solvent`, `complex`, and 'standard_state' - representing each portion of the thermodynamic cycle, - with lists of tuples containing the individual free energy - estimates and, for 'solvent' and 'complex', the associated MBAR - uncertainties for each repeat of that simulation type. - - Notes - ----- - * Standard state correction has no error and so will return a value - of 0. - """ - complex_dGs = [] - correction_dGs = [] - solv_dGs = [] - - for pus in self.data["complex"].values(): - complex_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - correction_dGs.append( - ( - pus[0].outputs["standard_state_correction"], - 0 * offunit.kilocalorie_per_mole, # correction has no error - ) - ) - - for pus in self.data["solvent"].values(): - solv_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - - return { - "solvent": solv_dGs, - "complex": complex_dGs, - "standard_state_correction": correction_dGs, - } - - @staticmethod - def _add_complex_standard_state_corr( - complex_dG: list[tuple[Quantity, Quantity]], - standard_state_dG: list[tuple[Quantity, Quantity]], - ) -> list[tuple[Quantity, Quantity]]: - """ - Helper method to combine the - complex & standard state corrections legs. - - Parameters - ---------- - complex_dG : list[tuple[openff.units.Quantity, openff.units.Quantity]] - The individual estimates of the complex leg, - where the first entry of each tuple is the dG estimate - and the second entry is the MBAR error. - standard_state_dG : list[tuple[Quantity, Quantity]] - The individual standard state corrections for each corresponding - complex leg. The first entry is the correction, the second - is an empty error value of 0. - - Returns - ------- - combined_dG : list[tuple[openff.units.Quantity,openff.units. Quantity]] - A list of dG estimates & MBAR errors for the combined - complex & standard state correction of each repeat. - - Notes - ----- - We assume that both list of items are in the right order. - """ - combined_dG: list[tuple[Quantity, Quantity]] = [] - for comp, corr in zip(complex_dG, standard_state_dG): - # No need to convert unit types, since pint takes care of that - # except that mypy hates it because pint isn't typed properly... - # No need to add errors since there's just the one - combined_dG.append((comp[0] + corr[0], comp[1])) # type: ignore[operator] - - return combined_dG - - def get_estimate(self) -> Quantity: - """Get the binding free energy estimate for this calculation. - - Returns - ------- - dG : openff.units.Quantity - The binding free energy. This is a Quantity defined with units. - """ - - def _get_average(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.average(dGs) * u - - individual_estimates = self.get_individual_estimates() - complex_dG = _get_average( - self._add_complex_standard_state_corr( - individual_estimates["complex"], - individual_estimates["standard_state_correction"], - ) - ) - solv_dG = _get_average(individual_estimates["solvent"]) - - return -complex_dG + solv_dG - - def get_uncertainty(self) -> Quantity: - """Get the binding free energy error for this calculation. - - Returns - ------- - err : openff.units.Quantity - The standard deviation between estimates of the binding free - energy. This is a Quantity defined with units. - """ - - def _get_stdev(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.std(dGs) * u - - individual_estimates = self.get_individual_estimates() - - complex_err = _get_stdev( - self._add_complex_standard_state_corr( - individual_estimates["complex"], individual_estimates["standard_state_correction"] - ) - ) - solv_err = _get_stdev(individual_estimates["solvent"]) - - # return the combined error - return np.sqrt(complex_err**2 + solv_err**2) - - def get_forward_and_reverse_energy_analysis( - self, - ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: - """ - Get the reverse and forward analysis of the free energies. - - Returns - ------- - forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] - A dictionary, keyed `solvent` and `complex` for each leg of the - thermodynamic cycle which each contain a list of dictionaries - containing the forward and reverse analysis of each repeat - of that simulation type. - - The forward and reverse analysis dictionaries contain: - - `fractions`: npt.NDArray - The fractions of data used for the estimates - - `forward_DGs`, `reverse_DGs`: openff.units.Quantity - The forward and reverse estimates for each fraction of data - - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity - The forward and reverse estimate uncertainty for each - fraction of data. - - If one of the cycle leg list entries is ``None``, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - Raises - ------ - UserWarning - * If any of the forward and reverse dictionaries are ``None`` in a - given thermodynamic cycle leg. - """ - - forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} - - for key in ["solvent", "complex"]: - forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() - ] - - if None in forward_reverse[key]: - wmsg = ( - "One or more ``None`` entries were found in the forward " - f"and reverse dictionaries of the repeats of the {key} " - "calculations. This is likely caused by an MBAR convergence " - "failure caused by too few independent samples when " - "calculating the free energies of the 10% timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get a the MBAR overlap estimates for all legs of the simulation. - - Returns - ------- - overlap_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `complex` for each - leg of the thermodynamic cycle, which each containing a - list of dictionaries with the MBAR overlap estimates of - each repeat of that simulation type. - - The underlying MBAR dictionaries contain the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - - for key in ["solvent", "complex"]: - overlap_stats[key] = [ - pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() - ] - - return overlap_stats - - def get_replica_transition_statistics( - self, - ) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get the replica exchange transition statistics for all - legs of the simulation. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `complex` for each - leg of the thermodynamic cycle, which each containing - a list of dictionaries containing the replica transition - statistics for each repeat of that simulation type. - - The replica transition statistics dictionaries contain the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - try: - for key in ["solvent", "complex"]: - repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> dict[str, list[npt.NDArray]]: - """ - Get the timeseries of replica states for all simulation legs. - - Returns - ------- - replica_states : dict[str, list[npt.NDArray]] - Dictionary keyed `solvent` and `complex` for each leg of - the thermodynamic cycle, with lists of replica states - timeseries for each repeat of that simulation type. - """ - replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "complex": []} - - def is_file(filename: str): - p = pathlib.Path(filename) - - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - - return p - - def get_replica_state(nc, chk): - nc = is_file(nc) - dir_path = nc.parents[0] - chk = is_file(dir_path / chk).name - - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - - retval = np.asarray(reporter.read_replica_thermodynamic_states()) - reporter.close() - - return retval - - for key in ["solvent", "complex"]: - for pus in self.data[key].values(): - states = get_replica_state( - pus[0].outputs["nc"], - pus[0].outputs["last_checkpoint"], - ) - replica_states[key].append(states) - - return replica_states - - def equilibration_iterations(self) -> dict[str, list[float]]: - """ - Get the number of equilibration iterations for each simulation. - - Returns - ------- - equilibration_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `complex` for each leg - of the thermodynamic cycle, with lists containing the - number of equilibration iterations for each repeat - of that simulation type. - """ - equilibration_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "complex"]: - equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() - ] - - return equilibration_lengths - - def production_iterations(self) -> dict[str, list[float]]: - """ - Get the number of production iterations for each simulation. - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `complex` for each leg of the - thermodynamic cycle, with lists with the number - of production iterations for each repeat of that simulation - type. - """ - production_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "complex"]: - production_lengths[key] = [ - pus[0].outputs["production_iterations"] for pus in self.data[key].values() - ] - - return production_lengths - - def restraint_geometries(self) -> list[BoreschRestraintGeometry]: - """ - Get a list of the restraint geometries for the - complex simulations. These define the atoms that have - been restrained in the system. - - Returns - ------- - geometries : list[dict[str, Any]] - A list of dictionaries containing the details of the atoms - in the system that are involved in the restraint. - """ - geometries = [ - BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry"]) - for pus in self.data["complex"].values() - ] - - return geometries - - def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]: - """ - Get the system selection indices used to write PDB and - trajectory files. - - Returns - ------- - indices : dict[str, list[npt.NDArray]] - A dictionary keyed as `complex` and `solvent` for each - state, each containing a list of NDArrays containing the corresponding - full system atom indices for each atom written in the production - trajectory files for each replica. - """ - indices: dict[str, list[Optional[npt.NDArray]]] = {} - - for key in ["complex", "solvent"]: - indices[key] = [] - for pus in self.data[key].values(): - indices[key].append(pus[0].outputs["selection_indices"]) - - return indices - - -class AbsoluteBindingProtocol(gufe.Protocol): - """ - Absolute binding free energy calculations using OpenMM and OpenMMTools. - - See Also - -------- - :mod:`openfe.protocols` - :class:`openfe.protocols.openmm_afe.AbsoluteBindingSettings` - :class:`openfe.protocols.openmm_afe.AbsoluteBindingProtocolResult` - :class:`openfe.protocols.openmm_afe.AbsoluteBindingSolventUnit` - :class:`openfe.protocols.openmm_afe.AbsoluteBindingComplexUnit` - """ - - result_cls = AbsoluteBindingProtocolResult - _settings_cls = AbsoluteBindingSettings - _settings: AbsoluteBindingSettings - - @classmethod - def _default_settings(cls): - """A dictionary of initial settings for this creating this Protocol - - These settings are intended as a suitable starting point for creating - an instance of this protocol. It is recommended, however that care is - taken to inspect and customize these before performing a Protocol. - - Returns - ------- - Settings - a set of default settings - """ - # fmt: off - return AbsoluteBindingSettings( - protocol_repeats=3, - forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), - thermo_settings=settings.ThermoSettings( - temperature=298.15 * offunit.kelvin, - pressure=1 * offunit.bar, - ), - alchemical_settings=AlchemicalSettings(), - solvent_lambda_settings=LambdaSettings( - lambda_elec=[ - 0.0, 0.25, 0.5, 0.75, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - ], - lambda_vdw=[ - 0.0, 0.0, 0.0, 0.0, 0.0, - 0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0 - ], - lambda_restraints=[ - 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 - ], - ), - complex_lambda_settings=LambdaSettings( - lambda_elec=[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0 - ], - lambda_vdw=[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0 - ], - lambda_restraints=[ - 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0 - ], - ), - partial_charge_settings=OpenFFPartialChargeSettings(), - complex_solvation_settings=OpenMMSolvationSettings( - solvent_padding=1.0 * offunit.nanometer, - ), - solvent_solvation_settings=OpenMMSolvationSettings(), - engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), - restraint_settings=BoreschRestraintSettings(), - solvent_equil_simulation_settings=MDSimulationSettings( - equilibration_length_nvt=0.1 * offunit.nanosecond, - equilibration_length=0.2 * offunit.nanosecond, - production_length=0.5 * offunit.nanosecond, - ), - solvent_equil_output_settings=ABFEPreEquilOutputSettings(), - solvent_simulation_settings=MultiStateSimulationSettings( - n_replicas=14, - equilibration_length=1.0 * offunit.nanosecond, - production_length=10.0 * offunit.nanosecond, - ), - solvent_output_settings=MultiStateOutputSettings( - output_structure="alchemical_system.pdb", - output_filename="solvent.nc", - checkpoint_storage_filename="solvent_checkpoint.nc", - ), - complex_equil_simulation_settings=MDSimulationSettings( - equilibration_length_nvt=0.25 * offunit.nanosecond, - equilibration_length=0.5 * offunit.nanosecond, - production_length=5.0 * offunit.nanosecond, - ), - complex_equil_output_settings=ABFEPreEquilOutputSettings(), - complex_simulation_settings=MultiStateSimulationSettings( - n_replicas=30, - equilibration_length=1 * offunit.nanosecond, - production_length=10.0 * offunit.nanosecond, - ), - complex_output_settings=MultiStateOutputSettings( - output_structure="alchemical_system.pdb", - output_filename="complex.nc", - checkpoint_storage_filename="complex_checkpoint.nc", - ), - ) - # fmt: on - - @staticmethod - def _validate_endstates( - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ) -> None: - """ - A binding transformation is defined (in terms of gufe components) - as starting from one or more ligands with one protein and solvent, - that then ends up in a state with one less ligand. - - Parameters - ---------- - stateA : ChemicalSystem - The chemical system of end state A - stateB : ChemicalSystem - The chemical system of end state B - - Raises - ------ - ValueError - If stateA & stateB do not contain a ProteinComponent. - If stateA & stateB do not contain a SolventComponent. - If stateA has more than one unique Component. - If the stateA unique Component is not a SmallMoleculeComponent. - If stateB contains any unique Components. - If the alchemical species is charged. - """ - if not (stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent)): - errmsg = "No ProteinComponent found" - raise ValueError(errmsg) - - if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)): - errmsg = "No SolventComponent found" - raise ValueError(errmsg) - - # Needs gufe 1.3 - diff = stateA.component_diff(stateB) - if len(diff[0]) != 1: - errmsg = ( - "Only one alchemical species is supported. " - f"Number of unique components found in stateA: {len(diff[0])}." - ) - raise ValueError(errmsg) - - if not isinstance(diff[0][0], SmallMoleculeComponent): - errmsg = ( - "Only dissapearing small molecule components " - "are supported by this protocol. " - f"Found a {type(diff[0][0])}" - ) - raise ValueError(errmsg) - - # Check that the state A unique isn't charged - if diff[0][0].total_charge != 0: - errmsg = ( - "Charged alchemical molecules are not currently " - "supported for solvation free energies. " - f"Molecule total charge: {diff[0][0].total_charge}." - ) - raise ValueError(errmsg) - - # If there are any alchemical Components in state B - if len(diff[1]) > 0: - errmsg = "Components appearing in state B are not currently supported" - raise ValueError(errmsg) - - @staticmethod - def _validate_lambda_schedule( - lambda_settings: LambdaSettings, - simulation_settings: MultiStateSimulationSettings, - ) -> None: - """ - Checks that the lambda schedule is set up correctly. - - Parameters - ---------- - lambda_settings : LambdaSettings - the lambda schedule Settings - simulation_settings : MultiStateSimulationSettings - the settings for either the complex or solvent phase - - Raises - ------ - ValueError - If the number of lambda windows differs for electrostatics, sterics, - and restraints. - If the number of replicas does not match the number of lambda windows. - If there are states with naked charges. - """ - - lambda_elec = lambda_settings.lambda_elec - lambda_vdw = lambda_settings.lambda_vdw - lambda_restraints = lambda_settings.lambda_restraints - n_replicas = simulation_settings.n_replicas - - # Ensure that all lambda components have equal amount of windows - lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] - it = iter(lambda_components) - the_len = len(next(it)) - if not all(len(lambda_comp) == the_len for lambda_comp in it): - errmsg = ( - "Components elec, vdw, and restraints must have equal amount" - f" of lambda windows. Got {len(lambda_elec)} elec lambda" - f" windows, {len(lambda_vdw)} vdw lambda windows, and" - f"{len(lambda_restraints)} restraints lambda windows." - ) - raise ValueError(errmsg) - - # Ensure that number of overall lambda windows matches number of lambda - # windows for individual components - if n_replicas != len(lambda_vdw): - errmsg = ( - f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(lambda_vdw)}" - ) - raise ValueError(errmsg) - - # Check if there are no lambda windows with naked charges - for inx, lam in enumerate(lambda_elec): - if lam < 1 and lambda_vdw[inx] == 1: - errmsg = ( - "There are states along this lambda schedule " - "where there are atoms with charges but no LJ " - f"interactions: lambda {inx}: " - f"elec {lam} vdW {lambda_vdw[inx]}" - ) - raise ValueError(errmsg) - - def _validate( - self, - *, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, - ): - # Check we're not extending - if extends is not None: - # This technically should be NotImplementedError - # but gufe.Protocol.validate calls `_validate` wrapped around an - # except for NotImplementedError, so we can't raise it here - raise ValueError("Can't extend simulations yet") - - # Check we're not using a mapping, since we're not doing anything with it - if mapping is not None: - wmsg = "A mapping was passed but is not used by this Protocol." - warnings.warn(wmsg) - - # Validate the end states & alchemical components - self._validate_endstates(stateA, stateB) - - # Validate the complex lambda schedule - self._validate_lambda_schedule( - self.settings.complex_lambda_settings, - self.settings.complex_simulation_settings, - ) - - # If the complex restraints schedule is all zero, it might be bad - # but we don't dissallow it. - if all([i == 0.0 for i in self.settings.complex_lambda_settings.lambda_restraints]): - wmsg = ( - "No restraints are being applied in the complex phase, " - "this will likely lead to problematic results." - ) - warnings.warn(wmsg) - - # Validate the solvent lambda schedule - self._validate_lambda_schedule( - self.settings.solvent_lambda_settings, - self.settings.solvent_simulation_settings, - ) - - # If the solvent restraints schedule is not all one, it was likely - # copied from the complex schedule. In this case we just ignore - # the values and let the user know. - # P.S. we don't need to change the settings at this point - # the list gets popped out later in the SolventUnit, because we - # don't have a restraint parameter state. - - if any([i != 0.0 for i in self.settings.solvent_lambda_settings.lambda_restraints]): - wmsg = ( - "There is an attempt to add restraints in the solvent " - "phase. This protocol does not apply restraints in the " - "solvent phase. These restraint lambda values will be ignored." - ) - warnings.warn(wmsg) - - # Check nonbond & solvent compatibility - nonbonded_method = self.settings.forcefield_settings.nonbonded_method - # Use the more complete system validation solvent checks - system_validation.validate_solvent(stateA, nonbonded_method) - - # Validate solvation settings - settings_validation.validate_openmm_solvation_settings( - self.settings.solvent_solvation_settings - ) - settings_validation.validate_openmm_solvation_settings( - self.settings.complex_solvation_settings - ) - - # Validate integrator things - settings_validation.validate_timestep( - self.settings.forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep, - ) - - def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # Validate inputs - self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) - - # Get the alchemical components - alchem_comps = system_validation.get_alchemical_components( - stateA, - stateB, - ) - - # Get the name of the alchemical species - alchname = alchem_comps["stateA"][0].name - - # Create list units for complex and solvent transforms - - solvent_units = [ - AbsoluteBindingSolventUnit( - protocol=self, - stateA=stateA, - stateB=stateB, - alchemical_components=alchem_comps, - generation=0, - repeat_id=int(uuid.uuid4()), - name=(f"Absolute Binding, {alchname} solvent leg: repeat {i} generation 0"), - ) - for i in range(self.settings.protocol_repeats) - ] - - complex_units = [ - AbsoluteBindingComplexUnit( - protocol=self, - stateA=stateA, - stateB=stateB, - alchemical_components=alchem_comps, - generation=0, - repeat_id=int(uuid.uuid4()), - name=(f"Absolute Binding, {alchname} complex leg: repeat {i} generation 0"), - ) - for i in range(self.settings.protocol_repeats) - ] - - return solvent_units + complex_units - - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, dict[str, Any]]: - # result units will have a repeat_id and generation - # first group according to repeat_id - unsorted_solvent_repeats = defaultdict(list) - unsorted_complex_repeats = defaultdict(list) - for d in protocol_dag_results: - pu: gufe.ProtocolUnitResult - for pu in d.protocol_unit_results: - if not pu.ok(): - continue - if pu.outputs["simtype"] == "solvent": - unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) - else: - unsorted_complex_repeats[pu.outputs["repeat_id"]].append(pu) - - repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { - "solvent": {}, - "complex": {}, - } - for k, v in unsorted_solvent_repeats.items(): - repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - - for k, v in unsorted_complex_repeats.items(): - repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats - - -class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the complex phase of an absolute binding free energy - """ - - simtype = "complex" - - def _get_components(self): - """ - Get the relevant components for a complex transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A dict of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in small_mols} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp - return alchem_comps, solv_comp, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a complex transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : ABFEPreEquilOutputSettings - * simulation_settings : SimulationSettings - * output_settings: MultiStateOutputSettings - * restraint_settings: BaseRestraintSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.complex_solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.complex_lambda_settings - settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.complex_equil_output_settings - settings["simulation_settings"] = prot_settings.complex_simulation_settings - settings["output_settings"] = prot_settings.complex_output_settings - settings["restraint_settings"] = prot_settings.restraint_settings - - return settings - - @staticmethod - def _get_mda_universe( - topology: omm_topology, - positions: ommunit.Quantity, - trajectory: Optional[pathlib.Path], - ) -> mda.Universe: - """ - Helper method to get a Universe from an openmm Topology, - and either an input trajectory or a set of positions. - - Parameters - ---------- - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - positions: openmm.unit.Quantity - The System's current positions. - Used if a trajectory file is None or is not a file. - trajectory: pathlib.Path - A Path to a trajectory file to read positions from. - - Returns - ------- - mda.Universe - An MDAnalysis Universe of the System. - """ - from MDAnalysis.coordinates.memory import MemoryReader - - # If the trajectory file doesn't exist, then we use positions - if trajectory is not None and trajectory.is_file(): - return mda.Universe( - topology, - trajectory, - topology_format="OPENMMTOPOLOGY", - ) - else: - # Positions is an openmm Quantity in nm we need - # to convert to angstroms - return mda.Universe( - topology, - np.array(positions._value) * 10, - topology_format="OPENMMTOPOLOGY", - trajectory_format=MemoryReader, - ) - - @staticmethod - def _get_idxs_from_residxs( - topology: omm_topology, - residxs: list[int], - ) -> list[int]: - """ - Helper method to get the a list of atom indices which belong to a list - of residues. - - Parameters - ---------- - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - residxs : list[int] - A list of residue numbers who's atoms we should get atom indices. - - Returns - ------- - atom_ids : list[int] - A list of atom indices. - - TODO - ---- - * Check how this works when we deal with virtual sites. - """ - atom_ids = [] - - for r in topology.residues(): - if r.index in residxs: - atom_ids.extend([at.index for at in r.atoms()]) - - return atom_ids - - @staticmethod - def _get_boresch_restraint( - universe: mda.Universe, - guest_rdmol: Chem.Mol, - guest_atom_ids: list[int], - host_atom_ids: list[int], - temperature: Quantity, - settings: BoreschRestraintSettings, - ) -> tuple[BoreschRestraintGeometry, BoreschRestraint]: - """ - Get a Boresch-like restraint Geometry and OpenMM restraint force - supplier. - - Parameters - ---------- - universe : mda.Universe - An MDAnalysis Universe defining the system to get the restraint for. - guest_rdmol : Chem.Mol - An RDKit Molecule defining the guest molecule in the system. - guest_atom_ids: list[int] - A list of atom indices defining the guest molecule in the universe. - host_atom_ids : list[int] - A list of atom indices defining the host molecules in the universe. - temperature : openff.units.Quantity - The temperature of the simulation where the restraint will be added. - settings : BoreschRestraintSettings - Settings on how the Boresch-like restraint should be defined. - - Returns - ------- - geom : BoreschRestraintGeometry - A class defining the Boresch-like restraint. - restraint : BoreschRestraint - A factory class for generating Boresch restraints in OpenMM. - """ - # Take the minimum of the two possible force constants to check against - frc_const = min(settings.K_thetaA, settings.K_thetaB) - - geom = geometry.boresch.find_boresch_restraint( - universe=universe, - guest_rdmol=guest_rdmol, - guest_idxs=guest_atom_ids, - host_idxs=host_atom_ids, - host_selection=settings.host_selection, - anchor_finding_strategy=settings.anchor_finding_strategy, - dssp_filter=settings.dssp_filter, - rmsf_cutoff=settings.rmsf_cutoff, - host_min_distance=settings.host_min_distance, - host_max_distance=settings.host_max_distance, - angle_force_constant=frc_const, - temperature=temperature, - ) - - restraint = omm_restraints.BoreschRestraint(settings) - return geom, restraint - - def _add_restraints( - self, - system: System, - topology: omm_topology, - positions: ommunit.Quantity, - alchem_comps: dict[str, list[Component]], - comp_resids: dict[Component, npt.NDArray], - settings: dict[str, SettingsBaseModel], - ) -> tuple[ - GlobalParameterState, - Quantity, - System, - geometry.HostGuestRestraintGeometry, - ]: - """ - Find and add restraints to the OpenMM System. - - Notes - ----- - Currently, only Boresch-like restraints are supported. - - Parameters - ---------- - system : openmm.System - The System to add the restraint to. - topology : openmm.app.Topology - An OpenMM Topology that defines the System. - positions: openmm.unit.Quantity - The System's current positions. - Used if a trajectory file isn't found. - alchem_comps: dict[str, list[Component]] - A dictionary with a list of alchemical components - in both state A and B. - comp_resids: dict[Component, npt.NDArray] - A dictionary keyed by each Component in the System - which contains arrays with the residue indices that is contained - by that Component. - settings : dict[str, SettingsBaseModel] - A dictionary of settings that defines how to find and set - the restraint. - - Returns - ------- - restraint_parameter_state : RestraintParameterState - A RestraintParameterState object that defines the control - parameter for the restraint. - correction : openff.units.Quantity - The standard state correction for the restraint. - system : openmm.System - A copy of the System with the restraint added. - rest_geom : geometry.HostGuestRestraintGeometry - The restraint Geometry object. - """ - if self.verbose: - self.logger.info("Generating restraints") - - # Get the guest rdmol - guest_rdmol = alchem_comps["stateA"][0].to_rdkit() - - # sanitize the rdmol if possible - warn if you can't - err = Chem.SanitizeMol(guest_rdmol, catchErrors=True) - - if err: - msg = "restraint generation: could not sanitize ligand rdmol" - logger.warning(msg) - - # Get the guest idxs - # concatenate a list of residue indexes for all alchemical components - residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) - - # get the alchemicical atom ids - guest_atom_ids = self._get_idxs_from_residxs(topology, residxs) - - # Now get the host idxs - # We assume this is everything but the alchemical component - # and the solvent. - solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)] - exclude_comps = [alchem_comps["stateA"]] + solv_comps - residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps]) - - host_atom_ids = self._get_idxs_from_residxs(topology, residxs) - - # Finally create an MDAnalysis Universe - # We try to pass the equilibration production file path through - # In some cases (debugging / dry runs) this won't be available - # so we'll default to using input positions. - univ = self._get_mda_universe( - topology, - positions, - self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename, - ) - - if isinstance(settings["restraint_settings"], BoreschRestraintSettings): - rest_geom, restraint = self._get_boresch_restraint( - univ, - guest_rdmol, - guest_atom_ids, - host_atom_ids, - settings["thermo_settings"].temperature, - settings["restraint_settings"], - ) - else: - # TODO turn this into a direction for different restraint types supported? - raise NotImplementedError("Other restraint types are not yet available") - - if self.verbose: - self.logger.info(f"restraint geometry is: {rest_geom}") - - # We need a temporary thermodynamic state to add the restraint - # & get the correction - thermodynamic_state = ThermodynamicState( - system, - temperature=to_openmm(settings["thermo_settings"].temperature), - pressure=to_openmm(settings["thermo_settings"].pressure), - ) - - # Add the force to the thermodynamic state - restraint.add_force( - thermodynamic_state, - rest_geom, - controlling_parameter_name="lambda_restraints", - ) - # Get the standard state correction as a unit.Quantity - correction = restraint.get_standard_state_correction( - thermodynamic_state, - rest_geom, - ) - - # Get the GlobalParameterState for the restraint - restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0) - return ( - restraint_parameter_state, - correction, - # Remove the thermostat, otherwise you'll get an - # Andersen thermostat by default! - thermodynamic_state.get_system(remove_thermostat=True), - rest_geom, - ) - - -class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the solvent phase of an absolute binding free energy - """ - - simtype = "solvent" - - def _get_components(self): - """ - Get the relevant components for a solvent transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A list of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp just return None - return alchem_comps, solv_comp, None, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a solvent transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : ABFEPreEquilOutputSettings - * simulation_settings : MultiStateSimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvent_solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.solvent_lambda_settings - settings["engine_settings"] = prot_settings.engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings - settings["simulation_settings"] = prot_settings.solvent_simulation_settings - settings["output_settings"] = prot_settings.solvent_output_settings - - return settings diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py deleted file mode 100644 index 4ffb9477..00000000 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ /dev/null @@ -1,962 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` -=============================================================================================================== - -This module implements the necessary methodology tooling to run calculate an -absolute solvation free energy using OpenMM tools and one of the following -alchemical sampling methods: - -* Hamiltonian Replica Exchange -* Self-adjusted mixture sampling -* Independent window sampling - -Current limitations -------------------- -* Alchemical species with a net charge are not currently supported. -* Disapearing molecules are only allowed in state A. Support for - appearing molecules will be added in due course. -* Only small molecules are allowed to act as alchemical molecules. - Alchemically changing protein or solvent components would induce - perturbations which are too large to be handled by this Protocol. - - -Acknowledgements ----------------- -* Originally based on hydration.py in - `espaloma_charge `_ - -""" - -from __future__ import annotations - -import itertools -import logging -import pathlib -import uuid -import warnings -from collections import defaultdict -from typing import Any, Iterable, Optional, Union - -import gufe -import numpy as np -import numpy.typing as npt -from gufe import ( - ChemicalSystem, - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, - settings, -) -from gufe.components import Component -from openff.units import Quantity, unit -from openmmtools import multistate - -from openfe.due import Doi, due -from openfe.protocols.openmm_afe.equil_afe_settings import ( - AbsoluteSolvationSettings, - AlchemicalSettings, - IntegratorSettings, - LambdaSettings, - MDOutputSettings, - MDSimulationSettings, - MultiStateOutputSettings, - MultiStateSimulationSettings, - OpenFFPartialChargeSettings, - OpenMMEngineSettings, - OpenMMSolvationSettings, - SettingsBaseModel, -) - -from ..openmm_utils import settings_validation, system_validation -from .base import BaseAbsoluteUnit - -due.cite( - Doi("10.5281/zenodo.596504"), - description="Yank", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True, -) - -due.cite( - Doi("10.48550/ARXIV.2302.06758"), - description="EspalomaCharge", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True, -) - -due.cite( - Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True, -) - -due.cite( - Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True, -) - - -logger = logging.getLogger(__name__) - - -class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteSolvationProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # TODO: Detect when we have extensions and stitch these together? - if any( - len(pur_list) > 2 - for pur_list in itertools.chain( - self.data["solvent"].values(), self.data["vacuum"].values() - ) - ): - raise NotImplementedError("Can't stitch together results yet") - - def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]: - """ - Get the individual estimate of the free energies. - - Returns - ------- - dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] - A dictionary, keyed `solvent` and `vacuum` for each leg - of the thermodynamic cycle, with lists of tuples containing - the individual free energy estimates and associated MBAR - uncertainties for each repeat of that simulation type. - """ - vac_dGs = [] - solv_dGs = [] - - for pus in self.data["vacuum"].values(): - vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) - - for pus in self.data["solvent"].values(): - solv_dGs.append( - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - ) - - return {"solvent": solv_dGs, "vacuum": vac_dGs} - - def get_estimate(self): - """Get the solvation free energy estimate for this calculation. - - Returns - ------- - dG : openff.units.Quantity - The solvation free energy. This is a Quantity defined with units. - """ - - def _get_average(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.average(dGs) * u - - individual_estimates = self.get_individual_estimates() - vac_dG = _get_average(individual_estimates["vacuum"]) - solv_dG = _get_average(individual_estimates["solvent"]) - - return vac_dG - solv_dG - - def get_uncertainty(self): - """Get the solvation free energy error for this calculation. - - Returns - ------- - err : openff.units.Quantity - The standard deviation between estimates of the solvation free - energy. This is a Quantity defined with units. - """ - - def _get_stdev(estimates): - # Get the unit value of the first value in the estimates - u = estimates[0][0].u - # Loop through estimates and get the free energy values - # in the unit of the first estimate - dGs = [i[0].to(u).m for i in estimates] - - return np.std(dGs) * u - - individual_estimates = self.get_individual_estimates() - vac_err = _get_stdev(individual_estimates["vacuum"]) - solv_err = _get_stdev(individual_estimates["solvent"]) - - # return the combined error - return np.sqrt(vac_err**2 + solv_err**2) - - def get_forward_and_reverse_energy_analysis( - self, - ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: - """ - Get the reverse and forward analysis of the free energies. - - Returns - ------- - forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] - A dictionary, keyed `solvent` and `vacuum` for each leg of the - thermodynamic cycle which each contain a list of dictionaries - containing the forward and reverse analysis of each repeat - of that simulation type. - - The forward and reverse analysis dictionaries contain: - - `fractions`: npt.NDArray - The fractions of data used for the estimates - - `forward_DGs`, `reverse_DGs`: openff.units.Quantity - The forward and reverse estimates for each fraction of data - - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity - The forward and reverse estimate uncertainty for each - fraction of data. - - If one of the cycle leg list entries is ``None``, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - Raises - ------ - UserWarning - * If any of the forward and reverse dictionaries are ``None`` in a - given thermodynamic cycle leg. - """ - - forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} - - for key in ["solvent", "vacuum"]: - forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() - ] - - if None in forward_reverse[key]: - wmsg = ( - "One or more ``None`` entries were found in the forward " - f"and reverse dictionaries of the repeats of the {key} " - "calculations. This is likely caused by an MBAR convergence " - "failure caused by too few independent samples when " - "calculating the free energies of the 10% timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get a the MBAR overlap estimates for all legs of the simulation. - - Returns - ------- - overlap_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `vacuum` for each - leg of the thermodynamic cycle, which each containing a - list of dictionaries with the MBAR overlap estimates of - each repeat of that simulation type. - - The underlying MBAR dictionaries contain the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - - for key in ["solvent", "vacuum"]: - overlap_stats[key] = [ - pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() - ] - - return overlap_stats - - def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: - """ - Get the replica exchange transition statistics for all - legs of the simulation. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : dict[str, list[dict[str, npt.NDArray]]] - A dictionary with keys `solvent` and `vacuum` for each - leg of the thermodynamic cycle, which each containing - a list of dictionaries containing the replica transition - statistics for each repeat of that simulation type. - - The replica transition statistics dictionaries contain the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - try: - for key in ["solvent", "vacuum"]: - repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> dict[str, list[npt.NDArray]]: - """ - Get the timeseries of replica states for all simulation legs. - - Returns - ------- - replica_states : dict[str, list[npt.NDArray]] - Dictionary keyed `solvent` and `vacuum` for each leg of - the thermodynamic cycle, with lists of replica states - timeseries for each repeat of that simulation type. - """ - replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "vacuum": []} - - def is_file(filename: str): - p = pathlib.Path(filename) - - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - - return p - - def get_replica_state(nc, chk): - nc = is_file(nc) - dir_path = nc.parents[0] - chk = is_file(dir_path / chk).name - - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - - retval = np.asarray(reporter.read_replica_thermodynamic_states()) - reporter.close() - - return retval - - for key in ["solvent", "vacuum"]: - for pus in self.data[key].values(): - states = get_replica_state( - pus[0].outputs["nc"], - pus[0].outputs["last_checkpoint"], - ) - replica_states[key].append(states) - - return replica_states - - def equilibration_iterations(self) -> dict[str, list[float]]: - """ - Get the number of equilibration iterations for each simulation. - - Returns - ------- - equilibration_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `vacuum` for each leg - of the thermodynamic cycle, with lists containing the - number of equilibration iterations for each repeat - of that simulation type. - """ - equilibration_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "vacuum"]: - equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() - ] - - return equilibration_lengths - - def production_iterations(self) -> dict[str, list[float]]: - """ - Get the number of production iterations for each simulation. - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : dict[str, list[float]] - Dictionary keyed `solvent` and `vacuum` for each leg of the - thermodynamic cycle, with lists with the number - of production iterations for each repeat of that simulation - type. - """ - production_lengths: dict[str, list[float]] = {} - - for key in ["solvent", "vacuum"]: - production_lengths[key] = [ - pus[0].outputs["production_iterations"] for pus in self.data[key].values() - ] - - return production_lengths - - -class AbsoluteSolvationProtocol(gufe.Protocol): - """ - Absolute solvation free energy calculations using OpenMM and OpenMMTools. - - See Also - -------- - :mod:`openfe.protocols` - :class:`openfe.protocols.openmm_afe.AbsoluteSolvationSettings` - :class:`openfe.protocols.openmm_afe.AbsoluteSolvationProtocolResult` - :class:`openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit` - :class:`openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit` - """ - - result_cls = AbsoluteSolvationProtocolResult - _settings_cls = AbsoluteSolvationSettings - _settings: AbsoluteSolvationSettings - - @classmethod - def _default_settings(cls): - """A dictionary of initial settings for this creating this Protocol - - These settings are intended as a suitable starting point for creating - an instance of this protocol. It is recommended, however that care is - taken to inspect and customize these before performing a Protocol. - - Returns - ------- - Settings - a set of default settings - """ - return AbsoluteSolvationSettings( - protocol_repeats=3, - solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), - vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings( - nonbonded_method="nocutoff", - ), - thermo_settings=settings.ThermoSettings( - temperature=298.15 * unit.kelvin, - pressure=1 * unit.bar, - ), - alchemical_settings=AlchemicalSettings(), - lambda_settings=LambdaSettings( - lambda_elec=[ - 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - lambda_vdw=[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, - 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], - lambda_restraints=[ - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], - ), - partial_charge_settings=OpenFFPartialChargeSettings(), - solvation_settings=OpenMMSolvationSettings(), - vacuum_engine_settings=OpenMMEngineSettings(), - solvent_engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), - solvent_equil_simulation_settings=MDSimulationSettings( - equilibration_length_nvt=0.1 * unit.nanosecond, - equilibration_length=0.2 * unit.nanosecond, - production_length=0.5 * unit.nanosecond, - ), - solvent_equil_output_settings=MDOutputSettings( - equil_nvt_structure="equil_nvt_structure.pdb", - equil_npt_structure="equil_npt_structure.pdb", - production_trajectory_filename="production_equil.xtc", - log_output="equil_simulation.log", - ), - solvent_simulation_settings=MultiStateSimulationSettings( - n_replicas=14, - equilibration_length=1.0 * unit.nanosecond, - production_length=10.0 * unit.nanosecond, - ), - solvent_output_settings=MultiStateOutputSettings( - output_filename="solvent.nc", - checkpoint_storage_filename="solvent_checkpoint.nc", - ), - vacuum_equil_simulation_settings=MDSimulationSettings( - equilibration_length_nvt=None, - equilibration_length=0.2 * unit.nanosecond, - production_length=0.5 * unit.nanosecond, - ), - vacuum_equil_output_settings=MDOutputSettings( - equil_nvt_structure=None, - equil_npt_structure="equil_structure.pdb", - production_trajectory_filename="production_equil.xtc", - log_output="equil_simulation.log", - ), - vacuum_simulation_settings=MultiStateSimulationSettings( - n_replicas=14, - equilibration_length=0.5 * unit.nanosecond, - production_length=2.0 * unit.nanosecond, - ), - vacuum_output_settings=MultiStateOutputSettings( - output_filename="vacuum.nc", - checkpoint_storage_filename="vacuum_checkpoint.nc", - ), - ) # fmt: skip - - @staticmethod - def _validate_endstates( - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ) -> None: - """ - A solvent transformation is defined (in terms of gufe components) - as starting from one or more ligands in solvent and - ending up in a state with one less ligand. - - No protein components are allowed. - - Parameters - ---------- - stateA : ChemicalSystem - The chemical system of end state A - stateB : ChemicalSystem - The chemical system of end state B - - Raises - ------ - ValueError - If stateA or stateB contains a ProteinComponent. - If there is no SolventComponent in either stateA or stateB. - If there are alchemical components in state B. - If there are non SmallMoleculeComponent alchemical species. - If there are more than one alchemical species. - If the alchemical species is charged. - - Notes - ----- - * Currently doesn't support alchemical components in state B. - * Currently doesn't support alchemical components which are not - SmallMoleculeComponents. - * Currently doesn't support more than one alchemical component - being desolvated. - * Currently doesn't support charged alchemical components. - * Solvent must always be present in both end states. - """ - # Check that there are no protein components - if stateA.contains(ProteinComponent) or stateB.contains(ProteinComponent): - errmsg = "Protein components are not allowed for absolute solvation free energies." - raise ValueError(errmsg) - - # Check that there is a solvent component in both end states - if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)): - errmsg = "No SolventComponent found in stateA and/or stateB" - raise ValueError(errmsg) - - # Now we check the alchemical Components - diff = stateA.component_diff(stateB) - - # Check that there's only one state A unique Component - if len(diff[0]) != 1: - errmsg = ( - "Only one alchemical species is supported " - "for absolute solvation free energies. " - f"Number of unique components found in stateA: {len(diff[0])}." - ) - raise ValueError(errmsg) - - # Make sure that the state A unique is an SMC - if not isinstance(diff[0][0], SmallMoleculeComponent): - errmsg = ( - "Only dissapearing SmallMoleculeComponents " - "are supported by this protocol. " - f"Found a {type(diff[0][0])}" - ) - raise ValueError(errmsg) - - # Check that the state A unique isn't charged - if diff[0][0].total_charge != 0: - errmsg = ( - "Charged alchemical molecules are not currently " - "supported for solvation free energies. " - f"Molecule total charge: {diff[0][0].total_charge}." - ) - raise ValueError(errmsg) - - # If there are any alchemical Components in state B - if len(diff[1]) > 0: - errmsg = "Components appearing in state B are not currently supported" - raise ValueError(errmsg) - - @staticmethod - def _validate_lambda_schedule( - lambda_settings: LambdaSettings, - simulation_settings: MultiStateSimulationSettings, - ) -> None: - """ - Checks that the lambda schedule is set up correctly. - - Parameters - ---------- - lambda_settings : LambdaSettings - the lambda schedule Settings - simulation_settings : MultiStateSimulationSettings - the settings for either the vacuum or solvent phase - - Raises - ------ - ValueError - If the number of lambda windows differs for electrostatics and sterics. - If the number of replicas does not match the number of lambda windows. - If there are states with naked charges. - Warnings - If there are non-zero values for restraints (lambda_restraints). - """ - - lambda_elec = lambda_settings.lambda_elec - lambda_vdw = lambda_settings.lambda_vdw - lambda_restraints = lambda_settings.lambda_restraints - n_replicas = simulation_settings.n_replicas - - # Ensure that all lambda components have equal amount of windows - lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] - it = iter(lambda_components) - the_len = len(next(it)) - if not all(len(lambda_comp) == the_len for lambda_comp in it): - errmsg = ( - "Components elec, vdw, and restraints must have equal amount" - f" of lambda windows. Got {len(lambda_elec)} elec lambda" - f" windows, {len(lambda_vdw)} vdw lambda windows, and" - f"{len(lambda_restraints)} restraints lambda windows." - ) - raise ValueError(errmsg) - - # Ensure that number of overall lambda windows matches number of lambda - # windows for individual components - if n_replicas != len(lambda_vdw): - errmsg = ( - f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(lambda_vdw)}" - ) - raise ValueError(errmsg) - - # Check if there are lambda windows with naked charges - for inx, lam in enumerate(lambda_elec): - if lam < 1 and lambda_vdw[inx] == 1: - errmsg = ( - "There are states along this lambda schedule " - "where there are atoms with charges but no LJ " - f"interactions: lambda {inx}: " - f"elec {lam} vdW {lambda_vdw[inx]}" - ) - raise ValueError(errmsg) - - # Check if there are lambda windows with non-zero restraints - if len([r for r in lambda_restraints if r != 0]) > 0: - wmsg = ( - "Non-zero restraint lambdas applied. The absolute " - "solvation protocol doesn't apply restraints, " - "therefore restraints won't be applied. " - f"Given lambda_restraints: {lambda_restraints}" - ) - logger.warning(wmsg) - warnings.warn(wmsg) - - def _validate( - self, - *, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, - ): - # Check we're not extending - if extends is not None: - # This should be a NotImplementedError, but the underlying - # `validate` method wraps a call to `_validate` around a - # NotImplementedError exception guard - raise ValueError("Can't extend simulations yet") - - # Check we're not using a mapping, since we're not doing anything with it - if mapping is not None: - wmsg = "A mapping was passed but is not used by this Protocol." - warnings.warn(wmsg) - - # Validate the endstates & alchemical components - self._validate_endstates(stateA, stateB) - - # Validate the lambda schedule - for solv_sets in ( - self.settings.solvent_simulation_settings, - self.settings.vacuum_simulation_settings, - ): - self._validate_lambda_schedule( - self.settings.lambda_settings, - solv_sets, - ) - - # Check nonbond & solvent compatibility - solv_nonbonded_method = self.settings.solvent_forcefield_settings.nonbonded_method - vac_nonbonded_method = self.settings.vacuum_forcefield_settings.nonbonded_method - - # Use the more complete system validation solvent checks - system_validation.validate_solvent(stateA, solv_nonbonded_method) - - # Gas phase is always gas phase - if vac_nonbonded_method.lower() != "nocutoff": - errmsg = ( - "Only the nocutoff nonbonded_method is supported for " - f"vacuum calculations, {vac_nonbonded_method} was " - "passed" - ) - raise ValueError(errmsg) - - # Validate solvation settings - settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) - - # Check vacuum equilibration MD settings is 0 ns - nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt - if nvt_time is not None: - if not np.allclose(nvt_time, 0 * unit.nanosecond): - errmsg = "NVT equilibration cannot be run in vacuum simulation" - raise ValueError(errmsg) - - # Validate integrator things - settings_validation.validate_timestep( - self.settings.vacuum_forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep, - ) - - settings_validation.validate_timestep( - self.settings.solvent_forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep, - ) - - def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # Validate inputs - self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) - - # Get the alchemical components - alchem_comps = system_validation.get_alchemical_components( - stateA, - stateB, - ) - - # Get the name of the alchemical species - alchname = alchem_comps["stateA"][0].name - - # Create list units for vacuum and solvent transforms - solvent_units = [ - AbsoluteSolvationSolventUnit( - protocol=self, - stateA=stateA, - stateB=stateB, - alchemical_components=alchem_comps, - generation=0, - repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} solvent leg: repeat {i} generation 0"), - ) - for i in range(self.settings.protocol_repeats) - ] - - vacuum_units = [ - AbsoluteSolvationVacuumUnit( - # These don't really reflect the actual transform - # Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ? - protocol=self, - stateA=stateA, - stateB=stateB, - alchemical_components=alchem_comps, - generation=0, - repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} vacuum leg: repeat {i} generation 0"), - ) - for i in range(self.settings.protocol_repeats) - ] - - return solvent_units + vacuum_units - - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, dict[str, Any]]: - # result units will have a repeat_id and generation - # first group according to repeat_id - unsorted_solvent_repeats = defaultdict(list) - unsorted_vacuum_repeats = defaultdict(list) - for d in protocol_dag_results: - pu: gufe.ProtocolUnitResult - for pu in d.protocol_unit_results: - if not pu.ok(): - continue - if pu.outputs["simtype"] == "solvent": - unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) - else: - unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu) - - repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { - "solvent": {}, - "vacuum": {}, - } - for k, v in unsorted_solvent_repeats.items(): - repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - - for k, v in unsorted_vacuum_repeats.items(): - repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - return repeats - - -class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the vacuum phase of an absolute solvation free energy - """ - - simtype = "vacuum" - - def _get_components(self): - """ - Get the relevant components for a vacuum transformation. - - Returns - ------- - alchem_comps : dict[str, list[Component]] - A list of alchemical components - solv_comp : None - For the gas phase transformation, None will always be returned - for the solvent component of the chemical system. - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[Component, OpenFF Molecule] - The openff Molecules to add to the system. This - is equivalent to the alchemical components in stateA (since - we only allow for disappearing ligands). - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} - - _, prot_comp, _ = system_validation.get_components(stateA) - - # Notes: - # 1. Our input state will contain a solvent, we ``None`` that out - # since this is the gas phase unit. - # 2. Our small molecules will always just be the alchemical components - # (of stateA since we enforce only one disappearing ligand) - return alchem_comps, None, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a vacuum transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : MDOutputSettings - * simulation_settings : SimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.lambda_settings - settings["engine_settings"] = prot_settings.vacuum_engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings - settings["simulation_settings"] = prot_settings.vacuum_simulation_settings - settings["output_settings"] = prot_settings.vacuum_output_settings - - return settings - - -class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): - """ - Protocol Unit for the solvent phase of an absolute solvation free energy - """ - - simtype = "solvent" - - def _get_components(self): - """ - Get the relevant components for a solvent transformation. - - Returns - ------- - alchem_comps : dict[str, Component] - A list of alchemical components - solv_comp : SolventComponent - The SolventComponent of the system - prot_comp : Optional[ProteinComponent] - The protein component of the system, if it exists. - small_mols : dict[SmallMoleculeComponent: OFFMolecule] - SmallMoleculeComponents to add to the system. - """ - stateA = self._inputs["stateA"] - alchem_comps = self._inputs["alchemical_components"] - - solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) - off_comps = {m: m.to_openff() for m in small_mols} - - # We don't need to check that solv_comp is not None, otherwise - # an error will have been raised when calling `validate_solvent` - # in the Protocol's `_create`. - # Similarly we don't need to check prot_comp since that's also - # disallowed on create - return alchem_comps, solv_comp, prot_comp, off_comps - - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Extract the relevant settings for a solvent transformation. - - Returns - ------- - settings : dict[str, SettingsBaseModel] - A dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * charge_settings : OpenFFPartialChargeSettings - * solvation_settings : OpenMMSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : MDOutputSettings - * simulation_settings : MultiStateSimulationSettings - * output_settings: MultiStateOutputSettings - """ - prot_settings = self._inputs["protocol"].settings - - settings = {} - settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings - settings["thermo_settings"] = prot_settings.thermo_settings - settings["charge_settings"] = prot_settings.partial_charge_settings - settings["solvation_settings"] = prot_settings.solvation_settings - settings["alchemical_settings"] = prot_settings.alchemical_settings - settings["lambda_settings"] = prot_settings.lambda_settings - settings["engine_settings"] = prot_settings.solvent_engine_settings - settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings - settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings - settings["simulation_settings"] = prot_settings.solvent_simulation_settings - settings["output_settings"] = prot_settings.solvent_output_settings - - return settings diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py deleted file mode 100644 index e400cc3d..00000000 --- a/openfe/protocols/openmm_rfe/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe - -from . import _rfe_utils -from .equil_rfe_methods import ( - RelativeHybridTopologyProtocol, - RelativeHybridTopologyProtocolResult, - RelativeHybridTopologyProtocolUnit, -) -from .equil_rfe_settings import ( - RelativeHybridTopologyProtocolSettings, -) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py deleted file mode 100644 index 51423763..00000000 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ /dev/null @@ -1,1316 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -"""Equilibrium Relative Free Energy methods using OpenMM and OpenMMTools in a -Perses-like manner. - -This module implements the necessary methodology toolking to run calculate a -ligand relative free energy transformation using OpenMM tools and one of the -following methods: - - Hamiltonian Replica Exchange - - Self-adjusted mixture sampling - - Independent window sampling - -TODO ----- -* Improve this docstring by adding an example use case. - -Acknowledgements ----------------- -This Protocol is based on, and leverages components originating from -the Perses toolkit (https://github.com/choderalab/perses). -""" - -from __future__ import annotations - -import json -import logging -import os -import pathlib -import subprocess -import uuid -import warnings -from collections import defaultdict -from itertools import chain -from typing import Any, Iterable, Optional, Union - -import gufe -import matplotlib.pyplot as plt -import mdtraj -import numpy as np -import numpy.typing as npt -import openmmtools -from gufe import ( - ChemicalSystem, - Component, - ComponentMapping, - LigandAtomMapping, - ProteinComponent, - SmallMoleculeComponent, - SolventComponent, - settings, -) -from openff.toolkit.topology import Molecule as OFFMolecule -from openff.units import Quantity, unit -from openff.units.openmm import ensure_quantity, from_openmm, to_openmm -from openmmtools import multistate -from rdkit import Chem - -from openfe.due import Doi, due -from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, -) - -from ...analysis import plotting -from ...utils import log_system_probe, without_oechem_backend -from ..openmm_utils import ( - charge_generation, - multistate_analysis, - omm_compute, - settings_validation, - system_creation, - system_validation, -) -from . import _rfe_utils -from .equil_rfe_settings import ( - AlchemicalSettings, - IntegratorSettings, - LambdaSettings, - MultiStateOutputSettings, - MultiStateSimulationSettings, - OpenFFPartialChargeSettings, - OpenMMEngineSettings, - OpenMMSolvationSettings, - RelativeHybridTopologyProtocolSettings, -) - -logger = logging.getLogger(__name__) - - -due.cite( - Doi("10.5281/zenodo.1297683"), - description="Perses", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - -due.cite( - Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - -due.cite( - Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True, -) - - -def _get_resname(off_mol) -> str: - # behaviour changed between 0.10 and 0.11 - omm_top = off_mol.to_topology().to_openmm() - names = [r.name for r in omm_top.residues()] - if len(names) > 1: - raise ValueError("We assume single residue") - return names[0] - - -def _get_alchemical_charge_difference( - mapping: LigandAtomMapping, - nonbonded_method: str, - explicit_charge_correction: bool, - solvent_component: SolventComponent, -) -> int: - """ - Checks and returns the difference in formal charge between state A and B. - - Raises - ------ - ValueError - * If an explicit charge correction is attempted and the - nonbonded method is not PME. - * If the absolute charge difference is greater than one - and an explicit charge correction is attempted. - UserWarning - If there is any charge difference. - - Parameters - ---------- - mapping : dict[str, ComponentMapping] - Dictionary of mappings between transforming components. - nonbonded_method : str - The OpenMM nonbonded method used for the simulation. - explicit_charge_correction : bool - Whether or not to use an explicit charge correction. - solvent_component : openfe.SolventComponent - The SolventComponent of the simulation. - - Returns - ------- - int - The formal charge difference between states A and B. - This is defined as sum(charge state A) - sum(charge state B) - """ - - difference = mapping.get_alchemical_charge_difference() - - if abs(difference) > 0: - if explicit_charge_correction: - if nonbonded_method.lower() != "pme": - errmsg = "Explicit charge correction when not using PME is not currently supported." - raise ValueError(errmsg) - if abs(difference) > 1: - errmsg = ( - f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported." - ) - raise ValueError(errmsg) - - ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[ - difference - ] - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - logger.warning(wmsg) - warnings.warn(wmsg) - else: - wmsg = ( - f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results." - ) - logger.warning(wmsg) - warnings.warn(wmsg) - - return difference - - -def _validate_alchemical_components( - alchemical_components: dict[str, list[Component]], - mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], -): - """ - Checks that the alchemical components are suitable for the RFE protocol. - - Specifically we check: - 1. That all alchemical components are mapped. - 2. That all alchemical components are SmallMoleculeComponents. - 3. If the mappings involves element changes in core atoms - - Parameters - ---------- - alchemical_components : dict[str, list[Component]] - Dictionary contatining the alchemical components for - states A and B. - mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] - all mappings between transforming components. - - Raises - ------ - ValueError - * If there are more than one mapping or mapping is None - * If there are any unmapped alchemical components. - * If there are any alchemical components that are not - SmallMoleculeComponents. - UserWarning - * Mappings which involve element changes in core atoms - """ - if isinstance(mapping, ComponentMapping): - mapping = [mapping] - # Check mapping - # For now we only allow for a single mapping, this will likely change - if mapping is None or len(mapping) != 1: - errmsg = "A single LigandAtomMapping is expected for this Protocol" - raise ValueError(errmsg) - - # Check that all alchemical components are mapped & small molecules - mapped = { - "stateA": [m.componentA for m in mapping], - "stateB": [m.componentB for m in mapping], - } - - for idx in ["stateA", "stateB"]: - if len(alchemical_components[idx]) != len(mapped[idx]): - errmsg = f"missing alchemical components in {idx}" - raise ValueError(errmsg) - for comp in alchemical_components[idx]: - if comp not in mapped[idx]: - raise ValueError(f"Unmapped alchemical component {comp}") - if not isinstance(comp, SmallMoleculeComponent): # pragma: no-cover - errmsg = ( - "Transformations involving non " - "SmallMoleculeComponent species {comp} " - "are not currently supported" - ) - raise ValueError(errmsg) - - # Validate element changes in mappings - for m in mapping: - molA = m.componentA.to_rdkit() - molB = m.componentB.to_rdkit() - for i, j in m.componentA_to_componentB.items(): - atomA = molA.GetAtomWithIdx(i) - atomB = molB.GetAtomWithIdx(j) - if atomA.GetAtomicNum() != atomB.GetAtomicNum(): - wmsg = ( - f"Element change in mapping between atoms " - f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " - f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" - "No mass scaling is attempted in the hybrid topology, " - "the average mass of the two atoms will be used in the " - "simulation" - ) - logger.warning(wmsg) - warnings.warn(wmsg) # TODO: remove this once logging is fixed - - -class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a RelativeHybridTopologyProtocol""" - - def __init__(self, **data): - super().__init__(**data) - # data is mapping of str(repeat_id): list[protocolunitresults] - # TODO: Detect when we have extensions and stitch these together? - if any(len(pur_list) > 2 for pur_list in self.data.values()): - raise NotImplementedError("Can't stitch together results yet") - - @staticmethod - def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: - u = dGs[0].u - # convert all values to units of the first value, then take average of magnitude - # this would avoid a screwy case where each value was in different units - vals = np.asarray([dG.to(u).m for dG in dGs]) - - return np.average(vals) * u - - def get_estimate(self) -> Quantity: - """Average free energy difference of this transformation - - Returns - ------- - dG : openff.units.Quantity - The free energy difference between the first and last states. This is - a Quantity defined with units. - """ - # TODO: Check this holds up completely for SAMS. - dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] - return self.compute_mean_estimate(dGs) - - @staticmethod - def compute_uncertainty(dGs: list[Quantity]) -> Quantity: - u = dGs[0].u - # convert all values to units of the first value, then take average of magnitude - # this would avoid a screwy case where each value was in different units - vals = np.asarray([dG.to(u).m for dG in dGs]) - - return np.std(vals) * u - - def get_uncertainty(self) -> Quantity: - """The uncertainty/error in the dG value: The std of the estimates of - each independent repeat - """ - - dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] - return self.compute_uncertainty(dGs) - - def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: - """Return a list of tuples containing the individual free energy - estimates and associated MBAR errors for each repeat. - - Returns - ------- - dGs : list[tuple[openff.units.Quantity]] - n_replicate simulation list of tuples containing the free energy - estimates (first entry) and associated MBAR estimate errors - (second entry). - """ - dGs = [ - (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) - for pus in self.data.values() - ] - return dGs - - def get_forward_and_reverse_energy_analysis( - self, - ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: - """ - Get a list of forward and reverse analysis of the free energies - for each repeat using uncorrelated production samples. - - The returned dicts have keys: - 'fractions' - the fraction of data used for this estimate - 'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate - 'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty - - The 'fractions' values are a numpy array, while the other arrays are - Quantity arrays, with units attached. - - If the list entry is ``None`` instead of a dictionary, this indicates - that the analysis could not be carried out for that repeat. This - is most likely caused by MBAR convergence issues when attempting to - calculate free energies from too few samples. - - - Returns - ------- - forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]] - - - Raises - ------ - UserWarning - If any of the forward and reverse entries are ``None``. - """ - forward_reverse = [ - pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() - ] - - if None in forward_reverse: - wmsg = ( - "One or more ``None`` entries were found in the list of " - "forward and reverse analyses. This is likely caused by " - "an MBAR convergence failure caused by too few independent " - "samples when calculating the free energies of the 10% " - "timeseries slice." - ) - warnings.warn(wmsg) - - return forward_reverse - - def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: - """ - Return a list of dictionary containing the MBAR overlap estimates - calculated for each repeat. - - Returns - ------- - overlap_stats : list[dict[str, npt.NDArray]] - A list of dictionaries containing the following keys: - * ``scalar``: One minus the largest nontrivial eigenvalue - * ``eigenvalues``: The sorted (descending) eigenvalues of the - overlap matrix - * ``matrix``: Estimated overlap matrix of observing a sample from - state i in state j - """ - # Loop through and get the repeats and get the matrices - overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] - - return overlap_stats - - def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: - """The replica lambda state transition statistics for each repeat. - - Note - ---- - This is currently only available in cases where a replica exchange - simulation was run. - - Returns - ------- - repex_stats : list[dict[str, npt.NDArray]] - A list of dictionaries containing the following: - * ``eigenvalues``: The sorted (descending) eigenvalues of the - lambda state transition matrix - * ``matrix``: The transition matrix estimate of a replica switching - from state i to state j. - """ - try: - repex_stats = [ - pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() - ] - except KeyError: - errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" - raise ValueError(errmsg) - - return repex_stats - - def get_replica_states(self) -> list[npt.NDArray]: - """ - Returns the timeseries of replica states for each repeat. - - Returns - ------- - replica_states : List[npt.NDArray] - List of replica states for each repeat - """ - - def is_file(filename: str): - p = pathlib.Path(filename) - if not p.exists(): - errmsg = f"File could not be found {p}" - raise ValueError(errmsg) - return p - - replica_states = [] - - for pus in self.data.values(): - nc = is_file(pus[0].outputs["nc"]) - dir_path = nc.parents[0] - chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name - reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode="r" - ) - replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) - reporter.close() - - return replica_states - - def equilibration_iterations(self) -> list[float]: - """ - Returns the number of equilibration iterations for each repeat - of the calculation. - - Returns - ------- - equilibration_lengths : list[float] - """ - equilibration_lengths = [ - pus[0].outputs["equilibration_iterations"] for pus in self.data.values() - ] - - return equilibration_lengths - - def production_iterations(self) -> list[float]: - """ - Returns the number of uncorrelated production samples for each - repeat of the calculation. - - Returns - ------- - production_lengths : list[float] - """ - production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] - - return production_lengths - - -class RelativeHybridTopologyProtocol(gufe.Protocol): - """ - Relative Free Energy calculations using OpenMM and OpenMMTools. - - Based on `Perses `_ - - See Also - -------- - :mod:`openfe.protocols` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` - :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` - """ - - result_cls = RelativeHybridTopologyProtocolResult - _settings_cls = RelativeHybridTopologyProtocolSettings - _settings: RelativeHybridTopologyProtocolSettings - - @classmethod - def _default_settings(cls): - """A dictionary of initial settings for this creating this Protocol - - These settings are intended as a suitable starting point for creating - an instance of this protocol. It is recommended, however that care is - taken to inspect and customize these before performing a Protocol. - - Returns - ------- - Settings - a set of default settings - """ - return RelativeHybridTopologyProtocolSettings( - protocol_repeats=3, - forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), - thermo_settings=settings.ThermoSettings( - temperature=298.15 * unit.kelvin, - pressure=1 * unit.bar, - ), - partial_charge_settings=OpenFFPartialChargeSettings(), - solvation_settings=OpenMMSolvationSettings(), - alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), - lambda_settings=LambdaSettings(), - simulation_settings=MultiStateSimulationSettings( - equilibration_length=1.0 * unit.nanosecond, - production_length=5.0 * unit.nanosecond, - ), - engine_settings=OpenMMEngineSettings(), - integrator_settings=IntegratorSettings(), - output_settings=MultiStateOutputSettings(), - ) - - @classmethod - def _adaptive_settings( - cls, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping], - initial_settings: None | RelativeHybridTopologyProtocolSettings = None, - ) -> RelativeHybridTopologyProtocolSettings: - """ - Get the recommended OpenFE settings for this protocol based on the input states involved in the - transformation. - - These are intended as a suitable starting point for creating an instance of this protocol, which can be further - customized before performing a Protocol. - - Parameters - ---------- - stateA : ChemicalSystem - The initial state of the transformation. - stateB : ChemicalSystem - The final state of the transformation. - mapping : LigandAtomMapping | list[LigandAtomMapping] - The mapping(s) between transforming components in stateA and stateB. - initial_settings : None | RelativeHybridTopologyProtocolSettings, optional - Initial settings to base the adaptive settings on. If None, default settings are used. - - Returns - ------- - RelativeHybridTopologyProtocolSettings - The recommended settings for this protocol based on the input states. - - Notes - ----- - - If the transformation involves a change in net charge, the settings are adapted to use a more expensive - protocol with 22 lambda windows and 20 ns production length per window. - - If both states contain a ProteinComponent, the solvation padding is set to 1 nm. - - If initial_settings is provided, the adaptive settings are based on a copy of these settings. - """ - # use initial settings or default settings - # this is needed for the CLI so we don't override user settings - if initial_settings is not None: - protocol_settings = initial_settings.copy(deep=True) - else: - protocol_settings = cls.default_settings() - - if isinstance(mapping, list): - mapping = mapping[0] - - if mapping.get_alchemical_charge_difference() != 0: - # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated - # - info = ( - "Charge changing transformation between ligands " - f"{mapping.componentA.name} and {mapping.componentB.name}. " - "A more expensive protocol with 22 lambda windows, sampled " - "for 20 ns each, will be used here." - ) - logger.info(info) - protocol_settings.alchemical_settings.explicit_charge_correction = True - protocol_settings.simulation_settings.production_length = 20 * unit.nanosecond - protocol_settings.simulation_settings.n_replicas = 22 - protocol_settings.lambda_settings.lambda_windows = 22 - - # adapt the solvation padding based on the system components - if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): - protocol_settings.solvation_settings.solvent_padding = 1 * unit.nanometer - - return protocol_settings - - def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], - extends: Optional[gufe.ProtocolDAGResult] = None, - ) -> list[gufe.ProtocolUnit]: - # TODO: Extensions? - if extends: - raise NotImplementedError("Can't extend simulations yet") - - # Get alchemical components & validate them + mapping - alchem_comps = system_validation.get_alchemical_components(stateA, stateB) - _validate_alchemical_components(alchem_comps, mapping) - ligandmapping = mapping[0] if isinstance(mapping, list) else mapping - - # Validate solvent component - nonbond = self.settings.forcefield_settings.nonbonded_method - system_validation.validate_solvent(stateA, nonbond) - - # Validate solvation settings - settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) - - # Validate protein component - system_validation.validate_protein(stateA) - - # actually create and return Units - Anames = ",".join(c.name for c in alchem_comps["stateA"]) - Bnames = ",".join(c.name for c in alchem_comps["stateB"]) - # our DAG has no dependencies, so just list units - n_repeats = self.settings.protocol_repeats - units = [ - RelativeHybridTopologyProtocolUnit( - protocol=self, - stateA=stateA, - stateB=stateB, - ligandmapping=ligandmapping, - generation=0, - repeat_id=int(uuid.uuid4()), - name=f"{Anames} to {Bnames} repeat {i} generation 0", - ) - for i in range(n_repeats) - ] - - return units - - def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: - # result units will have a repeat_id and generations within this repeat_id - # first group according to repeat_id - unsorted_repeats = defaultdict(list) - for d in protocol_dag_results: - pu: gufe.ProtocolUnitResult - for pu in d.protocol_unit_results: - if not pu.ok(): - continue - - unsorted_repeats[pu.outputs["repeat_id"]].append(pu) - - # then sort by generation within each repeat_id list - repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} - for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) - - # returns a dict of repeat_id: sorted list of ProtocolUnitResult - return repeats - - -class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): - """ - Calculates the relative free energy of an alchemical ligand transformation. - """ - - def __init__( - self, - *, - protocol: RelativeHybridTopologyProtocol, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - ligandmapping: LigandAtomMapping, - generation: int, - repeat_id: int, - name: Optional[str] = None, - ): - """ - Parameters - ---------- - protocol : RelativeHybridTopologyProtocol - protocol used to create this Unit. Contains key information such - as the settings. - stateA, stateB : ChemicalSystem - the two ligand SmallMoleculeComponents to transform between. The - transformation will go from ligandA to ligandB. - ligandmapping : LigandAtomMapping - the mapping of atoms between the two ligand components - repeat_id : int - identifier for which repeat (aka replica/clone) this Unit is - generation : int - counter for how many times this repeat has been extended - name : str, optional - human-readable identifier for this Unit - - Notes - ----- - The mapping used must not involve any elemental changes. A check for - this is done on class creation. - """ - super().__init__( - name=name, - protocol=protocol, - stateA=stateA, - stateB=stateB, - ligandmapping=ligandmapping, - repeat_id=repeat_id, - generation=generation, - ) - - @staticmethod - def _assign_partial_charges( - charge_settings: OpenFFPartialChargeSettings, - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]], - ) -> None: - """ - Assign partial charges to SMCs. - - Parameters - ---------- - charge_settings : OpenFFPartialChargeSettings - Settings for controlling how the partial charges are assigned. - off_small_mols : dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - Dictionary of dictionary of OpenFF Molecules to add, keyed by - state and SmallMoleculeComponent. - """ - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - charge_generation.assign_offmol_partial_charges( - offmol=mol, - overwrite=False, - method=charge_settings.partial_charge_method, - toolkit_backend=charge_settings.off_toolkit_backend, - generate_n_conformers=charge_settings.number_of_conformers, - nagl_model=charge_settings.nagl_model, - ) - - def run( - self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None - ) -> dict[str, Any]: - """Run the relative free energy calculation. - - Parameters - ---------- - dry : bool - Do a dry run of the calculation, creating all necessary hybrid - system components (topology, system, sampler, etc...) but without - running the simulation. - verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging. - scratch_basepath: Pathlike, optional - Where to store temporary files, defaults to current working directory - shared_basepath : Pathlike, optional - Where to run the calculation, defaults to current working directory - - Returns - ------- - dict - Outputs created in the basepath directory or the debug objects - (i.e. sampler) if ``dry==True``. - - Raises - ------ - error - Exception if anything failed - """ - if verbose: - self.logger.info("Preparing the hybrid topology simulation") - if scratch_basepath is None: - scratch_basepath = pathlib.Path(".") - if shared_basepath is None: - # use cwd - shared_basepath = pathlib.Path(".") - - # 0. General setup and settings dependency resolution step - - # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ - "protocol" - ].settings - stateA = self._inputs["stateA"] - stateB = self._inputs["stateB"] - mapping = self._inputs["ligandmapping"] - - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( - protocol_settings.forcefield_settings - ) - thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings - alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings - lambda_settings: LambdaSettings = protocol_settings.lambda_settings - charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings - solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings - sampler_settings: MultiStateSimulationSettings = protocol_settings.simulation_settings - output_settings: MultiStateOutputSettings = protocol_settings.output_settings - integrator_settings: IntegratorSettings = protocol_settings.integrator_settings - - # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, integrator_settings.timestep - ) - # TODO: Also validate various conversions? - # Convert various time based inputs to steps/iterations - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings=sampler_settings, - integrator_settings=integrator_settings, - ) - - equil_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.equilibration_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - prod_steps = settings_validation.get_simsteps( - sim_length=sampler_settings.production_length, - timestep=integrator_settings.timestep, - mc_steps=steps_per_iteration, - ) - - solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) - - # Get the change difference between the end states - # and check if the charge correction used is appropriate - charge_difference = _get_alchemical_charge_difference( - mapping, - forcefield_settings.nonbonded_method, - alchem_settings.explicit_charge_correction, - solvent_comp, - ) - - # 1. Create stateA system - self.logger.info("Parameterizing molecules") - - # a. create offmol dictionaries and assign partial charges - # workaround for conformer generation failures - # see openfe issue #576 - # calculate partial charges manually if not already given - # convert to OpenFF here, - # and keep the molecule around to maintain the partial charges - off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] - off_small_mols = { - "stateA": [(mapping.componentA, mapping.componentA.to_openff())], - "stateB": [(mapping.componentB, mapping.componentB.to_openff())], - "both": [ - (m, m.to_openff()) - for m in small_mols - if (m != mapping.componentA and m != mapping.componentB) - ], - } - - self._assign_partial_charges(charge_settings, off_small_mols) - - # b. get a system generator - if output_settings.forcefield_cache is not None: - ffcache = shared_basepath / output_settings.forcefield_cache - else: - ffcache = None - - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system_generator = system_creation.get_system_generator( - forcefield_settings=forcefield_settings, - integrator_settings=integrator_settings, - thermo_settings=thermo_settings, - cache=ffcache, - has_solvent=solvent_comp is not None, - ) - - # c. force the creation of parameters - # This is necessary because we need to have the FF templates - # registered ahead of solvating the system. - for smc, mol in chain( - off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] - ): - system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) - - # c. get OpenMM Modeller + a dictionary of resids for each component - stateA_modeller, comp_resids = system_creation.get_omm_modeller( - protein_comp=protein_comp, - solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), - omm_forcefield=system_generator.forcefield, - solvent_settings=solvation_settings, - ) - - # d. get topology & positions - # Note: roundtrip positions to remove vec3 issues - stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) - - # e. create the stateA System - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateA_system = system_generator.create_system( - stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], - ) - - # 2. Get stateB system - # a. get the topology - stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( - stateA_topology, - # zeroth item (there's only one) then get the OFF representation - off_small_mols["stateB"][0][1].to_topology().to_openmm(), - exclude_resids=comp_resids[mapping.componentA], - ) - - # b. get a list of small molecules for stateB - # Block out oechem backend in system_generator calls to avoid - # any issues with smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - stateB_system = system_generator.create_system( - stateB_topology, - molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], - ) - - # c. Define correspondence mappings between the two systems - ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( - mapping.componentA_to_componentB, - stateA_system, - stateA_topology, - comp_resids[mapping.componentA], - stateB_system, - stateB_topology, - stateB_alchem_resids, - # These are non-optional settings for this method - fix_constraints=True, - ) - - # d. if a charge correction is necessary, select alchemical waters - # and transform them - if alchem_settings.explicit_charge_correction: - alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, - stateA_positions, - charge_difference, - alchem_settings.explicit_charge_correction_cutoff, - ) - _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, - stateB_topology, - stateB_system, - ligand_mappings, - charge_difference, - solvent_comp, - ) - - # e. Finally get the positions - stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, - stateA_topology, - stateB_topology, - old_positions=ensure_quantity(stateA_positions, "openmm"), - insert_positions=ensure_quantity( - off_small_mols["stateB"][0][1].conformers[0], "openmm" - ), - ) - - # 3. Create the hybrid topology - # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == "gapsys": - softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == "beutler": - softcore_LJ_v2 = False - # b. Get hybrid topology factory - hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( - stateA_system, - stateA_positions, - stateA_topology, - stateB_system, - stateB_positions, - stateB_topology, - old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], - old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], - use_dispersion_correction=alchem_settings.use_dispersion_correction, - softcore_alpha=alchem_settings.softcore_alpha, - softcore_LJ_v2=softcore_LJ_v2, - softcore_LJ_v2_alpha=alchem_settings.softcore_alpha, - interpolate_old_and_new_14s=alchem_settings.turn_off_core_unique_exceptions, - ) - - # 4. Create lambda schedule - # TODO - this should be exposed to users, maybe we should offer the - # ability to print the schedule directly in settings? - # fmt: off - lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( - functions=lambda_settings.lambda_functions, - windows=lambda_settings.lambda_windows - ) - # fmt: on - # PR #125 temporarily pin lambda schedule spacing to n_replicas - n_replicas = sampler_settings.n_replicas - if n_replicas != len(lambdas.lambda_schedule): - errmsg = ( - f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}" - ) - raise ValueError(errmsg) - - # 9. Create the multistate reporter - # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) - - # a. Create the multistate reporter - # convert checkpoint_interval from time to iterations - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=sampler_settings.time_per_iteration, - ) - - nc = shared_basepath / output_settings.output_filename - chk = output_settings.checkpoint_storage_filename - - if output_settings.positions_write_frequency is not None: - pos_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.positions_write_frequency, - denominator=sampler_settings.time_per_iteration, - numerator_name="output settings' position_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - else: - pos_interval = 0 - - if output_settings.velocities_write_frequency is not None: - vel_interval = settings_validation.divmod_time_and_check( - numerator=output_settings.velocities_write_frequency, - denominator=sampler_settings.time_per_iteration, - numerator_name="output settings' velocity_write_frequency", - denominator_name="sampler settings' time_per_iteration", - ) - else: - vel_interval = 0 - - reporter = multistate.MultiStateReporter( - storage=nc, - analysis_particle_indices=selection_indices, - checkpoint_interval=chk_intervals, - checkpoint_storage=chk, - position_interval=pos_interval, - velocity_interval=vel_interval, - ) - - # b. Write out a PDB containing the subsampled hybrid state - # fmt: off - bfactors = np.zeros_like(selection_indices, dtype=float) # solvent - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core - bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B - # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor - if len(selection_indices) > 0: - traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), - ).save_pdb( - shared_basepath / output_settings.output_structure, - bfactors=bfactors, - ) - # fmt: on - - # 10. Get compute platform - # restrict to a single CPU if running vacuum - restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" - platform = omm_compute.get_openmm_platform( - platform_name=protocol_settings.engine_settings.compute_platform, - gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu, - ) - - # 11. Set the integrator - # a. Validate integrator settings for current system - # Virtual sites sanity check - ensure we restart velocities when - # there are virtual sites in the system - if hybrid_factory.has_virtual_sites: - if not integrator_settings.reassign_velocities: - errmsg = ( - "Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools" - ) - raise ValueError(errmsg) - - # b. create langevin integrator - integrator = openmmtools.mcmc.LangevinDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.langevin_collision_rate), - n_steps=steps_per_iteration, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - ) - - # 12. Create sampler - self.logger.info("Creating and setting up the sampler") - rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( - simulation_settings=sampler_settings, - ) - # convert early_termination_target_error from kcal/mol to kT - early_termination_target_error = ( - settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermo_settings.temperature, - sampler_settings.early_termination_target_error, - ) - ) - - if sampler_settings.sampler_method.lower() == "repex": - sampler = _rfe_utils.multistate.HybridRepexSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, - ) - elif sampler_settings.sampler_method.lower() == "sams": - sampler = _rfe_utils.multistate.HybridSAMSSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=rta_its, - online_analysis_minimum_iterations=rta_min_its, - flatness_criteria=sampler_settings.sams_flatness_criteria, - gamma0=sampler_settings.sams_gamma0, - ) - elif sampler_settings.sampler_method.lower() == "independent": - sampler = _rfe_utils.multistate.HybridMultiStateSampler( - mcmc_moves=integrator, - hybrid_factory=hybrid_factory, - online_analysis_interval=rta_its, - online_analysis_target_error=early_termination_target_error, - online_analysis_minimum_iterations=rta_min_its, - ) - - else: - raise AttributeError(f"Unknown sampler {sampler_settings.sampler_method}") - - sampler.setup( - n_replicas=sampler_settings.n_replicas, - reporter=reporter, - lambda_protocol=lambdas, - temperature=to_openmm(thermo_settings.temperature), - endstates=alchem_settings.endstate_dispersion_correction, - minimization_platform=platform.getName(), - # Set minimization steps to None when running in dry mode - # otherwise do a very small one to avoid NaNs - minimization_steps=100 if not dry else None, - ) - - try: - # Create context caches (energy + sampler) - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache - - if not dry: # pragma: no-cover - # minimize - if verbose: - self.logger.info("Running minimization") - - sampler.minimize(max_iterations=sampler_settings.minimization_steps) - - # equilibrate - if verbose: - self.logger.info("Running equilibration phase") - - sampler.equilibrate(int(equil_steps / steps_per_iteration)) - - # production - if verbose: - self.logger.info("Running production phase") - - sampler.extend(int(prod_steps / steps_per_iteration)) - - self.logger.info("Production phase complete") - - self.logger.info("Post-simulation analysis of results") - # calculate relevant analyses of the free energies & sampling - # First close & reload the reporter to avoid netcdf clashes - analyzer = multistate_analysis.MultistateEquilFEAnalysis( - reporter, - sampling_method=sampler_settings.sampler_method.lower(), - result_units=unit.kilocalorie_per_mole, - ) - analyzer.plot(filepath=shared_basepath, filename_prefix="") - analyzer.close() - - else: - # clean up the reporter file - fns = [ - shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename, - ] - for fn in fns: - os.remove(fn) - finally: - # close reporter when you're done, prevent - # file handle clashes - reporter.close() - - # clear GPU contexts - # TODO: use cache.empty() calls when openmmtools #690 is resolved - # replace with above - for context in list(energy_context_cache._lru._data.keys()): - del energy_context_cache._lru._data[context] - for context in list(sampler_context_cache._lru._data.keys()): - del sampler_context_cache._lru._data[context] - # cautiously clear out the global context cache too - for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): - del openmmtools.cache.global_context_cache._lru._data[context] - - del sampler_context_cache, energy_context_cache - - if not dry: - del integrator, sampler - - if not dry: # pragma: no-cover - return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} - else: - return {"debug": {"sampler": sampler}} - - @staticmethod - def structural_analysis(scratch, shared) -> dict: - # don't put energy analysis in here, it uses the open file reporter - # whereas structural stuff requires that the file handle is closed - # TODO: we should just make openfe_analysis write an npz instead! - analysis_out = scratch / "structural_analysis.json" - - ret = subprocess.run( - [ - "openfe_analysis", # CLI entry point - "RFE_analysis", # CLI option - str(shared), # Where the simulation.nc fille - str(analysis_out), # Where the analysis json file is written - ], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if ret.returncode: - return {"structural_analysis_error": ret.stderr} - - with open(analysis_out, "rb") as f: - data = json.load(f) - - savedir = pathlib.Path(shared) - if d := data["protein_2D_RMSD"]: - fig = plotting.plot_2D_rmsd(d) - fig.savefig(savedir / "protein_2D_RMSD.png") - plt.close(fig) - f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) - f2.savefig(savedir / "ligand_COM_drift.png") - plt.close(f2) - - f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) - f3.savefig(savedir / "ligand_RMSD.png") - plt.close(f3) - - # Save to numpy compressed format (~ 6x more space efficient than JSON) - np.savez_compressed( - shared / "structural_analysis.npz", - protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), - ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), - ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), - protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), - time_ps=np.asarray(data["time(ps)"], dtype=np.float32), - ) - - return {"structural_analysis": shared / "structural_analysis.npz"} - - def _execute( - self, - ctx: gufe.Context, - **kwargs, - ) -> dict[str, Any]: - log_system_probe(logging.INFO, paths=[ctx.scratch]) - - outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) - - structural_analysis_outputs = self.structural_analysis(ctx.scratch, ctx.shared) - - return { - "repeat_id": self._inputs["repeat_id"], - "generation": self._inputs["generation"], - **outputs, - **structural_analysis_outputs, - } diff --git a/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz b/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz deleted file mode 100644 index 8d4fe9d1..00000000 Binary files a/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz and /dev/null differ diff --git a/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz b/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz deleted file mode 100644 index 97ba8137..00000000 Binary files a/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz and /dev/null differ diff --git a/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz b/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz deleted file mode 100644 index e0bd6e83..00000000 Binary files a/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz and /dev/null differ diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py b/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py deleted file mode 100644 index 4bfdb155..00000000 --- a/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py +++ /dev/null @@ -1,122 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -import gzip - -import pytest -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin - -import openfe -from openfe.protocols.openmm_afe import ( - AbsoluteBindingComplexUnit, - AbsoluteBindingProtocol, - AbsoluteBindingProtocolResult, - AbsoluteBindingSolventUnit, -) - - -@pytest.fixture -def protocol(): - return AbsoluteBindingProtocol(AbsoluteBindingProtocol.default_settings()) - - -@pytest.fixture -def protocol_units(protocol, benzene_complex_system, T4_protein_component): - stateA = benzene_complex_system - stateB = openfe.ChemicalSystem( - {"protein": T4_protein_component, "solvent": openfe.SolventComponent()} - ) - pus = protocol.create( - stateA=stateA, - stateB=stateB, - mapping=None, - ) - return list(pus.protocol_units) - - -@pytest.fixture -def solvent_protocol_unit(protocol_units): - for pu in protocol_units: - if isinstance(pu, AbsoluteBindingSolventUnit): - return pu - - -@pytest.fixture -def complex_protocol_unit(protocol_units): - for pu in protocol_units: - if isinstance(pu, AbsoluteBindingComplexUnit): - return pu - - -@pytest.fixture -def protocol_result(abfe_transformation_json_path): - with gzip.open(abfe_transformation_json_path) as f: - pr = AbsoluteBindingProtocolResult.from_json(f) - return pr - - -class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin): - cls = AbsoluteBindingProtocol - key = None - repr = "AbsoluteBindingProtocol-" - - @pytest.fixture() - def instance(self, protocol): - return protocol - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteBindingSolventUnit(GufeTokenizableTestsMixin): - cls = AbsoluteBindingSolventUnit - repr = "AbsoluteBindingSolventUnit(Absolute Binding, benzene solvent leg" - key = None - - @pytest.fixture() - def instance(self, solvent_protocol_unit): - return solvent_protocol_unit - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteBindingComplexUnit(GufeTokenizableTestsMixin): - cls = AbsoluteBindingComplexUnit - repr = "AbsoluteBindingComplexUnit(Absolute Binding, benzene complex leg" - key = None - - @pytest.fixture() - def instance(self, complex_protocol_unit): - return complex_protocol_unit - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin): - cls = AbsoluteBindingProtocolResult - key = None - repr = "AbsoluteBindingProtocolResult-" - - @pytest.fixture() - def instance(self, protocol_result): - return protocol_result - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) diff --git a/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py b/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py deleted file mode 100644 index b444d865..00000000 --- a/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py +++ /dev/null @@ -1,116 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe -import json - -import gufe -import pytest -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin - -import openfe -from openfe.protocols import openmm_afe - - -@pytest.fixture -def protocol(): - return openmm_afe.AbsoluteSolvationProtocol( - openmm_afe.AbsoluteSolvationProtocol.default_settings() - ) - - -@pytest.fixture -def protocol_units(protocol, benzene_system): - pus = protocol.create( - stateA=benzene_system, - stateB=openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}), - mapping=None, - ) - return list(pus.protocol_units) - - -@pytest.fixture -def solvent_protocol_unit(protocol_units): - for pu in protocol_units: - if isinstance(pu, openmm_afe.AbsoluteSolvationSolventUnit): - return pu - - -@pytest.fixture -def vacuum_protocol_unit(protocol_units): - for pu in protocol_units: - if isinstance(pu, openmm_afe.AbsoluteSolvationVacuumUnit): - return pu - - -@pytest.fixture -def protocol_result(afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) - return pr - - -class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin): - cls = openmm_afe.AbsoluteSolvationProtocol - key = None - repr = "AbsoluteSolvationProtocol-" - - @pytest.fixture() - def instance(self, protocol): - return protocol - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin): - cls = openmm_afe.AbsoluteSolvationSolventUnit - repr = "AbsoluteSolvationSolventUnit(Absolute Solvation, benzene solvent leg" - key = None - - @pytest.fixture() - def instance(self, solvent_protocol_unit): - return solvent_protocol_unit - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteSolvationVacuumUnit(GufeTokenizableTestsMixin): - cls = openmm_afe.AbsoluteSolvationVacuumUnit - repr = "AbsoluteSolvationVacuumUnit(Absolute Solvation, benzene vacuum leg" - key = None - - @pytest.fixture() - def instance(self, vacuum_protocol_unit): - return vacuum_protocol_unit - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) - - -class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin): - cls = openmm_afe.AbsoluteSolvationProtocolResult - key = None - repr = "AbsoluteSolvationProtocolResult-" - - @pytest.fixture() - def instance(self, protocol_result): - return protocol_result - - def test_repr(self, instance): - """ - Overwrites the base `test_repr` call. - """ - assert isinstance(repr(instance), str) - assert self.repr in repr(instance) diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py b/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py deleted file mode 100644 index 322c0f95..00000000 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py +++ /dev/null @@ -1,105 +0,0 @@ -# This code is part of OpenFE and is licensed under the MIT license. -# For details, see https://github.com/OpenFreeEnergy/openfe - -import pytest -from gufe.tests.test_tokenization import GufeTokenizableTestsMixin -from openff.units import unit - -from openfe.protocols import openmm_rfe - -""" -todo: -- RelativeHybridTopologyProtocolResult -- RelativeHybridTopologyProtocol -- RelativeHybridTopologyProtocolUnit -""" - - -@pytest.fixture -def rfe_protocol(): - return openmm_rfe.RelativeHybridTopologyProtocol( - openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - ) - - -@pytest.fixture -def rfe_protocol_other_units(): - """Identical to rfe_protocol, but with `kcal / mol` as input unit instead of `kilocalorie_per_mole`.""" - new_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - new_settings.simulation_settings.early_termination_target_error = 0.0 * unit.kilocalorie/unit.mol # fmt: skip - return openmm_rfe.RelativeHybridTopologyProtocol(new_settings) - - -@pytest.fixture -def protocol_unit(rfe_protocol, benzene_system, toluene_system, benzene_to_toluene_mapping): - pus = rfe_protocol.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=[benzene_to_toluene_mapping], - ) - return list(pus.protocol_units)[0] - - -@pytest.mark.skip -class TestRelativeHybridTopologyProtocolResult(GufeTokenizableTestsMixin): - cls = openmm_rfe.RelativeHybridTopologyProtocolResult - repr = "" - key = "" - - @pytest.fixture() - def instance(self): - pass - - -class TestRelativeHybridTopologyProtocolOtherUnits(GufeTokenizableTestsMixin): - cls = openmm_rfe.RelativeHybridTopologyProtocol - key = None - repr = " None: + """Helper function for pulling all test data up-front. + + Parameters + ---------- + path : str + path to store the data - usually a pooch.os_cache instance. + + """ + downloader = pooch.DOIDownloader(progressbar=True) + + def _infer_processor(fname: str): + if fname.endswith("tar.gz"): + return pooch.Untar() + elif fname.endswith("zip"): + return pooch.Unzip() + else: + return None + + for d in zenodo_registry: + pooch.retrieve( + url=d["base_url"] + d["fname"], + known_hash=d["known_hash"], + fname=d["fname"], + processor=_infer_processor(d["fname"]), + downloader=downloader, + path=path, + ) diff --git a/src/openfe/data/_registry.py b/src/openfe/data/_registry.py new file mode 100644 index 00000000..7a87814d --- /dev/null +++ b/src/openfe/data/_registry.py @@ -0,0 +1,24 @@ +import pooch + +POOCH_CACHE = pooch.os_cache("openfe") + +zenodo_rfe_simulation_nc = dict( + base_url="doi:10.5281/zenodo.15375081/", + fname="simulation.nc", + known_hash="md5:bc4e842b47de17704d804ae345b91599", +) +zenodo_t4_lysozyme_traj = dict( + base_url="doi:10.5281/zenodo.15212342", + fname="t4_lysozyme_trajectory.zip", + known_hash="sha256:e985d055db25b5468491e169948f641833a5fbb67a23dbb0a00b57fb7c0e59c8", +) +zenodo_industry_benchmark_systems = dict( + base_url="doi:10.5281/zenodo.15212342", + fname="industry_benchmark_systems.zip", + known_hash="sha256:2bb5eee36e29b718b96bf6e9350e0b9957a592f6c289f77330cbb6f4311a07bd", +) +zenodo_data_registry = [ + zenodo_rfe_simulation_nc, + zenodo_t4_lysozyme_traj, + zenodo_industry_benchmark_systems, +] diff --git a/openfe/due.py b/src/openfe/due.py similarity index 100% rename from openfe/due.py rename to src/openfe/due.py diff --git a/openfe/protocols/restraint_utils/__init__.py b/src/openfe/orchestration/__init__.py similarity index 100% rename from openfe/protocols/restraint_utils/__init__.py rename to src/openfe/orchestration/__init__.py diff --git a/openfe/protocols/__init__.py b/src/openfe/protocols/__init__.py similarity index 100% rename from openfe/protocols/__init__.py rename to src/openfe/protocols/__init__.py diff --git a/src/openfe/protocols/openmm_afe/__init__.py b/src/openfe/protocols/openmm_afe/__init__.py new file mode 100644 index 00000000..6995d8ed --- /dev/null +++ b/src/openfe/protocols/openmm_afe/__init__.py @@ -0,0 +1,56 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Run absolute free energy calculations using OpenMM and OpenMMTools. + +""" + +from .abfe_units import ( + ABFEComplexAnalysisUnit, + ABFEComplexSetupUnit, + ABFEComplexSimUnit, + ABFESolventAnalysisUnit, + ABFESolventSetupUnit, + ABFESolventSimUnit, +) +from .afe_protocol_results import ( + AbsoluteBindingProtocolResult, + AbsoluteSolvationProtocolResult, +) +from .ahfe_units import ( + AHFESolventAnalysisUnit, + AHFESolventSetupUnit, + AHFESolventSimUnit, + AHFEVacuumAnalysisUnit, + AHFEVacuumSetupUnit, + AHFEVacuumSimUnit, +) +from .equil_binding_afe_method import ( + AbsoluteBindingProtocol, + AbsoluteBindingSettings, +) +from .equil_solvation_afe_method import ( + AbsoluteSolvationProtocol, + AbsoluteSolvationSettings, +) + +__all__ = [ + "AbsoluteSolvationProtocol", + "AbsoluteSolvationSettings", + "AbsoluteSolvationProtocolResult", + "AHFESolventSetupUnit", + "AHFESolventSimUnit", + "AHFESolventAnalysisUnit", + "AHFEVacuumSetupUnit", + "AHFEVacuumSimUnit", + "AHFEVacuumAnalysisUnit", + "AbsoluteBindingProtocol", + "AbsoluteBindingSettings", + "AbsoluteBindingProtocolResult", + "ABFEComplexSetupUnit", + "ABFEComplexSimUnit", + "ABFEComplexAnalysisUnit", + "ABFESolventSetupUnit", + "ABFESolventSimUnit", + "ABFESolventAnalysisUnit", +] diff --git a/src/openfe/protocols/openmm_afe/abfe_units.py b/src/openfe/protocols/openmm_afe/abfe_units.py new file mode 100644 index 00000000..f35fa808 --- /dev/null +++ b/src/openfe/protocols/openmm_afe/abfe_units.py @@ -0,0 +1,513 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""ABFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.abfe_units` +======================================================================== +This module defines the ProtocolUnits for the +:class:`AbsoluteBindingProtocol`. +""" + +import logging +import pathlib + +import MDAnalysis as mda +import numpy as np +import numpy.typing as npt +from gufe import ( + SolventComponent, +) +from gufe.components import Component +from openff.units import Quantity +from openff.units.openmm import to_openmm +from openmm import System +from openmm import unit as ommunit +from openmm.app import Topology as omm_topology +from openmmtools.states import ThermodynamicState +from rdkit import Chem + +from openfe.protocols.openmm_afe.equil_afe_settings import ( + BoreschRestraintSettings, + SettingsBaseModel, +) +from openfe.protocols.openmm_utils import system_validation +from openfe.protocols.restraint_utils import geometry +from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry +from openfe.protocols.restraint_utils.openmm import omm_restraints +from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint + +from .base_afe_units import ( + BaseAbsoluteMultiStateAnalysisUnit, + BaseAbsoluteMultiStateSimulationUnit, + BaseAbsoluteSetupUnit, +) + +logger = logging.getLogger(__name__) + + +class ComplexComponentsMixin: + def _get_components(self): + """ + Get the relevant components for a complex transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A dict of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : ProteinComponent | None + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp + return alchem_comps, solv_comp, prot_comp, off_comps + + +class ComplexSettingsMixin: + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a complex transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : ABFEPreEquilOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + * restraint_settings: BaseRestraintSettings + """ + prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined] + + settings = {} + settings["forcefield_settings"] = prot_settings.forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.complex_solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.complex_lambda_settings + settings["engine_settings"] = prot_settings.engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.complex_equil_output_settings + settings["simulation_settings"] = prot_settings.complex_simulation_settings + settings["output_settings"] = prot_settings.complex_output_settings + settings["restraint_settings"] = prot_settings.restraint_settings + + return settings + + +class ABFEComplexSetupUnit(ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteSetupUnit): + """ + Setup unit for the complex phase of absolute binding free energy + transformations. + """ + + simtype = "complex" + + @staticmethod + def _get_mda_universe( + topology: omm_topology, + positions: ommunit.Quantity | None, + trajectory: pathlib.Path | None, + ) -> mda.Universe: + """ + Helper method to get a Universe from an openmm Topology, + and either an input trajectory or a set of positions. + + Parameters + ---------- + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + positions: openmm.unit.Quantity | None + The System's current positions. + Used if a trajectory file is None or is not a file. + trajectory: pathlib.Path | None + A Path to a trajectory file to read positions from. + + Returns + ------- + mda.Universe + An MDAnalysis Universe of the System. + """ + from MDAnalysis.coordinates.memory import MemoryReader + + # If the trajectory file doesn't exist, then we use positions + if trajectory is not None and trajectory.is_file(): + return mda.Universe( + topology, + trajectory, + topology_format="OPENMMTOPOLOGY", + ) + else: + if positions is None: + raise ValueError("No positions to create the Universe with") + + # Positions is an openmm Quantity in nm we need + # to convert to angstroms + return mda.Universe( + topology, + np.array(positions._value) * 10, + topology_format="OPENMMTOPOLOGY", + trajectory_format=MemoryReader, + ) + + @staticmethod + def _get_idxs_from_residxs( + topology: omm_topology, + residxs: list[int], + ) -> list[int]: + """ + Helper method to get the a list of atom indices which belong to a list + of residues. + + Parameters + ---------- + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + residxs : list[int] + A list of residue numbers who's atoms we should get atom indices. + + Returns + ------- + atom_ids : list[int] + A list of atom indices. + + TODO + ---- + * Check how this works when we deal with virtual sites. + """ + atom_ids = [] + + for r in topology.residues(): + if r.index in residxs: + atom_ids.extend([at.index for at in r.atoms()]) + + return atom_ids + + @staticmethod + def _get_boresch_restraint( + universe: mda.Universe, + guest_rdmol: Chem.Mol, + guest_atom_ids: list[int], + host_atom_ids: list[int], + temperature: Quantity, + settings: BoreschRestraintSettings, + ) -> tuple[BoreschRestraintGeometry, BoreschRestraint]: + """ + Get a Boresch-like restraint Geometry and OpenMM restraint force + supplier. + + Parameters + ---------- + universe : mda.Universe + An MDAnalysis Universe defining the system to get the restraint for. + guest_rdmol : Chem.Mol + An RDKit Molecule defining the guest molecule in the system. + guest_atom_ids: list[int] + A list of atom indices defining the guest molecule in the universe. + host_atom_ids : list[int] + A list of atom indices defining the host molecules in the universe. + temperature : openff.units.Quantity + The temperature of the simulation where the restraint will be added. + settings : BoreschRestraintSettings + Settings on how the Boresch-like restraint should be defined. + + Returns + ------- + geom : BoreschRestraintGeometry + A class defining the Boresch-like restraint. + restraint : BoreschRestraint + A factory class for generating Boresch restraints in OpenMM. + """ + # Take the minimum of the two possible force constants to check against + frc_const = min(settings.K_thetaA, settings.K_thetaB) + + geom = geometry.boresch.find_boresch_restraint( + universe=universe, + guest_rdmol=guest_rdmol, + guest_idxs=guest_atom_ids, + host_idxs=host_atom_ids, + host_selection=settings.host_selection, + anchor_finding_strategy=settings.anchor_finding_strategy, + dssp_filter=settings.dssp_filter, + rmsf_cutoff=settings.rmsf_cutoff, + host_min_distance=settings.host_min_distance, + host_max_distance=settings.host_max_distance, + angle_force_constant=frc_const, + temperature=temperature, + ) + + restraint = omm_restraints.BoreschRestraint(settings) + return geom, restraint + + def _add_restraints( + self, + system: System, + topology: omm_topology, + positions: ommunit.Quantity, + alchem_comps: dict[str, list[Component]], + comp_resids: dict[Component, npt.NDArray], + settings: dict[str, SettingsBaseModel], + ) -> tuple[ + Quantity, + System, + geometry.HostGuestRestraintGeometry, + ]: + """ + Find and add restraints to the OpenMM System. + + Notes + ----- + Currently, only Boresch-like restraints are supported. + + Parameters + ---------- + system : openmm.System + The System to add the restraint to. + topology : openmm.app.Topology + An OpenMM Topology that defines the System. + positions: openmm.unit.Quantity + The System's current positions. + Used if a trajectory file isn't found. + alchem_comps: dict[str, list[Component]] + A dictionary with a list of alchemical components + in both state A and B. + comp_resids: dict[Component, npt.NDArray] + A dictionary keyed by each Component in the System + which contains arrays with the residue indices that is contained + by that Component. + settings : dict[str, SettingsBaseModel] + A dictionary of settings that defines how to find and set + the restraint. + + Returns + ------- + correction : openff.units.Quantity + The standard state correction for the restraint. + system : openmm.System + A copy of the System with the restraint added. + rest_geom : geometry.HostGuestRestraintGeometry + The restraint Geometry object. + """ + if self.verbose: + self.logger.info("Generating restraints") + + # Get the guest rdmol + guest_rdmol = alchem_comps["stateA"][0].to_rdkit() + + # sanitize the rdmol if possible - warn if you can't + err = Chem.SanitizeMol(guest_rdmol, catchErrors=True) + + if err: + msg = "restraint generation: could not sanitize ligand rdmol" + logger.warning(msg) + + # Get the guest idxs + # concatenate a list of residue indexes for all alchemical components + residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) + + # get the alchemicical atom ids + guest_atom_ids = self._get_idxs_from_residxs(topology, residxs) + + # Now get the host idxs + # We assume this is everything but the alchemical component + # and the solvent. + solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)] + exclude_comps = [alchem_comps["stateA"]] + solv_comps + residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps]) + + host_atom_ids = self._get_idxs_from_residxs(topology, residxs) + + # Finally create an MDAnalysis Universe + # We try to pass the equilibration production file path through + # In some cases (debugging / dry runs) this won't be available + # so we'll default to using input positions. + univ = self._get_mda_universe( + topology, + positions, + self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename, + ) + + if isinstance(settings["restraint_settings"], BoreschRestraintSettings): + rest_geom, restraint = self._get_boresch_restraint( + univ, + guest_rdmol, + guest_atom_ids, + host_atom_ids, + settings["thermo_settings"].temperature, + settings["restraint_settings"], + ) + else: + # TODO turn this into a direction for different restraint types supported? + raise NotImplementedError("Other restraint types are not yet available") + + if self.verbose: + self.logger.info(f"restraint geometry is: {rest_geom}") + + # We need a temporary thermodynamic state to add the restraint + # & get the correction + thermodynamic_state = ThermodynamicState( + system, + temperature=to_openmm(settings["thermo_settings"].temperature), + pressure=to_openmm(settings["thermo_settings"].pressure), + ) + + # Add the force to the thermodynamic state + restraint.add_force( + thermodynamic_state, + rest_geom, + controlling_parameter_name="lambda_restraints", + ) + # Get the standard state correction as a unit.Quantity + correction = restraint.get_standard_state_correction( + thermodynamic_state, + rest_geom, + ) + + return ( + correction, + # Remove the thermostat, otherwise you'll get an + # Andersen thermostat by default! + thermodynamic_state.get_system(remove_thermostat=True), + rest_geom, + ) + + +class ABFEComplexSimUnit( + ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteMultiStateSimulationUnit +): + """ + Multi-state simulation (e.g. multi replica methods like Hamiltonian + replica exchange) unit for the complex phase of absolute binding + free energy transformations. + """ + + simtype = "complex" + + +class ABFEComplexAnalysisUnit(ComplexSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit): + """ + Analysis unit for multi-state simulations with the complex phase + of absolute binding free energy transformations. + """ + + simtype = "complex" + + +class SolventComponentsMixin: + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : ProteinComponent | None + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp just return None + return alchem_comps, solv_comp, None, off_comps + + +class SolventSettingsMixin: + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a solvent transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : ABFEPreEquilOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined] + + settings = {} + settings["forcefield_settings"] = prot_settings.forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvent_solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.solvent_lambda_settings + settings["engine_settings"] = prot_settings.engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings + + return settings + + +class ABFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit): + """ + Setup unit for the solvent phase of absolute binding free energy + transformations. + """ + + simtype = "solvent" + + +class ABFESolventSimUnit( + SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit +): + """ + Multi-state simulation (e.g. multi replica methods like Hamiltonian + replica exchange) unit for the solvent phase of absolute binding + free energy transformations. + """ + + simtype = "solvent" + + +class ABFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit): + """ + Analysis unit for multi-state simulations with the solvent phase + of absolute binding free energy transformations. + """ + + simtype = "solvent" diff --git a/src/openfe/protocols/openmm_afe/afe_protocol_results.py b/src/openfe/protocols/openmm_afe/afe_protocol_results.py new file mode 100644 index 00000000..38d1e187 --- /dev/null +++ b/src/openfe/protocols/openmm_afe/afe_protocol_results.py @@ -0,0 +1,550 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Result classes for the Absolute Free Energy Protocols +===================================================== + +This module implements :class:`gufe.ProtocolResult` classes for the absolute +free energy Protocols. + +Specifically it implements: + * AbsoluteBindingProtocolResult + * AbsoluteSolvationProtocolResult +""" + +import itertools +import logging +import pathlib +import warnings +from typing import Optional, Union + +import gufe +import numpy as np +import numpy.typing as npt +from openff.units import Quantity +from openff.units import unit as offunit +from openmmtools import multistate + +from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry + +logger = logging.getLogger(__name__) + + +class AbsoluteProtocolResultMixin: + bound_state = "solvent" + unbound_state = "vacuum" + + def __init__(self, **data): + super().__init__(**data) + # TODO: Detect when we have extensions and stitch these together? + if any( + len(pur_list) > 2 + for pur_list in itertools.chain( + self.data[self.bound_state].values(), self.data[self.unbound_state].values() + ) + ): + raise NotImplementedError("Can't stitch together results yet") + + def get_forward_and_reverse_energy_analysis( + self, + ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: + """ + Get the reverse and forward analysis of the free energies. + + Returns + ------- + forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]] + A dictionary, keyed for each leg of the thermodynamic cycle, + either ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries containing the forward + and reverse analysis of each repeat of that simulation type. + + The forward and reverse analysis dictionaries contain: + - `fractions`: npt.NDArray + The fractions of data used for the estimates + - `forward_DGs`, `reverse_DGs`: openff.units.Quantity + The forward and reverse estimates for each fraction of data + - `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity + The forward and reverse estimate uncertainty for each + fraction of data. + + If one of the cycle leg list entries is ``None``, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + Raises + ------ + UserWarning + * If any of the forward and reverse dictionaries are ``None`` in a + given thermodynamic cycle leg. + """ + + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} + + for key in [self.bound_state, self.unbound_state]: + forward_reverse[key] = [ + pus[0].outputs["forward_and_reverse_energies"] + for pus in self.data[key].values() # type: ignore[attr-defined] + ] + + if None in forward_reverse[key]: + wmsg = ( + "One or more ``None`` entries were found in the forward " + f"and reverse dictionaries of the repeats of the {key} " + "calculations. This is likely caused by an MBAR convergence " + "failure caused by too few independent samples when " + "calculating the free energies of the 10% timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get a the MBAR overlap estimates for all legs of the simulation. + + Returns + ------- + overlap_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary keyed for each leg of the thermodynamic cycle, either + ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries with the MBAR overlap + estimates of each repeat of that simulation type. + + The underlying MBAR dictionaries contain the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + + for key in [self.bound_state, self.unbound_state]: + overlap_stats[key] = [ + pus[0].outputs["unit_mbar_overlap"] + for pus in self.data[key].values() # type: ignore[attr-defined] + ] + + return overlap_stats + + def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]: + """ + Get the replica exchange transition statistics for all + legs of the simulation. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : dict[str, list[dict[str, npt.NDArray]]] + A dictionary with keys for each leg of the thermodynamic cycle, either + ``solvent`` and ``vaccuum` for a solvation free energy or + ``solvent`` and ``complex`` for a binding free energy, + with each containing a list of dictionaries containing the replica + transition statistics for each repeat of that simulation type. + + The replica transition statistics dictionaries contain the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} + try: + for key in [self.bound_state, self.unbound_state]: + repex_stats[key] = [ + pus[0].outputs["replica_exchange_statistics"] + for pus in self.data[key].values() # type: ignore[attr-defined] + ] + except KeyError: + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> dict[str, list[npt.NDArray]]: + """ + Get the timeseries of replica states for all simulation legs. + + Returns + ------- + replica_states : dict[str, list[npt.NDArray]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists of replica states timeseries for each repeat of that + simulation type. + """ + replica_states: dict[str, list[npt.NDArray]] = { + self.bound_state: [], + self.unbound_state: [], + } + + def is_file(filename: str): + p = pathlib.Path(filename) + + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + + return p + + def get_replica_state(nc, chk): + nc = is_file(nc) + dir_path = nc.parents[0] + chk = is_file(dir_path / chk).name + + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode="r" + ) + + retval = np.asarray(reporter.read_replica_thermodynamic_states()) + reporter.close() + + return retval + + for key in [self.bound_state, self.unbound_state]: + for pus in self.data[key].values(): # type: ignore[attr-defined] + states = get_replica_state( + pus[0].outputs["trajectory"], + pus[0].outputs["checkpoint"], + ) + replica_states[key].append(states) + + return replica_states + + def equilibration_iterations(self) -> dict[str, list[float]]: + """ + Get the number of equilibration iterations for each simulation. + + Returns + ------- + equilibration_lengths : dict[str, list[float]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists containing the number of equilibration iterations for + each repeat of that simulation type. + """ + equilibration_lengths: dict[str, list[float]] = {} + + for key in [self.bound_state, self.unbound_state]: + equilibration_lengths[key] = [ + pus[0].outputs["equilibration_iterations"] + for pus in self.data[key].values() # type: ignore[attr-defined] + ] + + return equilibration_lengths + + def production_iterations(self) -> dict[str, list[float]]: + """ + Get the number of production iterations for each simulation. + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : dict[str, list[float]] + Dictionary keyed for each leg of the thermodynamic cycle, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + with lists containing the number of equilibration iterations for + each repeat of that simulation type. + """ + production_lengths: dict[str, list[float]] = {} + + for key in [self.bound_state, self.unbound_state]: + production_lengths[key] = [ + pus[0].outputs["production_iterations"] + for pus in self.data[key].values() # type: ignore[attr-defined] + ] + + return production_lengths + + def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]: + """ + Get the system selection indices used to write PDB and + trajectory files. + + Returns + ------- + indices : dict[str, list[npt.NDArray]] + A dictionary keyed for each state, either + `solvent` and `vacuum` for solvation free energies, + or `complex` and `solvent` for binding free energies, + each containing a list of NDArrays containing the corresponding + full system atom indices for each atom written in the production + trajectory files for each replica. + """ + indices: dict[str, list[Optional[npt.NDArray]]] = {} + + for key in [self.bound_state, self.unbound_state]: + indices[key] = [] + for pus in self.data[key].values(): # type: ignore[attr-defined] + indices[key].append(pus[0].outputs["selection_indices"]) + + return indices + + +class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin): + """ + Protocol results with the output of a AbsoluteSolvationProtocol + """ + + bound_state = "solvent" + unbound_state = "vacuum" + + def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] + A dictionary, keyed `solvent` and `vacuum` for each leg + of the thermodynamic cycle, with lists of tuples containing + the individual free energy estimates and associated MBAR + uncertainties for each repeat of that simulation type. + """ + dGs = {} + + for state in [self.bound_state, self.unbound_state]: + state_dGs = [ + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + for pus in self.data[state].values() + ] + dGs[state] = state_dGs + + return dGs + + def get_estimate(self): + """Get the solvation free energy estimate for this calculation. + + Returns + ------- + dG : openff.units.Quantity + The solvation free energy. This is a Quantity defined with units. + """ + + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_dG = _get_average(individual_estimates["vacuum"]) + solv_dG = _get_average(individual_estimates["solvent"]) + + return vac_dG - solv_dG + + def get_uncertainty(self): + """Get the solvation free energy error for this calculation. + + Returns + ------- + err : openff.units.Quantity + The standard deviation between estimates of the solvation free + energy. This is a Quantity defined with units. + """ + + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + vac_err = _get_stdev(individual_estimates["vacuum"]) + solv_err = _get_stdev(individual_estimates["solvent"]) + + # return the combined error + return np.sqrt(vac_err**2 + solv_err**2) + + +class AbsoluteBindingProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin): + """ + Protocol results with the output of a AbsoluteBindingProtocol. + """ + + bound_state = "complex" + unbound_state = "solvent" + + def get_individual_estimates( + self, + ) -> dict[str, list[tuple[Quantity, Quantity]]]: + """ + Get the individual estimate of the free energies. + + Returns + ------- + dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]] + A dictionary, keyed `solvent`, `complex`, and 'standard_state' + representing each portion of the thermodynamic cycle, + with lists of tuples containing the individual free energy + estimates and, for 'solvent' and 'complex', the associated MBAR + uncertainties for each repeat of that simulation type. + + Notes + ----- + * Standard state correction has no error and so will return a value + of 0. + """ + complex_dGs = [] + correction_dGs = [] + solv_dGs = [] + + for pus in self.data["complex"].values(): + complex_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) + correction_dGs.append( + ( + pus[0].outputs["standard_state_correction"], + 0 * offunit.kilocalorie_per_mole, # correction has no error + ) + ) + + for pus in self.data["solvent"].values(): + solv_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) + + return { + "solvent": solv_dGs, + "complex": complex_dGs, + "standard_state_correction": correction_dGs, + } + + @staticmethod + def _add_complex_standard_state_corr( + complex_dG: list[tuple[Quantity, Quantity]], + standard_state_dG: list[tuple[Quantity, Quantity]], + ) -> list[tuple[Quantity, Quantity]]: + """ + Helper method to combine the + complex & standard state corrections legs. + + Parameters + ---------- + complex_dG : list[tuple[openff.units.Quantity, openff.units.Quantity]] + The individual estimates of the complex leg, + where the first entry of each tuple is the dG estimate + and the second entry is the MBAR error. + standard_state_dG : list[tuple[Quantity, Quantity]] + The individual standard state corrections for each corresponding + complex leg. The first entry is the correction, the second + is an empty error value of 0. + + Returns + ------- + combined_dG : list[tuple[openff.units.Quantity,openff.units. Quantity]] + A list of dG estimates & MBAR errors for the combined + complex & standard state correction of each repeat. + + Notes + ----- + We assume that both list of items are in the right order. + """ + combined_dG: list[tuple[Quantity, Quantity]] = [] + for comp, corr in zip(complex_dG, standard_state_dG): + # No need to convert unit types, since pint takes care of that + # except that mypy hates it because pint isn't typed properly... + # No need to add errors since there's just the one + combined_dG.append((comp[0] + corr[0], comp[1])) # type: ignore[operator] + + return combined_dG + + def get_estimate(self) -> Quantity: + """Get the binding free energy estimate for this calculation. + + Returns + ------- + dG : openff.units.Quantity + The binding free energy. This is a Quantity defined with units. + """ + + def _get_average(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.average(dGs) * u + + individual_estimates = self.get_individual_estimates() + complex_dG = _get_average( + self._add_complex_standard_state_corr( + individual_estimates["complex"], + individual_estimates["standard_state_correction"], + ) + ) + solv_dG = _get_average(individual_estimates["solvent"]) + + return -complex_dG + solv_dG + + def get_uncertainty(self) -> Quantity: + """Get the binding free energy error for this calculation. + + Returns + ------- + err : openff.units.Quantity + The standard deviation between estimates of the binding free + energy. This is a Quantity defined with units. + """ + + def _get_stdev(estimates): + # Get the unit value of the first value in the estimates + u = estimates[0][0].u + # Loop through estimates and get the free energy values + # in the unit of the first estimate + dGs = [i[0].to(u).m for i in estimates] + + return np.std(dGs) * u + + individual_estimates = self.get_individual_estimates() + + complex_err = _get_stdev( + self._add_complex_standard_state_corr( + individual_estimates["complex"], + individual_estimates["standard_state_correction"], + ) + ) + solv_err = _get_stdev(individual_estimates["solvent"]) + + # return the combined error + return np.sqrt(complex_err**2 + solv_err**2) + + def restraint_geometries(self) -> list[BoreschRestraintGeometry]: + """ + Get a list of the restraint geometries for the + complex simulations. These define the atoms that have + been restrained in the system. + + Returns + ------- + geometries : list[dict[str, Any]] + A list of dictionaries containing the details of the atoms + in the system that are involved in the restraint. + """ + geometries = [ + BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry"]) + for pus in self.data["complex"].values() + ] + + return geometries diff --git a/src/openfe/protocols/openmm_afe/ahfe_units.py b/src/openfe/protocols/openmm_afe/ahfe_units.py new file mode 100644 index 00000000..d496dd5f --- /dev/null +++ b/src/openfe/protocols/openmm_afe/ahfe_units.py @@ -0,0 +1,230 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units` +===================================================================== + +This module defines the ProtocolUnits for the +:class:`AbsoluteSolvationProtocol`. +""" + +import logging + +from openfe.protocols.openmm_afe.equil_afe_settings import ( + SettingsBaseModel, +) + +from ..openmm_utils import system_validation +from .base_afe_units import ( + BaseAbsoluteMultiStateAnalysisUnit, + BaseAbsoluteMultiStateSimulationUnit, + BaseAbsoluteSetupUnit, +) + +logger = logging.getLogger(__name__) + + +class VacuumComponentsMixin: + def _get_components(self): + """ + Get the relevant components for a vacuum transformation. + + Returns + ------- + alchem_comps : dict[str, list[Component]] + A list of alchemical components + solv_comp : None + For the gas phase transformation, None will always be returned + for the solvent component of the chemical system. + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[Component, OpenFF Molecule] + The openff Molecules to add to the system. This + is equivalent to the alchemical components in stateA (since + we only allow for disappearing ligands). + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} + + _, prot_comp, _ = system_validation.get_components(stateA) + + # Notes: + # 1. Our input state will contain a solvent, we ``None`` that out + # since this is the gas phase unit. + # 2. Our small molecules will always just be the alchemical components + # (of stateA since we enforce only one disappearing ligand) + return alchem_comps, None, prot_comp, off_comps + + +class VacuumSettingsMixin: + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a vacuum transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : SimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined] + + settings = {} + settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.vacuum_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings + settings["simulation_settings"] = prot_settings.vacuum_simulation_settings + settings["output_settings"] = prot_settings.vacuum_output_settings + + return settings + + +class AHFEVacuumSetupUnit(VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteSetupUnit): + """ + Setup unit for the vacuum phase of absolute hydration free energy + transformations. + """ + + simtype = "vacuum" + + +class AHFEVacuumSimUnit( + VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteMultiStateSimulationUnit +): + """ + Multi-state simulation (e.g. multi replica methods like Hamiltonian + replica exchange) unit for the vacuum phase of absolute hydration + free energy transformations. + """ + + simtype = "vacuum" + + +class AHFEVacuumAnalysisUnit(VacuumSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit): + """ + Analysis unit for multi-state simulations with the vacuum phase + of absolute hydration free energy transformations. + """ + + simtype = "vacuum" + + +class SolventComponentsMixin: + def _get_components(self): + """ + Get the relevant components for a solvent transformation. + + Returns + ------- + alchem_comps : dict[str, Component] + A list of alchemical components + solv_comp : SolventComponent + The SolventComponent of the system + prot_comp : Optional[ProteinComponent] + The protein component of the system, if it exists. + small_mols : dict[SmallMoleculeComponent: OFFMolecule] + SmallMoleculeComponents to add to the system. + """ + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] + + solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) + off_comps = {m: m.to_openff() for m in small_mols} + + # We don't need to check that solv_comp is not None, otherwise + # an error will have been raised when calling `validate_solvent` + # in the Protocol's `_create`. + # Similarly we don't need to check prot_comp since that's also + # disallowed on create + return alchem_comps, solv_comp, prot_comp, off_comps + + +class SolventSettingsMixin: + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Extract the relevant settings for a solvent transformation. + + Returns + ------- + settings : dict[str, SettingsBaseModel] + A dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * charge_settings : OpenFFPartialChargeSettings + * solvation_settings : OpenMMSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings: MultiStateOutputSettings + """ + prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined] + + settings = {} + settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.solvent_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings + + return settings + + +class AHFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit): + """ + Setup unit for the solvent phase of absolute hydration free energy + transformations. + """ + + simtype = "solvent" + + +class AHFESolventSimUnit( + SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit +): + """ + Multi-state simulation (e.g. multi replica methods like Hamiltonian + replica exchange) unit for the solvent phase of absolute hydration + free energy transformations. + """ + + simtype = "solvent" + + +class AHFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit): + """ + Analysis unit for multi-state simulations with the solvent phase + of absolute hydration free energy transformations. + """ + + simtype = "solvent" diff --git a/openfe/protocols/openmm_afe/base.py b/src/openfe/protocols/openmm_afe/base_afe_units.py similarity index 58% rename from openfe/protocols/openmm_afe/base.py rename to src/openfe/protocols/openmm_afe/base_afe_units.py index d28a630c..4095982a 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/src/openfe/protocols/openmm_afe/base_afe_units.py @@ -1,9 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -"""OpenMM Equilibrium AFE Protocol base classes -=============================================== +"""OpenMM AFE Protocol base classes +=================================== -Base classes for the equilibrium OpenMM absolute free energy ProtocolUnits. +Base classes for the OpenMM absolute free energy ProtocolUnits. Thist mostly implements BaseAbsoluteUnit whose methods can be overriden to define different types of alchemical transformations. @@ -15,14 +15,12 @@ TODO * Allow for a more flexible setting of Lambda regions. """ -from __future__ import annotations - import abc import copy import logging import os import pathlib -from typing import Any, Optional +from typing import Any import gufe import mdtraj as mdt @@ -30,10 +28,15 @@ import numpy as np import numpy.typing as npt import openmm import openmmtools -from gufe import ChemicalSystem, ProteinComponent, SmallMoleculeComponent, SolventComponent +from gufe import ( + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, +) from gufe.components import Component from openff.toolkit.topology import Molecule as OFFMolecule -from openff.units import Quantity, unit +from openff.units import Quantity +from openff.units import unit as offunit from openff.units.openmm import ensure_quantity, from_openmm, to_openmm from openmm import app from openmm import unit as ommunit @@ -71,67 +74,108 @@ from openfe.protocols.openmm_utils import ( system_creation, ) from openfe.protocols.openmm_utils.omm_settings import ( - BasePartialChargeSettings, SettingsBaseModel, ) +from openfe.protocols.openmm_utils.serialization import ( + deserialize, + make_vec3_box, + serialize, +) from openfe.protocols.restraint_utils import geometry +from openfe.protocols.restraint_utils.openmm import omm_restraints from openfe.utils import log_system_probe, without_oechem_backend logger = logging.getLogger(__name__) -class BaseAbsoluteUnit(gufe.ProtocolUnit): - """ - Base class for ligand absolute free energy transformations. - """ - - def __init__( +class AbsoluteUnitMixin: + def _prepare( self, - *, - protocol: gufe.Protocol, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - alchemical_components: dict[str, list[Component]], - generation: int = 0, - repeat_id: int = 0, - name: Optional[str] = None, + verbose: bool, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, ): """ + Set basepaths and do some initial logging. + Parameters ---------- - protocol : gufe.Protocol - protocol used to create this Unit. Contains key information such - as the settings. - stateA : ChemicalSystem - ChemicalSystem containing the components defining the state at - lambda 0. - stateB : ChemicalSystem - ChemicalSystem containing the components defining the state at - lambda 1. - alchemical_components : dict[str, Component] - the alchemical components for each state in this Unit - name : str, optional - Human-readable identifier for this Unit - repeat_id : int, optional - Identifier for which repeat (aka replica/clone) this Unit is, - default 0 - generation : int, optional - Generation counter which keeps track of how many times this repeat - has been extended, default 0. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath : pathlib.Path | None + Optional base path to write scratch files to. + shared_basepath : pathlib.Path | None + Optional base path to write shared files to. """ - super().__init__( - name=name, - protocol=protocol, - stateA=stateA, - stateB=stateB, - alchemical_components=alchemical_components, - repeat_id=repeat_id, - generation=generation, - ) + self.verbose = verbose + + if self.verbose: + self.logger.info("setting up alchemical system") # type: ignore[attr-defined] + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @abc.abstractmethod + def _get_settings(self) -> dict[str, SettingsBaseModel]: + """ + Get a dictionary with the following entries: + * forcefield_settings : OpenMMSystemGeneratorFFSettings + * thermo_settings : ThermoSettings + * solvation_settings : BaseSolvationSettings + * alchemical_settings : AlchemicalSettings + * lambda_settings : LambdaSettings + * engine_settings : OpenMMEngineSettings + * integrator_settings : IntegratorSettings + * equil_simulation_settings : MDSimulationSettings + * equil_output_settings : MDOutputSettings + * simulation_settings : MultiStateSimulationSettings + * output_settings : MultiStateOutputSettings + + Settings may change depending on what type of simulation you are + running. Cherry pick them and return them to be available later on. + + This method should also add various validation checks as necessary. + + Note + ---- + Must be implemented in the child class. + """ + ... + + +class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): + """ + Base class for setting up an absolute free energy transformations. + """ + + @abc.abstractmethod + def _get_components( + self, + ) -> tuple[ + dict[str, list[Component]], + gufe.SolventComponent | None, + gufe.ProteinComponent | None, + dict[SmallMoleculeComponent, OFFMolecule], + ]: + """ + Get the relevant components to create the alchemical system with. + + Note + ---- + Must be implemented in the child class. + """ + ... @staticmethod def _get_alchemical_indices( - omm_top: openmm.Topology, + omm_top: openmm.app.Topology, comp_resids: dict[Component, npt.NDArray], alchem_comps: dict[str, list[Component]], ) -> list[int]: @@ -282,134 +326,24 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): return equilibrated_positions, box - def _prepare( - self, - verbose: bool, - scratch_basepath: Optional[pathlib.Path], - shared_basepath: Optional[pathlib.Path], - ): - """ - Set basepaths and do some initial logging. - - Parameters - ---------- - verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging. - basepath : Optional[pathlib.Path] - Optional base path to write files to. - """ - self.verbose = verbose - - if self.verbose: - self.logger.info("setting up alchemical system") - - # set basepaths - def _set_optional_path(basepath): - if basepath is None: - return pathlib.Path(".") - return basepath - - self.scratch_basepath = _set_optional_path(scratch_basepath) - self.shared_basepath = _set_optional_path(shared_basepath) - - @abc.abstractmethod - def _get_components( - self, - ) -> tuple[ - dict[str, list[Component]], - Optional[gufe.SolventComponent], - Optional[gufe.ProteinComponent], - dict[SmallMoleculeComponent, OFFMolecule], - ]: - """ - Get the relevant components to create the alchemical system with. - - Note - ---- - Must be implemented in the child class. - """ - ... - - @abc.abstractmethod - def _handle_settings(self) -> dict[str, SettingsBaseModel]: - """ - Get a dictionary with the following entries: - * forcefield_settings : OpenMMSystemGeneratorFFSettings - * thermo_settings : ThermoSettings - * solvation_settings : BaseSolvationSettings - * alchemical_settings : AlchemicalSettings - * lambda_settings : LambdaSettings - * engine_settings : OpenMMEngineSettings - * integrator_settings : IntegratorSettings - * equil_simulation_settings : MDSimulationSettings - * equil_output_settings : MDOutputSettings - * simulation_settings : MultiStateSimulationSettings - * output_settings : MultiStateOutputSettings - - Settings may change depending on what type of simulation you are - running. Cherry pick them and return them to be available later on. - - This method should also add various validation checks as necessary. - - Note - ---- - Must be implemented in the child class. - """ - ... - - def _get_system_generator( - self, settings: dict[str, SettingsBaseModel], solvent_comp: Optional[SolventComponent] - ) -> SystemGenerator: - """ - Get a system generator through the system creation - utilities - - Parameters - ---------- - settings : dict[str, SettingsBaseModel] - A dictionary of settings object for the unit. - solvent_comp : Optional[SolventComponent] - The solvent component of this system, if there is one. - - Returns - ------- - system_generator : openmmforcefields.generator.SystemGenerator - System Generator to parameterise this unit. - """ - ffcache = settings["output_settings"].forcefield_cache - if ffcache is not None: - ffcache = self.shared_basepath / ffcache - - # Block out oechem backend to avoid any issues with - # smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system_generator = system_creation.get_system_generator( - forcefield_settings=settings["forcefield_settings"], - integrator_settings=settings["integrator_settings"], - thermo_settings=settings["thermo_settings"], - cache=ffcache, - has_solvent=solvent_comp is not None, - ) - return system_generator - @staticmethod def _assign_partial_charges( partial_charge_settings: OpenFFPartialChargeSettings, - smc_components: dict[SmallMoleculeComponent, OFFMolecule], + small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> None: """ - Assign partial charges to SMCs. + Assign partial charges to the OpenFF Molecules associated with + all the SmallMoleculeComponents in the transformation. Parameters ---------- charge_settings : OpenFFPartialChargeSettings Settings for controlling how the partial charges are assigned. - smc_components : dict[SmallMoleculeComponent, openff.toolkit.Molecule] - Dictionary of OpenFF Molecules to add, keyed by - SmallMoleculeComponent. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules to add, keyed by their + associated SmallMoleculeComponent. """ - for mol in smc_components.values(): + for mol in small_mols.values(): charge_generation.assign_offmol_partial_charges( offmol=mol, overwrite=False, @@ -419,13 +353,57 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): nagl_model=partial_charge_settings.nagl_model, ) + @staticmethod + def _get_system_generator( + settings: dict[str, SettingsBaseModel], + solvent_component: SolventComponent | None, + openff_molecules: list[OFFMolecule], + ffcache: pathlib.Path | None, + ) -> SystemGenerator: + """ + Get a system generator through the system creation + utilities + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + A dictionary of settings object for the unit. + solvent_comp : SolventComponent | None + The solvent component of this system, if there is one. + openff_molecules : list[openff.toolkit.Molecule] | None + A list of OpenFF Molecules to generate templates for, if any. + ffcache : pathlib.Path | None + Path to the force field parameter cache. + + Returns + ------- + system_generator : openmmforcefields.generator.SystemGenerator + System Generator to parameterise this unit. + """ + system_generator = system_creation.get_system_generator( + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], + cache=ffcache, + has_solvent=solvent_component is not None, + ) + + # Handle openff Molecule templates + # TODO: revisit this once the SystemGenerator update happens + if openff_molecules is None: + return system_generator + + # Register all the templates, pass unique molecules to avoid clashes + system_generator.add_molecules(list(set(openff_molecules))) + + return system_generator + + @staticmethod def _get_modeller( - self, - protein_component: Optional[ProteinComponent], - solvent_component: Optional[SolventComponent], - smc_components: dict[SmallMoleculeComponent, OFFMolecule], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], system_generator: SystemGenerator, - partial_charge_settings: BasePartialChargeSettings, solvation_settings: BaseSolvationSettings, ) -> tuple[app.Modeller, dict[Component, npt.NDArray]]: """ @@ -434,18 +412,15 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): Parameters ---------- - protein_component : Optional[ProteinComponent] + protein_component : ProteinComponent | None Protein Component, if it exists. - solvent_component : Optional[ProteinCompoinent] + solvent_component : SolventComponent | None Solvent Component, if it exists. - smc_components : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] Dictionary of OpenFF Molecules to add, keyed by SmallMoleculeComponent. system_generator : openmmforcefields.generator.SystemGenerator System Generator to parameterise this unit. - partial_charge_settings : BasePartialChargeSettings - Settings detailing how to assign partial charges to the - SMCs of the system. solvation_settings : BaseSolvationSettings Settings detailing how to solvate the system. @@ -457,39 +432,23 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): comp_resids : dict[Component, npt.NDArray] Dictionary of residue indices for each component in system. """ - if self.verbose: - self.logger.info("Parameterizing molecules") - - # Assign partial charges to smcs - self._assign_partial_charges(partial_charge_settings, smc_components) - - # TODO: guard the following from non-RDKit backends - # force the creation of parameters for the small molecules - # this is necessary because we need to have the FF generated ahead - # of solvating the system. - # Block out oechem backend to avoid any issues with - # smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - for mol in smc_components.values(): - system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) - - # get OpenMM modeller + dictionary of resids for each component - system_modeller, comp_resids = system_creation.get_omm_modeller( - protein_comp=protein_component, - solvent_comp=solvent_component, - small_mols=smc_components, - omm_forcefield=system_generator.forcefield, - solvent_settings=solvation_settings, - ) + # get OpenMM modeller + dictionary of resids for each component + system_modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=small_mols, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) return system_modeller, comp_resids def _get_omm_objects( self, settings: dict[str, SettingsBaseModel], - protein_component: Optional[ProteinComponent], - solvent_component: Optional[SolventComponent], - smc_components: dict[SmallMoleculeComponent, OFFMolecule], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], ) -> tuple[ app.Topology, openmm.System, @@ -504,12 +463,13 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ---------- settings : dict[str, SettingsBaseModel] Protocol settings - protein_component : Optional[ProteinComponent] + protein_component : ProteinComponent | None Protein component for the system. - solvent_component : Optional[SolventComponent] + solvent_component : SolventComponent | None Solvent component for the system. - smc_components : dict[str, OFFMolecule] - SmallMoleculeComponents defining ligands to be added to the system + small_mols : dict[str, openff.toolkit.Molecule] + Dictionary of SmallMoleculeComponents and OpenFF Molecules + defining the ligands to be added to the system Returns ------- @@ -525,76 +485,33 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): if self.verbose: self.logger.info("Parameterizing system") - system_generator = self._get_system_generator( - settings=settings, solvent_comp=solvent_component - ) + with without_oechem_backend(): + system_generator = self._get_system_generator( + settings=settings, + solvent_component=solvent_component, + openff_molecules=list(small_mols.values()), + ffcache=self.shared_basepath / settings["output_settings"].forcefield_cache, + ) - modeller, comp_resids = self._get_modeller( - protein_component=protein_component, - solvent_component=solvent_component, - smc_components=smc_components, - system_generator=system_generator, - partial_charge_settings=settings["charge_settings"], - solvation_settings=settings["solvation_settings"], - ) + modeller, comp_resids = self._get_modeller( + protein_component=protein_component, + solvent_component=solvent_component, + small_mols=small_mols, + system_generator=system_generator, + solvation_settings=settings["solvation_settings"], + ) + + system = system_generator.create_system( + topology=modeller.topology, + molecules=list(small_mols.values()), + ) topology = modeller.getTopology() # roundtrip positions to remove vec3 issues positions = to_openmm(from_openmm(modeller.getPositions())) - # Block out oechem backend to avoid any issues with - # smiles roundtripping between rdkit and oechem - with without_oechem_backend(): - system = system_generator.create_system( - topology=modeller.topology, - molecules=list(smc_components.values()), - ) - - # Check and fail early on the presence of virtual sites - # and multistate sampler not using velocity restart - if not settings["integrator_settings"].reassign_velocities: - has_vsite = any(system.isVirtualSite(i) for i in range(system.getNumParticles())) - if has_vsite: - errmsg = "Simulations with virtual sites without velocity reassignment are unstable" - raise ValueError(errmsg) - return topology, system, positions, comp_resids - def _get_lambda_schedule( - self, settings: dict[str, SettingsBaseModel] - ) -> dict[str, list[float]]: - """ - Create the lambda schedule - - Parameters - ---------- - settings : dict[str, SettingsBaseModel] - Settings for the unit. - - Returns - ------- - lambdas : dict[str, list[float]] - - TODO - ---- - * Augment this by using something akin to the RFE protocol's - LambdaProtocol - """ - lambdas = dict() - - lambda_elec = settings["lambda_settings"].lambda_elec - lambda_vdw = settings["lambda_settings"].lambda_vdw - lambda_rest = settings["lambda_settings"].lambda_restraints - - # Reverse lambda schedule for vdw, elect, and restraints - # since in AbsoluteAlchemicalFactory 1 means fully - # interacting (which would be non-interacting for us) - lambdas["lambda_electrostatics"] = [1 - x for x in lambda_elec] - lambdas["lambda_sterics"] = [1 - x for x in lambda_vdw] - lambdas["lambda_restraints"] = [x for x in lambda_rest] - - return lambdas - def _add_restraints( self, system: openmm.System, @@ -604,15 +521,14 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): comp_resids: dict[Component, npt.NDArray], settings: dict[str, SettingsBaseModel], ) -> tuple[ - Optional[GlobalParameterState], - Optional[Quantity], - Optional[openmm.System], - Optional[geometry.BaseRestraintGeometry], + Quantity | None, + openmm.System | None, + geometry.BaseRestraintGeometry | None, ]: """ Placeholder method to add restraints if necessary """ - return None, None, system, None + return None, system, None def _get_alchemical_system( self, @@ -680,15 +596,259 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): return alchemical_factory, alchemical_system, alchemical_indices + @staticmethod + def _subsample_topology( + topology: openmm.app.Topology, + positions: openmm.unit.Quantity, + output_selection: str, + output_file: pathlib.Path, + ) -> npt.NDArray: + """ + Subsample the system based on user-selected output selection + and write the subsampled topology to a PDB file. + + Parameters + ---------- + topology : openmm.app.Topology + The system topology to subsample. + positions : openmm.unit.Quantity + The system positions. + output_selection : str + An MDTraj selection string to subsample the topology with. + output_file : pathlib.Path + Path to the file to write the PDB to. + + Returns + ------- + selection_indices : npt.NDArray + The indices of the subselected system. + """ + mdt_top = mdt.Topology.from_openmm(topology) + selection_indices = mdt_top.select(output_selection) + + # Write out the subselected structure to PDB if not empty + if len(selection_indices) > 0: + traj = mdt.Trajectory( + positions[selection_indices, :], + mdt_top.subset(selection_indices), + ) + traj.save_pdb(output_file) + + return selection_indices + + def run( + self, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, + ) -> dict[str, Any]: + """Run the setup phase of an absolute free energy calculation. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary alchemical + system components (topology, system, etc...) but without + running the simulation, default False + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging, default True + scratch_basepath : pathlib.Path | None + Path to the scratch (temporary) directory space. Defaults to the + current working directory if ``None``. + shared_basepath : pathlib.Path | None + Path to the shared (persistent) directory space. Defaults to the + current working directory if ``None``. + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + """ + # General preparation tasks + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get components + alchem_comps, solv_comp, prot_comp, small_mols = self._get_components() + + # Get settings + settings = self._get_settings() + + # Assign partial charges now to avoid any discrepancies later + self._assign_partial_charges(settings["charge_settings"], small_mols) + + # Get OpenMM topology, positions, system, and comp_resids + omm_topology, omm_system, positions, comp_resids = self._get_omm_objects( + settings=settings, + protein_component=prot_comp, + solvent_component=solv_comp, + small_mols=small_mols, + ) + + # Pre-equilbrate System (Test + Avoid NaNs + get stable system) + positions, box_vectors = self._pre_equilibrate( + omm_system, omm_topology, positions, settings, dry + ) + + # Add restraints + # Note: when no restraint is applied, restrained_omm_system == omm_system + ( + standard_state_corr, + restrained_omm_system, + restraint_geometry, + ) = self._add_restraints( + omm_system, + omm_topology, + positions, + alchem_comps, + comp_resids, + settings, + ) + + # Get alchemical system + alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( + topology=omm_topology, + system=restrained_omm_system, + comp_resids=comp_resids, + alchem_comps=alchem_comps, + alchemical_settings=settings["alchemical_settings"], + ) + + # Subselect system based on user inputs & write initial PDB + selection_indices = self._subsample_topology( + topology=omm_topology, + positions=positions, + output_selection=settings["output_settings"].output_indices, + output_file=self.shared_basepath / settings["output_settings"].output_structure, + ) + + # Serialize relevant outputs + system_outfile = self.shared_basepath / "alchemical_system.xml.bz2" + serialize(alchem_system, system_outfile) + + positions_outfile = self.shared_basepath / "system_positions.npy" + npy_positions = from_openmm(positions).to("nanometer").m + np.save(positions_outfile, npy_positions) + + # Set the PDB file name + if len(selection_indices) > 0: + pdb_structure = self.shared_basepath / settings["output_settings"].output_structure + else: + pdb_structure = None + + unit_results_dict = { + "system": system_outfile, + "positions": positions_outfile, + "pdb_structure": pdb_structure, + "selection_indices": selection_indices, + "box_vectors": from_openmm(box_vectors), + } + + if standard_state_corr is not None: + unit_results_dict["standard_state_correction"] = standard_state_corr.to( + "kilocalorie_per_mole" + ) + else: + unit_results_dict["standard_state_correction"] = 0 * offunit.kilocalorie_per_mole + + if restraint_geometry is not None: + unit_results_dict["restraint_geometry"] = restraint_geometry.model_dump() + else: + unit_results_dict["restraint_geometry"] = None + + if dry: + unit_results_dict |= { + "standard_system": omm_system, + "restrained_system": restrained_omm_system, + "alchem_system": alchem_system, + "alchem_indices": alchem_indices, + "alchem_factory": alchem_factory, + "debug_positions": positions, + } + return unit_results_dict + + def _execute( + self, + ctx: gufe.Context, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + "simtype": self.simtype, + **outputs, + } + + +class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): + @abc.abstractmethod + def _get_components( + self, + ) -> tuple[ + dict[str, list[Component]], + gufe.SolventComponent | None, + gufe.ProteinComponent | None, + dict[SmallMoleculeComponent, OFFMolecule], + ]: + """ + Get the relevant components to create the alchemical system with. + + Note + ---- + Must be implemented in the child class. + """ + ... + + def _get_lambda_schedule( + self, settings: dict[str, SettingsBaseModel] + ) -> dict[str, list[float]]: + """ + Create the lambda schedule + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + Settings for the unit. + + Returns + ------- + lambdas : dict[str, list[float]] + + TODO + ---- + * Augment this by using something akin to the RFE protocol's + LambdaProtocol + """ + lambdas = dict() + + lambda_elec = settings["lambda_settings"].lambda_elec + lambda_vdw = settings["lambda_settings"].lambda_vdw + lambda_rest = settings["lambda_settings"].lambda_restraints + + # Reverse lambda schedule for vdw, end elec, + # since in AbsoluteAlchemicalFactory 1 means fully + # interacting (which would be non-interacting for us) + lambdas["lambda_electrostatics"] = [1 - x for x in lambda_elec] + lambdas["lambda_sterics"] = [1 - x for x in lambda_vdw] + lambdas["lambda_restraints"] = [x for x in lambda_rest] + + return lambdas + def _get_states( self, alchemical_system: openmm.System, positions: openmm.unit.Quantity, box_vectors: openmm.unit.Quantity, - settings: dict[str, SettingsBaseModel], + thermodynamic_settings: ThermoSettings, lambdas: dict[str, list[float]], - solvent_comp: Optional[SolventComponent], - restraint_state: Optional[GlobalParameterState], + solvent_component: SolventComponent | None, + alchemically_restrained: bool, ) -> tuple[list[SamplerState], list[ThermodynamicState]]: """ Get a list of sampler and thermodynmic states from an @@ -702,14 +862,15 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): Positions of the alchemical system. box_vectors : openmm.unit.Quantity Box vectors of the alchemical system. - settings : dict[str, SettingsBaseModel] - A dictionary of settings for the protocol unit. + thermodynamic_settings : ThermoSettings + Settings controlling the thermodynamic parameters. lambdas : dict[str, list[float]] A dictionary of lambda scales. - solvent_comp : Optional[SolventComponent] + solvent_component : SolventComponent | None The solvent component of the system, if there is one. - restraint_state : Optional[GlobalParameterState] - The restraint parameter control state, if there is one. + alchemically_restrained : bool + Whether or not the system requires a control parameter + for any alchemical restraints. Returns ------- @@ -722,19 +883,20 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): alchemical_state = AlchemicalState.from_system(alchemical_system) # Set up the system constants - temperature = settings["thermo_settings"].temperature - pressure = settings["thermo_settings"].pressure + temperature = thermodynamic_settings.temperature + pressure = thermodynamic_settings.pressure constants = dict() constants["temperature"] = ensure_quantity(temperature, "openmm") - if solvent_comp is not None: + if solvent_component is not None: constants["pressure"] = ensure_quantity(pressure, "openmm") # Get the thermodynamic parameter protocol param_protocol = copy.deepcopy(lambdas) # Get the composable states - if restraint_state is not None: + if alchemically_restrained: + restraint_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0) composable_states = [alchemical_state, restraint_state] else: composable_states = [alchemical_state] @@ -758,10 +920,67 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): return sampler_states, cmp_states + @staticmethod + def _get_integrator( + integrator_settings: IntegratorSettings, + simulation_settings: MultiStateSimulationSettings, + system: openmm.System, + ) -> openmmtools.mcmc.LangevinDynamicsMove: + """ + Return a LangevinDynamicsMove integrator + + Parameters + ---------- + integrator_settings : IntegratorSettings + Settings controlling the Langevin integrator + simulation_settings : MultiStateSimulationSettings + Settings controlling the simulation. + system : openmm.System + The OpenMM System. + + Returns + ------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + A configured integrator object. + + Raises + ------ + ValueError + If there are virtual sites in the system, but + velocities are not being reassigned after every MCMC move. + """ + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings, integrator_settings + ) + + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.langevin_collision_rate), + n_steps=steps_per_iteration, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + ) + + # Validate for known issue when dealing with virtual sites + # and mutltistate simulations + if not integrator_settings.reassign_velocities: + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable with MCMC integrators. " + "You can set `reassign_velocities` to ``True`` in the " + "`integrator_settings` to avoid this issue." + ) + raise ValueError(errmsg) + + return integrator + + @staticmethod def _get_reporter( - self, - topology: app.Topology, - positions: openmm.unit.Quantity, + storage_path: pathlib.Path, + selection_indices: npt.NDArray, simulation_settings: MultiStateSimulationSettings, output_settings: MultiStateOutputSettings, ) -> multistate.MultiStateReporter: @@ -770,10 +989,10 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): Parameters ---------- - topology : app.Topology - A Topology of the system being created. - positions : openmm.unit.Quantity - Positions of the pre-alchemical simulation system. + storage_path : pathlib.Path + Path to the directory where files should be written. + selection_indices : npt.NDArray + Array of system particle indices to subsample the system by. simulation_settings : MultiStateSimulationSettings Multistate simulation control settings, specifically containing the amount of time per state sampling iteration. @@ -785,18 +1004,10 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): reporter : multistate.MultiStateReporter The reporter for the simulation. """ - mdt_top = mdt.Topology.from_openmm(topology) - - # Store the selection indices in self to use later - # when storing them in the unit results - self.selection_indices = mdt_top.select(output_settings.output_indices) - - nc = self.shared_basepath / output_settings.output_filename + nc = storage_path / output_settings.output_filename + # The checkpoint file in openmmtools is taken as a file relative + # to the location of the nc file, so you only want the filename chk = output_settings.checkpoint_storage_filename - chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( - checkpoint_interval=output_settings.checkpoint_interval, - time_per_iteration=simulation_settings.time_per_iteration, - ) if output_settings.positions_write_frequency is not None: pos_interval = settings_validation.divmod_time_and_check( @@ -818,110 +1029,31 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): else: vel_interval = 0 + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + reporter = multistate.MultiStateReporter( storage=nc, - analysis_particle_indices=self.selection_indices, + analysis_particle_indices=selection_indices, checkpoint_interval=chk_intervals, checkpoint_storage=chk, position_interval=pos_interval, velocity_interval=vel_interval, ) - # Write out the structure's PDB whilst we're here - if len(self.selection_indices) > 0: - traj = mdt.Trajectory( - positions[self.selection_indices, :], - mdt_top.subset(self.selection_indices), - ) - traj.save_pdb(self.shared_basepath / output_settings.output_structure) - return reporter - def _get_ctx_caches( - self, - forcefield_settings: OpenMMSystemGeneratorFFSettings, - engine_settings: OpenMMEngineSettings, - ) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]: - """ - Set the context caches based on the chosen platform - - Parameters - ---------- - forcefield_settings: OpenMMSystemGeneratorFFSettings - engine_settings : OpenMMEngineSettings - - Returns - ------- - energy_context_cache : openmmtools.cache.ContextCache - The energy state context cache. - sampler_context_cache : openmmtools.cache.ContextCache - The sampler state context cache. - """ - # Get the compute platform - # Set the number of CPUs to 1 if running a vacuum simulation - restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" - platform = omm_compute.get_openmm_platform( - platform_name=engine_settings.compute_platform, - gpu_device_index=engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu, - ) - - energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, - time_to_live=None, - platform=platform, - ) - - return energy_context_cache, sampler_context_cache - - @staticmethod - def _get_integrator( - integrator_settings: IntegratorSettings, simulation_settings: MultiStateSimulationSettings - ) -> openmmtools.mcmc.LangevinDynamicsMove: - """ - Return a LangevinDynamicsMove integrator - - Parameters - ---------- - integrator_settings : IntegratorSettings - simulation_settings : MultiStateSimulationSettings - - Returns - ------- - integrator : openmmtools.mcmc.LangevinDynamicsMove - A configured integrator object. - """ - steps_per_iteration = settings_validation.convert_steps_per_iteration( - simulation_settings, integrator_settings - ) - - integrator = openmmtools.mcmc.LangevinDynamicsMove( - timestep=to_openmm(integrator_settings.timestep), - collision_rate=to_openmm(integrator_settings.langevin_collision_rate), - n_steps=steps_per_iteration, - reassign_velocities=integrator_settings.reassign_velocities, - n_restart_attempts=integrator_settings.n_restart_attempts, - constraint_tolerance=integrator_settings.constraint_tolerance, - ) - - return integrator - @staticmethod def _get_sampler( integrator: openmmtools.mcmc.LangevinDynamicsMove, reporter: openmmtools.multistate.MultiStateReporter, simulation_settings: MultiStateSimulationSettings, - thermo_settings: ThermoSettings, - cmp_states: list[ThermodynamicState], + thermodynamic_settings: ThermoSettings, + compound_states: list[ThermodynamicState], sampler_states: list[SamplerState], - energy_context_cache: openmmtools.cache.ContextCache, - sampler_context_cache: openmmtools.cache.ContextCache, + platform: openmm.Platform, ) -> multistate.MultiStateSampler: """ Get a sampler based on the equilibrium sampling method requested. @@ -934,16 +1066,14 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): The reporter to hook up to the sampler. simulation_settings : MultiStateSimulationSettings Settings for the alchemical sampler. - thermo_settings : ThermoSettings + thermodynamic_settings : ThermoSettings Thermodynamic settings - cmp_states : list[ThermodynamicState] + compound_states : list[ThermodynamicState] A list of thermodynamic states to sample. sampler_states : list[SamplerState] A list of sampler states. - energy_context_cache : openmmtools.cache.ContextCache - Context cache for the energy states. - sampler_context_cache : openmmtool.cache.ContextCache - Context cache for the sampler states. + platform : openmm.Platform + The compute platform to use. Returns ------- @@ -954,7 +1084,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): simulation_settings=simulation_settings, ) et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermo_settings.temperature, + thermodynamic_settings.temperature, simulation_settings.early_termination_target_error, ) @@ -984,11 +1114,22 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ) sampler.create( - thermodynamic_states=cmp_states, sampler_states=sampler_states, storage=reporter + thermodynamic_states=compound_states, + sampler_states=sampler_states, + storage=reporter, ) - sampler.energy_context_cache = energy_context_cache - sampler.sampler_context_cache = sampler_context_cache + sampler.energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + sampler.sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) return sampler @@ -997,7 +1138,6 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): sampler: multistate.MultiStateSampler, reporter: multistate.MultiStateReporter, settings: dict[str, SettingsBaseModel], - standard_state_corr: Optional[Quantity], dry: bool, ): """ @@ -1011,16 +1151,8 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): The reporter associated with the sampler. settings : dict[str, SettingsBaseModel] The dictionary of settings for the protocol. - standard_state_corr : Optional[openff.units.Quantity] - The standard state correction, if available. dry : bool Whether or not to dry run the simulation - - Returns - ------- - unit_results_dict : Optional[dict] - A dictionary containing all the free energy results, - if not a dry run. """ # Get the relevant simulation steps mc_steps = settings_validation.convert_steps_per_iteration( @@ -1060,26 +1192,6 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): if self.verbose: self.logger.info("production phase complete") - if self.verbose: - self.logger.info("post-simulation result analysis") - - analyzer = multistate_analysis.MultistateEquilFEAnalysis( - reporter, - sampling_method=settings["simulation_settings"].sampler_method.lower(), - result_units=unit.kilocalorie_per_mole, - ) - analyzer.plot(filepath=self.shared_basepath, filename_prefix="") - analyzer.close() - - return_dict = analyzer.unit_results_dict - - if standard_state_corr is not None: - return_dict["standard_state_correction"] = standard_state_corr.to( - "kilocalorie_per_mole" - ) - - return return_dict - else: # close reporter when you're done, prevent file handle clashes reporter.close() @@ -1092,130 +1204,118 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): for fn in fns: os.remove(fn) - return None - def run( - self, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + self, + *, + system: openmm.System, + positions: openmm.unit.Quantity, + box_vectors: Quantity, + selection_indices: npt.NDArray, + alchemical_restraints: bool, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, ) -> dict[str, Any]: - """Run the absolute free energy calculation. + """ + Run the free energy calculation using a multistate sampler. Parameters ---------- - dry : bool - Do a dry run of the calculation, creating all necessary alchemical - system components (topology, system, sampler, etc...) but without - running the simulation, default False + system : openmm.System + The System to simulate. + positions : openmm.unit.Quantity + The positions of the System. + box_vectors : openff.units.Quantity + The box vectors of the System. + selection_indices : npt.NDArray + Indices of the System particles to write to file. + alchemical_restraints: bool, + Whether or not the system has alchemical restraints. + dry: bool + Do a dry run of the calculation, creating all the necessary + components, but without running the simulation. verbose : bool - Verbose output of the simulation progress. Output is provided via - INFO level logging, default True - scratch_basepath : pathlib.Path - Path to the scratch (temporary) directory space. - shared_basepath : pathlib.Path - Path to the shared (persistent) directory space. + Verbose output of the simulation progress. Output is provided at + the INFO logging level. + scratch_basepath : pathlib.Path | None + Where to store temporary files, defaults to the current working + directory if ``None``. + shared_basepath : pathlib.Path | None + Where to store calculation outputs, defaults to the current working + directory if ``None``. Returns ------- dict - Outputs created in the basepath directory or the debug objects - (i.e. sampler) if ``dry==True``. + Outputs created by the unit, including the debug objects + (i.e. sampler) if ``dry==True`` """ - # 0. Generaly preparation tasks + # Prepare paths & verbosity self._prepare(verbose, scratch_basepath, shared_basepath) - # 1. Get components - alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components() + # Get the settings + settings = self._get_settings() - # 2. Get settings - settings = self._handle_settings() + # Get the components + alchem_comps, solv_comp, prot_comp, small_mols = self._get_components() - # 3. Get OpenMM topology, positions, system, and comp_resids - omm_topology, omm_system, positions, comp_resids = self._get_omm_objects( - settings=settings, - protein_component=prot_comp, - solvent_component=solv_comp, - smc_components=smc_comps, - ) - - # 4. Pre-equilbrate System (Test + Avoid NaNs + get stable system) - positions, box_vectors = self._pre_equilibrate( - omm_system, omm_topology, positions, settings, dry - ) - - # 5. Get lambdas + # Get the lambda schedule lambdas = self._get_lambda_schedule(settings) - # 6. Add restraints - # Note: when no restraint is applied, restrained_omm_system == omm_system - ( - restraint_parameter_state, - standard_state_corr, - restrained_omm_system, - restraint_geometry, - ) = self._add_restraints( - omm_system, - omm_topology, - positions, - alchem_comps, - comp_resids, - settings, + # Get the compute platform + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, ) - # 7. Get alchemical system - alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( - topology=omm_topology, - system=restrained_omm_system, - comp_resids=comp_resids, - alchem_comps=alchem_comps, - alchemical_settings=settings["alchemical_settings"], - ) - - # 8. Get compound and sampler states + # Get compound and sampler states sampler_states, cmp_states = self._get_states( - alchem_system, - positions, - box_vectors, - settings, - lambdas, - solv_comp, - restraint_parameter_state, + alchemical_system=system, + positions=positions, + # convert the box vectors to vec3 from openff + box_vectors=make_vec3_box(box_vectors), + thermodynamic_settings=settings["thermo_settings"], + lambdas=lambdas, + solvent_component=solv_comp, + alchemically_restrained=alchemical_restraints, ) - # 9. Create the multistate reporter & create PDB - reporter = self._get_reporter( - omm_topology, - positions, - settings["simulation_settings"], - settings["output_settings"], + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=system, ) - # Wrap in try/finally to avoid memory leak issues try: - # 10. Get context caches - energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( - settings["forcefield_settings"], settings["engine_settings"] + # Create or get the multistate reporter + reporter = self._get_reporter( + storage_path=self.shared_basepath, + selection_indices=selection_indices, + simulation_settings=settings["simulation_settings"], + output_settings=settings["output_settings"], ) - # 11. Get integrator - integrator = self._get_integrator( - settings["integrator_settings"], - settings["simulation_settings"], - ) - - # 12. Get sampler + # Get sampler sampler = self._get_sampler( - integrator, - reporter, - settings["simulation_settings"], - settings["thermo_settings"], - cmp_states, - sampler_states, - energy_ctx_cache, - sampler_ctx_cache, + integrator=integrator, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + thermodynamic_settings=settings["thermo_settings"], + compound_states=cmp_states, + sampler_states=sampler_states, + platform=platform, ) - # 13. Run simulation - unit_result_dict = self._run_simulation( - sampler, reporter, settings, standard_state_corr, dry + # Run simulation + self._run_simulation( + sampler=sampler, + reporter=reporter, + settings=settings, + dry=dry, ) finally: @@ -1224,15 +1324,15 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # clear GPU context # Note: use cache.empty() when openmmtools #690 is resolved - for context in list(energy_ctx_cache._lru._data.keys()): - del energy_ctx_cache._lru._data[context] - for context in list(sampler_ctx_cache._lru._data.keys()): - del sampler_ctx_cache._lru._data[context] + for context in list(sampler.energy_context_cache._lru._data.keys()): + del sampler.energy_context_cache._lru._data[context] + for context in list(sampler.sampler_context_cache._lru._data.keys()): + del sampler.sampler_context_cache._lru._data[context] # cautiously clear out the global context cache too for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] - del sampler_ctx_cache, energy_ctx_cache + del sampler.sampler_context_cache, sampler.energy_context_cache # Keep these around in a dry run so we can inspect things if not dry: @@ -1240,37 +1340,45 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): if not dry: nc = self.shared_basepath / settings["output_settings"].output_filename - chk = settings["output_settings"].checkpoint_storage_filename - unit_result_dict["nc"] = nc - unit_result_dict["last_checkpoint"] = chk - unit_result_dict["selection_indices"] = self.selection_indices - - if restraint_geometry is not None: - unit_result_dict["restraint_geometry"] = restraint_geometry.model_dump() - - return unit_result_dict + chk = self.shared_basepath / settings["output_settings"].checkpoint_storage_filename + return { + "trajectory": nc, + "checkpoint": chk, + } else: return { - # Add in various objects we can used to test the system - "debug": { - "sampler": sampler, - "system": omm_system, - "restrained_system": restrained_omm_system, - "alchem_system": alchem_system, - "alchem_indices": alchem_indices, - "alchem_factory": alchem_factory, - "positions": positions, - } + "sampler": sampler, + "integrator": integrator, } def _execute( self, ctx: gufe.Context, - **kwargs, + *, + setup_results, + **inputs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) + system = deserialize(setup_results.outputs["system"]) + positions = to_openmm(np.load(setup_results.outputs["positions"]) * offunit.nanometer) + selection_indices = setup_results.outputs["selection_indices"] + box_vectors = setup_results.outputs["box_vectors"] + + if setup_results.outputs["restraint_geometry"] is not None: + alchemical_restraints = True + else: + alchemical_restraints = False + + outputs = self.run( + system=system, + positions=positions, + box_vectors=box_vectors, + selection_indices=selection_indices, + alchemical_restraints=alchemical_restraints, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) return { "repeat_id": self._inputs["repeat_id"], @@ -1278,3 +1386,147 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): "simtype": self.simtype, **outputs, } + + +class BaseAbsoluteMultiStateAnalysisUnit(gufe.ProtocolUnit, AbsoluteUnitMixin): + @staticmethod + def _analyze_multistate_energies( + trajectory: pathlib.Path, + checkpoint: pathlib.Path, + sampler_method: str, + output_directory: pathlib.Path, + dry: bool, + ): + """ + Analyze multistate energies and generate plots. + + Parameters + ---------- + trajectory : pathlib.Path + Path to the NetCDF trajectory file. + checkpoint : pathlib.Path + The name of the checkpoint file. Note this is + relative in path to the trajectory file. + sampler_method : str + The multistate sampler method used. + output_directory : pathlib.Path + The path to where plots will be written. + dry : bool + Whether or not we are running a dry run. + """ + reporter = multistate.MultiStateReporter( + storage=trajectory, + # Note: openmmtools only wants the name of the checkpoint + # file, it assumes it to be in the same place as the trajectory + checkpoint_storage=checkpoint.name, + open_mode="r", + ) + + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter=reporter, + sampling_method=sampler_method, + result_units=offunit.kilocalorie_per_mole, + ) + + # Only create plots when not doing a dry run + if not dry: + analyzer.plot(filepath=output_directory, filename_prefix="") + + analyzer.close() + reporter.close() + return analyzer.unit_results_dict + + def run( + self, + *, + trajectory: pathlib.Path, + checkpoint: pathlib.Path, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, + ) -> dict[str, Any]: + """Analyze the multistate simulation. + + Parameters + ---------- + trajectory : pathlib.Path + Path to the MultiStateReporter generated NetCDF file. + checkpoint : pathlib.Path + Path to the checkpoint file generated by MultiStateReporter. + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: pathlib.Path | None + Where to store temporary files, defaults to current working directory + shared_basepath : pathlib.Path | None + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get the settings + settings = self._get_settings() + + # Energies analysis + if verbose: + self.logger.info("Analyzing energies") + + energy_analysis = self._analyze_multistate_energies( + trajectory=trajectory, + checkpoint=checkpoint, + sampler_method=settings["simulation_settings"].sampler_method.lower(), + output_directory=self.shared_basepath, + dry=dry, + ) + + return energy_analysis + + def _execute( + self, + ctx: gufe.Context, + *, + setup_results, + simulation_results, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + pdb_file = setup_results.outputs["pdb_structure"] + selection_indices = setup_results.outputs["selection_indices"] + restraint_geometry = setup_results.outputs["restraint_geometry"] + standard_state_corr = setup_results.outputs["standard_state_correction"] + trajectory = simulation_results.outputs["trajectory"] + checkpoint = simulation_results.outputs["checkpoint"] + + outputs = self.run( + trajectory=trajectory, + checkpoint=checkpoint, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + "simtype": self.simtype, + # We re-include things here also to make + # life easier when gathering results. + "pdb_structure": pdb_file, + "trajectory": trajectory, + "checkpoint": checkpoint, + "selection_indices": selection_indices, + "restraint_geometry": restraint_geometry, + "standard_state_correction": standard_state_corr, + **outputs, + } diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/src/openfe/protocols/openmm_afe/equil_afe_settings.py similarity index 100% rename from openfe/protocols/openmm_afe/equil_afe_settings.py rename to src/openfe/protocols/openmm_afe/equil_afe_settings.py diff --git a/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py new file mode 100644 index 00000000..28054a14 --- /dev/null +++ b/src/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -0,0 +1,514 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Binding AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_binding_afe_method` +========================================================================================================== + +This module implements the necessary methodology tooling to calculate an +absolute binding free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Alchemical species with a net charge are not currently supported. +* Disapearing molecules are only allowed in state A. +* Only small molecules are allowed to act as alchemical molecules. + +Acknowledgements +---------------- +* This Protocol re-implements components from + `Yank `_. + +""" + +import logging +import uuid +import warnings +from collections import defaultdict +from typing import Any, Iterable + +import gufe +from gufe import ( + ChemicalSystem, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, +) +from openff.units import unit as offunit + +from openfe.due import Doi, due +from openfe.protocols.openmm_afe.equil_afe_settings import ( + ABFEPreEquilOutputSettings, + AbsoluteBindingSettings, + AlchemicalSettings, + BoreschRestraintSettings, + IntegratorSettings, + LambdaSettings, + MDSimulationSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, +) +from openfe.protocols.openmm_utils import ( + settings_validation, + system_validation, +) + +from .abfe_units import ( + ABFEComplexAnalysisUnit, + ABFEComplexSetupUnit, + ABFEComplexSimUnit, + ABFESolventAnalysisUnit, + ABFESolventSetupUnit, + ABFESolventSimUnit, +) +from .afe_protocol_results import AbsoluteBindingProtocolResult + +due.cite( + Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True, +) + +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True, +) + +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_binding_afe_method", + cite_module=True, +) + + +logger = logging.getLogger(__name__) + + +class AbsoluteBindingProtocol(gufe.Protocol): + """ + Absolute binding free energy calculations using OpenMM and OpenMMTools. + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSettings` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingProtocolResult` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingSolventUnit` + :class:`openfe.protocols.openmm_afe.AbsoluteBindingComplexUnit` + """ + + result_cls = AbsoluteBindingProtocolResult + _settings_cls = AbsoluteBindingSettings + _settings: AbsoluteBindingSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + # fmt: off + return AbsoluteBindingSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, + ), + alchemical_settings=AlchemicalSettings(), + solvent_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.25, 0.5, 0.75, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + ], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0 + ], + lambda_restraints=[ + 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 + ], + ), + complex_lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0 + ], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0 + ], + lambda_restraints=[ + 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0 + ], + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + complex_solvation_settings=OpenMMSolvationSettings( + solvent_padding=1.0 * offunit.nanometer, + ), + solvent_solvation_settings=OpenMMSolvationSettings(), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + restraint_settings=BoreschRestraintSettings(), + solvent_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.1 * offunit.nanosecond, + equilibration_length=0.2 * offunit.nanosecond, + production_length=0.5 * offunit.nanosecond, + ), + solvent_equil_output_settings=ABFEPreEquilOutputSettings(), + solvent_simulation_settings=MultiStateSimulationSettings( + n_replicas=14, + equilibration_length=1.0 * offunit.nanosecond, + production_length=10.0 * offunit.nanosecond, + ), + solvent_output_settings=MultiStateOutputSettings( + output_structure="alchemical_system.pdb", + output_filename="solvent.nc", + checkpoint_storage_filename="solvent_checkpoint.nc", + ), + complex_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.25 * offunit.nanosecond, + equilibration_length=0.5 * offunit.nanosecond, + production_length=5.0 * offunit.nanosecond, + ), + complex_equil_output_settings=ABFEPreEquilOutputSettings(), + complex_simulation_settings=MultiStateSimulationSettings( + n_replicas=30, + equilibration_length=1 * offunit.nanosecond, + production_length=10.0 * offunit.nanosecond, + ), + complex_output_settings=MultiStateOutputSettings( + output_structure="alchemical_system.pdb", + output_filename="complex.nc", + checkpoint_storage_filename="complex_checkpoint.nc", + ), + ) + # fmt: on + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + A binding transformation is defined (in terms of gufe components) + as starting from one or more ligands with one protein and solvent, + that then ends up in a state with one less ligand. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A + stateB : ChemicalSystem + The chemical system of end state B + + Raises + ------ + ValueError + If stateA & stateB do not contain a ProteinComponent. + If stateA & stateB do not contain a SolventComponent. + If stateA has more than one unique Component. + If the stateA unique Component is not a SmallMoleculeComponent. + If stateB contains any unique Components. + If the alchemical species is charged. + """ + if not (stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent)): + errmsg = "No ProteinComponent found" + raise ValueError(errmsg) + + if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)): + errmsg = "No SolventComponent found" + raise ValueError(errmsg) + + # Needs gufe 1.3 + diff = stateA.component_diff(stateB) + if len(diff[0]) != 1: + errmsg = ( + "Only one alchemical species is supported. " + f"Number of unique components found in stateA: {len(diff[0])}." + ) + raise ValueError(errmsg) + + if not isinstance(diff[0][0], SmallMoleculeComponent): + errmsg = ( + "Only dissapearing small molecule components " + "are supported by this protocol. " + f"Found a {type(diff[0][0])}" + ) + raise ValueError(errmsg) + + # Check that the state A unique isn't charged + if diff[0][0].total_charge != 0: + errmsg = ( + "Charged alchemical molecules are not currently " + "supported for solvation free energies. " + f"Molecule total charge: {diff[0][0].total_charge}." + ) + raise ValueError(errmsg) + + # If there are any alchemical Components in state B + if len(diff[1]) > 0: + errmsg = "Components appearing in state B are not currently supported" + raise ValueError(errmsg) + + @staticmethod + def _validate_lambda_schedule( + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> None: + """ + Checks that the lambda schedule is set up correctly. + + Parameters + ---------- + lambda_settings : LambdaSettings + the lambda schedule Settings + simulation_settings : MultiStateSimulationSettings + the settings for either the complex or solvent phase + + Raises + ------ + ValueError + If the number of lambda windows differs for electrostatics, sterics, + and restraints. + If the number of replicas does not match the number of lambda windows. + If there are states with naked charges. + """ + + lambda_elec = lambda_settings.lambda_elec + lambda_vdw = lambda_settings.lambda_vdw + lambda_restraints = lambda_settings.lambda_restraints + n_replicas = simulation_settings.n_replicas + + # Ensure that all lambda components have equal amount of windows + lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] + it = iter(lambda_components) + the_len = len(next(it)) + if not all(len(lambda_comp) == the_len for lambda_comp in it): + errmsg = ( + "Components elec, vdw, and restraints must have equal amount" + f" of lambda windows. Got {len(lambda_elec)} elec lambda" + f" windows, {len(lambda_vdw)} vdw lambda windows, and" + f"{len(lambda_restraints)} restraints lambda windows." + ) + raise ValueError(errmsg) + + # Ensure that number of overall lambda windows matches number of lambda + # windows for individual components + if n_replicas != len(lambda_vdw): + errmsg = ( + f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(lambda_vdw)}" + ) + raise ValueError(errmsg) + + # Check if there are no lambda windows with naked charges + for inx, lam in enumerate(lambda_elec): + if lam < 1 and lambda_vdw[inx] == 1: + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: lambda {inx}: " + f"elec {lam} vdW {lambda_vdw[inx]}" + ) + raise ValueError(errmsg) + + def _validate( + self, + *, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, + ): + # Check we're not extending + if extends is not None: + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") + + # Check we're not using a mapping, since we're not doing anything with it + if mapping is not None: + wmsg = "A mapping was passed but is not used by this Protocol." + warnings.warn(wmsg) + + # Validate the end states & alchemical components + self._validate_endstates(stateA, stateB) + + # Validate the complex lambda schedule + self._validate_lambda_schedule( + self.settings.complex_lambda_settings, + self.settings.complex_simulation_settings, + ) + + # If the complex restraints schedule is all zero, it might be bad + # but we don't dissallow it. + if all([i == 0.0 for i in self.settings.complex_lambda_settings.lambda_restraints]): + wmsg = ( + "No restraints are being applied in the complex phase, " + "this will likely lead to problematic results." + ) + warnings.warn(wmsg) + + # Validate the solvent lambda schedule + self._validate_lambda_schedule( + self.settings.solvent_lambda_settings, + self.settings.solvent_simulation_settings, + ) + + # If the solvent restraints schedule is not all one, it was likely + # copied from the complex schedule. In this case we just ignore + # the values and let the user know. + # P.S. we don't need to change the settings at this point + # the list gets popped out later in the SolventUnit, because we + # don't have a restraint parameter state. + + if any([i != 0.0 for i in self.settings.solvent_lambda_settings.lambda_restraints]): + wmsg = ( + "There is an attempt to add restraints in the solvent " + "phase. This protocol does not apply restraints in the " + "solvent phase. These restraint lambda values will be ignored." + ) + warnings.warn(wmsg) + + # Check nonbond & solvent compatibility + nonbonded_method = self.settings.forcefield_settings.nonbonded_method + # Use the more complete system validation solvent checks + system_validation.validate_solvent(stateA, nonbonded_method) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings( + self.settings.solvent_solvation_settings + ) + settings_validation.validate_openmm_solvation_settings( + self.settings.complex_solvation_settings + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> list[gufe.ProtocolUnit]: + # Validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # Get the alchemical components + alchem_comps = system_validation.get_alchemical_components( + stateA, + stateB, + ) + + # Get the name of the alchemical species + alchname = alchem_comps["stateA"][0].name + unit_classes = { + "solvent": { + "setup": ABFESolventSetupUnit, + "simulation": ABFESolventSimUnit, + "analysis": ABFESolventAnalysisUnit, + }, + "complex": { + "setup": ABFEComplexSetupUnit, + "simulation": ABFEComplexSimUnit, + "analysis": ABFEComplexAnalysisUnit, + }, + } + + protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []} + + for phase in ["solvent", "complex"]: + for i in range(self.settings.protocol_repeats): + repeat_id = int(uuid.uuid4()) + + setup = unit_classes[phase]["setup"]( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, + repeat_id=repeat_id, + name=f"ABFE Setup: {alchname} {phase} leg: repeat {i} generation 0", + ) + + simulation = unit_classes[phase]["simulation"]( + protocol=self, + # only need state A & alchem comps + stateA=stateA, + alchemical_components=alchem_comps, + setup_results=setup, + generation=0, + repeat_id=repeat_id, + name=f"ABFE Simulation: {alchname} {phase} leg: repeat {i} generation 0", + ) + + analysis = unit_classes[phase]["analysis"]( + protocol=self, + setup_results=setup, + simulation_results=simulation, + generation=0, + repeat_id=repeat_id, + name=f"ABFE Analysis: {alchname} {phase} leg, repeat {i} generation 0", + ) + + protocol_units[phase] += [setup, simulation, analysis] + + return protocol_units["solvent"] + protocol_units["complex"] + + def _gather( + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + ) -> dict[str, dict[str, Any]]: + # result units will have a repeat_id and generation + # first group according to repeat_id + unsorted_solvent_repeats = defaultdict(list) + unsorted_complex_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if ("Analysis" not in pu.name) or (not pu.ok()): + continue + if pu.outputs["simtype"] == "solvent": + unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) + else: + unsorted_complex_repeats[pu.outputs["repeat_id"]].append(pu) + + repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { + "solvent": {}, + "complex": {}, + } + for k, v in unsorted_solvent_repeats.items(): + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + + for k, v in unsorted_complex_repeats.items(): + repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + return repeats diff --git a/src/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/src/openfe/protocols/openmm_afe/equil_solvation_afe_method.py new file mode 100644 index 00000000..72d129f4 --- /dev/null +++ b/src/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -0,0 +1,531 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method` +=============================================================================================================== + +This module implements the necessary methodology tooling to run calculate an +absolute solvation free energy using OpenMM tools and one of the following +alchemical sampling methods: + +* Hamiltonian Replica Exchange +* Self-adjusted mixture sampling +* Independent window sampling + +Current limitations +------------------- +* Alchemical species with a net charge are not currently supported. +* Disapearing molecules are only allowed in state A. Support for + appearing molecules will be added in due course. +* Only small molecules are allowed to act as alchemical molecules. + Alchemically changing protein or solvent components would induce + perturbations which are too large to be handled by this Protocol. + + +Acknowledgements +---------------- +* Originally based on hydration.py in + `espaloma_charge `_ + +""" + +import logging +import uuid +import warnings +from collections import defaultdict +from typing import Any, Iterable, Optional, Union + +import gufe +import numpy as np +from gufe import ( + ChemicalSystem, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, +) +from openff.units import unit as offunit + +from openfe.due import Doi, due +from openfe.protocols.openmm_afe.equil_afe_settings import ( + AbsoluteSolvationSettings, + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MDOutputSettings, + MDSimulationSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, +) + +from ..openmm_utils import settings_validation, system_validation +from .afe_protocol_results import AbsoluteSolvationProtocolResult +from .ahfe_units import ( + AHFESolventAnalysisUnit, + AHFESolventSetupUnit, + AHFESolventSimUnit, + AHFEVacuumAnalysisUnit, + AHFEVacuumSetupUnit, + AHFEVacuumSimUnit, +) + +due.cite( + Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) + +due.cite( + Doi("10.48550/ARXIV.2302.06758"), + description="EspalomaCharge", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) + +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) + +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) + + +logger = logging.getLogger(__name__) + + +class AbsoluteSolvationProtocol(gufe.Protocol): + """ + Absolute solvation free energy calculations using OpenMM and OpenMMTools. + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_afe.AbsoluteSolvationSettings` + :class:`openfe.protocols.openmm_afe.AbsoluteSolvationProtocolResult` + :class:`openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit` + :class:`openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit` + """ + + result_cls = AbsoluteSolvationProtocolResult + _settings_cls = AbsoluteSolvationSettings + _settings: AbsoluteSolvationSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return AbsoluteSolvationSettings( + protocol_repeats=3, + solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings( + nonbonded_method="nocutoff", + ), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, + ), + alchemical_settings=AlchemicalSettings(), + lambda_settings=LambdaSettings( + lambda_elec=[ + 0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + lambda_vdw=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, + 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0], + lambda_restraints=[ + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + vacuum_engine_settings=OpenMMEngineSettings(), + solvent_engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + solvent_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=0.1 * offunit.nanosecond, + equilibration_length=0.2 * offunit.nanosecond, + production_length=0.5 * offunit.nanosecond, + ), + solvent_equil_output_settings=MDOutputSettings( + equil_nvt_structure="equil_nvt_structure.pdb", + equil_npt_structure="equil_npt_structure.pdb", + production_trajectory_filename="production_equil.xtc", + log_output="equil_simulation.log", + ), + solvent_simulation_settings=MultiStateSimulationSettings( + n_replicas=14, + equilibration_length=1.0 * offunit.nanosecond, + production_length=10.0 * offunit.nanosecond, + ), + solvent_output_settings=MultiStateOutputSettings( + output_filename="solvent.nc", + checkpoint_storage_filename="solvent_checkpoint.nc", + ), + vacuum_equil_simulation_settings=MDSimulationSettings( + equilibration_length_nvt=None, + equilibration_length=0.2 * offunit.nanosecond, + production_length=0.5 * offunit.nanosecond, + ), + vacuum_equil_output_settings=MDOutputSettings( + equil_nvt_structure=None, + equil_npt_structure="equil_structure.pdb", + production_trajectory_filename="production_equil.xtc", + log_output="equil_simulation.log", + ), + vacuum_simulation_settings=MultiStateSimulationSettings( + n_replicas=14, + equilibration_length=0.5 * offunit.nanosecond, + production_length=2.0 * offunit.nanosecond, + ), + vacuum_output_settings=MultiStateOutputSettings( + output_filename="vacuum.nc", + checkpoint_storage_filename="vacuum_checkpoint.nc", + ), + ) # fmt: skip + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + A solvent transformation is defined (in terms of gufe components) + as starting from one or more ligands in solvent and + ending up in a state with one less ligand. + + No protein components are allowed. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A + stateB : ChemicalSystem + The chemical system of end state B + + Raises + ------ + ValueError + If stateA or stateB contains a ProteinComponent. + If there is no SolventComponent in either stateA or stateB. + If there are alchemical components in state B. + If there are non SmallMoleculeComponent alchemical species. + If there are more than one alchemical species. + If the alchemical species is charged. + + Notes + ----- + * Currently doesn't support alchemical components in state B. + * Currently doesn't support alchemical components which are not + SmallMoleculeComponents. + * Currently doesn't support more than one alchemical component + being desolvated. + * Currently doesn't support charged alchemical components. + * Solvent must always be present in both end states. + """ + # Check that there are no protein components + if stateA.contains(ProteinComponent) or stateB.contains(ProteinComponent): + errmsg = "Protein components are not allowed for absolute solvation free energies." + raise ValueError(errmsg) + + # Check that there is a solvent component in both end states + if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)): + errmsg = "No SolventComponent found in stateA and/or stateB" + raise ValueError(errmsg) + + # Now we check the alchemical Components + diff = stateA.component_diff(stateB) + + # Check that there's only one state A unique Component + if len(diff[0]) != 1: + errmsg = ( + "Only one alchemical species is supported " + "for absolute solvation free energies. " + f"Number of unique components found in stateA: {len(diff[0])}." + ) + raise ValueError(errmsg) + + # Make sure that the state A unique is an SMC + if not isinstance(diff[0][0], SmallMoleculeComponent): + errmsg = ( + "Only dissapearing SmallMoleculeComponents " + "are supported by this protocol. " + f"Found a {type(diff[0][0])}" + ) + raise ValueError(errmsg) + + # Check that the state A unique isn't charged + if diff[0][0].total_charge != 0: + errmsg = ( + "Charged alchemical molecules are not currently " + "supported for solvation free energies. " + f"Molecule total charge: {diff[0][0].total_charge}." + ) + raise ValueError(errmsg) + + # If there are any alchemical Components in state B + if len(diff[1]) > 0: + errmsg = "Components appearing in state B are not currently supported" + raise ValueError(errmsg) + + @staticmethod + def _validate_lambda_schedule( + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> None: + """ + Checks that the lambda schedule is set up correctly. + + Parameters + ---------- + lambda_settings : LambdaSettings + the lambda schedule Settings + simulation_settings : MultiStateSimulationSettings + the settings for either the vacuum or solvent phase + + Raises + ------ + ValueError + If the number of lambda windows differs for electrostatics and sterics. + If the number of replicas does not match the number of lambda windows. + If there are states with naked charges. + Warnings + If there are non-zero values for restraints (lambda_restraints). + """ + + lambda_elec = lambda_settings.lambda_elec + lambda_vdw = lambda_settings.lambda_vdw + lambda_restraints = lambda_settings.lambda_restraints + n_replicas = simulation_settings.n_replicas + + # Ensure that all lambda components have equal amount of windows + lambda_components = [lambda_vdw, lambda_elec, lambda_restraints] + it = iter(lambda_components) + the_len = len(next(it)) + if not all(len(lambda_comp) == the_len for lambda_comp in it): + errmsg = ( + "Components elec, vdw, and restraints must have equal amount" + f" of lambda windows. Got {len(lambda_elec)} elec lambda" + f" windows, {len(lambda_vdw)} vdw lambda windows, and" + f"{len(lambda_restraints)} restraints lambda windows." + ) + raise ValueError(errmsg) + + # Ensure that number of overall lambda windows matches number of lambda + # windows for individual components + if n_replicas != len(lambda_vdw): + errmsg = ( + f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(lambda_vdw)}" + ) + raise ValueError(errmsg) + + # Check if there are lambda windows with naked charges + for inx, lam in enumerate(lambda_elec): + if lam < 1 and lambda_vdw[inx] == 1: + errmsg = ( + "There are states along this lambda schedule " + "where there are atoms with charges but no LJ " + f"interactions: lambda {inx}: " + f"elec {lam} vdW {lambda_vdw[inx]}" + ) + raise ValueError(errmsg) + + # Check if there are lambda windows with non-zero restraints + if len([r for r in lambda_restraints if r != 0]) > 0: + wmsg = ( + "Non-zero restraint lambdas applied. The absolute " + "solvation protocol doesn't apply restraints, " + "therefore restraints won't be applied. " + f"Given lambda_restraints: {lambda_restraints}" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + def _validate( + self, + *, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, + ): + # Check we're not extending + if extends is not None: + # This should be a NotImplementedError, but the underlying + # `validate` method wraps a call to `_validate` around a + # NotImplementedError exception guard + raise ValueError("Can't extend simulations yet") + + # Check we're not using a mapping, since we're not doing anything with it + if mapping is not None: + wmsg = "A mapping was passed but is not used by this Protocol." + warnings.warn(wmsg) + + # Validate the endstates & alchemical components + self._validate_endstates(stateA, stateB) + + # Validate the lambda schedule + for solv_sets in ( + self.settings.solvent_simulation_settings, + self.settings.vacuum_simulation_settings, + ): + self._validate_lambda_schedule( + self.settings.lambda_settings, + solv_sets, + ) + + # Check nonbond & solvent compatibility + solv_nonbonded_method = self.settings.solvent_forcefield_settings.nonbonded_method + vac_nonbonded_method = self.settings.vacuum_forcefield_settings.nonbonded_method + + # Use the more complete system validation solvent checks + system_validation.validate_solvent(stateA, solv_nonbonded_method) + + # Gas phase is always gas phase + if vac_nonbonded_method.lower() != "nocutoff": + errmsg = ( + "Only the nocutoff nonbonded_method is supported for " + f"vacuum calculations, {vac_nonbonded_method} was " + "passed" + ) + raise ValueError(errmsg) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) + + # Check vacuum equilibration MD settings is 0 ns + nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt + if nvt_time is not None: + if not np.allclose(nvt_time, 0 * offunit.nanosecond): + errmsg = "NVT equilibration cannot be run in vacuum simulation" + raise ValueError(errmsg) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.vacuum_forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + settings_validation.validate_timestep( + self.settings.solvent_forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # Validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # Get the alchemical components + alchem_comps = system_validation.get_alchemical_components( + stateA, + stateB, + ) + + # Get the name of the alchemical species + alchname = alchem_comps["stateA"][0].name + + unit_classes = { + "solvent": { + "setup": AHFESolventSetupUnit, + "simulation": AHFESolventSimUnit, + "analysis": AHFESolventAnalysisUnit, + }, + "vacuum": { + "setup": AHFEVacuumSetupUnit, + "simulation": AHFEVacuumSimUnit, + "analysis": AHFEVacuumAnalysisUnit, + }, + } + + protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "vacuum": []} + + for phase in ["solvent", "vacuum"]: + for i in range(self.settings.protocol_repeats): + repeat_id = int(uuid.uuid4()) + + setup = unit_classes[phase]["setup"]( + protocol=self, + stateA=stateA, + stateB=stateB, + alchemical_components=alchem_comps, + generation=0, + repeat_id=repeat_id, + name=f"AHFE Setup: {alchname} {phase} leg: repeat {i} generation 0", + ) + + simulation = unit_classes[phase]["simulation"]( + protocol=self, + # only need state A & alchem comps + stateA=stateA, + alchemical_components=alchem_comps, + setup_results=setup, + generation=0, + repeat_id=repeat_id, + name=f"AHFE Simulation: {alchname} {phase} leg: repeat {i} generation 0", + ) + + analysis = unit_classes[phase]["analysis"]( + protocol=self, + setup_results=setup, + simulation_results=simulation, + generation=0, + repeat_id=repeat_id, + name=f"AHFE Analysis: {alchname} {phase} leg, repeat {i} generation 0", + ) + + protocol_units[phase] += [setup, simulation, analysis] + + return protocol_units["solvent"] + protocol_units["vacuum"] + + def _gather( + self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] + ) -> dict[str, dict[str, Any]]: + # result units will have a repeat_id and generation + # first group according to repeat_id + unsorted_solvent_repeats = defaultdict(list) + unsorted_vacuum_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + if ("Analysis" not in pu.name) or (not pu.ok()): + continue + if pu.outputs["simtype"] == "solvent": + unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) + else: + unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu) + + repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { + "solvent": {}, + "vacuum": {}, + } + for k, v in unsorted_solvent_repeats.items(): + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + + for k, v in unsorted_vacuum_repeats.items(): + repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + return repeats diff --git a/openfe/protocols/openmm_md/__init__.py b/src/openfe/protocols/openmm_md/__init__.py similarity index 100% rename from openfe/protocols/openmm_md/__init__.py rename to src/openfe/protocols/openmm_md/__init__.py diff --git a/openfe/protocols/openmm_md/plain_md_methods.py b/src/openfe/protocols/openmm_md/plain_md_methods.py similarity index 100% rename from openfe/protocols/openmm_md/plain_md_methods.py rename to src/openfe/protocols/openmm_md/plain_md_methods.py diff --git a/openfe/protocols/openmm_md/plain_md_settings.py b/src/openfe/protocols/openmm_md/plain_md_settings.py similarity index 100% rename from openfe/protocols/openmm_md/plain_md_settings.py rename to src/openfe/protocols/openmm_md/plain_md_settings.py diff --git a/src/openfe/protocols/openmm_rfe/__init__.py b/src/openfe/protocols/openmm_rfe/__init__.py new file mode 100644 index 00000000..f0fe367c --- /dev/null +++ b/src/openfe/protocols/openmm_rfe/__init__.py @@ -0,0 +1,12 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +from . import _rfe_utils +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocols import RelativeHybridTopologyProtocol +from .hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, +) diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py similarity index 100% rename from openfe/protocols/openmm_rfe/_rfe_utils/__init__.py rename to src/openfe/protocols/openmm_rfe/_rfe_utils/__init__.py diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py similarity index 100% rename from openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py rename to src/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py similarity index 89% rename from openfe/protocols/openmm_rfe/_rfe_utils/multistate.py rename to src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index 8c6b4edd..299a846f 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/src/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -26,14 +26,15 @@ from .lambdaprotocol import RelativeAlchemicalState logger = logging.getLogger(__name__) -class HybridCompatibilityMixin(object): +class HybridCompatibilityMixin: """ Mixin that allows the MultistateSampler to accommodate the situation where unsampled endpoints have a different number of degrees of freedom. """ - def __init__(self, *args, hybrid_factory=None, **kwargs): - self._hybrid_factory = hybrid_factory + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): + self._hybrid_system = hybrid_system + self._hybrid_positions = hybrid_positions super(HybridCompatibilityMixin, self).__init__(*args, **kwargs) def setup(self, reporter, lambda_protocol, @@ -73,15 +74,17 @@ class HybridCompatibilityMixin(object): """ n_states = len(lambda_protocol.lambda_schedule) - hybrid_system = self._factory.hybrid_system + lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system) - lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system) + thermostate = ThermodynamicState( + self._hybrid_system, + temperature=temperature + ) - thermostate = ThermodynamicState(hybrid_system, - temperature=temperature) compound_thermostate = CompoundThermodynamicState( - thermostate, - composable_states=[lambda_zero_state]) + thermostate, + composable_states=[lambda_zero_state] + ) # create lists for storing thermostates and sampler states thermodynamic_state_list = [] @@ -105,16 +108,20 @@ class HybridCompatibilityMixin(object): raise ValueError(errmsg) # starting with the hybrid factory positions - box = hybrid_system.getDefaultPeriodicBoxVectors() - sampler_state = SamplerState(self._factory.hybrid_positions, - box_vectors=box) + box = self._hybrid_system.getDefaultPeriodicBoxVectors() + sampler_state = SamplerState( + self._hybrid_positions, + box_vectors=box + ) # Loop over the lambdas and create & store a compound thermostate at # that lambda value for lambda_val in lambda_schedule: compound_thermostate_copy = copy.deepcopy(compound_thermostate) compound_thermostate_copy.set_alchemical_parameters( - lambda_val, lambda_protocol) + lambda_val, + lambda_protocol + ) thermodynamic_state_list.append(compound_thermostate_copy) # now generating a sampler_state for each thermodyanmic state, @@ -143,7 +150,8 @@ class HybridCompatibilityMixin(object): # generating unsampled endstates unsampled_dispersion_endstates = create_endstates( copy.deepcopy(thermodynamic_state_list[0]), - copy.deepcopy(thermodynamic_state_list[-1])) + copy.deepcopy(thermodynamic_state_list[-1]) + ) self.create(thermodynamic_states=thermodynamic_state_list, sampler_states=sampler_state_list, storage=reporter, unsampled_thermodynamic_states=unsampled_dispersion_endstates) @@ -159,10 +167,13 @@ class HybridRepexSampler(HybridCompatibilityMixin, number of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridRepexSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs) - self._factory = hybrid_factory + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs + ) class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): @@ -171,11 +182,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler): of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridSAMSSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory class HybridMultiStateSampler(HybridCompatibilityMixin, @@ -184,11 +197,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin, MultiStateSampler that supports unsample end states with a different number of positions """ - def __init__(self, *args, hybrid_factory=None, **kwargs): + def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs): super(HybridMultiStateSampler, self).__init__( - *args, hybrid_factory=hybrid_factory, **kwargs + *args, + hybrid_system=hybrid_system, + hybrid_positions=hybrid_positions, + **kwargs ) - self._factory = hybrid_factory def create_endstates(first_thermostate, last_thermostate): diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/relative.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py similarity index 100% rename from openfe/protocols/openmm_rfe/_rfe_utils/relative.py rename to src/openfe/protocols/openmm_rfe/_rfe_utils/relative.py diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py b/src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py similarity index 100% rename from openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py rename to src/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py diff --git a/src/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/src/openfe/protocols/openmm_rfe/equil_rfe_methods.py new file mode 100644 index 00000000..2b0d37e1 --- /dev/null +++ b/src/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -0,0 +1,26 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +"""Equilibrium Relative Free Energy Protocol using OpenMM and OpenMMTools in a +Perses-like manner. + +This module implements the necessary tooling to calculate the +relative free energy of a ligand transformation using OpenMM tools and one of +the following methods: + - Hamiltonian Replica Exchange + - Self-adjusted mixture sampling + - Independent window sampling + +Acknowledgements +---------------- +This Protocol is based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult +from .hybridtop_protocols import RelativeHybridTopologyProtocol +from .hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, +) diff --git a/openfe/protocols/openmm_rfe/equil_rfe_settings.py b/src/openfe/protocols/openmm_rfe/equil_rfe_settings.py similarity index 100% rename from openfe/protocols/openmm_rfe/equil_rfe_settings.py rename to src/openfe/protocols/openmm_rfe/equil_rfe_settings.py diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py b/src/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py new file mode 100644 index 00000000..c4da8e2a --- /dev/null +++ b/src/openfe/protocols/openmm_rfe/hybridtop_protocol_results.py @@ -0,0 +1,241 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +ProtocolUnitResults for Hybrid Topology methods using +OpenMM and OpenMMTools in a Perses-like manner. +""" + +import logging +import pathlib +import warnings +from typing import Optional, Union + +import gufe +import numpy as np +import numpy.typing as npt +from openff.units import Quantity +from openmmtools import multistate + +logger = logging.getLogger(__name__) + + +class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): + """ + Protocol results with the output of a RelativeHybridTopologyProtocol. + """ + + def __init__(self, **data): + super().__init__(**data) + # data is mapping of str(repeat_id): list[protocolunitresults] + # TODO: Detect when we have extensions and stitch these together? + if any(len(pur_list) > 2 for pur_list in self.data.values()): + raise NotImplementedError("Can't stitch together results yet") + + @staticmethod + def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid an edge case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.average(vals) * u + + def get_estimate(self) -> Quantity: + """Average free energy difference of this transformation + + Returns + ------- + dG : openff.units.Quantity + The free energy difference between the first and last states. This is + a Quantity defined with units. + """ + # TODO: Check this holds up completely for SAMS. + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_mean_estimate(dGs) + + @staticmethod + def compute_uncertainty(dGs: list[Quantity]) -> Quantity: + u = dGs[0].u + # convert all values to units of the first value, then take average of magnitude + # this would avoid a screwy case where each value was in different units + vals = np.asarray([dG.to(u).m for dG in dGs]) + + return np.std(vals) * u + + def get_uncertainty(self) -> Quantity: + """The uncertainty/error in the dG value: The std of the estimates of + each independent repeat + """ + + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] + return self.compute_uncertainty(dGs) + + def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: + """Return a list of tuples containing the individual free energy + estimates and associated MBAR errors for each repeat. + + Returns + ------- + dGs : list[tuple[openff.units.Quantity]] + n_replicate simulation list of tuples containing the free energy + estimates (first entry) and associated MBAR estimate errors + (second entry). + """ + dGs = [ + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + for pus in self.data.values() + ] + return dGs + + def get_forward_and_reverse_energy_analysis( + self, + ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: + """ + Get a list of forward and reverse analysis of the free energies + for each repeat using uncorrelated production samples. + + The returned dicts have keys: + 'fractions' - the fraction of data used for this estimate + 'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate + 'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty + + The 'fractions' values are a numpy array, while the other arrays are + Quantity arrays, with units attached. + + If the list entry is ``None`` instead of a dictionary, this indicates + that the analysis could not be carried out for that repeat. This + is most likely caused by MBAR convergence issues when attempting to + calculate free energies from too few samples. + + + Returns + ------- + forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]] + + + Raises + ------ + UserWarning + If any of the forward and reverse entries are ``None``. + """ + forward_reverse = [ + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() + ] + + if None in forward_reverse: + wmsg = ( + "One or more ``None`` entries were found in the list of " + "forward and reverse analyses. This is likely caused by " + "an MBAR convergence failure caused by too few independent " + "samples when calculating the free energies of the 10% " + "timeseries slice." + ) + warnings.warn(wmsg) + + return forward_reverse + + def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]: + """ + Return a list of dictionary containing the MBAR overlap estimates + calculated for each repeat. + + Returns + ------- + overlap_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following keys: + * ``scalar``: One minus the largest nontrivial eigenvalue + * ``eigenvalues``: The sorted (descending) eigenvalues of the + overlap matrix + * ``matrix``: Estimated overlap matrix of observing a sample from + state i in state j + """ + # Loop through and get the repeats and get the matrices + overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] + + return overlap_stats + + def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]: + """The replica lambda state transition statistics for each repeat. + + Note + ---- + This is currently only available in cases where a replica exchange + simulation was run. + + Returns + ------- + repex_stats : list[dict[str, npt.NDArray]] + A list of dictionaries containing the following: + * ``eigenvalues``: The sorted (descending) eigenvalues of the + lambda state transition matrix + * ``matrix``: The transition matrix estimate of a replica switching + from state i to state j. + """ + try: + repex_stats = [ + pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() + ] + except KeyError: + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" + raise ValueError(errmsg) + + return repex_stats + + def get_replica_states(self) -> list[npt.NDArray]: + """ + Returns the timeseries of replica states for each repeat. + + Returns + ------- + replica_states : List[npt.NDArray] + List of replica states for each repeat + """ + + def is_file(filename: str): + p = pathlib.Path(filename) + if not p.exists(): + errmsg = f"File could not be found {p}" + raise ValueError(errmsg) + return p + + replica_states = [] + + for pus in self.data.values(): + nc = is_file(pus[0].outputs["trajectory"]) + dir_path = nc.parents[0] + chk = is_file(pus[0].outputs["checkpoint"]).name + reporter = multistate.MultiStateReporter( + storage=nc, checkpoint_storage=chk, open_mode="r" + ) + replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) + reporter.close() + + return replica_states + + def equilibration_iterations(self) -> list[float]: + """ + Returns the number of equilibration iterations for each repeat + of the calculation. + + Returns + ------- + equilibration_lengths : list[float] + """ + equilibration_lengths = [ + pus[0].outputs["equilibration_iterations"] for pus in self.data.values() + ] + + return equilibration_lengths + + def production_iterations(self) -> list[float]: + """ + Returns the number of uncorrelated production samples for each + repeat of the calculation. + + Returns + ------- + production_lengths : list[float] + """ + production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] + + return production_lengths diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py new file mode 100644 index 00000000..47ac3779 --- /dev/null +++ b/src/openfe/protocols/openmm_rfe/hybridtop_protocols.py @@ -0,0 +1,662 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +Hybrid Topology Protocols using OpenMM and OpenMMTools in a Perses-like manner. + +Acknowledgements +---------------- +These Protocols are based on, and leverages components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +from __future__ import annotations + +import logging +import uuid +import warnings +from collections import defaultdict +from typing import Any, Iterable, Optional, Union + +import gufe +import numpy as np +from gufe import ( + ChemicalSystem, + Component, + ComponentMapping, + LigandAtomMapping, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, + settings, +) +from openff.units import unit as offunit + +from openfe.due import Doi, due + +from ..openmm_utils import ( + settings_validation, + system_validation, +) +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) +from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult +from .hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, +) + +logger = logging.getLogger(__name__) + + +due.cite( + Doi("10.5281/zenodo.1297683"), + description="Perses", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_rfe.hybridtop_protocols", + cite_module=True, +) + + +class RelativeHybridTopologyProtocol(gufe.Protocol): + """ + Relative Free Energy calculations using a Hybrid Topology scheme + using OpenMM and OpenMMTools. + + Based on `Perses `_ + + See Also + -------- + :mod:`openfe.protocols` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` + :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` + """ + + result_cls = RelativeHybridTopologyProtocolResult + _settings_cls = RelativeHybridTopologyProtocolSettings + _settings: RelativeHybridTopologyProtocolSettings + + @classmethod + def _default_settings(cls): + """A dictionary of initial settings for this creating this Protocol + + These settings are intended as a suitable starting point for creating + an instance of this protocol. It is recommended, however that care is + taken to inspect and customize these before performing a Protocol. + + Returns + ------- + Settings + a set of default settings + """ + return RelativeHybridTopologyProtocolSettings( + protocol_repeats=3, + forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), + thermo_settings=settings.ThermoSettings( + temperature=298.15 * offunit.kelvin, + pressure=1 * offunit.bar, + ), + partial_charge_settings=OpenFFPartialChargeSettings(), + solvation_settings=OpenMMSolvationSettings(), + alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), + lambda_settings=LambdaSettings(), + simulation_settings=MultiStateSimulationSettings( + equilibration_length=1.0 * offunit.nanosecond, + production_length=5.0 * offunit.nanosecond, + ), + engine_settings=OpenMMEngineSettings(), + integrator_settings=IntegratorSettings(), + output_settings=MultiStateOutputSettings(), + ) + + @classmethod + def _adaptive_settings( + cls, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping], + initial_settings: None | RelativeHybridTopologyProtocolSettings = None, + ) -> RelativeHybridTopologyProtocolSettings: + """ + Get the recommended OpenFE settings for this protocol based on the input states involved in the + transformation. + + These are intended as a suitable starting point for creating an instance of this protocol, which can be further + customized before performing a Protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The initial state of the transformation. + stateB : ChemicalSystem + The final state of the transformation. + mapping : LigandAtomMapping | list[LigandAtomMapping] + The mapping(s) between transforming components in stateA and stateB. + initial_settings : None | RelativeHybridTopologyProtocolSettings, optional + Initial settings to base the adaptive settings on. If None, default settings are used. + + Returns + ------- + RelativeHybridTopologyProtocolSettings + The recommended settings for this protocol based on the input states. + + Notes + ----- + - If the transformation involves a change in net charge, the settings are adapted to use a more expensive + protocol with 22 lambda windows and 20 ns production length per window. + - If both states contain a ProteinComponent, the solvation padding is set to 1 nm. + - If initial_settings is provided, the adaptive settings are based on a copy of these settings. + """ + # use initial settings or default settings + # this is needed for the CLI so we don't override user settings + if initial_settings is not None: + protocol_settings = initial_settings.model_copy(deep=True) + else: + protocol_settings = cls.default_settings() + + if isinstance(mapping, list): + mapping = mapping[0] + + if mapping.get_alchemical_charge_difference() != 0: + # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated + # + info = ( + "Charge changing transformation between ligands " + f"{mapping.componentA.name} and {mapping.componentB.name}. " + "A more expensive protocol with 22 lambda windows, sampled " + "for 20 ns each, will be used here." + ) + logger.info(info) + protocol_settings.alchemical_settings.explicit_charge_correction = True + protocol_settings.simulation_settings.production_length = 20 * offunit.nanosecond + protocol_settings.simulation_settings.n_replicas = 22 + protocol_settings.lambda_settings.lambda_windows = 22 + + # adapt the solvation padding based on the system components + if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent): + protocol_settings.solvation_settings.solvent_padding = 1 * offunit.nanometer + + return protocol_settings + + @staticmethod + def _validate_endstates( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the end states for the RFE protocol. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If either state contains more than one unique Component. + * If unique components are not SmallMoleculeComponents. + """ + # Get the difference in Components between each state + diff = stateA.component_diff(stateB) + + for i, entry in enumerate(diff): + state_label = "A" if i == 0 else "B" + + # Check that there is only one unique Component in each state + if len(entry) != 1: + errmsg = ( + "Only one alchemical component is allowed per end state. " + f"Found {len(entry)} in state {state_label}." + ) + raise ValueError(errmsg) + + # Check that the unique Component is a SmallMoleculeComponent + if not isinstance(entry[0], SmallMoleculeComponent): + errmsg = ( + f"Alchemical component in state {state_label} is of type " + f"{type(entry[0])}, but only SmallMoleculeComponents " + "transformations are currently supported." + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_mapping( + mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]], + alchemical_components: dict[str, list[Component]], + ) -> None: + """ + Validates that the provided mapping(s) are suitable for the RFE protocol. + + Parameters + ---------- + mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]] + all mappings between transforming components. + alchemical_components : dict[str, list[Component]] + Dictionary contatining the alchemical components for + states A and B. + + Raises + ------ + ValueError + * If there are more than one mapping or mapping is None + * If the mapping components are not in the alchemical components. + UserWarning + * Mappings which involve element changes in core atoms + """ + # if a single mapping is provided, convert to list + if isinstance(mapping, ComponentMapping): + mapping = [mapping] + + # For now we only support a single mapping + if mapping is None or len(mapping) > 1: + errmsg = "A single LigandAtomMapping is expected for this Protocol" + raise ValueError(errmsg) + + # check that the mapping components are in the alchemical components + for m in mapping: + for state in ["A", "B"]: + comp = getattr(m, f"component{state}") + if comp not in alchemical_components[f"state{state}"]: + raise ValueError( + f"Mapping component{state} {comp} not " + f"in alchemical components of state{state}" + ) + + # TODO: remove - this is now the default behaviour? + # Check for element changes in mappings + for m in mapping: + molA = m.componentA.to_rdkit() + molB = m.componentB.to_rdkit() + for i, j in m.componentA_to_componentB.items(): + atomA = molA.GetAtomWithIdx(i) + atomB = molB.GetAtomWithIdx(j) + if atomA.GetAtomicNum() != atomB.GetAtomicNum(): + wmsg = ( + f"Element change in mapping between atoms " + f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and " + f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" + "No mass scaling is attempted in the hybrid topology, " + "the average mass of the two atoms will be used in the " + "simulation" + ) + logger.warning(wmsg) + warnings.warn(wmsg) + + @staticmethod + def _validate_smcs( + stateA: ChemicalSystem, + stateB: ChemicalSystem, + ) -> None: + """ + Validates the SmallMoleculeComponents. + + Parameters + ---------- + stateA : ChemicalSystem + The chemical system of end state A. + stateB : ChemicalSystem + The chemical system of end state B. + + Raises + ------ + ValueError + * If there are isomorphic SmallMoleculeComponents with + different charges within a given ChemicalSystem. + """ + smcs_A = stateA.get_components_of_type(SmallMoleculeComponent) + smcs_B = stateB.get_components_of_type(SmallMoleculeComponent) + smcs_all = list(set(smcs_A).union(set(smcs_B))) + + def _equal_charges(moli, molj): + # Base case, both molecules don't have charges + if (moli.partial_charges is None) & (molj.partial_charges is None): + return True + # If either is None but not the other + if (moli.partial_charges is None) ^ (molj.partial_charges is None): + return False + # Check if the charges are close to each other + return np.allclose(moli.partial_charges, molj.partial_charges) + + clashes = [] + + for smcs in [smcs_A, smcs_B]: + offmols = [m.to_openff() for m in smcs] + for i, moli in enumerate(offmols): + for molj in offmols: + if moli.is_isomorphic_with(molj): + if not _equal_charges(moli, molj): + clashes.append(smcs[i]) + + if len(clashes) > 0: + errmsg = ( + "Found SmallMoleculeComponents that are isomorphic " + "but with different charges, this is not currently allowed. " + f"Affected components: {clashes}" + ) + raise ValueError(errmsg) + + @staticmethod + def _validate_charge_difference( + mapping: LigandAtomMapping, + nonbonded_method: str, + explicit_charge_correction: bool, + solvent_component: SolventComponent | None, + ): + """ + Validates the net charge difference between the two states. + + Parameters + ---------- + mapping : dict[str, ComponentMapping] + Dictionary of mappings between transforming components. + nonbonded_method : str + The OpenMM nonbonded method used for the simulation. + explicit_charge_correction : bool + Whether or not to use an explicit charge correction. + solvent_component : openfe.SolventComponent | None + The SolventComponent of the simulation. + + Raises + ------ + ValueError + * If an explicit charge correction is attempted and the + nonbonded method is not PME. + * If the absolute charge difference is greater than one + and an explicit charge correction is attempted. + * If an explicit charge correction is attempted and there is no + solvent present. + UserWarning + * If there is any charge difference. + """ + difference = mapping.get_alchemical_charge_difference() + + if abs(difference) == 0: + return + + if not explicit_charge_correction: + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) + logger.warning(wmsg) + warnings.warn(wmsg) + return + + if solvent_component is None: + errmsg = "Cannot use explicit charge correction without solvent" + raise ValueError(errmsg) + + # We implicitly check earlier that we have to have pme for a solvated + # system, so we only need to check the nonbonded method here + if nonbonded_method.lower() != "pme": + errmsg = "Explicit charge correction when not using PME is not currently supported." + raise ValueError(errmsg) + + if abs(difference) > 1: + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) + raise ValueError(errmsg) + + ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference] + + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) + logger.info(wmsg) + + @staticmethod + def _validate_simulation_settings( + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, + ): + """ + Validate various simulation settings, including but not limited to + timestep conversions, and output file write frequencies. + + Parameters + ---------- + simulation_settings : MultiStateSimulationSettings + The sampler simulation settings. + integrator_settings : IntegratorSettings + Settings defining the behaviour of the integrator. + output_settings : MultiStateOutputSettings + Settings defining the simulation file writing behaviour. + + Raises + ------ + ValueError + * If the + """ + + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=steps_per_iteration, + ) + + _ = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + if output_settings.positions_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' positions_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + if output_settings.velocities_write_frequency is not None: + _ = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocities_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + + _, _ = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, + ) + + def _validate( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None, + extends: gufe.ProtocolDAGResult | None = None, + ) -> None: + # Check we're not trying to extend + if extends: + # This technically should be NotImplementedError + # but gufe.Protocol.validate calls `_validate` wrapped around an + # except for NotImplementedError, so we can't raise it here + raise ValueError("Can't extend simulations yet") + + # Validate the end states + self._validate_endstates(stateA, stateB) + + # Validate the mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + self._validate_mapping(mapping, alchem_comps) + + # Validate the small molecule components + self._validate_smcs(stateA, stateB) + + # Validate solvent component + nonbond = self.settings.forcefield_settings.nonbonded_method + system_validation.validate_solvent(stateA, nonbond) + + # Validate solvation settings + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) + + # Validate protein component + system_validation.validate_protein(stateA) + + # Validate charge difference + # Note: validation depends on the mapping & solvent component checks + if stateA.contains(SolventComponent): + solv_comp = stateA.get_components_of_type(SolventComponent)[0] + else: + solv_comp = None + + self._validate_charge_difference( + mapping=mapping[0] if isinstance(mapping, list) else mapping, + nonbonded_method=self.settings.forcefield_settings.nonbonded_method, + explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction, + solvent_component=solv_comp, + ) + + # Validate integrator things + settings_validation.validate_timestep( + self.settings.forcefield_settings.hydrogen_mass, + self.settings.integrator_settings.timestep, + ) + + # Validate simulation & output settings + self._validate_simulation_settings( + self.settings.simulation_settings, + self.settings.integrator_settings, + self.settings.output_settings, + ) + + # Validate alchemical settings + # PR #125 temporarily pin lambda schedule spacing to n_replicas + if ( + self.settings.simulation_settings.n_replicas + != self.settings.lambda_settings.lambda_windows + ): + errmsg = ( + "Number of replicas in ``simulation_settings``: " + f"{self.settings.simulation_settings.n_replicas} must equal " + "the number of lambda windows in lambda_settings: " + f"{self.settings.lambda_settings.lambda_windows}." + ) + raise ValueError(errmsg) + + def _create( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]], + extends: Optional[gufe.ProtocolDAGResult] = None, + ) -> list[gufe.ProtocolUnit]: + # validate inputs + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) + + # get alchemical components and mapping + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) + ligandmapping = mapping[0] if isinstance(mapping, list) else mapping + + # actually create and return Units + Anames = ",".join(c.name for c in alchem_comps["stateA"]) + Bnames = ",".join(c.name for c in alchem_comps["stateB"]) + + # DAG dependency is setup -> simulation -> analysis + # |---------------------> + setup_units = [] + simulation_units = [] + analysis_units = [] + + for i in range(self.settings.protocol_repeats): + repeat_id = int(uuid.uuid4()) + + setup = HybridTopologySetupUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, + alchemical_components=alchem_comps, + generation=0, + repeat_id=repeat_id, + name=(f"HybridTopology Setup: {Anames} to {Bnames} repeat {i} generation 0"), + ) + + simulation = HybridTopologyMultiStateSimulationUnit( + protocol=self, + setup_results=setup, + generation=0, + repeat_id=repeat_id, + name=(f"HybridTopology Simulation: {Anames} to {Bnames} repeat {i} generation 0"), + ) + + analysis = HybridTopologyMultiStateAnalysisUnit( + protocol=self, + setup_results=setup, + simulation_results=simulation, + generation=0, + repeat_id=repeat_id, + name=(f"HybridTopology Analysis: {Anames} to {Bnames} repeat {i} generation 0"), + ) + setup_units.append(setup) + simulation_units.append(simulation) + analysis_units.append(analysis) + + return [*setup_units, *simulation_units, *analysis_units] + + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: + # result units will have a repeat_id and generations within this repeat_id + # first group according to repeat_id + unsorted_repeats = defaultdict(list) + for d in protocol_dag_results: + pu: gufe.ProtocolUnitResult + for pu in d.protocol_unit_results: + # We only need the analysis units that are ok + if ("Analysis" not in pu.name) or (not pu.ok()): + continue + + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) + + # then sort by generation within each repeat_id list + repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} + for k, v in unsorted_repeats.items(): + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) + + # returns a dict of repeat_id: sorted list of ProtocolUnitResult + return repeats diff --git a/src/openfe/protocols/openmm_rfe/hybridtop_units.py b/src/openfe/protocols/openmm_rfe/hybridtop_units.py new file mode 100644 index 00000000..c7afbd9c --- /dev/null +++ b/src/openfe/protocols/openmm_rfe/hybridtop_units.py @@ -0,0 +1,1521 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +""" +ProtocolUnits for Hybrid Topology methods using OpenMM and OpenMMTools in a +Perses-like manner. + +Acknowledgements +---------------- +These ProtocolUnits are based on, and leverage components originating from +the Perses toolkit (https://github.com/choderalab/perses). +""" + +import logging +import os +import pathlib +import subprocess +from itertools import chain +from typing import Any + +import gufe +import matplotlib.pyplot as plt +import mdtraj +import numpy as np +import numpy.typing as npt +import openmm +import openmmtools +from gufe import ( + ChemicalSystem, + Component, + LigandAtomMapping, + ProteinComponent, + SmallMoleculeComponent, + SolventComponent, +) +from gufe.settings import ( + SettingsBaseModel, + ThermoSettings, +) +from openff.toolkit.topology import Molecule as OFFMolecule +from openff.units import Quantity +from openff.units import unit as offunit +from openff.units.openmm import ensure_quantity, from_openmm, to_openmm +from openmmforcefields.generators import SystemGenerator +from openmmtools import multistate + +from openfe.protocols.openmm_utils.omm_settings import ( + BasePartialChargeSettings, +) + +from ...analysis import plotting +from ...utils import log_system_probe, without_oechem_backend +from ..openmm_utils import ( + charge_generation, + multistate_analysis, + omm_compute, + settings_validation, + system_creation, + system_validation, +) +from ..openmm_utils.serialization import ( + deserialize, + serialize, +) +from . import _rfe_utils +from ._rfe_utils.relative import HybridTopologyFactory +from .equil_rfe_settings import ( + AlchemicalSettings, + IntegratorSettings, + LambdaSettings, + MultiStateOutputSettings, + MultiStateSimulationSettings, + OpenFFPartialChargeSettings, + OpenMMEngineSettings, + OpenMMSolvationSettings, + RelativeHybridTopologyProtocolSettings, +) + +logger = logging.getLogger(__name__) + + +class HybridTopologyUnitMixin: + def _prepare( + self, + verbose: bool, + scratch_basepath: pathlib.Path | None, + shared_basepath: pathlib.Path | None, + ): + """ + Set basepaths and do some initial logging. + + Parameters + ---------- + verbose : bool + Verbose output of the simulation progress. Output is provided at the + INFO level logging. + scratch_basepath : pathlib.Path | None + Optional scratch base path to write scratch files to. + shared_basepath : pathlib.Path | None + Optional shared base path to write shared files to. + """ + self.verbose = verbose + + if self.verbose: + self.logger.info("Setting up the hybrid topology simulation") # type: ignore[attr-defined] + + # set basepaths + def _set_optional_path(basepath): + if basepath is None: + return pathlib.Path(".") + return basepath + + self.scratch_basepath = _set_optional_path(scratch_basepath) + self.shared_basepath = _set_optional_path(shared_basepath) + + @staticmethod + def _get_settings( + settings: RelativeHybridTopologyProtocolSettings, + ) -> dict[str, SettingsBaseModel]: + """ + Get a dictionary of Protocol settings. + + Returns + ------- + protocol_settings : dict[str, SettingsBaseModel] + + Notes + ----- + We return a dict so that we can duck type behaviour between phases. + For example subclasses may contain both `solvent` and `complex` + settings, using this approach we can extract the relevant entry + to the same key and pass it to other methods in a seamless manner. + """ + protocol_settings: dict[str, SettingsBaseModel] = {} + protocol_settings["forcefield_settings"] = settings.forcefield_settings + protocol_settings["thermo_settings"] = settings.thermo_settings + protocol_settings["alchemical_settings"] = settings.alchemical_settings + protocol_settings["lambda_settings"] = settings.lambda_settings + protocol_settings["charge_settings"] = settings.partial_charge_settings + protocol_settings["solvation_settings"] = settings.solvation_settings + protocol_settings["simulation_settings"] = settings.simulation_settings + protocol_settings["output_settings"] = settings.output_settings + protocol_settings["integrator_settings"] = settings.integrator_settings + protocol_settings["engine_settings"] = settings.engine_settings + return protocol_settings + + +class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin): + """ + Setup unit for Hybrid Topology Protocol transformations. + """ + + @staticmethod + def _get_components( + stateA: ChemicalSystem, stateB: ChemicalSystem + ) -> tuple[SolventComponent, ProteinComponent, dict[SmallMoleculeComponent, OFFMolecule]]: + """ + Get the components from the ChemicalSystem inputs. + + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining the state A components. + stateB : CHemicalSystem + ChemicalSystem defining the state B components. + + Returns + ------- + solv_comp : SolventComponent + The solvent component. + protein_comp : ProteinComponent + The protein component. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of small molecule components paired + with their OpenFF Molecule. + """ + solvent_comp, protein_comp, smcs_A = system_validation.get_components(stateA) + _, _, smcs_B = system_validation.get_components(stateB) + + small_mols = {m: m.to_openff() for m in set(smcs_A).union(set(smcs_B))} + + return solvent_comp, protein_comp, small_mols + + @staticmethod + def _assign_partial_charges( + charge_settings: OpenFFPartialChargeSettings, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + ) -> None: + """ + Assign partial charges to the OpenFF Molecules associated with all + the SmallMoleculeComponents in the transformation. + + Parameters + ---------- + charge_settings : OpenFFPartialChargeSettings + Settings for controlling how the partial charges are assigned. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules to add, keyed by + their associated SmallMoleculeComponent. + """ + for smc, mol in small_mols.items(): + charge_generation.assign_offmol_partial_charges( + offmol=mol, + overwrite=False, + method=charge_settings.partial_charge_method, + toolkit_backend=charge_settings.off_toolkit_backend, + generate_n_conformers=charge_settings.number_of_conformers, + nagl_model=charge_settings.nagl_model, + ) + + @staticmethod + def _get_system_generator( + settings: dict[str, SettingsBaseModel], + solvent_component: SolventComponent | None, + openff_molecules: list[OFFMolecule] | None, + ffcache: pathlib.Path | None, + ) -> SystemGenerator: + """ + Get an OpenMM SystemGenerator. + + Parameters + ---------- + settings : dict[str, SettingsBaseModel] + A dictionary of protocol settings. + solvent_component : SolventComponent | None + The solvent component of the system, if any. + openff_molecules : list[openff.toolkit.Molecule] | None + A list of openff molecules to generate templates for, if any. + ffcache : pathlib.Path | None + Path to the force field parameter cache. + + Returns + ------- + system_generator : openmmtools.SystemGenerator + The SystemGenerator for the protocol. + """ + system_generator = system_creation.get_system_generator( + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], + cache=ffcache, + has_solvent=solvent_component is not None, + ) + + # Handle openff Molecule templates + # TODO: revisit this once the SystemGenerator update happens + # and we start loading the whole protein into OpenFF Topologies + if openff_molecules is None: + return system_generator + + # Register all the templates, pass unique molecules to avoid clashes + system_generator.add_molecules(list(set(openff_molecules))) + + return system_generator + + @staticmethod + def _create_stateA_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + system_generator: SystemGenerator, + solvation_settings: OpenMMSolvationSettings, + ) -> tuple[ + openmm.System, openmm.app.Topology, openmm.unit.Quantity, dict[Component, npt.NDArray] + ]: + """ + Create an OpenMM System for state A. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + A list of small molecules to include in the System. + protein_component : ProteinComponent | None + Optionally, the protein component to include in the System. + solvent_component : SolventComponent | None + Optionally, the solvent component to include in the System. + system_generator : SystemGenerator + The SystemGenerator object ot use to construct the System. + solvation_settings : OpenMMSolvationSettings + Settings defining how to build the System. + + Returns + ------- + system : openmm.System + The System that defines state A. + topology : openmm.app.Topology + The Topology defining the returned System. + positions : openmm.unit.Quantity + The positions of the particles in the System. + comp_residues : dict[Component, npt.NDArray] + A dictionary defining which residues in the System + belong to which ChemicalSystem Component. + """ + modeller, comp_resids = system_creation.get_omm_modeller( + protein_comp=protein_component, + solvent_comp=solvent_component, + small_mols=small_mols, + omm_forcefield=system_generator.forcefield, + solvent_settings=solvation_settings, + ) + + topology = modeller.getTopology() + # Note: roundtrip positions to remove vec3 issues + positions = to_openmm(from_openmm(modeller.getPositions())) + + system = system_generator.create_system( + modeller.topology, + molecules=list(small_mols.values()), + ) + + return system, topology, positions, comp_resids + + @staticmethod + def _create_stateB_system( + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + mapping: LigandAtomMapping, + stateA_topology: openmm.app.Topology, + exclude_resids: npt.NDArray, + system_generator: SystemGenerator, + ) -> tuple[openmm.System, openmm.app.Topology, npt.NDArray]: + """ + Create the state B System from the state A Topology. + + Parameters + ---------- + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + Dictionary of OpenFF Molecules keyed by SmallMoleculeComponent + to be present in system B. + mapping : LigandAtomMapping + LigandAtomMapping defining the correspondance betwee state A + and B's alchemical ligand. + stateA_topology : openmm.app.Topology + The OpenMM topology for state A. + exclude_resids : npt.NDArray + A list of residues to exclude from state A when building state B. + system_generator : SystemGenerator + The SystemGenerator to use to build System B. + + Returns + ------- + system : openmm.System + The state B System. + topology : openmm.app.Topology + The OpenMM Topology associated with the state B System. + alchem_resids : npt.NDArray + The residue indices of the state B alchemical species. + """ + topology, alchem_resids = _rfe_utils.topologyhelpers.combined_topology( + topology1=stateA_topology, + topology2=small_mols[mapping.componentB].to_topology().to_openmm(), + exclude_resids=exclude_resids, + ) + + system = system_generator.create_system( + topology, + molecules=list(small_mols.values()), + ) + + return system, topology, alchem_resids + + @staticmethod + def _handle_net_charge( + stateA_topology: openmm.app.Topology, + stateA_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + stateB_system: openmm.System, + charge_difference: int, + system_mappings: dict[str, dict[int, int]], + distance_cutoff: Quantity, + solvent_component: SolventComponent | None, + ) -> None: + """ + Handle system net charge by adding an alchemical water. + + Parameters + ---------- + stateA_topology : openmm.app.Topology + stateA_positions : openmm.unit.Quantity + stateB_topology : openmm.app.Topology + stateB_system : openmm.System + charge_difference : int + system_mappings : dict[str, dict[int, int]] + distance_cutoff : Quantity + solvent_component : SolventComponent | None + """ + # Base case, return if no net charge + if charge_difference == 0: + return + + # Get the residue ids for waters to turn alchemical + alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( + topology=stateA_topology, + positions=stateA_positions, + charge_difference=charge_difference, + distance_cutoff=distance_cutoff, + ) + + # In-place modify state B alchemical waters to ions + _rfe_utils.topologyhelpers.handle_alchemical_waters( + water_resids=alchem_water_resids, + topology=stateB_topology, + system=stateB_system, + system_mapping=system_mappings, + charge_difference=charge_difference, + solvent_component=solvent_component, + ) + + def _get_omm_objects( + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: LigandAtomMapping, + settings: dict[str, SettingsBaseModel], + protein_component: ProteinComponent | None, + solvent_component: SolventComponent | None, + small_mols: dict[SmallMoleculeComponent, OFFMolecule], + ) -> tuple[ + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + openmm.System, + openmm.app.Topology, + openmm.unit.Quantity, + dict[str, dict[int, int]], + ]: + """ + Get OpenMM objects for both end states A and B. + + Parameters + ---------- + stateA : ChemicalSystem + ChemicalSystem defining end state A. + stateB : ChemicalSystem + ChemicalSystem defining end state B. + mapping : LigandAtomMapping + The mapping for alchemical components between state A and B. + settings : dict[str, SettingsBaseModel] + Settings for the transformation. + protein_component : ProteinComponent | None + The common ProteinComponent between the end states, if there is is one. + solvent_component : SolventComponent | None + The common SolventComponent between the end states, if there is one. + small_mols : dict[SmallMoleculeComponent, openff.toolkit.Molecule] + The small molecules for both end states. + + Returns + ------- + stateA_system : openmm.System + OpenMM System for state A. + stateA_topology : openmm.app.Topology + OpenMM Topology for the state A System. + stateA_positions : openmm.unit.Quantity + Positions of partials for state A System. + stateB_system : openmm.System + OpenMM System for state B. + stateB_topology : openmm.app.Topology + OpenMM Topology for the state B System. + stateB_positions : openmm.unit.Quantity + Positions of partials for state B System. + system_mapping : dict[str, dict[int, int]] + Dictionary of mappings defining the correspondance between + the two state Systems. + """ + if self.verbose: + self.logger.info("Parameterizing systems") + + def _filter_small_mols(smols, state): + return {smc: offmol for smc, offmol in smols.items() if state.contains(smc)} + + states_inputs = { + "A": {"state": stateA, "mols": _filter_small_mols(small_mols, stateA)}, + "B": {"state": stateB, "mols": _filter_small_mols(small_mols, stateB)}, + } + + # Everything involving systemgenerator handling has a risk of + # oechem <-> rdkit smiles conversion clashes, cautiously ban it. + with without_oechem_backend(): + # Get the system generators with all the templates registered + for state in ["A", "B"]: + ffcache = settings["output_settings"].forcefield_cache + if ffcache is not None: + ffcache = self.shared_basepath / (f"{state}_" + ffcache) + + states_inputs[state]["generator"] = self._get_system_generator( + settings=settings, + solvent_component=solvent_component, + openff_molecules=list(states_inputs[state]["mols"].values()), + ffcache=ffcache, + ) + + (stateA_system, stateA_topology, stateA_positions, comp_resids) = ( + self._create_stateA_system( + small_mols=states_inputs["A"]["mols"], + protein_component=protein_component, + solvent_component=solvent_component, + system_generator=states_inputs["A"]["generator"], + solvation_settings=settings["solvation_settings"], + ) + ) + + (stateB_system, stateB_topology, stateB_alchem_resids) = self._create_stateB_system( + small_mols=states_inputs["B"]["mols"], + mapping=mapping, + stateA_topology=stateA_topology, + exclude_resids=comp_resids[mapping.componentA], + system_generator=states_inputs["B"]["generator"], + ) + + # Get the mapping between the two systems + system_mappings = _rfe_utils.topologyhelpers.get_system_mappings( + old_to_new_atom_map=mapping.componentA_to_componentB, + old_system=stateA_system, + old_topology=stateA_topology, + old_resids=comp_resids[mapping.componentA], + new_system=stateB_system, + new_topology=stateB_topology, + new_resids=stateB_alchem_resids, + # These are non-optional settings for this method + fix_constraints=True, + ) + + # Net charge: add alchemical water if needed + # Must be done here as we in-place modify the particles of state B. + if settings["alchemical_settings"].explicit_charge_correction: + self._handle_net_charge( + stateA_topology=stateA_topology, + stateA_positions=stateA_positions, + stateB_topology=stateB_topology, + stateB_system=stateB_system, + charge_difference=mapping.get_alchemical_charge_difference(), + system_mappings=system_mappings, + distance_cutoff=settings["alchemical_settings"].explicit_charge_correction_cutoff, + solvent_component=solvent_component, + ) + + # Finally get the state B positions + stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( + system_mappings, + stateA_topology, + stateB_topology, + old_positions=ensure_quantity(stateA_positions, "openmm"), + insert_positions=ensure_quantity( + small_mols[mapping.componentB].conformers[0], "openmm" + ), + ) + + return ( + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, + ) + + @staticmethod + def _get_alchemical_system( + stateA_system: openmm.System, + stateA_positions: openmm.unit.Quantity, + stateA_topology: openmm.app.Topology, + stateB_system: openmm.System, + stateB_positions: openmm.unit.Quantity, + stateB_topology: openmm.app.Topology, + system_mappings: dict[str, dict[int, int]], + alchemical_settings: AlchemicalSettings, + ): + """ + Get the hybrid topology alchemical system. + + Parameters + ---------- + stateA_system : openmm.System + State A OpenMM System + stateA_positions : openmm.unit.Quantity + Positions of state A System + stateA_topology : openmm.app.Topology + Topology of state A System + stateB_system : openmm.System + State B OpenMM System + stateB_positions : openmm.unit.Quantity + Positions of state B System + stateB_topology : openmm.app.Topology + Topology of state B System + system_mappings : dict[str, dict[int, int]] + Mapping of corresponding atoms between the two Systems. + alchemical_settings : AlchemicalSettings + The alchemical settings defining how the alchemical system + will be built. + + Returns + ------- + hybrid_factory : HybridTopologyFactory + The factory creating the hybrid system. + hybrid_system : openmm.System + The hybrid System. + """ + if alchemical_settings.softcore_LJ.lower() == "gapsys": + softcore_LJ_v2 = True + elif alchemical_settings.softcore_LJ.lower() == "beutler": + softcore_LJ_v2 = False + + hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + old_to_new_atom_map=system_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=system_mappings["old_to_new_core_atom_map"], + use_dispersion_correction=alchemical_settings.use_dispersion_correction, + softcore_alpha=alchemical_settings.softcore_alpha, + softcore_LJ_v2=softcore_LJ_v2, + softcore_LJ_v2_alpha=alchemical_settings.softcore_alpha, + interpolate_old_and_new_14s=alchemical_settings.turn_off_core_unique_exceptions, + ) + + return hybrid_factory, hybrid_factory.hybrid_system + + def _subsample_topology( + self, + hybrid_topology: openmm.app.Topology, + hybrid_positions: openmm.unit.Quantity, + output_selection: str, + output_filename: str, + atom_classes: dict[str, set[int]], + ) -> npt.NDArray: + """ + Subsample the hybrid topology based on user-selected output selection + and write the subsampled topology to a PDB file. + + Parameters + ---------- + hybrid_topology : openmm.app.Topology + The hybrid system topology to subsample. + hybrid_positions : openmm.unit.Quantity + The hybrid system positions. + output_selection : str + An MDTraj selection string to subsample the topology with. + output_filename : str + The name of the file to write the PDB to. + atom_classes : dict[str, set[int]] + A dictionary defining what atoms belong to the different + components of the hybrid system. + + Returns + ------- + selection_indices : npt.NDArray + The indices of the subselected system. + + TODO + ---- + Modify this to also store the full system. + """ + selection_indices = hybrid_topology.select(output_selection) + + # Write out a PDB containing the subsampled hybrid state + # We use bfactors as a hack to label different states + # bfactor of 0 is environment atoms + # bfactor of 0.25 is unique old atoms + # bfactor of 0.5 is core atoms + # bfactor of 0.75 is unique new atoms + bfactors = np.zeros_like(selection_indices, dtype=float) + bfactors[np.isin(selection_indices, list(atom_classes["unique_old_atoms"]))] = 0.25 + bfactors[np.isin(selection_indices, list(atom_classes["core_atoms"]))] = 0.50 + bfactors[np.isin(selection_indices, list(atom_classes["unique_new_atoms"]))] = 0.75 + + if len(selection_indices) > 0: + traj = mdtraj.Trajectory( + hybrid_positions[selection_indices, :], + hybrid_topology.subset(selection_indices), + ).save_pdb( + self.shared_basepath / output_filename, + bfactors=bfactors, + ) + + return selection_indices + + def run( + self, + *, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, + ) -> dict[str, Any]: + """Setup a hybrid topology system. + + Parameters + ---------- + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: pathlib.Path | None + Where to store temporary files, defaults to current working directory + shared_basepath : pathlib.Path | None + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created by the setup unit or the debug objects + (e.g. HybridTopologyFactory) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Get components + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] + alchem_comps = self._inputs["alchemical_components"] + solvent_comp, protein_comp, small_mols = self._get_components(stateA, stateB) + + # Assign partial charges now to avoid any discrepancies later + self._assign_partial_charges(settings["charge_settings"], small_mols) + + ( + stateA_system, + stateA_topology, + stateA_positions, + stateB_system, + stateB_topology, + stateB_positions, + system_mappings, + ) = self._get_omm_objects( + stateA=stateA, + stateB=stateB, + mapping=mapping, + settings=settings, + protein_component=protein_comp, + solvent_component=solvent_comp, + small_mols=small_mols, + ) + + # Get the hybrid factory & system + hybrid_factory, hybrid_system = self._get_alchemical_system( + stateA_system=stateA_system, + stateA_positions=stateA_positions, + stateA_topology=stateA_topology, + stateB_system=stateB_system, + stateB_positions=stateB_positions, + stateB_topology=stateB_topology, + system_mappings=system_mappings, + alchemical_settings=settings["alchemical_settings"], + ) + + # Subselect system based on user inputs & write initial PDB + selection_indices = self._subsample_topology( + hybrid_topology=hybrid_factory.hybrid_topology, + hybrid_positions=hybrid_factory.hybrid_positions, + output_selection=settings["output_settings"].output_indices, + output_filename=settings["output_settings"].output_structure, + atom_classes=hybrid_factory._atom_classes, + ) + + # Serialize things + # OpenMM System + system_outfile = self.shared_basepath / "hybrid_system.xml.bz2" + serialize(hybrid_system, system_outfile) + + # Positions + positions_outfile = self.shared_basepath / "hybrid_positions.npy" + npy_positions = from_openmm(hybrid_factory.hybrid_positions).to("nanometer").m + np.save(positions_outfile, npy_positions) + + unit_results_dict = { + "system": system_outfile, + "positions": positions_outfile, + "pdb_structure": self.shared_basepath / settings["output_settings"].output_structure, + "selection_indices": selection_indices, + } + + if dry: + unit_results_dict |= { + # Adding unserialized objects so we can directly use them + # to chain units in tests + "hybrid_factory": hybrid_factory, + "hybrid_system": hybrid_system, + "hybrid_positions": hybrid_factory.hybrid_positions, + } + + return unit_results_dict + + def _execute( + self, + ctx: gufe.Context, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + } + + +class HybridTopologyMultiStateSimulationUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin): + """ + Multi-state simulation (e.g. multi replica methods like hamiltonian + replica exchange) unit for Hybrid Topology Protocol transformations. + """ + + @staticmethod + def _get_integrator( + integrator_settings: IntegratorSettings, + simulation_settings: MultiStateSimulationSettings, + system: openmm.System, + ) -> openmmtools.mcmc.LangevinDynamicsMove: + """ + Get and validate the integrator + + Parameters + ---------- + integrator_settings : IntegratorSettings + Settings controlling the Langevin integrator. + simulation_settings : MultiStateSimulationSettings + Settings controlling the simulation. + system : openmm.System + The OpenMM System. + + Returns + ------- + integrator : openmmtools.mcmc.LangevinDynamicsMove + The LangevinDynamicsMove integrator. + + Raises + ------ + ValueError + If there are virtual sites in the system, but velocities + are not being reassigned after every MCMC move. + """ + steps_per_iteration = settings_validation.convert_steps_per_iteration( + simulation_settings, integrator_settings + ) + + integrator = openmmtools.mcmc.LangevinDynamicsMove( + timestep=to_openmm(integrator_settings.timestep), + collision_rate=to_openmm(integrator_settings.langevin_collision_rate), + n_steps=steps_per_iteration, + reassign_velocities=integrator_settings.reassign_velocities, + n_restart_attempts=integrator_settings.n_restart_attempts, + constraint_tolerance=integrator_settings.constraint_tolerance, + ) + + # Validate for known issue when dealing with virtual sites + # and multistate simulations + if not integrator_settings.reassign_velocities: + for particle_idx in range(system.getNumParticles()): + if system.isVirtualSite(particle_idx): + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable with MCMC integrators." + ) + raise ValueError(errmsg) + + return integrator + + @staticmethod + def _get_reporter( + storage_path: pathlib.Path, + selection_indices: npt.NDArray, + output_settings: MultiStateOutputSettings, + simulation_settings: MultiStateSimulationSettings, + ) -> multistate.MultiStateReporter: + """ + Get the multistate reporter. + + Parameters + ---------- + storage_path : pathlib.Path + Path to the directory where files should be written. + selection_indices : npt.NDArray + The set of system indices to report positions & velocities for. + output_settings : MultiStateOutputSettings + Settings defining how outputs should be written. + simulation_settings : MultiStateSimulationSettings + Settings defining out the simulation should be run. + """ + nc = storage_path / output_settings.output_filename + chk = output_settings.checkpoint_storage_filename + + if output_settings.positions_write_frequency is not None: + pos_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.positions_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' position_write_frequency", + denominator_name="simulation settings' time_per_iteration", + ) + else: + pos_interval = 0 + + if output_settings.velocities_write_frequency is not None: + vel_interval = settings_validation.divmod_time_and_check( + numerator=output_settings.velocities_write_frequency, + denominator=simulation_settings.time_per_iteration, + numerator_name="output settings' velocity_write_frequency", + denominator_name="sampler settings' time_per_iteration", + ) + else: + vel_interval = 0 + + chk_intervals = settings_validation.convert_checkpoint_interval_to_iterations( + checkpoint_interval=output_settings.checkpoint_interval, + time_per_iteration=simulation_settings.time_per_iteration, + ) + + return multistate.MultiStateReporter( + storage=nc, + analysis_particle_indices=selection_indices, + checkpoint_interval=chk_intervals, + checkpoint_storage=chk, + position_interval=pos_interval, + velocity_interval=vel_interval, + ) + + @staticmethod + def _get_sampler( + system: openmm.System, + positions: openmm.unit.Quantity, + lambdas: _rfe_utils.lambdaprotocol.LambdaProtocol, + integrator: openmmtools.mcmc.MCMCMove, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + thermo_settings: ThermoSettings, + alchem_settings: AlchemicalSettings, + platform: openmm.Platform, + dry: bool, + ) -> multistate.MultiStateSampler: + """ + Get the MultiStateSampler. + + Parameters + ---------- + system : openmm.System + The OpenMM System to simulate. + positions : openmm.unit.Quantity + The positions of the OpenMM System. + lambdas : LambdaProtocol + The lambda protocol to sample along. + integrator : openmmtools.mcmc.MCMCMove + The integrator to use. + reporter : multistate.MultiStateReporter + The reporter to attach to the sampler. + simulation_settings : MultiStateSimulationSettings + The simulation control settings. + thermo_settings : ThermoSettings + The thermodynamic control settings. + alchem_settings : AlchemicalSettings + The alchemical transformation settings. + platform : openmm.Platform + The compute platform to use. + dry : bool + Whether or not this is a dry run. + + Returns + ------- + sampler : multistate.MultiStateSampler + The requested sampler. + """ + rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, + ) + + # convert early_termination_target_error from kcal/mol to kT + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermo_settings.temperature, + simulation_settings.early_termination_target_error, + ) + ) + + if simulation_settings.sampler_method.lower() == "repex": + sampler = _rfe_utils.multistate.HybridRepexSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + + elif simulation_settings.sampler_method.lower() == "sams": + sampler = _rfe_utils.multistate.HybridSAMSSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_minimum_iterations=rta_min_its, + flatness_criteria=simulation_settings.sams_flatness_criteria, + gamma0=simulation_settings.sams_gamma0, + ) + + elif simulation_settings.sampler_method.lower() == "independent": + sampler = _rfe_utils.multistate.HybridMultiStateSampler( + mcmc_moves=integrator, + hybrid_system=system, + hybrid_positions=positions, + online_analysis_interval=rta_its, + online_analysis_target_error=early_termination_target_error, + online_analysis_minimum_iterations=rta_min_its, + ) + + else: + raise AttributeError(f"Unknown sampler {simulation_settings.sampler_method}") + + sampler.setup( + n_replicas=simulation_settings.n_replicas, + reporter=reporter, + lambda_protocol=lambdas, + temperature=to_openmm(thermo_settings.temperature), + endstates=alchem_settings.endstate_dispersion_correction, + minimization_platform=platform.getName(), + # Set minimization steps to None when running in dry mode + # otherwise do a very small one to avoid NaNs + minimization_steps=100 if not dry else None, + ) + + # Get and set the context caches + sampler.energy_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + sampler.sampler_context_cache = openmmtools.cache.ContextCache( + capacity=None, + time_to_live=None, + platform=platform, + ) + + return sampler + + def _run_simulation( + self, + sampler: multistate.MultiStateSampler, + reporter: multistate.MultiStateReporter, + simulation_settings: MultiStateSimulationSettings, + integrator_settings: IntegratorSettings, + output_settings: MultiStateOutputSettings, + dry: bool, + ): + """ + Run the simulation. + + Parameters + ---------- + sampler : multistate.MultiStateSampler. + The sampler associated with the simulation to run. + reporter : multistate.MultiStateReporter + The reporter associated with the sampler. + simulation_settings : MultiStateSimulationSettings + Simulation control settings. + integrator_settings : IntegratorSettings + Integrator control settings. + output_settings : MultiStateOutputSettings + Simulation output control settings. + dry : bool + Whether or not to dry run the simulation. + """ + # Get the relevant simulation steps + mc_steps = settings_validation.convert_steps_per_iteration( + simulation_settings=simulation_settings, + integrator_settings=integrator_settings, + ) + + equil_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.equilibration_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + prod_steps = settings_validation.get_simsteps( + sim_length=simulation_settings.production_length, + timestep=integrator_settings.timestep, + mc_steps=mc_steps, + ) + + if not dry: # pragma: no-cover + # minimize + if self.verbose: + self.logger.info("minimizing systems") + + sampler.minimize(max_iterations=simulation_settings.minimization_steps) + + # equilibrate + if self.verbose: + self.logger.info("equilibrating systems") + + sampler.equilibrate(int(equil_steps / mc_steps)) + + # production + if self.verbose: + self.logger.info("running production phase") + + sampler.extend(int(prod_steps / mc_steps)) + + if self.verbose: + self.logger.info("production phase complete") + else: + # We ran a dry simulation + # close reporter when you're done, prevent file handle clashes + reporter.close() + + # TODO: review this is likely no longer necessary + # clean up the reporter file + fns = [ + self.shared_basepath / output_settings.output_filename, + self.shared_basepath / output_settings.checkpoint_storage_filename, + ] + for fn in fns: + os.remove(fn) + + def run( + self, + *, + system: openmm.System, + positions: openmm.unit.Quantity, + selection_indices: npt.NDArray, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, + ) -> dict[str, Any]: + """Run the free energy calculation using a multistate sampler. + + Parameters + ---------- + system : openmm.System + The System to simulate. + positions : openmm.unit.Quantity + The positions of the System. + selection_indices : npt.NDArray + Indices of the System particles to write to file. + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: pathlib.Path | None + Where to store temporary files, defaults to current working directory + shared_basepath : pathlib.Path | None + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get the settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Get the lambda schedule + # TODO - this should be better exposed to users + lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( + functions=settings["lambda_settings"].lambda_functions, + windows=settings["lambda_settings"].lambda_windows, + ) + + # Get the compute platform + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" + platform = omm_compute.get_openmm_platform( + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, + ) + + # Get the integrator + integrator = self._get_integrator( + integrator_settings=settings["integrator_settings"], + simulation_settings=settings["simulation_settings"], + system=system, + ) + + try: + # Get the reporter + reporter = self._get_reporter( + storage_path=self.shared_basepath, + selection_indices=selection_indices, + output_settings=settings["output_settings"], + simulation_settings=settings["simulation_settings"], + ) + + # Get sampler + sampler = self._get_sampler( + system=system, + positions=positions, + lambdas=lambdas, + integrator=integrator, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + thermo_settings=settings["thermo_settings"], + alchem_settings=settings["alchemical_settings"], + platform=platform, + dry=dry, + ) + + self._run_simulation( + sampler=sampler, + reporter=reporter, + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], + output_settings=settings["output_settings"], + dry=dry, + ) + finally: + # close reporter when you're done, prevent + # file handle clashes + reporter.close() + + # clear GPU contexts + # TODO: use cache.empty() calls when openmmtools #690 is resolved + # replace with above + for context in list(sampler.energy_context_cache._lru._data.keys()): + del sampler.energy_context_cache._lru._data[context] + for context in list(sampler.sampler_context_cache._lru._data.keys()): + del sampler.sampler_context_cache._lru._data[context] + # cautiously clear out the global context cache too + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): + del openmmtools.cache.global_context_cache._lru._data[context] + + del sampler.sampler_context_cache, sampler.energy_context_cache + + if not dry: + del integrator, sampler + + if not dry: # pragma: no-cover + return { + "nc": self.shared_basepath / settings["output_settings"].output_filename, + "checkpoint": self.shared_basepath + / settings["output_settings"].checkpoint_storage_filename, + } + else: + return { + "sampler": sampler, + "integrator": integrator, + } + + def _execute( + self, + ctx: gufe.Context, + *, + setup_results, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + # Get the relevant inputs + system = deserialize(setup_results.outputs["system"]) + positions = to_openmm(np.load(setup_results.outputs["positions"]) * offunit.nm) + selection_indices = setup_results.outputs["selection_indices"] + + # Run the unit + outputs = self.run( + system=system, + positions=positions, + selection_indices=selection_indices, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, + } + + +class HybridTopologyMultiStateAnalysisUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin): + """ + Analysis unit for multi-state Hybrid Topology Protocol transformations. + """ + + @staticmethod + def _analyze_multistate_energies( + trajectory: pathlib.Path, + checkpoint: pathlib.Path, + sampler_method: str, + output_directory: pathlib.Path, + dry: bool, + ): + """ + Analyze multistate energies and generate plots. + + Parameters + ---------- + trajectory : pathlib.Path + Path to the NetCDF trajectory file. + checkpoint : pathlib.Path + The name of the checkpoint file. Note this is + relative in path to the trajectory file. + sampler_method : str + The multistate sampler method used. + output_directory : pathlib.Path + The path to where plots will be written. + dry : bool + Whether or not we are running a dry run. + """ + reporter = multistate.MultiStateReporter( + storage=trajectory, + # Note: openmmtools only wants the name of the checkpoint + # file, it assumes it to be in the same place as the trajectory + checkpoint_storage=checkpoint.name, + open_mode="r", + ) + + analyzer = multistate_analysis.MultistateEquilFEAnalysis( + reporter=reporter, + sampling_method=sampler_method, + result_units=offunit.kilocalorie_per_mole, + ) + + # Only create plots when not doing a dry run + if not dry: + analyzer.plot(filepath=output_directory, filename_prefix="") + + analyzer.close() + reporter.close() + return analyzer.unit_results_dict + + @staticmethod + def _structural_analysis( + pdb_file: pathlib.Path, + trj_file: pathlib.Path, + output_directory: pathlib.Path, + dry: bool, + ) -> dict[str, str | pathlib.Path]: + """ + Run structural analysis using ``openfe-analysis``. + + Parameters + ---------- + pdb_file : pathlib.Path + Path to the PDB file. + trj_file : pathlib.Path + Path to the trajectory file. + output_directory : pathlib.Path + The output directory where plots and the data NPZ file + will be stored. + dry : bool + Whether or not we are running a dry run. + + Returns + ------- + dict[str, str | pathlib.Path] + Dictionary containing either the path to the NPZ + file with the structural data, or the analysis error. + + Notes + ----- + Don't put energy analysis here as it uses the MultiStateReporter, + the structural analysis requires the file handle to be closed. + """ + from openfe_analysis import rmsd + + try: + data = rmsd.gather_rms_data(pdb_file, trj_file) + # TODO: eventually change this to more specific exception types + except Exception as e: + return {"structural_analysis_error": str(e)} + + # Generate relevant plots if not a dry run + if not dry: + if d := data["protein_2D_RMSD"]: + fig = plotting.plot_2D_rmsd(d) + fig.savefig(output_directory / "protein_2D_RMSD.png") + plt.close(fig) + f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) + f2.savefig(output_directory / "ligand_COM_drift.png") + plt.close(f2) + + f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) + f3.savefig(output_directory / "ligand_RMSD.png") + plt.close(f3) + + # Write out an NPZ with all the relevant analysis data + npz_file = output_directory / "structural_analysis.npz" + np.savez_compressed( + npz_file, + protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), + ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), + ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + time_ps=np.asarray(data["time(ps)"], dtype=np.float32), + ) + + return {"structural_analysis": npz_file} + + def run( + self, + *, + pdb_file: pathlib.Path, + trajectory: pathlib.Path, + checkpoint: pathlib.Path, + dry: bool = False, + verbose: bool = True, + scratch_basepath: pathlib.Path | None = None, + shared_basepath: pathlib.Path | None = None, + ) -> dict[str, Any]: + """Analyze the multistate simulation. + + Parameters + ---------- + pdb_file : pathlib.Path + Path to the PDB file representing the subsampled structure. + trajectory : pathlib.Path + Path to the MultiStateReporter generated NetCDF file. + checkpoint : pathlib.Path + Path to the checkpoint file generated by MultiStateReporter. + dry : bool + Do a dry run of the calculation, creating all necessary hybrid + system components (topology, system, sampler, etc...) but without + running the simulation. + verbose : bool + Verbose output of the simulation progress. Output is provided via + INFO level logging. + scratch_basepath: pathlib.Path | None + Where to store temporary files, defaults to current working directory + shared_basepath : pathlib.Path | None + Where to run the calculation, defaults to current working directory + + Returns + ------- + dict + Outputs created in the basepath directory or the debug objects + (i.e. sampler) if ``dry==True``. + + Raises + ------ + error + Exception if anything failed + """ + # Prepare paths & verbosity + self._prepare(verbose, scratch_basepath, shared_basepath) + + # Get the settings + settings = self._get_settings(self._inputs["protocol"].settings) + + # Energies analysis + if verbose: + self.logger.info("Analyzing energies") + + energy_analysis = self._analyze_multistate_energies( + trajectory=trajectory, + checkpoint=checkpoint, + sampler_method=settings["simulation_settings"].sampler_method.lower(), + output_directory=self.shared_basepath, + dry=dry, + ) + + # Structural analysis + if verbose: + self.logger.info("Analyzing structural outputs") + + structural_analysis = self._structural_analysis( + pdb_file=pdb_file, + trj_file=trajectory, + output_directory=self.shared_basepath, + dry=dry, + ) + + # Return relevant things + outputs = energy_analysis | structural_analysis + return outputs + + def _execute( + self, + ctx: gufe.Context, + *, + setup_results, + simulation_results, + **inputs, + ) -> dict[str, Any]: + log_system_probe(logging.INFO, paths=[ctx.scratch]) + + pdb_file = setup_results.outputs["pdb_structure"] + selection_indices = setup_results.outputs["selection_indices"] + trajectory = simulation_results.outputs["nc"] + checkpoint = simulation_results.outputs["checkpoint"] + + outputs = self.run( + pdb_file=pdb_file, + trajectory=trajectory, + checkpoint=checkpoint, + scratch_basepath=ctx.scratch, + shared_basepath=ctx.shared, + ) + + return { + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + # We include various other outputs here to make + # things easier when gathering. + "pdb_structure": pdb_file, + "trajectory": trajectory, + "checkpoint": checkpoint, + "selection_indices": selection_indices, + **outputs, + } diff --git a/openfe/protocols/openmm_septop/__init__.py b/src/openfe/protocols/openmm_septop/__init__.py similarity index 100% rename from openfe/protocols/openmm_septop/__init__.py rename to src/openfe/protocols/openmm_septop/__init__.py diff --git a/openfe/protocols/openmm_septop/base.py b/src/openfe/protocols/openmm_septop/base.py similarity index 99% rename from openfe/protocols/openmm_septop/base.py rename to src/openfe/protocols/openmm_septop/base.py index 000f0b14..b3af8a4f 100644 --- a/openfe/protocols/openmm_septop/base.py +++ b/src/openfe/protocols/openmm_septop/base.py @@ -58,6 +58,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import ( from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit from openfe.protocols.openmm_utils import omm_compute from openfe.protocols.openmm_utils.omm_settings import SettingsBaseModel +from openfe.protocols.openmm_utils.serialization import deserialize from openfe.utils import without_oechem_backend from ..openmm_utils import ( @@ -66,7 +67,7 @@ from ..openmm_utils import ( settings_validation, system_creation, ) -from .utils import SepTopParameterState, deserialize +from .utils import SepTopParameterState logger = logging.getLogger(__name__) diff --git a/openfe/protocols/openmm_septop/equil_septop_method.py b/src/openfe/protocols/openmm_septop/equil_septop_method.py similarity index 99% rename from openfe/protocols/openmm_septop/equil_septop_method.py rename to src/openfe/protocols/openmm_septop/equil_septop_method.py index 87f7aa7a..51e0fded 100644 --- a/openfe/protocols/openmm_septop/equil_septop_method.py +++ b/src/openfe/protocols/openmm_septop/equil_septop_method.py @@ -81,6 +81,7 @@ from openfe.protocols.openmm_septop.equil_septop_settings import ( SepTopSettings, SettingsBaseModel, ) +from openfe.protocols.openmm_utils.serialization import serialize from openfe.protocols.restraint_utils import geometry from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry from openfe.protocols.restraint_utils.openmm import omm_restraints @@ -96,7 +97,6 @@ from ..restraint_utils.settings import ( DistanceRestraintSettings, ) from .base import BaseSepTopRunUnit, BaseSepTopSetupUnit, _pre_equilibrate -from .utils import serialize due.cite( Doi("10.1021/acs.jctc.3c00282"), @@ -2021,8 +2021,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): "topology": topology_file, "standard_state_correction_A": corr_A.to("kilocalorie_per_mole"), "standard_state_correction_B": corr_B.to("kilocalorie_per_mole"), - "restraint_geometry_A": restraint_geom_A.dict(), - "restraint_geometry_B": restraint_geom_B.dict(), + "restraint_geometry_A": restraint_geom_A.model_dump(), + "restraint_geometry_B": restraint_geom_B.model_dump(), } def _execute( diff --git a/openfe/protocols/openmm_septop/equil_septop_settings.py b/src/openfe/protocols/openmm_septop/equil_septop_settings.py similarity index 100% rename from openfe/protocols/openmm_septop/equil_septop_settings.py rename to src/openfe/protocols/openmm_septop/equil_septop_settings.py diff --git a/openfe/protocols/openmm_septop/utils.py b/src/openfe/protocols/openmm_septop/utils.py similarity index 62% rename from openfe/protocols/openmm_septop/utils.py rename to src/openfe/protocols/openmm_septop/utils.py index 53744b4d..38d7f1d3 100644 --- a/openfe/protocols/openmm_septop/utils.py +++ b/src/openfe/protocols/openmm_septop/utils.py @@ -1,84 +1,7 @@ -import os -import pathlib - from openmmtools import states from openmmtools.states import GlobalParameterState -def serialize(item, filename: pathlib.Path): - """ - Serialize an OpenMM System, State, or Integrator. - - Parameters - ---------- - item : System, State, or Integrator - The thing to be serialized - filename : str - The filename to serialize to - """ - from openmm import XmlSerializer - - # Create parent directory if it doesn't exist - filename_basedir = filename.parent - if not filename_basedir.exists(): - os.makedirs(filename_basedir) - - if filename.suffix == ".gz": - import gzip - - with gzip.open(filename, mode="wb") as outfile: - serialized_thing = XmlSerializer.serialize(item) - outfile.write(serialized_thing.encode()) - elif filename.suffix == ".bz2": - import bz2 - - with bz2.open(filename, mode="wb") as outfile: - serialized_thing = XmlSerializer.serialize(item) - outfile.write(serialized_thing.encode()) - else: - with open(filename, mode="w") as outfile: - serialized_thing = XmlSerializer.serialize(item) - outfile.write(serialized_thing) - - -def deserialize(filename: pathlib.Path): - """ - Deserialize an OpenMM System, State, or Integrator. - - Parameters - ---------- - item : System, State, or Integrator - The thing to be serialized - filename : str - The filename to serialize to - """ - from openmm import XmlSerializer - - # Create parent directory if it doesn't exist - filename_basedir = filename.parent - if not filename_basedir.exists(): - os.makedirs(filename_basedir) - - if filename.suffix == ".gz": - import gzip - - with gzip.open(filename, mode="rb") as infile: - serialized_thing = infile.read().decode() - item = XmlSerializer.deserialize(serialized_thing) - elif filename.suffix == ".bz2": - import bz2 - - with bz2.open(filename, mode="rb") as infile: - serialized_thing = infile.read().decode() - item = XmlSerializer.deserialize(serialized_thing) - else: - with open(filename) as infile: - serialized_thing = infile.read() - item = XmlSerializer.deserialize(serialized_thing) - - return item - - class SepTopParameterState(GlobalParameterState): """ Composable state to control lambda parameters for two ligands. diff --git a/openfe/protocols/openmm_utils/__init__.py b/src/openfe/protocols/openmm_utils/__init__.py similarity index 100% rename from openfe/protocols/openmm_utils/__init__.py rename to src/openfe/protocols/openmm_utils/__init__.py diff --git a/openfe/protocols/openmm_utils/charge_generation.py b/src/openfe/protocols/openmm_utils/charge_generation.py similarity index 95% rename from openfe/protocols/openmm_utils/charge_generation.py rename to src/openfe/protocols/openmm_utils/charge_generation.py index 2dcdd0ed..332eda51 100644 --- a/openfe/protocols/openmm_utils/charge_generation.py +++ b/src/openfe/protocols/openmm_utils/charge_generation.py @@ -7,7 +7,7 @@ Reusable utilities for assigning partial charges to ChemicalComponents. import copy import sys import warnings -from typing import Callable, Literal, Optional, Union +from typing import Callable, Literal import numpy as np from gufe import SmallMoleculeComponent @@ -112,7 +112,7 @@ def assign_offmol_espaloma_charges(offmol: OFFMol, toolkit_registry: ToolkitRegi def assign_offmol_nagl_charges( offmol: OFFMol, toolkit_registry: ToolkitRegistry, - nagl_model: Optional[str] = None, + nagl_model: str | None = None, ) -> None: """ Assign NAGL charges using the OpenFF toolkit. @@ -126,7 +126,7 @@ def assign_offmol_nagl_charges( This strictly limits available toolkit wrappers by overwriting the global registry during the partial charge assignment stage. - nagl_model : Optional[str] + nagl_model : str | None The NAGL model to use when assigning partial charges. If ``None``, will fetch the latest production "am1bcc" model. """ @@ -208,7 +208,7 @@ def _generate_offmol_conformers( offmol: OFFMol, max_conf: int, toolkit_registry: ToolkitRegistry, - generate_n_conformers: Optional[int], + generate_n_conformers: int | None, ) -> None: """ Helper method for OFF Molecule conformer generation in charge assignment. @@ -223,7 +223,7 @@ def _generate_offmol_conformers( Toolkit registry to use for generating conformers. This strictly limits available toolkit wrappers by overwriting the global registry during the conformer generation step. - generate_n_conformers : Optional[int] + generate_n_conformers : int | None The number of conformers to generate. If ``None``, the existing conformers are retained & used. @@ -288,8 +288,8 @@ def assign_offmol_partial_charges( overwrite: bool, method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"], toolkit_backend: Literal["ambertools", "openeye", "rdkit"], - generate_n_conformers: Optional[int], - nagl_model: Optional[str], + generate_n_conformers: int | None, + nagl_model: str | None, ) -> OFFMol: """ Assign partial charges to an OpenFF Molecule based on a selected method. @@ -312,11 +312,11 @@ def assign_offmol_partial_charges( * ``rdkit``: selects the RDKit toolkit Wrapper Note that the ``rdkit`` backend cannot be used for `am1bcc` or ``am1bccelf10`` partial charge methods. - generate_n_conformers : Optional[int] + generate_n_conformers : int | None Number of conformers to generate for partial charge generation. - If ``None`` (default), the input conformer will be used. + If ``None``, the input conformer will be used. Values greater than 1 can only be used alongside ``am1bccelf10``. - nagl_model : Optional[str] + nagl_model : str | None The NAGL model to use for charge assignment if method is ``nagl``. If ``None``, the latest am1bcc NAGL charge model is used. @@ -403,6 +403,12 @@ def assign_offmol_partial_charges( errmsg = "OpenEye is not available and cannot be selected as a backend" raise ImportError(errmsg) + # Issue 1760 + if HAS_OPENEYE and method.lower() == "nagl": + if toolkit_backend.lower() != "openeye": + errmsg = "OpenEye toolkit is installed but not used in the OpenFF toolkit registry backend. This is not possible with NAGL charges." + raise ValueError(errmsg) + toolkits = ToolkitRegistry([i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]]) # We make a copy of the molecule since we're going to modify conformers @@ -437,8 +443,8 @@ def bulk_assign_partial_charges( overwrite: bool, method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"], toolkit_backend: Literal["ambertools", "openeye", "rdkit"], - generate_n_conformers: Optional[int], - nagl_model: Optional[str], + generate_n_conformers: int | None, + nagl_model: str | None, processors: int = 1, ) -> list[SmallMoleculeComponent]: """ @@ -462,11 +468,11 @@ def bulk_assign_partial_charges( * ``rdkit``: selects the RDKit toolkit Wrapper Note that the ``rdkit`` backend cannot be used for `am1bcc` or ``am1bccelf10`` partial charge methods. - generate_n_conformers : Optional[int] + generate_n_conformers : int | None Number of conformers to generate for partial charge generation. - If ``None`` (default), the input conformer will be used. + If ``None``, the input conformer will be used. Values greater than 1 can only be used alongside ``am1bccelf10``. - nagl_model : Optional[str] + nagl_model : str | None The NAGL model to use for charge assignment if method is ``nagl``. If ``None``, the latest am1bcc NAGL charge model is used. processors: int, default 1 diff --git a/openfe/protocols/openmm_utils/multistate_analysis.py b/src/openfe/protocols/openmm_utils/multistate_analysis.py similarity index 100% rename from openfe/protocols/openmm_utils/multistate_analysis.py rename to src/openfe/protocols/openmm_utils/multistate_analysis.py diff --git a/openfe/protocols/openmm_utils/omm_compute.py b/src/openfe/protocols/openmm_utils/omm_compute.py similarity index 100% rename from openfe/protocols/openmm_utils/omm_compute.py rename to src/openfe/protocols/openmm_utils/omm_compute.py diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/src/openfe/protocols/openmm_utils/omm_settings.py similarity index 100% rename from openfe/protocols/openmm_utils/omm_settings.py rename to src/openfe/protocols/openmm_utils/omm_settings.py diff --git a/src/openfe/protocols/openmm_utils/serialization.py b/src/openfe/protocols/openmm_utils/serialization.py new file mode 100644 index 00000000..e0f2c28e --- /dev/null +++ b/src/openfe/protocols/openmm_utils/serialization.py @@ -0,0 +1,89 @@ +import os +import pathlib + +from gufe.settings.typing import NanometerArrayQuantity +from openff.units import Quantity +from openmm import Vec3 +from openmm import unit as ommunit + + +def serialize(item, filename: pathlib.Path): + """ + Serialize an OpenMM System, State, or Integrator. + + Parameters + ---------- + item : System, State, or Integrator + The thing to be serialized + filename : str + The filename to serialize to + """ + from openmm import XmlSerializer + + # Create parent directory if it doesn't exist + filename_basedir = filename.parent + if not filename_basedir.exists(): + os.makedirs(filename_basedir) + + if filename.suffix == ".bz2": + import bz2 + + with bz2.open(filename, mode="wb") as outfile: + serialized_thing = XmlSerializer.serialize(item) + outfile.write(serialized_thing.encode()) + else: + with open(filename, mode="w") as outfile: + serialized_thing = XmlSerializer.serialize(item) + outfile.write(serialized_thing) + + +def deserialize(filename: pathlib.Path): + """ + Deserialize an OpenMM System, State, or Integrator. + + Parameters + ---------- + item : System, State, or Integrator + The thing to be serialized + filename : str + The filename to serialize to + """ + from openmm import XmlSerializer + + # Create parent directory if it doesn't exist + filename_basedir = filename.parent + if not filename_basedir.exists(): + os.makedirs(filename_basedir) + + if filename.suffix == ".bz2": + import bz2 + + with bz2.open(filename, mode="rb") as infile: + serialized_thing = infile.read().decode() + item = XmlSerializer.deserialize(serialized_thing) + else: + with open(filename) as infile: + serialized_thing = infile.read() + item = XmlSerializer.deserialize(serialized_thing) + + return item + + +def make_vec3_box(dimensions: NanometerArrayQuantity) -> Vec3: + """ + Convert an OpenFF box dimensions Quantity back into Vec3 format. + + Parameters + ---------- + dimensions : gufe.settings.typing.NanometerArrayQuantity + United array to turn to Vec3 format. + + Returns + ------- + openmm.Vec3 + The input array in Vec3 format. + """ + return [ + Vec3(float(row[0]), float(row[1]), float(row[2])) * ommunit.nanometer + for row in dimensions.m_as("nanometer") + ] diff --git a/openfe/protocols/openmm_utils/settings_validation.py b/src/openfe/protocols/openmm_utils/settings_validation.py similarity index 100% rename from openfe/protocols/openmm_utils/settings_validation.py rename to src/openfe/protocols/openmm_utils/settings_validation.py diff --git a/openfe/protocols/openmm_utils/system_creation.py b/src/openfe/protocols/openmm_utils/system_creation.py similarity index 100% rename from openfe/protocols/openmm_utils/system_creation.py rename to src/openfe/protocols/openmm_utils/system_creation.py diff --git a/openfe/protocols/openmm_utils/system_validation.py b/src/openfe/protocols/openmm_utils/system_validation.py similarity index 78% rename from openfe/protocols/openmm_utils/system_validation.py rename to src/openfe/protocols/openmm_utils/system_validation.py index 0fd3c351..3e8ed5c5 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/src/openfe/protocols/openmm_utils/system_validation.py @@ -95,23 +95,24 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] + solv_comps = state.get_components_of_type(SolventComponent) - if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": - errmsg = "nocutoff cannot be used for solvent transformations" - raise ValueError(errmsg) + if len(solv_comps) > 0: + if nonbonded_method.lower() == "nocutoff": + errmsg = "nocutoff cannot be used for solvent transformations" + raise ValueError(errmsg) - if len(solv) == 0 and nonbonded_method.lower() == "pme": - errmsg = "PME cannot be used for vacuum transform" - raise ValueError(errmsg) + if len(solv_comps) > 1: + errmsg = "Multiple SolventComponent found, only one is supported" + raise ValueError(errmsg) - if len(solv) > 1: - errmsg = "Multiple SolventComponent found, only one is supported" - raise ValueError(errmsg) - - if len(solv) > 0 and solv[0].smiles != "O": - errmsg = "Non water solvent is not currently supported" - raise ValueError(errmsg) + if solv_comps[0].smiles != "O": + errmsg = "Non water solvent is not currently supported" + raise ValueError(errmsg) + else: + if nonbonded_method.lower() == "pme": + errmsg = "PME cannot be used for vacuum transform" + raise ValueError(errmsg) def validate_protein(state: ChemicalSystem): @@ -129,9 +130,9 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) + prot_comps = state.get_components_of_type(ProteinComponent) - if nprot > 1: + if len(prot_comps) > 1: errmsg = "Multiple ProteinComponent found, only one is supported" raise ValueError(errmsg) @@ -161,24 +162,18 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: small_mols : list[SmallMoleculeComponent] """ - def _get_single_comps(comp_list, comptype): - ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)] - if ret_comps: - return ret_comps[0] + def _get_single_comps(state, comptype): + comps = state.get_components_of_type(comptype) + + if len(comps) > 0: + return comps[0] else: return None - solvent_comp: Optional[SolventComponent] = _get_single_comps( - list(state.values()), SolventComponent - ) + solvent_comp: Optional[SolventComponent] = _get_single_comps(state, SolventComponent) - protein_comp: Optional[ProteinComponent] = _get_single_comps( - list(state.values()), ProteinComponent - ) + protein_comp: Optional[ProteinComponent] = _get_single_comps(state, ProteinComponent) - small_mols = [] - for comp in state.components.values(): - if isinstance(comp, SmallMoleculeComponent): - small_mols.append(comp) + small_mols = state.get_components_of_type(SmallMoleculeComponent) return solvent_comp, protein_comp, small_mols diff --git a/openfe/protocols/restraint_utils/openmm/__init__.py b/src/openfe/protocols/restraint_utils/__init__.py similarity index 100% rename from openfe/protocols/restraint_utils/openmm/__init__.py rename to src/openfe/protocols/restraint_utils/__init__.py diff --git a/openfe/protocols/restraint_utils/geometry/__init__.py b/src/openfe/protocols/restraint_utils/geometry/__init__.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/__init__.py rename to src/openfe/protocols/restraint_utils/geometry/__init__.py diff --git a/openfe/protocols/restraint_utils/geometry/base.py b/src/openfe/protocols/restraint_utils/geometry/base.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/base.py rename to src/openfe/protocols/restraint_utils/geometry/base.py diff --git a/openfe/protocols/restraint_utils/geometry/boresch/__init__.py b/src/openfe/protocols/restraint_utils/geometry/boresch/__init__.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/boresch/__init__.py rename to src/openfe/protocols/restraint_utils/geometry/boresch/__init__.py diff --git a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py b/src/openfe/protocols/restraint_utils/geometry/boresch/geometry.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/boresch/geometry.py rename to src/openfe/protocols/restraint_utils/geometry/boresch/geometry.py diff --git a/openfe/protocols/restraint_utils/geometry/boresch/guest.py b/src/openfe/protocols/restraint_utils/geometry/boresch/guest.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/boresch/guest.py rename to src/openfe/protocols/restraint_utils/geometry/boresch/guest.py diff --git a/openfe/protocols/restraint_utils/geometry/boresch/host.py b/src/openfe/protocols/restraint_utils/geometry/boresch/host.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/boresch/host.py rename to src/openfe/protocols/restraint_utils/geometry/boresch/host.py diff --git a/openfe/protocols/restraint_utils/geometry/flatbottom.py b/src/openfe/protocols/restraint_utils/geometry/flatbottom.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/flatbottom.py rename to src/openfe/protocols/restraint_utils/geometry/flatbottom.py diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/src/openfe/protocols/restraint_utils/geometry/harmonic.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/harmonic.py rename to src/openfe/protocols/restraint_utils/geometry/harmonic.py diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/src/openfe/protocols/restraint_utils/geometry/utils.py similarity index 100% rename from openfe/protocols/restraint_utils/geometry/utils.py rename to src/openfe/protocols/restraint_utils/geometry/utils.py diff --git a/openfe/storage/__init__.py b/src/openfe/protocols/restraint_utils/openmm/__init__.py similarity index 100% rename from openfe/storage/__init__.py rename to src/openfe/protocols/restraint_utils/openmm/__init__.py diff --git a/openfe/protocols/restraint_utils/openmm/omm_forces.py b/src/openfe/protocols/restraint_utils/openmm/omm_forces.py similarity index 100% rename from openfe/protocols/restraint_utils/openmm/omm_forces.py rename to src/openfe/protocols/restraint_utils/openmm/omm_forces.py diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/src/openfe/protocols/restraint_utils/openmm/omm_restraints.py similarity index 100% rename from openfe/protocols/restraint_utils/openmm/omm_restraints.py rename to src/openfe/protocols/restraint_utils/openmm/omm_restraints.py diff --git a/openfe/protocols/restraint_utils/settings.py b/src/openfe/protocols/restraint_utils/settings.py similarity index 100% rename from openfe/protocols/restraint_utils/settings.py rename to src/openfe/protocols/restraint_utils/settings.py diff --git a/openfe/setup/__init__.py b/src/openfe/setup/__init__.py similarity index 100% rename from openfe/setup/__init__.py rename to src/openfe/setup/__init__.py diff --git a/openfe/setup/alchemical_network_planner/__init__.py b/src/openfe/setup/alchemical_network_planner/__init__.py similarity index 100% rename from openfe/setup/alchemical_network_planner/__init__.py rename to src/openfe/setup/alchemical_network_planner/__init__.py diff --git a/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py b/src/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py similarity index 100% rename from openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py rename to src/openfe/setup/alchemical_network_planner/abstract_alchemical_network_planner.py diff --git a/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py b/src/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py similarity index 100% rename from openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py rename to src/openfe/setup/alchemical_network_planner/relative_alchemical_network_planner.py diff --git a/openfe/setup/atom_mapping/__init__.py b/src/openfe/setup/atom_mapping/__init__.py similarity index 100% rename from openfe/setup/atom_mapping/__init__.py rename to src/openfe/setup/atom_mapping/__init__.py diff --git a/openfe/setup/atom_mapping/ligandatommapper.py b/src/openfe/setup/atom_mapping/ligandatommapper.py similarity index 100% rename from openfe/setup/atom_mapping/ligandatommapper.py rename to src/openfe/setup/atom_mapping/ligandatommapper.py diff --git a/openfe/setup/atom_mapping/lomap_mapper.py b/src/openfe/setup/atom_mapping/lomap_mapper.py similarity index 100% rename from openfe/setup/atom_mapping/lomap_mapper.py rename to src/openfe/setup/atom_mapping/lomap_mapper.py diff --git a/openfe/setup/atom_mapping/lomap_scorers.py b/src/openfe/setup/atom_mapping/lomap_scorers.py similarity index 100% rename from openfe/setup/atom_mapping/lomap_scorers.py rename to src/openfe/setup/atom_mapping/lomap_scorers.py diff --git a/openfe/setup/atom_mapping/perses_mapper.py b/src/openfe/setup/atom_mapping/perses_mapper.py similarity index 100% rename from openfe/setup/atom_mapping/perses_mapper.py rename to src/openfe/setup/atom_mapping/perses_mapper.py diff --git a/openfe/setup/atom_mapping/perses_scorers.py b/src/openfe/setup/atom_mapping/perses_scorers.py similarity index 100% rename from openfe/setup/atom_mapping/perses_scorers.py rename to src/openfe/setup/atom_mapping/perses_scorers.py diff --git a/openfe/setup/chemicalsystem_generator/__init__.py b/src/openfe/setup/chemicalsystem_generator/__init__.py similarity index 100% rename from openfe/setup/chemicalsystem_generator/__init__.py rename to src/openfe/setup/chemicalsystem_generator/__init__.py diff --git a/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py b/src/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py similarity index 100% rename from openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py rename to src/openfe/setup/chemicalsystem_generator/abstract_chemicalsystem_generator.py diff --git a/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py b/src/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py similarity index 100% rename from openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py rename to src/openfe/setup/chemicalsystem_generator/easy_chemicalsystem_generator.py diff --git a/openfe/setup/ligand_network_planning.py b/src/openfe/setup/ligand_network_planning.py similarity index 100% rename from openfe/setup/ligand_network_planning.py rename to src/openfe/setup/ligand_network_planning.py diff --git a/openfe/tests/__init__.py b/src/openfe/storage/__init__.py similarity index 100% rename from openfe/tests/__init__.py rename to src/openfe/storage/__init__.py diff --git a/openfe/storage/metadatastore.py b/src/openfe/storage/metadatastore.py similarity index 100% rename from openfe/storage/metadatastore.py rename to src/openfe/storage/metadatastore.py diff --git a/openfe/storage/resultclient.py b/src/openfe/storage/resultclient.py similarity index 100% rename from openfe/storage/resultclient.py rename to src/openfe/storage/resultclient.py diff --git a/openfe/storage/resultserver.py b/src/openfe/storage/resultserver.py similarity index 100% rename from openfe/storage/resultserver.py rename to src/openfe/storage/resultserver.py diff --git a/openfe/tests/analysis/__init__.py b/src/openfe/tests/__init__.py similarity index 100% rename from openfe/tests/analysis/__init__.py rename to src/openfe/tests/__init__.py diff --git a/openfe/tests/data/__init__.py b/src/openfe/tests/analysis/__init__.py similarity index 100% rename from openfe/tests/data/__init__.py rename to src/openfe/tests/analysis/__init__.py diff --git a/openfe/tests/analysis/test_plotting.py b/src/openfe/tests/analysis/test_plotting.py similarity index 100% rename from openfe/tests/analysis/test_plotting.py rename to src/openfe/tests/analysis/test_plotting.py diff --git a/openfe/tests/conftest.py b/src/openfe/tests/conftest.py similarity index 99% rename from openfe/tests/conftest.py rename to src/openfe/tests/conftest.py index a1bec3b2..deaad59e 100644 --- a/openfe/tests/conftest.py +++ b/src/openfe/tests/conftest.py @@ -12,6 +12,7 @@ import mdtraj import numpy as np import openmm import pandas as pd +import pooch import pytest from gufe import AtomMapper, LigandAtomMapping, ProteinComponent, SmallMoleculeComponent from openff.toolkit import ForceField @@ -21,9 +22,10 @@ from rdkit import Chem from rdkit.Chem import AllChem import openfe +from openfe.data._registry import POOCH_CACHE from openfe.protocols.openmm_rfe import RelativeHybridTopologyProtocol from openfe.protocols.openmm_rfe._rfe_utils.relative import HybridTopologyFactory -from openfe.protocols.openmm_septop.utils import deserialize +from openfe.protocols.openmm_utils.serialization import deserialize from openfe.tests.protocols.openmm_rfe.helpers import make_htf diff --git a/openfe/tests/data/181l_only.pdb b/src/openfe/tests/data/181l_only.pdb similarity index 100% rename from openfe/tests/data/181l_only.pdb rename to src/openfe/tests/data/181l_only.pdb diff --git a/openfe/tests/data/6CZJ.pdb.gz b/src/openfe/tests/data/6CZJ.pdb.gz similarity index 100% rename from openfe/tests/data/6CZJ.pdb.gz rename to src/openfe/tests/data/6CZJ.pdb.gz diff --git a/openfe/tests/data/CN.sdf b/src/openfe/tests/data/CN.sdf similarity index 100% rename from openfe/tests/data/CN.sdf rename to src/openfe/tests/data/CN.sdf diff --git a/openfe/tests/data/cdk8/__init__.py b/src/openfe/tests/data/__init__.py similarity index 100% rename from openfe/tests/data/cdk8/__init__.py rename to src/openfe/tests/data/__init__.py diff --git a/openfe/tests/data/benzene_modifications.sdf b/src/openfe/tests/data/benzene_modifications.sdf similarity index 100% rename from openfe/tests/data/benzene_modifications.sdf rename to src/openfe/tests/data/benzene_modifications.sdf diff --git a/openfe/tests/data/eg5/__init__.py b/src/openfe/tests/data/cdk8/__init__.py similarity index 100% rename from openfe/tests/data/eg5/__init__.py rename to src/openfe/tests/data/cdk8/__init__.py diff --git a/openfe/tests/data/cdk8/cdk8_ligands.sdf b/src/openfe/tests/data/cdk8/cdk8_ligands.sdf similarity index 100% rename from openfe/tests/data/cdk8/cdk8_ligands.sdf rename to src/openfe/tests/data/cdk8/cdk8_ligands.sdf diff --git a/openfe/tests/data/cdk8/cdk8_protein.pdb b/src/openfe/tests/data/cdk8/cdk8_protein.pdb similarity index 100% rename from openfe/tests/data/cdk8/cdk8_protein.pdb rename to src/openfe/tests/data/cdk8/cdk8_protein.pdb diff --git a/openfe/tests/data/htf/__init__.py b/src/openfe/tests/data/eg5/__init__.py similarity index 100% rename from openfe/tests/data/htf/__init__.py rename to src/openfe/tests/data/eg5/__init__.py diff --git a/openfe/tests/data/eg5/eg5_cofactor.sdf b/src/openfe/tests/data/eg5/eg5_cofactor.sdf similarity index 100% rename from openfe/tests/data/eg5/eg5_cofactor.sdf rename to src/openfe/tests/data/eg5/eg5_cofactor.sdf diff --git a/openfe/tests/data/eg5/eg5_ligands.sdf b/src/openfe/tests/data/eg5/eg5_ligands.sdf similarity index 100% rename from openfe/tests/data/eg5/eg5_ligands.sdf rename to src/openfe/tests/data/eg5/eg5_ligands.sdf diff --git a/openfe/tests/data/eg5/eg5_protein.pdb b/src/openfe/tests/data/eg5/eg5_protein.pdb similarity index 100% rename from openfe/tests/data/eg5/eg5_protein.pdb rename to src/openfe/tests/data/eg5/eg5_protein.pdb diff --git a/openfe/tests/data/external_formats/__init__.py b/src/openfe/tests/data/external_formats/__init__.py similarity index 100% rename from openfe/tests/data/external_formats/__init__.py rename to src/openfe/tests/data/external_formats/__init__.py diff --git a/openfe/tests/data/external_formats/somebenzenes_edges.edge b/src/openfe/tests/data/external_formats/somebenzenes_edges.edge similarity index 100% rename from openfe/tests/data/external_formats/somebenzenes_edges.edge rename to src/openfe/tests/data/external_formats/somebenzenes_edges.edge diff --git a/openfe/tests/data/external_formats/somebenzenes_nes.dat b/src/openfe/tests/data/external_formats/somebenzenes_nes.dat similarity index 100% rename from openfe/tests/data/external_formats/somebenzenes_nes.dat rename to src/openfe/tests/data/external_formats/somebenzenes_nes.dat diff --git a/openfe/tests/data/lomap_basic/__init__.py b/src/openfe/tests/data/htf/__init__.py similarity index 100% rename from openfe/tests/data/lomap_basic/__init__.py rename to src/openfe/tests/data/htf/__init__.py diff --git a/openfe/tests/data/htf/t4_lysozyme_data/chlorobenzene.sdf b/src/openfe/tests/data/htf/t4_lysozyme_data/chlorobenzene.sdf similarity index 100% rename from openfe/tests/data/htf/t4_lysozyme_data/chlorobenzene.sdf rename to src/openfe/tests/data/htf/t4_lysozyme_data/chlorobenzene.sdf diff --git a/openfe/tests/data/htf/t4_lysozyme_data/fluorobenzene.sdf b/src/openfe/tests/data/htf/t4_lysozyme_data/fluorobenzene.sdf similarity index 100% rename from openfe/tests/data/htf/t4_lysozyme_data/fluorobenzene.sdf rename to src/openfe/tests/data/htf/t4_lysozyme_data/fluorobenzene.sdf diff --git a/openfe/tests/data/htf/t4_lysozyme_data/t4_lysozyme_solvated.pdb.gz b/src/openfe/tests/data/htf/t4_lysozyme_data/t4_lysozyme_solvated.pdb.gz similarity index 100% rename from openfe/tests/data/htf/t4_lysozyme_data/t4_lysozyme_solvated.pdb.gz rename to src/openfe/tests/data/htf/t4_lysozyme_data/t4_lysozyme_solvated.pdb.gz diff --git a/openfe/tests/data/lomap_basic/1,3,7-trimethylnaphthalene.mol2 b/src/openfe/tests/data/lomap_basic/1,3,7-trimethylnaphthalene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/1,3,7-trimethylnaphthalene.mol2 rename to src/openfe/tests/data/lomap_basic/1,3,7-trimethylnaphthalene.mol2 diff --git a/openfe/tests/data/lomap_basic/1-butyl-4-methylbenzene.mol2 b/src/openfe/tests/data/lomap_basic/1-butyl-4-methylbenzene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/1-butyl-4-methylbenzene.mol2 rename to src/openfe/tests/data/lomap_basic/1-butyl-4-methylbenzene.mol2 diff --git a/openfe/tests/data/lomap_basic/2,6-dimethylnaphthalene.mol2 b/src/openfe/tests/data/lomap_basic/2,6-dimethylnaphthalene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/2,6-dimethylnaphthalene.mol2 rename to src/openfe/tests/data/lomap_basic/2,6-dimethylnaphthalene.mol2 diff --git a/openfe/tests/data/lomap_basic/2-methyl-6-propylnaphthalene.mol2 b/src/openfe/tests/data/lomap_basic/2-methyl-6-propylnaphthalene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/2-methyl-6-propylnaphthalene.mol2 rename to src/openfe/tests/data/lomap_basic/2-methyl-6-propylnaphthalene.mol2 diff --git a/openfe/tests/data/lomap_basic/2-methylnaphthalene.mol2 b/src/openfe/tests/data/lomap_basic/2-methylnaphthalene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/2-methylnaphthalene.mol2 rename to src/openfe/tests/data/lomap_basic/2-methylnaphthalene.mol2 diff --git a/openfe/tests/data/lomap_basic/2-naftanol.mol2 b/src/openfe/tests/data/lomap_basic/2-naftanol.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/2-naftanol.mol2 rename to src/openfe/tests/data/lomap_basic/2-naftanol.mol2 diff --git a/openfe/tests/data/lomap_basic/README.md b/src/openfe/tests/data/lomap_basic/README.md similarity index 100% rename from openfe/tests/data/lomap_basic/README.md rename to src/openfe/tests/data/lomap_basic/README.md diff --git a/openfe/tests/data/openmm_rfe/__init__.py b/src/openfe/tests/data/lomap_basic/__init__.py similarity index 100% rename from openfe/tests/data/openmm_rfe/__init__.py rename to src/openfe/tests/data/lomap_basic/__init__.py diff --git a/openfe/tests/data/lomap_basic/methylcyclohexane.mol2 b/src/openfe/tests/data/lomap_basic/methylcyclohexane.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/methylcyclohexane.mol2 rename to src/openfe/tests/data/lomap_basic/methylcyclohexane.mol2 diff --git a/openfe/tests/data/lomap_basic/toluene.mol2 b/src/openfe/tests/data/lomap_basic/toluene.mol2 similarity index 100% rename from openfe/tests/data/lomap_basic/toluene.mol2 rename to src/openfe/tests/data/lomap_basic/toluene.mol2 diff --git a/openfe/tests/data/multi_molecule.sdf b/src/openfe/tests/data/multi_molecule.sdf similarity index 100% rename from openfe/tests/data/multi_molecule.sdf rename to src/openfe/tests/data/multi_molecule.sdf diff --git a/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz b/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz new file mode 100644 index 00000000..5cbc5d9f Binary files /dev/null and b/src/openfe/tests/data/openmm_afe/ABFEProtocol_json_results.json.gz differ diff --git a/src/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz b/src/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz new file mode 100644 index 00000000..fd78834f Binary files /dev/null and b/src/openfe/tests/data/openmm_afe/AHFEProtocol_json_results.gz differ diff --git a/openfe/tests/data/openmm_afe/T4_abfe_system.xml.bz2 b/src/openfe/tests/data/openmm_afe/T4_abfe_system.xml.bz2 similarity index 100% rename from openfe/tests/data/openmm_afe/T4_abfe_system.xml.bz2 rename to src/openfe/tests/data/openmm_afe/T4_abfe_system.xml.bz2 diff --git a/openfe/tests/data/openmm_afe/__init__.py b/src/openfe/tests/data/openmm_afe/__init__.py similarity index 100% rename from openfe/tests/data/openmm_afe/__init__.py rename to src/openfe/tests/data/openmm_afe/__init__.py diff --git a/openfe/tests/data/openmm_md/MDProtocol_json_results.gz b/src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz similarity index 100% rename from openfe/tests/data/openmm_md/MDProtocol_json_results.gz rename to src/openfe/tests/data/openmm_md/MDProtocol_json_results.gz diff --git a/openfe/tests/data/openmm_md/__init__.py b/src/openfe/tests/data/openmm_md/__init__.py similarity index 100% rename from openfe/tests/data/openmm_md/__init__.py rename to src/openfe/tests/data/openmm_md/__init__.py diff --git a/src/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz b/src/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz new file mode 100644 index 00000000..e3a327d2 Binary files /dev/null and b/src/openfe/tests/data/openmm_rfe/RHFEProtocol_json_results.gz differ diff --git a/openfe/tests/data/openmm_septop/__init__.py b/src/openfe/tests/data/openmm_rfe/__init__.py similarity index 100% rename from openfe/tests/data/openmm_septop/__init__.py rename to src/openfe/tests/data/openmm_rfe/__init__.py diff --git a/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_atoms.csv b/src/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_atoms.csv similarity index 100% rename from openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_atoms.csv rename to src/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_atoms.csv diff --git a/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_bonds.txt b/src/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_bonds.txt similarity index 100% rename from openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_bonds.txt rename to src/openfe/tests/data/openmm_rfe/benzene_toluene_hybrid_top/hybrid_topology_bonds.txt diff --git a/openfe/tests/data/openmm_rfe/charged_benzenes.sdf b/src/openfe/tests/data/openmm_rfe/charged_benzenes.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/charged_benzenes.sdf rename to src/openfe/tests/data/openmm_rfe/charged_benzenes.sdf diff --git a/openfe/tests/data/openmm_rfe/dummy_charge_ligand_23.sdf b/src/openfe/tests/data/openmm_rfe/dummy_charge_ligand_23.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/dummy_charge_ligand_23.sdf rename to src/openfe/tests/data/openmm_rfe/dummy_charge_ligand_23.sdf diff --git a/openfe/tests/data/openmm_rfe/dummy_charge_ligand_55.sdf b/src/openfe/tests/data/openmm_rfe/dummy_charge_ligand_55.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/dummy_charge_ligand_55.sdf rename to src/openfe/tests/data/openmm_rfe/dummy_charge_ligand_55.sdf diff --git a/openfe/tests/data/openmm_rfe/ligand_23.sdf b/src/openfe/tests/data/openmm_rfe/ligand_23.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/ligand_23.sdf rename to src/openfe/tests/data/openmm_rfe/ligand_23.sdf diff --git a/openfe/tests/data/openmm_rfe/ligand_55.sdf b/src/openfe/tests/data/openmm_rfe/ligand_55.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/ligand_55.sdf rename to src/openfe/tests/data/openmm_rfe/ligand_55.sdf diff --git a/openfe/tests/data/openmm_rfe/malt1_shapefit_1832577-09-9.sdf b/src/openfe/tests/data/openmm_rfe/malt1_shapefit_1832577-09-9.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/malt1_shapefit_1832577-09-9.sdf rename to src/openfe/tests/data/openmm_rfe/malt1_shapefit_1832577-09-9.sdf diff --git a/openfe/tests/data/openmm_rfe/malt1_shapefit_Pfizer-01-01.sdf b/src/openfe/tests/data/openmm_rfe/malt1_shapefit_Pfizer-01-01.sdf similarity index 100% rename from openfe/tests/data/openmm_rfe/malt1_shapefit_Pfizer-01-01.sdf rename to src/openfe/tests/data/openmm_rfe/malt1_shapefit_Pfizer-01-01.sdf diff --git a/openfe/tests/data/openmm_rfe/reference.xml b/src/openfe/tests/data/openmm_rfe/reference.xml similarity index 100% rename from openfe/tests/data/openmm_rfe/reference.xml rename to src/openfe/tests/data/openmm_rfe/reference.xml diff --git a/openfe/tests/data/openmm_rfe/vacuum_nocoord.nc b/src/openfe/tests/data/openmm_rfe/vacuum_nocoord.nc similarity index 100% rename from openfe/tests/data/openmm_rfe/vacuum_nocoord.nc rename to src/openfe/tests/data/openmm_rfe/vacuum_nocoord.nc diff --git a/openfe/tests/data/openmm_rfe/vacuum_nocoord_checkpoint.nc b/src/openfe/tests/data/openmm_rfe/vacuum_nocoord_checkpoint.nc similarity index 100% rename from openfe/tests/data/openmm_rfe/vacuum_nocoord_checkpoint.nc rename to src/openfe/tests/data/openmm_rfe/vacuum_nocoord_checkpoint.nc diff --git a/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz b/src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz similarity index 100% rename from openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz rename to src/openfe/tests/data/openmm_septop/SepTopProtocol_json_results.gz diff --git a/openfe/tests/data/serialization/__init__.py b/src/openfe/tests/data/openmm_septop/__init__.py similarity index 100% rename from openfe/tests/data/serialization/__init__.py rename to src/openfe/tests/data/openmm_septop/__init__.py diff --git a/openfe/tests/data/openmm_septop/system.xml.bz2 b/src/openfe/tests/data/openmm_septop/system.xml.bz2 similarity index 100% rename from openfe/tests/data/openmm_septop/system.xml.bz2 rename to src/openfe/tests/data/openmm_septop/system.xml.bz2 diff --git a/openfe/tests/dev/__init__.py b/src/openfe/tests/data/serialization/__init__.py similarity index 100% rename from openfe/tests/dev/__init__.py rename to src/openfe/tests/data/serialization/__init__.py diff --git a/openfe/tests/data/serialization/ethane_template.sdf b/src/openfe/tests/data/serialization/ethane_template.sdf similarity index 100% rename from openfe/tests/data/serialization/ethane_template.sdf rename to src/openfe/tests/data/serialization/ethane_template.sdf diff --git a/openfe/tests/data/serialization/network_template.graphml b/src/openfe/tests/data/serialization/network_template.graphml similarity index 100% rename from openfe/tests/data/serialization/network_template.graphml rename to src/openfe/tests/data/serialization/network_template.graphml diff --git a/openfe/tests/protocols/__init__.py b/src/openfe/tests/dev/__init__.py similarity index 100% rename from openfe/tests/protocols/__init__.py rename to src/openfe/tests/dev/__init__.py diff --git a/openfe/tests/dev/serialization_test_templates.py b/src/openfe/tests/dev/serialization_test_templates.py similarity index 100% rename from openfe/tests/dev/serialization_test_templates.py rename to src/openfe/tests/dev/serialization_test_templates.py diff --git a/openfe/tests/protocols/openmm_abfe/__init__.py b/src/openfe/tests/protocols/__init__.py similarity index 100% rename from openfe/tests/protocols/openmm_abfe/__init__.py rename to src/openfe/tests/protocols/__init__.py diff --git a/openfe/tests/protocols/conftest.py b/src/openfe/tests/protocols/conftest.py similarity index 85% rename from openfe/tests/protocols/conftest.py rename to src/openfe/tests/protocols/conftest.py index 46e724b1..b5f30294 100644 --- a/openfe/tests/protocols/conftest.py +++ b/src/openfe/tests/protocols/conftest.py @@ -1,12 +1,15 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe import gzip +import pathlib from importlib import resources from typing import Optional +import MDAnalysis as mda import openmm import pooch import pytest +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin from openff.units import Quantity, unit from openff.units.openmm import from_openmm from openmm import Platform @@ -14,6 +17,12 @@ from rdkit import Chem from rdkit.Geometry import Point3D import openfe +from openfe.data._registry import ( + POOCH_CACHE, + zenodo_industry_benchmark_systems, + zenodo_rfe_simulation_nc, + zenodo_t4_lysozyme_traj, +) @pytest.fixture @@ -279,22 +288,50 @@ def septop_json() -> str: return f.read().decode() # type: ignore -RFE_OUTPUT = pooch.create( - path=pooch.os_cache("openfe_analysis"), - base_url="doi:10.6084/m9.figshare.24101655", +pooch_industry_benchmark_systems = pooch.create( + path=POOCH_CACHE, + base_url=zenodo_industry_benchmark_systems["base_url"], registry={ - "checkpoint.nc": "5af398cb14340fddf7492114998b244424b6c3f4514b2e07e4bd411484c08464", - "db.json": "b671f9eb4daf9853f3e1645f9fd7c18150fd2a9bf17c18f23c5cf0c9fd5ca5b3", - "hybrid_system.pdb": "07203679cb14b840b36e4320484df2360f45e323faadb02d6eacac244fddd517", - "simulation.nc": "92361a0864d4359a75399470135f56642b72c605069a4c33dbc4be6f91f28b31", - "simulation_real_time_analysis.yaml": "65706002f371fafba96037f29b054fd7e050e442915205df88567f48f5e5e1cf", + zenodo_industry_benchmark_systems["fname"]: zenodo_industry_benchmark_systems["known_hash"] }, ) +@pytest.fixture +def industry_benchmark_files(): + pooch_industry_benchmark_systems.fetch( + "industry_benchmark_systems.zip", processor=pooch.Unzip() + ) + cache_dir = pathlib.Path( + POOCH_CACHE / "industry_benchmark_systems.zip.unzip/industry_benchmark_systems" + ) + return cache_dir + + +pooch_t4_lysozyme = pooch.create( + path=POOCH_CACHE, + base_url=zenodo_t4_lysozyme_traj["base_url"], + registry={zenodo_t4_lysozyme_traj["fname"]: zenodo_t4_lysozyme_traj["known_hash"]}, +) + + +# session scope for downstream reuse +@pytest.fixture(scope="session") +def t4_lysozyme_trajectory_dir(): + pooch_t4_lysozyme.fetch("t4_lysozyme_trajectory.zip", processor=pooch.Unzip()) + cache_dir = pathlib.Path( + POOCH_CACHE / "t4_lysozyme_trajectory.zip.unzip/t4_lysozyme_trajectory" + ) + return cache_dir + + @pytest.fixture def simulation_nc(): - return RFE_OUTPUT.fetch("simulation.nc") + return pooch.retrieve( + url=zenodo_rfe_simulation_nc["base_url"] + zenodo_rfe_simulation_nc["fname"], + known_hash=zenodo_rfe_simulation_nc["known_hash"], + path=POOCH_CACHE, + ) @pytest.fixture @@ -326,6 +363,20 @@ def get_available_openmm_platforms() -> set[str]: return working_platforms +class ModGufeTokenizableTestsMixin(GufeTokenizableTestsMixin): + """ + A modified gufe tokenizable tests mixin which allows + for repr to be lazily evaluated. + """ + + def test_repr(self, instance): + """ + Overwrites the base `test_repr` call. + """ + assert isinstance(repr(instance), str) + assert self.repr in repr(instance) + + def compute_energy( system: openmm.System, positions: openmm.unit.Quantity, diff --git a/openfe/tests/protocols/openmm_ahfe/__init__.py b/src/openfe/tests/protocols/openmm_abfe/__init__.py similarity index 100% rename from openfe/tests/protocols/openmm_ahfe/__init__.py rename to src/openfe/tests/protocols/openmm_abfe/__init__.py diff --git a/openfe/tests/protocols/openmm_abfe/conftest.py b/src/openfe/tests/protocols/openmm_abfe/conftest.py similarity index 100% rename from openfe/tests/protocols/openmm_abfe/conftest.py rename to src/openfe/tests/protocols/openmm_abfe/conftest.py diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_energies.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_energies.py similarity index 94% rename from openfe/tests/protocols/openmm_abfe/test_abfe_energies.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_energies.py index b842eaf7..55b3cd42 100644 --- a/openfe/tests/protocols/openmm_abfe/test_abfe_energies.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_energies.py @@ -17,11 +17,11 @@ from openmmtools.alchemy import ( ) from openfe.protocols import openmm_afe -from openfe.protocols.openmm_afe import ( - AbsoluteBindingComplexUnit, +from openfe.protocols.openmm_afe.abfe_units import ( + ABFEComplexSetupUnit, ) -from openfe.protocols.openmm_septop.utils import deserialize from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings +from openfe.protocols.openmm_utils.serialization import deserialize class AlchemStateRest(AlchemicalState): @@ -147,11 +147,11 @@ class TestT4EnergiesRegression: dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None) - complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)] + complex_units = [u for u in dag.protocol_units if isinstance(u, ABFEComplexSetupUnit)] with tmpdir.as_cwd(): - data = complex_units[0].run(dry=True)["debug"] - return data + results = complex_units[0].run(dry=True) + return results @staticmethod def get_energy_components( @@ -182,7 +182,7 @@ class TestT4EnergiesRegression: energies_ref = self.get_energy_components( t4_reference_system, t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_val, lambda_val, lambda_val, @@ -191,7 +191,7 @@ class TestT4EnergiesRegression: energies_val = self.get_energy_components( t4_validation_data["alchem_system"], t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_val, lambda_val, lambda_val, @@ -223,7 +223,7 @@ class TestT4EnergiesRegression: energies = self.get_energy_components( t4_validation_data["alchem_system"], t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_sterics=1.0, lambda_electrostatics=1.0, lambda_restraints=1.0, @@ -236,7 +236,7 @@ class TestT4EnergiesRegression: energies = self.get_energy_components( t4_validation_data["alchem_system"], t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_sterics=1.0, lambda_electrostatics=1.0, lambda_restraints=0.0, @@ -249,7 +249,7 @@ class TestT4EnergiesRegression: energies = self.get_energy_components( t4_validation_data["alchem_system"], t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_sterics=1.0, lambda_electrostatics=0.0, lambda_restraints=0.0, @@ -270,7 +270,7 @@ class TestT4EnergiesRegression: energies = self.get_energy_components( t4_validation_data["alchem_system"], t4_validation_data["alchem_indices"], - t4_validation_data["positions"], + t4_validation_data["debug_positions"], lambda_sterics=0.0, lambda_electrostatics=0.0, lambda_restraints=0.0, diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py similarity index 69% rename from openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py index c991a573..21d1009a 100644 --- a/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol.py @@ -1,6 +1,5 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -from importlib import resources from math import sqrt from unittest import mock @@ -23,9 +22,7 @@ from openmm import ( ) from openmm import unit as ommunit from openmmtools.alchemy import ( - AbsoluteAlchemicalFactory, AlchemicalRegion, - AlchemicalState, ) from openmmtools.multistate.multistatesampler import MultiStateSampler from openmmtools.tests.test_alchemy import ( @@ -38,11 +35,16 @@ import openfe from openfe import ChemicalSystem, SmallMoleculeComponent, SolventComponent from openfe.protocols import openmm_afe from openfe.protocols.openmm_afe import ( - AbsoluteBindingComplexUnit, AbsoluteBindingProtocol, - AbsoluteBindingSolventUnit, ) -from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings +from openfe.protocols.openmm_afe.abfe_units import ( + ABFEComplexSetupUnit, + ABFEComplexSimUnit, + ABFESolventSetupUnit, + ABFESolventSimUnit, +) + +from .utils import UNIT_TYPES, _get_units @pytest.fixture() @@ -75,37 +77,53 @@ def test_serialize_protocol(default_settings): assert protocol == ret -def test_unit_tagging(benzene_complex_dag, tmpdir): - # test that executing the units includes correct gen and repeat info +def test_repeat_units(benzene_modifications, T4_protein_component): + protocol = openmm_afe.AbsoluteBindingProtocol( + settings=openmm_afe.AbsoluteBindingProtocol.default_settings() + ) - dag_units = benzene_complex_dag.protocol_units + stateA = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "benzene": benzene_modifications["benzene"], + "solvent": gufe.SolventComponent(), + } + ) - with ( - mock.patch( - "openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - mock.patch( - "openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - ): - results = [] - for u in dag_units: - ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) - results.append(ret) + stateB = gufe.ChemicalSystem( + { + "protein": T4_protein_component, + "solvent": gufe.SolventComponent(), + } + ) - solv_repeats = set() - complex_repeats = set() - for ret in results: - assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs["generation"] == 0 - if ret.outputs["simtype"] == "complex": - complex_repeats.add(ret.outputs["repeat_id"]) - else: - solv_repeats.add(ret.outputs["repeat_id"]) - # Repeat ids are random ints so just check their lengths - assert len(complex_repeats) == len(solv_repeats) == 3 + dag = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + + # 6 protocol unit, 3 per repeat + pus = list(dag.protocol_units) + assert len(pus) == 18 + + # Check info for each repeat + for phase in ["solvent", "complex"]: + setup = _get_units(pus, UNIT_TYPES[phase]["setup"]) + sim = _get_units(pus, UNIT_TYPES[phase]["sim"]) + analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"]) + + # Should be 3 of each set + assert len(setup) == len(sim) == len(analysis) == 3 + + # Check that the dag chain is correct + for analysis_pu in analysis: + repeat_id = analysis_pu.inputs["repeat_id"] + setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0] + sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0] + assert analysis_pu.inputs["setup_results"] == setup_pu + assert analysis_pu.inputs["simulation_results"] == sim_pu + assert sim_pu.inputs["setup_results"] == setup_pu def test_create_independent_repeat_ids(benzene_modifications, T4_protein_component): @@ -137,12 +155,26 @@ def test_create_independent_repeat_ids(benzene_modifications, T4_protein_compone repeat_ids = set() for dag in dags: + # 3 sets of 6 units + assert len(list(dag.protocol_units)) == 18 for u in dag.protocol_units: repeat_ids.add(u.inputs["repeat_id"]) + # squashed by repeat_id, that's 2 sets of 6 assert len(repeat_ids) == 12 +def test_mda_universe_error(): + """ + Test that we get an error if we pass no positions or trajectory + when calling the mda Universe getter. + """ + with pytest.raises(ValueError, match="No positions to create"): + _ = openmm_afe.ABFEComplexSetupUnit._get_mda_universe( + topology="foo", positions=None, trajectory=None + ) + + class TestT4LysozymeDryRun: solvent = SolventComponent(ion_concentration=0 * offunit.molar) num_all_not_water = 2634 @@ -190,17 +222,32 @@ class TestT4LysozymeDryRun: ) @pytest.fixture(scope="class") - def complex_units(self, dag): - return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)] + def complex_setup_units(self, dag): + return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"]) @pytest.fixture(scope="class") - def solvent_units(self, dag): - return [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingSolventUnit)] + def complex_sim_units(self, dag): + return _get_units(dag.protocol_units, UNIT_TYPES["complex"]["sim"]) - def test_number_of_units(self, dag, complex_units, solvent_units): - assert len(list(dag.protocol_units)) == 2 - assert len(complex_units) == 1 - assert len(solvent_units) == 1 + @pytest.fixture(scope="class") + def solvent_setup_units(self, dag): + return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["setup"]) + + @pytest.fixture(scope="class") + def solvent_sim_units(self, dag): + return _get_units(dag.protocol_units, UNIT_TYPES["solvent"]["sim"]) + + def test_number_of_units( + self, + dag, + complex_setup_units, + complex_sim_units, + solvent_setup_units, + solvent_sim_units, + ): + assert len(list(dag.protocol_units)) == 6 + assert len(complex_setup_units) == len(complex_sim_units) == 1 + assert len(solvent_setup_units) == len(solvent_sim_units) == 1 def _assert_force_num(self, system, forcetype, number): forces = [f for f in system.getForces() if isinstance(f, forcetype)] @@ -345,83 +392,99 @@ class TestT4LysozymeDryRun: positions=positions, ) - def test_complex_dry_run(self, complex_units, settings, tmpdir): + def test_complex_dry_run(self, complex_setup_units, complex_sim_units, settings, tmpdir): with tmpdir.as_cwd(): - data = complex_units[0].run(dry=True, verbose=True)["debug"] + setup_results = complex_setup_units[0].run(dry=True, verbose=True) + sim_results = complex_sim_units[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=True, + dry=True, + ) # Check the sampler - self._verify_sampler(data["sampler"], complexed=True, settings=settings) + self._verify_sampler(sim_results["sampler"], complexed=True, settings=settings) # Check the alchemical system self._assert_expected_alchemical_forces( - data["alchem_system"], complexed=True, settings=settings + setup_results["alchem_system"], complexed=True, settings=settings ) - self._test_dodecahedron_vectors(data["alchem_system"]) + self._test_dodecahedron_vectors(setup_results["alchem_system"]) # Check the alchemical indices expected_indices = [i + self.num_complex_atoms for i in range(self.num_solvent_atoms)] - assert expected_indices == data["alchem_indices"] + assert expected_indices == setup_results["alchem_indices"] # Check the non-alchemical system - self._assert_expected_nonalchemical_forces(data["system"], settings) - self._test_dodecahedron_vectors(data["system"]) + self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings) + self._test_dodecahedron_vectors(setup_results["standard_system"]) # Check the box vectors haven't changed (they shouldn't have because we didn't do MD) assert_allclose( - from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()), - from_openmm(data["system"].getDefaultPeriodicBoxVectors()), + from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()), + from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()), ) # Check the PDB - pdb = mdt.load_pdb("alchemical_system.pdb") + pdb = mdt.load_pdb(setup_results["pdb_structure"]) assert pdb.n_atoms == self.num_all_not_water # Check energies - alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"]) + alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"]) self._test_energies( - reference_system=data["system"], - alchemical_system=data["alchem_system"], + reference_system=setup_results["standard_system"], + alchemical_system=setup_results["alchem_system"], alchemical_regions=alchem_region, - positions=data["positions"], + positions=setup_results["debug_positions"], ) - def test_solvent_dry_run(self, solvent_units, settings, tmpdir): + def test_solvent_dry_run(self, solvent_setup_units, solvent_sim_units, settings, tmpdir): with tmpdir.as_cwd(): - data = solvent_units[0].run(dry=True, verbose=True)["debug"] + setup_results = solvent_setup_units[0].run(dry=True, verbose=True) + sim_results = solvent_sim_units[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) # Check the sampler - self._verify_sampler(data["sampler"], complexed=False, settings=settings) + self._verify_sampler(sim_results["sampler"], complexed=False, settings=settings) # Check the alchemical system self._assert_expected_alchemical_forces( - data["alchem_system"], complexed=False, settings=settings + setup_results["alchem_system"], complexed=False, settings=settings ) - self._test_cubic_vectors(data["alchem_system"]) + self._test_cubic_vectors(setup_results["alchem_system"]) # Check the alchemical indices expected_indices = [i for i in range(self.num_solvent_atoms)] - assert expected_indices == data["alchem_indices"] + assert expected_indices == setup_results["alchem_indices"] # Check the non-alchemical system - self._assert_expected_nonalchemical_forces(data["system"], settings) - self._test_cubic_vectors(data["system"]) + self._assert_expected_nonalchemical_forces(setup_results["standard_system"], settings) + self._test_cubic_vectors(setup_results["standard_system"]) # Check the box vectors haven't changed (they shouldn't have because we didn't do MD) assert_allclose( - from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()), - from_openmm(data["system"].getDefaultPeriodicBoxVectors()), + from_openmm(setup_results["alchem_system"].getDefaultPeriodicBoxVectors()), + from_openmm(setup_results["standard_system"].getDefaultPeriodicBoxVectors()), ) # Check the PDB - pdb = mdt.load_pdb("alchemical_system.pdb") + pdb = mdt.load_pdb(setup_results["pdb_structure"]) assert pdb.n_atoms == self.num_solvent_atoms # Check energies - alchem_region = AlchemicalRegion(alchemical_atoms=data["alchem_indices"]) + alchem_region = AlchemicalRegion(alchemical_atoms=setup_results["alchem_indices"]) self._test_energies( - reference_system=data["system"], - alchemical_system=data["alchem_system"], + reference_system=setup_results["standard_system"], + alchemical_system=setup_results["alchem_system"], alchemical_regions=alchem_region, - positions=data["positions"], + positions=setup_results["debug_positions"], ) @@ -499,15 +562,17 @@ def test_user_charges(benzene_modifications, T4_protein_component, tmpdir): dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None) - complex_units = [u for u in dag.protocol_units if isinstance(u, AbsoluteBindingComplexUnit)] + complex_setup_units = _get_units(dag.protocol_units, UNIT_TYPES["complex"]["setup"]) with tmpdir.as_cwd(): - data = complex_units[0].run(dry=True)["debug"] + results = complex_setup_units[0].run(dry=True) - system_nbf = [f for f in data["system"].getForces() if isinstance(f, NonbondedForce)][0] + system_nbf = [ + f for f in results["standard_system"].getForces() if isinstance(f, NonbondedForce) + ][0] alchem_system_nbf = [ f - for f in data["alchem_system"].getForces() + for f in results["alchem_system"].getForces() if isinstance(f, NonbondedForce) ][0] # fmt: skip diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py similarity index 58% rename from openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py index 0314a1a1..5d815c71 100644 --- a/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_protocol_results.py @@ -3,6 +3,7 @@ import gzip import itertools import json +from pathlib import Path from unittest import mock import gufe @@ -14,25 +15,78 @@ import openfe from openfe.protocols import openmm_afe from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry +from .utils import UNIT_TYPES, _get_units -def test_gather(benzene_complex_dag, tmpdir): - # check that .gather behaves as expected + +@pytest.fixture() +def patcher(): with ( mock.patch( - "openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingSolventUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + "openfe.protocols.openmm_afe.abfe_units.ABFESolventSetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, + "standard_state_correction": 0 * offunit.kilocalorie_per_mole, + "restraint_geometry": None, + }, ), mock.patch( - "openfe.protocols.openmm_afe.equil_binding_afe_method.AbsoluteBindingComplexUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, + "openfe.protocols.openmm_afe.abfe_units.ABFEComplexSetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, + "standard_state_correction": 0 * offunit.kilocalorie_per_mole, + "restraint_geometry": True, + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.base_afe_units.np.load", + return_value=np.zeros(100), + ), + mock.patch( + "openfe.protocols.openmm_afe.base_afe_units.deserialize", + return_value="foo", + ), + mock.patch( + "openfe.protocols.openmm_afe.abfe_units.ABFEComplexSimUnit.run", + return_value={ + "trajectory": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.abfe_units.ABFESolventSimUnit.run", + return_value={ + "trajectory": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.abfe_units.ABFEComplexAnalysisUnit.run", + return_value={"foo": "bar"}, + ), + mock.patch( + "openfe.protocols.openmm_afe.abfe_units.ABFESolventAnalysisUnit.run", + return_value={"foo": "bar"}, ), ): - dagres = gufe.protocols.execute_DAG( - benzene_complex_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True, - ) + yield + + +def test_gather(benzene_complex_dag, patcher, tmpdir): + # check that .gather behaves as expected + dagres = gufe.protocols.execute_DAG( + benzene_complex_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) protocol = openmm_afe.AbsoluteBindingProtocol( settings=openmm_afe.AbsoluteBindingProtocol.default_settings(), @@ -43,6 +97,47 @@ def test_gather(benzene_complex_dag, tmpdir): assert isinstance(res, openmm_afe.AbsoluteBindingProtocolResult) +def test_unit_tagging(benzene_complex_dag, patcher, tmpdir): + # test that executing the units includes correct gen and repeat info + + dag_units = benzene_complex_dag.protocol_units + + for phase in ["solvent", "complex"]: + setup_results = {} + sim_results = {} + analysis_results = {} + + setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"]) + sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"]) + a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"]) + + for u in setup_units: + rid = u.inputs["repeat_id"] + setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir)) + + for u in sim_units: + rid = u.inputs["repeat_id"] + sim_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + ) + + for u in a_units: + rid = u.inputs["repeat_id"] + analysis_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + simulation_results=sim_results[rid], + ) + + for results in [setup_results, sim_results, analysis_results]: + for ret in results.values(): + assert isinstance(ret, gufe.ProtocolUnitResult) + assert ret.outputs["generation"] == 0 + + assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 + + class TestProtocolResult: @pytest.fixture() def protocolresult(self, abfe_transformation_json_path): @@ -62,7 +157,7 @@ class TestProtocolResult: est = protocolresult.get_estimate() assert est - assert est.m == pytest.approx(-21.71, abs=0.01) + assert est.m == pytest.approx(-21.35, abs=0.01) assert isinstance(est, offunit.Quantity) assert est.is_compatible_with(offunit.kilojoule_per_mole) @@ -70,7 +165,7 @@ class TestProtocolResult: est = protocolresult.get_uncertainty() assert est - assert est.m == pytest.approx(0.73, abs=0.01) + assert est.m == pytest.approx(1.04, abs=0.01) assert isinstance(est, offunit.Quantity) assert est.is_compatible_with(offunit.kilojoule_per_mole) @@ -176,12 +271,12 @@ class TestProtocolResult: assert isinstance(geom[0], BoreschRestraintGeometry) assert geom[0].guest_atoms == [1779, 1778, 1777] assert geom[0].host_atoms == [880, 865, 864] - assert pytest.approx(geom[0].r_aA0) == 1.083558 * offunit.nanometer - assert pytest.approx(geom[0].theta_A0) == 0.6786444 * offunit.radian - assert pytest.approx(geom[0].theta_B0) == 1.649905 * offunit.radian - assert pytest.approx(geom[0].phi_A0) == -0.3640583 * offunit.radian - assert pytest.approx(geom[0].phi_B0) == 1.892376 * offunit.radian - assert pytest.approx(geom[0].phi_C0) == -0.6106747 * offunit.radian + assert pytest.approx(geom[0].r_aA0, rel=1e-2) == 1.083558 * offunit.nanometer + assert pytest.approx(geom[0].theta_A0, rel=1e-2) == 0.711876 * offunit.radian + assert pytest.approx(geom[0].theta_B0, rel=1e-2) == 1.687366 * offunit.radian + assert pytest.approx(geom[0].phi_A0, rel=1e-2) == -0.2164231 * offunit.radian + assert pytest.approx(geom[0].phi_B0, rel=1e-2) == 1.892376 * offunit.radian + assert pytest.approx(geom[0].phi_C0, rel=1e-2) == -0.522031870 * offunit.radian @pytest.mark.parametrize( "key, expected_size", diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py similarity index 100% rename from openfe/tests/protocols/openmm_abfe/test_abfe_settings.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_settings.py diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_slow.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_slow.py similarity index 77% rename from openfe/tests/protocols/openmm_abfe/test_abfe_slow.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_slow.py index aa714766..3fb58178 100644 --- a/openfe/tests/protocols/openmm_abfe/test_abfe_slow.py +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_slow.py @@ -13,7 +13,7 @@ from openfe.protocols.openmm_utils.charge_generation import HAS_NAGL, HAS_OPENEY @pytest.mark.integration @pytest.mark.flaky(reruns=3) # pytest-rerunfailures; we can get bad minimisation @pytest.mark.skipif(not HAS_NAGL, reason="need NAGL") -@pytest.mark.xfail( +@pytest.mark.skipif( HAS_OPENEYE and HAS_NAGL, reason="NAGL/openeye incompatibility. See https://github.com/openforcefield/openff-nagl/issues/177", ) @@ -96,16 +96,36 @@ def test_openmm_run_engine( r = openfe.execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() - for pur in r.protocol_unit_results: - unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" - assert unit_shared.exists() - assert pathlib.Path(unit_shared).is_dir() - checkpoint = pur.outputs["last_checkpoint"] - assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc" - assert (unit_shared / checkpoint).exists() - nc = pur.outputs["nc"] - assert nc == unit_shared / f"{pur.outputs['simtype']}.nc" - assert nc.exists() + + # Check outputs of solvent & complex results + for phase in ["solvent", "complex"]: + purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase] + + # get the path to the simulation unit shared dict + for pur in purs: + if "Simulation" in pur.name: + sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert sim_shared.exists() + assert pathlib.Path(sim_shared).is_dir() + + # check the analysis outputs + for pur in purs: + if "Analysis" not in pur.name: + continue + + unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert unit_shared.exists() + assert pathlib.Path(unit_shared).is_dir() + + # Does the checkpoint file exist? + checkpoint = pur.outputs["checkpoint"] + assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc" + assert checkpoint.exists() + + # Does the trajectory file exist? + nc = pur.outputs["trajectory"] + assert nc == sim_shared / f"{pur.outputs['simtype']}.nc" + assert nc.exists() # Test results methods that need files present results = protocol.gather([r]) diff --git a/src/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py new file mode 100644 index 00000000..46390ca4 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_abfe/test_abfe_tokenization.py @@ -0,0 +1,161 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import gzip + +import pytest + +import openfe +from openfe.protocols.openmm_afe import ( + ABFEComplexAnalysisUnit, + ABFEComplexSetupUnit, + ABFEComplexSimUnit, + ABFESolventAnalysisUnit, + ABFESolventSetupUnit, + ABFESolventSimUnit, + AbsoluteBindingProtocol, + AbsoluteBindingProtocolResult, +) + +from ..conftest import ModGufeTokenizableTestsMixin + + +@pytest.fixture +def protocol(): + return AbsoluteBindingProtocol(AbsoluteBindingProtocol.default_settings()) + + +@pytest.fixture +def protocol_units(protocol, benzene_complex_system, T4_protein_component): + stateA = benzene_complex_system + stateB = openfe.ChemicalSystem( + {"protein": T4_protein_component, "solvent": openfe.SolventComponent()} + ) + pus = protocol.create( + stateA=stateA, + stateB=stateB, + mapping=None, + ) + return list(pus.protocol_units) + + +def _filter_units(pus, classtype): + for pu in pus: + if isinstance(pu, classtype): + return pu + + +@pytest.fixture +def complex_protocol_setup_unit(protocol_units): + return _filter_units(protocol_units, ABFEComplexSetupUnit) + + +@pytest.fixture +def complex_protocol_sim_unit(protocol_units): + return _filter_units(protocol_units, ABFEComplexSimUnit) + + +@pytest.fixture +def complex_protocol_analysis_unit(protocol_units): + return _filter_units(protocol_units, ABFEComplexAnalysisUnit) + + +@pytest.fixture +def solvent_protocol_setup_unit(protocol_units): + return _filter_units(protocol_units, ABFESolventSetupUnit) + + +@pytest.fixture +def solvent_protocol_sim_unit(protocol_units): + return _filter_units(protocol_units, ABFESolventSimUnit) + + +@pytest.fixture +def solvent_protocol_analysis_unit(protocol_units): + return _filter_units(protocol_units, ABFESolventAnalysisUnit) + + +@pytest.fixture +def protocol_result(abfe_transformation_json_path): + with gzip.open(abfe_transformation_json_path) as f: + pr = AbsoluteBindingProtocolResult.from_json(f) + return pr + + +class TestAbsoluteBindingProtocol(ModGufeTokenizableTestsMixin): + cls = AbsoluteBindingProtocol + key = None + repr = "AbsoluteBindingProtocol-" + + @pytest.fixture() + def instance(self, protocol): + return protocol + + +class TestABFESolventSetupUnit(ModGufeTokenizableTestsMixin): + cls = ABFESolventSetupUnit + repr = "ABFESolventSetupUnit(ABFE Setup: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_setup_unit): + return solvent_protocol_setup_unit + + +class TestABFESolventSimUnit(ModGufeTokenizableTestsMixin): + cls = ABFESolventSimUnit + repr = "ABFESolventSimUnit(ABFE Simulation: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_sim_unit): + return solvent_protocol_sim_unit + + +class TestABFESolventAnalysisUnit(ModGufeTokenizableTestsMixin): + cls = ABFESolventAnalysisUnit + repr = "ABFESolventAnalysisUnit(ABFE Analysis: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_analysis_unit): + return solvent_protocol_analysis_unit + + +class TestABFEComplexSetupUnit(ModGufeTokenizableTestsMixin): + cls = ABFEComplexSetupUnit + repr = "ABFEComplexSetupUnit(ABFE Setup: benzene complex leg" + key = None + + @pytest.fixture() + def instance(self, complex_protocol_setup_unit): + return complex_protocol_setup_unit + + +class TestABFEComplexSimUnit(ModGufeTokenizableTestsMixin): + cls = ABFEComplexSimUnit + repr = "ABFEComplexSimUnit(ABFE Simulation: benzene complex leg" + key = None + + @pytest.fixture() + def instance(self, complex_protocol_sim_unit): + return complex_protocol_sim_unit + + +class TestABFEComplexAnalysisUnit(ModGufeTokenizableTestsMixin): + cls = ABFEComplexAnalysisUnit + repr = "ABFEComplexAnalysisUnit(ABFE Analysis: benzene complex leg" + key = None + + @pytest.fixture() + def instance(self, complex_protocol_analysis_unit): + return complex_protocol_analysis_unit + + +class TestAbsoluteBindingProtocolResult(ModGufeTokenizableTestsMixin): + cls = AbsoluteBindingProtocolResult + key = None + repr = "AbsoluteBindingProtocolResult-" + + @pytest.fixture() + def instance(self, protocol_result): + return protocol_result diff --git a/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py b/src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py similarity index 100% rename from openfe/tests/protocols/openmm_abfe/test_abfe_validation.py rename to src/openfe/tests/protocols/openmm_abfe/test_abfe_validation.py diff --git a/src/openfe/tests/protocols/openmm_abfe/utils.py b/src/openfe/tests/protocols/openmm_abfe/utils.py new file mode 100644 index 00000000..89ded7f8 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_abfe/utils.py @@ -0,0 +1,30 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +from openfe.protocols.openmm_afe.abfe_units import ( + ABFEComplexAnalysisUnit, + ABFEComplexSetupUnit, + ABFEComplexSimUnit, + ABFESolventAnalysisUnit, + ABFESolventSetupUnit, + ABFESolventSimUnit, +) + +UNIT_TYPES = { + "solvent": { + "setup": ABFESolventSetupUnit, + "sim": ABFESolventSimUnit, + "analysis": ABFESolventAnalysisUnit, + }, + "complex": { + "setup": ABFEComplexSetupUnit, + "sim": ABFEComplexSimUnit, + "analysis": ABFEComplexAnalysisUnit, + }, +} + + +def _get_units(protocol_units, unit_type): + """ + Helper method to extract setup units. + """ + return [pu for pu in protocol_units if isinstance(pu, unit_type)] diff --git a/openfe/tests/protocols/openmm_md/__init__.py b/src/openfe/tests/protocols/openmm_ahfe/__init__.py similarity index 100% rename from openfe/tests/protocols/openmm_md/__init__.py rename to src/openfe/tests/protocols/openmm_ahfe/__init__.py diff --git a/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py similarity index 66% rename from openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py rename to src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py index 206ff98c..b2f2e387 100644 --- a/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol.py @@ -1,12 +1,9 @@ # This code is part of OpenFE and is licensed under the MIT license. # For details, see https://github.com/OpenFreeEnergy/openfe -import itertools -import json import sys from math import sqrt from unittest import mock -import gufe import mdtraj as mdt import numpy as np import pytest @@ -18,7 +15,6 @@ from openmm import ( CustomNonbondedForce, HarmonicAngleForce, HarmonicBondForce, - MonteCarloBarostat, NonbondedForce, PeriodicTorsionForce, ) @@ -29,16 +25,15 @@ from openfe import ChemicalSystem, SolventComponent from openfe.protocols import openmm_afe from openfe.protocols.openmm_afe import ( AbsoluteSolvationProtocol, - AbsoluteSolvationSolventUnit, - AbsoluteSolvationVacuumUnit, ) -from openfe.protocols.openmm_utils import system_validation from openfe.protocols.openmm_utils.charge_generation import ( HAS_ESPALOMA_CHARGE, HAS_NAGL, HAS_OPENEYE, ) +from .utils import UNIT_TYPES, _get_units + @pytest.fixture() def protocol_dry_settings(): @@ -68,6 +63,40 @@ def test_serialize_protocol(default_settings): assert protocol == ret +def test_repeat_units(benzene_system): + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=openmm_afe.AbsoluteSolvationProtocol.default_settings() + ) + + dag = protocol.create( + stateA=benzene_system, + stateB=ChemicalSystem({"solvent": SolventComponent()}), + mapping=None, + ) + + # 6 protocol unit, 3 per repeat + pus = list(dag.protocol_units) + assert len(pus) == 18 + + # Check info for each repeat + for phase in ["solvent", "vacuum"]: + setup = _get_units(pus, UNIT_TYPES[phase]["setup"]) + sim = _get_units(pus, UNIT_TYPES[phase]["sim"]) + analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"]) + + # Should be 3 of each set + assert len(setup) == len(sim) == len(analysis) == 3 + + # Check that the dag chain is correct + for analysis_pu in analysis: + repeat_id = analysis_pu.inputs["repeat_id"] + setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0] + sim_pu = [s for s in sim if s.inputs["repeat_id"] == repeat_id][0] + assert analysis_pu.inputs["setup_results"] == setup_pu + assert analysis_pu.inputs["simulation_results"] == sim_pu + assert sim_pu.inputs["setup_results"] == setup_pu + + def test_create_independent_repeat_ids(benzene_system): protocol = openmm_afe.AbsoluteSolvationProtocol( settings=openmm_afe.AbsoluteSolvationProtocol.default_settings() @@ -88,9 +117,12 @@ def test_create_independent_repeat_ids(benzene_system): repeat_ids = set() for dag in dags: + # 3 sets of 6 units + assert len(list(dag.protocol_units)) == 18 for u in dag.protocol_units: repeat_ids.add(u.inputs["repeat_id"]) + # squashed by repeat_id, that's 2 sets of 6 assert len(repeat_ids) == 12 @@ -143,7 +175,7 @@ def _verify_alchemical_sterics_force_parameters( @pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) -def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir): +def test_setup_dry_sim_vac_benzene(benzene_system, method, protocol_dry_settings, tmpdir): protocol_dry_settings.vacuum_simulation_settings.sampler_method = method protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings) @@ -161,21 +193,32 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd ) prot_units = list(dag.protocol_units) - assert len(prot_units) == 2 + assert len(prot_units) == 6 - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] + vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"]) + vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"]) - assert len(vac_unit) == 1 - assert len(sol_unit) == 1 + assert len(vac_setup_unit) == 1 + assert len(vac_sim_unit) == 1 with tmpdir.as_cwd(): - debug = vac_unit[0].run(dry=True)["debug"] - vac_sampler = debug["sampler"] - assert not vac_sampler.is_periodic + setup_results = vac_setup_unit[0].run(dry=True) + sim_results = vac_sim_unit[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) + + sampler = sim_results["sampler"] + assert isinstance(sampler, MultiStateSampler) + assert not sampler.is_periodic + assert sampler._thermodynamic_states[0].barostat is None # standard system - system = debug["system"] + system = setup_results["standard_system"] assert system.getNumParticles() == 12 assert len(system.getForces()) == 4 _assert_num_forces(system, NonbondedForce, 1) @@ -184,7 +227,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd _assert_num_forces(system, PeriodicTorsionForce, 1) # alchemical system - alchem_system = debug["alchem_system"] + alchem_system = setup_results["alchem_system"] assert alchem_system.getNumParticles() == 12 assert len(alchem_system.getForces()) == 12 _assert_num_forces(alchem_system, NonbondedForce, 1) @@ -212,7 +255,7 @@ def test_dry_run_vac_benzene(benzene_system, method, protocol_dry_settings, tmpd [0.35, 2.2, 1.5, 0, False], ], ) -def test_alchemical_settings_dry_run_vacuum( +def test_alchemical_settings_setup_vacuum( alpha, a, b, c, correction, benzene_system, protocol_dry_settings, tmpdir ): """ @@ -238,18 +281,18 @@ def test_alchemical_settings_dry_run_vacuum( ) prot_units = list(dag.protocol_units) - assert len(prot_units) == 2 + assert len(prot_units) == 6 - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] + vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"]) + vac_sim_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["sim"]) - assert len(vac_unit) == 1 - assert len(sol_unit) == 1 + assert len(vac_setup_unit) == 1 + assert len(vac_sim_unit) == 1 with tmpdir.as_cwd(): - debug = vac_unit[0].run(dry=True)["debug"] + results = vac_setup_unit[0].run(dry=True) - alchem_system = debug["alchem_system"] + alchem_system = results["alchem_system"] _assert_num_forces(alchem_system, NonbondedForce, 1) _assert_num_forces(alchem_system, CustomNonbondedForce, 4) _assert_num_forces(alchem_system, CustomBondForce, 4) @@ -291,16 +334,16 @@ def test_confgen_fail_AFE(benzene_system, protocol_dry_settings, tmpdir): mapping=None, ) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] + vac_setup_unit = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"]) with tmpdir.as_cwd(): with mock.patch("rdkit.Chem.AllChem.EmbedMultipleConfs", return_value=0): - vac_sampler = vac_unit[0].run(dry=True)["debug"]["sampler"] - - assert vac_sampler + # If this worked, the system will have been built + system = vac_setup_unit[0].run(dry=True)["alchem_system"] + assert system -def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir): +def test_setup_solv_benzene(benzene_system, protocol_dry_settings, tmpdir): protocol_dry_settings.solvent_output_settings.output_indices = "resname UNK" protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings) @@ -318,19 +361,25 @@ def test_dry_run_solv_benzene(benzene_system, protocol_dry_settings, tmpdir): ) prot_units = list(dag.protocol_units) - assert len(prot_units) == 2 + sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"]) + sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"]) - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)] - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] - - assert len(vac_unit) == 1 - assert len(sol_unit) == 1 + assert len(sol_setup_unit) == len(sol_sim_unit) == 1 with tmpdir.as_cwd(): - sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"] + setup_results = sol_setup_unit[0].run(dry=True) + sim_results = sol_sim_unit[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) + sol_sampler = sim_results["sampler"] assert sol_sampler.is_periodic - pdb = mdt.load_pdb("hybrid_system.pdb") + pdb = mdt.load_pdb(setup_results["pdb_structure"]) assert pdb.n_atoms == 12 @@ -363,14 +412,23 @@ def test_dry_run_vsite_fail(benzene_system, tmpdir, protocol_dry_settings): ) prot_units = list(dag.protocol_units) - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] + sol_setup_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"]) + sol_sim_unit = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"]) with tmpdir.as_cwd(): + setup_results = sol_setup_unit[0].run(dry=True) with pytest.raises(ValueError, match="are unstable"): - _ = sol_unit[0].run(dry=True) + sim_results = sol_sim_unit[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) -def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir): +def test_setup_dry_sim_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdir): protocol_dry_settings.vacuum_forcefield_settings.forcefields = [ "amber/ff14SB.xml", # ff14SB protein force field "amber/tip4pew_standard.xml", # FF we are testsing with the fun VS @@ -399,10 +457,20 @@ def test_dry_run_solv_benzene_tip4p(benzene_system, protocol_dry_settings, tmpdi ) prot_units = list(dag.protocol_units) - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] + sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"]) + sol_sim_units = _get_units(prot_units, UNIT_TYPES["solvent"]["sim"]) with tmpdir.as_cwd(): - sol_sampler = sol_unit[0].run(dry=True)["debug"]["sampler"] + setup_results = sol_setup_units[0].run(dry=True) + sim_results = sol_sim_units[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) + sol_sampler = sim_results["sampler"] assert sol_sampler.is_periodic @@ -425,11 +493,11 @@ def test_dry_run_solv_benzene_noncubic(benzene_system, protocol_dry_settings, tm ) prot_units = list(dag.protocol_units) - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)] + sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"]) with tmpdir.as_cwd(): - sampler = sol_unit[0].run(dry=True)["debug"]["sampler"] - system = sampler._thermodynamic_states[0].system + results = sol_setup_units[0].run(dry=True) + system = results["alchem_system"] vectors = system.getDefaultPeriodicBoxVectors() width = float(from_openmm(vectors)[0][0].to("nanometer").m) @@ -486,13 +554,13 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0] - sol_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationSolventUnit)][0] + vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"]) + sol_setup_units = _get_units(prot_units, UNIT_TYPES["solvent"]["setup"]) # check sol_unit charges with tmpdir.as_cwd(): - sampler = sol_unit.run(dry=True)["debug"]["sampler"] - system = sampler._thermodynamic_states[0].system + results = sol_setup_units[0].run(dry=True) + system = results["alchem_system"] nonbond = [f for f in system.getForces() if isinstance(f, NonbondedForce)] assert len(nonbond) == 1 @@ -506,8 +574,8 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s # check vac_unit charges with tmpdir.as_cwd(): - sampler = vac_unit.run(dry=True)["debug"]["sampler"] - system = sampler._thermodynamic_states[0].system + results = vac_setup_units[0].run(dry=True) + system = results["alchem_system"] nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)] assert len(nonbond) == 4 @@ -537,8 +605,8 @@ def test_dry_run_solv_user_charges_benzene(benzene_modifications, protocol_dry_s "rdkit", "nagl", marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith("darwin"), - reason="needs NAGL and/or on macos", + not HAS_NAGL or HAS_OPENEYE or sys.platform.startswith("darwin"), + reason="needs NAGL (without oechem) and/or on macos", ), ), pytest.param( @@ -572,12 +640,12 @@ def test_dry_run_charge_backends( dag = protocol.create(stateA=stateA, stateB=stateB, mapping=None) prot_units = list(dag.protocol_units) - vac_unit = [u for u in prot_units if isinstance(u, AbsoluteSolvationVacuumUnit)][0] + vac_setup_units = _get_units(prot_units, UNIT_TYPES["vacuum"]["setup"]) # check vac_unit charges with tmpdir.as_cwd(): - sampler = vac_unit.run(dry=True)["debug"]["sampler"] - system = sampler._thermodynamic_states[0].system + results = vac_setup_units[0].run(dry=True) + system = results["alchem_system"] nonbond = [f for f in system.getForces() if isinstance(f, CustomNonbondedForce)] assert len(nonbond) == 4 @@ -609,187 +677,6 @@ def benzene_solvation_dag(benzene_system, protocol_dry_settings): return protocol.create(stateA=stateA, stateB=stateB, mapping=None) -def test_unit_tagging(benzene_solvation_dag, tmpdir): - # test that executing the units includes correct gen and repeat info - - dag_units = benzene_solvation_dag.protocol_units - - with ( - mock.patch( - "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - mock.patch( - "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - ): - results = [] - for u in dag_units: - ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) - results.append(ret) - - solv_repeats = set() - vac_repeats = set() - for ret in results: - assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs["generation"] == 0 - if ret.outputs["simtype"] == "vacuum": - vac_repeats.add(ret.outputs["repeat_id"]) - else: - solv_repeats.add(ret.outputs["repeat_id"]) - # Repeat ids are random ints so just check their lengths - assert len(vac_repeats) == len(solv_repeats) == 3 - - -def test_gather(benzene_solvation_dag, tmpdir): - # check that .gather behaves as expected - with ( - mock.patch( - "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationSolventUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - mock.patch( - "openfe.protocols.openmm_afe.equil_solvation_afe_method.AbsoluteSolvationVacuumUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chck.nc"}, - ), - ): - dagres = gufe.protocols.execute_DAG( - benzene_solvation_dag, - shared_basedir=tmpdir, - scratch_basedir=tmpdir, - keep_shared=True, - ) - - protocol = AbsoluteSolvationProtocol( - settings=AbsoluteSolvationProtocol.default_settings(), - ) - - res = protocol.gather([dagres]) - - assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult) - - -class TestProtocolResult: - @pytest.fixture() - def protocolresult(self, afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - - pr = openfe.ProtocolResult.from_dict(d["protocol_result"]) - - return pr - - def test_reload_protocol_result(self, afe_solv_transformation_json): - d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) - - pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) - - assert pr - - def test_get_estimate(self, protocolresult): - est = protocolresult.get_estimate() - - assert est - assert est.m == pytest.approx(-2.47, abs=0.5) - assert isinstance(est, offunit.Quantity) - assert est.is_compatible_with(offunit.kilojoule_per_mole) - - def test_get_uncertainty(self, protocolresult): - est = protocolresult.get_uncertainty() - - assert est - assert est.m == pytest.approx(0.2, abs=0.2) - assert isinstance(est, offunit.Quantity) - assert est.is_compatible_with(offunit.kilojoule_per_mole) - - def test_get_individual(self, protocolresult): - inds = protocolresult.get_individual_estimates() - - assert isinstance(inds, dict) - assert isinstance(inds["solvent"], list) - assert isinstance(inds["vacuum"], list) - assert len(inds["solvent"]) == len(inds["vacuum"]) == 3 - for e, u in itertools.chain(inds["solvent"], inds["vacuum"]): - assert e.is_compatible_with(offunit.kilojoule_per_mole) - assert u.is_compatible_with(offunit.kilojoule_per_mole) - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_get_forwards_etc(self, key, protocolresult): - far = protocolresult.get_forward_and_reverse_energy_analysis() - - assert isinstance(far, dict) - assert isinstance(far[key], list) - far1 = far[key][0] - assert isinstance(far1, dict) - - for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]: - assert k in far1 - - if k == "fractions": - assert isinstance(far1[k], np.ndarray) - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_get_frwd_reverse_none_return(self, key, protocolresult): - # fetch the first result of type key - data = [i for i in protocolresult.data[key].values()][0][0] - # set the output to None - data.outputs["forward_and_reverse_energies"] = None - - # now fetch the analysis results and expect a warning - wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}" - with pytest.warns(UserWarning, match=wmsg): - protocolresult.get_forward_and_reverse_energy_analysis() - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_get_overlap_matrices(self, key, protocolresult): - ovp = protocolresult.get_overlap_matrices() - - assert isinstance(ovp, dict) - assert isinstance(ovp[key], list) - assert len(ovp[key]) == 3 - - ovp1 = ovp[key][0] - assert isinstance(ovp1["matrix"], np.ndarray) - assert ovp1["matrix"].shape == (14, 14) - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_get_replica_transition_statistics(self, key, protocolresult): - rpx = protocolresult.get_replica_transition_statistics() - - assert isinstance(rpx, dict) - assert isinstance(rpx[key], list) - assert len(rpx[key]) == 3 - rpx1 = rpx[key][0] - assert "eigenvalues" in rpx1 - assert "matrix" in rpx1 - assert rpx1["eigenvalues"].shape == (14,) - assert rpx1["matrix"].shape == (14, 14) - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_equilibration_iterations(self, key, protocolresult): - eq = protocolresult.equilibration_iterations() - - assert isinstance(eq, dict) - assert isinstance(eq[key], list) - assert len(eq[key]) == 3 - assert all(isinstance(v, float) for v in eq[key]) - - @pytest.mark.parametrize("key", ["solvent", "vacuum"]) - def test_production_iterations(self, key, protocolresult): - prod = protocolresult.production_iterations() - - assert isinstance(prod, dict) - assert isinstance(prod[key], list) - assert len(prod[key]) == 3 - assert all(isinstance(v, float) for v in prod[key]) - - def test_filenotfound_replica_states(self, protocolresult): - errmsg = "File could not be found" - - with pytest.raises(ValueError, match=errmsg): - protocolresult.get_replica_states() - - @pytest.mark.parametrize( "positions_write_frequency,velocities_write_frequency", [ @@ -821,7 +708,6 @@ def test_dry_run_vacuum_write_frequency( stateB = ChemicalSystem({"solvent": SolventComponent()}) # Create DAG from protocol, get the vacuum and solvent units - # and eventually dry run the first solvent unit dag = protocol.create( stateA=stateA, stateB=stateB, @@ -829,11 +715,23 @@ def test_dry_run_vacuum_write_frequency( ) prot_units = list(dag.protocol_units) - assert len(prot_units) == 2 + assert len(prot_units) == 6 - with tmpdir.as_cwd(): - for u in prot_units: - sampler = u.run(dry=True)["debug"]["sampler"] + for phase in ["solvent", "vacuum"]: + setup_units = _get_units(prot_units, UNIT_TYPES[phase]["setup"]) + sim_units = _get_units(prot_units, UNIT_TYPES[phase]["sim"]) + + with tmpdir.as_cwd(): + setup_results = setup_units[0].run(dry=True) + sim_results = sim_units[0].run( + system=setup_results["alchem_system"], + positions=setup_results["debug_positions"], + selection_indices=setup_results["selection_indices"], + box_vectors=setup_results["box_vectors"], + alchemical_restraints=False, + dry=True, + ) + sampler = sim_results["sampler"] reporter = sampler._reporter if positions_write_frequency: assert reporter.position_interval == positions_write_frequency.m diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py new file mode 100644 index 00000000..0cb2d2d2 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_protocol_results.py @@ -0,0 +1,278 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import itertools +import json +from pathlib import Path +from unittest import mock + +import gufe +import numpy as np +import pytest +from openff.units import unit as offunit + +import openfe +from openfe import ChemicalSystem, SolventComponent +from openfe.protocols import openmm_afe + +from .utils import UNIT_TYPES, _get_units + + +@pytest.fixture() +def protocol_dry_settings(): + settings = openmm_afe.AbsoluteSolvationProtocol.default_settings() + settings.vacuum_engine_settings.compute_platform = None + settings.solvent_engine_settings.compute_platform = None + settings.protocol_repeats = 1 + return settings + + +@pytest.fixture +def benzene_solvation_dag(benzene_system, protocol_dry_settings): + protocol_dry_settings.protocol_repeats = 3 + protocol = openmm_afe.AbsoluteSolvationProtocol(settings=protocol_dry_settings) + + stateA = benzene_system + + stateB = ChemicalSystem({"solvent": SolventComponent()}) + + return protocol.create(stateA=stateA, stateB=stateB, mapping=None) + + +@pytest.fixture +def patcher(): + with ( + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFESolventSetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, + "standard_state_correction": 0 * offunit.kilocalorie_per_mole, + "restraint_geometry": None, + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + "box_vectors": [np.zeros(3), np.zeros(3), np.zeros(3)] * offunit.nm, + "standard_state_correction": 0 * offunit.kilocalorie_per_mole, + "restraint_geometry": None, + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.base_afe_units.np.load", + return_value=np.zeros(100), + ), + mock.patch( + "openfe.protocols.openmm_afe.base_afe_units.deserialize", + return_value="foo", + ), + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFESolventSimUnit.run", + return_value={ + "trajectory": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumSimUnit.run", + return_value={ + "trajectory": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFESolventAnalysisUnit.run", + return_value={"foo": "bar"}, + ), + mock.patch( + "openfe.protocols.openmm_afe.ahfe_units.AHFEVacuumAnalysisUnit.run", + return_value={"foo": "bar"}, + ), + ): + yield + + +def test_gather(benzene_solvation_dag, patcher, tmpdir): + # check that .gather behaves as expected + dagres = gufe.protocols.execute_DAG( + benzene_solvation_dag, + shared_basedir=tmpdir, + scratch_basedir=tmpdir, + keep_shared=True, + ) + + protocol = openmm_afe.AbsoluteSolvationProtocol( + settings=openmm_afe.AbsoluteSolvationProtocol.default_settings(), + ) + + res = protocol.gather([dagres]) + + assert isinstance(res, openmm_afe.AbsoluteSolvationProtocolResult) + + +def test_unit_tagging(benzene_solvation_dag, patcher, tmpdir): + # test that executing the units includes correct gen and repeat info + + dag_units = benzene_solvation_dag.protocol_units + + for phase in ["solvent", "vacuum"]: + setup_results = {} + sim_results = {} + analysis_results = {} + + setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"]) + sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"]) + a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"]) + + for u in setup_units: + rid = u.inputs["repeat_id"] + setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir)) + + for u in sim_units: + rid = u.inputs["repeat_id"] + sim_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + ) + + for u in a_units: + rid = u.inputs["repeat_id"] + analysis_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + simulation_results=sim_results[rid], + ) + + for results in [setup_results, sim_results, analysis_results]: + for ret in results.values(): + assert isinstance(ret, gufe.ProtocolUnitResult) + assert ret.outputs["generation"] == 0 + + assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 + + +class TestProtocolResult: + @pytest.fixture() + def protocolresult(self, afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) + + pr = openfe.ProtocolResult.from_dict(d["protocol_result"]) + + return pr + + def test_reload_protocol_result(self, afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) + + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) + + assert pr + + def test_get_estimate(self, protocolresult): + est = protocolresult.get_estimate() + + assert est + assert est.m == pytest.approx(-2.47, abs=0.5) + assert isinstance(est, offunit.Quantity) + assert est.is_compatible_with(offunit.kilojoule_per_mole) + + def test_get_uncertainty(self, protocolresult): + est = protocolresult.get_uncertainty() + + assert est + assert est.m == pytest.approx(0.2, abs=0.2) + assert isinstance(est, offunit.Quantity) + assert est.is_compatible_with(offunit.kilojoule_per_mole) + + def test_get_individual(self, protocolresult): + inds = protocolresult.get_individual_estimates() + + assert isinstance(inds, dict) + assert isinstance(inds["solvent"], list) + assert isinstance(inds["vacuum"], list) + assert len(inds["solvent"]) == len(inds["vacuum"]) == 3 + for e, u in itertools.chain(inds["solvent"], inds["vacuum"]): + assert e.is_compatible_with(offunit.kilojoule_per_mole) + assert u.is_compatible_with(offunit.kilojoule_per_mole) + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_get_forwards_etc(self, key, protocolresult): + far = protocolresult.get_forward_and_reverse_energy_analysis() + + assert isinstance(far, dict) + assert isinstance(far[key], list) + far1 = far[key][0] + assert isinstance(far1, dict) + + for k in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]: + assert k in far1 + + if k == "fractions": + assert isinstance(far1[k], np.ndarray) + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_get_frwd_reverse_none_return(self, key, protocolresult): + # fetch the first result of type key + data = [i for i in protocolresult.data[key].values()][0][0] + # set the output to None + data.outputs["forward_and_reverse_energies"] = None + + # now fetch the analysis results and expect a warning + wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}" + with pytest.warns(UserWarning, match=wmsg): + protocolresult.get_forward_and_reverse_energy_analysis() + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_get_overlap_matrices(self, key, protocolresult): + ovp = protocolresult.get_overlap_matrices() + + assert isinstance(ovp, dict) + assert isinstance(ovp[key], list) + assert len(ovp[key]) == 3 + + ovp1 = ovp[key][0] + assert isinstance(ovp1["matrix"], np.ndarray) + assert ovp1["matrix"].shape == (14, 14) + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_get_replica_transition_statistics(self, key, protocolresult): + rpx = protocolresult.get_replica_transition_statistics() + + assert isinstance(rpx, dict) + assert isinstance(rpx[key], list) + assert len(rpx[key]) == 3 + rpx1 = rpx[key][0] + assert "eigenvalues" in rpx1 + assert "matrix" in rpx1 + assert rpx1["eigenvalues"].shape == (14,) + assert rpx1["matrix"].shape == (14, 14) + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_equilibration_iterations(self, key, protocolresult): + eq = protocolresult.equilibration_iterations() + + assert isinstance(eq, dict) + assert isinstance(eq[key], list) + assert len(eq[key]) == 3 + assert all(isinstance(v, float) for v in eq[key]) + + @pytest.mark.parametrize("key", ["solvent", "vacuum"]) + def test_production_iterations(self, key, protocolresult): + prod = protocolresult.production_iterations() + + assert isinstance(prod, dict) + assert isinstance(prod[key], list) + assert len(prod[key]) == 3 + assert all(isinstance(v, float) for v in prod[key]) + + def test_filenotfound_replica_states(self, protocolresult): + errmsg = "File could not be found" + + with pytest.raises(ValueError, match=errmsg): + protocolresult.get_replica_states() diff --git a/openfe/tests/protocols/openmm_ahfe/test_ahfe_settings.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_settings.py similarity index 100% rename from openfe/tests/protocols/openmm_ahfe/test_ahfe_settings.py rename to src/openfe/tests/protocols/openmm_ahfe/test_ahfe_settings.py diff --git a/openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py similarity index 74% rename from openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py rename to src/openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py index 81ba1f4c..353becde 100644 --- a/openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_slow.py @@ -79,16 +79,36 @@ def test_openmm_run_engine( r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() - for pur in r.protocol_unit_results: - unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" - assert unit_shared.exists() - assert pathlib.Path(unit_shared).is_dir() - checkpoint = pur.outputs["last_checkpoint"] - assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc" - assert (unit_shared / checkpoint).exists() - nc = pur.outputs["nc"] - assert nc == unit_shared / f"{pur.outputs['simtype']}.nc" - assert nc.exists() + + # Check outputs of solvent & vacuum results + for phase in ["solvent", "vacuum"]: + purs = [pur for pur in r.protocol_unit_results if pur.outputs["simtype"] == phase] + + # get the path to the simulation unit shared dict + for pur in purs: + if "Simulation" in pur.name: + sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert sim_shared.exists() + assert pathlib.Path(sim_shared).is_dir() + + # check the analysis outputs + for pur in purs: + if "Analysis" not in pur.name: + continue + + unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert unit_shared.exists() + assert pathlib.Path(unit_shared).is_dir() + + # Does the checkpoint file exist? + checkpoint = pur.outputs["checkpoint"] + assert checkpoint == sim_shared / f"{pur.outputs['simtype']}_checkpoint.nc" + assert checkpoint.exists() + + # Does the trajectory file exist? + nc = pur.outputs["trajectory"] + assert nc == sim_shared / f"{pur.outputs['simtype']}.nc" + assert nc.exists() # Test results methods that need files present results = protocol.gather([r]) diff --git a/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py new file mode 100644 index 00000000..8c919443 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_tokenization.py @@ -0,0 +1,159 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +import json + +import gufe +import pytest + +import openfe +from openfe.protocols import openmm_afe +from openfe.protocols.openmm_afe import ( + AHFESolventAnalysisUnit, + AHFESolventSetupUnit, + AHFESolventSimUnit, + AHFEVacuumAnalysisUnit, + AHFEVacuumSetupUnit, + AHFEVacuumSimUnit, +) + +from ..conftest import ModGufeTokenizableTestsMixin + + +@pytest.fixture +def protocol(): + return openmm_afe.AbsoluteSolvationProtocol( + openmm_afe.AbsoluteSolvationProtocol.default_settings() + ) + + +@pytest.fixture +def protocol_units(protocol, benzene_system): + pus = protocol.create( + stateA=benzene_system, + stateB=openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}), + mapping=None, + ) + return list(pus.protocol_units) + + +def _filter_units(pus, classtype): + for pu in pus: + if isinstance(pu, classtype): + return pu + + +@pytest.fixture +def solvent_protocol_setup_unit(protocol_units): + return _filter_units(protocol_units, AHFESolventSetupUnit) + + +@pytest.fixture +def solvent_protocol_sim_unit(protocol_units): + return _filter_units(protocol_units, AHFESolventSimUnit) + + +@pytest.fixture +def solvent_protocol_analysis_unit(protocol_units): + return _filter_units(protocol_units, AHFESolventAnalysisUnit) + + +@pytest.fixture +def vacuum_protocol_setup_unit(protocol_units): + return _filter_units(protocol_units, AHFEVacuumSetupUnit) + + +@pytest.fixture +def vacuum_protocol_sim_unit(protocol_units): + return _filter_units(protocol_units, AHFEVacuumSimUnit) + + +@pytest.fixture +def vacuum_protocol_analysis_unit(protocol_units): + return _filter_units(protocol_units, AHFEVacuumAnalysisUnit) + + +@pytest.fixture +def protocol_result(afe_solv_transformation_json): + d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder) + pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"]) + return pr + + +class TestAbsoluteSolvationProtocol(ModGufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationProtocol + key = None + repr = "AbsoluteSolvationProtocol-" + + @pytest.fixture() + def instance(self, protocol): + return protocol + + +class TestAHFESolventSetupUnit(ModGufeTokenizableTestsMixin): + cls = AHFESolventSetupUnit + repr = "AHFESolventSetupUnit(AHFE Setup: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_setup_unit): + return solvent_protocol_setup_unit + + +class TestAHFESolventSimUnit(ModGufeTokenizableTestsMixin): + cls = AHFESolventSimUnit + repr = "AHFESolventSimUnit(AHFE Simulation: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_sim_unit): + return solvent_protocol_sim_unit + + +class TestAHFESolventAnalysisUnit(ModGufeTokenizableTestsMixin): + cls = AHFESolventAnalysisUnit + repr = "AHFESolventAnalysisUnit(AHFE Analysis: benzene solvent leg" + key = None + + @pytest.fixture() + def instance(self, solvent_protocol_analysis_unit): + return solvent_protocol_analysis_unit + + +class TestAHFEVacuumSetupUnit(ModGufeTokenizableTestsMixin): + cls = AHFEVacuumSetupUnit + repr = "AHFEVacuumSetupUnit(AHFE Setup: benzene vacuum leg" + key = None + + @pytest.fixture() + def instance(self, vacuum_protocol_setup_unit): + return vacuum_protocol_setup_unit + + +class TestAHFEVacuumSimUnit(ModGufeTokenizableTestsMixin): + cls = AHFEVacuumSimUnit + repr = "AHFEVacuumSimUnit(AHFE Simulation: benzene vacuum leg" + key = None + + @pytest.fixture() + def instance(self, vacuum_protocol_sim_unit): + return vacuum_protocol_sim_unit + + +class TestAHFEVacuumAnalysisUnit(ModGufeTokenizableTestsMixin): + cls = AHFEVacuumAnalysisUnit + repr = "AHFEVacuumAnalysisUnit(AHFE Analysis: benzene vacuum leg" + key = None + + @pytest.fixture() + def instance(self, vacuum_protocol_analysis_unit): + return vacuum_protocol_analysis_unit + + +class TestAbsoluteSolvationProtocolResult(ModGufeTokenizableTestsMixin): + cls = openmm_afe.AbsoluteSolvationProtocolResult + key = None + repr = "AbsoluteSolvationProtocolResult-" + + @pytest.fixture() + def instance(self, protocol_result): + return protocol_result diff --git a/openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py similarity index 98% rename from openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py rename to src/openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py index 9e747472..4db8c4ea 100644 --- a/openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py +++ b/src/openfe/tests/protocols/openmm_ahfe/test_ahfe_validation.py @@ -8,15 +8,8 @@ from openfe import ChemicalSystem, SolventComponent from openfe.protocols import openmm_afe from openfe.protocols.openmm_afe import ( AbsoluteSolvationProtocol, - AbsoluteSolvationSolventUnit, - AbsoluteSolvationVacuumUnit, ) from openfe.protocols.openmm_utils import system_validation -from openfe.protocols.openmm_utils.charge_generation import ( - HAS_ESPALOMA_CHARGE, - HAS_NAGL, - HAS_OPENEYE, -) @pytest.fixture() diff --git a/src/openfe/tests/protocols/openmm_ahfe/utils.py b/src/openfe/tests/protocols/openmm_ahfe/utils.py new file mode 100644 index 00000000..39108188 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_ahfe/utils.py @@ -0,0 +1,31 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe +from openfe.protocols.openmm_afe import ( + AbsoluteSolvationProtocol, + AHFESolventAnalysisUnit, + AHFESolventSetupUnit, + AHFESolventSimUnit, + AHFEVacuumAnalysisUnit, + AHFEVacuumSetupUnit, + AHFEVacuumSimUnit, +) + +UNIT_TYPES = { + "solvent": { + "setup": AHFESolventSetupUnit, + "sim": AHFESolventSimUnit, + "analysis": AHFESolventAnalysisUnit, + }, + "vacuum": { + "setup": AHFEVacuumSetupUnit, + "sim": AHFEVacuumSimUnit, + "analysis": AHFEVacuumAnalysisUnit, + }, +} + + +def _get_units(protocol_units, unit_type): + """ + Helper method to extract setup units. + """ + return [pu for pu in protocol_units if isinstance(pu, unit_type)] diff --git a/openfe/tests/protocols/openmm_rfe/__init__.py b/src/openfe/tests/protocols/openmm_md/__init__.py similarity index 100% rename from openfe/tests/protocols/openmm_rfe/__init__.py rename to src/openfe/tests/protocols/openmm_md/__init__.py diff --git a/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py similarity index 99% rename from openfe/tests/protocols/openmm_md/test_plain_md_protocol.py rename to src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py index f7dddfb0..60c7e8c4 100644 --- a/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py +++ b/src/openfe/tests/protocols/openmm_md/test_plain_md_protocol.py @@ -243,8 +243,8 @@ def test_dry_run_espaloma_vacuum_user_charges(benzene_modifications, vac_setting "rdkit", "nagl", marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith("darwin"), - reason="needs NAGL and/or on macos", + not HAS_NAGL or HAS_OPENEYE or sys.platform.startswith("darwin"), + reason="needs NAGL (without oechem) and/or on macos", ), ), pytest.param( diff --git a/openfe/tests/protocols/openmm_md/test_plain_md_slow.py b/src/openfe/tests/protocols/openmm_md/test_plain_md_slow.py similarity index 100% rename from openfe/tests/protocols/openmm_md/test_plain_md_slow.py rename to src/openfe/tests/protocols/openmm_md/test_plain_md_slow.py diff --git a/openfe/tests/protocols/openmm_md/test_plain_md_tokenization.py b/src/openfe/tests/protocols/openmm_md/test_plain_md_tokenization.py similarity index 100% rename from openfe/tests/protocols/openmm_md/test_plain_md_tokenization.py rename to src/openfe/tests/protocols/openmm_md/test_plain_md_tokenization.py diff --git a/openfe/tests/protocols/openmm_septop/__init__.py b/src/openfe/tests/protocols/openmm_rfe/__init__.py similarity index 100% rename from openfe/tests/protocols/openmm_septop/__init__.py rename to src/openfe/tests/protocols/openmm_rfe/__init__.py diff --git a/openfe/tests/protocols/openmm_rfe/helpers.py b/src/openfe/tests/protocols/openmm_rfe/helpers.py similarity index 100% rename from openfe/tests/protocols/openmm_rfe/helpers.py rename to src/openfe/tests/protocols/openmm_rfe/helpers.py diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_factory.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_factory.py similarity index 100% rename from openfe/tests/protocols/openmm_rfe/test_hybrid_factory.py rename to src/openfe/tests/protocols/openmm_rfe/test_hybrid_factory.py diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py similarity index 81% rename from openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py rename to src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py index f5ea92cf..bd7a1f72 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_protocol.py @@ -30,9 +30,10 @@ import openfe from openfe import setup from openfe.protocols import openmm_rfe from openfe.protocols.openmm_rfe._rfe_utils import topologyhelpers -from openfe.protocols.openmm_rfe.equil_rfe_methods import ( - _get_alchemical_charge_difference, - _validate_alchemical_components, +from openfe.protocols.openmm_rfe.hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, ) from openfe.protocols.openmm_utils import omm_compute, system_creation from openfe.protocols.openmm_utils.charge_generation import ( @@ -42,6 +43,13 @@ from openfe.protocols.openmm_utils.charge_generation import ( ) +def _get_units(protocol_units, unit_type): + """ + Helper method to extract setup units + """ + return [pu for pu in protocol_units if isinstance(pu, unit_type)] + + @pytest.fixture() def vac_settings(): settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() @@ -167,6 +175,42 @@ def test_serialize_protocol(): assert protocol == ret +def test_repeat_units(benzene_system, toluene_system, benzene_to_toluene_mapping): + settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + + protocol = openmm_rfe.RelativeHybridTopologyProtocol( + settings=settings, + ) + + dag = protocol.create( + stateA=benzene_system, + stateB=toluene_system, + mapping=benzene_to_toluene_mapping, + ) + + # 9 protocol units, 3 per repeat + pus = list(dag.protocol_units) + assert len(pus) == 9 + + # Aggregate some info for each repeat + setup = _get_units(pus, HybridTopologySetupUnit) + simulation = _get_units(pus, HybridTopologyMultiStateSimulationUnit) + analysis = _get_units(pus, HybridTopologyMultiStateAnalysisUnit) + + # Should be 3 of everything + assert len(setup) == len(simulation) == len(analysis) == 3 + + # Check that the dag chain is correct + for analysis_pu in analysis: + repeat_id = analysis_pu.inputs["repeat_id"] + setup_pu = [s for s in setup if s.inputs["repeat_id"] == repeat_id][0] + sim_pu = [s for s in simulation if s.inputs["repeat_id"] == repeat_id][0] + + assert analysis_pu.inputs["setup_results"] == setup_pu + assert analysis_pu.inputs["simulation_results"] == sim_pu + assert sim_pu.inputs["setup_results"] == setup_pu + + def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_to_toluene_mapping): # if we create two dags each with 3 repeats, they should give 6 repeat_ids # this allows multiple DAGs in flight for one Transformation that don't clash on gather @@ -187,7 +231,6 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t ) repeat_ids = set() - u: openmm_rfe.RelativeHybridTopologyProtocolUnit for u in dag1.protocol_units: repeat_ids.add(u.inputs["repeat_id"]) for u in dag2.protocol_units: @@ -196,23 +239,8 @@ def test_create_independent_repeat_ids(benzene_system, toluene_system, benzene_t assert len(repeat_ids) == 6 -@pytest.mark.parametrize( - "mapping", - [None, [], ["A", "B"]], -) -def test_validate_alchemical_components_wrong_mappings(mapping): - with pytest.raises(ValueError, match="A single LigandAtomMapping"): - _validate_alchemical_components({"stateA": [], "stateB": []}, mapping) - - -def test_validate_alchemical_components_missing_alchem_comp(benzene_to_toluene_mapping): - alchem_comps = {"stateA": [openfe.SolventComponent()], "stateB": []} - with pytest.raises(ValueError, match="Unmapped alchemical component"): - _validate_alchemical_components(alchem_comps, benzene_to_toluene_mapping) - - @pytest.mark.parametrize("method", ["repex", "sams", "independent", "InDePeNdENT"]) -def test_dry_run_default_vacuum( +def test_setup_dry_sim_default_vacuum( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, @@ -233,16 +261,27 @@ def test_dry_run_default_vacuum( stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] + # Manually run the units + setup_results = dag_setup_unit.run(dry=True) + + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) + + sampler = sim_results["sampler"] assert isinstance(sampler, MultiStateSampler) assert not sampler.is_periodic assert sampler._thermodynamic_states[0].barostat is None # Check hybrid OMM and MDTtraj Topologies - htf = sampler._hybrid_factory + htf = setup_results["hybrid_factory"] # 16 atoms: # 11 common atoms, 1 extra hydrogen in benzene, 4 extra in toluene # 12 bonds in benzene + 4 extra toluene bonds @@ -281,9 +320,13 @@ def test_dry_run_default_vacuum( ) -def test_dry_run_gaff_vacuum( +def test_setup_gaff_vacuum( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, tmpdir ): + """ + Simple dry run of the setup unit to make sure that parameterisation + will work with gaff. + """ vac_settings.forcefield_settings.small_molecule_forcefield = "gaff-2.11" protocol = openmm_rfe.RelativeHybridTopologyProtocol( @@ -296,9 +339,10 @@ def test_dry_run_gaff_vacuum( stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, ) - unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + with tmpdir.as_cwd(): - _ = unit.run(dry=True)["debug"]["sampler"] + _ = dag_setup_unit.run(dry=True) @pytest.mark.slow @@ -310,7 +354,7 @@ def test_dry_many_molecules_solvent( tmpdir, ): """ - A basic test flushing "will it work if you pass multiple molecules" + A basic setup test flushing "will it work if you pass multiple molecules" """ protocol = openmm_rfe.RelativeHybridTopologyProtocol( settings=solv_settings, @@ -322,10 +366,10 @@ def test_dry_many_molecules_solvent( stateB=toluene_many_solv_system, mapping=benzene_to_toluene_mapping, ) - unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)["debug"]["sampler"] + _ = dag_setup_unit.run(dry=True) BENZ = """\ @@ -394,7 +438,7 @@ $$$$ """ -def test_dry_core_element_change(vac_settings, tmpdir): +def test_setup_core_element_change(vac_settings, tmpdir): benz = openfe.SmallMoleculeComponent(Chem.MolFromMolBlock(BENZ, removeHs=False)) pyr = openfe.SmallMoleculeComponent(Chem.MolFromMolBlock(PYRIDINE, removeHs=False)) @@ -410,11 +454,11 @@ def test_dry_core_element_change(vac_settings, tmpdir): mapping=mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - system = sampler._hybrid_factory.hybrid_system + results = dag_setup_unit.run(dry=True) + system = results["hybrid_system"] assert system.getNumParticles() == 12 # Average mass between nitrogen and carbon assert system.getParticleMass(1) == 12.0127235 * omm_unit.amu @@ -443,10 +487,21 @@ def test_dry_run_ligand( stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] + # Manually run the units + setup_results = dag_setup_unit.run(dry=True) + + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) + + sampler = sim_results["sampler"] assert isinstance(sampler, MultiStateSampler) assert sampler.is_periodic assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloBarostat) @@ -507,18 +562,18 @@ def tip4p_hybrid_factory( stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = [pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit)][0] shared_temp = tmp_path_factory.mktemp("tip4p_shared") scratch_temp = tmp_path_factory.mktemp("tip4p_scratch") - dag_unit_result = dag_unit.run( + dag_unit_setup_result = dag_setup_unit.run( dry=True, scratch_basepath=scratch_temp, shared_basepath=shared_temp, ) - return dag_unit_result["debug"]["sampler"]._factory + return dag_unit_setup_result["hybrid_factory"] def test_tip4p_particle_count(tip4p_hybrid_factory): @@ -603,7 +658,7 @@ def test_tip4p_check_vsite_parameters(tip4p_hybrid_factory): 0.9 * unit.nanometer, ], ) -def test_dry_run_ligand_system_cutoff( +def test_setup_ligand_system_cutoff( cutoff, benzene_system, toluene_system, benzene_to_toluene_mapping, solv_settings, tmpdir ): """ @@ -615,16 +670,17 @@ def test_dry_run_ligand_system_cutoff( protocol = openmm_rfe.RelativeHybridTopologyProtocol( settings=solv_settings, ) + dag = protocol.create( stateA=benzene_system, stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + + dag_setup_unit = [pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit)][0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - hs = sampler._factory.hybrid_system + hs = dag_setup_unit.run(dry=True)["hybrid_system"] nbfs = [ f @@ -652,8 +708,8 @@ def test_dry_run_ligand_system_cutoff( "rdkit", "nagl", marks=pytest.mark.skipif( - not HAS_NAGL or sys.platform.startswith("darwin"), - reason="needs NAGL and/or on macos", + not HAS_NAGL or HAS_OPENEYE or sys.platform.startswith("darwin"), + reason="needs NAGL (without oechem) and/or on macos", ), ), pytest.param( @@ -664,7 +720,7 @@ def test_dry_run_ligand_system_cutoff( ), ], ) -def test_dry_run_charge_backends( +def test_setup_charge_backends( CN_molecule, tmpdir, method, backend, ref_key, vac_settings, am1bcc_ref_charges ): vac_settings.partial_charge_settings.partial_charge_method = method @@ -688,12 +744,12 @@ def test_dry_run_charge_backends( dag = protocol.create(stateA=systemA, stateB=systemB, mapping=mapping) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = [pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit)][0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory - hybrid_system = htf.hybrid_system + results = dag_setup_unit.run(dry=True) + htf = results["hybrid_factory"] + hybrid_system = results["hybrid_system"] # get the standard nonbonded force nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] @@ -727,8 +783,72 @@ def test_dry_run_charge_backends( np.testing.assert_allclose(c, ref, rtol=1e-4) +def test_setup_same_mol_different_charges(benzene_modifications, vac_settings, tmpdir): + """ + Issue #1120 - make sure we can do an RFE of a system with different + parameters but the same molecule. + """ + protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=vac_settings) + + benzene_offmol = benzene_modifications["benzene"].to_openff() + # Give state A some gasteiger charges + benzene_offmol.assign_partial_charges(partial_charge_method="gasteiger") + stateA_charges = copy.deepcopy(benzene_offmol.partial_charges) + stateA_mol = openfe.SmallMoleculeComponent.from_openff(benzene_offmol) + + # Give state B gasteiger charges scaled by 0.9 + benzene_offmol.partial_charges *= 0.9 + stateB_charges = copy.deepcopy(benzene_offmol.partial_charges) + stateB_mol = openfe.SmallMoleculeComponent.from_openff(benzene_offmol) + + # Create new mapping + mapping = gufe.LigandAtomMapping( + componentA=stateA_mol, + componentB=stateB_mol, + componentA_to_componentB={i: i for i in range(12)}, + ) + + # create DAG from protocol and take first (and only) work unit from within + dag = protocol.create( + stateA=openfe.ChemicalSystem({"l": stateA_mol}), + stateB=openfe.ChemicalSystem({"l": stateB_mol}), + mapping=mapping, + ) + dag_setup_unit = [pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit)][0] + + with tmpdir.as_cwd(): + results = dag_setup_unit.run(dry=True) + htf = results["hybrid_factory"] + hybrid_system = results["hybrid_system"] + + # get the standard nonbonded force + nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] + + # get the particle parameters & offsets + for i in range(hybrid_system.getNumParticles()): + # All particles should be core atoms + assert i in htf._atom_classes["core_atoms"] + + # offsets + offset = ensure_quantity(nonbond[0].getParticleParameterOffset(i)[2], "openff") + + # parameters + c, s, e = nonbond[0].getParticleParameters(i) + c = ensure_quantity(c, "openff") + + # check state A charge + assert pytest.approx(c) == stateA_charges[i] + + # check state B charge + c_diff = stateB_charges[i] - stateA_charges[i] + assert pytest.approx(offset) == c_diff + + # check that the offset value is non-zero + assert abs(offset) > 0 * offset.units + + @pytest.mark.flaky(reruns=3) # bad minimisation can happen -def test_dry_run_user_charges(benzene_modifications, vac_settings, tmpdir): +def test_setup_user_charges(benzene_modifications, vac_settings, tmpdir): """ Create a hybrid system with a set of fictitious user supplied charges and ensure that they are properly passed through to the constructed @@ -782,12 +902,12 @@ def test_dry_run_user_charges(benzene_modifications, vac_settings, tmpdir): stateB=openfe.ChemicalSystem({"l": toluene_smc}), mapping=mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = [pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit)][0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory - hybrid_system = htf.hybrid_system + results = dag_setup_unit.run(dry=True) + htf = results["hybrid_factory"] + hybrid_system = results["hybrid_system"] # get the standard nonbonded force nonbond = [f for f in hybrid_system.getForces() if isinstance(f, NonbondedForce)] @@ -875,15 +995,23 @@ def test_virtual_sites_no_reassign( stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] with tmpdir.as_cwd(): + # Manually run the units + setup_results = dag_setup_unit.run(dry=True) errmsg = "Simulations with virtual sites without velocity" with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True) + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) -def test_dodecahdron_ligand_box( +def test_setup_dodecahdron_ligand_box( benzene_system, toluene_system, benzene_to_toluene_mapping, solv_settings, tmpdir ): """ @@ -898,11 +1026,10 @@ def test_dodecahdron_ligand_box( stateB=toluene_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] - hs = sampler._factory.hybrid_system + hs = dag_setup_unit.run(dry=True)["hybrid_system"] vectors = hs.getDefaultPeriodicBoxVectors() @@ -940,10 +1067,18 @@ def test_dry_run_complex( stateB=toluene_complex_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] + setup_results = dag_setup_unit.run(dry=True) + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) + sampler = sim_results["sampler"] assert isinstance(sampler, MultiStateSampler) assert sampler.is_periodic assert isinstance(sampler._thermodynamic_states[0].barostat, MonteCarloBarostat) @@ -967,247 +1102,7 @@ def test_lambda_schedule(windows): assert len(lambdas.lambda_schedule) == windows -def test_hightimestep( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - vac_settings.forcefield_settings.hydrogen_mass = 1.0 - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - errmsg = "too large for hydrogen mass" - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True) - - -def test_n_replicas_not_n_windows( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - vac_settings, - tmpdir, -): - # For PR #125 we pin such that the number of lambda windows - # equals the numbers of replicas used - TODO: remove limitation - # default lambda windows is 11 - vac_settings.simulation_settings.n_replicas = 13 - - errmsg = "Number of replicas 13 does not equal the number of lambda windows 11" - - with tmpdir.as_cwd(): - with pytest.raises(ValueError, match=errmsg): - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - dag = p.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - dag_unit.run(dry=True) - - -def test_missing_ligand(benzene_system, benzene_to_toluene_mapping): - # state B doesn't have a ligand component - stateB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - - match_str = "missing alchemical components in stateB" - with pytest.raises(ValueError, match=match_str): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_vaccuum_PME_error( - benzene_vacuum_system, benzene_modifications, benzene_to_toluene_mapping -): - # state B doesn't have a solvent component (i.e. its vacuum) - stateB = openfe.ChemicalSystem({"ligand": benzene_modifications["toluene"]}) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "PME cannot be used for vacuum transform" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_vacuum_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_incompatible_solvent(benzene_system, benzene_modifications, benzene_to_toluene_mapping): - # the solvents are different - stateB = openfe.ChemicalSystem( - { - "ligand": benzene_modifications["toluene"], - "solvent": openfe.SolventComponent(positive_ion="K", negative_ion="Cl"), - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - # We don't have a way to map non-ligand components so for now it - # just triggers that it's not a mapped component - errmsg = "missing alchemical components in stateA" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=stateB, - mapping=benzene_to_toluene_mapping, - ) - - -def test_mapping_mismatch_A(benzene_system, toluene_system, benzene_modifications): - # the atom mapping doesn't refer to the ligands in the systems - mapping = setup.LigandAtomMapping( - componentA=benzene_system.components["ligand"], - componentB=benzene_modifications["phenol"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=toluene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_mapping_mismatch_B(benzene_system, toluene_system, benzene_modifications): - mapping = setup.LigandAtomMapping( - componentA=benzene_modifications["phenol"], - componentB=toluene_system.components["ligand"], - componentA_to_componentB=dict(), - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = ( - r"Unmapped alchemical component " - r"SmallMoleculeComponent\(name=benzene\)" - ) - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=mapping, - ) - - -def test_complex_mismatch(benzene_system, toluene_complex_system, benzene_to_toluene_mapping): - # only one complex - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_system, - stateB=toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_too_many_specified_mappings(benzene_system, toluene_system, benzene_to_toluene_mapping): - # mapping dict requires 'ligand' key - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - errmsg = "A single LigandAtomMapping is expected for this Protocol" - with pytest.raises(ValueError, match=errmsg): - _ = p.create( - stateA=benzene_system, - stateB=toluene_system, - mapping=[benzene_to_toluene_mapping, benzene_to_toluene_mapping], - ) - - -def test_protein_mismatch( - benzene_complex_system, toluene_complex_system, benzene_to_toluene_mapping -): - # hack one protein to be labelled differently - prot = toluene_complex_system["protein"] - alt_prot = openfe.ProteinComponent(prot.to_rdkit(), name="Mickey Mouse") - alt_toluene_complex_system = openfe.ChemicalSystem( - { - "ligand": toluene_complex_system["ligand"], - "solvent": toluene_complex_system["solvent"], - "protein": alt_prot, - } - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.raises(ValueError): - _ = p.create( - stateA=benzene_complex_system, - stateB=alt_toluene_complex_system, - mapping=benzene_to_toluene_mapping, - ) - - -def test_element_change_warning(atom_mapping_basic_test_files): - # check a mapping with element change gets rejected early - l1 = atom_mapping_basic_test_files["2-methylnaphthalene"] - l2 = atom_mapping_basic_test_files["2-naftanol"] - - # We use the 'old' lomap defaults because the - # basic test files inputs we use aren't fully aligned - mapper = setup.LomapAtomMapper( - time=20, threed=True, max3d=1000.0, element_change=True, seed="", shift=True - ) - - mapping = next(mapper.suggest_mappings(l1, l2)) - - sys1 = openfe.ChemicalSystem( - {"ligand": l1, "solvent": openfe.SolventComponent()}, - ) - sys2 = openfe.ChemicalSystem( - {"ligand": l2, "solvent": openfe.SolventComponent()}, - ) - - p = openmm_rfe.RelativeHybridTopologyProtocol( - settings=openmm_rfe.RelativeHybridTopologyProtocol.default_settings(), - ) - with pytest.warns(UserWarning, match="Element change"): - _ = p.create( - stateA=sys1, - stateB=sys2, - mapping=mapping, - ) - - -def test_ligand_overlap_warning( +def test_setup_ligand_overlap_warning( benzene_vacuum_system, toluene_vacuum_system, benzene_to_toluene_mapping, vac_settings, tmpdir ): protocol = openmm_rfe.RelativeHybridTopologyProtocol( @@ -1239,9 +1134,11 @@ def test_ligand_overlap_warning( stateB=toluene_vacuum_system, mapping=mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = [ + pu for pu in dag.protocol_units if isinstance(pu, HybridTopologySetupUnit) + ][0] with tmpdir.as_cwd(): - dag_unit.run(dry=True) + dag_setup_unit.run(dry=True) @pytest.fixture @@ -1260,28 +1157,109 @@ def solvent_protocol_dag(benzene_system, toluene_system, benzene_to_toluene_mapp def test_unit_tagging(solvent_protocol_dag, tmpdir): # test that executing the Units includes correct generation and repeat info dag_units = solvent_protocol_dag.protocol_units - with mock.patch( - "openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + with ( + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.np.load", + return_value=np.zeros(100), + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.deserialize", + return_value={ + "item": "foo", + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateSimulationUnit.run", + return_value={ + "nc": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit.run", + return_value={ + "foo": "bar", + }, + ), ): - results = [] - for u in dag_units: - ret = u.execute(context=gufe.Context(tmpdir, tmpdir)) - results.append(ret) - repeats = set() - for ret in results: - assert isinstance(ret, gufe.ProtocolUnitResult) - assert ret.outputs["generation"] == 0 - repeats.add(ret.outputs["repeat_id"]) + setup_results = {} + sim_results = {} + analysis_results = {} + + setup_units = _get_units(dag_units, HybridTopologySetupUnit) + sim_units = _get_units(dag_units, HybridTopologyMultiStateSimulationUnit) + analysis_units = _get_units(dag_units, HybridTopologyMultiStateAnalysisUnit) + + for u in setup_units: + rid = u.inputs["repeat_id"] + setup_results[rid] = u.execute(context=gufe.Context(tmpdir, tmpdir)) + + for u in sim_units: + rid = u.inputs["repeat_id"] + sim_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), setup_results=setup_results[rid] + ) + + for u in analysis_units: + rid = u.inputs["repeat_id"] + analysis_results[rid] = u.execute( + context=gufe.Context(tmpdir, tmpdir), + setup_results=setup_results[rid], + simulation_results=sim_results[rid], + ) + for results in [setup_results, sim_results, analysis_results]: + for ret in results.values(): + assert isinstance(ret, gufe.ProtocolUnitResult) + assert ret.outputs["generation"] == 0 + # repeats are random ints, so check we got 3 individual numbers - assert len(repeats) == 3 + assert len(setup_results) == len(sim_results) == len(analysis_results) == 3 def test_gather(solvent_protocol_dag, tmpdir): # check .gather behaves as expected - with mock.patch( - "openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit.run", - return_value={"nc": "file.nc", "last_checkpoint": "chk.nc"}, + with ( + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit.run", + return_value={ + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.zeros(100), + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.np.load", + return_value=np.zeros(100), + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.deserialize", + return_value={ + "item": "foo", + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateSimulationUnit.run", + return_value={ + "nc": Path("file.nc"), + "checkpoint": Path("chk.chk"), + }, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit.run", + return_value={ + "foo": "bar", + }, + ), ): dagres = gufe.protocols.execute_DAG( solvent_protocol_dag, @@ -1592,13 +1570,13 @@ def tyk2_xml(tmp_path_factory): stateB=openfe.ChemicalSystem({"ligand": lig55}), mapping=mapping, ) - pu = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] tmp = tmp_path_factory.mktemp("xml_reg") - dryrun = pu.run(dry=True, shared_basepath=tmp) + setup_results = dag_setup_unit.run(dry=True, shared_basepath=tmp) - system = dryrun["debug"]["sampler"]._hybrid_factory.hybrid_system + system = setup_results["hybrid_system"] return ET.fromstring(XmlSerializer.serialize(system)) @@ -1749,68 +1727,6 @@ class TestProtocolResult: protocolresult.get_replica_states() -@pytest.mark.parametrize( - "mapping_name,result", - [ - ["benzene_to_toluene_mapping", 0], - ["benzene_to_benzoic_mapping", 1], - ["benzene_to_aniline_mapping", -1], - ["aniline_to_benzene_mapping", 1], - ], -) -def test_get_charge_difference(mapping_name, result, request): - mapping = request.getfixturevalue(mapping_name) - if result != 0: - ion = r"Na\+" if result == -1 else r"Cl\-" - wmsg = ( - f"A charge difference of {result} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion" - ) - with pytest.warns(UserWarning, match=wmsg): - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - else: - val = _get_alchemical_charge_difference(mapping, "pme", True, openfe.SolventComponent()) - assert result == pytest.approx(val) - - -def test_get_charge_difference_no_pme(benzene_to_benzoic_mapping): - errmsg = "Explicit charge correction when not using PME" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "nocutoff", - True, - openfe.SolventComponent(), - ) - - -def test_get_charge_difference_no_corr(benzene_to_benzoic_mapping): - wmsg = ( - "A charge difference of 1 is observed between the end states. " - "No charge correction has been requested" - ) - with pytest.warns(UserWarning, match=wmsg): - _get_alchemical_charge_difference( - benzene_to_benzoic_mapping, - "pme", - False, - openfe.SolventComponent(), - ) - - -def test_greater_than_one_charge_difference_error(aniline_to_benzoic_mapping): - errmsg = "A charge difference of 2" - with pytest.raises(ValueError, match=errmsg): - _get_alchemical_charge_difference( - aniline_to_benzoic_mapping, - "pme", - True, - openfe.SolventComponent(), - ) - - @pytest.fixture(scope="session") def benzene_solvent_openmm_system(benzene_modifications): smc = benzene_modifications["benzene"] @@ -2150,11 +2066,12 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, solv_settings, stateB=stateB_system, mapping=benzene_to_benzoic_mapping, ) - unit = list(dag.protocol_units)[0] + + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory + results = dag_setup_unit.run(dry=True) + htf = results["hybrid_factory"] _assert_total_charge(htf.hybrid_system, htf._atom_classes, 0, 0) assert len(htf._atom_classes["core_atoms"]) == 14 @@ -2175,7 +2092,7 @@ def test_dry_run_alchemwater_solvent(benzene_to_benzoic_mapping, solv_settings, ["benzoic_to_benzene_mapping", 0, 1, False, 11, 1, 3], ], ) -def test_dry_run_complex_alchemwater_totcharge( +def test_setup_complex_alchemwater_totcharge( mapping_name, chgA, chgB, @@ -2219,11 +2136,11 @@ def test_dry_run_complex_alchemwater_totcharge( stateB=stateB_system, mapping=mapping, ) - unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] with tmpdir.as_cwd(): - sampler = unit.run(dry=True)["debug"]["sampler"] - htf = sampler._factory + setup_results = dag_setup_unit.run(dry=True) + htf = setup_results["hybrid_factory"] _assert_total_charge(htf.hybrid_system, htf._atom_classes, chgA, chgB) assert len(htf._atom_classes["core_atoms"]) == core_atoms @@ -2233,8 +2150,11 @@ def test_dry_run_complex_alchemwater_totcharge( def test_structural_analysis_error(tmpdir): with tmpdir.as_cwd(): - ret = openmm_rfe.RelativeHybridTopologyProtocolUnit.structural_analysis( - Path("."), Path(".") + ret = openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit._structural_analysis( + Path("."), + Path("."), + Path("."), + True, ) assert "structural_analysis_error" in ret @@ -2274,10 +2194,21 @@ def test_dry_run_vacuum_write_frequency( stateB=toluene_vacuum_system, mapping=benzene_to_toluene_mapping, ) - dag_unit = list(dag.protocol_units)[0] + dag_setup_unit = _get_units(dag.protocol_units, HybridTopologySetupUnit)[0] + dag_sim_unit = _get_units(dag.protocol_units, HybridTopologyMultiStateSimulationUnit)[0] with tmpdir.as_cwd(): - sampler = dag_unit.run(dry=True)["debug"]["sampler"] + # Manually run the units + setup_results = dag_setup_unit.run(dry=True) + + sim_results = dag_sim_unit.run( + system=setup_results["hybrid_system"], + positions=setup_results["hybrid_positions"], + selection_indices=setup_results["selection_indices"], + dry=True, + ) + + sampler = sim_results["sampler"] reporter = sampler._reporter if positions_write_frequency: assert reporter.position_interval == positions_write_frequency.m @@ -2287,40 +2218,3 @@ def test_dry_run_vacuum_write_frequency( assert reporter.velocity_interval == velocities_write_frequency.m else: assert reporter.velocity_interval == 0 - - -@pytest.mark.parametrize( - "positions_write_frequency,velocities_write_frequency", - [ - [100.1 * unit.picosecond, 100 * unit.picosecond], - [100 * unit.picosecond, 100.1 * unit.picosecond], - ], -) -def test_pos_write_frequency_not_divisible( - benzene_vacuum_system, - toluene_vacuum_system, - benzene_to_toluene_mapping, - positions_write_frequency, - velocities_write_frequency, - tmpdir, - vac_settings, -): - vac_settings.output_settings.positions_write_frequency = positions_write_frequency - vac_settings.output_settings.velocities_write_frequency = velocities_write_frequency - - protocol = openmm_rfe.RelativeHybridTopologyProtocol( - settings=vac_settings, - ) - - # create DAG from protocol and take first (and only) work unit from within - dag = protocol.create( - stateA=benzene_vacuum_system, - stateB=toluene_vacuum_system, - mapping=benzene_to_toluene_mapping, - ) - dag_unit = list(dag.protocol_units)[0] - - with tmpdir.as_cwd(): - errmsg = "The output settings' " - with pytest.raises(ValueError, match=errmsg): - dag_unit.run(dry=True)["debug"]["sampler"] diff --git a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py similarity index 68% rename from openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py rename to src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py index a8875142..555bd10e 100644 --- a/openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_slow.py @@ -6,10 +6,11 @@ import numpy as np import pytest from gufe.protocols import execute_DAG from numpy.testing import assert_allclose -from openff.units import unit +from openff.units import unit as offunit import openfe from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_utils.charge_generation import HAS_NAGL, HAS_OPENEYE @pytest.mark.slow @@ -28,14 +29,14 @@ def test_openmm_run_engine( # these settings are a small self to self sim, that has enough eq that # it doesn't occasionally crash s = openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - s.simulation_settings.equilibration_length = 0.1 * unit.picosecond - s.simulation_settings.production_length = 0.1 * unit.picosecond - s.simulation_settings.time_per_iteration = 20 * unit.femtosecond + s.simulation_settings.equilibration_length = 0.1 * offunit.picosecond + s.simulation_settings.production_length = 0.1 * offunit.picosecond + s.simulation_settings.time_per_iteration = 20 * offunit.femtosecond s.forcefield_settings.nonbonded_method = "nocutoff" s.protocol_repeats = 1 s.engine_settings.compute_platform = platform - s.output_settings.checkpoint_interval = 20 * unit.femtosecond - s.output_settings.positions_write_frequency = 20 * unit.femtosecond + s.output_settings.checkpoint_interval = 20 * offunit.femtosecond + s.output_settings.positions_write_frequency = 20 * offunit.femtosecond p = openmm_rfe.RelativeHybridTopologyProtocol(s) @@ -61,24 +62,37 @@ def test_openmm_run_engine( r = execute_DAG(dag, shared_basedir=cwd, scratch_basedir=cwd, keep_shared=True) assert r.ok() + + # Get the path to the simulation unit shared for pur in r.protocol_unit_results: - unit_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" - assert unit_shared.exists() - assert pathlib.Path(unit_shared).is_dir() + if "Simulation" in pur.name: + sim_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert sim_shared.exists() + assert pathlib.Path(sim_shared).is_dir() + + for pur in r.protocol_unit_results: + if "Analysis" not in pur.name: + continue + + analysis_shared = tmpdir / f"shared_{pur.source_key}_attempt_0" + assert analysis_shared.exists() + assert pathlib.Path(analysis_shared).is_dir() # Check the checkpoint file exists - checkpoint = pur.outputs["last_checkpoint"] - assert checkpoint == "checkpoint.chk" - assert (unit_shared / checkpoint).exists() + checkpoint = pur.outputs["checkpoint"] + assert checkpoint.name == "checkpoint.chk" + assert checkpoint == sim_shared / "checkpoint.chk" + assert checkpoint.exists() # Check the nc simulation file exists # TODO: assert the number of frames - nc = pur.outputs["nc"] - assert nc == unit_shared / "simulation.nc" + nc = pur.outputs["trajectory"] + assert nc.name == "simulation.nc" + assert nc == sim_shared / "simulation.nc" assert nc.exists() # Check structural analysis contents - structural_analysis_file = unit_shared / "structural_analysis.npz" + structural_analysis_file = analysis_shared / "structural_analysis.npz" assert (structural_analysis_file).exists() assert pur.outputs["structural_analysis"] == structural_analysis_file @@ -110,6 +124,11 @@ def test_openmm_run_engine( @pytest.mark.integration # takes ~7 minutes to run @pytest.mark.flaky(reruns=3) +@pytest.mark.skipif(not HAS_NAGL, reason="need NAGL") +@pytest.mark.skipif( + HAS_OPENEYE and HAS_NAGL, + reason="NAGL/openeye incompatibility. See https://github.com/openforcefield/openff-nagl/issues/177", +) def test_run_eg5_sim(eg5_protein, eg5_ligands, eg5_cofactor, tmpdir): # this runs a very short eg5 complex leg # different to previous test: @@ -118,11 +137,14 @@ def test_run_eg5_sim(eg5_protein, eg5_ligands, eg5_cofactor, tmpdir): # - runs in solvated protein # if this passes 99.9% chance of a good time s = openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - s.simulation_settings.equilibration_length = 0.1 * unit.picosecond - s.simulation_settings.production_length = 0.1 * unit.picosecond - s.simulation_settings.time_per_iteration = 20 * unit.femtosecond + s.simulation_settings.equilibration_length = 0.1 * offunit.picosecond + s.simulation_settings.production_length = 0.1 * offunit.picosecond + s.simulation_settings.time_per_iteration = 20 * offunit.femtosecond + s.forcefield_settings.nonbonded_cutoff = 0.8 * offunit.nanometer + s.partial_charge_settings.partial_charge_method = "nagl" + s.partial_charge_settings.nagl_model = "openff-gnn-am1bcc-0.1.0-rc.3.pt" s.protocol_repeats = 1 - s.output_settings.checkpoint_interval = 20 * unit.femtosecond + s.output_settings.checkpoint_interval = 20 * offunit.femtosecond p = openmm_rfe.RelativeHybridTopologyProtocol(s) @@ -158,13 +180,13 @@ def test_run_dodecahedron_sim(benzene_system, toluene_system, benzene_to_toluene Test that we can run a ligand in solvent RFE with a non-cubic box """ settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() - settings.solvation_settings.solvent_padding = 1.5 * unit.nanometer + settings.solvation_settings.solvent_padding = 1.5 * offunit.nanometer settings.solvation_settings.box_shape = "dodecahedron" settings.protocol_repeats = 1 - settings.simulation_settings.equilibration_length = 0.1 * unit.picosecond - settings.simulation_settings.production_length = 0.1 * unit.picosecond - settings.simulation_settings.time_per_iteration = 20 * unit.femtosecond - settings.output_settings.checkpoint_interval = 20 * unit.femtosecond + settings.simulation_settings.equilibration_length = 0.1 * offunit.picosecond + settings.simulation_settings.production_length = 0.1 * offunit.picosecond + settings.simulation_settings.time_per_iteration = 20 * offunit.femtosecond + settings.output_settings.checkpoint_interval = 20 * offunit.femtosecond protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings) dag = protocol.create( diff --git a/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py new file mode 100644 index 00000000..404b03c0 --- /dev/null +++ b/src/openfe/tests/protocols/openmm_rfe/test_hybrid_top_tokenization.py @@ -0,0 +1,171 @@ +# This code is part of OpenFE and is licensed under the MIT license. +# For details, see https://github.com/OpenFreeEnergy/openfe + +import pytest +from gufe.tests.test_tokenization import GufeTokenizableTestsMixin +from openff.units import unit + +from openfe.protocols import openmm_rfe +from openfe.protocols.openmm_rfe.hybridtop_units import ( + HybridTopologyMultiStateAnalysisUnit, + HybridTopologyMultiStateSimulationUnit, + HybridTopologySetupUnit, +) + +""" +todo: +- RelativeHybridTopologyProtocolResult +- RelativeHybridTopologyProtocol +- RelativeHybridTopologyProtocolUnit +""" + + +@pytest.fixture +def rfe_protocol(): + return openmm_rfe.RelativeHybridTopologyProtocol( + openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + ) + + +@pytest.fixture +def rfe_protocol_other_input_units(): + """Identical to rfe_protocol, but with `kcal / mol` as input unit instead of `kilocalorie_per_mole`.""" + new_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() + new_settings.simulation_settings.early_termination_target_error = 0.0 * unit.kilocalorie/unit.mol # fmt: skip + return openmm_rfe.RelativeHybridTopologyProtocol(new_settings) + + +@pytest.fixture +def protocol_units(rfe_protocol, benzene_system, toluene_system, benzene_to_toluene_mapping): + pus = rfe_protocol.create( + stateA=benzene_system, + stateB=toluene_system, + mapping=[benzene_to_toluene_mapping], + ) + return list(pus.protocol_units) + + +@pytest.fixture +def protocol_setup_unit(protocol_units): + for pu in protocol_units: + if isinstance(pu, HybridTopologySetupUnit): + return pu + + +@pytest.fixture +def protocol_simulation_unit(protocol_units): + for pu in protocol_units: + if isinstance(pu, HybridTopologyMultiStateSimulationUnit): + return pu + + +@pytest.fixture +def protocol_analysis_unit(protocol_units): + for pu in protocol_units: + if isinstance(pu, HybridTopologyMultiStateAnalysisUnit): + return pu + + +@pytest.mark.skip +class TestRelativeHybridTopologyProtocolResult(GufeTokenizableTestsMixin): + cls = openmm_rfe.RelativeHybridTopologyProtocolResult + repr = "" + key = "" + + @pytest.fixture() + def instance(self): + pass + + +class TestRelativeHybridTopologyProtocolOtherInputUnits(GufeTokenizableTestsMixin): + cls = openmm_rfe.RelativeHybridTopologyProtocol + key = None + repr = "" -def test_serialize_gz(tmpdir): - filename = pathlib.Path(tmpdir / "file.xml.gz") - expected = "" - - with mock.patch("openmm.XmlSerializer.serialize", return_value=expected): - serialize(object(), filename) - - with gzip.open(filename, "rb") as f: - read_back = f.read().decode() - assert read_back == expected - - def test_serialize_bz2(tmpdir): filename = pathlib.Path(tmpdir / "file.xml.bz2") expected = "" @@ -61,19 +49,6 @@ def test_deserialize_xml(tmpdir): assert result == "DESERIALIZED" -def test_deserialize_gz(tmpdir): - filename = pathlib.Path(tmpdir / "file.xml.gz") - expected_serialized = "gz" - with gzip.open(filename, "wb") as f: - f.write(expected_serialized.encode()) - - with mock.patch("openmm.XmlSerializer.deserialize", return_value="FROM_GZ") as deser: - result = deserialize(filename) - - deser.assert_called_once_with(expected_serialized) - assert result == "FROM_GZ" - - def test_deserialize_bz2(tmpdir): filename = pathlib.Path(tmpdir / "file.xml.bz2") expected_serialized = "bz2" diff --git a/openfe/tests/setup/alchemical_network_planner/__init__.py b/src/openfe/tests/setup/__init__.py similarity index 100% rename from openfe/tests/setup/alchemical_network_planner/__init__.py rename to src/openfe/tests/setup/__init__.py diff --git a/openfe/tests/setup/atom_mapping/__init__.py b/src/openfe/tests/setup/alchemical_network_planner/__init__.py similarity index 100% rename from openfe/tests/setup/atom_mapping/__init__.py rename to src/openfe/tests/setup/alchemical_network_planner/__init__.py diff --git a/openfe/tests/setup/alchemical_network_planner/edge_types.py b/src/openfe/tests/setup/alchemical_network_planner/edge_types.py similarity index 100% rename from openfe/tests/setup/alchemical_network_planner/edge_types.py rename to src/openfe/tests/setup/alchemical_network_planner/edge_types.py diff --git a/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py b/src/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py similarity index 100% rename from openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py rename to src/openfe/tests/setup/alchemical_network_planner/test_relative_alchemical_network_planner.py diff --git a/openfe/tests/setup/chemicalsystem_generator/__init__.py b/src/openfe/tests/setup/atom_mapping/__init__.py similarity index 100% rename from openfe/tests/setup/chemicalsystem_generator/__init__.py rename to src/openfe/tests/setup/atom_mapping/__init__.py diff --git a/openfe/tests/setup/atom_mapping/conftest.py b/src/openfe/tests/setup/atom_mapping/conftest.py similarity index 100% rename from openfe/tests/setup/atom_mapping/conftest.py rename to src/openfe/tests/setup/atom_mapping/conftest.py diff --git a/openfe/tests/setup/atom_mapping/test_atommapper.py b/src/openfe/tests/setup/atom_mapping/test_atommapper.py similarity index 100% rename from openfe/tests/setup/atom_mapping/test_atommapper.py rename to src/openfe/tests/setup/atom_mapping/test_atommapper.py diff --git a/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py b/src/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py similarity index 100% rename from openfe/tests/setup/atom_mapping/test_lomap_atommapper.py rename to src/openfe/tests/setup/atom_mapping/test_lomap_atommapper.py diff --git a/openfe/tests/setup/atom_mapping/test_lomap_scorers.py b/src/openfe/tests/setup/atom_mapping/test_lomap_scorers.py similarity index 100% rename from openfe/tests/setup/atom_mapping/test_lomap_scorers.py rename to src/openfe/tests/setup/atom_mapping/test_lomap_scorers.py diff --git a/openfe/tests/setup/atom_mapping/test_perses_atommapper.py b/src/openfe/tests/setup/atom_mapping/test_perses_atommapper.py similarity index 100% rename from openfe/tests/setup/atom_mapping/test_perses_atommapper.py rename to src/openfe/tests/setup/atom_mapping/test_perses_atommapper.py diff --git a/openfe/tests/setup/atom_mapping/test_perses_scorers.py b/src/openfe/tests/setup/atom_mapping/test_perses_scorers.py similarity index 100% rename from openfe/tests/setup/atom_mapping/test_perses_scorers.py rename to src/openfe/tests/setup/atom_mapping/test_perses_scorers.py diff --git a/openfe/tests/storage/__init__.py b/src/openfe/tests/setup/chemicalsystem_generator/__init__.py similarity index 100% rename from openfe/tests/storage/__init__.py rename to src/openfe/tests/setup/chemicalsystem_generator/__init__.py diff --git a/openfe/tests/setup/chemicalsystem_generator/component_checks.py b/src/openfe/tests/setup/chemicalsystem_generator/component_checks.py similarity index 100% rename from openfe/tests/setup/chemicalsystem_generator/component_checks.py rename to src/openfe/tests/setup/chemicalsystem_generator/component_checks.py diff --git a/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py b/src/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py similarity index 100% rename from openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py rename to src/openfe/tests/setup/chemicalsystem_generator/test_easy_chemicalsystem_generator.py diff --git a/openfe/tests/setup/test_network_planning.py b/src/openfe/tests/setup/test_network_planning.py similarity index 100% rename from openfe/tests/setup/test_network_planning.py rename to src/openfe/tests/setup/test_network_planning.py diff --git a/openfe/tests/utils/__init__.py b/src/openfe/tests/storage/__init__.py similarity index 100% rename from openfe/tests/utils/__init__.py rename to src/openfe/tests/storage/__init__.py diff --git a/openfe/tests/storage/conftest.py b/src/openfe/tests/storage/conftest.py similarity index 100% rename from openfe/tests/storage/conftest.py rename to src/openfe/tests/storage/conftest.py diff --git a/openfe/tests/storage/test_metadatastore.py b/src/openfe/tests/storage/test_metadatastore.py similarity index 100% rename from openfe/tests/storage/test_metadatastore.py rename to src/openfe/tests/storage/test_metadatastore.py diff --git a/openfe/tests/storage/test_resultclient.py b/src/openfe/tests/storage/test_resultclient.py similarity index 100% rename from openfe/tests/storage/test_resultclient.py rename to src/openfe/tests/storage/test_resultclient.py diff --git a/openfe/tests/storage/test_resultserver.py b/src/openfe/tests/storage/test_resultserver.py similarity index 100% rename from openfe/tests/storage/test_resultserver.py rename to src/openfe/tests/storage/test_resultserver.py diff --git a/openfecli/commands/__init__.py b/src/openfe/tests/utils/__init__.py similarity index 100% rename from openfecli/commands/__init__.py rename to src/openfe/tests/utils/__init__.py diff --git a/openfe/tests/utils/conftest.py b/src/openfe/tests/utils/conftest.py similarity index 100% rename from openfe/tests/utils/conftest.py rename to src/openfe/tests/utils/conftest.py diff --git a/openfe/tests/utils/test_atommapping_network_plotting.py b/src/openfe/tests/utils/test_atommapping_network_plotting.py similarity index 100% rename from openfe/tests/utils/test_atommapping_network_plotting.py rename to src/openfe/tests/utils/test_atommapping_network_plotting.py diff --git a/openfe/tests/utils/test_duecredit.py b/src/openfe/tests/utils/test_duecredit.py similarity index 96% rename from openfe/tests/utils/test_duecredit.py rename to src/openfe/tests/utils/test_duecredit.py index 28edeb08..67652ad7 100644 --- a/openfe/tests/utils/test_duecredit.py +++ b/src/openfe/tests/utils/test_duecredit.py @@ -34,7 +34,7 @@ class TestDuecredit: ], ], [ - "openfe.protocols.openmm_rfe.equil_rfe_methods", + "openfe.protocols.openmm_rfe.hybridtop_protocols", ["10.5281/zenodo.1297683", "10.5281/zenodo.596622", "10.1371/journal.pcbi.1005659"], ], [ diff --git a/openfe/tests/utils/test_log_control.py b/src/openfe/tests/utils/test_log_control.py similarity index 100% rename from openfe/tests/utils/test_log_control.py rename to src/openfe/tests/utils/test_log_control.py diff --git a/openfe/tests/utils/test_network_plotting.py b/src/openfe/tests/utils/test_network_plotting.py similarity index 100% rename from openfe/tests/utils/test_network_plotting.py rename to src/openfe/tests/utils/test_network_plotting.py diff --git a/openfe/tests/utils/test_optional_imports.py b/src/openfe/tests/utils/test_optional_imports.py similarity index 100% rename from openfe/tests/utils/test_optional_imports.py rename to src/openfe/tests/utils/test_optional_imports.py diff --git a/openfe/tests/utils/test_remove_oechem.py b/src/openfe/tests/utils/test_remove_oechem.py similarity index 100% rename from openfe/tests/utils/test_remove_oechem.py rename to src/openfe/tests/utils/test_remove_oechem.py diff --git a/openfe/tests/utils/test_system_probe.py b/src/openfe/tests/utils/test_system_probe.py similarity index 99% rename from openfe/tests/utils/test_system_probe.py rename to src/openfe/tests/utils/test_system_probe.py index bfdbfbe0..f5a47fb6 100644 --- a/openfe/tests/utils/test_system_probe.py +++ b/src/openfe/tests/utils/test_system_probe.py @@ -6,9 +6,13 @@ import sys from collections import namedtuple from unittest.mock import Mock, patch -import psutil import pytest -from psutil._common import sdiskusage + +try: + from psutil._ntuples import sdiskusage +except ImportError: + from psutil._common import sdiskusage + from openfe.utils.system_probe import ( _get_disk_usage, diff --git a/openfe/tests/utils/test_visualization_3D.py b/src/openfe/tests/utils/test_visualization_3D.py similarity index 100% rename from openfe/tests/utils/test_visualization_3D.py rename to src/openfe/tests/utils/test_visualization_3D.py diff --git a/openfe/utils/__init__.py b/src/openfe/utils/__init__.py similarity index 100% rename from openfe/utils/__init__.py rename to src/openfe/utils/__init__.py diff --git a/openfe/utils/atommapping_network_plotting.py b/src/openfe/utils/atommapping_network_plotting.py similarity index 93% rename from openfe/utils/atommapping_network_plotting.py rename to src/openfe/utils/atommapping_network_plotting.py index 844dcc99..13d899d6 100644 --- a/openfe/utils/atommapping_network_plotting.py +++ b/src/openfe/utils/atommapping_network_plotting.py @@ -142,7 +142,7 @@ class LigandNode(Node): class AtomMappingNetworkDrawing(GraphDrawing): """ - Class for drawing atom mappings from a provided ligang network. + Class for drawing atom mappings from a provided ligand network. Parameters ---------- @@ -167,6 +167,12 @@ def plot_atommapping_network(network: LigandNetwork): Returns ------- :class:`matplotlib.figure.Figure` : - the matplotlib figure containing the iteractive visualization + the matplotlib figure containing the interactive visualization """ - return AtomMappingNetworkDrawing(network.graph).fig + fig = AtomMappingNetworkDrawing(network.graph).fig + axes = fig.axes + for ax in axes: + ax.set_frame_on(False) # remove the black frame + for t in ax.texts: + t.set_clip_on(False) # do not clip the label in the network plot + return fig diff --git a/openfe/utils/custom_typing.py b/src/openfe/utils/custom_typing.py similarity index 100% rename from openfe/utils/custom_typing.py rename to src/openfe/utils/custom_typing.py diff --git a/openfe/utils/ligand_utils.py b/src/openfe/utils/ligand_utils.py similarity index 100% rename from openfe/utils/ligand_utils.py rename to src/openfe/utils/ligand_utils.py diff --git a/openfe/utils/logging_control.py b/src/openfe/utils/logging_control.py similarity index 100% rename from openfe/utils/logging_control.py rename to src/openfe/utils/logging_control.py diff --git a/openfe/utils/network_plotting.py b/src/openfe/utils/network_plotting.py similarity index 100% rename from openfe/utils/network_plotting.py rename to src/openfe/utils/network_plotting.py diff --git a/openfe/utils/optional_imports.py b/src/openfe/utils/optional_imports.py similarity index 100% rename from openfe/utils/optional_imports.py rename to src/openfe/utils/optional_imports.py diff --git a/openfe/utils/remove_oechem.py b/src/openfe/utils/remove_oechem.py similarity index 100% rename from openfe/utils/remove_oechem.py rename to src/openfe/utils/remove_oechem.py diff --git a/openfe/utils/silence_root_logging.py b/src/openfe/utils/silence_root_logging.py similarity index 100% rename from openfe/utils/silence_root_logging.py rename to src/openfe/utils/silence_root_logging.py diff --git a/openfe/utils/system_probe.py b/src/openfe/utils/system_probe.py similarity index 100% rename from openfe/utils/system_probe.py rename to src/openfe/utils/system_probe.py diff --git a/openfe/utils/visualization_3D.py b/src/openfe/utils/visualization_3D.py similarity index 100% rename from openfe/utils/visualization_3D.py rename to src/openfe/utils/visualization_3D.py diff --git a/openfecli/README.md b/src/openfecli/README.md similarity index 100% rename from openfecli/README.md rename to src/openfecli/README.md diff --git a/openfecli/__init__.py b/src/openfecli/__init__.py similarity index 100% rename from openfecli/__init__.py rename to src/openfecli/__init__.py diff --git a/openfecli/cli.py b/src/openfecli/cli.py similarity index 100% rename from openfecli/cli.py rename to src/openfecli/cli.py diff --git a/openfecli/clicktypes/__init__.py b/src/openfecli/clicktypes/__init__.py similarity index 100% rename from openfecli/clicktypes/__init__.py rename to src/openfecli/clicktypes/__init__.py diff --git a/openfecli/clicktypes/hyphenchoice.py b/src/openfecli/clicktypes/hyphenchoice.py similarity index 100% rename from openfecli/clicktypes/hyphenchoice.py rename to src/openfecli/clicktypes/hyphenchoice.py diff --git a/openfecli/tests/__init__.py b/src/openfecli/commands/__init__.py similarity index 100% rename from openfecli/tests/__init__.py rename to src/openfecli/commands/__init__.py diff --git a/openfecli/commands/atommapping.py b/src/openfecli/commands/atommapping.py similarity index 100% rename from openfecli/commands/atommapping.py rename to src/openfecli/commands/atommapping.py diff --git a/openfecli/commands/fetch.py b/src/openfecli/commands/fetch.py similarity index 100% rename from openfecli/commands/fetch.py rename to src/openfecli/commands/fetch.py diff --git a/openfecli/commands/gather.py b/src/openfecli/commands/gather.py similarity index 91% rename from openfecli/commands/gather.py rename to src/openfecli/commands/gather.py index 730118b1..cbb2b93f 100644 --- a/openfecli/commands/gather.py +++ b/src/openfecli/commands/gather.py @@ -7,7 +7,6 @@ import sys from typing import List, Literal import click -import gufe import pandas as pd from openfecli import OFECommandPlugin @@ -221,9 +220,16 @@ def _get_names(result: dict) -> tuple[str, str]: # TODO: I don't like this [0][0] indexing, but I can't think of a better way currently protocol_data = list(result["protocol_result"]["data"].values())[0][0] - - name_A = protocol_data["inputs"]["ligandmapping"]["componentA"]["molprops"]["ofe-name"] - name_B = protocol_data["inputs"]["ligandmapping"]["componentB"]["molprops"]["ofe-name"] + try: + name_A = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentA"][ + "molprops" + ]["ofe-name"] + name_B = protocol_data["inputs"]["setup_results"]["inputs"]["ligandmapping"]["componentB"][ + "molprops" + ]["ofe-name"] + except KeyError: + name_A = protocol_data["inputs"]["ligandmapping"]["componentA"]["molprops"]["ofe-name"] + name_B = protocol_data["inputs"]["ligandmapping"]["componentB"]["molprops"]["ofe-name"] return str(name_A), str(name_B) @@ -232,9 +238,17 @@ def _get_type(result: dict) -> Literal["vacuum", "solvent", "complex"]: """Determine the simulation type based on the component types.""" protocol_data = list(result["protocol_result"]["data"].values())[0][0] - component_types = [ - x["__module__"] for x in protocol_data["inputs"]["stateA"]["components"].values() - ] + try: + component_types = [ + x["__module__"] + for x in protocol_data["inputs"]["setup_results"]["inputs"]["stateA"][ + "components" + ].values() + ] + except KeyError: + component_types = [ + x["__module__"] for x in protocol_data["inputs"]["stateA"]["components"].values() + ] if "gufe.components.solventcomponent" not in component_types: return "vacuum" elif "gufe.components.proteincomponent" in component_types: @@ -613,9 +627,8 @@ def _collect_result_jsons(results: List[os.PathLike | str]) -> List[pathlib.Path # 1) find all possible jsons json_fns = collect_jsons(results) - # 2) filter only result jsons - result_fns = filter(is_results_json, json_fns) + result_fns = list(filter(is_results_json, json_fns)) return result_fns @@ -643,35 +656,45 @@ def _get_legs_from_result_jsons( legs = defaultdict(lambda: defaultdict(list)) - for result_fn in result_fns: - result_info, result = _load_valid_result_json(result_fn) + with click.progressbar( + result_fns, + label="Loading results:", + fill_char="▇", + empty_char=" ", + bar_template="%(label)s %(bar)s %(info)s files", + length=len(result_fns), + show_percent=False, + show_pos=True, + show_eta=False, + ) as bar: + for result_fn in bar: + result_info, result = _load_valid_result_json(result_fn) - if result_info is None: # this means it couldn't find names and/or simtype - continue - names, simtype = result_info - if report.lower() == "raw": - if result is None: - parsed_raw_data = [(None, None)] + if result_info is None: # this means it couldn't find names and/or simtype + continue + names, simtype = result_info + if report.lower() == "raw": + if result is None: + parsed_raw_data = [(None, None)] + else: + parsed_raw_data = [ + ( + v[0]["outputs"]["unit_estimate"], + v[0]["outputs"]["unit_estimate_error"], + ) + for v in result["protocol_result"]["data"].values() + ] + legs[names][simtype].append(parsed_raw_data) else: - parsed_raw_data = [ - ( - v[0]["outputs"]["unit_estimate"], - v[0]["outputs"]["unit_estimate_error"], - ) - for v in result["protocol_result"]["data"].values() - ] - legs[names][simtype].append(parsed_raw_data) - else: - if result is None: - # we want the dict name/simtype entry to exist for error reporting, even if there's no valid data - dGs = [] - else: - dGs = [ - v[0]["outputs"]["unit_estimate"] - for v in result["protocol_result"]["data"].values() - ] - legs[names][simtype].extend(dGs) - + if result is None: + # we want the dict name/simtype entry to exist for error reporting, even if there's no valid data + dGs = [] + else: + dGs = [ + v[0]["outputs"]["unit_estimate"] + for v in result["protocol_result"]["data"].values() + ] + legs[names][simtype].extend(dGs) return legs diff --git a/openfecli/commands/gather_abfe.py b/src/openfecli/commands/gather_abfe.py similarity index 97% rename from openfecli/commands/gather_abfe.py rename to src/openfecli/commands/gather_abfe.py index a541bfc5..3a49f643 100644 --- a/openfecli/commands/gather_abfe.py +++ b/src/openfecli/commands/gather_abfe.py @@ -33,7 +33,12 @@ def _get_name(result: dict) -> str: """ solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0] - name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"] + try: + name = solvent_data["inputs"]["setup_results"]["inputs"]["alchemical_components"]["stateA"][ + 0 + ]["molprops"]["ofe-name"] + except KeyError: + name = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"] return str(name) diff --git a/openfecli/commands/gather_septop.py b/src/openfecli/commands/gather_septop.py similarity index 100% rename from openfecli/commands/gather_septop.py rename to src/openfecli/commands/gather_septop.py diff --git a/openfecli/commands/generate_partial_charges.py b/src/openfecli/commands/generate_partial_charges.py similarity index 100% rename from openfecli/commands/generate_partial_charges.py rename to src/openfecli/commands/generate_partial_charges.py diff --git a/openfecli/commands/plan_rbfe_network.py b/src/openfecli/commands/plan_rbfe_network.py similarity index 100% rename from openfecli/commands/plan_rbfe_network.py rename to src/openfecli/commands/plan_rbfe_network.py diff --git a/openfecli/commands/plan_rhfe_network.py b/src/openfecli/commands/plan_rhfe_network.py similarity index 100% rename from openfecli/commands/plan_rhfe_network.py rename to src/openfecli/commands/plan_rhfe_network.py diff --git a/openfecli/commands/quickrun.py b/src/openfecli/commands/quickrun.py similarity index 100% rename from openfecli/commands/quickrun.py rename to src/openfecli/commands/quickrun.py diff --git a/openfecli/commands/test.py b/src/openfecli/commands/test.py similarity index 63% rename from openfecli/commands/test.py rename to src/openfecli/commands/test.py index d659ed80..47cd06d9 100644 --- a/openfecli/commands/test.py +++ b/src/openfecli/commands/test.py @@ -4,13 +4,23 @@ import sys import click import pytest +from openfe.data import _downloader +from openfe.data._registry import zenodo_data_registry as api_test_data_registry from openfecli import OFECommandPlugin +from openfecli.data._registry import POOCH_CACHE +from openfecli.data._registry import zenodo_data_registry as cli_test_data_registry from openfecli.utils import write @click.command("test", short_help="Run the OpenFE test suite") @click.option('--long', is_flag=True, default=False, help="Run additional tests (takes much longer)") # fmt: skip -def test(long): +@click.option( + "--download-only", + is_flag=True, + default=False, + help="Download data to the cache if not already present (this is helpful if internet is unreliable). If all data exists in the cache, only the cache location is shown.", +) +def test(long, download_only): """ Run the OpenFE test suite. This first checks that OpenFE is correctly imported, and then runs the main test suite, which should take several @@ -22,6 +32,14 @@ def test(long): terminals, these show as green or yellow. Warnings are not a concern. However, You should not see anything that fails or errors (red). """ + + if download_only: + click.echo(f"Checking for test data in cache location:\n{POOCH_CACHE}") + _downloader.retrieve_registry_data( + cli_test_data_registry + api_test_data_registry, POOCH_CACHE + ) + sys.exit(0) + try: old_env = dict(os.environ) os.environ["OFE_SLOW_TESTS"] = str(long) diff --git a/openfecli/commands/view_ligand_network.py b/src/openfecli/commands/view_ligand_network.py similarity index 100% rename from openfecli/commands/view_ligand_network.py rename to src/openfecli/commands/view_ligand_network.py diff --git a/openfecli/tests/commands/__init__.py b/src/openfecli/data/__init__.py similarity index 100% rename from openfecli/tests/commands/__init__.py rename to src/openfecli/data/__init__.py diff --git a/src/openfecli/data/_registry.py b/src/openfecli/data/_registry.py new file mode 100644 index 00000000..f9b811f7 --- /dev/null +++ b/src/openfecli/data/_registry.py @@ -0,0 +1,38 @@ +"""Registry for all remotely-stored CLI test data.""" + +import pooch + +POOCH_CACHE = pooch.os_cache("openfe") +zenodo_cmet_data = dict( + base_url="doi:10.5281/zenodo.15200083/", + fname="cmet_results.tar.gz", + known_hash="md5:a4ca67a907f744c696b09660dc1eb8ec", +) +zenodo_rbfe_serial_data = dict( + base_url="doi:10.5281/zenodo.15042470/", + fname="rbfe_results_serial_repeats.tar.gz", + known_hash="md5:2355ecc80e03242a4c7fcbf20cb45487", +) +zenodo_rbfe_parallel_data = dict( + base_url="doi:10.5281/zenodo.15042470/", + fname="rbfe_results_parallel_repeats.tar.gz", + known_hash="md5:ff7313e14eb6f2940c6ffd50f2192181", +) +zenodo_abfe_data = dict( + base_url="doi:10.5281/zenodo.17348229/", + fname="abfe_results.zip", + known_hash="md5:547f896e867cce61979d75b7e082f6ba", +) +zenodo_septop_data = dict( + base_url="doi:10.5281/zenodo.17435569/", + fname="septop_results.zip", + known_hash="md5:2cfa18da59a20228f5c75a1de6ec879e", +) + +zenodo_data_registry = [ + zenodo_cmet_data, + zenodo_rbfe_serial_data, + zenodo_rbfe_parallel_data, + zenodo_abfe_data, + zenodo_septop_data, +] diff --git a/openfecli/fetchables.py b/src/openfecli/fetchables.py similarity index 100% rename from openfecli/fetchables.py rename to src/openfecli/fetchables.py diff --git a/openfecli/fetching.py b/src/openfecli/fetching.py similarity index 100% rename from openfecli/fetching.py rename to src/openfecli/fetching.py diff --git a/openfecli/parameters/__init__.py b/src/openfecli/parameters/__init__.py similarity index 100% rename from openfecli/parameters/__init__.py rename to src/openfecli/parameters/__init__.py diff --git a/openfecli/parameters/mapper.py b/src/openfecli/parameters/mapper.py similarity index 100% rename from openfecli/parameters/mapper.py rename to src/openfecli/parameters/mapper.py diff --git a/openfecli/parameters/misc.py b/src/openfecli/parameters/misc.py similarity index 100% rename from openfecli/parameters/misc.py rename to src/openfecli/parameters/misc.py diff --git a/openfecli/parameters/mol.py b/src/openfecli/parameters/mol.py similarity index 100% rename from openfecli/parameters/mol.py rename to src/openfecli/parameters/mol.py diff --git a/openfecli/parameters/molecules.py b/src/openfecli/parameters/molecules.py similarity index 100% rename from openfecli/parameters/molecules.py rename to src/openfecli/parameters/molecules.py diff --git a/openfecli/parameters/output.py b/src/openfecli/parameters/output.py similarity index 100% rename from openfecli/parameters/output.py rename to src/openfecli/parameters/output.py diff --git a/openfecli/parameters/output_dir.py b/src/openfecli/parameters/output_dir.py similarity index 100% rename from openfecli/parameters/output_dir.py rename to src/openfecli/parameters/output_dir.py diff --git a/openfecli/parameters/plan_network_options.py b/src/openfecli/parameters/plan_network_options.py similarity index 99% rename from openfecli/parameters/plan_network_options.py rename to src/openfecli/parameters/plan_network_options.py index 5798d69e..b349d897 100644 --- a/openfecli/parameters/plan_network_options.py +++ b/src/openfecli/parameters/plan_network_options.py @@ -205,7 +205,7 @@ def load_yaml_planner_options(path: Optional[str], context) -> PlanNetworkOption # TODO: do we want this in the docs anywhere? DEFAULT_YAML = """ - mapper: KartografAtomMapper + mapper: kartograf settings: atom_max_distance: 0.95 atom_map_hydrogens: true diff --git a/openfecli/parameters/protein.py b/src/openfecli/parameters/protein.py similarity index 100% rename from openfecli/parameters/protein.py rename to src/openfecli/parameters/protein.py diff --git a/openfecli/parameters/utils.py b/src/openfecli/parameters/utils.py similarity index 100% rename from openfecli/parameters/utils.py rename to src/openfecli/parameters/utils.py diff --git a/openfecli/plan_alchemical_networks_utils.py b/src/openfecli/plan_alchemical_networks_utils.py similarity index 100% rename from openfecli/plan_alchemical_networks_utils.py rename to src/openfecli/plan_alchemical_networks_utils.py diff --git a/openfecli/plugins.py b/src/openfecli/plugins.py similarity index 100% rename from openfecli/plugins.py rename to src/openfecli/plugins.py diff --git a/openfecli/tests/data/__init__.py b/src/openfecli/tests/__init__.py similarity index 100% rename from openfecli/tests/data/__init__.py rename to src/openfecli/tests/__init__.py diff --git a/openfecli/tests/clicktypes/test_hyphenchoice.py b/src/openfecli/tests/clicktypes/test_hyphenchoice.py similarity index 100% rename from openfecli/tests/clicktypes/test_hyphenchoice.py rename to src/openfecli/tests/clicktypes/test_hyphenchoice.py diff --git a/openfecli/tests/data/rbfe_tutorial/__init__.py b/src/openfecli/tests/commands/__init__.py similarity index 100% rename from openfecli/tests/data/rbfe_tutorial/__init__.py rename to src/openfecli/tests/commands/__init__.py diff --git a/openfecli/tests/commands/conftest.py b/src/openfecli/tests/commands/conftest.py similarity index 100% rename from openfecli/tests/commands/conftest.py rename to src/openfecli/tests/commands/conftest.py diff --git a/openfecli/tests/commands/test_atommapping.py b/src/openfecli/tests/commands/test_atommapping.py similarity index 91% rename from openfecli/tests/commands/test_atommapping.py rename to src/openfecli/tests/commands/test_atommapping.py index 48c558e5..8b459de6 100644 --- a/openfecli/tests/commands/test_atommapping.py +++ b/src/openfecli/tests/commands/test_atommapping.py @@ -30,7 +30,7 @@ def mapper_args(): @pytest.fixture -def mols(molA_args, molB_args): +def mols_AB(molA_args, molB_args): return MOL.get(molA_args[1]), MOL.get(molB_args[1]) @@ -89,8 +89,8 @@ def test_atommapping_missing_mapper(molA_args, molB_args): @pytest.mark.parametrize("n_mappings", [0, 1, 2]) -def test_generate_mapping(n_mappings, mols): - molA, molB = mols +def test_generate_mapping(n_mappings, mols_AB): + molA, molB = mols_AB mappings = [ LigandAtomMapping(molA, molB, {i: i for i in range(7)}), LigandAtomMapping(molA, molB, {i: (i + 1) % 7 for i in range(7)}), @@ -104,8 +104,8 @@ def test_generate_mapping(n_mappings, mols): generate_mapping(mapper, molA, molB) -def test_atommapping_print_dict_main(capsys, mols): - molA, molB = mols +def test_atommapping_print_dict_main(capsys, mols_AB): + molA, molB = mols_AB mapper = LomapAtomMapper mapping = LigandAtomMapping(molA, molB, {i: i for i in range(7)}) with mock.patch("openfecli.commands.atommapping.generate_mapping", mock.Mock(return_value=mapping)): # fmt: skip @@ -114,14 +114,14 @@ def test_atommapping_print_dict_main(capsys, mols): assert captured.out == str(mapping.componentA_to_componentB) + "\n" -def test_atommapping_visualize_main(mols, tmpdir): - molA, molB = mols +def test_atommapping_visualize_main(mols_AB, tmpdir): + molA, molB = mols_AB mapper = LomapAtomMapper pytest.skip() # TODO: probably with a smoke test -def test_atommapping_visualize_main_bad_extension(mols, tmpdir): - molA, molB = mols +def test_atommapping_visualize_main_bad_extension(mols_AB, tmpdir): + molA, molB = mols_AB mapper = LomapAtomMapper mapping = LigandAtomMapping(molA, molB, {i: i for i in range(7)}) with mock.patch("openfecli.commands.atommapping.generate_mapping", mock.Mock(return_value=mapping)): # fmt: skip diff --git a/openfecli/tests/commands/test_charge_generation.py b/src/openfecli/tests/commands/test_charge_generation.py similarity index 94% rename from openfecli/tests/commands/test_charge_generation.py rename to src/openfecli/tests/commands/test_charge_generation.py index db843c30..ae481cf4 100644 --- a/openfecli/tests/commands/test_charge_generation.py +++ b/src/openfecli/tests/commands/test_charge_generation.py @@ -9,6 +9,10 @@ from openff.toolkit import Molecule from openff.units import unit from openff.utilities.testing import skip_if_missing +from openfe.protocols.openmm_utils.charge_generation import ( + HAS_NAGL, + HAS_OPENEYE, +) from openfecli.commands.generate_partial_charges import charge_molecules @@ -122,8 +126,13 @@ def test_charge_molecules_overwrite( pytest.param(2, id="2"), ], ) -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_charge_settings(methane, tmpdir, caplog, yaml_nagl_settings, ncores): runner = CliRunner() mol_path = tmpdir / "methane.sdf" diff --git a/openfecli/tests/commands/test_gather.py b/src/openfecli/tests/commands/test_gather.py similarity index 89% rename from openfecli/tests/commands/test_gather.py rename to src/openfecli/tests/commands/test_gather.py index 498f4a2e..9584482a 100644 --- a/openfecli/tests/commands/test_gather.py +++ b/src/openfecli/tests/commands/test_gather.py @@ -17,25 +17,34 @@ from openfecli.commands.gather import ( ) from openfecli.commands.gather_abfe import gather_abfe from openfecli.commands.gather_septop import gather_septop +from openfecli.data._registry import ( + POOCH_CACHE, + zenodo_abfe_data, + zenodo_cmet_data, + zenodo_rbfe_parallel_data, + zenodo_rbfe_serial_data, + zenodo_septop_data, +) from ..conftest import HAS_INTERNET from ..utils import assert_click_success -POOCH_CACHE = pooch.os_cache("openfe") -ZENODO_RBFE_DATA = pooch.create( +pooch_rbfe_serial = pooch.create( path=POOCH_CACHE, - base_url="doi:10.5281/zenodo.15042470", - registry={ - "rbfe_results_serial_repeats.tar.gz": "md5:2355ecc80e03242a4c7fcbf20cb45487", - "rbfe_results_parallel_repeats.tar.gz": "md5:ff7313e14eb6f2940c6ffd50f2192181", - }, - retry_if_failed=5, + base_url=zenodo_rbfe_serial_data["base_url"], + registry={zenodo_rbfe_serial_data["fname"]: zenodo_rbfe_serial_data["known_hash"]}, ) -ZENODO_CMET_DATA = pooch.create( + +pooch_rbfe_parallel = pooch.create( path=POOCH_CACHE, - base_url="doi:10.5281/zenodo.15200083", - registry={"cmet_results.tar.gz": "md5:a4ca67a907f744c696b09660dc1eb8ec"}, - retry_if_failed=5, + base_url=zenodo_rbfe_parallel_data["base_url"], + registry={zenodo_rbfe_parallel_data["fname"]: zenodo_rbfe_parallel_data["known_hash"]}, +) + +pooch_cmet = pooch.create( + path=POOCH_CACHE, + base_url=zenodo_cmet_data["base_url"], + registry={zenodo_cmet_data["fname"]: zenodo_cmet_data["known_hash"]}, ) @@ -140,6 +149,7 @@ def test_no_results_found(): _RBFE_EXPECTED_DG = b""" +Loading results: ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_ejm_31 -0.09 0.05 lig_ejm_42 0.7 0.1 @@ -154,6 +164,7 @@ lig_jmc_28 -1.25 0.08 """ _RBFE_EXPECTED_DDG = b""" +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_ejm_31 lig_ejm_42 0.8 0.1 lig_ejm_31 lig_ejm_46 -0.89 0.06 @@ -167,6 +178,7 @@ lig_ejm_46 lig_jmc_28 -0.27 0.06 """ _RBFE_EXPECTED_RAW = b"""\ +Loading results: leg ligand_i ligand_j DG(i->j) (kcal/mol) MBAR uncertainty (kcal/mol) complex lig_ejm_31 lig_ejm_42 -14.9 0.8 complex lig_ejm_31 lig_ejm_42 -14.8 0.8 @@ -225,19 +237,9 @@ solvent lig_ejm_46 lig_jmc_28 23.4 0.8 """ -@pytest.fixture -def rbfe_result_dir() -> pathlib.Path: - def _rbfe_result_dir(dataset) -> str: - ZENODO_RBFE_DATA.fetch(f"{dataset}.tar.gz", processor=pooch.Untar()) - cache_dir = pathlib.Path(POOCH_CACHE) / f"{dataset}.tar.gz.untar/{dataset}/" - return cache_dir - - return _rbfe_result_dir - - @pytest.fixture def cmet_result_dir() -> pathlib.Path: - ZENODO_CMET_DATA.fetch("cmet_results.tar.gz", processor=pooch.Untar()) + pooch_cmet.fetch("cmet_results.tar.gz", processor=pooch.Untar()) result_dir = pathlib.Path(POOCH_CACHE) / "cmet_results.tar.gz.untar/cmet_results/" return result_dir @@ -346,14 +348,34 @@ class TestGatherCMET: assert pathlib.Path(fname).is_file() +@pytest.fixture +def rbfe_results_serial_dir() -> pathlib.Path: + pooch_rbfe_serial.fetch("rbfe_results_serial_repeats.tar.gz", processor=pooch.Untar()) + result_dir = ( + pathlib.Path(POOCH_CACHE) + / "rbfe_results_serial_repeats.tar.gz.untar/rbfe_results_serial_repeats/" + ) + return result_dir + + +@pytest.fixture +def rbfe_results_parallel_dir() -> pathlib.Path: + pooch_rbfe_parallel.fetch("rbfe_results_parallel_repeats.tar.gz", processor=pooch.Untar()) + result_dir = ( + pathlib.Path(POOCH_CACHE) + / "rbfe_results_parallel_repeats.tar.gz.untar/rbfe_results_parallel_repeats/" + ) + return result_dir + + @pytest.mark.skipif( not os.path.exists(POOCH_CACHE) and not HAS_INTERNET, reason="Internet seems to be unavailable and test data is not cached locally.", ) -@pytest.mark.parametrize("dataset", ["rbfe_results_serial_repeats", "rbfe_results_parallel_repeats"]) # fmt: skip +@pytest.mark.parametrize("dataset", ["rbfe_results_serial_dir", "rbfe_results_parallel_dir"]) # fmt: skip @pytest.mark.parametrize("report", ["", "dg", "ddg", "raw"]) @pytest.mark.parametrize("input_mode", ["directory", "filepaths"]) -def test_rbfe_gather(rbfe_result_dir, dataset, report, input_mode): +def test_rbfe_gather(request, dataset, report, input_mode): expected = { "": _RBFE_EXPECTED_DG, "dg": _RBFE_EXPECTED_DG, @@ -367,7 +389,7 @@ def test_rbfe_gather(rbfe_result_dir, dataset, report, input_mode): else: args = [] - results = rbfe_result_dir(dataset) + results = request.getfixturevalue(dataset) if input_mode == "directory": results = [str(results)] elif input_mode == "filepaths": @@ -382,11 +404,11 @@ def test_rbfe_gather(rbfe_result_dir, dataset, report, input_mode): assert set(expected.split(b"\n")) == actual_lines -def test_rbfe_gather_single_repeats_dg_error(rbfe_result_dir): +def test_rbfe_gather_single_repeats_dg_error(rbfe_results_parallel_dir): """A single repeat is insufficient for a dg calculation - should fail cleanly.""" runner = CliRunner() - results = rbfe_result_dir("rbfe_results_parallel_repeats") + results = rbfe_results_parallel_dir args = ["report", "dg"] cli_result = runner.invoke(gather, [f"{results}/replicate_0"] + args + ["--tsv"]) assert cli_result.exit_code == 1 @@ -398,9 +420,9 @@ def test_rbfe_gather_single_repeats_dg_error(rbfe_result_dir): ) class TestRBFEGatherFailedEdges: @pytest.fixture() - def results_paths_serial_missing_legs(self, rbfe_result_dir) -> str: + def results_paths_serial_missing_legs(self, rbfe_results_serial_dir) -> str: """Example output data, with replicates run in serial and two missing results JSONs.""" - result_dir = rbfe_result_dir("rbfe_results_serial_repeats") + result_dir = rbfe_results_serial_dir results = glob.glob(f"{result_dir}/*", recursive=True) files_to_skip = [ @@ -441,14 +463,13 @@ class TestRBFEGatherFailedEdges: ZENODO_ABFE_DATA = pooch.create( path=POOCH_CACHE, - base_url="doi:10.5281/zenodo.17348229", - registry={"abfe_results.zip": "md5:547f896e867cce61979d75b7e082f6ba"}, + base_url=zenodo_abfe_data["base_url"], + registry={zenodo_abfe_data["fname"]: zenodo_abfe_data["known_hash"]}, ) ZENODO_SEPTOP_DATA = pooch.create( path=POOCH_CACHE, - base_url="doi:10.5281/zenodo.17435569", - registry={"septop_results.zip": "md5:2cfa18da59a20228f5c75a1de6ec879e"}, - retry_if_failed=2, + base_url=zenodo_septop_data["base_url"], + registry={zenodo_septop_data["fname"]: zenodo_septop_data["known_hash"]}, ) diff --git a/openfecli/tests/commands/test_gather/test_abfe_full_results_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_abfe_full_results_dg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_abfe_full_results_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_abfe_full_results_dg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_abfe_full_results_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_abfe_full_results_raw_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_abfe_full_results_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_abfe_full_results_raw_.tsv diff --git a/openfecli/tests/commands/test_gather/test_abfe_single_repeat_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_abfe_single_repeat_dg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_abfe_single_repeat_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_abfe_single_repeat_dg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_abfe_single_repeat_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_abfe_single_repeat_raw_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_abfe_single_repeat_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_abfe_single_repeat_raw_.tsv diff --git a/openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv similarity index 87% rename from openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv index ba72b148..8aa5dd8a 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_ddg_.tsv @@ -1,2 +1,3 @@ +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 1.1 0.1 diff --git a/openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv similarity index 96% rename from openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv index 209b47fa..4d7f5bf7 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_failed_edge_raw_.tsv @@ -1,3 +1,4 @@ +Loading results: leg ligand_i ligand_j DG(i->j) (kcal/mol) MBAR uncertainty (kcal/mol) complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.54 0.06 complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.31 0.06 diff --git a/openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv similarity index 94% rename from openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv index 5bf680b0..a9bbcca5 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_ddg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 1.2 0.1 lig_CHEMBL3402745_200_5 lig_CHEMBL3402749_500_9 3.6 0.2 diff --git a/openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv similarity index 92% rename from openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv index 9fffe426..4fc2ae0b 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_dg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 -0.3 0.1 lig_CHEMBL3402744_300_4 0.9 0.2 diff --git a/openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv similarity index 98% rename from openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv index 5ca82584..1dd30fa3 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_full_results_raw_.tsv @@ -1,3 +1,4 @@ +Loading results: leg ligand_i ligand_j DG(i->j) (kcal/mol) MBAR uncertainty (kcal/mol) complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.54 0.06 complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.31 0.06 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv similarity index 94% rename from openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv index d4287452..0823aad6 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_allow_partial_ddg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 Error Error lig_CHEMBL3402745_200_5 lig_CHEMBL3402749_500_9 Error Error diff --git a/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_ddg_.tsv new file mode 100644 index 00000000..989bf55d --- /dev/null +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_ddg_.tsv @@ -0,0 +1 @@ +Loading results: diff --git a/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_dg_.tsv new file mode 100644 index 00000000..989bf55d --- /dev/null +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_dg_.tsv @@ -0,0 +1 @@ +Loading results: diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv similarity index 94% rename from openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv index c16311fa..c314f562 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_ddg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 1.2 0.1 lig_CHEMBL3402745_200_5 lig_CHEMBL3402749_500_9 3.6 0.2 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv similarity index 92% rename from openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv index 9f6c94f3..4c4ce05c 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_dg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 -0.3 0.1 lig_CHEMBL3402744_300_4 0.8 0.2 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv similarity index 98% rename from openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv index a203a4e4..999936b2 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_complex_leg_raw_.tsv @@ -1,3 +1,4 @@ +Loading results: leg ligand_i ligand_j DG(i->j) (kcal/mol) MBAR uncertainty (kcal/mol) complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.54 0.06 complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.31 0.06 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv similarity index 93% rename from openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv index 56dac09e..7efe75da 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_ddg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand_i ligand_j DDG(i->j) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 1.2 0.1 lig_CHEMBL3402745_200_5 lig_CHEMBL3402749_500_9 3.6 0.2 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv similarity index 91% rename from openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv index 6fe7ea24..7573757a 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_dg_.tsv @@ -1,3 +1,4 @@ +Loading results: ligand DG(MLE) (kcal/mol) uncertainty (kcal/mol) lig_CHEMBL3402745_200_5 -0.90 0.08 lig_CHEMBL3402744_300_4 0.3 0.1 diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv similarity index 98% rename from openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv index 451fa9bf..339578aa 100644 --- a/openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv +++ b/src/openfecli/tests/commands/test_gather/test_cmet_missing_edge_raw_.tsv @@ -1,3 +1,4 @@ +Loading results: leg ligand_i ligand_j DG(i->j) (kcal/mol) MBAR uncertainty (kcal/mol) complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.54 0.06 complex lig_CHEMBL3402745_200_5 lig_CHEMBL3402744_300_4 -12.31 0.06 diff --git a/openfecli/tests/commands/test_gather/test_septop_full_results_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_full_results_ddg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_full_results_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_full_results_ddg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_septop_full_results_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_full_results_dg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_full_results_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_full_results_dg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_septop_full_results_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_full_results_raw_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_full_results_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_full_results_raw_.tsv diff --git a/openfecli/tests/commands/test_gather/test_septop_single_repeat_ddg_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_single_repeat_ddg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_single_repeat_ddg_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_single_repeat_ddg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_septop_single_repeat_dg_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_single_repeat_dg_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_single_repeat_dg_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_single_repeat_dg_.tsv diff --git a/openfecli/tests/commands/test_gather/test_septop_single_repeat_raw_.tsv b/src/openfecli/tests/commands/test_gather/test_septop_single_repeat_raw_.tsv similarity index 100% rename from openfecli/tests/commands/test_gather/test_septop_single_repeat_raw_.tsv rename to src/openfecli/tests/commands/test_gather/test_septop_single_repeat_raw_.tsv diff --git a/openfecli/tests/commands/test_ligand_network_viewer.py b/src/openfecli/tests/commands/test_ligand_network_viewer.py similarity index 100% rename from openfecli/tests/commands/test_ligand_network_viewer.py rename to src/openfecli/tests/commands/test_ligand_network_viewer.py diff --git a/openfecli/tests/commands/test_plan_rbfe_network.py b/src/openfecli/tests/commands/test_plan_rbfe_network.py similarity index 92% rename from openfecli/tests/commands/test_plan_rbfe_network.py rename to src/openfecli/tests/commands/test_plan_rbfe_network.py index dda23dac..ed380512 100644 --- a/openfecli/tests/commands/test_plan_rbfe_network.py +++ b/src/openfecli/tests/commands/test_plan_rbfe_network.py @@ -10,7 +10,10 @@ from gufe import AlchemicalNetwork, SmallMoleculeComponent from openff.units import unit from openff.utilities import skip_if_missing -from openfe.protocols.openmm_utils.charge_generation import HAS_OPENEYE +from openfe.protocols.openmm_utils.charge_generation import ( + HAS_NAGL, + HAS_OPENEYE, +) from openfecli.commands.plan_rbfe_network import ( plan_rbfe_network, plan_rbfe_network_main, @@ -70,8 +73,13 @@ def validate_charges(smc): assert len(off_mol.partial_charges) == off_mol.n_atoms -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rbfe_network_main(): from gufe import ( ProteinComponent, @@ -137,8 +145,13 @@ partial_charge: """ -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rbfe_network(mol_dir_args, protein_args, tmpdir, yaml_nagl_settings): """ smoke test @@ -218,8 +231,13 @@ def test_plan_rbfe_network_n_repeats(mol_dir_args, protein_args, input_n_repeat, pytest.param(False, id="No overwrite"), ], ) -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rbfe_network_charge_overwrite(dummy_charge_dir_args, protein_args, tmpdir, yaml_nagl_settings, overwrite): # fmt: skip # make sure the dummy charges are overwritten when requested @@ -268,9 +286,13 @@ def eg5_files(): yield pdb_path, lig_path, cof_path -@pytest.mark.xfail(HAS_OPENEYE, reason="openff-nagl#177") -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rbfe_network_cofactors(eg5_files, tmpdir, yaml_nagl_settings): # use nagl charges for CI speed! settings_path = tmpdir / "settings.yaml" @@ -372,8 +394,13 @@ partial_charge: """ -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_lomap_yaml_plan_rbfe_smoke_test(lomap_yaml_settings, cdk8_files, tmpdir): protein, ligand = cdk8_files settings_path = tmpdir / "settings.yaml" @@ -413,9 +440,13 @@ partial_charge: """ -@pytest.mark.xfail(HAS_OPENEYE, reason="openff-nagl#177") -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_custom_yaml_plan_radial_smoke_test(custom_yaml_radial, eg5_files, tmpdir): protein, ligand, cofactor = eg5_files settings_path = tmpdir / "settings.yaml" diff --git a/openfecli/tests/commands/test_plan_rhfe_network.py b/src/openfecli/tests/commands/test_plan_rhfe_network.py similarity index 90% rename from openfecli/tests/commands/test_plan_rhfe_network.py rename to src/openfecli/tests/commands/test_plan_rhfe_network.py index e7de9262..a3780d79 100644 --- a/openfecli/tests/commands/test_plan_rhfe_network.py +++ b/src/openfecli/tests/commands/test_plan_rhfe_network.py @@ -10,6 +10,10 @@ from gufe import AlchemicalNetwork, SmallMoleculeComponent, SolventComponent from gufe.tokenization import JSON_HANDLER from openff.utilities.testing import skip_if_missing +from openfe.protocols.openmm_utils.charge_generation import ( + HAS_NAGL, + HAS_OPENEYE, +) from openfecli.commands.plan_rhfe_network import ( plan_rhfe_network, plan_rhfe_network_main, @@ -54,8 +58,13 @@ def validate_charges(smc): assert len(off_mol.partial_charges) == off_mol.n_atoms -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rhfe_network_main(): from openfe.protocols.openmm_utils.omm_settings import OpenFFPartialChargeSettings from openfe.setup import ( @@ -102,8 +111,13 @@ partial_charge: """ -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rhfe_network(mol_dir_args, tmpdir, yaml_nagl_settings): """ smoke test @@ -175,8 +189,13 @@ partial_charge: """ -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_custom_yaml_plan_rhfe_smoke_test(custom_yaml_settings, mol_dir_args, tmpdir): settings_path = tmpdir / "settings.yaml" with open(settings_path, "w") as f: @@ -201,8 +220,13 @@ def test_custom_yaml_plan_rhfe_smoke_test(custom_yaml_settings, mol_dir_args, tm pytest.param(False, id="No overwrite"), ], ) -@skip_if_missing("openff.nagl") -@skip_if_missing("openff.nagl_models") +@pytest.mark.skipif( + not HAS_NAGL, + reason="needs NAGL", +) +@pytest.mark.skipif( + HAS_OPENEYE, reason="cannot use NAGL with rdkit backend when OpenEye is installed" +) def test_plan_rhfe_network_charge_overwrite(dummy_charge_dir_args, tmpdir, yaml_nagl_settings, overwrite): # fmt: skip # make sure the dummy charges are overwritten when requested diff --git a/openfecli/tests/commands/test_quickrun.py b/src/openfecli/tests/commands/test_quickrun.py similarity index 100% rename from openfecli/tests/commands/test_quickrun.py rename to src/openfecli/tests/commands/test_quickrun.py diff --git a/openfecli/tests/commands/test_test.py b/src/openfecli/tests/commands/test_test.py similarity index 100% rename from openfecli/tests/commands/test_test.py rename to src/openfecli/tests/commands/test_test.py diff --git a/openfecli/tests/conftest.py b/src/openfecli/tests/conftest.py similarity index 100% rename from openfecli/tests/conftest.py rename to src/openfecli/tests/conftest.py diff --git a/openfecli/tests/dev/__init__.py b/src/openfecli/tests/data/__init__.py similarity index 100% rename from openfecli/tests/dev/__init__.py rename to src/openfecli/tests/data/__init__.py diff --git a/openfecli/tests/data/bad_transformation.json b/src/openfecli/tests/data/bad_transformation.json similarity index 100% rename from openfecli/tests/data/bad_transformation.json rename to src/openfecli/tests/data/bad_transformation.json diff --git a/openfecli/tests/data/rbfe_results.tar.gz b/src/openfecli/tests/data/rbfe_results.tar.gz similarity index 100% rename from openfecli/tests/data/rbfe_results.tar.gz rename to src/openfecli/tests/data/rbfe_results.tar.gz diff --git a/openfecli/tests/parameters/__init__.py b/src/openfecli/tests/data/rbfe_tutorial/__init__.py similarity index 100% rename from openfecli/tests/parameters/__init__.py rename to src/openfecli/tests/data/rbfe_tutorial/__init__.py diff --git a/openfecli/tests/data/rbfe_tutorial/tyk2_ligands.sdf b/src/openfecli/tests/data/rbfe_tutorial/tyk2_ligands.sdf similarity index 100% rename from openfecli/tests/data/rbfe_tutorial/tyk2_ligands.sdf rename to src/openfecli/tests/data/rbfe_tutorial/tyk2_ligands.sdf diff --git a/openfecli/tests/data/rbfe_tutorial/tyk2_protein.pdb b/src/openfecli/tests/data/rbfe_tutorial/tyk2_protein.pdb similarity index 100% rename from openfecli/tests/data/rbfe_tutorial/tyk2_protein.pdb rename to src/openfecli/tests/data/rbfe_tutorial/tyk2_protein.pdb diff --git a/openfecli/tests/data/transformation.json b/src/openfecli/tests/data/transformation.json similarity index 100% rename from openfecli/tests/data/transformation.json rename to src/openfecli/tests/data/transformation.json diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_ddg_.tsv b/src/openfecli/tests/dev/__init__.py similarity index 100% rename from openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_ddg_.tsv rename to src/openfecli/tests/dev/__init__.py diff --git a/openfecli/tests/dev/write_transformation_json.py b/src/openfecli/tests/dev/write_transformation_json.py similarity index 100% rename from openfecli/tests/dev/write_transformation_json.py rename to src/openfecli/tests/dev/write_transformation_json.py diff --git a/openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_dg_.tsv b/src/openfecli/tests/parameters/__init__.py similarity index 100% rename from openfecli/tests/commands/test_gather/test_cmet_missing_all_complex_legs_fail_dg_.tsv rename to src/openfecli/tests/parameters/__init__.py diff --git a/openfecli/tests/parameters/test_mapper.py b/src/openfecli/tests/parameters/test_mapper.py similarity index 100% rename from openfecli/tests/parameters/test_mapper.py rename to src/openfecli/tests/parameters/test_mapper.py diff --git a/openfecli/tests/parameters/test_mol.py b/src/openfecli/tests/parameters/test_mol.py similarity index 100% rename from openfecli/tests/parameters/test_mol.py rename to src/openfecli/tests/parameters/test_mol.py diff --git a/openfecli/tests/parameters/test_molecules.py b/src/openfecli/tests/parameters/test_molecules.py similarity index 100% rename from openfecli/tests/parameters/test_molecules.py rename to src/openfecli/tests/parameters/test_molecules.py diff --git a/openfecli/tests/parameters/test_output.py b/src/openfecli/tests/parameters/test_output.py similarity index 100% rename from openfecli/tests/parameters/test_output.py rename to src/openfecli/tests/parameters/test_output.py diff --git a/openfecli/tests/parameters/test_output_dir.py b/src/openfecli/tests/parameters/test_output_dir.py similarity index 100% rename from openfecli/tests/parameters/test_output_dir.py rename to src/openfecli/tests/parameters/test_output_dir.py diff --git a/openfecli/tests/parameters/test_plan_network_options.py b/src/openfecli/tests/parameters/test_plan_network_options.py similarity index 100% rename from openfecli/tests/parameters/test_plan_network_options.py rename to src/openfecli/tests/parameters/test_plan_network_options.py diff --git a/openfecli/tests/parameters/test_protein.py b/src/openfecli/tests/parameters/test_protein.py similarity index 100% rename from openfecli/tests/parameters/test_protein.py rename to src/openfecli/tests/parameters/test_protein.py diff --git a/openfecli/tests/parameters/test_utils.py b/src/openfecli/tests/parameters/test_utils.py similarity index 100% rename from openfecli/tests/parameters/test_utils.py rename to src/openfecli/tests/parameters/test_utils.py diff --git a/openfecli/tests/test_cli.py b/src/openfecli/tests/test_cli.py similarity index 100% rename from openfecli/tests/test_cli.py rename to src/openfecli/tests/test_cli.py diff --git a/openfecli/tests/test_fetchables.py b/src/openfecli/tests/test_fetchables.py similarity index 100% rename from openfecli/tests/test_fetchables.py rename to src/openfecli/tests/test_fetchables.py diff --git a/openfecli/tests/test_fetching.py b/src/openfecli/tests/test_fetching.py similarity index 100% rename from openfecli/tests/test_fetching.py rename to src/openfecli/tests/test_fetching.py diff --git a/openfecli/tests/test_plugins.py b/src/openfecli/tests/test_plugins.py similarity index 100% rename from openfecli/tests/test_plugins.py rename to src/openfecli/tests/test_plugins.py diff --git a/openfecli/tests/test_rbfe_tutorial.py b/src/openfecli/tests/test_rbfe_tutorial.py similarity index 61% rename from openfecli/tests/test_rbfe_tutorial.py rename to src/openfecli/tests/test_rbfe_tutorial.py index bb06ed65..1ace81cd 100644 --- a/openfecli/tests/test_rbfe_tutorial.py +++ b/src/openfecli/tests/test_rbfe_tutorial.py @@ -8,8 +8,10 @@ Tests the easy start guide import os from importlib import resources from os import path +from pathlib import Path from unittest import mock +import numpy as np import pytest from click.testing import CliRunner from openff.units import unit @@ -89,28 +91,10 @@ def test_plan_tyk2(tyk2_ligands, tyk2_protein, expected_transformations): assert "n_protocol_repeats=3" in result.output -@pytest.fixture -def mock_execute(expected_transformations): - def fake_execute(*args, **kwargs): - return { - "repeat_id": kwargs["repeat_id"], - "generation": kwargs["generation"], - "nc": "file.nc", - "last_checkpoint": "checkpoint.nc", - "unit_estimate": 4.2 * unit.kilocalories_per_mole, - } - - with mock.patch( - "openfe.protocols.openmm_rfe.equil_rfe_methods.RelativeHybridTopologyProtocolUnit._execute" - ) as m: - m.side_effect = fake_execute - - yield m - - @pytest.fixture def ref_gather(): return """\ +Loading results: ligand_i\tligand_j\tDDG(i->j) (kcal/mol)\tuncertainty (kcal/mol) lig_ejm_31\tlig_ejm_46\t0.0\t0.0 lig_ejm_31\tlig_ejm_47\t0.0\t0.0 @@ -124,10 +108,67 @@ lig_jmc_27\tlig_jmc_28\t0.0\t0.0 """ -def test_run_tyk2(tyk2_ligands, tyk2_protein, expected_transformations, mock_execute, ref_gather): +@pytest.fixture +def fake_setup_execute_results(): + """Use for mocking the expensive _execute step and instead directly return plausible results.""" + + def _fake_execute_results(*args, **kwargs): + return { + "repeat_id": kwargs["repeat_id"], + "generation": kwargs["generation"], + "system": Path("system.xml.bz2"), + "positions": Path("positions.npy"), + "pdb_structure": Path("hybrid_system.pdb"), + "selection_indices": np.arange(50), + } + + return _fake_execute_results + + +@pytest.fixture +def fake_sim_execute_results(): + """Use for mocking the expensive _execute step and instead directly return plausible results.""" + + def _fake_execute_results(*args, **kwargs): + return { + "repeat_id": kwargs["repeat_id"], + "generation": kwargs["generation"], + "nc": Path("file.nc"), + "checkpoint": Path("chk.chk"), + } + + return _fake_execute_results + + +@pytest.fixture +def fake_analysis_execute_results(): + """Use for mocking the expensive _execute step and instead directly return plausible results.""" + + def _fake_execute_results(*args, **kwargs): + return { + "repeat_id": kwargs["repeat_id"], + "generation": kwargs["generation"], + "pdb_structure": Path("hybrid_system.pdb"), + "checkpoint": Path("chk.chk"), + "selection_indices": np.arange(50), + "unit_estimate": 4.2 * unit.kilocalories_per_mole, + } + + return _fake_execute_results + + +def test_run_tyk2( + tyk2_ligands, + tyk2_protein, + expected_transformations, + fake_setup_execute_results, + fake_sim_execute_results, + fake_analysis_execute_results, + ref_gather, +): runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke( + result_setup = runner.invoke( plan_rbfe_network, [ "-M", tyk2_ligands, @@ -135,14 +176,27 @@ def test_run_tyk2(tyk2_ligands, tyk2_protein, expected_transformations, mock_exe ], ) # fmt: skip - assert_click_success(result) + assert_click_success(result_setup) + with ( + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologySetupUnit._execute", + side_effect=fake_setup_execute_results, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateSimulationUnit._execute", + side_effect=fake_sim_execute_results, + ), + mock.patch( + "openfe.protocols.openmm_rfe.hybridtop_units.HybridTopologyMultiStateAnalysisUnit._execute", + side_effect=fake_analysis_execute_results, + ), + ): + for f in expected_transformations: + fn = path.join("alchemicalNetwork/transformations", f) + result_run = runner.invoke(quickrun, [fn]) + assert_click_success(result_run) - for f in expected_transformations: - fn = path.join("alchemicalNetwork/transformations", f) - result2 = runner.invoke(quickrun, [fn]) - assert_click_success(result2) + result_gather = runner.invoke(gather, ["--report", "ddg", ".", "--tsv"]) - gather_result = runner.invoke(gather, ["--report", "ddg", ".", "--tsv"]) - - assert_click_success(gather_result) - assert gather_result.stdout == ref_gather + assert_click_success(result_gather) + assert result_gather.stdout == ref_gather diff --git a/openfecli/tests/test_utils.py b/src/openfecli/tests/test_utils.py similarity index 100% rename from openfecli/tests/test_utils.py rename to src/openfecli/tests/test_utils.py diff --git a/openfecli/tests/utils.py b/src/openfecli/tests/utils.py similarity index 100% rename from openfecli/tests/utils.py rename to src/openfecli/tests/utils.py diff --git a/openfecli/utils.py b/src/openfecli/utils.py similarity index 100% rename from openfecli/utils.py rename to src/openfecli/utils.py