Initial test at what we may need to change to support charge changes for explicitly solvated systems

This commit is contained in:
hannahbaumann
2026-05-18 16:50:08 +02:00
parent cb10892e79
commit 403d38e02f
4 changed files with 144 additions and 15 deletions

View File

@@ -10,6 +10,7 @@
import itertools
import logging
import warnings
from collections import Counter
from copy import deepcopy
from typing import Optional, Union
@@ -123,10 +124,14 @@ def _fix_alchemical_water_atom_mapping(
def handle_alchemical_waters(
water_resids: list[int], topology: app.Topology,
system: System, system_mapping: dict,
water_resids: list[int],
topology: app.Topology,
system: System,
system_mapping: dict,
charge_difference: int,
solvent_component: SolventComponent,
positive_ion_resname: str,
negative_ion_resname: str,
water_resname: str,
):
"""
Add alchemical waters from a pre-defined list.
@@ -150,7 +155,7 @@ def handle_alchemical_waters(
The name of a negative ion to replace the water with if the absolute
charge difference is negative.
water_resname : str
The residue name of the water to get parameters for. Default 'HOH'.
The residue name of the water to get parameters for.
Raises
------
@@ -172,16 +177,16 @@ def handle_alchemical_waters(
raise ValueError(errmsg)
if charge_difference > 0:
ion_resname = solvent_component.positive_ion.strip('-+').upper()
ion_resname = positive_ion_resname
elif charge_difference < 0:
ion_resname = solvent_component.negative_ion.strip('-+').upper()
ion_resname = negative_ion_resname
# if there's no charge difference then just skip altogether
else:
return None
ion_charge, ion_sigma, ion_epsilon, o_charge, h_charge = _get_ion_and_water_parameters(
topology, system, ion_resname,
'HOH', # Modeller always adds HOH waters
water_resname,
)
# get the nonbonded forces
@@ -433,7 +438,6 @@ def _remove_constraints(old_to_new_atom_map, old_system, old_topology,
* Very slow, needs refactoring
* Can we drop having topologies as inputs here?
"""
from collections import Counter
no_const_old_to_new_atom_map = deepcopy(old_to_new_atom_map)
@@ -720,3 +724,60 @@ def set_and_check_new_positions(mapping, old_topology, new_topology,
logging.warning(wmsg)
return new_pos_array * omm_unit.angstrom
def _get_ion_resnames_from_topology(topology: app.Topology) -> tuple[str, str]:
"""
Infer positive and negative ion residue names from a topology by
finding the most common monovalent ion of each charge type.
Falls back to NA/CL if none are found.
Parameters
----------
topology : app.Topology
The topology to search for ions.
Returns
-------
pos_ion : str
The residue name of the most abundant positive monovalent ion (Na, K).
neg_ion : str
The residue name of the most abundant negative monovalent ion (Cl).
"""
known_positive = {'NA', 'K'}
# This doesn't make much sense yet to check it, since it's only Cl, but
# leaving it here for now so we can add other neg ions.
known_negative = {'CL'}
pos_counts = Counter(
r.name for r in topology.residues() if r.name in known_positive
)
neg_counts = Counter(
r.name for r in topology.residues() if r.name in known_negative
)
if not pos_counts:
wmsg = (
"Could not find any known positive monovalent ions "
f"(searched for {known_positive}) in the topology. "
"Defaulting to NA for explicit charge correction."
)
warnings.warn(wmsg)
logger.warning(wmsg)
pos_ion = 'NA'
else:
pos_ion = max(pos_counts, key=pos_counts.get)
if not neg_counts:
wmsg = (
"Could not find any known negative monovalent ions "
f"(searched for {known_negative}) in the topology. "
"Defaulting to CL for explicit charge correction."
)
warnings.warn(wmsg)
logger.warning(wmsg)
neg_ion = 'CL'
else:
neg_ion = max(neg_counts, key=neg_counts.get)
return pos_ion, neg_ion

View File

@@ -28,6 +28,7 @@ from gufe import (
ProteinComponent,
ProteinMembraneComponent,
SmallMoleculeComponent,
SolvatedPDBComponent,
SolventComponent,
settings,
)
@@ -50,6 +51,7 @@ from .equil_rfe_settings import (
OpenMMSolvationSettings,
RelativeHybridTopologyProtocolSettings,
)
from . import _rfe_utils
from .hybridtop_protocol_results import RelativeHybridTopologyProtocolResult
from .hybridtop_units import (
HybridTopologyMultiStateAnalysisUnit,
@@ -435,7 +437,20 @@ class RelativeHybridTopologyProtocol(gufe.Protocol):
)
raise ValueError(errmsg)
ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[difference]
# resolve ion names from SolventComponent or topology
if isinstance(solvent_component, SolventComponent):
positive_ion = solvent_component.positive_ion.strip(
'-+').upper()
negative_ion = solvent_component.negative_ion.strip(
'-+').upper()
elif isinstance(solvent_component, SolvatedPDBComponent):
positive_ion, negative_ion = (
_rfe_utils.topologyhelpers._get_ion_resnames_from_topology(
solvent_component.to_openmm_topology()
)
)
ion = {-1: positive_ion, 1: negative_ion}[difference]
wmsg = (
f"A charge difference of {difference} is observed "

View File

@@ -410,6 +410,17 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
if charge_difference == 0:
return
# Resolve ion names from the solvent component if available,
# otherwise infer from the topology (for explicitly solvated systems).
if isinstance(solvent_component, SolventComponent):
positive_ion = solvent_component.positive_ion.strip('-+').upper()
negative_ion = solvent_component.negative_ion.strip('-+').upper()
else:
positive_ion, negative_ion = (
_rfe_utils.topologyhelpers._get_ion_resnames_from_topology(
stateA_topology)
)
# Get the residue ids for waters to turn alchemical
alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters(
topology=stateA_topology,
@@ -425,7 +436,9 @@ class HybridTopologySetupUnit(gufe.ProtocolUnit, HybridTopologyUnitMixin):
system=stateB_system,
system_mapping=system_mappings,
charge_difference=charge_difference,
solvent_component=solvent_component,
positive_ion_resname=positive_ion,
negative_ion_resname=negative_ion,
water_resname='HOH' # Need to check if this is also true for systems not prepped with addSolvent
)
def _get_omm_objects(

View File

@@ -1215,6 +1215,36 @@ def test_dry_run_membrane_complex(
)
def test_validate_charge_difference_membrane_system(
a2a_protein_membrane_component,
a2a_ligands,
):
ligA = next(c for c in a2a_ligands if c.name == "4g")
ligB = next(c for c in a2a_ligands if c.name == "4h")
mapping = openfe.LigandAtomMapping(
componentA=ligA,
componentB=ligB,
componentA_to_componentB={i: i for i in range(36)},
)
settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings()
# Mock get_alchemical_charge_difference to return non-zero
with mock.patch.object(
mapping,
"get_alchemical_charge_difference",
return_value=1,
):
protocol = openmm_rfe.RelativeHybridTopologyProtocol(settings=settings)
protocol._validate_charge_difference(
mapping=mapping,
nonbonded_method=settings.forcefield_settings.nonbonded_method,
explicit_charge_correction=True,
solvent_component=a2a_protein_membrane_component,
)
def test_lambda_schedule_default():
lambdas = openmm_rfe._rfe_utils.lambdaprotocol.LambdaProtocol(functions="default")
assert len(lambdas.lambda_schedule) == 10
@@ -1980,7 +2010,9 @@ def test_handle_alchemwats_incorrect_count(
system=system,
system_mapping={},
charge_difference=1,
solvent_component=openfe.SolventComponent(),
positive_ion_resname='NA',
negative_ion_resname='CL',
water_resname='HOH',
)
@@ -2004,7 +2036,9 @@ def test_handle_alchemwats_too_many_nbf(
system=new_system,
system_mapping={},
charge_difference=1,
solvent_component=openfe.SolventComponent(),
positive_ion_resname='NA',
negative_ion_resname='CL',
water_resname='HOH',
)
@@ -2026,7 +2060,9 @@ def test_handle_alchemwats_vsite_water(
system=system,
system_mapping={},
charge_difference=1,
solvent_component=openfe.SolventComponent(),
positive_ion_resname='NA',
negative_ion_resname='CL',
water_resname='HOH',
)
@@ -2054,7 +2090,9 @@ def test_handle_alchemwats_incorrect_atom(
system=new_system,
system_mapping=benzene_self_system_mapping,
charge_difference=1,
solvent_component=openfe.SolventComponent(),
positive_ion_resname='NA',
negative_ion_resname='CL',
water_resname='HOH',
)
@@ -2073,7 +2111,9 @@ def test_handle_alchemical_wats(
system=system,
system_mapping=benzene_self_system_mapping,
charge_difference=1,
solvent_component=openfe.SolventComponent(),
positive_ion_resname='NA',
negative_ion_resname='CL',
water_resname='HOH',
)
# check the mappings