update the ghostly application function

This commit is contained in:
Josh Horton
2026-02-04 16:32:31 +00:00
parent 36b623d918
commit 603405728a

View File

@@ -8,6 +8,7 @@
# fmt: off
import ast
from collections import defaultdict
import itertools
import logging
import warnings
@@ -727,6 +728,72 @@ def set_and_check_new_positions(mapping, old_topology, new_topology,
return new_pos_array * omm_unit.angstrom
def _create_bond_lookup(htf: HybridTopologyFactory) -> dict:
"""
Create a lookup dictionary of bonded atoms in the hybrid topology.
Parameters
----------
htf : DevelopmentHybridTopologyFactory
The hybrid topology factory containing the hybrid topology.
Returns
-------
dict
A dictionary where keys are atom indices and values are sets of bonded atom indices.
"""
# get all the bonds in the hybrid topology so we can check for an improper
hybrid_bonds = list(htf.omm_hybrid_topology.bonds())
# construct a lookup of bonded atoms
bonded_atom_lookup = defaultdict(set)
for bond in hybrid_bonds:
a1 = bond[0].index
a2 = bond[1].index
bonded_atom_lookup[a1].add(a2)
bonded_atom_lookup[a2].add(a1)
return bonded_atom_lookup
def _copy_hybrid_system(htf: HybridTopologyFactory) -> openmm.System:
"""
Create a deep copy of the given HTF which is ready to be modified by scaling or ghostly corrections.
Parameters
----------
htf : HybridTopologyFactory
The hybrid topology factory with the hybrid openmm system to copy.
Returns
-------
openmm.System
A minimal copy of the hybrid system ready for modification.
Notes
-----
Things copied:
- hybrid system particles
- hybrid system constraints
- barostat if present and box vectors
"""
new_hybrid_system = openmm.System()
# add all the particles
for i in range(htf.hybrid_system.getNumParticles()):
new_hybrid_system.addParticle(htf.hybrid_system.getParticleMass(i))
# add all constraints
for i in range(htf.hybrid_system.getNumConstraints()):
p1, p2, dist = htf.hybrid_system.getConstraintParameters(i)
new_hybrid_system.addConstraint(p1, p2, dist)
# add the barostat and the box vectors
for force in htf.hybrid_system.getForces():
if isinstance(force, openmm.MonteCarloBarostat) or isinstance(force, openmm.MonteCarloMembraneBarostat):
new_hybrid_system.addForce(deepcopy(force))
# also add the box vectors
box_vectors = htf.hybrid_system.getDefaultPeriodicBoxVectors()
new_hybrid_system.setDefaultPeriodicBoxVectors(*box_vectors)
break
return new_hybrid_system
def _scale_angles_and_torsions(htf: HybridTopologyFactory, scale_factor: float = 0.1, scale_angles: bool = True) -> HybridTopologyFactory:
"""
Scale all angles and torsion force constants in the dummy-core junction by the given scale factor.
@@ -756,20 +823,7 @@ def _scale_angles_and_torsions(htf: HybridTopologyFactory, scale_factor: float =
# add all the particles
logger.info("Copying particles and constraints to new hybrid system.")
print("Copying particles and constraints to new hybrid system.")
for i in range(htf.hybrid_system.getNumParticles()):
softened_hybrid_system.addParticle(htf.hybrid_system.getParticleMass(i))
# add all constraints
for i in range(htf.hybrid_system.getNumConstraints()):
p1, p2, dist = htf.hybrid_system.getConstraintParameters(i)
softened_hybrid_system.addConstraint(p1, p2, dist)
# add the barostat and the box vectors
for force in htf.hybrid_system.getForces():
if isinstance(force, openmm.MonteCarloBarostat) or isinstance(force, openmm.MonteCarloMembraneBarostat):
softened_hybrid_system.addForce(deepcopy(force))
# also add the box vectors
box_vectors = htf.hybrid_system.getDefaultPeriodicBoxVectors()
softened_hybrid_system.setDefaultPeriodicBoxVectors(*box_vectors)
break
softened_hybrid_system = _copy_hybrid_system(htf=htf)
hybrid_forces = htf._hybrid_system_forces
# copy all forces which do not need to be modified
@@ -871,7 +925,13 @@ def _scale_angles_and_torsions(htf: HybridTopologyFactory, scale_factor: float =
def _load_ghostly_corrections(ghostly_output_string: str) -> dict:
"""Parse the Ghostly modification output JSON string and return the corrections as a dictionary."""
"""
Parse the Ghostly modification output JSON string and return the corrections as a dictionary.
Notes
-----
- The corrections are returned in a dictionary with the same structure as the Ghostly output but we use sets for the removed and stiffened angles/dihedrals for faster lookup.
"""
corrections = json.loads(ghostly_output_string)
# convert the string keys back to tuples
@@ -879,10 +939,10 @@ def _load_ghostly_corrections(ghostly_output_string: str) -> dict:
for correction_type in corrections[lambda_key].keys():
# these are stings for some reason, convert them back to tuples
if correction_type in ["removed_angles", "removed_dihedrals"]:
corrections[lambda_key][correction_type] = [ast.literal_eval(tup_str) for tup_str in corrections[lambda_key][correction_type]]
corrections[lambda_key][correction_type] = set([ast.literal_eval(tup_str) for tup_str in corrections[lambda_key][correction_type]])
# this is a list so we need to convert each to a tuple
elif correction_type == "stiffened_angles":
corrections[lambda_key][correction_type] = [tuple(angle) for angle in corrections[lambda_key][correction_type]]
corrections[lambda_key][correction_type] = set([tuple(angle) for angle in corrections[lambda_key][correction_type]])
elif correction_type == "softened_angles":
new_dict = {}
for tup_str, params in corrections[lambda_key][correction_type].items():
@@ -891,6 +951,7 @@ def _load_ghostly_corrections(ghostly_output_string: str) -> dict:
corrections[lambda_key][correction_type] = new_dict
return corrections
def _shift_ghostly_correction_indices(htf: HybridTopologyFactory, corrections: dict) -> dict:
"""Shift the atom indices in the ghostly corrections to account for the solvent environment in the HTF."""
core_atoms = htf._atom_classes["core_atoms"]
@@ -914,7 +975,7 @@ def _shift_ghostly_correction_indices(htf: HybridTopologyFactory, corrections: d
for lambda_key in corrections.keys():
for correction_type in corrections[lambda_key].keys():
if correction_type in ["removed_angles", "removed_dihedrals", "stiffened_angles"]:
shifted_list = []
shifted_list = set()
for angle in corrections[lambda_key][correction_type]:
shifted_angle = tuple(
atom_idx + (last_water_atom - max_state_a_atom)
@@ -923,7 +984,7 @@ def _shift_ghostly_correction_indices(htf: HybridTopologyFactory, corrections: d
if atom_idx not in core_atoms and atom_idx > max_state_a_atom else atom_idx
for atom_idx in angle
)
shifted_list.append(shifted_angle)
shifted_list.add(shifted_angle)
shifted_corrections.setdefault(lambda_key, {})[correction_type] = shifted_list
elif correction_type == "softened_angles":
new_dict = {}
@@ -941,28 +1002,29 @@ def _shift_ghostly_correction_indices(htf: HybridTopologyFactory, corrections: d
def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) -> HybridTopologyFactory:
"""Apply the ghostly corrections parsed from the output file to the HTF."""
"""
Apply the ghostly corrections parsed from the output file to the HTF.
Notes
-----
- The HTF is edited inplace due to issues with deepcopying the HTF object.
- The method will track which corrections were applied and compare them to the supplied corrections.
- The method will check that a correction is applied to all junctions involving dummy atoms identified using an internal method.
Raises
------
AssertionError
If a parameter is changed by ghostly but we can not determine what type of correction it was.
ValueError
If a correction provided by ghostly is not applied to the HTF.
"""
logger.info("Applying ghostly corrections to hybrid system.")
dummy_old_atoms = htf._atom_classes["unique_old_atoms"]
dummy_new_atoms = htf._atom_classes["unique_new_atoms"]
corrections = _shift_ghostly_correction_indices(htf, corrections)
new_hybrid_system = openmm.System()
# add all the particles
for i in range(htf.hybrid_system.getNumParticles()):
new_hybrid_system.addParticle(htf.hybrid_system.getParticleMass(i))
# add all constraints
for i in range(htf.hybrid_system.getNumConstraints()):
p1, p2, dist = htf.hybrid_system.getConstraintParameters(i)
new_hybrid_system.addConstraint(p1, p2, dist)
# add the barostat and the box vectors
for force in htf.hybrid_system.getForces():
if isinstance(force, openmm.MonteCarloBarostat) or isinstance(force, openmm.MonteCarloMembraneBarostat):
new_hybrid_system.addForce(deepcopy(force))
# also add the box vectors
box_vectors = htf.hybrid_system.getDefaultPeriodicBoxVectors()
new_hybrid_system.setDefaultPeriodicBoxVectors(*box_vectors)
break
new_hybrid_system = _copy_hybrid_system(htf=htf)
hybrid_forces = htf._hybrid_system_forces
# copy all forces which do not need to be modified
@@ -984,42 +1046,67 @@ def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) ->
# get a quick lookup of the forces
new_hybrid_forces = {force.getName(): force for force in new_hybrid_system.getForces()}
# track the applied corrections
applied_corrections = {
"lambda_0": {"removed_angles": set(), "stiffened_angles": set(), "softened_angles": set(), "removed_dihedrals": set()},
"lambda_1": {"removed_angles": set(), "stiffened_angles": set(), "softened_angles": set(), "removed_dihedrals": set()}
}
# process angles
custom_angle_force = new_hybrid_forces["CustomAngleForce"]
old_hybrid_angle_force = hybrid_forces["standard_angle_force"]
for i in range(old_hybrid_angle_force.getNumAngles()):
p1, p2, p3, theta_eq, k = old_hybrid_angle_force.getAngleParameters(i)
# set up the angle parameters for stiffening and zeroing
ZERO_K = 0.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
STIFF_K = 100.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
STIFF_THETA = 0.5 * math.pi * omm_unit.radian
for angle_idx in range(old_hybrid_angle_force.getNumAngles()):
p1, p2, p3, theta_eq, k = old_hybrid_angle_force.getAngleParameters(angle_idx)
# check if we have one ghost atom for this angle
angle = (p1, p2, p3)
print(angle)
if 1<= len(dummy_old_atoms.intersection(angle)) < 3 or 1<= len(dummy_new_atoms.intersection(angle)) < 3:
if 1 <= len(dummy_old_atoms.intersection(angle)) < 3 or 1<= len(dummy_new_atoms.intersection(angle)) < 3:
angle_reversed = (p3, p2, p1)
# set up containers for the end state values
lambda_0_k = k
lambda_0_theta_eq = theta_eq
lambda_1_k = k
lambda_1_theta_eq = theta_eq
end_state, correction_type = None, None
# check for removed angles
if angle in corrections["lambda_0"]["removed_angles"] or angle_reversed in corrections["lambda_0"]["removed_angles"]:
lambda_0_k = 0.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
elif angle in corrections["lambda_1"]["removed_angles"] or angle_reversed in corrections["lambda_1"]["removed_angles"]:
lambda_1_k = 0.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
if (prob_angle:= angle) in corrections["lambda_0"]["removed_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_0"]["removed_angles"]:
lambda_0_k = ZERO_K
end_state = 0
correction_type = "removed_angles"
elif (prob_angle:= angle) in corrections["lambda_1"]["removed_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_1"]["removed_angles"]:
lambda_1_k = ZERO_K
end_state = 1
correction_type = "removed_angles"
# check for stiffened angles
elif angle in corrections["lambda_0"]["stiffened_angles"] or angle_reversed in corrections["lambda_0"]["stiffened_angles"]:
lambda_0_k = 100.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2) # default stiffening k value
lambda_0_theta_eq = 0.5 * math.pi * omm_unit.radian # 90 degrees
elif angle in corrections["lambda_1"]["stiffened_angles"] or angle_reversed in corrections["lambda_1"]["stiffened_angles"]:
lambda_1_k = 100.0 * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2) # default stiffening k value
lambda_1_theta_eq = 0.5 * math.pi * omm_unit.radian # 90 degrees
elif (prob_angle:= angle) in corrections["lambda_0"]["stiffened_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_0"]["stiffened_angles"]:
lambda_0_k = STIFF_K # default stiffening k value
lambda_0_theta_eq = STIFF_THETA # 90 degrees
end_state = 0
correction_type = "stiffened_angles"
elif (prob_angle:= angle) in corrections["lambda_1"]["stiffened_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_1"]["stiffened_angles"]:
lambda_1_k = STIFF_K # default stiffening k value
lambda_1_theta_eq = STIFF_THETA # 90 degrees
end_state = 1
correction_type = "stiffened_angles"
# check for softened angles
elif (prob_angle:= angle) in corrections["lambda_0"]["softened_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_0"]["softened_angles"]:
soften_params = corrections["lambda_0"]["softened_angles"][prob_angle]
lambda_0_k = soften_params["k"] * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
lambda_0_theta_eq = soften_params["theta0"] * omm_unit.radian
end_state = 0
correction_type = "softened_angles"
elif (prob_angle:= angle) in corrections["lambda_1"]["softened_angles"] or (prob_angle:= angle_reversed) in corrections["lambda_1"]["softened_angles"]:
soften_params = corrections["lambda_1"]["softened_angles"][prob_angle]
lambda_1_k = soften_params["k"] * omm_unit.kilocalories_per_mole / (omm_unit.radian ** 2)
lambda_1_theta_eq = soften_params["theta0"] * omm_unit.radian
end_state = 1
correction_type = "softened_angles"
# some angles involving dummy atoms need to be kept to ensure 3 redundant connections
if lambda_0_k != lambda_1_k or lambda_0_theta_eq != lambda_1_theta_eq:
@@ -1033,6 +1120,10 @@ def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) ->
custom_angle_force.addAngle(p1, p2, p3,
[lambda_0_theta_eq, lambda_0_k,
lambda_1_theta_eq, lambda_1_k])
# log this as a correction applied
assert correction_type is not None, "Correction type should not be None if k or theta_eq differ!"
applied_corrections[f"lambda_{end_state}"][correction_type].add(prob_angle)
else:
# both k and theta_eq values are the same, just add to the standard angle force
new_harmonic_angle_force.addAngle(p1, p2, p3, theta_eq, k)
@@ -1044,20 +1135,45 @@ def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) ->
# process torsions
custom_torsion_force = new_hybrid_forces["CustomTorsionForce"]
old_hybrid_torsion_force = hybrid_forces["unique_atom_torsion_force"]
for i in range(old_hybrid_torsion_force.getNumTorsions()):
p1, p2, p3, p4, periodicity, phase, k = old_hybrid_torsion_force.getTorsionParameters(i)
# set up the torsion parameters for zeroing
TORSION_ZERO_K = 0.0 * omm_unit.kilocalories_per_mole
# get all the bonds in the hybrid topology so we can check for an improper
bonded_atom_lookup = _create_bond_lookup(htf=htf)
for torsion_idx in range(old_hybrid_torsion_force.getNumTorsions()):
p1, p2, p3, p4, periodicity, phase, k = old_hybrid_torsion_force.getTorsionParameters(torsion_idx)
# check if we have one ghost atom for this torsion
torsion = (p1, p2, p3, p4)
if 1<= len(dummy_old_atoms.intersection(torsion)) < 4 or 1<= len(dummy_new_atoms.intersection(torsion)) < 4:
torsion_reversed = (p4, p3, p2, p1)
# check if we have an improper torsion (central atoms bonded)
if not (p1 in bonded_atom_lookup[p2] and p2 in bonded_atom_lookup[p3] and p3 in bonded_atom_lookup[p4]):
# this is an improper with a dummy atom and should be skipped
# generate all permutations of the other atoms to check for removal
central_atom = p1
other_atoms = [p2, p3, p4]
torsion_variants = set()
for perm in itertools.permutations(other_atoms):
torsion_variants.add((central_atom, perm[0], perm[1], perm[2]))
# add the reverse as well as ghostly may list either
torsion_variants.add((perm[2], perm[1], perm[0], central_atom))
else:
torsion_variants = {torsion, torsion_reversed}
# set up containers for the end state values
lambda_0_k = k
lambda_1_k = k
end_state = None
# check for removed dihedrals
if torsion in corrections["lambda_0"]["removed_dihedrals"] or torsion_reversed in corrections["lambda_0"]["removed_dihedrals"]:
lambda_0_k = 0.0 * omm_unit.kilocalories_per_mole
elif torsion in corrections["lambda_1"]["removed_dihedrals"] or torsion_reversed in corrections["lambda_1"]["removed_dihedrals"]:
lambda_1_k = 0.0 * omm_unit.kilocalories_per_mole
if matched:= corrections["lambda_0"]["removed_dihedrals"].intersection(torsion_variants):
lambda_0_k = TORSION_ZERO_K
end_state = 0
elif matched:= corrections["lambda_1"]["removed_dihedrals"].intersection(torsion_variants):
lambda_1_k = TORSION_ZERO_K
end_state = 1
# some dihedrals involving ghost atoms need to be kept to ensure 3 redundant connections
if lambda_0_k != lambda_1_k:
# add the term to the interpolated custom torsion force
@@ -1071,6 +1187,9 @@ def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) ->
[periodicity, phase,
lambda_0_k, periodicity,
phase, lambda_1_k])
# log this as a correction applied
assert end_state is not None, "End state should not be None if k values differ!"
applied_corrections[f"lambda_{end_state}"]["removed_dihedrals"].update(matched)
else:
# both k values are the same, just add to the standard torsion force
new_torsion_force.addTorsion(p1, p2, p3, p4, periodicity, phase, k)
@@ -1078,6 +1197,33 @@ def _apply_ghostly_corrections(htf: HybridTopologyFactory, corrections: dict) ->
# the term does not involve any ghost atoms, so we can just copy it
new_torsion_force.addTorsion(p1, p2, p3, p4, periodicity, phase, k)
# compare the supplied and applied corrections
for lambda_key in corrections.keys():
for correction_type in corrections[lambda_key].keys():
supplied = set()
applied = applied_corrections[lambda_key][correction_type]
if correction_type in ["removed_angles", "stiffened_angles"]:
supplied = corrections[lambda_key][correction_type]
elif correction_type == "softened_angles":
supplied = set([tuple(tup) for tup in corrections[lambda_key][correction_type].keys()])
elif correction_type == "removed_dihedrals":
supplied = corrections[lambda_key][correction_type]
not_applied = supplied - applied
if len(not_applied) > 0 and correction_type == "removed_dihedrals":
# in some cases dihedrals are listed to be removed but are not present in the HTF these involve linear nitrile groups for example
# check if these missed dihedrals are this type
dummy_group = dummy_new_atoms if lambda_key == "lambda_1" else dummy_old_atoms
for dihedral in list(not_applied):
# if this is a linear group torsion then atom 2 or 3 will be bonded to only 2 other atoms
a2_bonds = bonded_atom_lookup[dihedral[1]] - dummy_group
a3_bonds = bonded_atom_lookup[dihedral[2]] - dummy_group
if len(a2_bonds) <= 2 or len(a3_bonds) <= 2:
not_applied.remove(dihedral)
if len(not_applied) > 0:
raise ValueError(f"The following {correction_type} corrections for {lambda_key} were not applied: {not_applied}")
htf._hybrid_system = new_hybrid_system
# set the hybrid system forces dict to the new one
htf._hybrid_system_forces = {force.getName(): force for force in new_hybrid_system.getForces()}