Add Analysis unit for Septop (#1937)

* Add an Analysis unit to SepTopProtocol.
* Update PDB writing via MDTraj in AFE and SepTop Protocols to account for box dimensions with mdtraj_from_openmm utility.
* Split up some SepTop tests into different files.
* Updated the CLI gathering to account for post openfe v1.11 Analysis units

---------

Co-authored-by: Hannah Baumann <43765638+hannahbaumann@users.noreply.github.com>
This commit is contained in:
Irfan Alibay
2026-04-24 10:00:06 +01:00
committed by GitHub
parent 7959cc6e74
commit c81d27cf0d
21 changed files with 1430 additions and 989 deletions

View File

@@ -18,8 +18,10 @@ Protocol API specification
SepTopProtocol
SepTopComplexSetupUnit
SepTopComplexRunUnit
SepTopComplexAnalysisUnit
SepTopSolventSetupUnit
SepTopSolventRunUnit
SepTopSolventAnalysisUnit
SepTopProtocolResult
Protocol Settings

27
news/analysis-septop.rst Normal file
View File

@@ -0,0 +1,27 @@
**Added:**
* <news item>
**Changed:**
* The SepTopProtocol now has a dedicated Analysis unit.
At the top level API, this does not change behaviour, but
if you are directly interfacing with the ProtocolUnits, you
will have to account for this change. The SepTopProtocolResult now
solely uses the Analysis units. PR #1937
**Deprecated:**
* <news item>
**Removed:**
* <news item>
**Fixed:**
* <news item>
**Security:**
* <news item>

View File

@@ -23,7 +23,6 @@ import pathlib
from typing import Any
import gufe
import mdtraj as mdt
import numpy as np
import numpy.typing as npt
import openmm
@@ -75,6 +74,9 @@ from openfe.protocols.openmm_utils import (
system_creation,
system_validation,
)
from openfe.protocols.openmm_utils.mdtraj_utils import (
mdtraj_from_openmm,
)
from openfe.protocols.openmm_utils.omm_settings import (
SettingsBaseModel,
)
@@ -643,16 +645,14 @@ class BaseAbsoluteSetupUnit(gufe.ProtocolUnit, AbsoluteUnitMixin):
selection_indices : npt.NDArray
The indices of the subselected system.
"""
mdt_top = mdt.Topology.from_openmm(topology)
selection_indices = mdt_top.select(output_selection)
traj = mdtraj_from_openmm(topology, positions)
selection_indices = traj.topology.select(output_selection)
# Write out the subselected structure to PDB if not empty
if len(selection_indices) > 0:
traj = mdt.Trajectory(
positions[selection_indices, :],
mdt_top.subset(selection_indices),
)
traj.save_pdb(output_file)
sub_traj = traj.atom_slice(selection_indices)
sub_traj.save_pdb(output_file)
return selection_indices
@@ -1346,7 +1346,7 @@ class BaseAbsoluteMultiStateSimulationUnit(gufe.ProtocolUnit, AbsoluteUnitMixin)
self.shared_basepath / settings["output_settings"].checkpoint_storage_filename,
]
for fn in fns:
os.remove(fn)
fn.unlink()
def run(
self,

View File

@@ -487,7 +487,7 @@ class AbsoluteBindingProtocol(gufe.Protocol):
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
unit_classes = {
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
"solvent": {
"setup": ABFESolventSetupUnit,
"simulation": ABFESolventSimUnit,

View File

@@ -451,7 +451,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol):
# Get the name of the alchemical species
alchname = alchem_comps["stateA"][0].name
unit_classes = {
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
"solvent": {
"setup": AHFESolventSetupUnit,
"simulation": AHFESolventSimUnit,

View File

@@ -19,7 +19,7 @@ from typing import Any
import gufe
import matplotlib.pyplot as plt
import mdtraj
import mdtraj as mdt
import numpy as np
import numpy.typing as npt
import openmm
@@ -643,7 +643,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
def _subsample_topology(
self,
hybrid_topology: openmm.app.Topology,
hybrid_topology: mdt.Topology,
hybrid_positions: openmm.unit.Quantity,
output_selection: str,
output_filename: str,
@@ -655,7 +655,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
Parameters
----------
hybrid_topology : openmm.app.Topology
hybrid_topology : mdtraj.Topology
The hybrid system topology to subsample.
hybrid_positions : openmm.unit.Quantity
The hybrid system positions.
@@ -674,7 +674,8 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
TODO
----
Modify this to also store the full system.
* Modify this to also store the full system.
* Use the mdtraj_from_openmm utility.
"""
selection_indices = hybrid_topology.select(output_selection)
@@ -690,7 +691,7 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
bfactors[np.isin(selection_indices, list(atom_classes["unique_new_atoms"]))] = 0.75
if len(selection_indices) > 0:
traj = mdtraj.Trajectory(
traj = mdt.Trajectory(
hybrid_positions[selection_indices, :],
hybrid_topology.subset(selection_indices),
).save_pdb(

View File

@@ -6,10 +6,12 @@ Run SepTop free energy calculations using OpenMM and OpenMMTools.
"""
from .equil_septop_method import (
SepTopComplexAnalysisUnit,
SepTopComplexRunUnit,
SepTopComplexSetupUnit,
SepTopProtocol,
SepTopProtocolResult,
SepTopSolventAnalysisUnit,
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
@@ -25,4 +27,6 @@ __all__ = [
"SepTopSolventSetupUnit",
"SepTopSolventRunUnit",
"SepTopComplexRunUnit",
"SepTopSolventAnalysisUnit",
"SeptopComplexAnalysisUnit",
]

File diff suppressed because it is too large Load Diff

View File

@@ -74,8 +74,10 @@ from ..restraint_utils.settings import (
)
from .septop_protocol_results import SepTopProtocolResult
from .septop_units import (
SepTopComplexAnalysisUnit,
SepTopComplexRunUnit,
SepTopComplexSetupUnit,
SepTopSolventAnalysisUnit,
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
@@ -148,7 +150,7 @@ class SepTopProtocol(gufe.Protocol):
:class:`openfe.protocols.openmm_septop.SepTopProtocolResult`
:class:`openfe.protocols.openmm_septop.SepTopComplexSetupUnit`
:class:`openfe.protocols.openmm_septop.SepTopComplexRunUnit`
:class:`openfe.protocols.openmm_septop.SepTopSolventSetupUnit
:class:`openfe.protocols.openmm_septop.SepTopSolventSetupUnit`
:class:`openfe.protocols.openmm_septop.SepTopSolventRunUnit`
"""
@@ -460,7 +462,7 @@ class SepTopProtocol(gufe.Protocol):
self,
stateA: ChemicalSystem,
stateB: ChemicalSystem,
mapping: gufe.ComponentMapping | list[gufe.ComponentMappping] | None,
mapping: gufe.ComponentMapping | list[gufe.ComponentMapping] | None,
extends: gufe.ProtocolDAGResult | None = None,
) -> None:
# Check we're not trying to extend
@@ -525,89 +527,96 @@ class SepTopProtocol(gufe.Protocol):
)
# Create list units for complex and solvent transforms
def create_setup_units(unit_cls, leg):
return [
unit_cls(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=int(uuid.uuid4()),
name=(
f"SepTop RBFE Setup, transformation {alchname_A} to "
f"{alchname_B}, {leg} leg: repeat {i} generation 0"
),
)
for i in range(self.settings.protocol_repeats)
]
def create_run_units(unit_cls, leg, setup):
return [
unit_cls(
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
setup=setup[i],
generation=0,
repeat_id=int(uuid.uuid4()),
name=(
f"SepTop RBFE Run, transformation {alchname_A} to "
f"{alchname_B}, {leg} leg: repeat {i} generation 0"
),
)
for i in range(self.settings.protocol_repeats)
]
alchname_A = alchem_comps["stateA"][0].name
alchname_B = alchem_comps["stateB"][0].name
solvent_setup = create_setup_units(SepTopSolventSetupUnit, "solvent")
solvent_run = create_run_units(SepTopSolventRunUnit, "solvent", setup=solvent_setup)
complex_setup = create_setup_units(SepTopComplexSetupUnit, "complex")
complex_run = create_run_units(SepTopComplexRunUnit, "complex", setup=complex_setup)
unit_classes: dict[str, dict[str, type[gufe.ProtocolUnit]]] = {
"solvent": {
"setup": SepTopSolventSetupUnit,
"simulation": SepTopSolventRunUnit,
"analysis": SepTopSolventAnalysisUnit,
},
"complex": {
"setup": SepTopComplexSetupUnit,
"simulation": SepTopComplexRunUnit,
"analysis": SepTopComplexAnalysisUnit,
},
}
return solvent_setup + solvent_run + complex_setup + complex_run
protocol_units: dict[str, list[gufe.ProtocolUnit]] = {"solvent": [], "complex": []}
for i in range(self.settings.protocol_repeats):
repeat_id = int(uuid.uuid4())
for phase in ["solvent", "complex"]:
setup = unit_classes[phase]["setup"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
generation=0,
repeat_id=repeat_id,
name=(
f"SepTop RBFE Setup, transformation {alchname_A} to "
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
),
)
simulation = unit_classes[phase]["simulation"](
protocol=self,
stateA=stateA,
stateB=stateB,
alchemical_components=alchem_comps,
setup=setup,
generation=0,
repeat_id=repeat_id,
name=(
f"SepTop RBFE Run, transformation {alchname_A} to "
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
),
)
analysis = unit_classes[phase]["analysis"](
protocol=self,
setup=setup,
simulation=simulation,
generation=0,
repeat_id=repeat_id,
name=(
f"SepTop RBFE Analysis, transformation {alchname_A} to "
f"{alchname_B}, {phase} leg: repeat {i} generation 0"
),
)
protocol_units[phase] += [setup, simulation, analysis]
return protocol_units["solvent"] + protocol_units["complex"]
def _gather(
self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]
) -> dict[str, dict[str, Any]]:
# result units will have a repeat_id and generation
# first group according to repeat_id
unsorted_solvent_repeats_setup = defaultdict(list)
unsorted_solvent_repeats_run = defaultdict(list)
unsorted_complex_repeats_setup = defaultdict(list)
unsorted_complex_repeats_run = defaultdict(list)
unsorted_solvent_repeats = defaultdict(list)
unsorted_complex_repeats = defaultdict(list)
for d in protocol_dag_results:
pu: gufe.ProtocolUnitResult
for pu in d.protocol_unit_results:
if not pu.ok():
if ("Analysis" not in pu.name) or (not pu.ok()):
continue
if pu.outputs["simtype"] == "solvent":
if "Run" in pu.name:
unsorted_solvent_repeats_run[pu.outputs["repeat_id"]].append(pu)
elif "Setup" in pu.name:
unsorted_solvent_repeats_setup[pu.outputs["repeat_id"]].append(pu)
unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu)
else:
if "Run" in pu.name:
unsorted_complex_repeats_run[pu.outputs["repeat_id"]].append(pu)
elif "Setup" in pu.name:
unsorted_complex_repeats_setup[pu.outputs["repeat_id"]].append(pu)
unsorted_complex_repeats[pu.outputs["repeat_id"]].append(pu)
repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = {
"solvent_setup": {},
"solvent": {},
"complex_setup": {},
"complex": {},
}
for k, v in unsorted_solvent_repeats_setup.items():
repeats["solvent_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_solvent_repeats_run.items():
for k, v in unsorted_solvent_repeats.items():
repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_complex_repeats_setup.items():
repeats["complex_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
for k, v in unsorted_complex_repeats_run.items():
for k, v in unsorted_complex_repeats.items():
repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"])
return repeats

View File

@@ -65,8 +65,6 @@ class SepTopProtocolResult(gufe.ProtocolResult):
complex_dGs.append(
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
)
for pus in self.data["complex_setup"].values():
complex_correction_dGs_A.append(
(
pus[0].outputs["standard_state_correction_A"],
@@ -84,8 +82,6 @@ class SepTopProtocolResult(gufe.ProtocolResult):
solv_dGs.append(
(pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])
)
for pus in self.data["solvent_setup"].values():
solv_correction_dGs.append(
(
pus[0].outputs["standard_state_correction"],
@@ -417,8 +413,8 @@ class SepTopProtocolResult(gufe.ProtocolResult):
for key in ["complex", "solvent"]:
for pus in self.data[key].values():
states = get_replica_state(
pus[0].outputs["nc"],
pus[0].outputs["last_checkpoint"],
pus[0].outputs["trajectory"],
pus[0].outputs["checkpoint"],
)
replica_states[key].append(states)
@@ -487,11 +483,11 @@ class SepTopProtocolResult(gufe.ProtocolResult):
"""
geometry_A = [
BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_A"])
for pus in self.data["complex_setup"].values()
for pus in self.data["complex"].values()
]
geometry_B = [
BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_B"])
for pus in self.data["complex_setup"].values()
for pus in self.data["complex"].values()
]
return geometry_A, geometry_B

View File

@@ -34,6 +34,7 @@ from openff.units.openmm import from_openmm, to_openmm
from openmmtools.states import ThermodynamicState
from rdkit import Chem
from openfe.protocols.openmm_utils import omm_compute
from openfe.protocols.openmm_utils.serialization import serialize
from openfe.protocols.restraint_utils import geometry
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
@@ -43,59 +44,25 @@ from openfe.protocols.restraint_utils.openmm.omm_restraints import (
add_force_in_separate_group,
)
from ..openmm_utils import settings_validation, system_validation
from ..openmm_utils import (
settings_validation,
system_validation,
)
from ..openmm_utils.mdtraj_utils import mdtraj_from_openmm
from ..restraint_utils.settings import (
BoreschRestraintSettings,
DistanceRestraintSettings,
)
from .base_units import BaseSepTopRunUnit, BaseSepTopSetupUnit, _pre_equilibrate
from .base_units import (
BaseSepTopAnalysisUnit,
BaseSepTopRunUnit,
BaseSepTopSetupUnit,
_pre_equilibrate,
)
logger = logging.getLogger(__name__)
def _get_mdtraj_from_openmm(
omm_topology: openmm.app.Topology,
omm_positions: openmm.unit.Quantity,
):
"""
Get an mdtraj object from an OpenMM topology and positions.
Parameters
----------
omm_topology: openmm.app.Topology
The OpenMM topology
omm_positions: openmm.unit.Quantity
The OpenMM positions
Returns
-------
mdtraj_system: md.Trajectory
"""
mdtraj_topology = md.Topology.from_openmm(omm_topology)
positions_in_mdtraj_format = omm_positions.value_in_unit(omm_units.nanometers)
box = omm_topology.getPeriodicBoxVectors()
x, y, z = [np.array(b._value) for b in box]
lx = np.linalg.norm(x)
ly = np.linalg.norm(y)
lz = np.linalg.norm(z)
# angle between y and z
alpha = np.arccos(np.dot(y, z) / (ly * lz))
# angle between x and z
beta = np.arccos(np.dot(x, z) / (lx * lz))
# angle between x and y
gamma = np.arccos(np.dot(x, y) / (lx * ly))
mdtraj_system = md.Trajectory(
positions_in_mdtraj_format,
mdtraj_topology,
unitcell_lengths=np.array([lx, ly, lz]),
unitcell_angles=np.array([np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)]),
)
return mdtraj_system
class SepTopComplexMixin:
"""
A mixin to get the components and the settings for the Complex Units.
@@ -132,7 +99,7 @@ class SepTopComplexMixin:
return alchem_comps, solv_comp, prot_comp, small_mols
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a complex transformation.
@@ -218,9 +185,9 @@ class SepTopSolventMixin:
return alchem_comps, solv_comp, None, small_mols
def _handle_settings(self) -> dict[str, SettingsBaseModel]:
def _get_settings(self) -> dict[str, SettingsBaseModel]:
"""
Extract the relevant settings for a complex transformation.
Extract the relevant settings for a solvent transformation.
Returns
-------
@@ -391,8 +358,8 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
updated_positions_B: openmm.unit.Quantity
Updated positions of the complex B
"""
mdtraj_complex_A = _get_mdtraj_from_openmm(omm_topology_A, positions_A)
mdtraj_complex_B = _get_mdtraj_from_openmm(omm_topology_B, positions_B)
mdtraj_complex_A = mdtraj_from_openmm(omm_topology_A, positions_A)
mdtraj_complex_B = mdtraj_from_openmm(omm_topology_B, positions_B)
alignment_indices = SepTopComplexSetupUnit._get_selection_atom_indices(mdtraj_complex_A)
imaged_complex_B = mdtraj_complex_B.image_molecules()
imaged_complex_B.superpose(
@@ -701,13 +668,15 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
# 0. General preparation tasks
self._prepare(verbose, scratch_basepath, shared_basepath)
self.logger.info("Setting up SepTop complex system.")
# 1. Get components
self.logger.info("Creating and setting up the OpenMM systems")
alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components()
smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps)
# 3. Get settings
settings = self._handle_settings()
settings = self._get_settings()
# 4. Assign partial charges
self._assign_partial_charges(settings["charge_settings"], smc_comps_AB)
@@ -740,11 +709,6 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
smc_comp_B_unique,
settings,
)
# Virtual sites sanity check - ensure we restart velocities when
# there are virtual sites in the system
self.check_assign_velocities_with_virtual_site(
omm_system_AB, settings["integrator_settings"]
)
# Get the comp_resids of the AB system
resids_A = list(itertools.chain(*comp_resids_A.values()))
@@ -753,28 +717,38 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
comp_resids_AB = comp_resids_A | {alchem_comps["stateB"][0]: np.array(diff_resids)}
# 6. Pre-equilbrate System (for restraint selection)
self.logger.info("Pre-equilibrating the systems")
equil_positions_A, box_A = _pre_equilibrate(
omm_system_A,
omm_topology_A,
positions_A,
settings,
"A",
dry,
self.shared_basepath,
self.verbose,
self.logger,
platform = omm_compute.get_openmm_platform(
platform_name=settings["engine_settings"].compute_platform,
gpu_device_index=settings["engine_settings"].gpu_device_index,
restrict_cpu_count=False,
)
self.logger.info("Pre-equilibrating the systems")
equil_positions_A, box_A = _pre_equilibrate(
system=omm_system_A,
topology=omm_topology_A,
positions=positions_A,
settings=settings,
endstate="A",
dry=dry,
shared_basepath=self.shared_basepath,
platform=platform,
verbose=self.verbose,
logger=self.logger,
)
equil_positions_B, box_B = _pre_equilibrate(
omm_system_B,
omm_topology_B,
positions_B,
settings,
"B",
dry,
self.shared_basepath,
self.verbose,
self.logger,
system=omm_system_B,
topology=omm_topology_B,
positions=positions_B,
settings=settings,
endstate="B",
dry=dry,
shared_basepath=self.shared_basepath,
platform=platform,
verbose=self.verbose,
logger=self.logger,
)
# 7. Get all the right atom indices for alignments
@@ -832,20 +806,35 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
)
equil_positions_AB, box_AB = _pre_equilibrate(
system,
omm_topology_AB,
positions_AB,
settings,
"AB",
dry,
self.shared_basepath,
self.verbose,
self.logger,
system=system,
topology=omm_topology_AB,
positions=positions_AB,
settings=settings,
endstate="AB",
dry=dry,
platform=platform,
shared_basepath=self.shared_basepath,
verbose=self.verbose,
logger=self.logger,
)
# Update box vectors
omm_topology_AB.setPeriodicBoxVectors(box_AB)
# Serialize system, state and integrator
# Subselect system based on user inputs & write initial subsampled PDB
sub_pdb_structure = self.shared_basepath / settings["output_settings"].output_structure
selection_indices = self._subsample_topology(
topology=omm_topology_AB,
positions=positions_AB,
output_selection=settings["output_settings"].output_indices,
output_file=self.shared_basepath / settings["output_settings"].output_structure,
)
# The subsampled PDB may not have been written if selection_indices == 0
# Issue #1942 - maybe move this to the method?
if len(selection_indices) == 0:
sub_pdb_structure = None
# Serialize the system and PDB topology
system_outfile = self.shared_basepath / "system.xml.bz2"
serialize(system, system_outfile)
@@ -864,21 +853,23 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit):
"standard_state_correction_B": corr_B.to("kilocalorie_per_mole"),
"restraint_geometry_A": restraint_geom_A.model_dump(),
"restraint_geometry_B": restraint_geom_B.model_dump(),
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
}
else:
return {
# Add in various objects we can use to test the system
"debug": {
"system": system_outfile,
"topology": topology_file,
"system_A": omm_system_A,
"system_B": omm_system_B,
"system_AB": omm_system_AB,
"restrained_system": system,
"alchem_system": alchemical_system,
"alchem_factory": alchemical_factory,
"positions": equil_positions_AB,
}
"system": system_outfile,
"topology": topology_file,
"system_A": omm_system_A,
"system_B": omm_system_B,
"system_AB": omm_system_AB,
"alchem_restrained_system": system,
"alchem_system": alchemical_system,
"alchem_factory": alchemical_factory,
"positions": equil_positions_AB,
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
}
@@ -1049,13 +1040,15 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
# 0. General preparation tasks
self._prepare(verbose, scratch_basepath, shared_basepath)
self.logger.info("Setting up SepTop solvent system.")
# 1. Get components
self.logger.info("Creating and setting up the OpenMM systems")
alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components()
smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps)
# 2. Get settings
settings = self._handle_settings()
settings = self._get_settings()
# 3. Assign partial charges
self._assign_partial_charges(settings["charge_settings"], smc_comps_AB)
@@ -1078,12 +1071,6 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
)
) # fmt: skip
# Virtual sites sanity check - ensure we restart velocities when
# there are virtual sites in the system
self.check_assign_velocities_with_virtual_site(
omm_system_AB, settings["integrator_settings"]
)
# 6. Get atom indices for ligand A and ligand B and the solvent in the
# system AB
comp_atomids_AB = self._get_atom_indices(omm_topology_AB, comp_resids_AB)
@@ -1116,16 +1103,27 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
positions_AB,
)
# Write the full system PDB
topology_file = self.shared_basepath / "topology.pdb"
openmm.app.pdbfile.PDBFile.writeFile(
omm_topology_AB, positions_AB, open(topology_file, "w")
)
# ToDo: also apply REST
# Subselect system based on user inputs & write initial subsampled PDB
sub_pdb_structure = self.shared_basepath / settings["output_settings"].output_structure
selection_indices = self._subsample_topology(
topology=omm_topology_AB,
positions=positions_AB,
output_selection=settings["output_settings"].output_indices,
output_file=self.shared_basepath / settings["output_settings"].output_structure,
)
# The subsampled PDB may not have been written if selection_indices == 0
# Issue #1942 - maybe move this to the method?
if len(selection_indices) == 0:
sub_pdb_structure = None
# Serialize the system
system_outfile = self.shared_basepath / "system.xml.bz2"
# Serialize system, state and integrator
serialize(system, system_outfile)
if not dry:
@@ -1133,19 +1131,21 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit):
"system": system_outfile,
"topology": topology_file,
"standard_state_correction": corr.to("kilocalorie_per_mole"),
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
}
else:
return {
# Add in various objects we can used to test the system
"debug": {
"system": system_outfile,
"topology": topology_file,
"system_AB": omm_system_AB,
"restrained_system": system,
"alchem_system": alchemical_system,
"alchem_factory": alchemical_factory,
"positions": positions_AB,
}
"system": system_outfile,
"topology": topology_file,
"system_AB": omm_system_AB,
"alchem_restrained_system": system,
"alchem_system": alchemical_system,
"alchem_factory": alchemical_factory,
"positions": positions_AB,
"selection_indices": selection_indices,
"subsampled_pdb_structure": sub_pdb_structure,
}
@@ -1218,3 +1218,19 @@ class SepTopComplexRunUnit(SepTopComplexMixin, BaseSepTopRunUnit):
lambdas["lambda_restraints_B"] = lambda_restraints_B
return lambdas
class SepTopSolventAnalysisUnit(SepTopSolventMixin, BaseSepTopAnalysisUnit):
"""
Protocol Unit for the analysis of the solvent phase of a relative SepTop free energy
"""
simtype = "solvent"
class SepTopComplexAnalysisUnit(SepTopComplexMixin, BaseSepTopAnalysisUnit):
"""
Protocol Unit for the analysis of the complex phase of a relative SepTop free energy
"""
simtype = "complex"

View File

@@ -0,0 +1,57 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import mdtraj as mdt
import numpy as np
import openmm
from openmm import unit as omm_unit
def mdtraj_from_openmm(
omm_topology: openmm.app.Topology,
omm_positions: openmm.unit.Quantity,
):
"""
Get an mdtraj object from an OpenMM topology and positions.
Parameters
----------
omm_topology : openmm.app.Topology
The OpenMM topology
omm_positions : openmm.unit.Quantity
The OpenMM positions
Returns
-------
mdtraj_trajectory : md.Trajectory
"""
mdtraj_topology = mdt.Topology.from_openmm(omm_topology)
positions_in_mdtraj_format = omm_positions.value_in_unit(omm_unit.nanometers)
box = omm_topology.getPeriodicBoxVectors()
if box is not None:
x, y, z = [np.array(b._value) for b in box]
lx = np.linalg.norm(x)
ly = np.linalg.norm(y)
lz = np.linalg.norm(z)
# angle between y and z
alpha = np.arccos(np.dot(y, z) / (ly * lz))
# angle between x and z
beta = np.arccos(np.dot(x, z) / (lx * lz))
# angle between x and y
gamma = np.arccos(np.dot(x, y) / (lx * ly))
unitcell_lengths = np.array([lx, ly, lz])
unitcell_angles = np.array([np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)])
else:
unitcell_lengths = None
unitcell_angles = None
mdtraj_trajectory = mdt.Trajectory(
positions_in_mdtraj_format,
mdtraj_topology,
unitcell_lengths=unitcell_lengths,
unitcell_angles=unitcell_angles,
)
return mdtraj_trajectory

View File

@@ -1,3 +1,6 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import os
import pathlib

View File

@@ -0,0 +1,30 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import gufe
import pytest
from openfe.protocols.openmm_septop import SepTopProtocol
@pytest.fixture()
def protocol_dry_settings():
# a set of settings for dry run tests
s = SepTopProtocol.default_settings()
s.engine_settings.compute_platform = None
s.protocol_repeats = 1
return s
@pytest.fixture
def benzene_toluene_dag(
benzene_complex_system,
toluene_complex_system,
protocol_dry_settings,
):
protocol = SepTopProtocol(settings=protocol_dry_settings)
return protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)

View File

@@ -33,10 +33,12 @@ from openmmtools.multistate.multistatesampler import MultiStateSampler
import openfe.protocols.openmm_septop
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols.openmm_septop import (
SepTopComplexAnalysisUnit,
SepTopComplexRunUnit,
SepTopComplexSetupUnit,
SepTopProtocol,
SepTopProtocolResult,
SepTopSolventAnalysisUnit,
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
@@ -48,6 +50,8 @@ from openfe.tests.protocols.openmm_ahfe.test_ahfe_protocol import (
_verify_alchemical_sterics_force_parameters,
)
from .utils import UNIT_TYPES, _get_units
E_CHARGE = 1.602176634e-19 * openmm.unit.coulomb
EPSILON0 = (
1e-6
@@ -59,15 +63,6 @@ EPSILON0 = (
ONE_4PI_EPS0 = 1 / (4 * np.pi * EPSILON0) * EPSILON0.unit * 10.0 # nm -> angstrom
@pytest.fixture()
def protocol_dry_settings():
# a set of settings for dry run tests
s = SepTopProtocol.default_settings()
s.engine_settings.compute_platform = None
s.protocol_repeats = 1
return s
@pytest.fixture()
def default_settings():
s = SepTopProtocol.default_settings()
@@ -93,6 +88,47 @@ def test_serialize_protocol(default_settings):
assert protocol == ret
def test_repeat_units(benzene_complex_system, toluene_complex_system, default_settings):
default_settings.protocol_repeats = 3
protocol = SepTopProtocol(
settings=default_settings,
)
dag = protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)
# 6 protocol unit, 3 per repeat
pus = list(dag.protocol_units)
assert len(pus) == 18
# Check info for each repeat
for phase in ["solvent", "complex"]:
setup = _get_units(pus, UNIT_TYPES[phase]["setup"])
sim = _get_units(pus, UNIT_TYPES[phase]["sim"])
analysis = _get_units(pus, UNIT_TYPES[phase]["analysis"])
# Should be 3 of each set
assert len(setup) == 3
assert len(sim) == 3
assert len(analysis) == 3
# check that the dag chain is correct
for analysis_pu in analysis:
repeat_id = analysis_pu.inputs["repeat_id"]
setup_pu = [
s for s in setup if (s.inputs["repeat_id"] == repeat_id) and (s.simtype == phase)
][0]
sim_pu = [
s for s in sim if (s.inputs["repeat_id"] == repeat_id) and (s.simtype == phase)
][0]
assert analysis_pu.inputs["setup"] == setup_pu
assert analysis_pu.inputs["simulation"] == sim_pu
assert sim_pu.inputs["setup"] == setup_pu
def test_create_independent_repeat_ids(
benzene_complex_system, toluene_complex_system, default_settings
):
@@ -104,25 +140,26 @@ def test_create_independent_repeat_ids(
settings=default_settings,
)
dag1 = protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)
dag2 = protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)
# print([u for u in dag1.protocol_units])
repeat_ids = set()
for u in dag1.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
for u in dag2.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
dags = []
for i in range(2):
dags.append(
protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)
)
# There are 4 units per repeat per DAG: 4 * 3 * 2 = 24
assert len(repeat_ids) == 24
repeat_ids = set()
for dag in dags:
# 3 repeats of 6 units
assert len(list(dag.protocol_units)) == 18
for u in dag.protocol_units:
repeat_ids.add(u.inputs["repeat_id"])
# one uuid per repeat, so should equal 6
assert len(repeat_ids) == 6
# Tests for the alchemical systems. This tests were modified from
@@ -325,25 +362,10 @@ class TestNonbondedInteractions:
assert_allclose(energy, from_openmm(expected_energy), rtol=1e-05)
@pytest.fixture
def benzene_toluene_dag(
benzene_complex_system,
toluene_complex_system,
protocol_dry_settings,
):
protocol = SepTopProtocol(settings=protocol_dry_settings)
return protocol.create(
stateA=benzene_complex_system,
stateB=toluene_complex_system,
mapping=None,
)
def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
prot_units = list(benzene_toluene_dag.protocol_units)
assert len(prot_units) == 4
assert len(prot_units) == 6
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
sol_run_unit = [u for u in prot_units if isinstance(u, SepTopSolventRunUnit)]
@@ -356,7 +378,7 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
solv_setup_output = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
pdb = md.load_pdb(tmp_path / "topology.pdb")
assert pdb.n_atoms == 1762
central_atoms = np.array([[2, 19]], dtype=np.int32)
@@ -367,10 +389,11 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
solv_sampler = sol_run_unit[0].run(
alchem_system,
pdb_file,
solv_setup_output["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)["debug"]["sampler"] # fmt: skip
)["sampler"] # fmt: skip
assert solv_sampler.is_periodic
assert isinstance(solv_sampler, MultiStateSampler)
@@ -397,16 +420,17 @@ def test_dry_run_benzene_toluene(benzene_toluene_dag, tmp_path):
complex_setup_output = complex_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
pdb_file = openmm.app.pdbfile.PDBFile(str(complex_setup_output["topology"]))
alchem_system = deserialize(complex_setup_output["system"])
complex_sampler = complex_run_unit[0].run(
alchem_system,
pdb_file,
complex_setup_output["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)["debug"]["sampler"] # fmt: skip
)["sampler"] # fmt: skip
assert complex_sampler.is_periodic
assert isinstance(complex_sampler, MultiStateSampler)
@@ -461,16 +485,17 @@ def test_dry_run_methods(
solv_setup_output = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
alchem_system = deserialize(solv_setup_output["system"])
solv_sampler = sol_run_unit[0].run(
alchem_system,
pdb_file,
solv_setup_output["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)["debug"]["sampler"] # fmt: skip
)["sampler"] # fmt: skip
assert isinstance(solv_sampler, MultiStateSampler)
assert solv_sampler.is_periodic
@@ -518,16 +543,17 @@ def test_dry_run_ligand_system_pressure(
solv_setup_output = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
alchem_system = deserialize(solv_setup_output["system"])
solv_sampler = sol_run_unit[0].run(
alchem_system,
pdb_file,
solv_setup_output["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)["debug"]["sampler"] # fmt: skip
)["sampler"] # fmt: skip
# at this point, the units will be in openmm units
assert solv_sampler._thermodynamic_states[1].pressure == pressure * openmm.unit.bar
@@ -561,9 +587,25 @@ def test_virtual_sites_no_reassign(
dag_units = list(dag.protocol_units)
# Only check the Solvent Unit
solv_setup_unit = [u for u in dag_units if isinstance(u, SepTopSolventSetupUnit)]
errmsg = "Simulations with virtual sites without velocity"
with pytest.raises(ValueError, match=errmsg):
_ = solv_setup_unit[0].run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
solv_run_unit = [u for u in dag_units if isinstance(u, SepTopSolventRunUnit)]
setup_results = solv_setup_unit[0].run(
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)
pdb_file = openmm.app.pdbfile.PDBFile(str(setup_results["topology"]))
with pytest.raises(ValueError, match="are unstable"):
_ = solv_run_unit[0].run(
setup_results["alchem_system"],
pdb_file,
setup_results["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
) # fmt: skip
@pytest.mark.parametrize(
@@ -597,7 +639,7 @@ def test_dry_run_ligand_system_cutoff(
serialized_system = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]["system"]
)["system"]
system = deserialize(serialized_system)
nbfs = [
f
@@ -636,7 +678,7 @@ def test_dry_run_benzene_toluene_tip4p(
prot_units = list(dag.protocol_units)
assert len(prot_units) == 4
assert len(prot_units) == 6
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
sol_run_unit = [u for u in prot_units if isinstance(u, SepTopSolventRunUnit)]
@@ -646,16 +688,17 @@ def test_dry_run_benzene_toluene_tip4p(
solv_setup_output = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
alchem_system = deserialize(solv_setup_output["system"])
solv_run = sol_run_unit[0].run(
alchem_system,
pdb_file,
solv_setup_output["selection_indices"],
dry=True,
scratch_basepath=tmp_path,
shared_basepath=tmp_path,
)["debug"]["sampler"] # fmt: skip
)["sampler"] # fmt: skip
assert solv_run.is_periodic
@@ -681,7 +724,7 @@ def test_dry_run_benzene_toluene_noncubic(
prot_units = list(dag.protocol_units)
assert len(prot_units) == 4
assert len(prot_units) == 6
solv_setup_unit = [u for u in prot_units if isinstance(u, SepTopSolventSetupUnit)]
@@ -689,7 +732,7 @@ def test_dry_run_benzene_toluene_noncubic(
solv_setup_output = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]
)
serialized_system = solv_setup_output["system"]
system = deserialize(serialized_system)
vectors = system.getDefaultPeriodicBoxVectors()
@@ -779,7 +822,7 @@ def test_dry_run_solv_user_charges_benzene_toluene(
# check sol_unit charges
serialized_system = solv_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]["system"]
)["system"]
system = deserialize(serialized_system)
nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)]
assert len(nonbond) == 1
@@ -799,7 +842,7 @@ def test_dry_run_solv_user_charges_benzene_toluene(
# check complex_unit charges
serialized_system = complex_setup_unit[0].run(
dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path
)["debug"]["system"]
)["system"]
system = deserialize(serialized_system)
nonbond = [f for f in system.getForces() if isinstance(f, openmm.NonbondedForce)]
assert len(nonbond) == 1
@@ -840,6 +883,23 @@ def test_high_timestep(
prot_units[0].run(dry=True, scratch_basepath=tmp_path, shared_basepath=tmp_path)
def test_bad_sampler():
class FakeSimSettings(gufe.settings.SettingsBaseModel):
sampler_method: str = "foo bar"
errmsg = "Unknown sampler foo bar"
with pytest.raises(AttributeError, match=errmsg):
SepTopSolventRunUnit._get_sampler(
integrator=None,
reporter=None,
simulation_settings=FakeSimSettings(),
thermodynamic_settings=None,
compound_states=None,
sampler_states=None,
platform=None,
)
@pytest.fixture
def T4L_xml(
benzene_complex_system,
@@ -865,7 +925,7 @@ def T4L_xml(
tmp = tmp_path_factory.mktemp("xml_reg")
dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp)["debug"]
dryrun = solv_setup_unit[0].run(dry=True, shared_basepath=tmp)
system = dryrun["system"]
return deserialize(system)
@@ -908,270 +968,6 @@ class TestT4LXmlRegression:
assert a[2] == b[2]
def test_unit_tagging(benzene_toluene_dag, tmp_path):
# test that executing the units includes correct gen and repeat info
dag_units = benzene_toluene_dag.protocol_units
with (
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopComplexSetupUnit.run", # fmt: skip
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopComplexRunUnit._execute", # fmt: skip
return_value={
"repeat_id": 0,
"generation": 0,
"simtype": "complex",
"nc": "file.nc",
"last_checkpoint": "chck.nc",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopSolventSetupUnit.run", # fmt: skip
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopSolventRunUnit._execute", # fmt: skip
return_value={
"repeat_id": 0,
"generation": 0,
"simtype": "solvent",
"nc": "file.nc",
"last_checkpoint": "chck.nc",
},
),
):
results = []
for u in dag_units:
ret = u.execute(context=gufe.Context(tmp_path, tmp_path))
results.append(ret)
solv_repeats = set()
complex_repeats = set()
for ret in results:
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
if ret.outputs["simtype"] == "complex":
complex_repeats.add(ret.outputs["repeat_id"])
else:
solv_repeats.add(ret.outputs["repeat_id"])
# Repeat ids are random ints so just check their lengths
assert len(complex_repeats) == len(solv_repeats) == 2
def test_gather(benzene_toluene_dag, tmp_path):
# check that .gather behaves as expected
with (
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopComplexSetupUnit.run", # fmt: skip
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopComplexRunUnit._execute", # fmt: skip
return_value={
"repeat_id": 0,
"generation": 0,
"simtype": "complex",
"nc": "file.nc",
"last_checkpoint": "chck.nc",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopSolventSetupUnit.run", # fmt: skip
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
},
),
mock.patch(
"openfe.protocols.openmm_septop.equil_septop_method"
".SepTopSolventRunUnit._execute", # fmt: skip
return_value={
"repeat_id": 0,
"generation": 0,
"simtype": "solvent",
"nc": "file.nc",
"last_checkpoint": "chck.nc",
},
),
):
dagres = gufe.protocols.execute_DAG(
benzene_toluene_dag,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)
protocol = SepTopProtocol(
settings=SepTopProtocol.default_settings(),
)
res = protocol.gather([dagres])
assert isinstance(res, openfe.protocols.openmm_septop.SepTopProtocolResult)
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, septop_json):
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
return pr
def test_reload_protocol_result(self, septop_json):
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = SepTopProtocolResult.from_dict(d["protocol_result"])
assert pr
def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()
assert est
assert est.m == pytest.approx(3.82, abs=0.1)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()
assert est.m == pytest.approx(0.0, abs=0.1)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()
assert isinstance(inds, dict)
assert isinstance(inds["solvent"], list)
assert isinstance(inds["complex"], list)
assert len(inds["solvent"]) == len(inds["complex"]) == 1
for e, u in itertools.chain(inds["solvent"], inds["complex"]):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_get_forwards_etc(self, key, protocolresult):
"""
Due to the short simulation times, we expect the frwd/reverse
analysis to be None.
"""
wmsg = f"were found in the forward and reverse dictionaries of the repeats of the {key}"
with pytest.warns(UserWarning, match=wmsg):
far = protocolresult.get_forward_and_reverse_energy_analysis()
assert isinstance(far, dict)
assert isinstance(far[key], list)
for entry in far[key]:
assert entry is None
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()
assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 1
ovp1 = ovp[key][0]
assert isinstance(ovp1["matrix"], np.ndarray)
if key == "solvent":
lambda_nr = 27
else:
lambda_nr = 19
assert ovp1["matrix"].shape == (lambda_nr, lambda_nr)
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()
if key == "solvent":
lambda_nr = 27
else:
lambda_nr = 19
assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 1
rpx1 = rpx[key][0]
assert "eigenvalues" in rpx1
assert "matrix" in rpx1
assert rpx1["eigenvalues"].shape == (lambda_nr,)
assert rpx1["matrix"].shape == (lambda_nr, lambda_nr)
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()
assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 1
assert all(isinstance(v, float) for v in eq[key])
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()
assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 1
assert all(isinstance(v, float) for v in prod[key])
@pytest.mark.parametrize(
"key, expected_size",
[
["solvent", 87],
["complex", 1868],
],
)
def test_selection_indices(self, key, protocolresult, expected_size):
indices = protocolresult.selection_indices()
assert isinstance(indices, dict)
assert isinstance(indices[key], list)
for inds in indices[key]:
assert isinstance(inds, np.ndarray)
assert len(inds) == expected_size
def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"
with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()
def test_restraint_geometry(self, protocolresult):
geom = protocolresult.restraint_geometries()
assert isinstance(geom, tuple)
assert len(geom) == 2
assert isinstance(geom[0], list)
assert isinstance(geom[0][0], BoreschRestraintGeometry)
assert geom[0][0].guest_atoms == [1779, 1778, 1777]
assert geom[0][0].host_atoms == [802, 801, 800]
assert pytest.approx(geom[0][0].r_aA0) == 0.798936 * offunit.nanometer
assert pytest.approx(geom[0][0].theta_A0) == 2.049091 * offunit.radian
assert pytest.approx(geom[0][0].theta_B0) == 1.221973 * offunit.radian
assert pytest.approx(geom[0][0].phi_A0) == 0.956774 * offunit.radian
assert pytest.approx(geom[0][0].phi_B0) == -1.217188 * offunit.radian
assert pytest.approx(geom[0][0].phi_C0) == -1.068226 * offunit.radian
@pytest.mark.slow
class TestA2AMembraneDryRun:
solvent = SolventComponent(ion_concentration=0 * offunit.molar)
@@ -1231,6 +1027,10 @@ class TestA2AMembraneDryRun:
def complex_run_units(self, dag):
return [u for u in dag.protocol_units if isinstance(u, SepTopComplexRunUnit)]
@pytest.fixture(scope="function")
def complex_analysis_unit(self, dag):
return [u for u in dag.protocol_units if isinstance(u, SepTopComplexAnalysisUnit)]
@pytest.fixture(scope="function")
def solvent_setup_units(self, dag):
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventSetupUnit)]
@@ -1239,10 +1039,14 @@ class TestA2AMembraneDryRun:
def solvent_run_units(self, dag):
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventRunUnit)]
@pytest.fixture(scope="function")
def solvent_analysis_unit(self, dag):
return [u for u in dag.protocol_units if isinstance(u, SepTopSolventAnalysisUnit)]
def test_number_of_units(
self, dag, complex_setup_units, complex_run_units, solvent_setup_units, solvent_run_units
):
assert len(list(dag.protocol_units)) == 4
assert len(list(dag.protocol_units)) == 6
assert len(complex_setup_units) == 1
assert len(complex_run_units) == 1
assert len(solvent_setup_units) == 1
@@ -1400,18 +1204,21 @@ class TestA2AMembraneDryRun:
adaptive_settings.complex_integrator_settings.barostat
== "MonteCarloMembraneBarostat"
)
complex_setup_output = complex_setup_units[0].run(dry=True)["debug"]
complex_setup_output = complex_setup_units[0].run(dry=True)
pdb_file = openmm.app.pdbfile.PDBFile(str(complex_setup_output["topology"]))
system = deserialize(complex_setup_output["system"])
data = complex_run_units[0].run(system, pdb_file, dry=True)["debug"] # fmt: skip
indices = complex_setup_output["selection_indices"]
data = complex_run_units[0].run(system, pdb_file, indices, dry=True) # fmt: skip
# Check the sampler
self._verify_sampler(data["sampler"], complexed=True, settings=adaptive_settings)
# Check the alchemical system
self._assert_expected_alchemical_forces(
data["alchem_system"], complexed=True, settings=adaptive_settings
complex_setup_output["alchem_restrained_system"],
complexed=True,
settings=adaptive_settings,
)
self._test_orthogonal_vectors(data["alchem_system"])
self._test_orthogonal_vectors(complex_setup_output["alchem_restrained_system"])
# Check the non-alchemical system
self._assert_expected_nonalchemical_forces(
@@ -1420,7 +1227,9 @@ class TestA2AMembraneDryRun:
self._test_orthogonal_vectors(complex_setup_output["system_AB"])
# Check the box vectors haven't changed (they shouldn't have because we didn't do MD)
assert_allclose(
from_openmm(data["alchem_system"].getDefaultPeriodicBoxVectors()),
from_openmm(
complex_setup_output["alchem_restrained_system"].getDefaultPeriodicBoxVectors()
),
from_openmm(complex_setup_output["system_AB"].getDefaultPeriodicBoxVectors()),
)
@@ -1433,23 +1242,24 @@ class TestA2AMembraneDryRun:
def test_solvent_dry_run(self, solvent_setup_units, solvent_run_units, settings, tmpdir):
with tmpdir.as_cwd():
solv_setup_output = solvent_setup_units[0].run(dry=True)["debug"]
solv_setup_output = solvent_setup_units[0].run(dry=True)
pdb_file = openmm.app.pdbfile.PDBFile(str(solv_setup_output["topology"]))
system = deserialize(solv_setup_output["system"])
data = solvent_run_units[0].run(system, pdb_file, dry=True)["debug"] # fmt: skip
indices = solv_setup_output["selection_indices"]
data = solvent_run_units[0].run(system, pdb_file, indices, dry=True) # fmt: skip
# Check the sampler
self._verify_sampler(data["sampler"], complexed=False, settings=settings)
# Check the alchemical system
self._assert_expected_alchemical_forces(
data["alchem_system"], complexed=False, settings=settings
solv_setup_output["alchem_restrained_system"], complexed=False, settings=settings
)
self._test_cubic_vectors(data["alchem_system"])
self._test_cubic_vectors(solv_setup_output["alchem_restrained_system"])
# Check the alchemical indices
expected_indices = [i for i in range(self.num_ligand_atoms_A + self.num_ligand_atoms_B)]
assert expected_indices == data["selection_indices"].tolist()
assert expected_indices == solv_setup_output["selection_indices"].tolist()
# Check the non-alchemical system
self._assert_expected_nonalchemical_forces(

View File

@@ -0,0 +1,338 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
import itertools
import json
import math
import pathlib
from unittest import mock
import gufe
import mdtraj as md
import numpy as np
import openmm
import openmm.app
import openmm.unit
import pytest
from numpy.testing import assert_allclose
from openff.units import unit as offunit
from openff.units.openmm import ensure_quantity, from_openmm, to_openmm
from openmm import (
CustomBondForce,
CustomCompoundBondForce,
CustomNonbondedForce,
HarmonicAngleForce,
HarmonicBondForce,
MonteCarloBarostat,
MonteCarloMembraneBarostat,
NonbondedForce,
PeriodicTorsionForce,
)
from openmmtools.alchemy import AbsoluteAlchemicalFactory, AlchemicalRegion
from openmmtools.multistate.multistatesampler import MultiStateSampler
import openfe.protocols.openmm_septop
from openfe import ChemicalSystem, SolventComponent
from openfe.protocols.openmm_septop import (
SepTopComplexRunUnit,
SepTopComplexSetupUnit,
SepTopProtocol,
SepTopProtocolResult,
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
from openfe.protocols.openmm_septop.base_units import (
BaseSepTopAnalysisUnit,
BaseSepTopRunUnit,
BaseSepTopSetupUnit,
)
from openfe.protocols.openmm_utils.serialization import deserialize
from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry
from openfe.tests.protocols.conftest import compute_energy
from openfe.tests.protocols.openmm_ahfe.test_ahfe_protocol import (
_assert_num_forces,
_verify_alchemical_sterics_force_parameters,
)
from .utils import UNIT_TYPES, _get_units
@pytest.fixture
def patcher():
base_path = "openfe.protocols.openmm_septop.base_units"
protocol_path = "openfe.protocols.openmm_septop.equil_septop_method"
with (
mock.patch(
f"{protocol_path}.SepTopComplexSetupUnit.run",
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
"standard_state_correction_A": 0 * offunit.kilocalorie_per_mole,
"standard_state_correction_B": 0 * offunit.kilocalorie_per_mole,
"restraint_geometry_A": None,
"restraint_geometry_B": None,
"selection_indices": np.array(
[
0,
]
),
"subsampled_pdb_structure": "subsampled.pdb",
},
),
mock.patch(
f"{protocol_path}.SepTopSolventSetupUnit.run",
return_value={
"system": pathlib.Path("system.xml.bz2"),
"topology": "topology.pdb",
"standard_state_correction": 0 * offunit.kilocalorie_per_mole,
"selection_indices": np.array(
[
0,
]
),
"subsampled_pdb_structure": "subsampled.pdb",
},
),
mock.patch(
f"{protocol_path}.SepTopComplexRunUnit.run",
return_value={
"trajectory": "foo.nc",
"checkpoint": "bar.nc",
},
),
mock.patch(
f"{protocol_path}.SepTopSolventRunUnit.run",
return_value={
"trajectory": "foo.nc",
"checkpoint": "bar.nc",
},
),
mock.patch(
f"{protocol_path}.SepTopComplexAnalysisUnit.run",
return_value={"foo": "bar"},
),
mock.patch(
f"{protocol_path}.SepTopSolventAnalysisUnit.run",
return_value={"foo": "bar"},
),
mock.patch(
f"{base_path}.deserialize",
return_value="foo",
),
mock.patch(
f"{base_path}.openmm.app.pdbfile.PDBFile",
return_value="foo",
),
):
yield
def test_unit_tagging(benzene_toluene_dag, patcher, tmp_path):
# test that executing the units includes correct gen and repeat info
dag_units = benzene_toluene_dag.protocol_units
for phase in ["solvent", "complex"]:
setup_results = {}
sim_results = {}
analysis_results = {}
setup_units = _get_units(dag_units, UNIT_TYPES[phase]["setup"])
sim_units = _get_units(dag_units, UNIT_TYPES[phase]["sim"])
a_units = _get_units(dag_units, UNIT_TYPES[phase]["analysis"])
for u in setup_units:
rid = u.inputs["repeat_id"]
setup_results[rid] = u.execute(context=gufe.Context(tmp_path, tmp_path))
for u in sim_units:
rid = u.inputs["repeat_id"]
sim_results[rid] = u.execute(
context=gufe.Context(tmp_path, tmp_path),
setup=setup_results[rid],
)
for u in a_units:
rid = u.inputs["repeat_id"]
analysis_results[rid] = u.execute(
context=gufe.Context(tmp_path, tmp_path),
setup=setup_results[rid],
simulation=sim_results[rid],
)
for results in [setup_results, sim_results, analysis_results]:
for ret in results.values():
assert isinstance(ret, gufe.ProtocolUnitResult)
assert ret.outputs["generation"] == 0
assert len(setup_results) == 1
assert len(sim_results) == 1
assert len(analysis_results) == 1
def test_gather(benzene_toluene_dag, patcher, tmp_path):
# check that .gather behaves as expected
dagres = gufe.protocols.execute_DAG(
benzene_toluene_dag,
shared_basedir=tmp_path,
scratch_basedir=tmp_path,
keep_shared=True,
)
protocol = SepTopProtocol(
settings=SepTopProtocol.default_settings(),
)
res = protocol.gather([dagres])
assert isinstance(res, openfe.protocols.openmm_septop.SepTopProtocolResult)
class TestProtocolResult:
@pytest.fixture()
def protocolresult(self, septop_json):
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = openfe.ProtocolResult.from_dict(d["protocol_result"])
return pr
def test_reload_protocol_result(self, septop_json):
d = json.loads(septop_json, cls=gufe.tokenization.JSON_HANDLER.decoder)
pr = SepTopProtocolResult.from_dict(d["protocol_result"])
assert pr
def test_get_estimate(self, protocolresult):
est = protocolresult.get_estimate()
assert est
assert est.m == pytest.approx(1.6, abs=0.1)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_uncertainty(self, protocolresult):
est = protocolresult.get_uncertainty()
assert est.m == pytest.approx(0.0, abs=0.1)
assert isinstance(est, offunit.Quantity)
assert est.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_individual(self, protocolresult):
inds = protocolresult.get_individual_estimates()
assert isinstance(inds, dict)
assert isinstance(inds["solvent"], list)
assert isinstance(inds["complex"], list)
assert len(inds["solvent"]) == len(inds["complex"]) == 1
for e, u in itertools.chain(inds["solvent"], inds["complex"]):
assert e.is_compatible_with(offunit.kilojoule_per_mole)
assert u.is_compatible_with(offunit.kilojoule_per_mole)
def test_get_forwards_etc(self, protocolresult):
"""
Due to the short simulation times, we expect the frwd/reverse
analysis of the solvent to be None.
"""
wmsg = "were found in the forward and reverse dictionaries of the repeats of the solvent"
with pytest.warns(UserWarning, match=wmsg):
far = protocolresult.get_forward_and_reverse_energy_analysis()
assert isinstance(far, dict)
for key in ["solvent", "complex"]:
assert isinstance(far[key], list)
assert far["solvent"][0] is None
complex_keys = list(far["complex"][0].keys())
for key in ["fractions", "forward_DGs", "forward_dDGs", "reverse_DGs", "reverse_dDGs"]:
assert key in complex_keys
assert len(far["complex"][0][key]) == 10
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_get_overlap_matrices(self, key, protocolresult):
ovp = protocolresult.get_overlap_matrices()
assert isinstance(ovp, dict)
assert isinstance(ovp[key], list)
assert len(ovp[key]) == 1
ovp1 = ovp[key][0]
assert isinstance(ovp1["matrix"], np.ndarray)
if key == "solvent":
lambda_nr = 27
else:
lambda_nr = 19
assert ovp1["matrix"].shape == (lambda_nr, lambda_nr)
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_get_replica_transition_statistics(self, key, protocolresult):
rpx = protocolresult.get_replica_transition_statistics()
if key == "solvent":
lambda_nr = 27
else:
lambda_nr = 19
assert isinstance(rpx, dict)
assert isinstance(rpx[key], list)
assert len(rpx[key]) == 1
rpx1 = rpx[key][0]
assert "eigenvalues" in rpx1
assert "matrix" in rpx1
assert rpx1["eigenvalues"].shape == (lambda_nr,)
assert rpx1["matrix"].shape == (lambda_nr, lambda_nr)
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_equilibration_iterations(self, key, protocolresult):
eq = protocolresult.equilibration_iterations()
assert isinstance(eq, dict)
assert isinstance(eq[key], list)
assert len(eq[key]) == 1
assert all(isinstance(v, float) for v in eq[key])
@pytest.mark.parametrize("key", ["solvent", "complex"])
def test_production_iterations(self, key, protocolresult):
prod = protocolresult.production_iterations()
assert isinstance(prod, dict)
assert isinstance(prod[key], list)
assert len(prod[key]) == 1
assert all(isinstance(v, float) for v in prod[key])
@pytest.mark.parametrize(
"key, expected_size",
[
["solvent", 87],
["complex", 1868],
],
)
def test_selection_indices(self, key, protocolresult, expected_size):
indices = protocolresult.selection_indices()
assert isinstance(indices, dict)
assert isinstance(indices[key], list)
for inds in indices[key]:
assert isinstance(inds, np.ndarray)
assert len(inds) == expected_size
def test_filenotfound_replica_states(self, protocolresult):
errmsg = "File could not be found"
with pytest.raises(ValueError, match=errmsg):
protocolresult.get_replica_states()
def test_restraint_geometry(self, protocolresult):
geom = protocolresult.restraint_geometries()
assert isinstance(geom, tuple)
assert len(geom) == 2
assert isinstance(geom[0], list)
assert isinstance(geom[0][0], BoreschRestraintGeometry)
assert geom[0][0].guest_atoms == [1779, 1778, 1777]
assert geom[0][0].host_atoms == [802, 801, 800]
assert pytest.approx(geom[0][0].r_aA0, abs=0.01) == 0.75 * offunit.nanometer
assert pytest.approx(geom[0][0].theta_A0, abs=0.01) == 1.95 * offunit.radian
assert pytest.approx(geom[0][0].theta_B0, abs=0.01) == 1.33 * offunit.radian
assert pytest.approx(geom[0][0].phi_A0, abs=0.01) == 1.01 * offunit.radian
assert pytest.approx(geom[0][0].phi_B0, abs=0.01) == -1.24 * offunit.radian
assert pytest.approx(geom[0][0].phi_C0, abs=0.01) == -1.08 * offunit.radian

View File

@@ -277,10 +277,10 @@ def test_openmm_run_engine(
assert unit_shared.exists()
assert pathlib.Path(unit_shared).is_dir()
if "SepTopComplexRunUnit" in pur.source_key.split("-") or "SepTopSolventRunUnit" in pur.source_key.split("-"): # fmt: skip
checkpoint = pur.outputs["last_checkpoint"]
assert checkpoint == f"{pur.outputs['simtype']}_checkpoint.nc"
checkpoint = pur.outputs["checkpoint"]
assert checkpoint == unit_shared / f"{pur.outputs['simtype']}_checkpoint.nc"
assert (unit_shared / checkpoint).exists()
nc = pur.outputs["nc"]
nc = pur.outputs["trajectory"]
assert nc == unit_shared / f"{pur.outputs['simtype']}.nc"
assert nc.exists()

View File

@@ -38,6 +38,13 @@ def solvent_run_protocol_unit(protocol_units):
return pu
@pytest.fixture
def solvent_analysis_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_septop.SepTopSolventAnalysisUnit):
return pu
@pytest.fixture
def complex_setup_protocol_unit(protocol_units):
for pu in protocol_units:
@@ -52,6 +59,13 @@ def complex_run_protocol_unit(protocol_units):
return pu
@pytest.fixture
def complex_analysis_protocol_unit(protocol_units):
for pu in protocol_units:
if isinstance(pu, openmm_septop.SepTopComplexAnalysisUnit):
return pu
@pytest.fixture
def protocol_result(septop_json):
d = json.loads(
@@ -115,6 +129,23 @@ class TestSepTopSolventRunUnit(GufeTokenizableTestsMixin):
assert self.repr in repr(instance)
class TestSepTopSolventAnalysisUnit(GufeTokenizableTestsMixin):
cls = openmm_septop.SepTopSolventAnalysisUnit
repr = "SepTopSolventAnalysisUnit(SepTop RBFE Analysis, transformation benzene to toluene, solvent leg"
key = None
@pytest.fixture()
def instance(self, solvent_analysis_protocol_unit):
return solvent_analysis_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestSepTopComplexSetupUnit(GufeTokenizableTestsMixin):
cls = openmm_septop.SepTopComplexSetupUnit
repr = (
@@ -151,6 +182,23 @@ class TestSepTopComplexRunUnit(GufeTokenizableTestsMixin):
assert self.repr in repr(instance)
class TestSepTopComplexAnalysisUnit(GufeTokenizableTestsMixin):
cls = openmm_septop.SepTopComplexAnalysisUnit
repr = "SepTopComplexAnalysisUnit(SepTop RBFE Analysis, transformation benzene to toluene, complex leg"
key = None
@pytest.fixture()
def instance(self, complex_analysis_protocol_unit):
return complex_analysis_protocol_unit
def test_repr(self, instance):
"""
Overwrites the base `test_repr` call.
"""
assert isinstance(repr(instance), str)
assert self.repr in repr(instance)
class TestSepTopProtocolResult(GufeTokenizableTestsMixin):
cls = openmm_septop.SepTopProtocolResult
key = None

View File

@@ -0,0 +1,30 @@
# This code is part of OpenFE and is licensed under the MIT license.
# For details, see https://github.com/OpenFreeEnergy/openfe
from openfe.protocols.openmm_septop import (
SepTopComplexAnalysisUnit,
SepTopComplexRunUnit,
SepTopComplexSetupUnit,
SepTopSolventAnalysisUnit,
SepTopSolventRunUnit,
SepTopSolventSetupUnit,
)
UNIT_TYPES = {
"solvent": {
"setup": SepTopSolventSetupUnit,
"sim": SepTopSolventRunUnit,
"analysis": SepTopSolventAnalysisUnit,
},
"complex": {
"setup": SepTopComplexSetupUnit,
"sim": SepTopComplexRunUnit,
"analysis": SepTopComplexAnalysisUnit,
},
}
def _get_units(protocol_units, unit_type):
"""
Helper method to extract setup units.
"""
return [pu for pu in protocol_units if isinstance(pu, unit_type)]

View File

@@ -97,7 +97,23 @@ def _get_legs_from_result_jsons(
for k in result["unit_results"].keys()
if k.startswith("ProtocolUnitResult")
] # fmt: skip
# In openfe v1.11+, we only want to pick up results from
# the Analysis Unit. To ensure backwards compatibility,
# we check if there are any analysis units. If so,
# we set a flag and later exclude Setup and Run.
has_analysis_units = any(
["Analysis" in result["unit_results"][p]["source_key"] for p in proto_key]
)
for p in proto_key:
# Skip non-analysis units if we have any
if has_analysis_units and (
"Setup" in result["unit_results"][p]["source_key"]
or "Run" in result["unit_results"][p]["source_key"]
):
continue
if "unit_estimate" in result["unit_results"][p]["outputs"]:
simtype = result["unit_results"][p]["outputs"]["simtype"]
dg = result["unit_results"][p]["outputs"]["unit_estimate"]
@@ -135,8 +151,13 @@ def _get_names(result: dict) -> tuple[str, str]:
solvent_data = list(result["protocol_result"]["data"]["solvent"].values())[0][0]
name_A = solvent_data["inputs"]["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
name_B = solvent_data["inputs"]["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"]
try:
setup_data = solvent_data["inputs"]["setup"]["inputs"]
except KeyError:
setup_data = solvent_data["inputs"]
name_A = setup_data["alchemical_components"]["stateA"][0]["molprops"]["ofe-name"]
name_B = setup_data["alchemical_components"]["stateB"][0]["molprops"]["ofe-name"]
return str(name_A), str(name_B)