Merge remote-tracking branch 'origin/main' into test/pymbar-with-numba

This commit is contained in:
Mike Henry
2026-02-12 15:57:43 -07:00
355 changed files with 8827 additions and 5938 deletions

View File

@@ -14,7 +14,7 @@ see https://regro.github.io/rever-docs/news.html for details on how to add news
Checklist
* [ ] All new code is appropriately documented (user-facing code _must_ have complete docstrings).
* [ ] Added a ``news`` entry, or the changes are not user-facing.
* [ ] Ran pre-commit by making a comment with `pre-commit.ci autofix` before requesting review.
* [ ] Ran pre-commit: you can run [pre-commit](https://pre-commit.com) locally or comment on this PR with `pre-commit.ci autofix`.
Manual Tests: these are slow so don't need to be run every commit, only before merging and when relevant changes are made (generally at reviewer-discretion).
* [ ] [GPU integration tests](https://github.com/OpenFreeEnergy/openfe/actions/workflows/gpu-integration-tests.yaml)

View File

@@ -35,7 +35,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ["ubuntu-latest", "macos-latest"]
os: ["ubuntu-latest"]
openeye: ["no"]
python-version:
- "3.11"
@@ -45,6 +45,9 @@ jobs:
- os: "ubuntu-latest"
python-version: "3.13"
openeye: "yes"
- os: "macos-latest"
python-version: "3.12"
openeye: "no"
env:
OE_LICENSE: ${{ github.workspace }}/oe_license.txt

View File

@@ -20,11 +20,15 @@ jobs:
strategy:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest']
os: ['ubuntu-latest']
python-version:
- "3.11"
- "3.12"
- "3.13"
include:
- os: "macos-latest"
python-version: "3.12"
openeye: "no"
steps:
- name: Checkout Code
uses: actions/checkout@v4

View File

@@ -92,7 +92,7 @@ jobs:
DUECREDIT_ENABLE: 'yes'
OFE_INTEGRATION_TESTS: FALSE
run: |
pytest -n logical -vv --durations=10 --runslow openfecli/tests/ openfe/tests/
pytest -n logical -vv --durations=10 --runslow src/openfecli/tests/ src/openfe/tests/
stop-aws-runner:
runs-on: ubuntu-latest

View File

@@ -96,7 +96,7 @@ jobs:
OFE_INTEGRATION_TESTS: TRUE
run: |
# The -m flag will only run tests with @pytest.mark.integration
pytest -n logical -vv --durations=10 -m integration openfecli/tests/ openfe/tests/
pytest -n logical -vv --durations=10 -m integration src/openfecli/tests/ src/openfe/tests/
stop-aws-runner:
runs-on: ubuntu-latest

View File

@@ -12,7 +12,7 @@ defaults:
shell: bash -leo pipefail {0}
jobs:
test-conda-build:
test-example-notebooks:
runs-on: ubuntu-latest
steps:

View File

@@ -10,6 +10,7 @@ repos:
rev: v6.0.0
hooks:
- id: check-added-large-files
args: ["--maxkb=900"]
- id: check-case-conflict
- id: check-executables-have-shebangs
- id: check-symlinks
@@ -19,12 +20,12 @@ repos:
- id: debug-statements
- repo: https://github.com/tox-dev/pyproject-fmt
rev: "v2.8.0"
rev: "v2.11.1"
hooks:
- id: pyproject-fmt
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.13.3
rev: v0.14.10
hooks:
# Run the linter.
- id: ruff

View File

@@ -1,18 +1,18 @@
recursive-include openfe/tests/data/ *.sdf
recursive-include openfe/tests/data/ *.bz2
recursive-include openfe/tests/data/ *.csv
recursive-include openfe/tests/data/ *.pdb
recursive-include openfe/tests/data/ *.mol2
recursive-include openfe/tests/data/ *.xml
recursive-include openfe/tests/data/ *.graphml
recursive-include openfe/tests/data/ *.edge
recursive-include openfe/tests/data/ *.dat
recursive-include openfe/tests/data/ *.txt
recursive-include openfe/tests/data/ *.gz
recursive-include openfe/tests/data/ *json_results.gz
include openfecli/tests/data/*.json
include openfecli/tests/data/*.tar.gz
include openfecli/tests/commands/test_gather/*.tsv
recursive-include openfecli/tests/ *.sdf
recursive-include openfecli/tests/ *.pdb
include openfe/tests/data/openmm_rfe/vacuum_nocoord.nc
recursive-include src/openfe/tests/data/ *.sdf
recursive-include src/openfe/tests/data/ *.bz2
recursive-include src/openfe/tests/data/ *.csv
recursive-include src/openfe/tests/data/ *.pdb
recursive-include src/openfe/tests/data/ *.mol2
recursive-include src/openfe/tests/data/ *.xml
recursive-include src/openfe/tests/data/ *.graphml
recursive-include src/openfe/tests/data/ *.edge
recursive-include src/openfe/tests/data/ *.dat
recursive-include src/openfe/tests/data/ *.txt
recursive-include src/openfe/tests/data/ *.gz
recursive-include src/openfe/tests/data/ *json_results.gz
include src/openfecli/tests/data/*.json
include src/openfecli/tests/data/*.tar.gz
include src/openfecli/tests/commands/test_gather/*.tsv
recursive-include src/openfecli/tests/ *.sdf
recursive-include src/openfecli/tests/ *.pdb
include src/openfe/tests/data/openmm_rfe/vacuum_nocoord.nc

View File

@@ -4,6 +4,55 @@ Changelog
.. current developments
v1.9.0
====================
**Added:**
* The ``validate`` method for the RelativeHybridTopologyProtocol has been implemented.
This means that settings and system validation can mostly be done prior to Protocol execution by calling ``RelativeHybridTopologyProtocol.validate(stateA, stateB, mapping)`` (`PR 1740 <https://github.com/OpenFreeEnergy/openfe/pull/1740>`_).
* Added ``openfe test --download-only`` flag, which downloads all test data stored remotely to the local cache (`PR 1814 <https://github.com/OpenFreeEnergy/openfe/pull/1814>`_).
**Changed:**
* The absolute free energy protocols (AbsoluteBindingProtocol and AbsoluteSolvationProtocol) have been broken into multiple
protocol units, allowing for setup, run, and analysis to happen
separately in the future when relevant changes to protocol execution are
made (`PR 1776 <https://github.com/OpenFreeEnergy/openfe/pull/1776>`_).
* The relative free energy protocol (RelativeHybridTopologyProtocol) has been
broken into multiple protocol units, allowing for the setup, run, analysis to happen
separately (`PR 1773 <https://github.com/OpenFreeEnergy/openfe/pull/1773>`_).
**Fixed:**
* Fixed bug in ligand network visualization (such as with ``openfe view-ligand-network``) so that ligand names are no longer cut off by the plot border (`PR 1822 <https://github.com/OpenFreeEnergy/openfe/pull/1822>`_).
* Endstates in the RelativeHybridTopologyProtocol are now being created
in a manner that allows for isomorphic molecules that differ between
endstates to have different parameters (`PR 1772 <https://github.com/OpenFreeEnergy/openfe/pull/1772>`_).
v1.8.1
====================
**Added:**
* Added a progress bar for ``openfe gather`` JSON loading (`PR #1786 <https://github.com/OpenFreeEnergy/openfe/pull/1786>`_).
**Fixed:**
* Due to issues with OpenFF's handling of toolkit registries
with NAGL, the use of NAGL models (e.g. AshGC) when OpenEye
is installed but not requested as the charge backend has been
disabled (Issue #1760, `PR #1762 <https://github.com/OpenFreeEnergy/openfe/pull/1762>`_).
* Fixed bug in ligand network visualization (such as with ``openfe view-ligand-network``) so that ligand names are no longer cut off by the plot border (`PR #1822 <https://github.com/OpenFreeEnergy/openfe/pull/1822>`_).
v1.8.0
====================
@@ -14,6 +63,7 @@ v1.8.0
* Added experimental features ``openfe gather-septop`` and ``openfe gather-abfe``, which are analogous to ``openfe gather`` and allow for gathering results generated by the Separated Topologies and Absolute Binding Free Energy protocols, respectively. These commands are experimental and are liable to be changed in a future release.
* Emit a clarifying log message when a user gets a warning from JAX (`PR #1585 <https://github.com/OpenFreeEnergy/openfe/pull/1585>`_, fixes `Issue #1499 <https://github.com/OpenFreeEnergy/openfe/issues/1499>`_).
* Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (`PR #1694 <https://github.com/OpenFreeEnergy/openfe/pull/1692>`_).
* New options have been added to the ``AlchemicalSettings`` of the ``SepTopProtocol``, ``AbsoluteSolvationProtocol`` and ``AbsoluteBindingProtocol``. Notably, these options allow users to control the softcore parameters as well as the use of long range dispersion corrections (`PR #1742 <https://github.com/OpenFreeEnergy/openfe/pull/1742>`_).
**Changed:**

View File

@@ -103,6 +103,7 @@ exclude_patterns = [
autodoc_mock_imports = [
"cinnabar",
"dill",
"MDAnalysis",
"matplotlib",
"mdtraj",
@@ -192,7 +193,7 @@ try:
else:
repo = git.Repo.clone_from(
"https://github.com/OpenFreeEnergy/ExampleNotebooks.git",
branch="2025.12.04",
branch="2026.01.26",
to_path=example_notebooks_path,
)
except Exception as e:

View File

@@ -10,7 +10,7 @@ This settings file has a series of sections for customising the different algori
For example, the settings file which re-specifies the default behaviour would look like ::
network:
method: plan_minimal_spanning_tree
method: generate_minimal_spanning_network
mapper:
method: LomapAtomMapper
settings:
@@ -29,7 +29,7 @@ All sections of the file ``network:``, ``mapper:`` and ``partial_charge:`` are
The settings YAML file is then provided to the ``-s`` option of ``openfe plan-rbfe-network``: ::
openfe plan-rbfe-network -M molecules.sdf -P protein.pdb -s settings.yaml
openfe plan-rbfe-network -M molecules.sdf -p protein.pdb -s settings.yaml
Customising the atom mapper
---------------------------

View File

@@ -25,12 +25,12 @@ Using this toolkit you can plan, execute, and analyze free energy calculations u
Follow our installation guide to get **openfe** running on your machine!
.. grid-item-card:: :fas:`laptop-code` CLI
.. grid-item-card:: :fas:`laptop-code` CLI Quickstart
:text-align: center
:link: reference/cli/index
:link: tutorials/rbfe_cli_tutorial
:link-type: doc
Documentation for **openfe**\'s simple command line interface.
Get started with **openfe**\'s command line interface.
.. grid-item-card:: :fas:`person-chalkboard` Tutorials
:text-align: center
@@ -53,12 +53,12 @@ Using this toolkit you can plan, execute, and analyze free energy calculations u
How-to guides for common tasks.
.. grid-item-card:: :fas:`code` Python API
.. grid-item-card:: :fas:`code` API Reference
:text-align: center
:link: reference/api/index
:link: reference/index
:link-type: doc
Comprehensive details of the **openfe** Python API.
Comprehensive details of the **openfe** Python and CLI APIs.
.. grid-item-card:: :fas:`gears` Protocols
:text-align: center

View File

@@ -4,7 +4,7 @@
We have reproduced API documentation from the `gufe`_ package here for convenience.
`gufe`_ serves as a foundation layer for openfe, providing abstract base classes and object models, and so might be more useful for developers.
OpenFE API Reference
Python API Reference
====================
.. toctree::

View File

@@ -16,8 +16,12 @@ Protocol API specification
:toctree: generated/
AbsoluteBindingProtocol
AbsoluteBindingComplexUnit
AbsoluteBindingSolventUnit
ABFEComplexAnalysisUnit
ABFEComplexSetupUnit
ABFEComplexSimUnit
ABFESolventAnalysisUnit
ABFESolventSetupUnit
ABFESolventSimUnit
AbsoluteBindingProtocolResult
Protocol Settings

View File

@@ -16,7 +16,9 @@ Protocol API specification
:toctree: generated/
RelativeHybridTopologyProtocol
RelativeHybridTopologyProtocolUnit
HybridTopologySetupUnit
HybridTopologyMultiStateSimulationUnit
HybridTopologyMultiStateAnalysisUnit
RelativeHybridTopologyProtocolResult
Protocol Settings

View File

@@ -16,8 +16,12 @@ Protocol API specification
:toctree: generated/
AbsoluteSolvationProtocol
AbsoluteSolvationVacuumUnit
AbsoluteSolvationSolventUnit
AHFESolventAnalysisUnit
AHFESolventSetupUnit
AHFESolventSimUnit
AHFEVacuumAnalysisUnit
AHFEVacuumSetupUnit
AHFEVacuumSimUnit
AbsoluteSolvationProtocolResult
Protocol Settings

View File

@@ -12,7 +12,7 @@ dependencies:
- lomap2>=3.2.1
- networkx
- numba
- numpy<2.3
- numpy
- openfe-analysis>=0.3.1
- openff-interchange-base
- openff-nagl-base >=0.3.3
@@ -21,14 +21,15 @@ dependencies:
- openff-units==0.3.1 # https://github.com/OpenFreeEnergy/openfe/pull/1374
- openmm ~=8.2.0 # omit 8.3.0 and 8.3.1 due to https://github.com/openmm/openmm/pull/5069, unpin once we've qualified 8.3.2
- openmmforcefields >=0.15.0 # min needed for https://github.com/OpenFreeEnergy/openfe/pull/1695
- openmmtools >=0.25.0
- openmmtools >=0.25.3 # fix to support numpy >=2.3: https://github.com/choderalab/openmmtools/pull/793
- packaging
- pandas
- parmed >=4.3.1 # fix to support numpy >=2.3: https://github.com/ParmEd/ParmEd/pull/1387
- perses>=0.10.3
- plugcli
- pint>=0.24.0
- pip
- pooch
- pooch >= 1.9.0 # min needed for https://github.com/fatiando/pooch/issues/502
- py3dmol
- pydantic >= 2.0.0, <2.12.0 # https://github.com/openforcefield/openff-interchange/issues/1346
- pygraphviz
@@ -55,3 +56,6 @@ dependencies:
- pip:
- git+https://github.com/OpenFreeEnergy/gufe@main
- git+https://github.com/choderalab/pymbar.git@ed40ec3bbef03bb08938ad1a74d459b0d1ab81f7
- run_constrained:
# drop this pin when handled upstream in espaloma-feedstock
- smirnoff99frosst>=1.1.0.1 #https://github.com/openforcefield/smirnoff99Frosst/issues/109

View File

@@ -1,27 +0,0 @@
**Added:**
* New options have been added to the ``AlchemicalSettings``
of the ``SepTopProtocol``, ``AbsoluteSolvationProtocol``
and ``AbsoluteBindingProtocol``. Notably, these options allow users to
control the softcore parameters as well as the use of
long range dispersion corrections.
**Changed:**
* <news item>
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View File

@@ -1,34 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
Run absolute free energy calculations using OpenMM and OpenMMTools.
"""
from .equil_binding_afe_method import (
AbsoluteBindingComplexUnit,
AbsoluteBindingProtocol,
AbsoluteBindingProtocolResult,
AbsoluteBindingSettings,
AbsoluteBindingSolventUnit,
)
from .equil_solvation_afe_method import (
AbsoluteSolvationProtocol,
AbsoluteSolvationProtocolResult,
AbsoluteSolvationSettings,
AbsoluteSolvationSolventUnit,
AbsoluteSolvationVacuumUnit,
)
__all__ = [
"AbsoluteSolvationProtocol",
"AbsoluteSolvationSettings",
"AbsoluteSolvationProtocolResult",
"AbsoluteVacuumUnit",
"AbsoluteSolventUnit",
"AbsoluteBindingProtocol",
"AbsoluteBindingSettings",
"AbsoluteBindingProtocolResult",
"AbsoluteBindingComplexUnit",
"AbsoluteBindingSolventUnit",
]

File diff suppressed because it is too large Load Diff

View File

@@ -1,962 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method`
===============================================================================================================
This module implements the necessary methodology tooling to run calculate an
absolute solvation free energy using OpenMM tools and one of the following
alchemical sampling methods:
* Hamiltonian Replica Exchange
* Self-adjusted mixture sampling
* Independent window sampling
Current limitations
-------------------
* Alchemical species with a net charge are not currently supported.
* Disapearing molecules are only allowed in state A. Support for
appearing molecules will be added in due course.
* Only small molecules are allowed to act as alchemical molecules.
Alchemically changing protein or solvent components would induce
perturbations which are too large to be handled by this Protocol.
Acknowledgements
----------------
* Originally based on hydration.py in
`espaloma_charge <https://github.com/choderalab/espaloma_charge>`_
"""
from __future__ import annotations
import itertools
import logging
import pathlib
import uuid
import warnings
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
import gufe
import numpy as np
import numpy.typing as npt
from gufe import (
ChemicalSystem,
ProteinComponent,
SmallMoleculeComponent,
SolventComponent,
settings,
)
from gufe.components import Component
from openff.units import Quantity, unit
from openmmtools import multistate
from openfe.due import Doi, due
from openfe.protocols.openmm_afe.equil_afe_settings import (
AbsoluteSolvationSettings,
AlchemicalSettings,
IntegratorSettings,
LambdaSettings,
MDOutputSettings,
MDSimulationSettings,
MultiStateOutputSettings,
MultiStateSimulationSettings,
OpenFFPartialChargeSettings,
OpenMMEngineSettings,
OpenMMSolvationSettings,
SettingsBaseModel,
)
from ..openmm_utils import settings_validation, system_validation
from .base import BaseAbsoluteUnit
due.cite(
Doi("10.5281/zenodo.596504"),
description="Yank",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.48550/ARXIV.2302.06758"),
description="EspalomaCharge",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.5281/zenodo.596622"),
description="OpenMMTools",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.1371/journal.pcbi.1005659"),
description="OpenMM",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
logger = logging.getLogger(__name__)
class AbsoluteSolvationProtocolResult(gufe.ProtocolResult):
"""Dict-like container for the output of a AbsoluteSolvationProtocol"""
def __init__(self, **data):
super().__init__(**data)
# TODO: Detect when we have extensions and stitch these together?
if any(
len(pur_list) > 2
for pur_list in itertools.chain(
self.data["solvent"].values(), self.data["vacuum"].values()
)
):
raise NotImplementedError("Can't stitch together results yet")
def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]:
"""
Get the individual estimate of the free energies.
Returns
-------
dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]]
A dictionary, keyed `solvent` and `vacuum` for each leg
of the thermodynamic cycle, with lists of tuples containing
the individual free energy estimates and associated MBAR
uncertainties for each repeat of that simulation type.
"""
vac_dGs = []
solv_dGs = []
for pus in self.data["vacuum"].values():
vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]))
for pus in self.data["solvent"].values():
solv_dGs.append(
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
)
return {"solvent": solv_dGs, "vacuum": vac_dGs}
def get_estimate(self):
"""Get the solvation free energy estimate for this calculation.
Returns
-------
dG : openff.units.Quantity
The solvation free energy. This is a Quantity defined with units.
"""
def _get_average(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.average(dGs) * u
individual_estimates = self.get_individual_estimates()
vac_dG = _get_average(individual_estimates["vacuum"])
solv_dG = _get_average(individual_estimates["solvent"])
return vac_dG - solv_dG
def get_uncertainty(self):
"""Get the solvation free energy error for this calculation.
Returns
-------
err : openff.units.Quantity
The standard deviation between estimates of the solvation free
energy. This is a Quantity defined with units.
"""
def _get_stdev(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.std(dGs) * u
individual_estimates = self.get_individual_estimates()
vac_err = _get_stdev(individual_estimates["vacuum"])
solv_err = _get_stdev(individual_estimates["solvent"])
# return the combined error
return np.sqrt(vac_err**2 + solv_err**2)
def get_forward_and_reverse_energy_analysis(
self,
) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]:
"""
Get the reverse and forward analysis of the free energies.
Returns
-------
forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]]
A dictionary, keyed `solvent` and `vacuum` for each leg of the
thermodynamic cycle which each contain a list of dictionaries
containing the forward and reverse analysis of each repeat
of that simulation type.
The forward and reverse analysis dictionaries contain:
- `fractions`: npt.NDArray
The fractions of data used for the estimates
- `forward_DGs`, `reverse_DGs`: openff.units.Quantity
The forward and reverse estimates for each fraction of data
- `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity
The forward and reverse estimate uncertainty for each
fraction of data.
If one of the cycle leg list entries is ``None``, this indicates
that the analysis could not be carried out for that repeat. This
is most likely caused by MBAR convergence issues when attempting to
calculate free energies from too few samples.
Raises
------
UserWarning
* If any of the forward and reverse dictionaries are ``None`` in a
given thermodynamic cycle leg.
"""
forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {}
for key in ["solvent", "vacuum"]:
forward_reverse[key] = [
pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values()
]
if None in forward_reverse[key]:
wmsg = (
"One or more ``None`` entries were found in the forward "
f"and reverse dictionaries of the repeats of the {key} "
"calculations. This is likely caused by an MBAR convergence "
"failure caused by too few independent samples when "
"calculating the free energies of the 10% timeseries slice."
)
warnings.warn(wmsg)
return forward_reverse
def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get a the MBAR overlap estimates for all legs of the simulation.
Returns
-------
overlap_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary with keys `solvent` and `vacuum` for each
leg of the thermodynamic cycle, which each containing a
list of dictionaries with the MBAR overlap estimates of
each repeat of that simulation type.
The underlying MBAR dictionaries contain the following keys:
* ``scalar``: One minus the largest nontrivial eigenvalue
* ``eigenvalues``: The sorted (descending) eigenvalues of the
overlap matrix
* ``matrix``: Estimated overlap matrix of observing a sample from
state i in state j
"""
# Loop through and get the repeats and get the matrices
overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {}
for key in ["solvent", "vacuum"]:
overlap_stats[key] = [
pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values()
]
return overlap_stats
def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get the replica exchange transition statistics for all
legs of the simulation.
Note
----
This is currently only available in cases where a replica exchange
simulation was run.
Returns
-------
repex_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary with keys `solvent` and `vacuum` for each
leg of the thermodynamic cycle, which each containing
a list of dictionaries containing the replica transition
statistics for each repeat of that simulation type.
The replica transition statistics dictionaries contain the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {}
try:
for key in ["solvent", "vacuum"]:
repex_stats[key] = [
pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values()
]
except KeyError:
errmsg = "Replica exchange statistics were not found, did you run a repex calculation?"
raise ValueError(errmsg)
return repex_stats
def get_replica_states(self) -> dict[str, list[npt.NDArray]]:
"""
Get the timeseries of replica states for all simulation legs.
Returns
-------
replica_states : dict[str, list[npt.NDArray]]
Dictionary keyed `solvent` and `vacuum` for each leg of
the thermodynamic cycle, with lists of replica states
timeseries for each repeat of that simulation type.
"""
replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "vacuum": []}
def is_file(filename: str):
p = pathlib.Path(filename)
if not p.exists():
errmsg = f"File could not be found {p}"
raise ValueError(errmsg)
return p
def get_replica_state(nc, chk):
nc = is_file(nc)
dir_path = nc.parents[0]
chk = is_file(dir_path / chk).name
reporter = multistate.MultiStateReporter(
storage=nc, checkpoint_storage=chk, open_mode="r"
)
retval = np.asarray(reporter.read_replica_thermodynamic_states())
reporter.close()
return retval
for key in ["solvent", "vacuum"]:
for pus in self.data[key].values():
states = get_replica_state(
pus[0].outputs["nc"],
pus[0].outputs["last_checkpoint"],
)
replica_states[key].append(states)
return replica_states
def equilibration_iterations(self) -> dict[str, list[float]]:
"""
Get the number of equilibration iterations for each simulation.
Returns
-------
equilibration_lengths : dict[str, list[float]]
Dictionary keyed `solvent` and `vacuum` for each leg
of the thermodynamic cycle, with lists containing the
number of equilibration iterations for each repeat
of that simulation type.
"""
equilibration_lengths: dict[str, list[float]] = {}
for key in ["solvent", "vacuum"]:
equilibration_lengths[key] = [
pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values()
]
return equilibration_lengths
def production_iterations(self) -> dict[str, list[float]]:
"""
Get the number of production iterations for each simulation.
Returns the number of uncorrelated production samples for each
repeat of the calculation.
Returns
-------
production_lengths : dict[str, list[float]]
Dictionary keyed `solvent` and `vacuum` for each leg of the
thermodynamic cycle, with lists with the number
of production iterations for each repeat of that simulation
type.
"""
production_lengths: dict[str, list[float]] = {}
for key in ["solvent", "vacuum"]:
production_lengths[key] = [
pus[0].outputs["production_iterations"] for pus in self.data[key].values()
]
return production_lengths
class AbsoluteSolvationProtocol(gufe.Protocol):
"""
Absolute solvation free energy calculations using OpenMM and OpenMMTools.
See Also
--------
:mod:`openfe.protocols`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationSettings`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationProtocolResult`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit`
"""
result_cls = AbsoluteSolvationProtocolResult
_settings_cls = AbsoluteSolvationSettings
_settings: AbsoluteSolvationSettings
@classmethod
def _default_settings(cls):
"""A dictionary of initial settings for this creating this Protocol
These settings are intended as a suitable starting point for creating
an instance of this protocol. It is recommended, however that care is
taken to inspect and customize these before performing a Protocol.
Returns
-------
Settings
a set of default settings
"""
return AbsoluteSolvationSettings(
protocol_repeats=3,
solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(),
vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(
nonbonded_method="nocutoff",
),
thermo_settings=settings.ThermoSettings(
temperature=298.15 * unit.kelvin,
pressure=1 * unit.bar,
),
alchemical_settings=AlchemicalSettings(),
lambda_settings=LambdaSettings(
lambda_elec=[
0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lambda_vdw=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24,
0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0],
lambda_restraints=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
),
partial_charge_settings=OpenFFPartialChargeSettings(),
solvation_settings=OpenMMSolvationSettings(),
vacuum_engine_settings=OpenMMEngineSettings(),
solvent_engine_settings=OpenMMEngineSettings(),
integrator_settings=IntegratorSettings(),
solvent_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=0.1 * unit.nanosecond,
equilibration_length=0.2 * unit.nanosecond,
production_length=0.5 * unit.nanosecond,
),
solvent_equil_output_settings=MDOutputSettings(
equil_nvt_structure="equil_nvt_structure.pdb",
equil_npt_structure="equil_npt_structure.pdb",
production_trajectory_filename="production_equil.xtc",
log_output="equil_simulation.log",
),
solvent_simulation_settings=MultiStateSimulationSettings(
n_replicas=14,
equilibration_length=1.0 * unit.nanosecond,
production_length=10.0 * unit.nanosecond,
),
solvent_output_settings=MultiStateOutputSettings(
output_filename="solvent.nc",
checkpoint_storage_filename="solvent_checkpoint.nc",
),
vacuum_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=None,
equilibration_length=0.2 * unit.nanosecond,
production_length=0.5 * unit.nanosecond,
),
vacuum_equil_output_settings=MDOutputSettings(
equil_nvt_structure=None,
equil_npt_structure="equil_structure.pdb",
production_trajectory_filename="production_equil.xtc",
log_output="equil_simulation.log",
),
vacuum_simulation_settings=MultiStateSimulationSettings(
n_replicas=14,
equilibration_length=0.5 * unit.nanosecond,
production_length=2.0 * unit.nanosecond,
),
vacuum_output_settings=MultiStateOutputSettings(
output_filename="vacuum.nc",
checkpoint_storage_filename="vacuum_checkpoint.nc",
),
) # fmt: skip
@staticmethod
def _validate_endstates(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
A solvent transformation is defined (in terms of gufe components)
as starting from one or more ligands in solvent and
ending up in a state with one less ligand.
No protein components are allowed.
Parameters
----------
stateA : ChemicalSystem
The chemical system of end state A
stateB : ChemicalSystem
The chemical system of end state B
Raises
------
ValueError
If stateA or stateB contains a ProteinComponent.
If there is no SolventComponent in either stateA or stateB.
If there are alchemical components in state B.
If there are non SmallMoleculeComponent alchemical species.
If there are more than one alchemical species.
If the alchemical species is charged.
Notes
-----
* Currently doesn't support alchemical components in state B.
* Currently doesn't support alchemical components which are not
SmallMoleculeComponents.
* Currently doesn't support more than one alchemical component
being desolvated.
* Currently doesn't support charged alchemical components.
* Solvent must always be present in both end states.
"""
# Check that there are no protein components
if stateA.contains(ProteinComponent) or stateB.contains(ProteinComponent):
errmsg = "Protein components are not allowed for absolute solvation free energies."
raise ValueError(errmsg)
# Check that there is a solvent component in both end states
if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)):
errmsg = "No SolventComponent found in stateA and/or stateB"
raise ValueError(errmsg)
# Now we check the alchemical Components
diff = stateA.component_diff(stateB)
# Check that there's only one state A unique Component
if len(diff[0]) != 1:
errmsg = (
"Only one alchemical species is supported "
"for absolute solvation free energies. "
f"Number of unique components found in stateA: {len(diff[0])}."
)
raise ValueError(errmsg)
# Make sure that the state A unique is an SMC
if not isinstance(diff[0][0], SmallMoleculeComponent):
errmsg = (
"Only dissapearing SmallMoleculeComponents "
"are supported by this protocol. "
f"Found a {type(diff[0][0])}"
)
raise ValueError(errmsg)
# Check that the state A unique isn't charged
if diff[0][0].total_charge != 0:
errmsg = (
"Charged alchemical molecules are not currently "
"supported for solvation free energies. "
f"Molecule total charge: {diff[0][0].total_charge}."
)
raise ValueError(errmsg)
# If there are any alchemical Components in state B
if len(diff[1]) > 0:
errmsg = "Components appearing in state B are not currently supported"
raise ValueError(errmsg)
@staticmethod
def _validate_lambda_schedule(
lambda_settings: LambdaSettings,
simulation_settings: MultiStateSimulationSettings,
) -> None:
"""
Checks that the lambda schedule is set up correctly.
Parameters
----------
lambda_settings : LambdaSettings
the lambda schedule Settings
simulation_settings : MultiStateSimulationSettings
the settings for either the vacuum or solvent phase
Raises
------
ValueError
If the number of lambda windows differs for electrostatics and sterics.
If the number of replicas does not match the number of lambda windows.
If there are states with naked charges.
Warnings
If there are non-zero values for restraints (lambda_restraints).
"""
lambda_elec = lambda_settings.lambda_elec
lambda_vdw = lambda_settings.lambda_vdw
lambda_restraints = lambda_settings.lambda_restraints
n_replicas = simulation_settings.n_replicas
# Ensure that all lambda components have equal amount of windows
lambda_components = [lambda_vdw, lambda_elec, lambda_restraints]
it = iter(lambda_components)
the_len = len(next(it))
if not all(len(lambda_comp) == the_len for lambda_comp in it):
errmsg = (
"Components elec, vdw, and restraints must have equal amount"
f" of lambda windows. Got {len(lambda_elec)} elec lambda"
f" windows, {len(lambda_vdw)} vdw lambda windows, and"
f"{len(lambda_restraints)} restraints lambda windows."
)
raise ValueError(errmsg)
# Ensure that number of overall lambda windows matches number of lambda
# windows for individual components
if n_replicas != len(lambda_vdw):
errmsg = (
f"Number of replicas {n_replicas} does not equal the"
f" number of lambda windows {len(lambda_vdw)}"
)
raise ValueError(errmsg)
# Check if there are lambda windows with naked charges
for inx, lam in enumerate(lambda_elec):
if lam < 1 and lambda_vdw[inx] == 1:
errmsg = (
"There are states along this lambda schedule "
"where there are atoms with charges but no LJ "
f"interactions: lambda {inx}: "
f"elec {lam} vdW {lambda_vdw[inx]}"
)
raise ValueError(errmsg)
# Check if there are lambda windows with non-zero restraints
if len([r for r in lambda_restraints if r != 0]) > 0:
wmsg = (
"Non-zero restraint lambdas applied. The absolute "
"solvation protocol doesn't apply restraints, "
"therefore restraints won't be applied. "
f"Given lambda_restraints: {lambda_restraints}"
)
logger.warning(wmsg)
warnings.warn(wmsg)
def _validate(
self,
*,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
):
# Check we're not extending
if extends is not None:
# This should be a NotImplementedError, but the underlying
# `validate` method wraps a call to `_validate` around a
# NotImplementedError exception guard
raise ValueError("Can't extend simulations yet")
# Check we're not using a mapping, since we're not doing anything with it
if mapping is not None:
wmsg = "A mapping was passed but is not used by this Protocol."
warnings.warn(wmsg)
# Validate the endstates & alchemical components
self._validate_endstates(stateA, stateB)
# Validate the lambda schedule
for solv_sets in (
self.settings.solvent_simulation_settings,
self.settings.vacuum_simulation_settings,
):
self._validate_lambda_schedule(
self.settings.lambda_settings,
solv_sets,
)
# Check nonbond & solvent compatibility
solv_nonbonded_method = self.settings.solvent_forcefield_settings.nonbonded_method
vac_nonbonded_method = self.settings.vacuum_forcefield_settings.nonbonded_method
# Use the more complete system validation solvent checks
system_validation.validate_solvent(stateA, solv_nonbonded_method)
# Gas phase is always gas phase
if vac_nonbonded_method.lower() != "nocutoff":
errmsg = (
"Only the nocutoff nonbonded_method is supported for "
f"vacuum calculations, {vac_nonbonded_method} was "
"passed"
)
raise ValueError(errmsg)
# Validate solvation settings
settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings)
# Check vacuum equilibration MD settings is 0 ns
nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt
if nvt_time is not None:
if not np.allclose(nvt_time, 0 * unit.nanosecond):
errmsg = "NVT equilibration cannot be run in vacuum simulation"
raise ValueError(errmsg)
# Validate integrator things
settings_validation.validate_timestep(
self.settings.vacuum_forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
settings_validation.validate_timestep(
self.settings.solvent_forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# Validate inputs
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
# Get the alchemical components
alchem_comps = system_validation.get_alchemical_components(
stateA,
stateB,
)
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
# Create list units for vacuum and solvent transforms
solvent_units = [
AbsoluteSolvationSolventUnit(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} solvent leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
vacuum_units = [
AbsoluteSolvationVacuumUnit(
# These don't really reflect the actual transform
# Should these be overriden to be ChemicalSystem{smc} -> ChemicalSystem{} ?
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(f"Absolute Solvation, {alchname} vacuum leg: repeat {i} generation 0"),
)
for i in range(self.settings.protocol_repeats)
]
return solvent_units + vacuum_units
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
) -> dict[str, dict[str, Any]]:
# result units will have a repeat_id and generation
# first group according to repeat_id
unsorted_solvent_repeats = defaultdict(list)
unsorted_vacuum_repeats = defaultdict(list)
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if not pu.ok():
continue
if pu.outputs["simtype"] == "solvent":
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
else:
unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu)
repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = {
"solvent": {},
"vacuum": {},
}
for k, v in unsorted_solvent_repeats.items():
repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_vacuum_repeats.items():
repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
return repeats
class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit):
"""
Protocol Unit for the vacuum phase of an absolute solvation free energy
"""
simtype = "vacuum"
def _get_components(self):
"""
Get the relevant components for a vacuum transformation.
Returns
-------
alchem_comps : dict[str, list[Component]]
A list of alchemical components
solv_comp : None
For the gas phase transformation, None will always be returned
for the solvent component of the chemical system.
prot_comp : Optional[ProteinComponent]
The protein component of the system, if it exists.
small_mols : dict[Component, OpenFF Molecule]
The openff Molecules to add to the system. This
is equivalent to the alchemical components in stateA (since
we only allow for disappearing ligands).
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]}
_, prot_comp, _ = system_validation.get_components(stateA)
# Notes:
# 1. Our input state will contain a solvent, we ``None`` that out
# since this is the gas phase unit.
# 2. Our small molecules will always just be the alchemical components
# (of stateA since we enforce only one disappearing ligand)
return alchem_comps, None, prot_comp, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a vacuum transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : MDOutputSettings
* simulation_settings : SimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings
settings = {}
settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.lambda_settings
settings["engine_settings"] = prot_settings.vacuum_engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings
settings["simulation_settings"] = prot_settings.vacuum_simulation_settings
settings["output_settings"] = prot_settings.vacuum_output_settings
return settings
class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit):
"""
Protocol Unit for the solvent phase of an absolute solvation free energy
"""
simtype = "solvent"
def _get_components(self):
"""
Get the relevant components for a solvent transformation.
Returns
-------
alchem_comps : dict[str, Component]
A list of alchemical components
solv_comp : SolventComponent
The SolventComponent of the system
prot_comp : Optional[ProteinComponent]
The protein component of the system, if it exists.
small_mols : dict[SmallMoleculeComponent: OFFMolecule]
SmallMoleculeComponents to add to the system.
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
solv_comp, prot_comp, small_mols = system_validation.get_components(stateA)
off_comps = {m: m.to_openff() for m in small_mols}
# We don't need to check that solv_comp is not None, otherwise
# an error will have been raised when calling `validate_solvent`
# in the Protocol's `_create`.
# Similarly we don't need to check prot_comp since that's also
# disallowed on create
return alchem_comps, solv_comp, prot_comp, off_comps
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a solvent transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : MDOutputSettings
* simulation_settings : MultiStateSimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings
settings = {}
settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.lambda_settings
settings["engine_settings"] = prot_settings.solvent_engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings
settings["simulation_settings"] = prot_settings.solvent_simulation_settings
settings["output_settings"] = prot_settings.solvent_output_settings
return settings

View File

@@ -1,12 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from . import _rfe_utils
from .equil_rfe_methods import (
RelativeHybridTopologyProtocol,
RelativeHybridTopologyProtocolResult,
RelativeHybridTopologyProtocolUnit,
)
from .equil_rfe_settings import (
RelativeHybridTopologyProtocolSettings,
)

File diff suppressed because it is too large Load Diff

View File

@@ -1,122 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import gzip
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import openfe
from openfe.protocols.openmm_afe import (
AbsoluteBindingComplexUnit,
AbsoluteBindingProtocol,
AbsoluteBindingProtocolResult,
AbsoluteBindingSolventUnit,
)
@pytest.fixture
def protocol():
return AbsoluteBindingProtocol(AbsoluteBindingProtocol.default_settings())
@pytest.fixture
def protocol_units(protocol, benzene_complex_system, T4_protein_component):
stateA = benzene_complex_system
stateB = openfe.ChemicalSystem(
{"protein": T4_protein_component, "solvent": openfe.SolventComponent()}
)
pus = protocol.create(
stateA=stateA,
stateB=stateB,
mapping=None,
)
return list(pus.protocol_units)
@pytest.fixture
def solvent_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, AbsoluteBindingSolventUnit):
return pu
@pytest.fixture
def complex_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, AbsoluteBindingComplexUnit):
return pu
@pytest.fixture
def protocol_result(abfe_transformation_json_path):
with gzip.open(abfe_transformation_json_path) as f:
pr = AbsoluteBindingProtocolResult.from_json(f)
return pr
class TestAbsoluteBindingProtocol(GufeTokenizableTestsMixin):
cls = AbsoluteBindingProtocol
key = None
repr = "AbsoluteBindingProtocol-"
@pytest.fixture()
def instance(self, protocol):
return protocol
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteBindingSolventUnit(GufeTokenizableTestsMixin):
cls = AbsoluteBindingSolventUnit
repr = "AbsoluteBindingSolventUnit(Absolute Binding, benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_unit):
return solvent_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteBindingComplexUnit(GufeTokenizableTestsMixin):
cls = AbsoluteBindingComplexUnit
repr = "AbsoluteBindingComplexUnit(Absolute Binding, benzene complex leg"
key = None
@pytest.fixture()
def instance(self, complex_protocol_unit):
return complex_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteBindingProtocolResult(GufeTokenizableTestsMixin):
cls = AbsoluteBindingProtocolResult
key = None
repr = "AbsoluteBindingProtocolResult-"
@pytest.fixture()
def instance(self, protocol_result):
return protocol_result
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)

View File

@@ -1,116 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import json
import gufe
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
import openfe
from openfe.protocols import openmm_afe
@pytest.fixture
def protocol():
return openmm_afe.AbsoluteSolvationProtocol(
openmm_afe.AbsoluteSolvationProtocol.default_settings()
)
@pytest.fixture
def protocol_units(protocol, benzene_system):
pus = protocol.create(
stateA=benzene_system,
stateB=openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}),
mapping=None,
)
return list(pus.protocol_units)
@pytest.fixture
def solvent_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_afe.AbsoluteSolvationSolventUnit):
return pu
@pytest.fixture
def vacuum_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_afe.AbsoluteSolvationVacuumUnit):
return pu
@pytest.fixture
def protocol_result(afe_solv_transformation_json):
d = json.loads(afe_solv_transformation_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openmm_afe.AbsoluteSolvationProtocolResult.from_dict(d["protocol_result"])
return pr
class TestAbsoluteSolvationProtocol(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocol
key = None
repr = "AbsoluteSolvationProtocol-"
@pytest.fixture()
def instance(self, protocol):
return protocol
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteSolvationSolventUnit(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationSolventUnit
repr = "AbsoluteSolvationSolventUnit(Absolute Solvation, benzene solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_protocol_unit):
return solvent_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteSolvationVacuumUnit(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationVacuumUnit
repr = "AbsoluteSolvationVacuumUnit(Absolute Solvation, benzene vacuum leg"
key = None
@pytest.fixture()
def instance(self, vacuum_protocol_unit):
return vacuum_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestAbsoluteSolvationProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_afe.AbsoluteSolvationProtocolResult
key = None
repr = "AbsoluteSolvationProtocolResult-"
@pytest.fixture()
def instance(self, protocol_result):
return protocol_result
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)

View File

@@ -1,105 +0,0 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import pytest
from gufe.tests.test_tokenization import GufeTokenizableTestsMixin
from openff.units import unit
from openfe.protocols import openmm_rfe
"""
todo:
- RelativeHybridTopologyProtocolResult
- RelativeHybridTopologyProtocol
- RelativeHybridTopologyProtocolUnit
"""
@pytest.fixture
def rfe_protocol():
return openmm_rfe.RelativeHybridTopologyProtocol(
openmm_rfe.RelativeHybridTopologyProtocol.default_settings()
)
@pytest.fixture
def rfe_protocol_other_units():
"""Identical to rfe_protocol, but with `kcal / mol` as input unit instead of `kilocalorie_per_mole`."""
new_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings()
new_settings.simulation_settings.early_termination_target_error = 0.0 * unit.kilocalorie/unit.mol # fmt: skip
return openmm_rfe.RelativeHybridTopologyProtocol(new_settings)
@pytest.fixture
def protocol_unit(rfe_protocol, benzene_system, toluene_system, benzene_to_toluene_mapping):
pus = rfe_protocol.create(
stateA=benzene_system,
stateB=toluene_system,
mapping=[benzene_to_toluene_mapping],
)
return list(pus.protocol_units)[0]
@pytest.mark.skip
class TestRelativeHybridTopologyProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocolResult
repr = ""
key = ""
@pytest.fixture()
def instance(self):
pass
class TestRelativeHybridTopologyProtocolOtherUnits(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocol
key = None
repr = "<RelativeHybridTopologyProtocol-"
@pytest.fixture()
def instance(self, rfe_protocol_other_units):
return rfe_protocol_other_units
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestRelativeHybridTopologyProtocol(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocol
key = None
repr = "<RelativeHybridTopologyProtocol-"
@pytest.fixture()
def instance(self, rfe_protocol):
return rfe_protocol
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestRelativeHybridTopologyProtocolUnit(GufeTokenizableTestsMixin):
cls = openmm_rfe.RelativeHybridTopologyProtocolUnit
repr = "RelativeHybridTopologyProtocolUnit(benzene to toluene repeat"
key = None
@pytest.fixture()
def instance(self, protocol_unit):
return protocol_unit
def test_key_stable(self):
pytest.skip()
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)

View File

@@ -12,16 +12,14 @@ description = ""
readme = "README.md"
license = "MIT"
license-files = [ "LICENSE" ]
authors = [ { name = "The OpenFE developers", email = "openfe@omsf.io" } ]
requires-python = ">=3.10"
authors = [ { name = "The OpenFE developers", email = "openfreeenergy@omsf.io" } ]
requires-python = ">=3.11"
classifiers = [
"Development Status :: 1 - Planning",
"Development Status :: 5 - Production/Stable",
"Intended Audience :: Science/Research",
"Operating System :: Microsoft :: Windows",
"Operating System :: POSIX",
"Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
@@ -38,11 +36,12 @@ scripts.openfe = "openfecli.cli:main"
zip-safe = false
include-package-data = true
[tool.setuptools.packages]
find = { namespaces = false }
[tool.setuptools.packages.find]
where = [ "src" ]
namespaces = false
[tool.setuptools.package-data]
openfe = [ '"./openfe/tests/data/lomap_basic/toluene.mol2"' ]
openfe = [ '"./src/openfe/tests/data/lomap_basic/toluene.mol2"' ]
[tool.setuptools_scm]
fallback_version = "0.0.0"
@@ -71,9 +70,9 @@ lint.isort.known-first-party = [ "openfe" ]
[tool.coverage.run]
omit = [
"openfe/due.py",
"*/tests/dev/*py",
"*/tests/protocols/test_openmm_rfe_slow.py",
"src/openfe/due.py",
"src/*/tests/dev/*py",
"src/*/tests/protocols/test_openmm_rfe_slow.py",
]
[tool.coverage.report]
@@ -86,6 +85,6 @@ exclude_lines = [
]
[tool.mypy]
files = "openfe"
files = "src/openfe" # TODO: add src/openfecli
ignore_missing_imports = true
warn_unused_ignores = true

View File

@@ -0,0 +1,33 @@
import pooch
from ._registry import zenodo_data_registry
def retrieve_registry_data(zenodo_registry: list[dict], path: str) -> None:
"""Helper function for pulling all test data up-front.
Parameters
----------
path : str
path to store the data - usually a pooch.os_cache instance.
"""
downloader = pooch.DOIDownloader(progressbar=True)
def _infer_processor(fname: str):
if fname.endswith("tar.gz"):
return pooch.Untar()
elif fname.endswith("zip"):
return pooch.Unzip()
else:
return None
for d in zenodo_registry:
pooch.retrieve(
url=d["base_url"] + d["fname"],
known_hash=d["known_hash"],
fname=d["fname"],
processor=_infer_processor(d["fname"]),
downloader=downloader,
path=path,
)

View File

@@ -0,0 +1,24 @@
import pooch
POOCH_CACHE = pooch.os_cache("openfe")
zenodo_rfe_simulation_nc = dict(
base_url="doi:10.5281/zenodo.15375081/",
fname="simulation.nc",
known_hash="md5:bc4e842b47de17704d804ae345b91599",
)
zenodo_t4_lysozyme_traj = dict(
base_url="doi:10.5281/zenodo.15212342",
fname="t4_lysozyme_trajectory.zip",
known_hash="sha256:e985d055db25b5468491e169948f641833a5fbb67a23dbb0a00b57fb7c0e59c8",
)
zenodo_industry_benchmark_systems = dict(
base_url="doi:10.5281/zenodo.15212342",
fname="industry_benchmark_systems.zip",
known_hash="sha256:2bb5eee36e29b718b96bf6e9350e0b9957a592f6c289f77330cbb6f4311a07bd",
)
zenodo_data_registry = [
zenodo_rfe_simulation_nc,
zenodo_t4_lysozyme_traj,
zenodo_industry_benchmark_systems,
]

View File

@@ -0,0 +1,56 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
Run absolute free energy calculations using OpenMM and OpenMMTools.
"""
from .abfe_units import (
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
from .afe_protocol_results import (
AbsoluteBindingProtocolResult,
AbsoluteSolvationProtocolResult,
)
from .ahfe_units import (
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
from .equil_binding_afe_method import (
AbsoluteBindingProtocol,
AbsoluteBindingSettings,
)
from .equil_solvation_afe_method import (
AbsoluteSolvationProtocol,
AbsoluteSolvationSettings,
)
__all__ = [
"AbsoluteSolvationProtocol",
"AbsoluteSolvationSettings",
"AbsoluteSolvationProtocolResult",
"AHFESolventSetupUnit",
"AHFESolventSimUnit",
"AHFESolventAnalysisUnit",
"AHFEVacuumSetupUnit",
"AHFEVacuumSimUnit",
"AHFEVacuumAnalysisUnit",
"AbsoluteBindingProtocol",
"AbsoluteBindingSettings",
"AbsoluteBindingProtocolResult",
"ABFEComplexSetupUnit",
"ABFEComplexSimUnit",
"ABFEComplexAnalysisUnit",
"ABFESolventSetupUnit",
"ABFESolventSimUnit",
"ABFESolventAnalysisUnit",
]

View File

@@ -0,0 +1,513 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""ABFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.abfe_units`
========================================================================
This module defines the ProtocolUnits for the
:class:`AbsoluteBindingProtocol`.
"""
import logging
import pathlib
import MDAnalysis as mda
import numpy as np
import numpy.typing as npt
from gufe import (
SolventComponent,
)
from gufe.components import Component
from openff.units import Quantity
from openff.units.openmm import to_openmm
from openmm import System
from openmm import unit as ommunit
from openmm.app import Topology as omm_topology
from openmmtools.states import ThermodynamicState
from rdkit import Chem
from openfe.protocols.openmm_afe.equil_afe_settings import (
BoreschRestraintSettings,
SettingsBaseModel,
)
from openfe.protocols.openmm_utils import system_validation
from openfe.protocols.restraint_utils import geometry
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
from openfe.protocols.restraint_utils.openmm import omm_restraints
from openfe.protocols.restraint_utils.openmm.omm_restraints import BoreschRestraint
from .base_afe_units import (
BaseAbsoluteMultiStateAnalysisUnit,
BaseAbsoluteMultiStateSimulationUnit,
BaseAbsoluteSetupUnit,
)
logger = logging.getLogger(__name__)
class ComplexComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a complex transformation.
Returns
-------
alchem_comps : dict[str, Component]
A dict of alchemical components
solv_comp : SolventComponent
The SolventComponent of the system
prot_comp : ProteinComponent | None
The protein component of the system, if it exists.
small_mols : dict[SmallMoleculeComponent: OFFMolecule]
SmallMoleculeComponents to add to the system.
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
solv_comp, prot_comp, small_mols = system_validation.get_components(stateA)
off_comps = {m: m.to_openff() for m in small_mols}
# We don't need to check that solv_comp is not None, otherwise
# an error will have been raised when calling `validate_solvent`
# in the Protocol's `_create`.
# Similarly we don't need to check prot_comp
return alchem_comps, solv_comp, prot_comp, off_comps
class ComplexSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a complex transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : ABFEPreEquilOutputSettings
* simulation_settings : SimulationSettings
* output_settings: MultiStateOutputSettings
* restraint_settings: BaseRestraintSettings
"""
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.complex_solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.complex_lambda_settings
settings["engine_settings"] = prot_settings.engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.complex_equil_output_settings
settings["simulation_settings"] = prot_settings.complex_simulation_settings
settings["output_settings"] = prot_settings.complex_output_settings
settings["restraint_settings"] = prot_settings.restraint_settings
return settings
class ABFEComplexSetupUnit(ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the complex phase of absolute binding free energy
transformations.
"""
simtype = "complex"
@staticmethod
def _get_mda_universe(
topology: omm_topology,
positions: ommunit.Quantity | None,
trajectory: pathlib.Path | None,
) -> mda.Universe:
"""
Helper method to get a Universe from an openmm Topology,
and either an input trajectory or a set of positions.
Parameters
----------
topology : openmm.app.Topology
An OpenMM Topology that defines the System.
positions: openmm.unit.Quantity | None
The System's current positions.
Used if a trajectory file is None or is not a file.
trajectory: pathlib.Path | None
A Path to a trajectory file to read positions from.
Returns
-------
mda.Universe
An MDAnalysis Universe of the System.
"""
from MDAnalysis.coordinates.memory import MemoryReader
# If the trajectory file doesn't exist, then we use positions
if trajectory is not None and trajectory.is_file():
return mda.Universe(
topology,
trajectory,
topology_format="OPENMMTOPOLOGY",
)
else:
if positions is None:
raise ValueError("No positions to create the Universe with")
# Positions is an openmm Quantity in nm we need
# to convert to angstroms
return mda.Universe(
topology,
np.array(positions._value) * 10,
topology_format="OPENMMTOPOLOGY",
trajectory_format=MemoryReader,
)
@staticmethod
def _get_idxs_from_residxs(
topology: omm_topology,
residxs: list[int],
) -> list[int]:
"""
Helper method to get the a list of atom indices which belong to a list
of residues.
Parameters
----------
topology : openmm.app.Topology
An OpenMM Topology that defines the System.
residxs : list[int]
A list of residue numbers who's atoms we should get atom indices.
Returns
-------
atom_ids : list[int]
A list of atom indices.
TODO
----
* Check how this works when we deal with virtual sites.
"""
atom_ids = []
for r in topology.residues():
if r.index in residxs:
atom_ids.extend([at.index for at in r.atoms()])
return atom_ids
@staticmethod
def _get_boresch_restraint(
universe: mda.Universe,
guest_rdmol: Chem.Mol,
guest_atom_ids: list[int],
host_atom_ids: list[int],
temperature: Quantity,
settings: BoreschRestraintSettings,
) -> tuple[BoreschRestraintGeometry, BoreschRestraint]:
"""
Get a Boresch-like restraint Geometry and OpenMM restraint force
supplier.
Parameters
----------
universe : mda.Universe
An MDAnalysis Universe defining the system to get the restraint for.
guest_rdmol : Chem.Mol
An RDKit Molecule defining the guest molecule in the system.
guest_atom_ids: list[int]
A list of atom indices defining the guest molecule in the universe.
host_atom_ids : list[int]
A list of atom indices defining the host molecules in the universe.
temperature : openff.units.Quantity
The temperature of the simulation where the restraint will be added.
settings : BoreschRestraintSettings
Settings on how the Boresch-like restraint should be defined.
Returns
-------
geom : BoreschRestraintGeometry
A class defining the Boresch-like restraint.
restraint : BoreschRestraint
A factory class for generating Boresch restraints in OpenMM.
"""
# Take the minimum of the two possible force constants to check against
frc_const = min(settings.K_thetaA, settings.K_thetaB)
geom = geometry.boresch.find_boresch_restraint(
universe=universe,
guest_rdmol=guest_rdmol,
guest_idxs=guest_atom_ids,
host_idxs=host_atom_ids,
host_selection=settings.host_selection,
anchor_finding_strategy=settings.anchor_finding_strategy,
dssp_filter=settings.dssp_filter,
rmsf_cutoff=settings.rmsf_cutoff,
host_min_distance=settings.host_min_distance,
host_max_distance=settings.host_max_distance,
angle_force_constant=frc_const,
temperature=temperature,
)
restraint = omm_restraints.BoreschRestraint(settings)
return geom, restraint
def _add_restraints(
self,
system: System,
topology: omm_topology,
positions: ommunit.Quantity,
alchem_comps: dict[str, list[Component]],
comp_resids: dict[Component, npt.NDArray],
settings: dict[str, SettingsBaseModel],
) -> tuple[
Quantity,
System,
geometry.HostGuestRestraintGeometry,
]:
"""
Find and add restraints to the OpenMM System.
Notes
-----
Currently, only Boresch-like restraints are supported.
Parameters
----------
system : openmm.System
The System to add the restraint to.
topology : openmm.app.Topology
An OpenMM Topology that defines the System.
positions: openmm.unit.Quantity
The System's current positions.
Used if a trajectory file isn't found.
alchem_comps: dict[str, list[Component]]
A dictionary with a list of alchemical components
in both state A and B.
comp_resids: dict[Component, npt.NDArray]
A dictionary keyed by each Component in the System
which contains arrays with the residue indices that is contained
by that Component.
settings : dict[str, SettingsBaseModel]
A dictionary of settings that defines how to find and set
the restraint.
Returns
-------
correction : openff.units.Quantity
The standard state correction for the restraint.
system : openmm.System
A copy of the System with the restraint added.
rest_geom : geometry.HostGuestRestraintGeometry
The restraint Geometry object.
"""
if self.verbose:
self.logger.info("Generating restraints")
# Get the guest rdmol
guest_rdmol = alchem_comps["stateA"][0].to_rdkit()
# sanitize the rdmol if possible - warn if you can't
err = Chem.SanitizeMol(guest_rdmol, catchErrors=True)
if err:
msg = "restraint generation: could not sanitize ligand rdmol"
logger.warning(msg)
# Get the guest idxs
# concatenate a list of residue indexes for all alchemical components
residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]])
# get the alchemicical atom ids
guest_atom_ids = self._get_idxs_from_residxs(topology, residxs)
# Now get the host idxs
# We assume this is everything but the alchemical component
# and the solvent.
solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)]
exclude_comps = [alchem_comps["stateA"]] + solv_comps
residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps])
host_atom_ids = self._get_idxs_from_residxs(topology, residxs)
# Finally create an MDAnalysis Universe
# We try to pass the equilibration production file path through
# In some cases (debugging / dry runs) this won't be available
# so we'll default to using input positions.
univ = self._get_mda_universe(
topology,
positions,
self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename,
)
if isinstance(settings["restraint_settings"], BoreschRestraintSettings):
rest_geom, restraint = self._get_boresch_restraint(
univ,
guest_rdmol,
guest_atom_ids,
host_atom_ids,
settings["thermo_settings"].temperature,
settings["restraint_settings"],
)
else:
# TODO turn this into a direction for different restraint types supported?
raise NotImplementedError("Other restraint types are not yet available")
if self.verbose:
self.logger.info(f"restraint geometry is: {rest_geom}")
# We need a temporary thermodynamic state to add the restraint
# & get the correction
thermodynamic_state = ThermodynamicState(
system,
temperature=to_openmm(settings["thermo_settings"].temperature),
pressure=to_openmm(settings["thermo_settings"].pressure),
)
# Add the force to the thermodynamic state
restraint.add_force(
thermodynamic_state,
rest_geom,
controlling_parameter_name="lambda_restraints",
)
# Get the standard state correction as a unit.Quantity
correction = restraint.get_standard_state_correction(
thermodynamic_state,
rest_geom,
)
return (
correction,
# Remove the thermostat, otherwise you'll get an
# Andersen thermostat by default!
thermodynamic_state.get_system(remove_thermostat=True),
rest_geom,
)
class ABFEComplexSimUnit(
ComplexComponentsMixin, ComplexSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the complex phase of absolute binding
free energy transformations.
"""
simtype = "complex"
class ABFEComplexAnalysisUnit(ComplexSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the complex phase
of absolute binding free energy transformations.
"""
simtype = "complex"
class SolventComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a solvent transformation.
Returns
-------
alchem_comps : dict[str, Component]
A list of alchemical components
solv_comp : SolventComponent
The SolventComponent of the system
prot_comp : ProteinComponent | None
The protein component of the system, if it exists.
small_mols : dict[SmallMoleculeComponent: OFFMolecule]
SmallMoleculeComponents to add to the system.
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
solv_comp, prot_comp, small_mols = system_validation.get_components(stateA)
off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]}
# We don't need to check that solv_comp is not None, otherwise
# an error will have been raised when calling `validate_solvent`
# in the Protocol's `_create`.
# Similarly we don't need to check prot_comp just return None
return alchem_comps, solv_comp, None, off_comps
class SolventSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a solvent transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : ABFEPreEquilOutputSettings
* simulation_settings : MultiStateSimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.solvent_solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.solvent_lambda_settings
settings["engine_settings"] = prot_settings.engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings
settings["simulation_settings"] = prot_settings.solvent_simulation_settings
settings["output_settings"] = prot_settings.solvent_output_settings
return settings
class ABFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the solvent phase of absolute binding free energy
transformations.
"""
simtype = "solvent"
class ABFESolventSimUnit(
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the solvent phase of absolute binding
free energy transformations.
"""
simtype = "solvent"
class ABFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the solvent phase
of absolute binding free energy transformations.
"""
simtype = "solvent"

View File

@@ -0,0 +1,550 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
Result classes for the Absolute Free Energy Protocols
=====================================================
This module implements :class:`gufe.ProtocolResult` classes for the absolute
free energy Protocols.
Specifically it implements:
* AbsoluteBindingProtocolResult
* AbsoluteSolvationProtocolResult
"""
import itertools
import logging
import pathlib
import warnings
from typing import Optional, Union
import gufe
import numpy as np
import numpy.typing as npt
from openff.units import Quantity
from openff.units import unit as offunit
from openmmtools import multistate
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
logger = logging.getLogger(__name__)
class AbsoluteProtocolResultMixin:
bound_state = "solvent"
unbound_state = "vacuum"
def __init__(self, **data):
super().__init__(**data)
# TODO: Detect when we have extensions and stitch these together?
if any(
len(pur_list) > 2
for pur_list in itertools.chain(
self.data[self.bound_state].values(), self.data[self.unbound_state].values()
)
):
raise NotImplementedError("Can't stitch together results yet")
def get_forward_and_reverse_energy_analysis(
self,
) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]:
"""
Get the reverse and forward analysis of the free energies.
Returns
-------
forward_reverse : dict[str, list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]]
A dictionary, keyed for each leg of the thermodynamic cycle,
either ``solvent`` and ``vaccuum` for a solvation free energy or
``solvent`` and ``complex`` for a binding free energy,
with each containing a list of dictionaries containing the forward
and reverse analysis of each repeat of that simulation type.
The forward and reverse analysis dictionaries contain:
- `fractions`: npt.NDArray
The fractions of data used for the estimates
- `forward_DGs`, `reverse_DGs`: openff.units.Quantity
The forward and reverse estimates for each fraction of data
- `forward_dDGs`, `reverse_dDGs`: openff.units.Quantity
The forward and reverse estimate uncertainty for each
fraction of data.
If one of the cycle leg list entries is ``None``, this indicates
that the analysis could not be carried out for that repeat. This
is most likely caused by MBAR convergence issues when attempting to
calculate free energies from too few samples.
Raises
------
UserWarning
* If any of the forward and reverse dictionaries are ``None`` in a
given thermodynamic cycle leg.
"""
forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {}
for key in [self.bound_state, self.unbound_state]:
forward_reverse[key] = [
pus[0].outputs["forward_and_reverse_energies"]
for pus in self.data[key].values() # type: ignore[attr-defined]
]
if None in forward_reverse[key]:
wmsg = (
"One or more ``None`` entries were found in the forward "
f"and reverse dictionaries of the repeats of the {key} "
"calculations. This is likely caused by an MBAR convergence "
"failure caused by too few independent samples when "
"calculating the free energies of the 10% timeseries slice."
)
warnings.warn(wmsg)
return forward_reverse
def get_overlap_matrices(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get a the MBAR overlap estimates for all legs of the simulation.
Returns
-------
overlap_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary keyed for each leg of the thermodynamic cycle, either
``solvent`` and ``vaccuum` for a solvation free energy or
``solvent`` and ``complex`` for a binding free energy,
with each containing a list of dictionaries with the MBAR overlap
estimates of each repeat of that simulation type.
The underlying MBAR dictionaries contain the following keys:
* ``scalar``: One minus the largest nontrivial eigenvalue
* ``eigenvalues``: The sorted (descending) eigenvalues of the
overlap matrix
* ``matrix``: Estimated overlap matrix of observing a sample from
state i in state j
"""
# Loop through and get the repeats and get the matrices
overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {}
for key in [self.bound_state, self.unbound_state]:
overlap_stats[key] = [
pus[0].outputs["unit_mbar_overlap"]
for pus in self.data[key].values() # type: ignore[attr-defined]
]
return overlap_stats
def get_replica_transition_statistics(self) -> dict[str, list[dict[str, npt.NDArray]]]:
"""
Get the replica exchange transition statistics for all
legs of the simulation.
Note
----
This is currently only available in cases where a replica exchange
simulation was run.
Returns
-------
repex_stats : dict[str, list[dict[str, npt.NDArray]]]
A dictionary with keys for each leg of the thermodynamic cycle, either
``solvent`` and ``vaccuum` for a solvation free energy or
``solvent`` and ``complex`` for a binding free energy,
with each containing a list of dictionaries containing the replica
transition statistics for each repeat of that simulation type.
The replica transition statistics dictionaries contain the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {}
try:
for key in [self.bound_state, self.unbound_state]:
repex_stats[key] = [
pus[0].outputs["replica_exchange_statistics"]
for pus in self.data[key].values() # type: ignore[attr-defined]
]
except KeyError:
errmsg = "Replica exchange statistics were not found, did you run a repex calculation?"
raise ValueError(errmsg)
return repex_stats
def get_replica_states(self) -> dict[str, list[npt.NDArray]]:
"""
Get the timeseries of replica states for all simulation legs.
Returns
-------
replica_states : dict[str, list[npt.NDArray]]
Dictionary keyed for each leg of the thermodynamic cycle, either
`solvent` and `vacuum` for solvation free energies,
or `complex` and `solvent` for binding free energies,
with lists of replica states timeseries for each repeat of that
simulation type.
"""
replica_states: dict[str, list[npt.NDArray]] = {
self.bound_state: [],
self.unbound_state: [],
}
def is_file(filename: str):
p = pathlib.Path(filename)
if not p.exists():
errmsg = f"File could not be found {p}"
raise ValueError(errmsg)
return p
def get_replica_state(nc, chk):
nc = is_file(nc)
dir_path = nc.parents[0]
chk = is_file(dir_path / chk).name
reporter = multistate.MultiStateReporter(
storage=nc, checkpoint_storage=chk, open_mode="r"
)
retval = np.asarray(reporter.read_replica_thermodynamic_states())
reporter.close()
return retval
for key in [self.bound_state, self.unbound_state]:
for pus in self.data[key].values(): # type: ignore[attr-defined]
states = get_replica_state(
pus[0].outputs["trajectory"],
pus[0].outputs["checkpoint"],
)
replica_states[key].append(states)
return replica_states
def equilibration_iterations(self) -> dict[str, list[float]]:
"""
Get the number of equilibration iterations for each simulation.
Returns
-------
equilibration_lengths : dict[str, list[float]]
Dictionary keyed for each leg of the thermodynamic cycle, either
`solvent` and `vacuum` for solvation free energies,
or `complex` and `solvent` for binding free energies,
with lists containing the number of equilibration iterations for
each repeat of that simulation type.
"""
equilibration_lengths: dict[str, list[float]] = {}
for key in [self.bound_state, self.unbound_state]:
equilibration_lengths[key] = [
pus[0].outputs["equilibration_iterations"]
for pus in self.data[key].values() # type: ignore[attr-defined]
]
return equilibration_lengths
def production_iterations(self) -> dict[str, list[float]]:
"""
Get the number of production iterations for each simulation.
Returns the number of uncorrelated production samples for each
repeat of the calculation.
Returns
-------
production_lengths : dict[str, list[float]]
Dictionary keyed for each leg of the thermodynamic cycle, either
`solvent` and `vacuum` for solvation free energies,
or `complex` and `solvent` for binding free energies,
with lists containing the number of equilibration iterations for
each repeat of that simulation type.
"""
production_lengths: dict[str, list[float]] = {}
for key in [self.bound_state, self.unbound_state]:
production_lengths[key] = [
pus[0].outputs["production_iterations"]
for pus in self.data[key].values() # type: ignore[attr-defined]
]
return production_lengths
def selection_indices(self) -> dict[str, list[Optional[npt.NDArray]]]:
"""
Get the system selection indices used to write PDB and
trajectory files.
Returns
-------
indices : dict[str, list[npt.NDArray]]
A dictionary keyed for each state, either
`solvent` and `vacuum` for solvation free energies,
or `complex` and `solvent` for binding free energies,
each containing a list of NDArrays containing the corresponding
full system atom indices for each atom written in the production
trajectory files for each replica.
"""
indices: dict[str, list[Optional[npt.NDArray]]] = {}
for key in [self.bound_state, self.unbound_state]:
indices[key] = []
for pus in self.data[key].values(): # type: ignore[attr-defined]
indices[key].append(pus[0].outputs["selection_indices"])
return indices
class AbsoluteSolvationProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
"""
Protocol results with the output of a AbsoluteSolvationProtocol
"""
bound_state = "solvent"
unbound_state = "vacuum"
def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]:
"""
Get the individual estimate of the free energies.
Returns
-------
dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]]
A dictionary, keyed `solvent` and `vacuum` for each leg
of the thermodynamic cycle, with lists of tuples containing
the individual free energy estimates and associated MBAR
uncertainties for each repeat of that simulation type.
"""
dGs = {}
for state in [self.bound_state, self.unbound_state]:
state_dGs = [
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
for pus in self.data[state].values()
]
dGs[state] = state_dGs
return dGs
def get_estimate(self):
"""Get the solvation free energy estimate for this calculation.
Returns
-------
dG : openff.units.Quantity
The solvation free energy. This is a Quantity defined with units.
"""
def _get_average(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.average(dGs) * u
individual_estimates = self.get_individual_estimates()
vac_dG = _get_average(individual_estimates["vacuum"])
solv_dG = _get_average(individual_estimates["solvent"])
return vac_dG - solv_dG
def get_uncertainty(self):
"""Get the solvation free energy error for this calculation.
Returns
-------
err : openff.units.Quantity
The standard deviation between estimates of the solvation free
energy. This is a Quantity defined with units.
"""
def _get_stdev(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.std(dGs) * u
individual_estimates = self.get_individual_estimates()
vac_err = _get_stdev(individual_estimates["vacuum"])
solv_err = _get_stdev(individual_estimates["solvent"])
# return the combined error
return np.sqrt(vac_err**2 + solv_err**2)
class AbsoluteBindingProtocolResult(gufe.ProtocolResult, AbsoluteProtocolResultMixin):
"""
Protocol results with the output of a AbsoluteBindingProtocol.
"""
bound_state = "complex"
unbound_state = "solvent"
def get_individual_estimates(
self,
) -> dict[str, list[tuple[Quantity, Quantity]]]:
"""
Get the individual estimate of the free energies.
Returns
-------
dGs : dict[str, list[tuple[openff.units.Quantity, openff.units.Quantity]]]
A dictionary, keyed `solvent`, `complex`, and 'standard_state'
representing each portion of the thermodynamic cycle,
with lists of tuples containing the individual free energy
estimates and, for 'solvent' and 'complex', the associated MBAR
uncertainties for each repeat of that simulation type.
Notes
-----
* Standard state correction has no error and so will return a value
of 0.
"""
complex_dGs = []
correction_dGs = []
solv_dGs = []
for pus in self.data["complex"].values():
complex_dGs.append(
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
)
correction_dGs.append(
(
pus[0].outputs["standard_state_correction"],
0 * offunit.kilocalorie_per_mole, # correction has no error
)
)
for pus in self.data["solvent"].values():
solv_dGs.append(
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
)
return {
"solvent": solv_dGs,
"complex": complex_dGs,
"standard_state_correction": correction_dGs,
}
@staticmethod
def _add_complex_standard_state_corr(
complex_dG: list[tuple[Quantity, Quantity]],
standard_state_dG: list[tuple[Quantity, Quantity]],
) -> list[tuple[Quantity, Quantity]]:
"""
Helper method to combine the
complex & standard state corrections legs.
Parameters
----------
complex_dG : list[tuple[openff.units.Quantity, openff.units.Quantity]]
The individual estimates of the complex leg,
where the first entry of each tuple is the dG estimate
and the second entry is the MBAR error.
standard_state_dG : list[tuple[Quantity, Quantity]]
The individual standard state corrections for each corresponding
complex leg. The first entry is the correction, the second
is an empty error value of 0.
Returns
-------
combined_dG : list[tuple[openff.units.Quantity,openff.units. Quantity]]
A list of dG estimates & MBAR errors for the combined
complex & standard state correction of each repeat.
Notes
-----
We assume that both list of items are in the right order.
"""
combined_dG: list[tuple[Quantity, Quantity]] = []
for comp, corr in zip(complex_dG, standard_state_dG):
# No need to convert unit types, since pint takes care of that
# except that mypy hates it because pint isn't typed properly...
# No need to add errors since there's just the one
combined_dG.append((comp[0] + corr[0], comp[1])) # type: ignore[operator]
return combined_dG
def get_estimate(self) -> Quantity:
"""Get the binding free energy estimate for this calculation.
Returns
-------
dG : openff.units.Quantity
The binding free energy. This is a Quantity defined with units.
"""
def _get_average(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.average(dGs) * u
individual_estimates = self.get_individual_estimates()
complex_dG = _get_average(
self._add_complex_standard_state_corr(
individual_estimates["complex"],
individual_estimates["standard_state_correction"],
)
)
solv_dG = _get_average(individual_estimates["solvent"])
return -complex_dG + solv_dG
def get_uncertainty(self) -> Quantity:
"""Get the binding free energy error for this calculation.
Returns
-------
err : openff.units.Quantity
The standard deviation between estimates of the binding free
energy. This is a Quantity defined with units.
"""
def _get_stdev(estimates):
# Get the unit value of the first value in the estimates
u = estimates[0][0].u
# Loop through estimates and get the free energy values
# in the unit of the first estimate
dGs = [i[0].to(u).m for i in estimates]
return np.std(dGs) * u
individual_estimates = self.get_individual_estimates()
complex_err = _get_stdev(
self._add_complex_standard_state_corr(
individual_estimates["complex"],
individual_estimates["standard_state_correction"],
)
)
solv_err = _get_stdev(individual_estimates["solvent"])
# return the combined error
return np.sqrt(complex_err**2 + solv_err**2)
def restraint_geometries(self) -> list[BoreschRestraintGeometry]:
"""
Get a list of the restraint geometries for the
complex simulations. These define the atoms that have
been restrained in the system.
Returns
-------
geometries : list[dict[str, Any]]
A list of dictionaries containing the details of the atoms
in the system that are involved in the restraint.
"""
geometries = [
BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry"])
for pus in self.data["complex"].values()
]
return geometries

View File

@@ -0,0 +1,230 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
AHFE Protocol Units --- :mod:`openfe.protocols.openmm_afe.ahfe_units`
=====================================================================
This module defines the ProtocolUnits for the
:class:`AbsoluteSolvationProtocol`.
"""
import logging
from openfe.protocols.openmm_afe.equil_afe_settings import (
SettingsBaseModel,
)
from ..openmm_utils import system_validation
from .base_afe_units import (
BaseAbsoluteMultiStateAnalysisUnit,
BaseAbsoluteMultiStateSimulationUnit,
BaseAbsoluteSetupUnit,
)
logger = logging.getLogger(__name__)
class VacuumComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a vacuum transformation.
Returns
-------
alchem_comps : dict[str, list[Component]]
A list of alchemical components
solv_comp : None
For the gas phase transformation, None will always be returned
for the solvent component of the chemical system.
prot_comp : Optional[ProteinComponent]
The protein component of the system, if it exists.
small_mols : dict[Component, OpenFF Molecule]
The openff Molecules to add to the system. This
is equivalent to the alchemical components in stateA (since
we only allow for disappearing ligands).
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]}
_, prot_comp, _ = system_validation.get_components(stateA)
# Notes:
# 1. Our input state will contain a solvent, we ``None`` that out
# since this is the gas phase unit.
# 2. Our small molecules will always just be the alchemical components
# (of stateA since we enforce only one disappearing ligand)
return alchem_comps, None, prot_comp, off_comps
class VacuumSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a vacuum transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : MDOutputSettings
* simulation_settings : SimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.lambda_settings
settings["engine_settings"] = prot_settings.vacuum_engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings
settings["simulation_settings"] = prot_settings.vacuum_simulation_settings
settings["output_settings"] = prot_settings.vacuum_output_settings
return settings
class AHFEVacuumSetupUnit(VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the vacuum phase of absolute hydration free energy
transformations.
"""
simtype = "vacuum"
class AHFEVacuumSimUnit(
VacuumComponentsMixin, VacuumSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the vacuum phase of absolute hydration
free energy transformations.
"""
simtype = "vacuum"
class AHFEVacuumAnalysisUnit(VacuumSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the vacuum phase
of absolute hydration free energy transformations.
"""
simtype = "vacuum"
class SolventComponentsMixin:
def _get_components(self):
"""
Get the relevant components for a solvent transformation.
Returns
-------
alchem_comps : dict[str, Component]
A list of alchemical components
solv_comp : SolventComponent
The SolventComponent of the system
prot_comp : Optional[ProteinComponent]
The protein component of the system, if it exists.
small_mols : dict[SmallMoleculeComponent: OFFMolecule]
SmallMoleculeComponents to add to the system.
"""
stateA = self._inputs["stateA"]
alchem_comps = self._inputs["alchemical_components"]
solv_comp, prot_comp, small_mols = system_validation.get_components(stateA)
off_comps = {m: m.to_openff() for m in small_mols}
# We don't need to check that solv_comp is not None, otherwise
# an error will have been raised when calling `validate_solvent`
# in the Protocol's `_create`.
# Similarly we don't need to check prot_comp since that's also
# disallowed on create
return alchem_comps, solv_comp, prot_comp, off_comps
class SolventSettingsMixin:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a solvent transformation.
Returns
-------
settings : dict[str, SettingsBaseModel]
A dictionary with the following entries:
* forcefield_settings : OpenMMSystemGeneratorFFSettings
* thermo_settings : ThermoSettings
* charge_settings : OpenFFPartialChargeSettings
* solvation_settings : OpenMMSolvationSettings
* alchemical_settings : AlchemicalSettings
* lambda_settings : LambdaSettings
* engine_settings : OpenMMEngineSettings
* integrator_settings : IntegratorSettings
* equil_simulation_settings : MDSimulationSettings
* equil_output_settings : MDOutputSettings
* simulation_settings : MultiStateSimulationSettings
* output_settings: MultiStateOutputSettings
"""
prot_settings = self._inputs["protocol"].settings # type: ignore[attr-defined]
settings = {}
settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings
settings["thermo_settings"] = prot_settings.thermo_settings
settings["charge_settings"] = prot_settings.partial_charge_settings
settings["solvation_settings"] = prot_settings.solvation_settings
settings["alchemical_settings"] = prot_settings.alchemical_settings
settings["lambda_settings"] = prot_settings.lambda_settings
settings["engine_settings"] = prot_settings.solvent_engine_settings
settings["integrator_settings"] = prot_settings.integrator_settings
settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings
settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings
settings["simulation_settings"] = prot_settings.solvent_simulation_settings
settings["output_settings"] = prot_settings.solvent_output_settings
return settings
class AHFESolventSetupUnit(SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteSetupUnit):
"""
Setup unit for the solvent phase of absolute hydration free energy
transformations.
"""
simtype = "solvent"
class AHFESolventSimUnit(
SolventComponentsMixin, SolventSettingsMixin, BaseAbsoluteMultiStateSimulationUnit
):
"""
Multi-state simulation (e.g. multi replica methods like Hamiltonian
replica exchange) unit for the solvent phase of absolute hydration
free energy transformations.
"""
simtype = "solvent"
class AHFESolventAnalysisUnit(SolventSettingsMixin, BaseAbsoluteMultiStateAnalysisUnit):
"""
Analysis unit for multi-state simulations with the solvent phase
of absolute hydration free energy transformations.
"""
simtype = "solvent"

View File

@@ -0,0 +1,514 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""OpenMM Equilibrium Binding AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_binding_afe_method`
==========================================================================================================
This module implements the necessary methodology tooling to calculate an
absolute binding free energy using OpenMM tools and one of the following
alchemical sampling methods:
* Hamiltonian Replica Exchange
* Self-adjusted mixture sampling
* Independent window sampling
Current limitations
-------------------
* Alchemical species with a net charge are not currently supported.
* Disapearing molecules are only allowed in state A.
* Only small molecules are allowed to act as alchemical molecules.
Acknowledgements
----------------
* This Protocol re-implements components from
`Yank <https://github.com/choderalab/yank>`_.
"""
import logging
import uuid
import warnings
from collections import defaultdict
from typing import Any, Iterable
import gufe
from gufe import (
ChemicalSystem,
ProteinComponent,
SmallMoleculeComponent,
SolventComponent,
settings,
)
from openff.units import unit as offunit
from openfe.due import Doi, due
from openfe.protocols.openmm_afe.equil_afe_settings import (
ABFEPreEquilOutputSettings,
AbsoluteBindingSettings,
AlchemicalSettings,
BoreschRestraintSettings,
IntegratorSettings,
LambdaSettings,
MDSimulationSettings,
MultiStateOutputSettings,
MultiStateSimulationSettings,
OpenFFPartialChargeSettings,
OpenMMEngineSettings,
OpenMMSolvationSettings,
)
from openfe.protocols.openmm_utils import (
settings_validation,
system_validation,
)
from .abfe_units import (
ABFEComplexAnalysisUnit,
ABFEComplexSetupUnit,
ABFEComplexSimUnit,
ABFESolventAnalysisUnit,
ABFESolventSetupUnit,
ABFESolventSimUnit,
)
from .afe_protocol_results import AbsoluteBindingProtocolResult
due.cite(
Doi("10.5281/zenodo.596504"),
description="Yank",
path="openfe.protocols.openmm_afe.equil_binding_afe_method",
cite_module=True,
)
due.cite(
Doi("10.5281/zenodo.596622"),
description="OpenMMTools",
path="openfe.protocols.openmm_afe.equil_binding_afe_method",
cite_module=True,
)
due.cite(
Doi("10.1371/journal.pcbi.1005659"),
description="OpenMM",
path="openfe.protocols.openmm_afe.equil_binding_afe_method",
cite_module=True,
)
logger = logging.getLogger(__name__)
class AbsoluteBindingProtocol(gufe.Protocol):
"""
Absolute binding free energy calculations using OpenMM and OpenMMTools.
See Also
--------
:mod:`openfe.protocols`
:class:`openfe.protocols.openmm_afe.AbsoluteBindingSettings`
:class:`openfe.protocols.openmm_afe.AbsoluteBindingProtocolResult`
:class:`openfe.protocols.openmm_afe.AbsoluteBindingSolventUnit`
:class:`openfe.protocols.openmm_afe.AbsoluteBindingComplexUnit`
"""
result_cls = AbsoluteBindingProtocolResult
_settings_cls = AbsoluteBindingSettings
_settings: AbsoluteBindingSettings
@classmethod
def _default_settings(cls):
"""A dictionary of initial settings for this creating this Protocol
These settings are intended as a suitable starting point for creating
an instance of this protocol. It is recommended, however that care is
taken to inspect and customize these before performing a Protocol.
Returns
-------
Settings
a set of default settings
"""
# fmt: off
return AbsoluteBindingSettings(
protocol_repeats=3,
forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(),
thermo_settings=settings.ThermoSettings(
temperature=298.15 * offunit.kelvin,
pressure=1 * offunit.bar,
),
alchemical_settings=AlchemicalSettings(),
solvent_lambda_settings=LambdaSettings(
lambda_elec=[
0.0, 0.25, 0.5, 0.75, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
],
lambda_vdw=[
0.0, 0.0, 0.0, 0.0, 0.0,
0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0
],
lambda_restraints=[
0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
],
),
complex_lambda_settings=LambdaSettings(
lambda_elec=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0
],
lambda_vdw=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0
],
lambda_restraints=[
0.0, 0.2, 0.4, 0.6, 0.8, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0
],
),
partial_charge_settings=OpenFFPartialChargeSettings(),
complex_solvation_settings=OpenMMSolvationSettings(
solvent_padding=1.0 * offunit.nanometer,
),
solvent_solvation_settings=OpenMMSolvationSettings(),
engine_settings=OpenMMEngineSettings(),
integrator_settings=IntegratorSettings(),
restraint_settings=BoreschRestraintSettings(),
solvent_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=0.1 * offunit.nanosecond,
equilibration_length=0.2 * offunit.nanosecond,
production_length=0.5 * offunit.nanosecond,
),
solvent_equil_output_settings=ABFEPreEquilOutputSettings(),
solvent_simulation_settings=MultiStateSimulationSettings(
n_replicas=14,
equilibration_length=1.0 * offunit.nanosecond,
production_length=10.0 * offunit.nanosecond,
),
solvent_output_settings=MultiStateOutputSettings(
output_structure="alchemical_system.pdb",
output_filename="solvent.nc",
checkpoint_storage_filename="solvent_checkpoint.nc",
),
complex_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=0.25 * offunit.nanosecond,
equilibration_length=0.5 * offunit.nanosecond,
production_length=5.0 * offunit.nanosecond,
),
complex_equil_output_settings=ABFEPreEquilOutputSettings(),
complex_simulation_settings=MultiStateSimulationSettings(
n_replicas=30,
equilibration_length=1 * offunit.nanosecond,
production_length=10.0 * offunit.nanosecond,
),
complex_output_settings=MultiStateOutputSettings(
output_structure="alchemical_system.pdb",
output_filename="complex.nc",
checkpoint_storage_filename="complex_checkpoint.nc",
),
)
# fmt: on
@staticmethod
def _validate_endstates(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
A binding transformation is defined (in terms of gufe components)
as starting from one or more ligands with one protein and solvent,
that then ends up in a state with one less ligand.
Parameters
----------
stateA : ChemicalSystem
The chemical system of end state A
stateB : ChemicalSystem
The chemical system of end state B
Raises
------
ValueError
If stateA & stateB do not contain a ProteinComponent.
If stateA & stateB do not contain a SolventComponent.
If stateA has more than one unique Component.
If the stateA unique Component is not a SmallMoleculeComponent.
If stateB contains any unique Components.
If the alchemical species is charged.
"""
if not (stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent)):
errmsg = "No ProteinComponent found"
raise ValueError(errmsg)
if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)):
errmsg = "No SolventComponent found"
raise ValueError(errmsg)
# Needs gufe 1.3
diff = stateA.component_diff(stateB)
if len(diff[0]) != 1:
errmsg = (
"Only one alchemical species is supported. "
f"Number of unique components found in stateA: {len(diff[0])}."
)
raise ValueError(errmsg)
if not isinstance(diff[0][0], SmallMoleculeComponent):
errmsg = (
"Only dissapearing small molecule components "
"are supported by this protocol. "
f"Found a {type(diff[0][0])}"
)
raise ValueError(errmsg)
# Check that the state A unique isn't charged
if diff[0][0].total_charge != 0:
errmsg = (
"Charged alchemical molecules are not currently "
"supported for solvation free energies. "
f"Molecule total charge: {diff[0][0].total_charge}."
)
raise ValueError(errmsg)
# If there are any alchemical Components in state B
if len(diff[1]) > 0:
errmsg = "Components appearing in state B are not currently supported"
raise ValueError(errmsg)
@staticmethod
def _validate_lambda_schedule(
lambda_settings: LambdaSettings,
simulation_settings: MultiStateSimulationSettings,
) -> None:
"""
Checks that the lambda schedule is set up correctly.
Parameters
----------
lambda_settings : LambdaSettings
the lambda schedule Settings
simulation_settings : MultiStateSimulationSettings
the settings for either the complex or solvent phase
Raises
------
ValueError
If the number of lambda windows differs for electrostatics, sterics,
and restraints.
If the number of replicas does not match the number of lambda windows.
If there are states with naked charges.
"""
lambda_elec = lambda_settings.lambda_elec
lambda_vdw = lambda_settings.lambda_vdw
lambda_restraints = lambda_settings.lambda_restraints
n_replicas = simulation_settings.n_replicas
# Ensure that all lambda components have equal amount of windows
lambda_components = [lambda_vdw, lambda_elec, lambda_restraints]
it = iter(lambda_components)
the_len = len(next(it))
if not all(len(lambda_comp) == the_len for lambda_comp in it):
errmsg = (
"Components elec, vdw, and restraints must have equal amount"
f" of lambda windows. Got {len(lambda_elec)} elec lambda"
f" windows, {len(lambda_vdw)} vdw lambda windows, and"
f"{len(lambda_restraints)} restraints lambda windows."
)
raise ValueError(errmsg)
# Ensure that number of overall lambda windows matches number of lambda
# windows for individual components
if n_replicas != len(lambda_vdw):
errmsg = (
f"Number of replicas {n_replicas} does not equal the"
f" number of lambda windows {len(lambda_vdw)}"
)
raise ValueError(errmsg)
# Check if there are no lambda windows with naked charges
for inx, lam in enumerate(lambda_elec):
if lam < 1 and lambda_vdw[inx] == 1:
errmsg = (
"There are states along this lambda schedule "
"where there are atoms with charges but no LJ "
f"interactions: lambda {inx}: "
f"elec {lam} vdW {lambda_vdw[inx]}"
)
raise ValueError(errmsg)
def _validate(
self,
*,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None,
extends: gufe.ProtocolDAGResult | None = None,
):
# Check we're not extending
if extends is not None:
# This technically should be NotImplementedError
# but gufe.Protocol.validate calls `_validate` wrapped around an
# except for NotImplementedError, so we can't raise it here
raise ValueError("Can't extend simulations yet")
# Check we're not using a mapping, since we're not doing anything with it
if mapping is not None:
wmsg = "A mapping was passed but is not used by this Protocol."
warnings.warn(wmsg)
# Validate the end states & alchemical components
self._validate_endstates(stateA, stateB)
# Validate the complex lambda schedule
self._validate_lambda_schedule(
self.settings.complex_lambda_settings,
self.settings.complex_simulation_settings,
)
# If the complex restraints schedule is all zero, it might be bad
# but we don't dissallow it.
if all([i == 0.0 for i in self.settings.complex_lambda_settings.lambda_restraints]):
wmsg = (
"No restraints are being applied in the complex phase, "
"this will likely lead to problematic results."
)
warnings.warn(wmsg)
# Validate the solvent lambda schedule
self._validate_lambda_schedule(
self.settings.solvent_lambda_settings,
self.settings.solvent_simulation_settings,
)
# If the solvent restraints schedule is not all one, it was likely
# copied from the complex schedule. In this case we just ignore
# the values and let the user know.
# P.S. we don't need to change the settings at this point
# the list gets popped out later in the SolventUnit, because we
# don't have a restraint parameter state.
if any([i != 0.0 for i in self.settings.solvent_lambda_settings.lambda_restraints]):
wmsg = (
"There is an attempt to add restraints in the solvent "
"phase. This protocol does not apply restraints in the "
"solvent phase. These restraint lambda values will be ignored."
)
warnings.warn(wmsg)
# Check nonbond & solvent compatibility
nonbonded_method = self.settings.forcefield_settings.nonbonded_method
# Use the more complete system validation solvent checks
system_validation.validate_solvent(stateA, nonbonded_method)
# Validate solvation settings
settings_validation.validate_openmm_solvation_settings(
self.settings.solvent_solvation_settings
)
settings_validation.validate_openmm_solvation_settings(
self.settings.complex_solvation_settings
)
# Validate integrator things
settings_validation.validate_timestep(
self.settings.forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None = None,
extends: gufe.ProtocolDAGResult | None = None,
) -> list[gufe.ProtocolUnit]:
# Validate inputs
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
# Get the alchemical components
alchem_comps = system_validation.get_alchemical_components(
stateA,
stateB,
)
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
unit_classes = {
"solvent": {
"setup": ABFESolventSetupUnit,
"simulation": ABFESolventSimUnit,
"analysis": ABFESolventAnalysisUnit,
},
"complex": {
"setup": ABFEComplexSetupUnit,
"simulation": ABFEComplexSimUnit,
"analysis": ABFEComplexAnalysisUnit,
},
}
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []}
for phase in ["solvent", "complex"]:
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
setup = unit_classes[phase]["setup"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
)
simulation = unit_classes[phase]["simulation"](
protocol=self,
# only need state A & alchem comps
stateA=stateA,
alchemical_components=alchem_comps,
setup_results=setup,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
)
analysis = unit_classes[phase]["analysis"](
protocol=self,
setup_results=setup,
simulation_results=simulation,
generation=0,
repeat_id=repeat_id,
name=f"ABFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
)
protocol_units[phase] += [setup, simulation, analysis]
return protocol_units["solvent"] + protocol_units["complex"]
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
) -> dict[str, dict[str, Any]]:
# result units will have a repeat_id and generation
# first group according to repeat_id
unsorted_solvent_repeats = defaultdict(list)
unsorted_complex_repeats = defaultdict(list)
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
if pu.outputs["simtype"] == "solvent":
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
else:
unsorted_complex_repeats[pu.outputs["repeat_id"]].append(pu)
repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = {
"solvent": {},
"complex": {},
}
for k, v in unsorted_solvent_repeats.items():
repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_complex_repeats.items():
repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
return repeats

View File

@@ -0,0 +1,531 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""OpenMM Equilibrium Solvation AFE Protocol --- :mod:`openfe.protocols.openmm_afe.equil_solvation_afe_method`
===============================================================================================================
This module implements the necessary methodology tooling to run calculate an
absolute solvation free energy using OpenMM tools and one of the following
alchemical sampling methods:
* Hamiltonian Replica Exchange
* Self-adjusted mixture sampling
* Independent window sampling
Current limitations
-------------------
* Alchemical species with a net charge are not currently supported.
* Disapearing molecules are only allowed in state A. Support for
appearing molecules will be added in due course.
* Only small molecules are allowed to act as alchemical molecules.
Alchemically changing protein or solvent components would induce
perturbations which are too large to be handled by this Protocol.
Acknowledgements
----------------
* Originally based on hydration.py in
`espaloma_charge <https://github.com/choderalab/espaloma_charge>`_
"""
import logging
import uuid
import warnings
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
import gufe
import numpy as np
from gufe import (
ChemicalSystem,
ProteinComponent,
SmallMoleculeComponent,
SolventComponent,
settings,
)
from openff.units import unit as offunit
from openfe.due import Doi, due
from openfe.protocols.openmm_afe.equil_afe_settings import (
AbsoluteSolvationSettings,
AlchemicalSettings,
IntegratorSettings,
LambdaSettings,
MDOutputSettings,
MDSimulationSettings,
MultiStateOutputSettings,
MultiStateSimulationSettings,
OpenFFPartialChargeSettings,
OpenMMEngineSettings,
OpenMMSolvationSettings,
)
from ..openmm_utils import settings_validation, system_validation
from .afe_protocol_results import AbsoluteSolvationProtocolResult
from .ahfe_units import (
AHFESolventAnalysisUnit,
AHFESolventSetupUnit,
AHFESolventSimUnit,
AHFEVacuumAnalysisUnit,
AHFEVacuumSetupUnit,
AHFEVacuumSimUnit,
)
due.cite(
Doi("10.5281/zenodo.596504"),
description="Yank",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.48550/ARXIV.2302.06758"),
description="EspalomaCharge",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.5281/zenodo.596622"),
description="OpenMMTools",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
due.cite(
Doi("10.1371/journal.pcbi.1005659"),
description="OpenMM",
path="openfe.protocols.openmm_afe.equil_solvation_afe_method",
cite_module=True,
)
logger = logging.getLogger(__name__)
class AbsoluteSolvationProtocol(gufe.Protocol):
"""
Absolute solvation free energy calculations using OpenMM and OpenMMTools.
See Also
--------
:mod:`openfe.protocols`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationSettings`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationProtocolResult`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit`
:class:`openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit`
"""
result_cls = AbsoluteSolvationProtocolResult
_settings_cls = AbsoluteSolvationSettings
_settings: AbsoluteSolvationSettings
@classmethod
def _default_settings(cls):
"""A dictionary of initial settings for this creating this Protocol
These settings are intended as a suitable starting point for creating
an instance of this protocol. It is recommended, however that care is
taken to inspect and customize these before performing a Protocol.
Returns
-------
Settings
a set of default settings
"""
return AbsoluteSolvationSettings(
protocol_repeats=3,
solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(),
vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(
nonbonded_method="nocutoff",
),
thermo_settings=settings.ThermoSettings(
temperature=298.15 * offunit.kelvin,
pressure=1 * offunit.bar,
),
alchemical_settings=AlchemicalSettings(),
lambda_settings=LambdaSettings(
lambda_elec=[
0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
lambda_vdw=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24,
0.36, 0.48, 0.6, 0.7, 0.77, 0.85, 1.0],
lambda_restraints=[
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
),
partial_charge_settings=OpenFFPartialChargeSettings(),
solvation_settings=OpenMMSolvationSettings(),
vacuum_engine_settings=OpenMMEngineSettings(),
solvent_engine_settings=OpenMMEngineSettings(),
integrator_settings=IntegratorSettings(),
solvent_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=0.1 * offunit.nanosecond,
equilibration_length=0.2 * offunit.nanosecond,
production_length=0.5 * offunit.nanosecond,
),
solvent_equil_output_settings=MDOutputSettings(
equil_nvt_structure="equil_nvt_structure.pdb",
equil_npt_structure="equil_npt_structure.pdb",
production_trajectory_filename="production_equil.xtc",
log_output="equil_simulation.log",
),
solvent_simulation_settings=MultiStateSimulationSettings(
n_replicas=14,
equilibration_length=1.0 * offunit.nanosecond,
production_length=10.0 * offunit.nanosecond,
),
solvent_output_settings=MultiStateOutputSettings(
output_filename="solvent.nc",
checkpoint_storage_filename="solvent_checkpoint.nc",
),
vacuum_equil_simulation_settings=MDSimulationSettings(
equilibration_length_nvt=None,
equilibration_length=0.2 * offunit.nanosecond,
production_length=0.5 * offunit.nanosecond,
),
vacuum_equil_output_settings=MDOutputSettings(
equil_nvt_structure=None,
equil_npt_structure="equil_structure.pdb",
production_trajectory_filename="production_equil.xtc",
log_output="equil_simulation.log",
),
vacuum_simulation_settings=MultiStateSimulationSettings(
n_replicas=14,
equilibration_length=0.5 * offunit.nanosecond,
production_length=2.0 * offunit.nanosecond,
),
vacuum_output_settings=MultiStateOutputSettings(
output_filename="vacuum.nc",
checkpoint_storage_filename="vacuum_checkpoint.nc",
),
) # fmt: skip
@staticmethod
def _validate_endstates(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
A solvent transformation is defined (in terms of gufe components)
as starting from one or more ligands in solvent and
ending up in a state with one less ligand.
No protein components are allowed.
Parameters
----------
stateA : ChemicalSystem
The chemical system of end state A
stateB : ChemicalSystem
The chemical system of end state B
Raises
------
ValueError
If stateA or stateB contains a ProteinComponent.
If there is no SolventComponent in either stateA or stateB.
If there are alchemical components in state B.
If there are non SmallMoleculeComponent alchemical species.
If there are more than one alchemical species.
If the alchemical species is charged.
Notes
-----
* Currently doesn't support alchemical components in state B.
* Currently doesn't support alchemical components which are not
SmallMoleculeComponents.
* Currently doesn't support more than one alchemical component
being desolvated.
* Currently doesn't support charged alchemical components.
* Solvent must always be present in both end states.
"""
# Check that there are no protein components
if stateA.contains(ProteinComponent) or stateB.contains(ProteinComponent):
errmsg = "Protein components are not allowed for absolute solvation free energies."
raise ValueError(errmsg)
# Check that there is a solvent component in both end states
if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)):
errmsg = "No SolventComponent found in stateA and/or stateB"
raise ValueError(errmsg)
# Now we check the alchemical Components
diff = stateA.component_diff(stateB)
# Check that there's only one state A unique Component
if len(diff[0]) != 1:
errmsg = (
"Only one alchemical species is supported "
"for absolute solvation free energies. "
f"Number of unique components found in stateA: {len(diff[0])}."
)
raise ValueError(errmsg)
# Make sure that the state A unique is an SMC
if not isinstance(diff[0][0], SmallMoleculeComponent):
errmsg = (
"Only dissapearing SmallMoleculeComponents "
"are supported by this protocol. "
f"Found a {type(diff[0][0])}"
)
raise ValueError(errmsg)
# Check that the state A unique isn't charged
if diff[0][0].total_charge != 0:
errmsg = (
"Charged alchemical molecules are not currently "
"supported for solvation free energies. "
f"Molecule total charge: {diff[0][0].total_charge}."
)
raise ValueError(errmsg)
# If there are any alchemical Components in state B
if len(diff[1]) > 0:
errmsg = "Components appearing in state B are not currently supported"
raise ValueError(errmsg)
@staticmethod
def _validate_lambda_schedule(
lambda_settings: LambdaSettings,
simulation_settings: MultiStateSimulationSettings,
) -> None:
"""
Checks that the lambda schedule is set up correctly.
Parameters
----------
lambda_settings : LambdaSettings
the lambda schedule Settings
simulation_settings : MultiStateSimulationSettings
the settings for either the vacuum or solvent phase
Raises
------
ValueError
If the number of lambda windows differs for electrostatics and sterics.
If the number of replicas does not match the number of lambda windows.
If there are states with naked charges.
Warnings
If there are non-zero values for restraints (lambda_restraints).
"""
lambda_elec = lambda_settings.lambda_elec
lambda_vdw = lambda_settings.lambda_vdw
lambda_restraints = lambda_settings.lambda_restraints
n_replicas = simulation_settings.n_replicas
# Ensure that all lambda components have equal amount of windows
lambda_components = [lambda_vdw, lambda_elec, lambda_restraints]
it = iter(lambda_components)
the_len = len(next(it))
if not all(len(lambda_comp) == the_len for lambda_comp in it):
errmsg = (
"Components elec, vdw, and restraints must have equal amount"
f" of lambda windows. Got {len(lambda_elec)} elec lambda"
f" windows, {len(lambda_vdw)} vdw lambda windows, and"
f"{len(lambda_restraints)} restraints lambda windows."
)
raise ValueError(errmsg)
# Ensure that number of overall lambda windows matches number of lambda
# windows for individual components
if n_replicas != len(lambda_vdw):
errmsg = (
f"Number of replicas {n_replicas} does not equal the"
f" number of lambda windows {len(lambda_vdw)}"
)
raise ValueError(errmsg)
# Check if there are lambda windows with naked charges
for inx, lam in enumerate(lambda_elec):
if lam < 1 and lambda_vdw[inx] == 1:
errmsg = (
"There are states along this lambda schedule "
"where there are atoms with charges but no LJ "
f"interactions: lambda {inx}: "
f"elec {lam} vdW {lambda_vdw[inx]}"
)
raise ValueError(errmsg)
# Check if there are lambda windows with non-zero restraints
if len([r for r in lambda_restraints if r != 0]) > 0:
wmsg = (
"Non-zero restraint lambdas applied. The absolute "
"solvation protocol doesn't apply restraints, "
"therefore restraints won't be applied. "
f"Given lambda_restraints: {lambda_restraints}"
)
logger.warning(wmsg)
warnings.warn(wmsg)
def _validate(
self,
*,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
):
# Check we're not extending
if extends is not None:
# This should be a NotImplementedError, but the underlying
# `validate` method wraps a call to `_validate` around a
# NotImplementedError exception guard
raise ValueError("Can't extend simulations yet")
# Check we're not using a mapping, since we're not doing anything with it
if mapping is not None:
wmsg = "A mapping was passed but is not used by this Protocol."
warnings.warn(wmsg)
# Validate the endstates & alchemical components
self._validate_endstates(stateA, stateB)
# Validate the lambda schedule
for solv_sets in (
self.settings.solvent_simulation_settings,
self.settings.vacuum_simulation_settings,
):
self._validate_lambda_schedule(
self.settings.lambda_settings,
solv_sets,
)
# Check nonbond & solvent compatibility
solv_nonbonded_method = self.settings.solvent_forcefield_settings.nonbonded_method
vac_nonbonded_method = self.settings.vacuum_forcefield_settings.nonbonded_method
# Use the more complete system validation solvent checks
system_validation.validate_solvent(stateA, solv_nonbonded_method)
# Gas phase is always gas phase
if vac_nonbonded_method.lower() != "nocutoff":
errmsg = (
"Only the nocutoff nonbonded_method is supported for "
f"vacuum calculations, {vac_nonbonded_method} was "
"passed"
)
raise ValueError(errmsg)
# Validate solvation settings
settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings)
# Check vacuum equilibration MD settings is 0 ns
nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt
if nvt_time is not None:
if not np.allclose(nvt_time, 0 * offunit.nanosecond):
errmsg = "NVT equilibration cannot be run in vacuum simulation"
raise ValueError(errmsg)
# Validate integrator things
settings_validation.validate_timestep(
self.settings.vacuum_forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
settings_validation.validate_timestep(
self.settings.solvent_forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None,
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# Validate inputs
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
# Get the alchemical components
alchem_comps = system_validation.get_alchemical_components(
stateA,
stateB,
)
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
unit_classes = {
"solvent": {
"setup": AHFESolventSetupUnit,
"simulation": AHFESolventSimUnit,
"analysis": AHFESolventAnalysisUnit,
},
"vacuum": {
"setup": AHFEVacuumSetupUnit,
"simulation": AHFEVacuumSimUnit,
"analysis": AHFEVacuumAnalysisUnit,
},
}
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "vacuum": []}
for phase in ["solvent", "vacuum"]:
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
setup = unit_classes[phase]["setup"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Setup: {alchname} {phase} leg: repeat {i} generation 0",
)
simulation = unit_classes[phase]["simulation"](
protocol=self,
# only need state A & alchem comps
stateA=stateA,
alchemical_components=alchem_comps,
setup_results=setup,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Simulation: {alchname} {phase} leg: repeat {i} generation 0",
)
analysis = unit_classes[phase]["analysis"](
protocol=self,
setup_results=setup,
simulation_results=simulation,
generation=0,
repeat_id=repeat_id,
name=f"AHFE Analysis: {alchname} {phase} leg, repeat {i} generation 0",
)
protocol_units[phase] += [setup, simulation, analysis]
return protocol_units["solvent"] + protocol_units["vacuum"]
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
) -> dict[str, dict[str, Any]]:
# result units will have a repeat_id and generation
# first group according to repeat_id
unsorted_solvent_repeats = defaultdict(list)
unsorted_vacuum_repeats = defaultdict(list)
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
if pu.outputs["simtype"] == "solvent":
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
else:
unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu)
repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = {
"solvent": {},
"vacuum": {},
}
for k, v in unsorted_solvent_repeats.items():
repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_vacuum_repeats.items():
repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
return repeats

View File

@@ -0,0 +1,12 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from . import _rfe_utils
from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings
from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult
from .hybridtop_protocols import RelativeHybridTopologyProtocol
from .hybridtop_units import (
HybridTopologyMultiStateAnalysisUnit,
HybridTopologyMultiStateSimulationUnit,
HybridTopologySetupUnit,
)

View File

@@ -26,14 +26,15 @@ from .lambdaprotocol import RelativeAlchemicalState
logger = logging.getLogger(__name__)
class HybridCompatibilityMixin(object):
class HybridCompatibilityMixin:
"""
Mixin that allows the MultistateSampler to accommodate the situation where
unsampled endpoints have a different number of degrees of freedom.
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
self._hybrid_factory = hybrid_factory
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
self._hybrid_system = hybrid_system
self._hybrid_positions = hybrid_positions
super(HybridCompatibilityMixin, self).__init__(*args, **kwargs)
def setup(self, reporter, lambda_protocol,
@@ -73,15 +74,17 @@ class HybridCompatibilityMixin(object):
"""
n_states = len(lambda_protocol.lambda_schedule)
hybrid_system = self._factory.hybrid_system
lambda_zero_state = RelativeAlchemicalState.from_system(self._hybrid_system)
lambda_zero_state = RelativeAlchemicalState.from_system(hybrid_system)
thermostate = ThermodynamicState(
self._hybrid_system,
temperature=temperature
)
thermostate = ThermodynamicState(hybrid_system,
temperature=temperature)
compound_thermostate = CompoundThermodynamicState(
thermostate,
composable_states=[lambda_zero_state])
thermostate,
composable_states=[lambda_zero_state]
)
# create lists for storing thermostates and sampler states
thermodynamic_state_list = []
@@ -105,16 +108,20 @@ class HybridCompatibilityMixin(object):
raise ValueError(errmsg)
# starting with the hybrid factory positions
box = hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(self._factory.hybrid_positions,
box_vectors=box)
box = self._hybrid_system.getDefaultPeriodicBoxVectors()
sampler_state = SamplerState(
self._hybrid_positions,
box_vectors=box
)
# Loop over the lambdas and create & store a compound thermostate at
# that lambda value
for lambda_val in lambda_schedule:
compound_thermostate_copy = copy.deepcopy(compound_thermostate)
compound_thermostate_copy.set_alchemical_parameters(
lambda_val, lambda_protocol)
lambda_val,
lambda_protocol
)
thermodynamic_state_list.append(compound_thermostate_copy)
# now generating a sampler_state for each thermodyanmic state,
@@ -143,7 +150,8 @@ class HybridCompatibilityMixin(object):
# generating unsampled endstates
unsampled_dispersion_endstates = create_endstates(
copy.deepcopy(thermodynamic_state_list[0]),
copy.deepcopy(thermodynamic_state_list[-1]))
copy.deepcopy(thermodynamic_state_list[-1])
)
self.create(thermodynamic_states=thermodynamic_state_list,
sampler_states=sampler_state_list, storage=reporter,
unsampled_thermodynamic_states=unsampled_dispersion_endstates)
@@ -159,10 +167,13 @@ class HybridRepexSampler(HybridCompatibilityMixin,
number of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridRepexSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs)
self._factory = hybrid_factory
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
@@ -171,11 +182,13 @@ class HybridSAMSSampler(HybridCompatibilityMixin, sams.SAMSSampler):
of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridSAMSSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory
class HybridMultiStateSampler(HybridCompatibilityMixin,
@@ -184,11 +197,13 @@ class HybridMultiStateSampler(HybridCompatibilityMixin,
MultiStateSampler that supports unsample end states with a different
number of positions
"""
def __init__(self, *args, hybrid_factory=None, **kwargs):
def __init__(self, *args, hybrid_system, hybrid_positions, **kwargs):
super(HybridMultiStateSampler, self).__init__(
*args, hybrid_factory=hybrid_factory, **kwargs
*args,
hybrid_system=hybrid_system,
hybrid_positions=hybrid_positions,
**kwargs
)
self._factory = hybrid_factory
def create_endstates(first_thermostate, last_thermostate):

View File

@@ -0,0 +1,26 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""Equilibrium Relative Free Energy Protocol using OpenMM and OpenMMTools in a
Perses-like manner.
This module implements the necessary tooling to calculate the
relative free energy of a ligand transformation using OpenMM tools and one of
the following methods:
- Hamiltonian Replica Exchange
- Self-adjusted mixture sampling
- Independent window sampling
Acknowledgements
----------------
This Protocol is based on, and leverages components originating from
the Perses toolkit (https://github.com/choderalab/perses).
"""
from .equil_rfe_settings import RelativeHybridTopologyProtocolSettings
from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult
from .hybridtop_protocols import RelativeHybridTopologyProtocol
from .hybridtop_units import (
HybridTopologyMultiStateAnalysisUnit,
HybridTopologyMultiStateSimulationUnit,
HybridTopologySetupUnit,
)

View File

@@ -0,0 +1,241 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
ProtocolUnitResults for Hybrid Topology methods using
OpenMM and OpenMMTools in a Perses-like manner.
"""
import logging
import pathlib
import warnings
from typing import Optional, Union
import gufe
import numpy as np
import numpy.typing as npt
from openff.units import Quantity
from openmmtools import multistate
logger = logging.getLogger(__name__)
class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult):
"""
Protocol results with the output of a RelativeHybridTopologyProtocol.
"""
def __init__(self, **data):
super().__init__(**data)
# data is mapping of str(repeat_id): list[protocolunitresults]
# TODO: Detect when we have extensions and stitch these together?
if any(len(pur_list) > 2 for pur_list in self.data.values()):
raise NotImplementedError("Can't stitch together results yet")
@staticmethod
def compute_mean_estimate(dGs: list[Quantity]) -> Quantity:
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid an edge case where each value was in different units
vals = np.asarray([dG.to(u).m for dG in dGs])
return np.average(vals) * u
def get_estimate(self) -> Quantity:
"""Average free energy difference of this transformation
Returns
-------
dG : openff.units.Quantity
The free energy difference between the first and last states. This is
a Quantity defined with units.
"""
# TODO: Check this holds up completely for SAMS.
dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()]
return self.compute_mean_estimate(dGs)
@staticmethod
def compute_uncertainty(dGs: list[Quantity]) -> Quantity:
u = dGs[0].u
# convert all values to units of the first value, then take average of magnitude
# this would avoid a screwy case where each value was in different units
vals = np.asarray([dG.to(u).m for dG in dGs])
return np.std(vals) * u
def get_uncertainty(self) -> Quantity:
"""The uncertainty/error in the dG value: The std of the estimates of
each independent repeat
"""
dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()]
return self.compute_uncertainty(dGs)
def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]:
"""Return a list of tuples containing the individual free energy
estimates and associated MBAR errors for each repeat.
Returns
-------
dGs : list[tuple[openff.units.Quantity]]
n_replicate simulation list of tuples containing the free energy
estimates (first entry) and associated MBAR estimate errors
(second entry).
"""
dGs = [
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
for pus in self.data.values()
]
return dGs
def get_forward_and_reverse_energy_analysis(
self,
) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]:
"""
Get a list of forward and reverse analysis of the free energies
for each repeat using uncorrelated production samples.
The returned dicts have keys:
'fractions' - the fraction of data used for this estimate
'forward_DGs', 'reverse_DGs' - for each fraction of data, the estimate
'forward_dDGs', 'reverse_dDGs' - for each estimate, the uncertainty
The 'fractions' values are a numpy array, while the other arrays are
Quantity arrays, with units attached.
If the list entry is ``None`` instead of a dictionary, this indicates
that the analysis could not be carried out for that repeat. This
is most likely caused by MBAR convergence issues when attempting to
calculate free energies from too few samples.
Returns
-------
forward_reverse : list[Optional[dict[str, Union[npt.NDArray, openff.units.Quantity]]]]
Raises
------
UserWarning
If any of the forward and reverse entries are ``None``.
"""
forward_reverse = [
pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values()
]
if None in forward_reverse:
wmsg = (
"One or more ``None`` entries were found in the list of "
"forward and reverse analyses. This is likely caused by "
"an MBAR convergence failure caused by too few independent "
"samples when calculating the free energies of the 10% "
"timeseries slice."
)
warnings.warn(wmsg)
return forward_reverse
def get_overlap_matrices(self) -> list[dict[str, npt.NDArray]]:
"""
Return a list of dictionary containing the MBAR overlap estimates
calculated for each repeat.
Returns
-------
overlap_stats : list[dict[str, npt.NDArray]]
A list of dictionaries containing the following keys:
* ``scalar``: One minus the largest nontrivial eigenvalue
* ``eigenvalues``: The sorted (descending) eigenvalues of the
overlap matrix
* ``matrix``: Estimated overlap matrix of observing a sample from
state i in state j
"""
# Loop through and get the repeats and get the matrices
overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()]
return overlap_stats
def get_replica_transition_statistics(self) -> list[dict[str, npt.NDArray]]:
"""The replica lambda state transition statistics for each repeat.
Note
----
This is currently only available in cases where a replica exchange
simulation was run.
Returns
-------
repex_stats : list[dict[str, npt.NDArray]]
A list of dictionaries containing the following:
* ``eigenvalues``: The sorted (descending) eigenvalues of the
lambda state transition matrix
* ``matrix``: The transition matrix estimate of a replica switching
from state i to state j.
"""
try:
repex_stats = [
pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values()
]
except KeyError:
errmsg = "Replica exchange statistics were not found, did you run a repex calculation?"
raise ValueError(errmsg)
return repex_stats
def get_replica_states(self) -> list[npt.NDArray]:
"""
Returns the timeseries of replica states for each repeat.
Returns
-------
replica_states : List[npt.NDArray]
List of replica states for each repeat
"""
def is_file(filename: str):
p = pathlib.Path(filename)
if not p.exists():
errmsg = f"File could not be found {p}"
raise ValueError(errmsg)
return p
replica_states = []
for pus in self.data.values():
nc = is_file(pus[0].outputs["trajectory"])
dir_path = nc.parents[0]
chk = is_file(pus[0].outputs["checkpoint"]).name
reporter = multistate.MultiStateReporter(
storage=nc, checkpoint_storage=chk, open_mode="r"
)
replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states()))
reporter.close()
return replica_states
def equilibration_iterations(self) -> list[float]:
"""
Returns the number of equilibration iterations for each repeat
of the calculation.
Returns
-------
equilibration_lengths : list[float]
"""
equilibration_lengths = [
pus[0].outputs["equilibration_iterations"] for pus in self.data.values()
]
return equilibration_lengths
def production_iterations(self) -> list[float]:
"""
Returns the number of uncorrelated production samples for each
repeat of the calculation.
Returns
-------
production_lengths : list[float]
"""
production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()]
return production_lengths

View File

@@ -0,0 +1,662 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
"""
Hybrid Topology Protocols using OpenMM and OpenMMTools in a Perses-like manner.
Acknowledgements
----------------
These Protocols are based on, and leverages components originating from
the Perses toolkit (https://github.com/choderalab/perses).
"""
from __future__ import annotations
import logging
import uuid
import warnings
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
import gufe
import numpy as np
from gufe import (
ChemicalSystem,
Component,
ComponentMapping,
LigandAtomMapping,
ProteinComponent,
SmallMoleculeComponent,
SolventComponent,
settings,
)
from openff.units import unit as offunit
from openfe.due import Doi, due
from ..openmm_utils import (
settings_validation,
system_validation,
)
from .equil_rfe_settings import (
AlchemicalSettings,
IntegratorSettings,
LambdaSettings,
MultiStateOutputSettings,
MultiStateSimulationSettings,
OpenFFPartialChargeSettings,
OpenMMEngineSettings,
OpenMMSolvationSettings,
RelativeHybridTopologyProtocolSettings,
)
from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult
from .hybridtop_units import (
HybridTopologyMultiStateAnalysisUnit,
HybridTopologyMultiStateSimulationUnit,
HybridTopologySetupUnit,
)
logger = logging.getLogger(__name__)
due.cite(
Doi("10.5281/zenodo.1297683"),
description="Perses",
path="openfe.protocols.openmm_rfe.hybridtop_protocols",
cite_module=True,
)
due.cite(
Doi("10.5281/zenodo.596622"),
description="OpenMMTools",
path="openfe.protocols.openmm_rfe.hybridtop_protocols",
cite_module=True,
)
due.cite(
Doi("10.1371/journal.pcbi.1005659"),
description="OpenMM",
path="openfe.protocols.openmm_rfe.hybridtop_protocols",
cite_module=True,
)
class RelativeHybridTopologyProtocol(gufe.Protocol):
"""
Relative Free Energy calculations using a Hybrid Topology scheme
using OpenMM and OpenMMTools.
Based on `Perses <https://github.com/choderalab/perses>`_
See Also
--------
:mod:`openfe.protocols`
:class:`openfe.protocols.openmm_rfe.RelativeHybridTopologySettings`
:class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult`
:class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit`
"""
result_cls = RelativeHybridTopologyProtocolResult
_settings_cls = RelativeHybridTopologyProtocolSettings
_settings: RelativeHybridTopologyProtocolSettings
@classmethod
def _default_settings(cls):
"""A dictionary of initial settings for this creating this Protocol
These settings are intended as a suitable starting point for creating
an instance of this protocol. It is recommended, however that care is
taken to inspect and customize these before performing a Protocol.
Returns
-------
Settings
a set of default settings
"""
return RelativeHybridTopologyProtocolSettings(
protocol_repeats=3,
forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(),
thermo_settings=settings.ThermoSettings(
temperature=298.15 * offunit.kelvin,
pressure=1 * offunit.bar,
),
partial_charge_settings=OpenFFPartialChargeSettings(),
solvation_settings=OpenMMSolvationSettings(),
alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"),
lambda_settings=LambdaSettings(),
simulation_settings=MultiStateSimulationSettings(
equilibration_length=1.0 * offunit.nanosecond,
production_length=5.0 * offunit.nanosecond,
),
engine_settings=OpenMMEngineSettings(),
integrator_settings=IntegratorSettings(),
output_settings=MultiStateOutputSettings(),
)
@classmethod
def _adaptive_settings(
cls,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.LigandAtomMapping | list[gufe.LigandAtomMapping],
initial_settings: None | RelativeHybridTopologyProtocolSettings = None,
) -> RelativeHybridTopologyProtocolSettings:
"""
Get the recommended OpenFE settings for this protocol based on the input states involved in the
transformation.
These are intended as a suitable starting point for creating an instance of this protocol, which can be further
customized before performing a Protocol.
Parameters
----------
stateA : ChemicalSystem
The initial state of the transformation.
stateB : ChemicalSystem
The final state of the transformation.
mapping : LigandAtomMapping | list[LigandAtomMapping]
The mapping(s) between transforming components in stateA and stateB.
initial_settings : None | RelativeHybridTopologyProtocolSettings, optional
Initial settings to base the adaptive settings on. If None, default settings are used.
Returns
-------
RelativeHybridTopologyProtocolSettings
The recommended settings for this protocol based on the input states.
Notes
-----
- If the transformation involves a change in net charge, the settings are adapted to use a more expensive
protocol with 22 lambda windows and 20 ns production length per window.
- If both states contain a ProteinComponent, the solvation padding is set to 1 nm.
- If initial_settings is provided, the adaptive settings are based on a copy of these settings.
"""
# use initial settings or default settings
# this is needed for the CLI so we don't override user settings
if initial_settings is not None:
protocol_settings = initial_settings.model_copy(deep=True)
else:
protocol_settings = cls.default_settings()
if isinstance(mapping, list):
mapping = mapping[0]
if mapping.get_alchemical_charge_difference() != 0:
# apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated
# <https://github.com/OpenFreeEnergy/IndustryBenchmarks2024/blob/2df362306e2727321d55d16e06919559338c4250/industry_benchmarks/utils/plan_rbfe_network.py#L128-L146>
info = (
"Charge changing transformation between ligands "
f"{mapping.componentA.name} and {mapping.componentB.name}. "
"A more expensive protocol with 22 lambda windows, sampled "
"for 20 ns each, will be used here."
)
logger.info(info)
protocol_settings.alchemical_settings.explicit_charge_correction = True
protocol_settings.simulation_settings.production_length = 20 * offunit.nanosecond
protocol_settings.simulation_settings.n_replicas = 22
protocol_settings.lambda_settings.lambda_windows = 22
# adapt the solvation padding based on the system components
if stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent):
protocol_settings.solvation_settings.solvent_padding = 1 * offunit.nanometer
return protocol_settings
@staticmethod
def _validate_endstates(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
Validates the end states for the RFE protocol.
Parameters
----------
stateA : ChemicalSystem
The chemical system of end state A.
stateB : ChemicalSystem
The chemical system of end state B.
Raises
------
ValueError
* If either state contains more than one unique Component.
* If unique components are not SmallMoleculeComponents.
"""
# Get the difference in Components between each state
diff = stateA.component_diff(stateB)
for i, entry in enumerate(diff):
state_label = "A" if i == 0 else "B"
# Check that there is only one unique Component in each state
if len(entry) != 1:
errmsg = (
"Only one alchemical component is allowed per end state. "
f"Found {len(entry)} in state {state_label}."
)
raise ValueError(errmsg)
# Check that the unique Component is a SmallMoleculeComponent
if not isinstance(entry[0], SmallMoleculeComponent):
errmsg = (
f"Alchemical component in state {state_label} is of type "
f"{type(entry[0])}, but only SmallMoleculeComponents "
"transformations are currently supported."
)
raise ValueError(errmsg)
@staticmethod
def _validate_mapping(
mapping: Optional[Union[ComponentMapping, list[ComponentMapping]]],
alchemical_components: dict[str, list[Component]],
) -> None:
"""
Validates that the provided mapping(s) are suitable for the RFE protocol.
Parameters
----------
mapping : Optional[Union[ComponentMapping, list[ComponentMapping]]]
all mappings between transforming components.
alchemical_components : dict[str, list[Component]]
Dictionary contatining the alchemical components for
states A and B.
Raises
------
ValueError
* If there are more than one mapping or mapping is None
* If the mapping components are not in the alchemical components.
UserWarning
* Mappings which involve element changes in core atoms
"""
# if a single mapping is provided, convert to list
if isinstance(mapping, ComponentMapping):
mapping = [mapping]
# For now we only support a single mapping
if mapping is None or len(mapping) > 1:
errmsg = "A single LigandAtomMapping is expected for this Protocol"
raise ValueError(errmsg)
# check that the mapping components are in the alchemical components
for m in mapping:
for state in ["A", "B"]:
comp = getattr(m, f"component{state}")
if comp not in alchemical_components[f"state{state}"]:
raise ValueError(
f"Mapping component{state} {comp} not "
f"in alchemical components of state{state}"
)
# TODO: remove - this is now the default behaviour?
# Check for element changes in mappings
for m in mapping:
molA = m.componentA.to_rdkit()
molB = m.componentB.to_rdkit()
for i, j in m.componentA_to_componentB.items():
atomA = molA.GetAtomWithIdx(i)
atomB = molB.GetAtomWithIdx(j)
if atomA.GetAtomicNum() != atomB.GetAtomicNum():
wmsg = (
f"Element change in mapping between atoms "
f"Ligand A: {i} (element {atomA.GetAtomicNum()}) and "
f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n"
"No mass scaling is attempted in the hybrid topology, "
"the average mass of the two atoms will be used in the "
"simulation"
)
logger.warning(wmsg)
warnings.warn(wmsg)
@staticmethod
def _validate_smcs(
stateA: ChemicalSystem,
stateB: ChemicalSystem,
) -> None:
"""
Validates the SmallMoleculeComponents.
Parameters
----------
stateA : ChemicalSystem
The chemical system of end state A.
stateB : ChemicalSystem
The chemical system of end state B.
Raises
------
ValueError
* If there are isomorphic SmallMoleculeComponents with
different charges within a given ChemicalSystem.
"""
smcs_A = stateA.get_components_of_type(SmallMoleculeComponent)
smcs_B = stateB.get_components_of_type(SmallMoleculeComponent)
smcs_all = list(set(smcs_A).union(set(smcs_B)))
def _equal_charges(moli, molj):
# Base case, both molecules don't have charges
if (moli.partial_charges is None) & (molj.partial_charges is None):
return True
# If either is None but not the other
if (moli.partial_charges is None) ^ (molj.partial_charges is None):
return False
# Check if the charges are close to each other
return np.allclose(moli.partial_charges, molj.partial_charges)
clashes = []
for smcs in [smcs_A, smcs_B]:
offmols = [m.to_openff() for m in smcs]
for i, moli in enumerate(offmols):
for molj in offmols:
if moli.is_isomorphic_with(molj):
if not _equal_charges(moli, molj):
clashes.append(smcs[i])
if len(clashes) > 0:
errmsg = (
"Found SmallMoleculeComponents that are isomorphic "
"but with different charges, this is not currently allowed. "
f"Affected components: {clashes}"
)
raise ValueError(errmsg)
@staticmethod
def _validate_charge_difference(
mapping: LigandAtomMapping,
nonbonded_method: str,
explicit_charge_correction: bool,
solvent_component: SolventComponent | None,
):
"""
Validates the net charge difference between the two states.
Parameters
----------
mapping : dict[str, ComponentMapping]
Dictionary of mappings between transforming components.
nonbonded_method : str
The OpenMM nonbonded method used for the simulation.
explicit_charge_correction : bool
Whether or not to use an explicit charge correction.
solvent_component : openfe.SolventComponent | None
The SolventComponent of the simulation.
Raises
------
ValueError
* If an explicit charge correction is attempted and the
nonbonded method is not PME.
* If the absolute charge difference is greater than one
and an explicit charge correction is attempted.
* If an explicit charge correction is attempted and there is no
solvent present.
UserWarning
* If there is any charge difference.
"""
difference = mapping.get_alchemical_charge_difference()
if abs(difference) == 0:
return
if not explicit_charge_correction:
wmsg = (
f"A charge difference of {difference} is observed "
"between the end states. No charge correction has "
"been requested, please account for this in your "
"final results."
)
logger.warning(wmsg)
warnings.warn(wmsg)
return
if solvent_component is None:
errmsg = "Cannot use explicit charge correction without solvent"
raise ValueError(errmsg)
# We implicitly check earlier that we have to have pme for a solvated
# system, so we only need to check the nonbonded method here
if nonbonded_method.lower() != "pme":
errmsg = "Explicit charge correction when not using PME is not currently supported."
raise ValueError(errmsg)
if abs(difference) > 1:
errmsg = (
f"A charge difference of {difference} is observed "
"between the end states and an explicit charge "
"correction has been requested. Unfortunately "
"only absolute differences of 1 are supported."
)
raise ValueError(errmsg)
ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference]
wmsg = (
f"A charge difference of {difference} is observed "
"between the end states. This will be addressed by "
f"transforming a water into a {ion} ion"
)
logger.info(wmsg)
@staticmethod
def _validate_simulation_settings(
simulation_settings: MultiStateSimulationSettings,
integrator_settings: IntegratorSettings,
output_settings: MultiStateOutputSettings,
):
"""
Validate various simulation settings, including but not limited to
timestep conversions, and output file write frequencies.
Parameters
----------
simulation_settings : MultiStateSimulationSettings
The sampler simulation settings.
integrator_settings : IntegratorSettings
Settings defining the behaviour of the integrator.
output_settings : MultiStateOutputSettings
Settings defining the simulation file writing behaviour.
Raises
------
ValueError
* If the
"""
steps_per_iteration = settings_validation.convert_steps_per_iteration(
simulation_settings=simulation_settings,
integrator_settings=integrator_settings,
)
_ = settings_validation.get_simsteps(
sim_length=simulation_settings.equilibration_length,
timestep=integrator_settings.timestep,
mc_steps=steps_per_iteration,
)
_ = settings_validation.get_simsteps(
sim_length=simulation_settings.production_length,
timestep=integrator_settings.timestep,
mc_steps=steps_per_iteration,
)
_ = settings_validation.convert_checkpoint_interval_to_iterations(
checkpoint_interval=output_settings.checkpoint_interval,
time_per_iteration=simulation_settings.time_per_iteration,
)
if output_settings.positions_write_frequency is not None:
_ = settings_validation.divmod_time_and_check(
numerator=output_settings.positions_write_frequency,
denominator=simulation_settings.time_per_iteration,
numerator_name="output settings' positions_write_frequency",
denominator_name="sampler settings' time_per_iteration",
)
if output_settings.velocities_write_frequency is not None:
_ = settings_validation.divmod_time_and_check(
numerator=output_settings.velocities_write_frequency,
denominator=simulation_settings.time_per_iteration,
numerator_name="output settings' velocities_write_frequency",
denominator_name="sampler settings' time_per_iteration",
)
_, _ = settings_validation.convert_real_time_analysis_iterations(
simulation_settings=simulation_settings,
)
def _validate(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None,
extends: gufe.ProtocolDAGResult | None = None,
) -> None:
# Check we're not trying to extend
if extends:
# This technically should be NotImplementedError
# but gufe.Protocol.validate calls `_validate` wrapped around an
# except for NotImplementedError, so we can't raise it here
raise ValueError("Can't extend simulations yet")
# Validate the end states
self._validate_endstates(stateA, stateB)
# Validate the mapping
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
self._validate_mapping(mapping, alchem_comps)
# Validate the small molecule components
self._validate_smcs(stateA, stateB)
# Validate solvent component
nonbond = self.settings.forcefield_settings.nonbonded_method
system_validation.validate_solvent(stateA, nonbond)
# Validate solvation settings
settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings)
# Validate protein component
system_validation.validate_protein(stateA)
# Validate charge difference
# Note: validation depends on the mapping & solvent component checks
if stateA.contains(SolventComponent):
solv_comp = stateA.get_components_of_type(SolventComponent)[0]
else:
solv_comp = None
self._validate_charge_difference(
mapping=mapping[0] if isinstance(mapping, list) else mapping,
nonbonded_method=self.settings.forcefield_settings.nonbonded_method,
explicit_charge_correction=self.settings.alchemical_settings.explicit_charge_correction,
solvent_component=solv_comp,
)
# Validate integrator things
settings_validation.validate_timestep(
self.settings.forcefield_settings.hydrogen_mass,
self.settings.integrator_settings.timestep,
)
# Validate simulation & output settings
self._validate_simulation_settings(
self.settings.simulation_settings,
self.settings.integrator_settings,
self.settings.output_settings,
)
# Validate alchemical settings
# PR #125 temporarily pin lambda schedule spacing to n_replicas
if (
self.settings.simulation_settings.n_replicas
!= self.settings.lambda_settings.lambda_windows
):
errmsg = (
"Number of replicas in ``simulation_settings``: "
f"{self.settings.simulation_settings.n_replicas} must equal "
"the number of lambda windows in lambda_settings: "
f"{self.settings.lambda_settings.lambda_windows}."
)
raise ValueError(errmsg)
def _create(
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]],
extends: Optional[gufe.ProtocolDAGResult] = None,
) -> list[gufe.ProtocolUnit]:
# validate inputs
self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends)
# get alchemical components and mapping
alchem_comps = system_validation.get_alchemical_components(stateA, stateB)
ligandmapping = mapping[0] if isinstance(mapping, list) else mapping
# actually create and return Units
Anames = ",".join(c.name for c in alchem_comps["stateA"])
Bnames = ",".join(c.name for c in alchem_comps["stateB"])
# DAG dependency is setup -> simulation -> analysis
# |--------------------->
setup_units = []
simulation_units = []
analysis_units = []
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
setup = HybridTopologySetupUnit(
protocol=self,
stateA=stateA,
stateB=stateB,
ligandmapping=ligandmapping,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=(f"HybridTopology Setup: {Anames} to {Bnames} repeat {i} generation 0"),
)
simulation = HybridTopologyMultiStateSimulationUnit(
protocol=self,
setup_results=setup,
generation=0,
repeat_id=repeat_id,
name=(f"HybridTopology Simulation: {Anames} to {Bnames} repeat {i} generation 0"),
)
analysis = HybridTopologyMultiStateAnalysisUnit(
protocol=self,
setup_results=setup,
simulation_results=simulation,
generation=0,
repeat_id=repeat_id,
name=(f"HybridTopology Analysis: {Anames} to {Bnames} repeat {i} generation 0"),
)
setup_units.append(setup)
simulation_units.append(simulation)
analysis_units.append(analysis)
return [*setup_units, *simulation_units, *analysis_units]
def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]:
# result units will have a repeat_id and generations within this repeat_id
# first group according to repeat_id
unsorted_repeats = defaultdict(list)
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
# We only need the analysis units that are ok
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
unsorted_repeats[pu.outputs["repeat_id"]].append(pu)
# then sort by generation within each repeat_id list
repeats: dict[str, list[gufe.ProtocolUnitResult]] = {}
for k, v in unsorted_repeats.items():
repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
# returns a dict of repeat_id: sorted list of ProtocolUnitResult
return repeats

File diff suppressed because it is too large Load Diff

View File

@@ -58,6 +58,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import (
from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit
from openfe.protocols.openmm_utils import omm_compute
from openfe.protocols.openmm_utils.omm_settings import SettingsBaseModel
from openfe.protocols.openmm_utils.serialization import deserialize
from openfe.utils import without_oechem_backend
from ..openmm_utils import (
@@ -66,7 +67,7 @@ from ..openmm_utils import (
settings_validation,
system_creation,
)
from .utils import SepTopParameterState, deserialize
from .utils import SepTopParameterState
logger = logging.getLogger(__name__)

View File

@@ -81,6 +81,7 @@ from openfe.protocols.openmm_septop.equil_septop_settings import (
SepTopSettings,
SettingsBaseModel,
)
from openfe.protocols.openmm_utils.serialization import serialize
from openfe.protocols.restraint_utils import geometry
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
from openfe.protocols.restraint_utils.openmm import omm_restraints
@@ -96,7 +97,6 @@ from ..restraint_utils.settings import (
DistanceRestraintSettings,
)
from .base import BaseSepTopRunUnit, BaseSepTopSetupUnit, _pre_equilibrate
from .utils import serialize
due.cite(
Doi("10.1021/acs.jctc.3c00282"),
@@ -2021,8 +2021,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
"topology": topology_file,
"standard_state_correction_A": corr_A.to("kilocalorie_per_mole"),
"standard_state_correction_B": corr_B.to("kilocalorie_per_mole"),
"restraint_geometry_A": restraint_geom_A.dict(),
"restraint_geometry_B": restraint_geom_B.dict(),
"restraint_geometry_A": restraint_geom_A.model_dump(),
"restraint_geometry_B": restraint_geom_B.model_dump(),
}
def _execute(

View File

@@ -1,84 +1,7 @@
import os
import pathlib
from openmmtools import states
from openmmtools.states import GlobalParameterState
def serialize(item, filename: pathlib.Path):
"""
Serialize an OpenMM System, State, or Integrator.
Parameters
----------
item : System, State, or Integrator
The thing to be serialized
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer
# Create parent directory if it doesn't exist
filename_basedir = filename.parent
if not filename_basedir.exists():
os.makedirs(filename_basedir)
if filename.suffix == ".gz":
import gzip
with gzip.open(filename, mode="wb") as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing.encode())
elif filename.suffix == ".bz2":
import bz2
with bz2.open(filename, mode="wb") as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing.encode())
else:
with open(filename, mode="w") as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing)
def deserialize(filename: pathlib.Path):
"""
Deserialize an OpenMM System, State, or Integrator.
Parameters
----------
item : System, State, or Integrator
The thing to be serialized
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer
# Create parent directory if it doesn't exist
filename_basedir = filename.parent
if not filename_basedir.exists():
os.makedirs(filename_basedir)
if filename.suffix == ".gz":
import gzip
with gzip.open(filename, mode="rb") as infile:
serialized_thing = infile.read().decode()
item = XmlSerializer.deserialize(serialized_thing)
elif filename.suffix == ".bz2":
import bz2
with bz2.open(filename, mode="rb") as infile:
serialized_thing = infile.read().decode()
item = XmlSerializer.deserialize(serialized_thing)
else:
with open(filename) as infile:
serialized_thing = infile.read()
item = XmlSerializer.deserialize(serialized_thing)
return item
class SepTopParameterState(GlobalParameterState):
"""
Composable state to control lambda parameters for two ligands.

View File

@@ -7,7 +7,7 @@ Reusable utilities for assigning partial charges to ChemicalComponents.
import copy
import sys
import warnings
from typing import Callable, Literal, Optional, Union
from typing import Callable, Literal
import numpy as np
from gufe import SmallMoleculeComponent
@@ -112,7 +112,7 @@ def assign_offmol_espaloma_charges(offmol: OFFMol, toolkit_registry: ToolkitRegi
def assign_offmol_nagl_charges(
offmol: OFFMol,
toolkit_registry: ToolkitRegistry,
nagl_model: Optional[str] = None,
nagl_model: str | None = None,
) -> None:
"""
Assign NAGL charges using the OpenFF toolkit.
@@ -126,7 +126,7 @@ def assign_offmol_nagl_charges(
This strictly limits available toolkit wrappers by
overwriting the global registry during the partial charge
assignment stage.
nagl_model : Optional[str]
nagl_model : str | None
The NAGL model to use when assigning partial charges.
If ``None``, will fetch the latest production "am1bcc" model.
"""
@@ -208,7 +208,7 @@ def _generate_offmol_conformers(
offmol: OFFMol,
max_conf: int,
toolkit_registry: ToolkitRegistry,
generate_n_conformers: Optional[int],
generate_n_conformers: int | None,
) -> None:
"""
Helper method for OFF Molecule conformer generation in charge assignment.
@@ -223,7 +223,7 @@ def _generate_offmol_conformers(
Toolkit registry to use for generating conformers.
This strictly limits available toolkit wrappers by
overwriting the global registry during the conformer generation step.
generate_n_conformers : Optional[int]
generate_n_conformers : int | None
The number of conformers to generate. If ``None``, the existing
conformers are retained & used.
@@ -288,8 +288,8 @@ def assign_offmol_partial_charges(
overwrite: bool,
method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"],
toolkit_backend: Literal["ambertools", "openeye", "rdkit"],
generate_n_conformers: Optional[int],
nagl_model: Optional[str],
generate_n_conformers: int | None,
nagl_model: str | None,
) -> OFFMol:
"""
Assign partial charges to an OpenFF Molecule based on a selected method.
@@ -312,11 +312,11 @@ def assign_offmol_partial_charges(
* ``rdkit``: selects the RDKit toolkit Wrapper
Note that the ``rdkit`` backend cannot be used for `am1bcc` or
``am1bccelf10`` partial charge methods.
generate_n_conformers : Optional[int]
generate_n_conformers : int | None
Number of conformers to generate for partial charge generation.
If ``None`` (default), the input conformer will be used.
If ``None``, the input conformer will be used.
Values greater than 1 can only be used alongside ``am1bccelf10``.
nagl_model : Optional[str]
nagl_model : str | None
The NAGL model to use for charge assignment if method is ``nagl``.
If ``None``, the latest am1bcc NAGL charge model is used.
@@ -403,6 +403,12 @@ def assign_offmol_partial_charges(
errmsg = "OpenEye is not available and cannot be selected as a backend"
raise ImportError(errmsg)
# Issue 1760
if HAS_OPENEYE and method.lower() == "nagl":
if toolkit_backend.lower() != "openeye":
errmsg = "OpenEye toolkit is installed but not used in the OpenFF toolkit registry backend. This is not possible with NAGL charges."
raise ValueError(errmsg)
toolkits = ToolkitRegistry([i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]])
# We make a copy of the molecule since we're going to modify conformers
@@ -437,8 +443,8 @@ def bulk_assign_partial_charges(
overwrite: bool,
method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"],
toolkit_backend: Literal["ambertools", "openeye", "rdkit"],
generate_n_conformers: Optional[int],
nagl_model: Optional[str],
generate_n_conformers: int | None,
nagl_model: str | None,
processors: int = 1,
) -> list[SmallMoleculeComponent]:
"""
@@ -462,11 +468,11 @@ def bulk_assign_partial_charges(
* ``rdkit``: selects the RDKit toolkit Wrapper
Note that the ``rdkit`` backend cannot be used for `am1bcc` or
``am1bccelf10`` partial charge methods.
generate_n_conformers : Optional[int]
generate_n_conformers : int | None
Number of conformers to generate for partial charge generation.
If ``None`` (default), the input conformer will be used.
If ``None``, the input conformer will be used.
Values greater than 1 can only be used alongside ``am1bccelf10``.
nagl_model : Optional[str]
nagl_model : str | None
The NAGL model to use for charge assignment if method is ``nagl``.
If ``None``, the latest am1bcc NAGL charge model is used.
processors: int, default 1

View File

@@ -0,0 +1,89 @@
import os
import pathlib
from gufe.settings.typing import NanometerArrayQuantity
from openff.units import Quantity
from openmm import Vec3
from openmm import unit as ommunit
def serialize(item, filename: pathlib.Path):
"""
Serialize an OpenMM System, State, or Integrator.
Parameters
----------
item : System, State, or Integrator
The thing to be serialized
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer
# Create parent directory if it doesn't exist
filename_basedir = filename.parent
if not filename_basedir.exists():
os.makedirs(filename_basedir)
if filename.suffix == ".bz2":
import bz2
with bz2.open(filename, mode="wb") as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing.encode())
else:
with open(filename, mode="w") as outfile:
serialized_thing = XmlSerializer.serialize(item)
outfile.write(serialized_thing)
def deserialize(filename: pathlib.Path):
"""
Deserialize an OpenMM System, State, or Integrator.
Parameters
----------
item : System, State, or Integrator
The thing to be serialized
filename : str
The filename to serialize to
"""
from openmm import XmlSerializer
# Create parent directory if it doesn't exist
filename_basedir = filename.parent
if not filename_basedir.exists():
os.makedirs(filename_basedir)
if filename.suffix == ".bz2":
import bz2
with bz2.open(filename, mode="rb") as infile:
serialized_thing = infile.read().decode()
item = XmlSerializer.deserialize(serialized_thing)
else:
with open(filename) as infile:
serialized_thing = infile.read()
item = XmlSerializer.deserialize(serialized_thing)
return item
def make_vec3_box(dimensions: NanometerArrayQuantity) -> Vec3:
"""
Convert an OpenFF box dimensions Quantity back into Vec3 format.
Parameters
----------
dimensions : gufe.settings.typing.NanometerArrayQuantity
United array to turn to Vec3 format.
Returns
-------
openmm.Vec3
The input array in Vec3 format.
"""
return [
Vec3(float(row[0]), float(row[1]), float(row[2])) * ommunit.nanometer
for row in dimensions.m_as("nanometer")
]

View File

@@ -95,23 +95,24 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str):
`nocutoff`.
* If the SolventComponent solvent is not water.
"""
solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)]
solv_comps = state.get_components_of_type(SolventComponent)
if len(solv) > 0 and nonbonded_method.lower() == "nocutoff":
errmsg = "nocutoff cannot be used for solvent transformations"
raise ValueError(errmsg)
if len(solv_comps) > 0:
if nonbonded_method.lower() == "nocutoff":
errmsg = "nocutoff cannot be used for solvent transformations"
raise ValueError(errmsg)
if len(solv) == 0 and nonbonded_method.lower() == "pme":
errmsg = "PME cannot be used for vacuum transform"
raise ValueError(errmsg)
if len(solv_comps) > 1:
errmsg = "Multiple SolventComponent found, only one is supported"
raise ValueError(errmsg)
if len(solv) > 1:
errmsg = "Multiple SolventComponent found, only one is supported"
raise ValueError(errmsg)
if len(solv) > 0 and solv[0].smiles != "O":
errmsg = "Non water solvent is not currently supported"
raise ValueError(errmsg)
if solv_comps[0].smiles != "O":
errmsg = "Non water solvent is not currently supported"
raise ValueError(errmsg)
else:
if nonbonded_method.lower() == "pme":
errmsg = "PME cannot be used for vacuum transform"
raise ValueError(errmsg)
def validate_protein(state: ChemicalSystem):
@@ -129,9 +130,9 @@ def validate_protein(state: ChemicalSystem):
ValueError
If there are multiple ProteinComponent in the ChemicalSystem.
"""
nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent))
prot_comps = state.get_components_of_type(ProteinComponent)
if nprot > 1:
if len(prot_comps) > 1:
errmsg = "Multiple ProteinComponent found, only one is supported"
raise ValueError(errmsg)
@@ -161,24 +162,18 @@ def get_components(state: ChemicalSystem) -> ParseCompRet:
small_mols : list[SmallMoleculeComponent]
"""
def _get_single_comps(comp_list, comptype):
ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)]
if ret_comps:
return ret_comps[0]
def _get_single_comps(state, comptype):
comps = state.get_components_of_type(comptype)
if len(comps) > 0:
return comps[0]
else:
return None
solvent_comp: Optional[SolventComponent] = _get_single_comps(
list(state.values()), SolventComponent
)
solvent_comp: Optional[SolventComponent] = _get_single_comps(state, SolventComponent)
protein_comp: Optional[ProteinComponent] = _get_single_comps(
list(state.values()), ProteinComponent
)
protein_comp: Optional[ProteinComponent] = _get_single_comps(state, ProteinComponent)
small_mols = []
for comp in state.components.values():
if isinstance(comp, SmallMoleculeComponent):
small_mols.append(comp)
small_mols = state.get_components_of_type(SmallMoleculeComponent)
return solvent_comp, protein_comp, small_mols

Some files were not shown because too many files have changed in this diff Show More