diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 11445c4c..a0680555 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,22 +1,32 @@ ci: - autofix_commit_msg: | - [pre-commit.ci] auto fixes from pre-commit.com hooks - - for more information, see https://pre-commit.ci - autofix_prs: true - autoupdate_branch: '' - autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' autoupdate_schedule: quarterly + # comment / label "pre-commit.ci autofix" to a pull request to manually trigger auto-fixing + autofix_prs: false skip: [] submodules: false + repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v6.0.0 hooks: - id: check-added-large-files + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-symlinks + - id: check-toml + # - id: check-yaml # TODO: resolve violation in devtools/installer/construct.yaml + - id: debug-statements - repo: https://github.com/tox-dev/pyproject-fmt - rev: "v2.7.0" + rev: "v2.8.0" hooks: - id: pyproject-fmt +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.13.3 + hooks: + # Run the linter. + - id: ruff + args: [--fix ] + # Run the formatter. + - id: ruff-format diff --git a/devtools/data/fix_rbfe_results.py b/devtools/data/fix_rbfe_results.py index ccc1beb2..e729cf48 100644 --- a/devtools/data/fix_rbfe_results.py +++ b/devtools/data/fix_rbfe_results.py @@ -4,6 +4,7 @@ Useful if Settings are ever changed in a backwards-incompatible way Will expect "rbfe_results.tar.gz" in this directory, will overwrite this file """ + from gufe.tokenization import JSON_HANDLER import glob import json @@ -20,38 +21,38 @@ def untar(fn): def retar(loc, name): """create tar.gz called *name* of directory *loc*""" - with tarfile.open(name, mode='w:gz') as f: + with tarfile.open(name, mode="w:gz") as f: f.add(loc, arcname=os.path.basename(loc)) def replace_settings(fn, new_settings): """replace settings instances in *fn* with *new_settings*""" - with open(fn, 'r') as f: + with open(fn, "r") as f: data = json.load(f) - for k in data['protocol_result']['data']: - data['protocol_result']['data'][k][0]['inputs']['settings'] = new_settings + for k in data["protocol_result"]["data"]: + data["protocol_result"]["data"][k][0]["inputs"]["settings"] = new_settings - for k in data['unit_results']: - data['unit_results'][k]['inputs']['settings'] = new_settings + for k in data["unit_results"]: + data["unit_results"][k]["inputs"]["settings"] = new_settings - with open(fn, 'w') as f: + with open(fn, "w") as f: json.dump(data, f, cls=JSON_HANDLER.encoder) def fix_rbfe_results(): - untar('rbfe_results.tar.gz') + untar("rbfe_results.tar.gz") # generate valid settings as defaults new_settings = openmm_rfe.RelativeHybridTopologyProtocol.default_settings() # walk over all result jsons - for fn in glob.glob('./results/*json'): + for fn in glob.glob("./results/*json"): # replace instances of settings within with valid settings replace_settings(fn, new_settings) - retar('results', 'rbfe_results.tar.gz') + retar("results", "rbfe_results.tar.gz") -if __name__ == '__main__': +if __name__ == "__main__": fix_rbfe_results() diff --git a/devtools/data/gen_serialized_results.py b/devtools/data/gen_serialized_results.py index cbae94a4..7a18e867 100644 --- a/devtools/data/gen_serialized_results.py +++ b/devtools/data/gen_serialized_results.py @@ -13,18 +13,15 @@ Generates - MDProtocol_json_results.gz - used in md_json fixture """ + import gzip import json import logging import pathlib from rdkit import Chem import tempfile -from openff.toolkit import ( - Molecule, RDKitToolkitWrapper, AmberToolsToolkitWrapper -) -from openff.toolkit.utils.toolkit_registry import ( - toolkit_registry_manager, ToolkitRegistry -) +from openff.toolkit import Molecule, RDKitToolkitWrapper, AmberToolsToolkitWrapper +from openff.toolkit.utils.toolkit_registry import toolkit_registry_manager, ToolkitRegistry from openff.units import unit from kartograf.atom_aligner import align_mol_shape from kartograf import KartografAtomMapper @@ -46,20 +43,18 @@ from openfecli.utils import configure_logger sys.stdout.reconfigure(line_buffering=True) stdout_handler = logging.StreamHandler(sys.stdout) -configure_logger('gufekey', handler=stdout_handler) -configure_logger('gufe', handler=stdout_handler) -configure_logger('openfe', handler=stdout_handler) -configure_logger('openmmtools.multistate.multistatereporter', level=logging.DEBUG, handler=stdout_handler) -configure_logger('openmmtools.multistate.multistatesampler', level=logging.DEBUG, handler=stdout_handler) +configure_logger("gufekey", handler=stdout_handler) +configure_logger("gufe", handler=stdout_handler) +configure_logger("openfe", handler=stdout_handler) +configure_logger("openmmtools.multistate.multistatereporter", level=logging.DEBUG, handler=stdout_handler) # fmt: skip +configure_logger("openmmtools.multistate.multistatesampler", level=logging.DEBUG, handler=stdout_handler) # fmt: skip logger = logging.getLogger(__name__) LIGA = "[H]C([H])([H])C([H])([H])C(=O)C([H])([H])C([H])([H])[H]" LIGB = "[H]C([H])([H])C(=O)C([H])([H])C([H])([H])C([H])([H])[H]" -amber_rdkit = ToolkitRegistry( - [RDKitToolkitWrapper(), AmberToolsToolkitWrapper()] -) +amber_rdkit = ToolkitRegistry([RDKitToolkitWrapper(), AmberToolsToolkitWrapper()]) def get_molecule(smi, name): @@ -71,12 +66,14 @@ def get_molecule(smi, name): def get_hif2a_inputs(): - with gzip.open('inputs/hif2a_protein.pdb.gz', 'r') as f: - protcomp = openfe.ProteinComponent.from_pdb_file(f, name='hif2a_prot') + with gzip.open("inputs/hif2a_protein.pdb.gz", "r") as f: + protcomp = openfe.ProteinComponent.from_pdb_file(f, name="hif2a_prot") - with gzip.open('inputs/hif2a_ligands.sdf.gz', 'r') as f: - smcs = [openfe.SmallMoleculeComponent(mol) for mol in - list(Chem.ForwardSDMolSupplier(f, removeHs=False))] + with gzip.open("inputs/hif2a_ligands.sdf.gz", "r") as f: + smcs = [ + openfe.SmallMoleculeComponent(mol) + for mol in list(Chem.ForwardSDMolSupplier(f, removeHs=False)) + ] return smcs, protcomp @@ -86,7 +83,7 @@ def execute_and_serialize( protocol, simname, new_serialization: bool = False -): +): # fmt: skip """ Execute & serialize a DAG @@ -127,9 +124,9 @@ def execute_and_serialize( unit.key: unit.to_keyed_dict() for unit in dagres.protocol_unit_results } - } + } # fmt: skip - with gzip.open(f"{simname}_json_results.gz", 'wt') as zipfile: + with gzip.open(f"{simname}_json_results.gz", "wt") as zipfile: json.dump(outdict, zipfile, cls=JSON_HANDLER.encoder) @@ -165,13 +162,13 @@ def generate_abfe_settings(): settings.complex_simulation_settings.equilibration_length = 100 * unit.picosecond settings.complex_simulation_settings.production_length = 500 * unit.picosecond settings.complex_simulation_settings.time_per_iteration = 2.5 * unit.ps - settings.solvent_solvation_settings.box_shape = 'dodecahedron' - settings.complex_solvation_settings.box_shape = 'dodecahedron' + settings.solvent_solvation_settings.box_shape = "dodecahedron" + settings.complex_solvation_settings.box_shape = "dodecahedron" settings.solvent_solvation_settings.solvent_padding = 1.5 * unit.nanometer settings.complex_solvation_settings.solvent_padding = 1.0 * unit.nanometer settings.forcefield_settings.nonbonded_cutoff = 0.8 * unit.nanometer settings.protocol_repeats = 3 - settings.engine_settings.compute_platform = 'CUDA' + settings.engine_settings.compute_platform = "CUDA" return settings @@ -210,29 +207,25 @@ def generate_ahfe_settings(): settings.vacuum_simulation_settings.production_length = 1000 * unit.picosecond settings.lambda_settings.lambda_elec = [0.0, 0.25, 0.5, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0] + 1.0] # fmt: skip settings.lambda_settings.lambda_vdw = [0.0, 0.0, 0.0, 0.0, 0.0, 0.12, 0.24, 0.36, 0.48, 0.6, 0.7, 0.77, 0.85, - 1.0] + 1.0] # fmt: skip settings.protocol_repeats = 3 settings.solvent_simulation_settings.n_replicas = 14 settings.vacuum_simulation_settings.n_replicas = 14 - settings.solvent_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole - settings.vacuum_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole - settings.vacuum_engine_settings.compute_platform = 'CPU' - settings.solvent_engine_settings.compute_platform = 'CUDA' + settings.solvent_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole # fmt: skip + settings.vacuum_simulation_settings.early_termination_target_error = 0.12 * unit.kilocalorie_per_mole # fmt: skip + settings.vacuum_engine_settings.compute_platform = "CPU" + settings.solvent_engine_settings.compute_platform = "CUDA" return settings def generate_ahfe_json(smc): protocol = AbsoluteSolvationProtocol(settings=generate_ahfe_settings()) - sysA = openfe.ChemicalSystem( - {"ligand": smc, "solvent": openfe.SolventComponent()} - ) - sysB = openfe.ChemicalSystem( - {"solvent": openfe.SolventComponent()} - ) + sysA = openfe.ChemicalSystem({"ligand": smc, "solvent": openfe.SolventComponent()}) + sysB = openfe.ChemicalSystem({"solvent": openfe.SolventComponent()}) dag = protocol.create(stateA=sysA, stateB=sysB, mapping=None) @@ -255,12 +248,10 @@ def generate_rfe_json(smcA, smcB): mapper = KartografAtomMapper(atom_map_hydrogens=True) mapping = next(mapper.suggest_mappings(smcA, a_smcB)) - systemA = openfe.ChemicalSystem({'ligand': smcA}) - systemB = openfe.ChemicalSystem({'ligand': a_smcB}) + systemA = openfe.ChemicalSystem({"ligand": smcA}) + systemB = openfe.ChemicalSystem({"ligand": a_smcB}) - dag = protocol.create( - stateA=systemA, stateB=systemB, mapping=mapping - ) + dag = protocol.create(stateA=systemA, stateB=systemB, mapping=mapping) execute_and_serialize(dag, protocol, "RHFEProtocol") @@ -279,13 +270,13 @@ def generate_septop_settings(): settings.complex_simulation_settings.equilibration_length = 10 * unit.picosecond settings.complex_simulation_settings.production_length = 50 * unit.picosecond settings.complex_simulation_settings.time_per_iteration = 2.5 * unit.ps - settings.solvent_solvation_settings.box_shape = 'dodecahedron' - settings.complex_solvation_settings.box_shape = 'dodecahedron' + settings.solvent_solvation_settings.box_shape = "dodecahedron" + settings.complex_solvation_settings.box_shape = "dodecahedron" settings.solvent_solvation_settings.solvent_padding = 1.2 * unit.nanometer settings.complex_solvation_settings.solvent_padding = 1.0 * unit.nanometer settings.forcefield_settings.nonbonded_cutoff = 0.9 * unit.nanometer settings.protocol_repeats = 1 - settings.engine_settings.compute_platform = 'CUDA' + settings.engine_settings.compute_platform = "CUDA" return settings diff --git a/docs/conf.py b/docs/conf.py index 18eb2537..9fde336b 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -21,7 +21,7 @@ from git import Repo import nbsphinx import nbformat -sys.path.insert(0, os.path.abspath('../')) +sys.path.insert(0, os.path.abspath("../")) os.environ["SPHINX"] = "True" @@ -31,7 +31,7 @@ os.environ["SPHINX"] = "True" project = "OpenFE" copyright = "2022, The OpenFE Development Team" author = "The OpenFE Development Team" - # don't include patch version (https://github.com/OpenFreeEnergy/openfe/issues/1261) +# don't include patch version (https://github.com/OpenFreeEnergy/openfe/issues/1261) version = f"{parse(version('openfe')).major}.{parse(version('openfe')).minor}" # -- General configuration --------------------------------------------------- @@ -55,7 +55,7 @@ extensions = [ "nbsphinx_link", "sphinx.ext.mathjax", ] -suppress_warnings = ["config.cache"] # https://github.com/sphinx-doc/sphinx/issues/12300 +suppress_warnings = ["config.cache"] # https://github.com/sphinx-doc/sphinx/issues/12300 intersphinx_mapping = { "python": ("https://docs.python.org/3.9", None), @@ -147,11 +147,11 @@ html_theme_options = { "navigation_with_keys": False, } html_logo = "_static/OFE-color-icon.svg" -html_favicon = '_static/OFE-color-icon.svg' +html_favicon = "_static/OFE-color-icon.svg" # temporary fix, see https://github.com/pydata/pydata-sphinx-theme/issues/1662 html_sidebars = { "installation": [], - "CHANGELOG":[], + "CHANGELOG": [], } # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, @@ -183,11 +183,11 @@ example_notebooks_path = Path("ExampleNotebooks") try: if example_notebooks_path.exists(): repo = Repo(example_notebooks_path) - repo.remote('origin').pull() + repo.remote("origin").pull() else: repo = Repo.clone_from( "https://github.com/OpenFreeEnergy/ExampleNotebooks.git", - branch='oct-2025', + branch="oct-2025", to_path=example_notebooks_path, ) except Exception as e: @@ -195,7 +195,7 @@ except Exception as e: filename = e.__traceback__.tb_frame.f_code.co_filename lineno = e.__traceback__.tb_lineno - getLogger('sphinx.ext.openfe_git').warning( + getLogger("sphinx.ext.openfe_git").warning( f"Getting ExampleNotebooks failed in {filename} line {lineno}: {e}" ) diff --git a/openfe/__init__.py b/openfe/__init__.py index a944cd8f..c37c9f5b 100644 --- a/openfe/__init__.py +++ b/openfe/__init__.py @@ -1,9 +1,15 @@ # silence pymbar logging warnings import logging + + def _mute_timeseries(record): return not "Warning on use of the timeseries module:" in record.msg + + def _mute_jax(record): return not "****** PyMBAR will use 64-bit JAX! *******" in record.msg + + _mbar_log = logging.getLogger("pymbar.timeseries") _mbar_log.addFilter(_mute_timeseries) _mbar_log = logging.getLogger("pymbar.mbar_solvers") @@ -24,7 +30,8 @@ from gufe.protocols import ( Protocol, ProtocolDAG, ProtocolUnit, - ProtocolUnitResult, ProtocolUnitFailure, + ProtocolUnitResult, + ProtocolUnitFailure, ProtocolDAGResult, ProtocolResult, execute_DAG, @@ -45,4 +52,5 @@ from . import orchestration from . import analysis from importlib.metadata import version + __version__ = version("openfe") diff --git a/openfe/analysis/plotting.py b/openfe/analysis/plotting.py index e1676bc6..9d2351b7 100644 --- a/openfe/analysis/plotting.py +++ b/openfe/analysis/plotting.py @@ -38,15 +38,16 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: # Check if any row or column isn't close to 1.0 # Throw a warning if it's the case - if (not np.allclose(matrix.sum(axis=0), 1.0) or - not np.allclose(matrix.sum(axis=1), 1.0)): - wmsg = ("Overlap/probability matrix exceeds a sum of 1.0 in one or " - "more columns or rows of the matrix. This indicates an " - "incorrect overlap/probability matrix.") + if not np.allclose(matrix.sum(axis=0), 1.0) or not np.allclose(matrix.sum(axis=1), 1.0): + wmsg = ( + "Overlap/probability matrix exceeds a sum of 1.0 in one or " + "more columns or rows of the matrix. This indicates an " + "incorrect overlap/probability matrix." + ) warnings.warn(wmsg) fig, ax = plt.subplots(figsize=(num_states / 2, num_states / 2)) - ax.axis('off') + ax.axis("off") for i in range(num_states): if i != 0: ax.axvline(x=i, ls="-", lw=0.5, color="k", alpha=0.25) @@ -69,28 +70,37 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: # shade box ax.fill_between( - [i, i+1], [num_states - j, num_states - j], + [i, i + 1], + [num_states - j, num_states - j], [num_states - (j + 1), num_states - (j + 1)], - color='k', alpha=rel_prob + color="k", + alpha=rel_prob, ) # annotate box ax.annotate( - val_str, xy=(i, j), xytext=(i+0.5, num_states - (j + 0.5)), - size=8, va="center", ha="center", + val_str, + xy=(i, j), + xytext=(i + 0.5, num_states - (j + 0.5)), + size=8, + va="center", + ha="center", color=("k" if rel_prob < 0.5 else "w"), ) # anotate axes base_settings: dict[str, Union[str, int]] = { - 'size': 10, 'va': 'center', 'ha': 'center', 'color': 'k', - 'family': 'sans-serif' + "size": 10, + "va": "center", + "ha": "center", + "color": "k", + "family": "sans-serif", } for i in range(num_states): ax.annotate( text=f"{i}", xy=(i + 0.5, 1), xytext=(i + 0.5, num_states + 0.5), - xycoords='data', + xycoords="data", textcoords=None, arrowprops=None, annotation_clip=None, @@ -100,7 +110,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: text=f"{i}", xy=(-0.5, num_states - (num_states - 0.5)), xytext=(-0.5, num_states - (i + 0.5)), - xycoords='data', + xycoords="data", textcoords=None, arrowprops=None, annotation_clip=None, @@ -111,7 +121,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: r"$\lambda$", xy=(-0.5, num_states - (num_states - 0.5)), xytext=(-0.5, num_states + 0.5), - xycoords='data', + xycoords="data", textcoords=None, arrowprops=None, annotation_clip=None, @@ -128,8 +138,7 @@ def plot_lambda_transition_matrix(matrix: npt.NDArray) -> Axes: def plot_convergence( - forward_and_reverse: dict[str, Union[npt.NDArray, Quantity]], - units: Quantity + forward_and_reverse: dict[str, Union[npt.NDArray, Quantity]], units: Quantity ) -> Axes: """ Plot a Reverse and Forward convergence analysis of the @@ -154,18 +163,20 @@ def plot_convergence( Modified from `alchemical analysis <>`_ """ known_units = { - 'kilojoule_per_mole': 'kJ/mol', - 'kilojoules_per_mole': 'kJ/mol', - 'kilocalorie_per_mole': 'kcal/mol', - 'kilocalories_per_mole': 'kcal/mol', + "kilojoule_per_mole": "kJ/mol", + "kilojoules_per_mole": "kJ/mol", + "kilocalorie_per_mole": "kcal/mol", + "kilocalories_per_mole": "kcal/mol", } try: plt_units = known_units[str(units)] except KeyError: - errmsg = (f"Unknown plotting units {units} passed, acceptable " - "values are kilojoule(s)_per_mole and " - "kilocalorie(s)_per_mole") + errmsg = ( + f"Unknown plotting units {units} passed, acceptable " + "values are kilojoule(s)_per_mole and " + "kilocalorie(s)_per_mole" + ) raise ValueError(errmsg) fig, ax = plt.subplots(figsize=(8, 6)) @@ -181,38 +192,45 @@ def plot_convergence( ax.yaxis.set_ticks_position("left") # Set the overall error bar to the final error for the reverse results - overall_error = forward_and_reverse['reverse_dDGs'][-1].m # type: ignore - final_value = forward_and_reverse['reverse_DGs'][-1].m # type: ignore - ax.fill_between([0, 1], - final_value - overall_error, - final_value + overall_error, - color='#D2B9D3', zorder=1) - - ax.errorbar( - forward_and_reverse['fractions'], # type: ignore - [val.m - for val in forward_and_reverse['forward_DGs']], # type: ignore - yerr=[err.m - for err in forward_and_reverse['forward_dDGs']], # type: ignore - color="#736AFF", lw=3, zorder=2, - marker="o", mfc="w", mew=2.5, - mec="#736AFF", ms=8, label='Forward' + overall_error = forward_and_reverse["reverse_dDGs"][-1].m # type: ignore + final_value = forward_and_reverse["reverse_DGs"][-1].m # type: ignore + ax.fill_between( + [0, 1], final_value - overall_error, final_value + overall_error, color="#D2B9D3", zorder=1 ) ax.errorbar( - forward_and_reverse['fractions'], # type: ignore - [val.m - for val in forward_and_reverse['reverse_DGs']], # type: ignore - yerr=[err.m - for err in forward_and_reverse['reverse_dDGs']], # type: ignore - color="#C11B17", lw=3, zorder=2, - marker="o", mfc="w", mew=2.5, - mec="#C11B17", ms=8, label='Reverse', + forward_and_reverse["fractions"], # type: ignore + [val.m for val in forward_and_reverse["forward_DGs"]], # type: ignore + yerr=[err.m for err in forward_and_reverse["forward_dDGs"]], # type: ignore + color="#736AFF", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#736AFF", + ms=8, + label="Forward", + ) + + ax.errorbar( + forward_and_reverse["fractions"], # type: ignore + [val.m for val in forward_and_reverse["reverse_DGs"]], # type: ignore + yerr=[err.m for err in forward_and_reverse["reverse_dDGs"]], # type: ignore + color="#C11B17", + lw=3, + zorder=2, + marker="o", + mfc="w", + mew=2.5, + mec="#C11B17", + ms=8, + label="Reverse", ) ax.legend(frameon=False) - ax.set_ylabel(r'$\Delta G$' + f' ({plt_units})') - ax.set_xlabel('Fraction of uncorrelated samples') + ax.set_ylabel(r"$\Delta G$" + f" ({plt_units})") + ax.set_xlabel("Fraction of uncorrelated samples") return ax @@ -242,7 +260,7 @@ def plot_replica_timeseries( iterations = [i for i in range(len(state_timeseries))] for i in range(num_states): - ax.scatter(iterations, state_timeseries.T[i], label=f'replica {i}', s=8) + ax.scatter(iterations, state_timeseries.T[i], label=f"replica {i}", s=8) ax.set_xlabel("Number of simulation iterations") ax.set_ylabel("Lambda state") @@ -250,16 +268,14 @@ def plot_replica_timeseries( if equilibration_iterations is not None: ax.axvline( - x=equilibration_iterations, color='grey', - linestyle='--', label='equilibration limit' + x=equilibration_iterations, color="grey", linestyle="--", label="equilibration limit" ) - ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + ax.legend(loc="center left", bbox_to_anchor=(1, 0.5)) return ax -def plot_2D_rmsd(data: list[list[float]], - vmax=5.0) -> plt.Figure: +def plot_2D_rmsd(data: list[list[float]], vmax=5.0) -> plt.Figure: """Plots 2D RMSD for many states Parameters @@ -292,26 +308,25 @@ def plot_2D_rmsd(data: list[list[float]], fig, axes = plt.subplots(nrows, 4) - for i, (arr, ax) in enumerate( - zip(twod_rmsd_arrs, axes.flatten())): - ax.imshow(arr, - vmin=0, vmax=vmax, - cmap=plt.get_cmap('cividis')) - ax.axis('off') # turn off ticks/labels - ax.set_title(f'State {i}') + for i, (arr, ax) in enumerate(zip(twod_rmsd_arrs, axes.flatten())): + ax.imshow(arr, vmin=0, vmax=vmax, cmap=plt.get_cmap("cividis")) + ax.axis("off") # turn off ticks/labels + ax.set_title(f"State {i}") # if we have any leftover plots then we turn them off # except the last one! overage = len(axes.flatten()) - len(twod_rmsd_arrs) - for i in range(overage, len(axes.flatten())-1): + for i in range(overage, len(axes.flatten()) - 1): axes.flatten()[i].set_axis_off() - plt.colorbar(axes.flatten()[0].images[0], - cax=axes.flatten()[-1], - label="RMSD scale (A)", - orientation="horizontal") + plt.colorbar( + axes.flatten()[0].images[0], + cax=axes.flatten()[-1], + label="RMSD scale (A)", + orientation="horizontal", + ) - fig.suptitle('Protein 2D RMSD') + fig.suptitle("Protein 2D RMSD") fig.tight_layout() return fig @@ -321,12 +336,12 @@ def plot_ligand_COM_drift(time: list[float], data: list[list[float]]): fig, ax = plt.subplots() for i, s in enumerate(data): - ax.plot(time, s, label=f'State {i}') + ax.plot(time, s, label=f"State {i}") - ax.legend(loc='upper left') - ax.set_xlabel('Time (ps)') - ax.set_ylabel('Distance (A)') - ax.set_title('Ligand COM drift') + ax.legend(loc="upper left") + ax.set_xlabel("Time (ps)") + ax.set_ylabel("Distance (A)") + ax.set_title("Ligand COM drift") return fig @@ -335,11 +350,11 @@ def plot_ligand_RMSD(time: list[float], data: list[list[float]]): fig, ax = plt.subplots() for i, s in enumerate(data): - ax.plot(time, s, label=f'State {i}') + ax.plot(time, s, label=f"State {i}") - ax.legend(loc='upper left') - ax.set_xlabel('Time (ps)') - ax.set_ylabel('RMSD (A)') - ax.set_title('Ligand RMSD') + ax.legend(loc="upper left") + ax.set_xlabel("Time (ps)") + ax.set_ylabel("RMSD (A)") + ax.set_title("Ligand RMSD") return fig diff --git a/openfe/due.py b/openfe/due.py index f729f843..05060832 100644 --- a/openfe/due.py +++ b/openfe/due.py @@ -24,26 +24,29 @@ Copyright: 2015-2021 DueCredit developers License: BSD-2 """ -__version__ = '0.0.9' +__version__ = "0.0.9" class InactiveDueCreditCollector(object): """Just a stub at the Collector which would not do anything""" + def _donothing(self, *args, **kwargs): """Perform no good and no bad""" pass def dcite(self, *args, **kwargs): """If I could cite I would""" + def nondecorating_decorator(func): return func + return nondecorating_decorator active = False activate = add = cite = dump = load = _donothing def __repr__(self): - return self.__class__.__name__ + '()' + return self.__class__.__name__ + "()" def _donothing_func(*args, **kwargs): @@ -53,14 +56,14 @@ def _donothing_func(*args, **kwargs): try: from duecredit import due, BibTeX, Doi, Url, Text # lgtm [py/unused-import] - if 'due' in locals() and not hasattr(due, 'cite'): - raise RuntimeError( - "Imported due lacks .cite. DueCredit is now disabled") + + if "due" in locals() and not hasattr(due, "cite"): + raise RuntimeError("Imported due lacks .cite. DueCredit is now disabled") except Exception as e: if not isinstance(e, ImportError): import logging - logging.getLogger("duecredit").error( - "Failed to import duecredit due to %s" % str(e)) + + logging.getLogger("duecredit").error("Failed to import duecredit due to %s" % str(e)) # Initiate due stub due = InactiveDueCreditCollector() BibTeX = Doi = Url = Text = _donothing_func diff --git a/openfe/protocols/openmm_afe/base.py b/openfe/protocols/openmm_afe/base.py index 5de03bc7..d5a7b8f9 100644 --- a/openfe/protocols/openmm_afe/base.py +++ b/openfe/protocols/openmm_afe/base.py @@ -14,6 +14,7 @@ TODO as settings. * Allow for a more flexible setting of Lambda regions. """ + from __future__ import annotations import abc @@ -30,12 +31,17 @@ from openff.units import unit, Quantity from openff.units.openmm import from_openmm, to_openmm, ensure_quantity from openff.toolkit.topology import Molecule as OFFMolecule from openmmtools import multistate -from openmmtools.states import (SamplerState, - ThermodynamicState, - GlobalParameterState, - create_thermodynamic_state_protocol,) -from openmmtools.alchemy import (AlchemicalRegion, AbsoluteAlchemicalFactory, - AlchemicalState,) +from openmmtools.states import ( + SamplerState, + ThermodynamicState, + GlobalParameterState, + create_thermodynamic_state_protocol, +) +from openmmtools.alchemy import ( + AlchemicalRegion, + AbsoluteAlchemicalFactory, + AlchemicalState, +) from typing import Optional from openmm import app from openmm import unit as omm_unit @@ -45,10 +51,7 @@ from typing import Any import openmmtools import mdtraj as mdt -from gufe import ( - ChemicalSystem, SmallMoleculeComponent, - ProteinComponent, SolventComponent -) +from gufe import ChemicalSystem, SmallMoleculeComponent, ProteinComponent, SolventComponent from openfe.protocols.openmm_utils.omm_settings import ( SettingsBaseModel, ) @@ -57,21 +60,24 @@ from openfe.protocols.openmm_utils.omm_settings import ( ) from openfe.protocols.openmm_afe.equil_afe_settings import ( BaseSolvationSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, MultiStateOutputSettings, - ThermoSettings, OpenFFPartialChargeSettings, + MultiStateSimulationSettings, + OpenMMEngineSettings, + IntegratorSettings, + MultiStateOutputSettings, + ThermoSettings, + OpenFFPartialChargeSettings, OpenMMSystemGeneratorFFSettings, ) from openfe.protocols.openmm_md.plain_md_methods import PlainMDProtocolUnit from openfe.protocols.openmm_utils import ( - settings_validation, system_creation, - multistate_analysis, charge_generation, + settings_validation, + system_creation, + multistate_analysis, + charge_generation, omm_compute, ) from openfe.protocols.restraint_utils import geometry -from openfe.utils import ( - without_oechem_backend, log_system_probe -) +from openfe.utils import without_oechem_backend, log_system_probe logger = logging.getLogger(__name__) @@ -81,14 +87,18 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ Base class for ligand absolute free energy transformations. """ - def __init__(self, *, - protocol: gufe.Protocol, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - alchemical_components: dict[str, list[Component]], - generation: int = 0, - repeat_id: int = 0, - name: Optional[str] = None,): + + def __init__( + self, + *, + protocol: gufe.Protocol, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + alchemical_components: dict[str, list[Component]], + generation: int = 0, + repeat_id: int = 0, + name: Optional[str] = None, + ): """ Parameters ---------- @@ -123,10 +133,11 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ) @staticmethod - def _get_alchemical_indices(omm_top: openmm.Topology, - comp_resids: dict[Component, npt.NDArray], - alchem_comps: dict[str, list[Component]] - ) -> list[int]: + def _get_alchemical_indices( + omm_top: openmm.Topology, + comp_resids: dict[Component, npt.NDArray], + alchem_comps: dict[str, list[Component]], + ) -> list[int]: """ Get a list of atom indices for all the alchemical species @@ -146,9 +157,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ # concatenate a list of residue indexes for all alchemical components - residxs = np.concatenate( - [comp_resids[key] for key in alchem_comps['stateA']] - ) + residxs = np.concatenate([comp_resids[key] for key in alchem_comps["stateA"]]) # get the alchemicical atom ids atom_ids = [] @@ -165,7 +174,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): topology: openmm.app.Topology, positions: omm_unit.Quantity, settings: dict[str, SettingsBaseModel], - dry: bool + dry: bool, ) -> tuple[omm_unit.Quantity, omm_unit.Quantity]: """ Run a non-alchemical equilibration to get a stable system. @@ -199,17 +208,17 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ # Prep the simulation object # Restrict CPU count if running vacuum simulation - restrict_cpu = settings['forcefield_settings'].nonbonded_method.lower() == 'nocutoff' + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( - platform_name=settings['engine_settings'].compute_platform, - gpu_device_index=settings['engine_settings'].gpu_device_index, - restrict_cpu_count=restrict_cpu + platform_name=settings["engine_settings"].compute_platform, + gpu_device_index=settings["engine_settings"].gpu_device_index, + restrict_cpu_count=restrict_cpu, ) integrator = openmm.LangevinMiddleIntegrator( - to_openmm(settings['thermo_settings'].temperature), - to_openmm(settings['integrator_settings'].langevin_collision_rate), - to_openmm(settings['integrator_settings'].timestep), + to_openmm(settings["thermo_settings"].temperature), + to_openmm(settings["integrator_settings"].langevin_collision_rate), + to_openmm(settings["integrator_settings"].timestep), ) simulation = openmm.app.Simulation( @@ -220,25 +229,24 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ) # Get the necessary number of steps - if settings['equil_simulation_settings'].equilibration_length_nvt is not None: + if settings["equil_simulation_settings"].equilibration_length_nvt is not None: equil_steps_nvt = settings_validation.get_simsteps( - sim_length=settings[ - 'equil_simulation_settings'].equilibration_length_nvt, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["equil_simulation_settings"].equilibration_length_nvt, + timestep=settings["integrator_settings"].timestep, mc_steps=1, ) else: equil_steps_nvt = None equil_steps_npt = settings_validation.get_simsteps( - sim_length=settings['equil_simulation_settings'].equilibration_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["equil_simulation_settings"].equilibration_length, + timestep=settings["integrator_settings"].timestep, mc_steps=1, ) prod_steps_npt = settings_validation.get_simsteps( - sim_length=settings['equil_simulation_settings'].production_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["equil_simulation_settings"].production_length, + timestep=settings["integrator_settings"].timestep, mc_steps=1, ) @@ -254,11 +262,11 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): PlainMDProtocolUnit._run_MD( simulation=simulation, positions=positions, - simulation_settings=settings['equil_simulation_settings'], - output_settings=settings['equil_output_settings'], - temperature=settings['thermo_settings'].temperature, - barostat_frequency=settings['integrator_settings'].barostat_frequency, - timestep=settings['integrator_settings'].timestep, + simulation_settings=settings["equil_simulation_settings"], + output_settings=settings["equil_output_settings"], + temperature=settings["thermo_settings"].temperature, + barostat_frequency=settings["integrator_settings"].barostat_frequency, + timestep=settings["integrator_settings"].timestep, equil_steps_nvt=equil_steps_nvt, equil_steps_npt=equil_steps_npt, prod_steps=prod_steps_npt, @@ -278,7 +286,8 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): return equilibrated_positions, box def _prepare( - self, verbose: bool, + self, + verbose: bool, scratch_basepath: Optional[pathlib.Path], shared_basepath: Optional[pathlib.Path], ): @@ -301,17 +310,21 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # set basepaths def _set_optional_path(basepath): if basepath is None: - return pathlib.Path('.') + return pathlib.Path(".") return basepath self.scratch_basepath = _set_optional_path(scratch_basepath) self.shared_basepath = _set_optional_path(shared_basepath) @abc.abstractmethod - def _get_components(self) -> tuple[dict[str, list[Component]], - Optional[gufe.SolventComponent], - Optional[gufe.ProteinComponent], - dict[SmallMoleculeComponent, OFFMolecule]]: + def _get_components( + self, + ) -> tuple[ + dict[str, list[Component]], + Optional[gufe.SolventComponent], + Optional[gufe.ProteinComponent], + dict[SmallMoleculeComponent, OFFMolecule], + ]: """ Get the relevant components to create the alchemical system with. @@ -349,8 +362,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ... def _get_system_generator( - self, settings: dict[str, SettingsBaseModel], - solvent_comp: Optional[SolventComponent] + self, settings: dict[str, SettingsBaseModel], solvent_comp: Optional[SolventComponent] ) -> SystemGenerator: """ Get a system generator through the system creation @@ -368,7 +380,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): system_generator : openmmforcefields.generator.SystemGenerator System Generator to parameterise this unit. """ - ffcache = settings['output_settings'].forcefield_cache + ffcache = settings["output_settings"].forcefield_cache if ffcache is not None: ffcache = self.shared_basepath / ffcache @@ -376,9 +388,9 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # smiles roundtripping between rdkit and oechem with without_oechem_backend(): system_generator = system_creation.get_system_generator( - forcefield_settings=settings['forcefield_settings'], - integrator_settings=settings['integrator_settings'], - thermo_settings=settings['thermo_settings'], + forcefield_settings=settings["forcefield_settings"], + integrator_settings=settings["integrator_settings"], + thermo_settings=settings["thermo_settings"], cache=ffcache, has_solvent=solvent_comp is not None, ) @@ -417,7 +429,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): smc_components: dict[SmallMoleculeComponent, OFFMolecule], system_generator: SystemGenerator, partial_charge_settings: BasePartialChargeSettings, - solvation_settings: BaseSolvationSettings + solvation_settings: BaseSolvationSettings, ) -> tuple[app.Modeller, dict[Component, npt.NDArray]]: """ Get an OpenMM Modeller object and a list of residue indices @@ -462,9 +474,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # smiles roundtripping between rdkit and oechem with without_oechem_backend(): for mol in smc_components.values(): - system_generator.create_system( - mol.to_topology().to_openmm(), molecules=[mol] - ) + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # get OpenMM modeller + dictionary of resids for each component system_modeller, comp_resids = system_creation.get_omm_modeller( @@ -519,8 +529,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): self.logger.info("Parameterizing system") system_generator = self._get_system_generator( - settings=settings, - solvent_comp=solvent_component + settings=settings, solvent_comp=solvent_component ) modeller, comp_resids = self._get_modeller( @@ -528,8 +537,8 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): solvent_component=solvent_component, smc_components=smc_components, system_generator=system_generator, - partial_charge_settings=settings['charge_settings'], - solvation_settings=settings['solvation_settings'] + partial_charge_settings=settings["charge_settings"], + solvation_settings=settings["solvation_settings"], ) topology = modeller.getTopology() @@ -546,16 +555,10 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # Check and fail early on the presence of virtual sites # and multistate sampler not using velocity restart - if not settings['integrator_settings'].reassign_velocities: - has_vsite = any( - system.isVirtualSite(i) - for i in range(system.getNumParticles()) - ) + if not settings["integrator_settings"].reassign_velocities: + has_vsite = any(system.isVirtualSite(i) for i in range(system.getNumParticles())) if has_vsite: - errmsg = ( - "Simulations with virtual sites without " - "velocity reassignment are unstable" - ) + errmsg = "Simulations with virtual sites without velocity reassignment are unstable" raise ValueError(errmsg) return topology, system, positions, comp_resids @@ -582,16 +585,16 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ lambdas = dict() - lambda_elec = settings['lambda_settings'].lambda_elec - lambda_vdw = settings['lambda_settings'].lambda_vdw - lambda_rest = settings['lambda_settings'].lambda_restraints + lambda_elec = settings["lambda_settings"].lambda_elec + lambda_vdw = settings["lambda_settings"].lambda_vdw + lambda_rest = settings["lambda_settings"].lambda_restraints # Reverse lambda schedule for vdw, elect, and restraints # since in AbsoluteAlchemicalFactory 1 means fully # interacting (which would be non-interacting for us) - lambdas['lambda_electrostatics'] = [1-x for x in lambda_elec] - lambdas['lambda_sterics'] = [1-x for x in lambda_vdw] - lambdas['lambda_restraints'] = [x for x in lambda_rest] + lambdas["lambda_electrostatics"] = [1 - x for x in lambda_elec] + lambdas["lambda_sterics"] = [1 - x for x in lambda_vdw] + lambdas["lambda_restraints"] = [x for x in lambda_rest] return lambdas @@ -602,7 +605,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): positions: openmm.unit.Quantity, alchem_comps: dict[str, list[Component]], comp_resids: dict[Component, npt.NDArray], - settings: dict[str, SettingsBaseModel], + settings: dict[str, SettingsBaseModel], ) -> tuple[ Optional[GlobalParameterState], Optional[Quantity], @@ -619,7 +622,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): topology: app.Topology, system: openmm.System, comp_resids: dict[Component, npt.NDArray], - alchem_comps: dict[str, list[Component]] + alchem_comps: dict[str, list[Component]], ) -> tuple[AbsoluteAlchemicalFactory, openmm.System, list[int]]: """ Get an alchemically modified system and its associated factory @@ -649,18 +652,14 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ---- * Add support for all alchemical factory options """ - alchemical_indices = self._get_alchemical_indices( - topology, comp_resids, alchem_comps - ) + alchemical_indices = self._get_alchemical_indices(topology, comp_resids, alchem_comps) alchemical_region = AlchemicalRegion( alchemical_atoms=alchemical_indices, ) alchemical_factory = AbsoluteAlchemicalFactory() - alchemical_system = alchemical_factory.create_alchemical_system( - system, alchemical_region - ) + alchemical_system = alchemical_factory.create_alchemical_system(system, alchemical_region) return alchemical_factory, alchemical_system, alchemical_indices @@ -706,13 +705,13 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): alchemical_state = AlchemicalState.from_system(alchemical_system) # Set up the system constants - temperature = settings['thermo_settings'].temperature - pressure = settings['thermo_settings'].pressure + temperature = settings["thermo_settings"].temperature + pressure = settings["thermo_settings"].pressure constants = dict() - constants['temperature'] = ensure_quantity(temperature, 'openmm') + constants["temperature"] = ensure_quantity(temperature, "openmm") if solvent_comp is not None: - constants['pressure'] = ensure_quantity(pressure, 'openmm') + constants["pressure"] = ensure_quantity(pressure, "openmm") # Get the thermodynamic parameter protocol param_protocol = copy.deepcopy(lambdas) @@ -721,11 +720,11 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): if restraint_state is not None: composable_states = [alchemical_state, restraint_state] else: - composable_states = [alchemical_state,] + composable_states = [alchemical_state] # In this case we also don't have a restraint being controlled # so we drop it from the protocol - param_protocol.pop('lambda_restraints', None) + param_protocol.pop("lambda_restraints", None) cmp_states = create_thermodynamic_state_protocol( alchemical_system, @@ -773,9 +772,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # Store the selection indices in self to use later # when storing them in the unit results - self.selection_indices = mdt_top.select( - output_settings.output_indices - ) + self.selection_indices = mdt_top.select(output_settings.output_indices) nc = self.shared_basepath / output_settings.output_filename chk = output_settings.checkpoint_storage_filename @@ -789,7 +786,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): numerator=output_settings.positions_write_frequency, denominator=simulation_settings.time_per_iteration, numerator_name="output settings' position_write_frequency", - denominator_name="simulation settings' time_per_iteration" + denominator_name="simulation settings' time_per_iteration", ) else: pos_interval = 0 @@ -799,7 +796,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): numerator=output_settings.velocities_write_frequency, denominator=simulation_settings.time_per_iteration, numerator_name="output settings' velocity_write_frequency", - denominator_name="simulation settings' time_per_iteration" + denominator_name="simulation settings' time_per_iteration", ) else: vel_interval = 0 @@ -819,16 +816,14 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): positions[self.selection_indices, :], mdt_top.subset(self.selection_indices), ) - traj.save_pdb( - self.shared_basepath / output_settings.output_structure - ) + traj.save_pdb(self.shared_basepath / output_settings.output_structure) return reporter def _get_ctx_caches( self, forcefield_settings: OpenMMSystemGeneratorFFSettings, - engine_settings: OpenMMEngineSettings + engine_settings: OpenMMEngineSettings, ) -> tuple[openmmtools.cache.ContextCache, openmmtools.cache.ContextCache]: """ Set the context caches based on the chosen platform @@ -847,27 +842,30 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ # Get the compute platform # Set the number of CPUs to 1 if running a vacuum simulation - restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff' + restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( platform_name=engine_settings.compute_platform, gpu_device_index=engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu + restrict_cpu_count=restrict_cpu, ) energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) return energy_context_cache, sampler_context_cache @staticmethod def _get_integrator( - integrator_settings: IntegratorSettings, - simulation_settings: MultiStateSimulationSettings + integrator_settings: IntegratorSettings, simulation_settings: MultiStateSimulationSettings ) -> openmmtools.mcmc.LangevinDynamicsMove: """ Return a LangevinDynamicsMove integrator @@ -906,7 +904,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): cmp_states: list[ThermodynamicState], sampler_states: list[SamplerState], energy_context_cache: openmmtools.cache.ContextCache, - sampler_context_cache: openmmtools.cache.ContextCache + sampler_context_cache: openmmtools.cache.ContextCache, ) -> multistate.MultiStateSampler: """ Get a sampler based on the equilibrium sampling method requested. @@ -950,7 +948,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): mcmc_moves=integrator, online_analysis_interval=rta_its, online_analysis_target_error=et_target_err, - online_analysis_minimum_iterations=rta_min_its + online_analysis_minimum_iterations=rta_min_its, ) elif simulation_settings.sampler_method.lower() == "sams": sampler = multistate.SAMSSampler( @@ -960,7 +958,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): flatness_criteria=simulation_settings.sams_flatness_criteria, gamma0=simulation_settings.sams_gamma0, ) - elif simulation_settings.sampler_method.lower() == 'independent': + elif simulation_settings.sampler_method.lower() == "independent": sampler = multistate.MultiStateSampler( mcmc_moves=integrator, online_analysis_interval=rta_its, @@ -969,9 +967,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): ) sampler.create( - thermodynamic_states=cmp_states, - sampler_states=sampler_states, - storage=reporter + thermodynamic_states=cmp_states, sampler_states=sampler_states, storage=reporter ) sampler.energy_context_cache = energy_context_cache @@ -985,7 +981,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): reporter: multistate.MultiStateReporter, settings: dict[str, SettingsBaseModel], standard_state_corr: Optional[Quantity], - dry: bool + dry: bool, ): """ Run the simulation. @@ -1011,18 +1007,18 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): """ # Get the relevant simulation steps mc_steps = settings_validation.convert_steps_per_iteration( - simulation_settings=settings['simulation_settings'], - integrator_settings=settings['integrator_settings'], + simulation_settings=settings["simulation_settings"], + integrator_settings=settings["integrator_settings"], ) equil_steps = settings_validation.get_simsteps( - sim_length=settings['simulation_settings'].equilibration_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["simulation_settings"].equilibration_length, + timestep=settings["integrator_settings"].timestep, mc_steps=mc_steps, ) prod_steps = settings_validation.get_simsteps( - sim_length=settings['simulation_settings'].production_length, - timestep=settings['integrator_settings'].timestep, + sim_length=settings["simulation_settings"].production_length, + timestep=settings["integrator_settings"].timestep, mc_steps=mc_steps, ) @@ -1031,9 +1027,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): if self.verbose: self.logger.info("minimizing systems") - sampler.minimize( - max_iterations=settings['simulation_settings'].minimization_steps - ) + sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) # equilibrate if self.verbose: @@ -1054,8 +1048,8 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): analyzer = multistate_analysis.MultistateEquilFEAnalysis( reporter, - sampling_method=settings['simulation_settings'].sampler_method.lower(), - result_units=unit.kilocalorie_per_mole + sampling_method=settings["simulation_settings"].sampler_method.lower(), + result_units=unit.kilocalorie_per_mole, ) analyzer.plot(filepath=self.shared_basepath, filename_prefix="") analyzer.close() @@ -1063,7 +1057,9 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): return_dict = analyzer.unit_results_dict if standard_state_corr is not None: - return_dict['standard_state_correction'] = standard_state_corr.to('kilocalorie_per_mole') + return_dict["standard_state_correction"] = standard_state_corr.to( + "kilocalorie_per_mole" + ) return return_dict @@ -1072,15 +1068,18 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): reporter.close() # clean up the reporter file - fns = [self.shared_basepath / settings['output_settings'].output_filename, - self.shared_basepath / settings['output_settings'].checkpoint_storage_filename] + fns = [ + self.shared_basepath / settings["output_settings"].output_filename, + self.shared_basepath / settings["output_settings"].checkpoint_storage_filename, + ] for fn in fns: os.remove(fn) return None - def run(self, dry=False, verbose=True, - scratch_basepath=None, shared_basepath=None) -> dict[str, Any]: + def run( + self, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: """Run the absolute free energy calculation. Parameters @@ -1130,7 +1129,12 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # 6. Add restraints # Note: when no restraint is applied, restrained_omm_system == omm_system - restraint_parameter_state, standard_state_corr, restrained_omm_system, restraint_geometry = self._add_restraints( + ( + restraint_parameter_state, + standard_state_corr, + restrained_omm_system, + restraint_geometry, + ) = self._add_restraints( omm_system, omm_topology, positions, @@ -1141,10 +1145,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # 7. Get alchemical system alchem_factory, alchem_system, alchem_indices = self._get_alchemical_system( - omm_topology, - restrained_omm_system, - comp_resids, - alchem_comps + omm_topology, restrained_omm_system, comp_resids, alchem_comps ) # 8. Get compound and sampler states @@ -1160,40 +1161,40 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): # 9. Create the multistate reporter & create PDB reporter = self._get_reporter( - omm_topology, positions, - settings['simulation_settings'], - settings['output_settings'], + omm_topology, + positions, + settings["simulation_settings"], + settings["output_settings"], ) # Wrap in try/finally to avoid memory leak issues try: # 10. Get context caches energy_ctx_cache, sampler_ctx_cache = self._get_ctx_caches( - settings['forcefield_settings'], - settings['engine_settings'] + settings["forcefield_settings"], settings["engine_settings"] ) # 11. Get integrator integrator = self._get_integrator( - settings['integrator_settings'], - settings['simulation_settings'], + settings["integrator_settings"], + settings["simulation_settings"], ) # 12. Get sampler sampler = self._get_sampler( - integrator, reporter, settings['simulation_settings'], - settings['thermo_settings'], - cmp_states, sampler_states, - energy_ctx_cache, sampler_ctx_cache + integrator, + reporter, + settings["simulation_settings"], + settings["thermo_settings"], + cmp_states, + sampler_states, + energy_ctx_cache, + sampler_ctx_cache, ) # 13. Run simulation unit_result_dict = self._run_simulation( - sampler, - reporter, - settings, - standard_state_corr, - dry + sampler, reporter, settings, standard_state_corr, dry ) finally: @@ -1207,8 +1208,7 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): for context in list(sampler_ctx_cache._lru._data.keys()): del sampler_ctx_cache._lru._data[context] # cautiously clear out the global context cache too - for context in list( - openmmtools.cache.global_context_cache._lru._data.keys()): + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] del sampler_ctx_cache, energy_ctx_cache @@ -1218,29 +1218,29 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): del integrator, sampler if not dry: - nc = self.shared_basepath / settings['output_settings'].output_filename - chk = settings['output_settings'].checkpoint_storage_filename - unit_result_dict['nc'] = nc - unit_result_dict['last_checkpoint'] = chk - unit_result_dict['selection_indices'] = self.selection_indices + nc = self.shared_basepath / settings["output_settings"].output_filename + chk = settings["output_settings"].checkpoint_storage_filename + unit_result_dict["nc"] = nc + unit_result_dict["last_checkpoint"] = chk + unit_result_dict["selection_indices"] = self.selection_indices if restraint_geometry is not None: - unit_result_dict['restraint_geometry'] = restraint_geometry.model_dump() + unit_result_dict["restraint_geometry"] = restraint_geometry.model_dump() return unit_result_dict else: return { - # Add in various objects we can used to test the system - 'debug': { - 'sampler': sampler, - 'system': omm_system, - 'restrained_system': restrained_omm_system, - 'alchem_system': alchem_system, - 'alchem_indices': alchem_indices, - 'alchem_factory': alchem_factory, - 'positions': positions - } - } + # Add in various objects we can used to test the system + "debug": { + "sampler": sampler, + "system": omm_system, + "restrained_system": restrained_omm_system, + "alchem_system": alchem_system, + "alchem_indices": alchem_indices, + "alchem_factory": alchem_factory, + "positions": positions, + } + } def _execute( self, @@ -1257,4 +1257,3 @@ class BaseAbsoluteUnit(gufe.ProtocolUnit): "simtype": self.simtype, **outputs, } - diff --git a/openfe/protocols/openmm_afe/equil_afe_settings.py b/openfe/protocols/openmm_afe/equil_afe_settings.py index 6e32ccd5..ba20db4d 100644 --- a/openfe/protocols/openmm_afe/equil_afe_settings.py +++ b/openfe/protocols/openmm_afe/equil_afe_settings.py @@ -15,6 +15,7 @@ TODO * Add support for restraints """ + from gufe.settings import ( SettingsBaseModel, OpenMMSystemGeneratorFFSettings, @@ -109,15 +110,13 @@ class LambdaSettings(SettingsBaseModel): for window in v: if not 0 <= window <= 1: errmsg = ( - "Lambda windows must be between 0 and 1, got a" - f" window with value {window}." + f"Lambda windows must be between 0 and 1, got a window with value {window}." ) raise ValueError(errmsg) return v @field_validator("lambda_elec", "lambda_vdw", "lambda_restraints") def must_be_monotonic(cls, v): - difference = np.diff(v) monotonic = np.all(difference >= 0) @@ -171,7 +170,7 @@ class ABFEPreEquilOutputSettings(MDOutputSettings): # Would be better if this was just changed to a Literal # but changing types in child classes in pydantic is messy if v != "all": - msg = "output_indices must be all for ABFE " "pre-equilibration simulations" + msg = "output_indices must be all for ABFE pre-equilibration simulations" raise ValueError(msg) return v diff --git a/openfe/protocols/openmm_afe/equil_binding_afe_method.py b/openfe/protocols/openmm_afe/equil_binding_afe_method.py index e6d49761..9eca0d7f 100644 --- a/openfe/protocols/openmm_afe/equil_binding_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_binding_afe_method.py @@ -23,6 +23,7 @@ Acknowledgements `Yank `_. """ + import itertools import logging import pathlib @@ -59,10 +60,7 @@ from openfe.protocols.openmm_afe.equil_afe_settings import ( OpenMMSolvationSettings, SettingsBaseModel, ) -from openfe.protocols.openmm_utils import ( - settings_validation, - system_validation -) +from openfe.protocols.openmm_utils import settings_validation, system_validation from openfe.protocols.restraint_utils import geometry from openfe.protocols.restraint_utils.geometry.boresch import BoreschRestraintGeometry from openfe.protocols.restraint_utils.openmm import omm_restraints @@ -225,7 +223,7 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): complex_dG = _get_average( self._add_complex_standard_state_corr( individual_estimates["complex"], - individual_estimates["standard_state_correction"] + individual_estimates["standard_state_correction"], ) ) solv_dG = _get_average(individual_estimates["solvent"]) @@ -255,8 +253,7 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): complex_err = _get_stdev( self._add_complex_standard_state_corr( - individual_estimates["complex"], - individual_estimates["standard_state_correction"] + individual_estimates["complex"], individual_estimates["standard_state_correction"] ) ) solv_err = _get_stdev(individual_estimates["solvent"]) @@ -299,14 +296,11 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): given thermodynamic cycle leg. """ - forward_reverse: dict[ - str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]] - ] = {} + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} for key in ["solvent", "complex"]: forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] - for pus in self.data[key].values() + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() ] if None in forward_reverse[key]: @@ -380,14 +374,10 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): try: for key in ["solvent", "complex"]: repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] - for pus in self.data[key].values() + pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() ] except KeyError: - errmsg = ( - "Replica exchange statistics were not found, " - "did you run a repex calculation?" - ) + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -454,8 +444,7 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): for key in ["solvent", "complex"]: equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] - for pus in self.data[key].values() + pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() ] return equilibration_lengths @@ -478,8 +467,7 @@ class AbsoluteBindingProtocolResult(gufe.ProtocolResult): for key in ["solvent", "complex"]: production_lengths[key] = [ - pus[0].outputs["production_iterations"] - for pus in self.data[key].values() + pus[0].outputs["production_iterations"] for pus in self.data[key].values() ] return production_lengths @@ -556,6 +544,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): Settings a set of default settings """ + # fmt: off return AbsoluteBindingSettings( protocol_repeats=3, forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), @@ -564,7 +553,6 @@ class AbsoluteBindingProtocol(gufe.Protocol): pressure=1 * offunit.bar, ), alchemical_settings=AlchemicalSettings(), - # fmt: off solvent_lambda_settings=LambdaSettings( lambda_elec=[ 0.0, 0.25, 0.5, 0.75, 1.0, @@ -596,7 +584,6 @@ class AbsoluteBindingProtocol(gufe.Protocol): 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0, 1.00, 1.0 ], ), - # fmt: on partial_charge_settings=OpenFFPartialChargeSettings(), complex_solvation_settings=OpenMMSolvationSettings( solvent_padding=1.0 * offunit.nanometer, @@ -638,6 +625,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): checkpoint_storage_filename="complex_checkpoint.nc", ), ) + # fmt: on @staticmethod def _validate_endstates( @@ -666,15 +654,11 @@ class AbsoluteBindingProtocol(gufe.Protocol): If stateB contains any unique Components. If the alchemical species is charged. """ - if not ( - stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent) - ): + if not (stateA.contains(ProteinComponent) and stateB.contains(ProteinComponent)): errmsg = "No ProteinComponent found" raise ValueError(errmsg) - if not ( - stateA.contains(SolventComponent) and stateB.contains(SolventComponent) - ): + if not (stateA.contains(SolventComponent) and stateB.contains(SolventComponent)): errmsg = "No SolventComponent found" raise ValueError(errmsg) @@ -776,9 +760,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): *, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[ - Union[gufe.ComponentMapping, list[gufe.ComponentMapping]] - ] = None, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ): # Check we're not extending @@ -804,12 +786,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): # If the complex restraints schedule is all zero, it might be bad # but we don't dissallow it. - if all( - [ - i == 0.0 - for i in self.settings.complex_lambda_settings.lambda_restraints - ] - ): + if all([i == 0.0 for i in self.settings.complex_lambda_settings.lambda_restraints]): wmsg = ( "No restraints are being applied in the complex phase, " "this will likely lead to problematic results." @@ -829,12 +806,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): # the list gets popped out later in the SolventUnit, because we # don't have a restraint parameter state. - if any( - [ - i != 0.0 - for i in self.settings.solvent_lambda_settings.lambda_restraints - ] - ): + if any([i != 0.0 for i in self.settings.solvent_lambda_settings.lambda_restraints]): wmsg = ( "There is an attempt to add restraints in the solvent " "phase. This protocol does not apply restraints in the " @@ -858,16 +830,14 @@ class AbsoluteBindingProtocol(gufe.Protocol): # Validate integrator things settings_validation.validate_timestep( self.settings.forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep + self.settings.integrator_settings.timestep, ) def _create( self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[ - Union[gufe.ComponentMapping, list[gufe.ComponentMapping]] - ] = None, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # Validate inputs @@ -892,10 +862,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): alchemical_components=alchem_comps, generation=0, repeat_id=int(uuid.uuid4()), - name=( - f"Absolute Binding, {alchname} solvent leg: " - f"repeat {i} generation 0" - ), + name=(f"Absolute Binding, {alchname} solvent leg: repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] @@ -908,10 +875,7 @@ class AbsoluteBindingProtocol(gufe.Protocol): alchemical_components=alchem_comps, generation=0, repeat_id=int(uuid.uuid4()), - name=( - f"Absolute Binding, {alchname} complex leg: " - f"repeat {i} generation 0" - ), + name=(f"Absolute Binding, {alchname} complex leg: repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] @@ -940,14 +904,10 @@ class AbsoluteBindingProtocol(gufe.Protocol): "complex": {}, } for k, v in unsorted_solvent_repeats.items(): - repeats["solvent"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_complex_repeats.items(): - repeats["complex"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) return repeats @@ -955,6 +915,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): """ Protocol Unit for the complex phase of an absolute binding free energy """ + simtype = "complex" def _get_components(self): @@ -1017,9 +978,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): settings["lambda_settings"] = prot_settings.complex_lambda_settings settings["engine_settings"] = prot_settings.engine_settings settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = ( - prot_settings.complex_equil_simulation_settings - ) + settings["equil_simulation_settings"] = prot_settings.complex_equil_simulation_settings settings["equil_output_settings"] = prot_settings.complex_equil_output_settings settings["simulation_settings"] = prot_settings.complex_simulation_settings settings["output_settings"] = prot_settings.complex_output_settings @@ -1238,9 +1197,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): # and the solvent. solv_comps = [c for c in comp_resids if isinstance(c, SolventComponent)] exclude_comps = [alchem_comps["stateA"]] + solv_comps - residxs = np.concatenate( - [v for i, v in comp_resids.items() if i not in exclude_comps] - ) + residxs = np.concatenate([v for i, v in comp_resids.items() if i not in exclude_comps]) host_atom_ids = self._get_idxs_from_residxs(topology, residxs) @@ -1251,8 +1208,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): univ = self._get_mda_universe( topology, positions, - self.shared_basepath - / settings["equil_output_settings"].production_trajectory_filename, + self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename, ) if isinstance(settings["restraint_settings"], BoreschRestraintSettings): @@ -1292,9 +1248,7 @@ class AbsoluteBindingComplexUnit(BaseAbsoluteUnit): ) # Get the GlobalParameterState for the restraint - restraint_parameter_state = omm_restraints.RestraintParameterState( - lambda_restraints=1.0 - ) + restraint_parameter_state = omm_restraints.RestraintParameterState(lambda_restraints=1.0) return ( restraint_parameter_state, correction, @@ -1309,6 +1263,7 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): """ Protocol Unit for the solvent phase of an absolute binding free energy """ + simtype = "solvent" def _get_components(self): @@ -1370,9 +1325,7 @@ class AbsoluteBindingSolventUnit(BaseAbsoluteUnit): settings["lambda_settings"] = prot_settings.solvent_lambda_settings settings["engine_settings"] = prot_settings.engine_settings settings["integrator_settings"] = prot_settings.integrator_settings - settings["equil_simulation_settings"] = ( - prot_settings.solvent_equil_simulation_settings - ) + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings settings["simulation_settings"] = prot_settings.solvent_simulation_settings settings["output_settings"] = prot_settings.solvent_output_settings diff --git a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py index c7f0b1d7..5819b232 100644 --- a/openfe/protocols/openmm_afe/equil_solvation_afe_method.py +++ b/openfe/protocols/openmm_afe/equil_solvation_afe_method.py @@ -27,6 +27,7 @@ Acknowledgements `espaloma_charge `_ """ + from __future__ import annotations import pathlib @@ -46,15 +47,22 @@ import uuid from gufe import ( settings, - ChemicalSystem, SmallMoleculeComponent, - ProteinComponent, SolventComponent + ChemicalSystem, + SmallMoleculeComponent, + ProteinComponent, + SolventComponent, ) from openfe.protocols.openmm_afe.equil_afe_settings import ( AbsoluteSolvationSettings, - OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, - MDSimulationSettings, MDOutputSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, MultiStateOutputSettings, + OpenMMSolvationSettings, + AlchemicalSettings, + LambdaSettings, + MDSimulationSettings, + MDOutputSettings, + MultiStateSimulationSettings, + OpenMMEngineSettings, + IntegratorSettings, + MultiStateOutputSettings, OpenFFPartialChargeSettings, SettingsBaseModel, ) @@ -63,38 +71,50 @@ from .base import BaseAbsoluteUnit from openfe.due import due, Doi -due.cite(Doi("10.5281/zenodo.596504"), - description="Yank", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596504"), + description="Yank", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.48550/ARXIV.2302.06758"), - description="EspalomaCharge", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.48550/ARXIV.2302.06758"), + description="EspalomaCharge", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) -due.cite(Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_afe.equil_solvation_afe_method", - cite_module=True) +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_afe.equil_solvation_afe_method", + cite_module=True, +) logger = logging.getLogger(__name__) class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): - """Dict-like container for the output of a AbsoluteSolvationProtocol - """ + """Dict-like container for the output of a AbsoluteSolvationProtocol""" + def __init__(self, **data): super().__init__(**data) # TODO: Detect when we have extensions and stitch these together? - if any(len(pur_list) > 2 for pur_list - in itertools.chain(self.data['solvent'].values(), self.data['vacuum'].values())): + if any( + len(pur_list) > 2 + for pur_list in itertools.chain( + self.data["solvent"].values(), self.data["vacuum"].values() + ) + ): raise NotImplementedError("Can't stitch together results yet") def get_individual_estimates(self) -> dict[str, list[tuple[Quantity, Quantity]]]: @@ -112,19 +132,15 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): vac_dGs = [] solv_dGs = [] - for pus in self.data['vacuum'].values(): - vac_dGs.append(( - pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error'] - )) + for pus in self.data["vacuum"].values(): + vac_dGs.append((pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"])) - for pus in self.data['solvent'].values(): - solv_dGs.append(( - pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error'] - )) + for pus in self.data["solvent"].values(): + solv_dGs.append( + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + ) - return {'solvent': solv_dGs, 'vacuum': vac_dGs} + return {"solvent": solv_dGs, "vacuum": vac_dGs} def get_estimate(self): """Get the solvation free energy estimate for this calculation. @@ -134,6 +150,7 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): dG : openff.units.Quantity The solvation free energy. This is a Quantity defined with units. """ + def _get_average(estimates): # Get the unit value of the first value in the estimates u = estimates[0][0].u @@ -144,8 +161,8 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): return np.average(dGs) * u individual_estimates = self.get_individual_estimates() - vac_dG = _get_average(individual_estimates['vacuum']) - solv_dG = _get_average(individual_estimates['solvent']) + vac_dG = _get_average(individual_estimates["vacuum"]) + solv_dG = _get_average(individual_estimates["solvent"]) return vac_dG - solv_dG @@ -158,6 +175,7 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): The standard deviation between estimates of the solvation free energy. This is a Quantity defined with units. """ + def _get_stdev(estimates): # Get the unit value of the first value in the estimates u = estimates[0][0].u @@ -168,13 +186,15 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): return np.std(dGs) * u individual_estimates = self.get_individual_estimates() - vac_err = _get_stdev(individual_estimates['vacuum']) - solv_err = _get_stdev(individual_estimates['solvent']) + vac_err = _get_stdev(individual_estimates["vacuum"]) + solv_err = _get_stdev(individual_estimates["solvent"]) # return the combined error return np.sqrt(vac_err**2 + solv_err**2) - def get_forward_and_reverse_energy_analysis(self) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: + def get_forward_and_reverse_energy_analysis( + self, + ) -> dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]]: """ Get the reverse and forward analysis of the free energies. @@ -209,10 +229,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: forward_reverse[key] = [ - pus[0].outputs['forward_and_reverse_energies'] - for pus in self.data[key].values() + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() ] if None in forward_reverse[key]: @@ -249,10 +268,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): # Loop through and get the repeats and get the matrices overlap_stats: dict[str, list[dict[str, npt.NDArray]]] = {} - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: overlap_stats[key] = [ - pus[0].outputs['unit_mbar_overlap'] - for pus in self.data[key].values() + pus[0].outputs["unit_mbar_overlap"] for pus in self.data[key].values() ] return overlap_stats @@ -283,14 +301,12 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): """ repex_stats: dict[str, list[dict[str, npt.NDArray]]] = {} try: - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: repex_stats[key] = [ - pus[0].outputs['replica_exchange_statistics'] - for pus in self.data[key].values() + pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() ] except KeyError: - errmsg = ("Replica exchange statistics were not found, " - "did you run a repex calculation?") + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -306,9 +322,7 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): the thermodynamic cycle, with lists of replica states timeseries for each repeat of that simulation type. """ - replica_states: dict[str, list[npt.NDArray]] = { - 'solvent': [], 'vacuum': [] - } + replica_states: dict[str, list[npt.NDArray]] = {"solvent": [], "vacuum": []} def is_file(filename: str): p = pathlib.Path(filename) @@ -325,7 +339,7 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): chk = is_file(dir_path / chk).name reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode='r' + storage=nc, checkpoint_storage=chk, open_mode="r" ) retval = np.asarray(reporter.read_replica_thermodynamic_states()) @@ -333,11 +347,11 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): return retval - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: for pus in self.data[key].values(): states = get_replica_state( - pus[0].outputs['nc'], - pus[0].outputs['last_checkpoint'], + pus[0].outputs["nc"], + pus[0].outputs["last_checkpoint"], ) replica_states[key].append(states) @@ -357,10 +371,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): """ equilibration_lengths: dict[str, list[float]] = {} - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: equilibration_lengths[key] = [ - pus[0].outputs['equilibration_iterations'] - for pus in self.data[key].values() + pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() ] return equilibration_lengths @@ -381,10 +394,9 @@ class AbsoluteSolvationProtocolResult(gufe.ProtocolResult): """ production_lengths: dict[str, list[float]] = {} - for key in ['solvent', 'vacuum']: + for key in ["solvent", "vacuum"]: production_lengths[key] = [ - pus[0].outputs['production_iterations'] - for pus in self.data[key].values() + pus[0].outputs["production_iterations"] for pus in self.data[key].values() ] return production_lengths @@ -402,6 +414,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol): :class:`openfe.protocols.openmm_afe.AbsoluteSolvationVacuumUnit` :class:`openfe.protocols.openmm_afe.AbsoluteSolvationSolventUnit` """ + result_cls = AbsoluteSolvationProtocolResult _settings_cls = AbsoluteSolvationSettings _settings: AbsoluteSolvationSettings @@ -423,7 +436,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol): protocol_repeats=3, solvent_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings(), vacuum_forcefield_settings=settings.OpenMMSystemGeneratorFFSettings( - nonbonded_method='nocutoff', + nonbonded_method="nocutoff", ), thermo_settings=settings.ThermoSettings( temperature=298.15 * unit.kelvin, @@ -452,10 +465,10 @@ class AbsoluteSolvationProtocol(gufe.Protocol): production_length=0.5 * unit.nanosecond, ), solvent_equil_output_settings=MDOutputSettings( - equil_nvt_structure='equil_nvt_structure.pdb', - equil_npt_structure='equil_npt_structure.pdb', - production_trajectory_filename='production_equil.xtc', - log_output='equil_simulation.log', + equil_nvt_structure="equil_nvt_structure.pdb", + equil_npt_structure="equil_npt_structure.pdb", + production_trajectory_filename="production_equil.xtc", + log_output="equil_simulation.log", ), solvent_simulation_settings=MultiStateSimulationSettings( n_replicas=14, @@ -463,8 +476,8 @@ class AbsoluteSolvationProtocol(gufe.Protocol): production_length=10.0 * unit.nanosecond, ), solvent_output_settings=MultiStateOutputSettings( - output_filename='solvent.nc', - checkpoint_storage_filename='solvent_checkpoint.nc', + output_filename="solvent.nc", + checkpoint_storage_filename="solvent_checkpoint.nc", ), vacuum_equil_simulation_settings=MDSimulationSettings( equilibration_length_nvt=None, @@ -473,9 +486,9 @@ class AbsoluteSolvationProtocol(gufe.Protocol): ), vacuum_equil_output_settings=MDOutputSettings( equil_nvt_structure=None, - equil_npt_structure='equil_structure.pdb', - production_trajectory_filename='production_equil.xtc', - log_output='equil_simulation.log', + equil_npt_structure="equil_structure.pdb", + production_trajectory_filename="production_equil.xtc", + log_output="equil_simulation.log", ), vacuum_simulation_settings=MultiStateSimulationSettings( n_replicas=14, @@ -483,14 +496,15 @@ class AbsoluteSolvationProtocol(gufe.Protocol): production_length=2.0 * unit.nanosecond, ), vacuum_output_settings=MultiStateOutputSettings( - output_filename='vacuum.nc', - checkpoint_storage_filename='vacuum_checkpoint.nc' + output_filename="vacuum.nc", + checkpoint_storage_filename="vacuum_checkpoint.nc", ), - ) + ) # fmt: skip @staticmethod def _validate_endstates( - stateA: ChemicalSystem, stateB: ChemicalSystem, + stateA: ChemicalSystem, + stateB: ChemicalSystem, ) -> None: """ A solvent transformation is defined (in terms of gufe components) @@ -528,8 +542,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol): """ # Check that there are no protein components if stateA.contains(ProteinComponent) or stateB.contains(ProteinComponent): - errmsg = ("Protein components are not allowed for " - "absolute solvation free energies.") + errmsg = "Protein components are not allowed for absolute solvation free energies." raise ValueError(errmsg) # Check that there is a solvent component in both end states @@ -569,14 +582,13 @@ class AbsoluteSolvationProtocol(gufe.Protocol): # If there are any alchemical Components in state B if len(diff[1]) > 0: - errmsg = ("Components appearing in state B are not " - "currently supported") + errmsg = "Components appearing in state B are not currently supported" raise ValueError(errmsg) @staticmethod def _validate_lambda_schedule( - lambda_settings: LambdaSettings, - simulation_settings: MultiStateSimulationSettings, + lambda_settings: LambdaSettings, + simulation_settings: MultiStateSimulationSettings, ) -> None: """ Checks that the lambda schedule is set up correctly. @@ -612,14 +624,17 @@ class AbsoluteSolvationProtocol(gufe.Protocol): "Components elec, vdw, and restraints must have equal amount" f" of lambda windows. Got {len(lambda_elec)} elec lambda" f" windows, {len(lambda_vdw)} vdw lambda windows, and" - f"{len(lambda_restraints)} restraints lambda windows.") + f"{len(lambda_restraints)} restraints lambda windows." + ) raise ValueError(errmsg) # Ensure that number of overall lambda windows matches number of lambda # windows for individual components if n_replicas != len(lambda_vdw): - errmsg = (f"Number of replicas {n_replicas} does not equal the" - f" number of lambda windows {len(lambda_vdw)}") + errmsg = ( + f"Number of replicas {n_replicas} does not equal the" + f" number of lambda windows {len(lambda_vdw)}" + ) raise ValueError(errmsg) # Check if there are lambda windows with naked charges @@ -629,15 +644,18 @@ class AbsoluteSolvationProtocol(gufe.Protocol): "There are states along this lambda schedule " "where there are atoms with charges but no LJ " f"interactions: lambda {inx}: " - f"elec {lam} vdW {lambda_vdw[inx]}") + f"elec {lam} vdW {lambda_vdw[inx]}" + ) raise ValueError(errmsg) # Check if there are lambda windows with non-zero restraints if len([r for r in lambda_restraints if r != 0]) > 0: - wmsg = ("Non-zero restraint lambdas applied. The absolute " - "solvation protocol doesn't apply restraints, " - "therefore restraints won't be applied. " - f"Given lambda_restraints: {lambda_restraints}") + wmsg = ( + "Non-zero restraint lambdas applied. The absolute " + "solvation protocol doesn't apply restraints, " + "therefore restraints won't be applied. " + f"Given lambda_restraints: {lambda_restraints}" + ) logger.warning(wmsg) warnings.warn(wmsg) @@ -667,7 +685,7 @@ class AbsoluteSolvationProtocol(gufe.Protocol): # Validate the lambda schedule for solv_sets in ( self.settings.solvent_simulation_settings, - self.settings.vacuum_simulation_settings + self.settings.vacuum_simulation_settings, ): self._validate_lambda_schedule( self.settings.lambda_settings, @@ -682,16 +700,16 @@ class AbsoluteSolvationProtocol(gufe.Protocol): system_validation.validate_solvent(stateA, solv_nonbonded_method) # Gas phase is always gas phase - if vac_nonbonded_method.lower() != 'nocutoff': - errmsg = ("Only the nocutoff nonbonded_method is supported for " - f"vacuum calculations, {vac_nonbonded_method} was " - "passed") + if vac_nonbonded_method.lower() != "nocutoff": + errmsg = ( + "Only the nocutoff nonbonded_method is supported for " + f"vacuum calculations, {vac_nonbonded_method} was " + "passed" + ) raise ValueError(errmsg) # Validate solvation settings - settings_validation.validate_openmm_solvation_settings( - self.settings.solvation_settings - ) + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) # Check vacuum equilibration MD settings is 0 ns nvt_time = self.settings.vacuum_equil_simulation_settings.equilibration_length_nvt @@ -703,15 +721,14 @@ class AbsoluteSolvationProtocol(gufe.Protocol): # Validate integrator things settings_validation.validate_timestep( self.settings.vacuum_forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep + self.settings.integrator_settings.timestep, ) settings_validation.validate_timestep( self.settings.solvent_forcefield_settings.hydrogen_mass, - self.settings.integrator_settings.timestep + self.settings.integrator_settings.timestep, ) - def _create( self, stateA: ChemicalSystem, @@ -720,17 +737,16 @@ class AbsoluteSolvationProtocol(gufe.Protocol): extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # Validate inputs - self.validate( - stateA=stateA, stateB=stateB, mapping=mapping, extends=extends - ) + self.validate(stateA=stateA, stateB=stateB, mapping=mapping, extends=extends) # Get the alchemical components alchem_comps = system_validation.get_alchemical_components( - stateA, stateB, + stateA, + stateB, ) # Get the name of the alchemical species - alchname = alchem_comps['stateA'][0].name + alchname = alchem_comps["stateA"][0].name # Create list units for vacuum and solvent transforms solvent_units = [ @@ -739,9 +755,9 @@ class AbsoluteSolvationProtocol(gufe.Protocol): stateA=stateA, stateB=stateB, alchemical_components=alchem_comps, - generation=0, repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} solvent leg: " - f"repeat {i} generation 0"), + generation=0, + repeat_id=int(uuid.uuid4()), + name=(f"Absolute Solvation, {alchname} solvent leg: repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] @@ -754,9 +770,9 @@ class AbsoluteSolvationProtocol(gufe.Protocol): stateA=stateA, stateB=stateB, alchemical_components=alchem_comps, - generation=0, repeat_id=int(uuid.uuid4()), - name=(f"Absolute Solvation, {alchname} vacuum leg: " - f"repeat {i} generation 0"), + generation=0, + repeat_id=int(uuid.uuid4()), + name=(f"Absolute Solvation, {alchname} vacuum leg: repeat {i} generation 0"), ) for i in range(self.settings.protocol_repeats) ] @@ -775,19 +791,20 @@ class AbsoluteSolvationProtocol(gufe.Protocol): for pu in d.protocol_unit_results: if not pu.ok(): continue - if pu.outputs['simtype'] == 'solvent': - unsorted_solvent_repeats[pu.outputs['repeat_id']].append(pu) + if pu.outputs["simtype"] == "solvent": + unsorted_solvent_repeats[pu.outputs["repeat_id"]].append(pu) else: - unsorted_vacuum_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_vacuum_repeats[pu.outputs["repeat_id"]].append(pu) repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { - 'solvent': {}, 'vacuum': {}, + "solvent": {}, + "vacuum": {}, } for k, v in unsorted_solvent_repeats.items(): - repeats['solvent'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_vacuum_repeats.items(): - repeats['vacuum'][str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats["vacuum"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) return repeats @@ -795,6 +812,7 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): """ Protocol Unit for the vacuum phase of an absolute solvation free energy """ + simtype = "vacuum" def _get_components(self): @@ -815,11 +833,10 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): is equivalent to the alchemical components in stateA (since we only allow for disappearing ligands). """ - stateA = self._inputs['stateA'] - alchem_comps = self._inputs['alchemical_components'] + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] - off_comps = {m: m.to_openff() - for m in alchem_comps['stateA']} + off_comps = {m: m.to_openff() for m in alchem_comps["stateA"]} _, prot_comp, _ = system_validation.get_components(stateA) @@ -851,21 +868,21 @@ class AbsoluteSolvationVacuumUnit(BaseAbsoluteUnit): * simulation_settings : SimulationSettings * output_settings: MultiStateOutputSettings """ - prot_settings = self._inputs['protocol'].settings + prot_settings = self._inputs["protocol"].settings settings = {} - settings['forcefield_settings'] = prot_settings.vacuum_forcefield_settings - settings['thermo_settings'] = prot_settings.thermo_settings - settings['charge_settings'] = prot_settings.partial_charge_settings - settings['solvation_settings'] = prot_settings.solvation_settings - settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.vacuum_engine_settings - settings['integrator_settings'] = prot_settings.integrator_settings - settings['equil_simulation_settings'] = prot_settings.vacuum_equil_simulation_settings - settings['equil_output_settings'] = prot_settings.vacuum_equil_output_settings - settings['simulation_settings'] = prot_settings.vacuum_simulation_settings - settings['output_settings'] = prot_settings.vacuum_output_settings + settings["forcefield_settings"] = prot_settings.vacuum_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.vacuum_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.vacuum_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.vacuum_equil_output_settings + settings["simulation_settings"] = prot_settings.vacuum_simulation_settings + settings["output_settings"] = prot_settings.vacuum_output_settings return settings @@ -874,6 +891,7 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): """ Protocol Unit for the solvent phase of an absolute solvation free energy """ + simtype = "solvent" def _get_components(self): @@ -891,8 +909,8 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): small_mols : dict[SmallMoleculeComponent: OFFMolecule] SmallMoleculeComponents to add to the system. """ - stateA = self._inputs['stateA'] - alchem_comps = self._inputs['alchemical_components'] + stateA = self._inputs["stateA"] + alchem_comps = self._inputs["alchemical_components"] solv_comp, prot_comp, small_mols = system_validation.get_components(stateA) off_comps = {m: m.to_openff() for m in small_mols} @@ -925,20 +943,20 @@ class AbsoluteSolvationSolventUnit(BaseAbsoluteUnit): * simulation_settings : MultiStateSimulationSettings * output_settings: MultiStateOutputSettings """ - prot_settings = self._inputs['protocol'].settings + prot_settings = self._inputs["protocol"].settings settings = {} - settings['forcefield_settings'] = prot_settings.solvent_forcefield_settings - settings['thermo_settings'] = prot_settings.thermo_settings - settings['charge_settings'] = prot_settings.partial_charge_settings - settings['solvation_settings'] = prot_settings.solvation_settings - settings['alchemical_settings'] = prot_settings.alchemical_settings - settings['lambda_settings'] = prot_settings.lambda_settings - settings['engine_settings'] = prot_settings.solvent_engine_settings - settings['integrator_settings'] = prot_settings.integrator_settings - settings['equil_simulation_settings'] = prot_settings.solvent_equil_simulation_settings - settings['equil_output_settings'] = prot_settings.solvent_equil_output_settings - settings['simulation_settings'] = prot_settings.solvent_simulation_settings - settings['output_settings'] = prot_settings.solvent_output_settings + settings["forcefield_settings"] = prot_settings.solvent_forcefield_settings + settings["thermo_settings"] = prot_settings.thermo_settings + settings["charge_settings"] = prot_settings.partial_charge_settings + settings["solvation_settings"] = prot_settings.solvation_settings + settings["alchemical_settings"] = prot_settings.alchemical_settings + settings["lambda_settings"] = prot_settings.lambda_settings + settings["engine_settings"] = prot_settings.solvent_engine_settings + settings["integrator_settings"] = prot_settings.integrator_settings + settings["equil_simulation_settings"] = prot_settings.solvent_equil_simulation_settings + settings["equil_output_settings"] = prot_settings.solvent_equil_output_settings + settings["simulation_settings"] = prot_settings.solvent_simulation_settings + settings["output_settings"] = prot_settings.solvent_output_settings return settings diff --git a/openfe/protocols/openmm_md/plain_md_methods.py b/openfe/protocols/openmm_md/plain_md_methods.py index 64bde6ab..bed130e9 100644 --- a/openfe/protocols/openmm_md/plain_md_methods.py +++ b/openfe/protocols/openmm_md/plain_md_methods.py @@ -8,6 +8,7 @@ This module implements the necessary methodology tools to run an MD simulation using OpenMM tools. """ + from __future__ import annotations import logging @@ -34,19 +35,25 @@ from gufe import ( from gufe.settings.typing import KelvinQuantity from openfe.protocols.openmm_utils.omm_settings import ( BasePartialChargeSettings, - FemtosecondQuantity + FemtosecondQuantity, ) from openfe.protocols.openmm_md.plain_md_settings import ( PlainMDProtocolSettings, OpenFFPartialChargeSettings, - OpenMMSolvationSettings, OpenMMEngineSettings, - IntegratorSettings, MDSimulationSettings, MDOutputSettings, + OpenMMSolvationSettings, + OpenMMEngineSettings, + IntegratorSettings, + MDSimulationSettings, + MDOutputSettings, ) from openff.toolkit.topology import Molecule as OFFMolecule from openfe.protocols.openmm_utils import ( - system_validation, settings_validation, system_creation, - charge_generation, omm_compute + system_validation, + settings_validation, + system_creation, + charge_generation, + omm_compute, ) logger = logging.getLogger(__name__) @@ -59,6 +66,7 @@ class PlainMDProtocolResult(gufe.ProtocolResult): Provides access to simulation outputs including the pre-minimized system PDB and production trajectory files. """ + def __init__(self, **data): super().__init__(**data) # data is mapping of str(repeat_id): list[protocolunitresults] @@ -89,7 +97,7 @@ class PlainMDProtocolResult(gufe.ProtocolResult): traj : list[pathlib.Path] list of paths (pathlib.Path) to the simulation trajectory """ - traj = [pus[0].outputs['nc'] for pus in self.data.values()] + traj = [pus[0].outputs["nc"] for pus in self.data.values()] return traj @@ -102,7 +110,7 @@ class PlainMDProtocolResult(gufe.ProtocolResult): pdbs : list[pathlib.Path] list of paths (pathlib.Path) to the pdb files """ - pdbs = [pus[0].outputs['system_pdb'] for pus in self.data.values()] + pdbs = [pus[0].outputs["system_pdb"] for pus in self.data.values()] return pdbs @@ -118,6 +126,7 @@ class PlainMDProtocol(gufe.Protocol): :class:`openfe.protocols.openmm_md.PlainMDProtocolUnit` :class:`openfe.protocols.openmm_md.PlainMDProtocolResult` """ + result_cls = PlainMDProtocolResult _settings_cls = PlainMDProtocolSettings _settings: PlainMDProtocolSettings @@ -155,11 +164,11 @@ class PlainMDProtocol(gufe.Protocol): ) def _create( - self, - stateA: ChemicalSystem, - stateB: ChemicalSystem, - mapping: Optional[dict[str, gufe.ComponentMapping]] = None, - extends: Optional[gufe.ProtocolDAGResult] = None, + self, + stateA: ChemicalSystem, + stateB: ChemicalSystem, + mapping: Optional[dict[str, gufe.ComponentMapping]] = None, + extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: Extensions? if extends: @@ -173,9 +182,7 @@ class PlainMDProtocol(gufe.Protocol): system_validation.validate_protein(stateA) # Validate solvation settings - settings_validation.validate_openmm_solvation_settings( - self.settings.solvation_settings - ) + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) # actually create and return Units # TODO: Deal with multiple ProteinComponents @@ -187,25 +194,27 @@ class PlainMDProtocol(gufe.Protocol): if comp is not None: comp_type = comp.__class__.__name__ if len(comp.name) == 0: - comp_name = 'NoName' + comp_name = "NoName" else: comp_name = comp.name system_name += f" {comp_type}:{comp_name}" # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats - units = [PlainMDProtocolUnit( - protocol=self, - stateA=stateA, - generation=0, repeat_id=int(uuid.uuid4()), - name=f'{system_name} repeat {i} generation 0') - for i in range(n_repeats)] + units = [ + PlainMDProtocolUnit( + protocol=self, + stateA=stateA, + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{system_name} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] return units - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: # result units will have a repeat_id and generations within this # repeat_id # first group according to repeat_id @@ -216,12 +225,12 @@ class PlainMDProtocol(gufe.Protocol): if not pu.ok(): continue - unsorted_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) # then sort by generation within each repeat_id list repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) # returns a dict of repeat_id: sorted list of ProtocolUnitResult return repeats @@ -266,23 +275,24 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): protocol=protocol, stateA=stateA, repeat_id=repeat_id, - generation=generation + generation=generation, ) @staticmethod - def _run_MD(simulation: openmm.app.Simulation, - positions: omm_unit.Quantity, - simulation_settings: MDSimulationSettings, - output_settings: MDOutputSettings, - temperature: KelvinQuantity, - barostat_frequency: Quantity, - timestep: FemtosecondQuantity, - equil_steps_nvt: Optional[int], - equil_steps_npt: int, - prod_steps: int, - verbose=True, - shared_basepath=None) -> None: - + def _run_MD( + simulation: openmm.app.Simulation, + positions: omm_unit.Quantity, + simulation_settings: MDSimulationSettings, + output_settings: MDOutputSettings, + temperature: KelvinQuantity, + barostat_frequency: Quantity, + timestep: FemtosecondQuantity, + equil_steps_nvt: Optional[int], + equil_steps_npt: int, + prod_steps: int, + verbose=True, + shared_basepath=None, + ) -> None: """ Energy minimization, Equilibration and Production MD to be reused in multiple protocols @@ -318,24 +328,26 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): """ if shared_basepath is None: - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") simulation.context.setPositions(positions) # minimize if verbose: logger.info("minimizing systems") - simulation.minimizeEnergy( - maxIterations=simulation_settings.minimization_steps - ) + simulation.minimizeEnergy(maxIterations=simulation_settings.minimization_steps) # Get the sub selection of the system to save coords for - selection_indices = mdtraj.Topology.from_openmm( - simulation.topology).select(output_settings.output_indices) + selection_indices = mdtraj.Topology.from_openmm(simulation.topology).select( + output_settings.output_indices + ) - positions = to_openmm(from_openmm( - simulation.context.getState(getPositions=True, - enforcePeriodicBox=False - ).getPositions())) + positions = to_openmm( + from_openmm( + simulation.context.getState( + getPositions=True, enforcePeriodicBox=False + ).getPositions() + ) + ) # Store subset of atoms, specified in input, as PDB file mdtraj_top = mdtraj.Topology.from_openmm(simulation.topology) traj = mdtraj.Trajectory( @@ -344,9 +356,7 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): ) if output_settings.minimized_structure: - traj.save_pdb( - shared_basepath / output_settings.minimized_structure - ) + traj.save_pdb(shared_basepath / output_settings.minimized_structure) # equilibrate # NVT equilibration if equil_steps_nvt: @@ -355,66 +365,64 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): # Set barostat frequency to zero for NVT for x in simulation.context.getSystem().getForces(): - if x.getName() == 'MonteCarloBarostat': + if x.getName() == "MonteCarloBarostat": x.setFrequency(0) - simulation.context.setVelocitiesToTemperature( - to_openmm(temperature)) + simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) t0 = time.time() simulation.step(equil_steps_nvt) t1 = time.time() if verbose: - logger.info( - f"Completed NVT equilibration in {t1 - t0} seconds") + logger.info(f"Completed NVT equilibration in {t1 - t0} seconds") # Save last frame NVT equilibration positions = to_openmm( - from_openmm(simulation.context.getState( - getPositions=True, enforcePeriodicBox=False - ).getPositions())) + from_openmm( + simulation.context.getState( + getPositions=True, enforcePeriodicBox=False + ).getPositions() + ) + ) traj = mdtraj.Trajectory( positions[selection_indices, :], mdtraj_top.subset(selection_indices), ) if output_settings.equil_nvt_structure is not None: - traj.save_pdb( - shared_basepath / output_settings.equil_nvt_structure - ) + traj.save_pdb(shared_basepath / output_settings.equil_nvt_structure) # NPT equilibration if verbose: logger.info("Running NPT equilibration") - simulation.context.setVelocitiesToTemperature( - to_openmm(temperature)) + simulation.context.setVelocitiesToTemperature(to_openmm(temperature)) # Enable the barostat for NPT for x in simulation.context.getSystem().getForces(): - if x.getName() == 'MonteCarloBarostat': + if x.getName() == "MonteCarloBarostat": x.setFrequency(barostat_frequency.m) t0 = time.time() simulation.step(equil_steps_npt) t1 = time.time() if verbose: - logger.info( - f"Completed NPT equilibration in {t1 - t0} seconds") + logger.info(f"Completed NPT equilibration in {t1 - t0} seconds") # Save last frame NPT equilibration positions = to_openmm( - from_openmm(simulation.context.getState( - getPositions=True, enforcePeriodicBox=False - ).getPositions())) + from_openmm( + simulation.context.getState( + getPositions=True, enforcePeriodicBox=False + ).getPositions() + ) + ) traj = mdtraj.Trajectory( positions[selection_indices, :], mdtraj_top.subset(selection_indices), ) if output_settings.equil_npt_structure is not None: - traj.save_pdb( - shared_basepath / output_settings.equil_npt_structure - ) + traj.save_pdb(shared_basepath / output_settings.equil_npt_structure) # production if verbose: @@ -436,33 +444,34 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): if output_settings.production_trajectory_filename: xtc_reporter = XTCReporter( - file=str( - shared_basepath / - output_settings.production_trajectory_filename), + file=str(shared_basepath / output_settings.production_trajectory_filename), reportInterval=write_interval, - atomSubset=selection_indices + atomSubset=selection_indices, ) simulation.reporters.append(xtc_reporter) if output_settings.checkpoint_storage_filename: - simulation.reporters.append(openmm.app.CheckpointReporter( - file=str( - shared_basepath / - output_settings.checkpoint_storage_filename), - reportInterval=checkpoint_interval)) + simulation.reporters.append( + openmm.app.CheckpointReporter( + file=str(shared_basepath / output_settings.checkpoint_storage_filename), + reportInterval=checkpoint_interval, + ) + ) if output_settings.log_output: - simulation.reporters.append(openmm.app.StateDataReporter( - str(shared_basepath / output_settings.log_output), - checkpoint_interval, - step=True, - time=True, - potentialEnergy=True, - kineticEnergy=True, - totalEnergy=True, - temperature=True, - volume=True, - density=True, - speed=True, - )) + simulation.reporters.append( + openmm.app.StateDataReporter( + str(shared_basepath / output_settings.log_output), + checkpoint_interval, + step=True, + time=True, + potentialEnergy=True, + kineticEnergy=True, + totalEnergy=True, + temperature=True, + volume=True, + density=True, + speed=True, + ) + ) t0 = time.time() simulation.step(prod_steps) t1 = time.time() @@ -498,9 +507,9 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): nagl_model=charge_settings.nagl_model, ) - def run(self, *, dry=False, verbose=True, - scratch_basepath=None, - shared_basepath=None) -> dict[str, Any]: + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: """Run the MD simulation. Parameters @@ -532,15 +541,17 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): self.logger.info("Creating system") if shared_basepath is None: # use cwd - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: PlainMDProtocolSettings = self._inputs['protocol'].settings - stateA = self._inputs['stateA'] + protocol_settings: PlainMDProtocolSettings = self._inputs["protocol"].settings + stateA = self._inputs["stateA"] - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings + forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( + protocol_settings.forcefield_settings + ) thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings solvation_settings: OpenMMSolvationSettings = protocol_settings.solvation_settings charge_settings: BasePartialChargeSettings = protocol_settings.partial_charge_settings @@ -550,25 +561,26 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): integrator_settings = protocol_settings.integrator_settings # is the timestep good for the mass? - settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, timestep - ) + settings_validation.validate_timestep(forcefield_settings.hydrogen_mass, timestep) if sim_settings.equilibration_length_nvt is not None: equil_steps_nvt = settings_validation.get_simsteps( sim_length=sim_settings.equilibration_length_nvt, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) else: equil_steps_nvt = None equil_steps_npt = settings_validation.get_simsteps( sim_length=sim_settings.equilibration_length, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) prod_steps = settings_validation.get_simsteps( sim_length=sim_settings.production_length, - timestep=timestep, mc_steps=1, + timestep=timestep, + mc_steps=1, ) solvent_comp, protein_comp, small_mols = system_validation.get_components(stateA) @@ -602,9 +614,7 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): # Force creation of smc templates so we can solvate later for mol in smc_components.values(): - system_generator.create_system( - mol.to_topology().to_openmm(), molecules=[mol] - ) + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # c. get OpenMM Modeller + a resids dictionary for each component stateA_modeller, comp_resids = system_creation.get_omm_modeller( @@ -618,9 +628,7 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): # d. get topology & positions # Note: roundtrip positions to remove vec3 issues stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm( - from_openmm(stateA_modeller.getPositions()) - ) + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) # e. create the stateA System stateA_system = system_generator.create_system( @@ -630,19 +638,17 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): # f. Save pdb of entire system if output_settings.preminimized_structure: - with open( - shared_basepath / - output_settings.preminimized_structure, "w") as f: + with open(shared_basepath / output_settings.preminimized_structure, "w") as f: openmm.app.PDBFile.writeFile( stateA_topology, stateA_positions, file=f, keepIds=True ) # 10. Get platform - restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff' + restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( platform_name=protocol_settings.engine_settings.compute_platform, gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu + restrict_cpu_count=restrict_cpu, ) # 11. Set the integrator @@ -653,14 +659,10 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): ) simulation = openmm.app.Simulation( - stateA_modeller.topology, - stateA_system, - integrator, - platform=platform + stateA_modeller.topology, stateA_system, integrator, platform=platform ) try: - if not dry: # pragma: no-cover self._run_MD( simulation, @@ -677,20 +679,19 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): ) finally: - if not dry: del integrator, simulation if not dry: # pragma: no-cover output = { - 'system_pdb': shared_basepath / output_settings.preminimized_structure, - 'minimized_pdb': shared_basepath / output_settings.minimized_structure, - 'nc': shared_basepath / output_settings.production_trajectory_filename, - 'last_checkpoint': shared_basepath / output_settings.checkpoint_storage_filename, + "system_pdb": shared_basepath / output_settings.preminimized_structure, + "minimized_pdb": shared_basepath / output_settings.minimized_structure, + "nc": shared_basepath / output_settings.production_trajectory_filename, + "last_checkpoint": shared_basepath / output_settings.checkpoint_storage_filename, } # The checkpoint file can not exist if frequency > sim length - if not output['last_checkpoint'].exists(): - output['last_checkpoint'] = None + if not output["last_checkpoint"].exists(): + output["last_checkpoint"] = None # The NVT PDB can be ommitted if we don't run the simulation # Note: we could also just check the file exist @@ -698,27 +699,28 @@ class PlainMDProtocolUnit(gufe.ProtocolUnit): output_settings.equil_nvt_structure and sim_settings.equilibration_length_nvt is not None ): - output['nvt_equil_pdb'] = shared_basepath / output_settings.equil_nvt_structure + output["nvt_equil_pdb"] = shared_basepath / output_settings.equil_nvt_structure else: - output['nvt_equil_pdb'] = None + output["nvt_equil_pdb"] = None if output_settings.equil_npt_structure: - output['npt_equil_pdb'] = shared_basepath / output_settings.equil_npt_structure + output["npt_equil_pdb"] = shared_basepath / output_settings.equil_npt_structure return output else: - return {'debug': {'system': stateA_system}} + return {"debug": {"system": stateA_system}} def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], - **outputs + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], + **outputs, } diff --git a/openfe/protocols/openmm_md/plain_md_settings.py b/openfe/protocols/openmm_md/plain_md_settings.py index 568e54f2..34512412 100644 --- a/openfe/protocols/openmm_md/plain_md_settings.py +++ b/openfe/protocols/openmm_md/plain_md_settings.py @@ -7,19 +7,19 @@ This module implements the settings necessary to run MD simulations using :class:`openfe.protocols.openmm_md.plain_md_methods.py` """ + from pydantic import ConfigDict, field_validator from openfe.protocols.openmm_utils.omm_settings import ( Settings, OpenMMSolvationSettings, OpenMMEngineSettings, MDSimulationSettings, - IntegratorSettings, MDOutputSettings, + IntegratorSettings, + MDOutputSettings, OpenFFPartialChargeSettings, ) -from gufe.settings import ( - SettingsBaseModel, - OpenMMSystemGeneratorFFSettings -) +from gufe.settings import SettingsBaseModel, OpenMMSystemGeneratorFFSettings + class PlainMDProtocolSettings(Settings): model_config = ConfigDict(arbitrary_types_allowed=True) @@ -29,7 +29,7 @@ class PlainMDProtocolSettings(Settings): Number of independent MD runs to perform. """ - @field_validator('protocol_repeats') + @field_validator("protocol_repeats") def must_be_positive(cls, v): if v <= 0: errmsg = f"protocol_repeats must be a positive value, got {v}." @@ -52,4 +52,3 @@ class PlainMDProtocolSettings(Settings): # Simulations output settings output_settings: MDOutputSettings - diff --git a/openfe/protocols/openmm_rfe/__init__.py b/openfe/protocols/openmm_rfe/__init__.py index eb50c071..9bb203e1 100644 --- a/openfe/protocols/openmm_rfe/__init__.py +++ b/openfe/protocols/openmm_rfe/__init__.py @@ -12,4 +12,3 @@ from .equil_rfe_methods import ( RelativeHybridTopologyProtocolResult, RelativeHybridTopologyProtocolUnit, ) - diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py b/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py index a8e764dd..188ccb7a 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/lambdaprotocol.py @@ -2,6 +2,9 @@ # License: MIT # OpenFE note: eventually we aim to move this to openmmtools where possible +# turn off formatting since this is mostly vendored code +# fmt: off + import numpy as np import warnings import copy diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py index e7426fa7..6608ba81 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/multistate.py @@ -5,6 +5,8 @@ This is adapted from Perses: https://github.com/choderalab/perses/ See here for the license: https://github.com/choderalab/perses/blob/main/LICENSE """ +# turn off formatting since this is mostly vendored code +# fmt: off import copy import warnings diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/relative.py b/openfe/protocols/openmm_rfe/_rfe_utils/relative.py index d9cdb872..08f7ac97 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/relative.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/relative.py @@ -3,6 +3,9 @@ # The eventual goal is to move a version of this towards openmmtools # LICENSE: MIT +# turn off formatting since this is mostly vendored code +# fmt: off + import logging import openmm from openmm import unit, app diff --git a/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py b/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py index c7078db0..12f94705 100644 --- a/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py +++ b/openfe/protocols/openmm_rfe/_rfe_utils/topologyhelpers.py @@ -4,6 +4,9 @@ # building toolsets. # LICENSE: MIT +# turn off formatting since this is mostly vendored code +# fmt: off + from copy import deepcopy import itertools import logging diff --git a/openfe/protocols/openmm_rfe/equil_rfe_methods.py b/openfe/protocols/openmm_rfe/equil_rfe_methods.py index eb505729..5f821582 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_methods.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_methods.py @@ -19,6 +19,7 @@ Acknowledgements This Protocol is based on, and leverages components originating from the Perses toolkit (https://github.com/choderalab/perses). """ + from __future__ import annotations import os @@ -45,23 +46,37 @@ from rdkit import Chem import gufe from gufe import ( - settings, ChemicalSystem, LigandAtomMapping, Component, ComponentMapping, - SmallMoleculeComponent, SolventComponent, ProteinComponent, + settings, + ChemicalSystem, + LigandAtomMapping, + Component, + ComponentMapping, + SmallMoleculeComponent, + SolventComponent, + ProteinComponent, ) from .equil_rfe_settings import ( RelativeHybridTopologyProtocolSettings, - OpenMMSolvationSettings, AlchemicalSettings, LambdaSettings, - MultiStateSimulationSettings, OpenMMEngineSettings, - IntegratorSettings, MultiStateOutputSettings, + OpenMMSolvationSettings, + AlchemicalSettings, + LambdaSettings, + MultiStateSimulationSettings, + OpenMMEngineSettings, + IntegratorSettings, + MultiStateOutputSettings, OpenFFPartialChargeSettings, ) from openfe.protocols.openmm_utils.omm_settings import ( BasePartialChargeSettings, ) from ..openmm_utils import ( - system_validation, settings_validation, system_creation, - multistate_analysis, charge_generation, omm_compute, + system_validation, + settings_validation, + system_creation, + multistate_analysis, + charge_generation, + omm_compute, ) from . import _rfe_utils from ...utils import without_oechem_backend, log_system_probe @@ -72,20 +87,26 @@ from openfe.due import due, Doi logger = logging.getLogger(__name__) -due.cite(Doi("10.5281/zenodo.1297683"), - description="Perses", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.1297683"), + description="Perses", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) -due.cite(Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) -due.cite(Doi("10.1371/journal.pcbi.1005659"), - description="OpenMM", - path="openfe.protocols.openmm_rfe.equil_rfe_methods", - cite_module=True) +due.cite( + Doi("10.1371/journal.pcbi.1005659"), + description="OpenMM", + path="openfe.protocols.openmm_rfe.equil_rfe_methods", + cite_module=True, +) def _get_resname(off_mol) -> str: @@ -101,7 +122,7 @@ def _get_alchemical_charge_difference( mapping: LigandAtomMapping, nonbonded_method: str, explicit_charge_correction: bool, - solvent_component: SolventComponent + solvent_component: SolventComponent, ) -> int: """ Checks and returns the difference in formal charge between state A and B. @@ -139,28 +160,34 @@ def _get_alchemical_charge_difference( if abs(difference) > 0: if explicit_charge_correction: if nonbonded_method.lower() != "pme": - errmsg = ("Explicit charge correction when not using PME is " - "not currently supported.") + errmsg = "Explicit charge correction when not using PME is not currently supported." raise ValueError(errmsg) if abs(difference) > 1: - errmsg = (f"A charge difference of {difference} is observed " - "between the end states and an explicit charge " - "correction has been requested. Unfortunately " - "only absolute differences of 1 are supported.") + errmsg = ( + f"A charge difference of {difference} is observed " + "between the end states and an explicit charge " + "correction has been requested. Unfortunately " + "only absolute differences of 1 are supported." + ) raise ValueError(errmsg) - ion = {-1: solvent_component.positive_ion, - 1: solvent_component.negative_ion}[difference] - wmsg = (f"A charge difference of {difference} is observed " - "between the end states. This will be addressed by " - f"transforming a water into a {ion} ion") + ion = {-1: solvent_component.positive_ion, 1: solvent_component.negative_ion}[ + difference + ] + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. This will be addressed by " + f"transforming a water into a {ion} ion" + ) logger.warning(wmsg) warnings.warn(wmsg) else: - wmsg = (f"A charge difference of {difference} is observed " - "between the end states. No charge correction has " - "been requested, please account for this in your " - "final results.") + wmsg = ( + f"A charge difference of {difference} is observed " + "between the end states. No charge correction has " + "been requested, please account for this in your " + "final results." + ) logger.warning(wmsg) warnings.warn(wmsg) @@ -206,10 +233,12 @@ def _validate_alchemical_components( raise ValueError(errmsg) # Check that all alchemical components are mapped & small molecules - mapped = {'stateA': [m.componentA for m in mapping], - 'stateB': [m.componentB for m in mapping]} + mapped = { + "stateA": [m.componentA for m in mapping], + "stateB": [m.componentB for m in mapping], + } - for idx in ['stateA', 'stateB']: + for idx in ["stateA", "stateB"]: if len(alchemical_components[idx]) != len(mapped[idx]): errmsg = f"missing alchemical components in {idx}" raise ValueError(errmsg) @@ -217,9 +246,11 @@ def _validate_alchemical_components( if comp not in mapped[idx]: raise ValueError(f"Unmapped alchemical component {comp}") if not isinstance(comp, SmallMoleculeComponent): # pragma: no-cover - errmsg = ("Transformations involving non " - "SmallMoleculeComponent species {comp} " - "are not currently supported") + errmsg = ( + "Transformations involving non " + "SmallMoleculeComponent species {comp} " + "are not currently supported" + ) raise ValueError(errmsg) # Validate element changes in mappings @@ -236,13 +267,15 @@ def _validate_alchemical_components( f"Ligand B: {j} (element {atomB.GetAtomicNum()})\n" "No mass scaling is attempted in the hybrid topology, " "the average mass of the two atoms will be used in the " - "simulation") + "simulation" + ) logger.warning(wmsg) warnings.warn(wmsg) # TODO: remove this once logging is fixed class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): """Dict-like container for the output of a RelativeHybridTopologyProtocol""" + def __init__(self, **data): super().__init__(**data) # data is mapping of str(repeat_id): list[protocolunitresults] @@ -251,7 +284,7 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): raise NotImplementedError("Can't stitch together results yet") @staticmethod - def compute_mean_estimate(dGs:list[Quantity]) -> Quantity: + def compute_mean_estimate(dGs: list[Quantity]) -> Quantity: u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -269,11 +302,11 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): a Quantity defined with units. """ # TODO: Check this holds up completely for SAMS. - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] return self.compute_mean_estimate(dGs) @staticmethod - def compute_uncertainty(dGs:list[Quantity]) -> Quantity: + def compute_uncertainty(dGs: list[Quantity]) -> Quantity: u = dGs[0].u # convert all values to units of the first value, then take average of magnitude # this would avoid a screwy case where each value was in different units @@ -286,10 +319,9 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): each independent repeat """ - dGs = [pus[0].outputs['unit_estimate'] for pus in self.data.values()] + dGs = [pus[0].outputs["unit_estimate"] for pus in self.data.values()] return self.compute_uncertainty(dGs) - def get_individual_estimates(self) -> list[tuple[Quantity, Quantity]]: """Return a list of tuples containing the individual free energy estimates and associated MBAR errors for each repeat. @@ -301,12 +333,15 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): estimates (first entry) and associated MBAR estimate errors (second entry). """ - dGs = [(pus[0].outputs['unit_estimate'], - pus[0].outputs['unit_estimate_error']) - for pus in self.data.values()] + dGs = [ + (pus[0].outputs["unit_estimate"], pus[0].outputs["unit_estimate_error"]) + for pus in self.data.values() + ] return dGs - def get_forward_and_reverse_energy_analysis(self) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: + def get_forward_and_reverse_energy_analysis( + self, + ) -> list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]: """ Get a list of forward and reverse analysis of the free energies for each repeat using uncorrelated production samples. @@ -335,8 +370,9 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): UserWarning If any of the forward and reverse entries are ``None``. """ - forward_reverse = [pus[0].outputs['forward_and_reverse_energies'] - for pus in self.data.values()] + forward_reverse = [ + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data.values() + ] if None in forward_reverse: wmsg = ( @@ -366,8 +402,7 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): state i in state j """ # Loop through and get the repeats and get the matrices - overlap_stats = [pus[0].outputs['unit_mbar_overlap'] - for pus in self.data.values()] + overlap_stats = [pus[0].outputs["unit_mbar_overlap"] for pus in self.data.values()] return overlap_stats @@ -389,11 +424,11 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): from state i to state j. """ try: - repex_stats = [pus[0].outputs['replica_exchange_statistics'] - for pus in self.data.values()] + repex_stats = [ + pus[0].outputs["replica_exchange_statistics"] for pus in self.data.values() + ] except KeyError: - errmsg = ("Replica exchange statistics were not found, " - "did you run a repex calculation?") + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -407,6 +442,7 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): replica_states : List[npt.NDArray] List of replica states for each repeat """ + def is_file(filename: str): p = pathlib.Path(filename) if not p.exists(): @@ -417,15 +453,13 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): replica_states = [] for pus in self.data.values(): - nc = is_file(pus[0].outputs['nc']) + nc = is_file(pus[0].outputs["nc"]) dir_path = nc.parents[0] - chk = is_file(dir_path / pus[0].outputs['last_checkpoint']).name + chk = is_file(dir_path / pus[0].outputs["last_checkpoint"]).name reporter = multistate.MultiStateReporter( - storage=nc, checkpoint_storage=chk, open_mode='r' - ) - replica_states.append( - np.asarray(reporter.read_replica_thermodynamic_states()) + storage=nc, checkpoint_storage=chk, open_mode="r" ) + replica_states.append(np.asarray(reporter.read_replica_thermodynamic_states())) reporter.close() return replica_states @@ -439,8 +473,9 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): ------- equilibration_lengths : list[float] """ - equilibration_lengths = [pus[0].outputs['equilibration_iterations'] - for pus in self.data.values()] + equilibration_lengths = [ + pus[0].outputs["equilibration_iterations"] for pus in self.data.values() + ] return equilibration_lengths @@ -453,8 +488,7 @@ class RelativeHybridTopologyProtocolResult(gufe.ProtocolResult): ------- production_lengths : list[float] """ - production_lengths = [pus[0].outputs['production_iterations'] - for pus in self.data.values()] + production_lengths = [pus[0].outputs["production_iterations"] for pus in self.data.values()] return production_lengths @@ -472,6 +506,7 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyResult` :class:`openfe.protocols.openmm_rfe.RelativeHybridTopologyProtocolUnit` """ + result_cls = RelativeHybridTopologyProtocolResult _settings_cls = RelativeHybridTopologyProtocolSettings _settings: RelativeHybridTopologyProtocolSettings @@ -498,7 +533,7 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): ), partial_charge_settings=OpenFFPartialChargeSettings(), solvation_settings=OpenMMSolvationSettings(), - alchemical_settings=AlchemicalSettings(softcore_LJ='gapsys'), + alchemical_settings=AlchemicalSettings(softcore_LJ="gapsys"), lambda_settings=LambdaSettings(), simulation_settings=MultiStateSimulationSettings( equilibration_length=1.0 * unit.nanosecond, @@ -560,10 +595,12 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): if mapping.get_alchemical_charge_difference() != 0: # apply the recommended charge change settings taken from the industry benchmarking as fast settings not validated # - info = ("Charge changing transformation between ligands " - f"{mapping.componentA.name} and {mapping.componentB.name}. " - "A more expensive protocol with 22 lambda windows, sampled " - "for 20 ns each, will be used here.") + info = ( + "Charge changing transformation between ligands " + f"{mapping.componentA.name} and {mapping.componentB.name}. " + "A more expensive protocol with 22 lambda windows, sampled " + "for 20 ns each, will be used here." + ) logger.info(info) protocol_settings.alchemical_settings.explicit_charge_correction = True protocol_settings.simulation_settings.production_length = 20 * unit.nanosecond @@ -576,7 +613,6 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): return protocol_settings - def _create( self, stateA: ChemicalSystem, @@ -589,9 +625,7 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): raise NotImplementedError("Can't extend simulations yet") # Get alchemical components & validate them + mapping - alchem_comps = system_validation.get_alchemical_components( - stateA, stateB - ) + alchem_comps = system_validation.get_alchemical_components(stateA, stateB) _validate_alchemical_components(alchem_comps, mapping) ligandmapping = mapping[0] if isinstance(mapping, list) else mapping @@ -600,31 +634,32 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): system_validation.validate_solvent(stateA, nonbond) # Validate solvation settings - settings_validation.validate_openmm_solvation_settings( - self.settings.solvation_settings - ) + settings_validation.validate_openmm_solvation_settings(self.settings.solvation_settings) # Validate protein component system_validation.validate_protein(stateA) # actually create and return Units - Anames = ','.join(c.name for c in alchem_comps['stateA']) - Bnames = ','.join(c.name for c in alchem_comps['stateB']) + Anames = ",".join(c.name for c in alchem_comps["stateA"]) + Bnames = ",".join(c.name for c in alchem_comps["stateB"]) # our DAG has no dependencies, so just list units n_repeats = self.settings.protocol_repeats - units = [RelativeHybridTopologyProtocolUnit( - protocol=self, - stateA=stateA, stateB=stateB, - ligandmapping=ligandmapping, - generation=0, repeat_id=int(uuid.uuid4()), - name=f'{Anames} to {Bnames} repeat {i} generation 0') - for i in range(n_repeats)] + units = [ + RelativeHybridTopologyProtocolUnit( + protocol=self, + stateA=stateA, + stateB=stateB, + ligandmapping=ligandmapping, + generation=0, + repeat_id=int(uuid.uuid4()), + name=f"{Anames} to {Bnames} repeat {i} generation 0", + ) + for i in range(n_repeats) + ] return units - def _gather( - self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult] - ) -> dict[str, Any]: + def _gather(self, protocol_dag_results: Iterable[gufe.ProtocolDAGResult]) -> dict[str, Any]: # result units will have a repeat_id and generations within this repeat_id # first group according to repeat_id unsorted_repeats = defaultdict(list) @@ -634,12 +669,12 @@ class RelativeHybridTopologyProtocol(gufe.Protocol): if not pu.ok(): continue - unsorted_repeats[pu.outputs['repeat_id']].append(pu) + unsorted_repeats[pu.outputs["repeat_id"]].append(pu) # then sort by generation within each repeat_id list repeats: dict[str, list[gufe.ProtocolUnitResult]] = {} for k, v in unsorted_repeats.items(): - repeats[str(k)] = sorted(v, key=lambda x: x.outputs['generation']) + repeats[str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) # returns a dict of repeat_id: sorted list of ProtocolUnitResult return repeats @@ -691,7 +726,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): stateB=stateB, ligandmapping=ligandmapping, repeat_id=repeat_id, - generation=generation + generation=generation, ) @staticmethod @@ -710,9 +745,9 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): Dictionary of dictionary of OpenFF Molecules to add, keyed by state and SmallMoleculeComponent. """ - for smc, mol in chain(off_small_mols['stateA'], - off_small_mols['stateB'], - off_small_mols['both']): + for smc, mol in chain( + off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] + ): charge_generation.assign_offmol_partial_charges( offmol=mol, overwrite=False, @@ -722,9 +757,9 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): nagl_model=charge_settings.nagl_model, ) - def run(self, *, dry=False, verbose=True, - scratch_basepath=None, - shared_basepath=None) -> dict[str, Any]: + def run( + self, *, dry=False, verbose=True, scratch_basepath=None, shared_basepath=None + ) -> dict[str, Any]: """Run the relative free energy calculation. Parameters @@ -755,20 +790,24 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): if verbose: self.logger.info("Preparing the hybrid topology simulation") if scratch_basepath is None: - scratch_basepath = pathlib.Path('.') + scratch_basepath = pathlib.Path(".") if shared_basepath is None: # use cwd - shared_basepath = pathlib.Path('.') + shared_basepath = pathlib.Path(".") # 0. General setup and settings dependency resolution step # Extract relevant settings - protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs['protocol'].settings - stateA = self._inputs['stateA'] - stateB = self._inputs['stateB'] - mapping = self._inputs['ligandmapping'] + protocol_settings: RelativeHybridTopologyProtocolSettings = self._inputs[ + "protocol" + ].settings + stateA = self._inputs["stateA"] + stateB = self._inputs["stateB"] + mapping = self._inputs["ligandmapping"] - forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = protocol_settings.forcefield_settings + forcefield_settings: settings.OpenMMSystemGeneratorFFSettings = ( + protocol_settings.forcefield_settings + ) thermo_settings: settings.ThermoSettings = protocol_settings.thermo_settings alchem_settings: AlchemicalSettings = protocol_settings.alchemical_settings lambda_settings: LambdaSettings = protocol_settings.lambda_settings @@ -780,8 +819,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # is the timestep good for the mass? settings_validation.validate_timestep( - forcefield_settings.hydrogen_mass, - integrator_settings.timestep + forcefield_settings.hydrogen_mass, integrator_settings.timestep ) # TODO: Also validate various conversions? # Convert various time based inputs to steps/iterations @@ -823,10 +861,13 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # and keep the molecule around to maintain the partial charges off_small_mols: dict[str, list[tuple[SmallMoleculeComponent, OFFMolecule]]] off_small_mols = { - 'stateA': [(mapping.componentA, mapping.componentA.to_openff())], - 'stateB': [(mapping.componentB, mapping.componentB.to_openff())], - 'both': [(m, m.to_openff()) for m in small_mols - if (m != mapping.componentA and m != mapping.componentB)] + "stateA": [(mapping.componentA, mapping.componentA.to_openff())], + "stateB": [(mapping.componentB, mapping.componentB.to_openff())], + "both": [ + (m, m.to_openff()) + for m in small_mols + if (m != mapping.componentA and m != mapping.componentB) + ], } self._assign_partial_charges(charge_settings, off_small_mols) @@ -851,18 +892,16 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # c. force the creation of parameters # This is necessary because we need to have the FF templates # registered ahead of solvating the system. - for smc, mol in chain(off_small_mols['stateA'], - off_small_mols['stateB'], - off_small_mols['both']): - system_generator.create_system(mol.to_topology().to_openmm(), - molecules=[mol]) + for smc, mol in chain( + off_small_mols["stateA"], off_small_mols["stateB"], off_small_mols["both"] + ): + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # c. get OpenMM Modeller + a dictionary of resids for each component stateA_modeller, comp_resids = system_creation.get_omm_modeller( protein_comp=protein_comp, solvent_comp=solvent_comp, - small_mols=dict(chain(off_small_mols['stateA'], - off_small_mols['both'])), + small_mols=dict(chain(off_small_mols["stateA"], off_small_mols["both"])), omm_forcefield=system_generator.forcefield, solvent_settings=solvation_settings, ) @@ -870,9 +909,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # d. get topology & positions # Note: roundtrip positions to remove vec3 issues stateA_topology = stateA_modeller.getTopology() - stateA_positions = to_openmm( - from_openmm(stateA_modeller.getPositions()) - ) + stateA_positions = to_openmm(from_openmm(stateA_modeller.getPositions())) # e. create the stateA System # Block out oechem backend in system_generator calls to avoid @@ -880,8 +917,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): with without_oechem_backend(): stateA_system = system_generator.create_system( stateA_modeller.topology, - molecules=[m for _, m in chain(off_small_mols['stateA'], - off_small_mols['both'])], + molecules=[m for _, m in chain(off_small_mols["stateA"], off_small_mols["both"])], ) # 2. Get stateB system @@ -889,7 +925,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): stateB_topology, stateB_alchem_resids = _rfe_utils.topologyhelpers.combined_topology( stateA_topology, # zeroth item (there's only one) then get the OFF representation - off_small_mols['stateB'][0][1].to_topology().to_openmm(), + off_small_mols["stateB"][0][1].to_topology().to_openmm(), exclude_resids=comp_resids[mapping.componentA], ) @@ -899,15 +935,18 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): with without_oechem_backend(): stateB_system = system_generator.create_system( stateB_topology, - molecules=[m for _, m in chain(off_small_mols['stateB'], - off_small_mols['both'])], + molecules=[m for _, m in chain(off_small_mols["stateB"], off_small_mols["both"])], ) # c. Define correspondence mappings between the two systems ligand_mappings = _rfe_utils.topologyhelpers.get_system_mappings( mapping.componentA_to_componentB, - stateA_system, stateA_topology, comp_resids[mapping.componentA], - stateB_system, stateB_topology, stateB_alchem_resids, + stateA_system, + stateA_topology, + comp_resids[mapping.componentA], + stateB_system, + stateB_topology, + stateB_alchem_resids, # These are non-optional settings for this method fix_constraints=True, ) @@ -916,35 +955,47 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # and transform them if alchem_settings.explicit_charge_correction: alchem_water_resids = _rfe_utils.topologyhelpers.get_alchemical_waters( - stateA_topology, stateA_positions, + stateA_topology, + stateA_positions, charge_difference, alchem_settings.explicit_charge_correction_cutoff, ) _rfe_utils.topologyhelpers.handle_alchemical_waters( - alchem_water_resids, stateB_topology, stateB_system, - ligand_mappings, charge_difference, + alchem_water_resids, + stateB_topology, + stateB_system, + ligand_mappings, + charge_difference, solvent_comp, ) # e. Finally get the positions stateB_positions = _rfe_utils.topologyhelpers.set_and_check_new_positions( - ligand_mappings, stateA_topology, stateB_topology, - old_positions=ensure_quantity(stateA_positions, 'openmm'), - insert_positions=ensure_quantity(off_small_mols['stateB'][0][1].conformers[0], 'openmm'), + ligand_mappings, + stateA_topology, + stateB_topology, + old_positions=ensure_quantity(stateA_positions, "openmm"), + insert_positions=ensure_quantity( + off_small_mols["stateB"][0][1].conformers[0], "openmm" + ), ) # 3. Create the hybrid topology # a. Get softcore potential settings - if alchem_settings.softcore_LJ.lower() == 'gapsys': + if alchem_settings.softcore_LJ.lower() == "gapsys": softcore_LJ_v2 = True - elif alchem_settings.softcore_LJ.lower() == 'beutler': + elif alchem_settings.softcore_LJ.lower() == "beutler": softcore_LJ_v2 = False # b. Get hybrid topology factory hybrid_factory = _rfe_utils.relative.HybridTopologyFactory( - stateA_system, stateA_positions, stateA_topology, - stateB_system, stateB_positions, stateB_topology, - old_to_new_atom_map=ligand_mappings['old_to_new_atom_map'], - old_to_new_core_atom_map=ligand_mappings['old_to_new_core_atom_map'], + stateA_system, + stateA_positions, + stateA_topology, + stateB_system, + stateB_positions, + stateB_topology, + old_to_new_atom_map=ligand_mappings["old_to_new_atom_map"], + old_to_new_core_atom_map=ligand_mappings["old_to_new_core_atom_map"], use_dispersion_correction=alchem_settings.use_dispersion_correction, softcore_alpha=alchem_settings.softcore_alpha, softcore_LJ_v2=softcore_LJ_v2, @@ -955,24 +1006,25 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # 4. Create lambda schedule # TODO - this should be exposed to users, maybe we should offer the # ability to print the schedule directly in settings? + # fmt: off lambdas = _rfe_utils.lambdaprotocol.LambdaProtocol( functions=lambda_settings.lambda_functions, windows=lambda_settings.lambda_windows ) - + # fmt: on # PR #125 temporarily pin lambda schedule spacing to n_replicas n_replicas = sampler_settings.n_replicas if n_replicas != len(lambdas.lambda_schedule): - errmsg = (f"Number of replicas {n_replicas} " - f"does not equal the number of lambda windows " - f"{len(lambdas.lambda_schedule)}") + errmsg = ( + f"Number of replicas {n_replicas} " + f"does not equal the number of lambda windows " + f"{len(lambdas.lambda_schedule)}" + ) raise ValueError(errmsg) # 9. Create the multistate reporter # Get the sub selection of the system to print coords for - selection_indices = hybrid_factory.hybrid_topology.select( - output_settings.output_indices - ) + selection_indices = hybrid_factory.hybrid_topology.select(output_settings.output_indices) # a. Create the multistate reporter # convert checkpoint_interval from time to iterations @@ -989,7 +1041,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): numerator=output_settings.positions_write_frequency, denominator=sampler_settings.time_per_iteration, numerator_name="output settings' position_write_frequency", - denominator_name="sampler settings' time_per_iteration" + denominator_name="sampler settings' time_per_iteration", ) else: pos_interval = 0 @@ -999,7 +1051,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): numerator=output_settings.velocities_write_frequency, denominator=sampler_settings.time_per_iteration, numerator_name="output settings' velocity_write_frequency", - denominator_name="sampler settings' time_per_iteration" + denominator_name="sampler settings' time_per_iteration", ) else: vel_interval = 0 @@ -1014,28 +1066,29 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): ) # b. Write out a PDB containing the subsampled hybrid state + # fmt: off bfactors = np.zeros_like(selection_indices, dtype=float) # solvent bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_old_atoms']))] = 0.25 # lig A bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['core_atoms']))] = 0.50 # core bfactors[np.in1d(selection_indices, list(hybrid_factory._atom_classes['unique_new_atoms']))] = 0.75 # lig B # bfactors[np.in1d(selection_indices, protein)] = 1.0 # prot+cofactor - if len(selection_indices) > 0: traj = mdtraj.Trajectory( - hybrid_factory.hybrid_positions[selection_indices, :], - hybrid_factory.hybrid_topology.subset(selection_indices), + hybrid_factory.hybrid_positions[selection_indices, :], + hybrid_factory.hybrid_topology.subset(selection_indices), ).save_pdb( shared_basepath / output_settings.output_structure, bfactors=bfactors, ) + # fmt: on # 10. Get compute platform # restrict to a single CPU if running vacuum - restrict_cpu = forcefield_settings.nonbonded_method.lower() == 'nocutoff' + restrict_cpu = forcefield_settings.nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( platform_name=protocol_settings.engine_settings.compute_platform, gpu_device_index=protocol_settings.engine_settings.gpu_device_index, - restrict_cpu_count=restrict_cpu + restrict_cpu_count=restrict_cpu, ) # 11. Set the integrator @@ -1044,8 +1097,10 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): # there are virtual sites in the system if hybrid_factory.has_virtual_sites: if not integrator_settings.reassign_velocities: - errmsg = ("Simulations with virtual sites without velocity " - "reassignments are unstable in openmmtools") + errmsg = ( + "Simulations with virtual sites without velocity " + "reassignments are unstable in openmmtools" + ) raise ValueError(errmsg) # b. create langevin integrator @@ -1064,9 +1119,11 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): simulation_settings=sampler_settings, ) # convert early_termination_target_error from kcal/mol to kT - early_termination_target_error = settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermo_settings.temperature, - sampler_settings.early_termination_target_error, + early_termination_target_error = ( + settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermo_settings.temperature, + sampler_settings.early_termination_target_error, + ) ) if sampler_settings.sampler_method.lower() == "repex": @@ -1086,7 +1143,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): flatness_criteria=sampler_settings.sams_flatness_criteria, gamma0=sampler_settings.sams_gamma0, ) - elif sampler_settings.sampler_method.lower() == 'independent': + elif sampler_settings.sampler_method.lower() == "independent": sampler = _rfe_utils.multistate.HybridMultiStateSampler( mcmc_moves=integrator, hybrid_factory=hybrid_factory, @@ -1107,17 +1164,21 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): minimization_platform=platform.getName(), # Set minimization steps to None when running in dry mode # otherwise do a very small one to avoid NaNs - minimization_steps=100 if not dry else None + minimization_steps=100 if not dry else None, ) try: # Create context caches (energy + sampler) energy_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler_context_cache = openmmtools.cache.ContextCache( - capacity=None, time_to_live=None, platform=platform, + capacity=None, + time_to_live=None, + platform=platform, ) sampler.energy_context_cache = energy_context_cache @@ -1134,17 +1195,13 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): if verbose: self.logger.info("Running equilibration phase") - sampler.equilibrate( - int(equil_steps / steps_per_iteration) - ) + sampler.equilibrate(int(equil_steps / steps_per_iteration)) # production if verbose: self.logger.info("Running production phase") - sampler.extend( - int(prod_steps / steps_per_iteration) - ) + sampler.extend(int(prod_steps / steps_per_iteration)) self.logger.info("Production phase complete") @@ -1161,8 +1218,10 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): else: # clean up the reporter file - fns = [shared_basepath / output_settings.output_filename, - shared_basepath / output_settings.checkpoint_storage_filename] + fns = [ + shared_basepath / output_settings.output_filename, + shared_basepath / output_settings.checkpoint_storage_filename, + ] for fn in fns: os.remove(fn) finally: @@ -1178,8 +1237,7 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): for context in list(sampler_context_cache._lru._data.keys()): del sampler_context_cache._lru._data[context] # cautiously clear out the global context cache too - for context in list( - openmmtools.cache.global_context_cache._lru._data.keys()): + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] del sampler_context_cache, energy_context_cache @@ -1188,87 +1246,72 @@ class RelativeHybridTopologyProtocolUnit(gufe.ProtocolUnit): del integrator, sampler if not dry: # pragma: no-cover - return { - 'nc': nc, - 'last_checkpoint': chk, - **analyzer.unit_results_dict - } + return {"nc": nc, "last_checkpoint": chk, **analyzer.unit_results_dict} else: - return {'debug': {'sampler': sampler}} + return {"debug": {"sampler": sampler}} @staticmethod def structural_analysis(scratch, shared) -> dict: # don't put energy analysis in here, it uses the open file reporter # whereas structural stuff requires that the file handle is closed # TODO: we should just make openfe_analysis write an npz instead! - analysis_out = scratch / 'structural_analysis.json' + analysis_out = scratch / "structural_analysis.json" ret = subprocess.run( [ - 'openfe_analysis', # CLI entry point - 'RFE_analysis', # CLI option + "openfe_analysis", # CLI entry point + "RFE_analysis", # CLI option str(shared), # Where the simulation.nc fille - str(analysis_out) # Where the analysis json file is written + str(analysis_out), # Where the analysis json file is written ], stdout=subprocess.PIPE, - stderr=subprocess.PIPE + stderr=subprocess.PIPE, ) if ret.returncode: - return {'structural_analysis_error': ret.stderr} + return {"structural_analysis_error": ret.stderr} - with open(analysis_out, 'rb') as f: + with open(analysis_out, "rb") as f: data = json.load(f) savedir = pathlib.Path(shared) - if d := data['protein_2D_RMSD']: + if d := data["protein_2D_RMSD"]: fig = plotting.plot_2D_rmsd(d) fig.savefig(savedir / "protein_2D_RMSD.png") plt.close(fig) - f2 = plotting.plot_ligand_COM_drift(data['time(ps)'], data['ligand_wander']) + f2 = plotting.plot_ligand_COM_drift(data["time(ps)"], data["ligand_wander"]) f2.savefig(savedir / "ligand_COM_drift.png") plt.close(f2) - f3 = plotting.plot_ligand_RMSD(data['time(ps)'], data['ligand_RMSD']) + f3 = plotting.plot_ligand_RMSD(data["time(ps)"], data["ligand_RMSD"]) f3.savefig(savedir / "ligand_RMSD.png") plt.close(f3) # Save to numpy compressed format (~ 6x more space efficient than JSON) np.savez_compressed( shared / "structural_analysis.npz", - protein_RMSD=np.asarray( - data["protein_RMSD"], dtype=np.float32 - ), - ligand_RMSD=np.asarray( - data["ligand_RMSD"], dtype=np.float32 - ), - ligand_COM_drift=np.asarray( - data["ligand_wander"], dtype=np.float32 - ), - protein_2D_RMSD=np.asarray( - data["protein_2D_RMSD"], dtype=np.float32 - ), - time_ps=np.asarray( - data["time(ps)"], dtype=np.float32 - ), + protein_RMSD=np.asarray(data["protein_RMSD"], dtype=np.float32), + ligand_RMSD=np.asarray(data["ligand_RMSD"], dtype=np.float32), + ligand_COM_drift=np.asarray(data["ligand_wander"], dtype=np.float32), + protein_2D_RMSD=np.asarray(data["protein_2D_RMSD"], dtype=np.float32), + time_ps=np.asarray(data["time(ps)"], dtype=np.float32), ) - return {'structural_analysis': shared / "structural_analysis.npz"} + return {"structural_analysis": shared / "structural_analysis.npz"} def _execute( - self, ctx: gufe.Context, **kwargs, + self, + ctx: gufe.Context, + **kwargs, ) -> dict[str, Any]: log_system_probe(logging.INFO, paths=[ctx.scratch]) - outputs = self.run(scratch_basepath=ctx.scratch, - shared_basepath=ctx.shared) + outputs = self.run(scratch_basepath=ctx.scratch, shared_basepath=ctx.shared) - structural_analysis_outputs = self.structural_analysis( - ctx.scratch, ctx.shared - ) + structural_analysis_outputs = self.structural_analysis(ctx.scratch, ctx.shared) return { - 'repeat_id': self._inputs['repeat_id'], - 'generation': self._inputs['generation'], + "repeat_id": self._inputs["repeat_id"], + "generation": self._inputs["generation"], **outputs, **structural_analysis_outputs, } diff --git a/openfe/protocols/openmm_rfe/equil_rfe_settings.py b/openfe/protocols/openmm_rfe/equil_rfe_settings.py index 739b4e2f..aee01ae7 100644 --- a/openfe/protocols/openmm_rfe/equil_rfe_settings.py +++ b/openfe/protocols/openmm_rfe/equil_rfe_settings.py @@ -6,6 +6,7 @@ This module implements the necessary settings necessary to run relative free energies using :class:`openfe.protocols.openmm_rfe.equil_rfe_methods.py` """ + from __future__ import annotations from typing import Literal @@ -28,15 +29,16 @@ from openfe.protocols.openmm_utils.omm_settings import ( OpenFFPartialChargeSettings, ) + class LambdaSettings(SettingsBaseModel): - model_config = ConfigDict(extra='ignore', arbitrary_types_allowed=True) + model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True) """Lambda schedule settings. Settings controlling the lambda schedule, these include the switching function type, and the number of windows. """ - lambda_functions: str = 'default' + lambda_functions: str = "default" """ Key of which switching functions to use for alchemical mutation. Default 'default'. @@ -46,7 +48,7 @@ class LambdaSettings(SettingsBaseModel): class AlchemicalSettings(SettingsBaseModel): - model_config = ConfigDict(extra='ignore', arbitrary_types_allowed=True) + model_config = ConfigDict(extra="ignore", arbitrary_types_allowed=True) """Settings for the alchemical protocol @@ -65,7 +67,7 @@ class AlchemicalSettings(SettingsBaseModel): Whether to use dispersion correction in the hybrid topology state. Default False. """ - softcore_LJ: Literal['gapsys', 'beutler'] + softcore_LJ: Literal["gapsys", "beutler"] """ Whether to use the LJ softcore function as defined by Gapsys et al. JCTC 2012, or the one by Beutler et al. Chem. Phys. Lett. 1994. @@ -73,7 +75,7 @@ class AlchemicalSettings(SettingsBaseModel): """ softcore_alpha: float = 0.85 """Softcore alpha parameter. Default 0.85""" - turn_off_core_unique_exceptions:bool = False + turn_off_core_unique_exceptions: bool = False """ Whether to turn off interactions for new exceptions (not just 1,4s) at lambda 0 and old exceptions at lambda 1 between unique atoms and core @@ -107,7 +109,7 @@ class RelativeHybridTopologyProtocolSettings(Settings): difference, while the variance between repeats is used as the uncertainty. """ - @field_validator('protocol_repeats') + @field_validator("protocol_repeats") def must_be_positive(cls, v): if v <= 0: errmsg = f"protocol_repeats must be a positive value, got {v}." diff --git a/openfe/protocols/openmm_septop/base.py b/openfe/protocols/openmm_septop/base.py index 5fdec16f..5a4799d6 100644 --- a/openfe/protocols/openmm_septop/base.py +++ b/openfe/protocols/openmm_septop/base.py @@ -13,6 +13,7 @@ TODO * Add in all the AlchemicalFactory and AlchemicalRegion kwargs as settings. """ + import abc import logging import pathlib @@ -118,9 +119,7 @@ def _pre_equilibrate( """ # Prep the simulation object # Restrict CPU count if no cutoff - restrict_cpu = ( - settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" - ) + restrict_cpu = settings["forcefield_settings"].nonbonded_method.lower() == "nocutoff" platform = omm_compute.get_openmm_platform( platform_name=settings["engine_settings"].compute_platform, gpu_device_index=settings["engine_settings"].gpu_device_index, @@ -177,8 +176,7 @@ def _pre_equilibrate( if endstate == "A" or endstate == "B" or endstate == "AB": if unfrozen_outsettings.production_trajectory_filename: unfrozen_outsettings.production_trajectory_filename = ( - unfrozen_outsettings.production_trajectory_filename - + f"_state{endstate}.xtc" + unfrozen_outsettings.production_trajectory_filename + f"_state{endstate}.xtc" ) if unfrozen_outsettings.preminimized_structure: unfrozen_outsettings.preminimized_structure = ( @@ -318,13 +316,9 @@ class BaseSepTopSetupUnit(gufe.ProtocolUnit): split_alchemical_forces=True, ) # Alchemical Region for ligand A - alchemical_region_A = AlchemicalRegion( - alchemical_atoms=alchem_indices_A, name="A" - ) + alchemical_region_A = AlchemicalRegion(alchemical_atoms=alchem_indices_A, name="A") # Alchemical Region for ligand B - alchemical_region_B = AlchemicalRegion( - alchemical_atoms=alchem_indices_B, name="B" - ) + alchemical_region_B = AlchemicalRegion(alchemical_atoms=alchem_indices_B, name="B") alchemical_system = alchemical_factory.create_alchemical_system( system, [alchemical_region_A, alchemical_region_B] ) @@ -470,7 +464,7 @@ class BaseSepTopSetupUnit(gufe.ProtocolUnit): # there are virtual sites in the system if integrator_settings.reassign_velocities: return - + for ix in range(system.getNumParticles()): if system.isVirtualSite(ix): errmsg = ( @@ -553,9 +547,7 @@ class BaseSepTopSetupUnit(gufe.ProtocolUnit): # smiles roundtripping between rdkit and oechem with without_oechem_backend(): for mol in smc_components.values(): - system_generator.create_system( - mol.to_topology().to_openmm(), molecules=[mol] - ) + system_generator.create_system(mol.to_topology().to_openmm(), molecules=[mol]) # get OpenMM modeller + dictionary of resids for each component system_modeller, comp_resids = system_creation.get_omm_modeller( @@ -1068,16 +1060,12 @@ class BaseSepTopRunUnit(gufe.ProtocolUnit): sampler : multistate.MultistateSampler A sampler configured for the chosen sampling method. """ - rta_its, rta_min_its = ( - settings_validation.convert_real_time_analysis_iterations( - simulation_settings=simulation_settings, - ) + rta_its, rta_min_its = settings_validation.convert_real_time_analysis_iterations( + simulation_settings=simulation_settings, ) - et_target_err = ( - settings_validation.convert_target_error_from_kcal_per_mole_to_kT( - thermo_settings.temperature, - simulation_settings.early_termination_target_error, - ) + et_target_err = settings_validation.convert_target_error_from_kcal_per_mole_to_kT( + thermo_settings.temperature, + simulation_settings.early_termination_target_error, ) # Select the right sampler @@ -1164,9 +1152,7 @@ class BaseSepTopRunUnit(gufe.ProtocolUnit): # minimize if self.verbose: self.logger.info("minimizing systems") - sampler.minimize( - max_iterations=settings["simulation_settings"].minimization_steps - ) + sampler.minimize(max_iterations=settings["simulation_settings"].minimization_steps) # equilibrate if self.verbose: self.logger.info("equilibrating systems") @@ -1201,8 +1187,7 @@ class BaseSepTopRunUnit(gufe.ProtocolUnit): # clean up the reporter file fns = [ self.shared_basepath / settings["output_settings"].output_filename, - self.shared_basepath - / settings["output_settings"].checkpoint_storage_filename, + self.shared_basepath / settings["output_settings"].checkpoint_storage_filename, ] for fn in fns: fn.unlink() @@ -1335,9 +1320,7 @@ class BaseSepTopRunUnit(gufe.ProtocolUnit): for context in list(sampler_ctx_cache._lru._data.keys()): del sampler_ctx_cache._lru._data[context] # cautiously clear out the global context cache too - for context in list( - openmmtools.cache.global_context_cache._lru._data.keys() - ): + for context in list(openmmtools.cache.global_context_cache._lru._data.keys()): del openmmtools.cache.global_context_cache._lru._data[context] del sampler_ctx_cache, energy_ctx_cache diff --git a/openfe/protocols/openmm_septop/equil_septop_method.py b/openfe/protocols/openmm_septop/equil_septop_method.py index 02775273..501a4687 100644 --- a/openfe/protocols/openmm_septop/equil_septop_method.py +++ b/openfe/protocols/openmm_septop/equil_septop_method.py @@ -28,6 +28,7 @@ the Mobleylab (https://github.com/MobleyLab/SeparatedTopologies) as well as femto (https://github.com/Psivant/femto). """ + from __future__ import annotations import copy @@ -158,8 +159,7 @@ def _get_mdtraj_from_openmm( positions_in_mdtraj_format, mdtraj_topology, unitcell_lengths=np.array([lx, ly, lz]), - unitcell_angles=np.array( - [np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)]), + unitcell_angles=np.array([np.rad2deg(alpha), np.rad2deg(beta), np.rad2deg(gamma)]), ) return mdtraj_system @@ -250,7 +250,7 @@ class SepTopComplexMixin: * output_settings: MultiStateOutputSettings * restraint_settings: BoreschRestraintSettings """ - prot_settings = self._inputs["protocol"].settings # type: ignore + prot_settings = self._inputs["protocol"].settings # type: ignore settings = { "forcefield_settings": prot_settings.forcefield_settings, @@ -336,7 +336,7 @@ class SepTopSolventMixin: * output_settings: MultiStateOutputSettings * restraint_settings: BaseRestraintsSettings """ - prot_settings = self._inputs["protocol"].settings # type: ignore + prot_settings = self._inputs["protocol"].settings # type: ignore settings = { "forcefield_settings": prot_settings.forcefield_settings, @@ -472,9 +472,7 @@ class SepTopProtocolResult(gufe.ProtocolResult): We assume that both list of items are in the right order. """ combined_dG: list[tuple[Quantity, Quantity]] = [] - for comp, corrA, corrB in zip( - complex_dG, standard_state_corrA_dG, standard_state_corrB_dG - ): + for comp, corrA, corrB in zip(complex_dG, standard_state_corrA_dG, standard_state_corrB_dG): # No need to convert unit types, since pint takes care of that # except that mypy hates it because pint isn't typed properly... # No need to add errors since there's just the one @@ -629,14 +627,11 @@ class SepTopProtocolResult(gufe.ProtocolResult): given thermodynamic cycle leg. """ - forward_reverse: dict[ - str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]] - ] = {} + forward_reverse: dict[str, list[Optional[dict[str, Union[npt.NDArray, Quantity]]]]] = {} for key in ["complex", "solvent"]: forward_reverse[key] = [ - pus[0].outputs["forward_and_reverse_energies"] - for pus in self.data[key].values() + pus[0].outputs["forward_and_reverse_energies"] for pus in self.data[key].values() ] if None in forward_reverse[key]: @@ -710,14 +705,10 @@ class SepTopProtocolResult(gufe.ProtocolResult): try: for key in ["complex", "solvent"]: repex_stats[key] = [ - pus[0].outputs["replica_exchange_statistics"] - for pus in self.data[key].values() + pus[0].outputs["replica_exchange_statistics"] for pus in self.data[key].values() ] except KeyError: - errmsg = ( - "Replica exchange statistics were not found, " - "did you run a repex calculation?" - ) + errmsg = "Replica exchange statistics were not found, did you run a repex calculation?" raise ValueError(errmsg) return repex_stats @@ -784,8 +775,7 @@ class SepTopProtocolResult(gufe.ProtocolResult): for key in ["complex", "solvent"]: equilibration_lengths[key] = [ - pus[0].outputs["equilibration_iterations"] - for pus in self.data[key].values() + pus[0].outputs["equilibration_iterations"] for pus in self.data[key].values() ] return equilibration_lengths @@ -808,13 +798,14 @@ class SepTopProtocolResult(gufe.ProtocolResult): for key in ["complex", "solvent"]: production_lengths[key] = [ - pus[0].outputs["production_iterations"] - for pus in self.data[key].values() + pus[0].outputs["production_iterations"] for pus in self.data[key].values() ] return production_lengths - def restraint_geometries(self) -> tuple[list[BoreschRestraintGeometry], list[BoreschRestraintGeometry]]: + def restraint_geometries( + self, + ) -> tuple[list[BoreschRestraintGeometry], list[BoreschRestraintGeometry]]: """ Get a list of the restraint geometries for the complex simulations. These define the atoms that have @@ -830,15 +821,11 @@ class SepTopProtocolResult(gufe.ProtocolResult): in the system that are involved in the restraint of ligand B. """ geometry_A = [ - BoreschRestraintGeometry.model_validate( - pus[0].outputs["restraint_geometry_A"] - ) + BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_A"]) for pus in self.data["complex_setup"].values() ] geometry_B = [ - BoreschRestraintGeometry.model_validate( - pus[0].outputs["restraint_geometry_B"] - ) + BoreschRestraintGeometry.model_validate(pus[0].outputs["restraint_geometry_B"]) for pus in self.data["complex_setup"].values() ] @@ -862,9 +849,7 @@ class SepTopProtocolResult(gufe.ProtocolResult): for key in ["complex", "solvent"]: indices[key] = [] for pus in self.data[key].values(): - indices[key].append( - pus[0].outputs["selection_indices"] - ) + indices[key].append(pus[0].outputs["selection_indices"]) return indices @@ -1184,9 +1169,7 @@ class SepTopProtocol(gufe.Protocol): raise ValueError(errmsg) @staticmethod - def _validate_alchemical_components( - alchemical_components: dict[str, list[Component]] - ) -> None: + def _validate_alchemical_components(alchemical_components: dict[str, list[Component]]) -> None: """ Checks that the ChemicalSystem alchemical components are correct. @@ -1316,9 +1299,7 @@ class SepTopProtocol(gufe.Protocol): self, stateA: ChemicalSystem, stateB: ChemicalSystem, - mapping: Optional[ - Union[gufe.ComponentMapping, list[gufe.ComponentMapping]] - ] = None, + mapping: Optional[Union[gufe.ComponentMapping, list[gufe.ComponentMapping]]] = None, extends: Optional[gufe.ProtocolDAGResult] = None, ) -> list[gufe.ProtocolUnit]: # TODO: extensions @@ -1396,13 +1377,9 @@ class SepTopProtocol(gufe.Protocol): alchname_B = alchem_comps["stateB"][0].name solvent_setup = create_setup_units(SepTopSolventSetupUnit, "solvent") - solvent_run = create_run_units( - SepTopSolventRunUnit, "solvent", setup=solvent_setup - ) + solvent_run = create_run_units(SepTopSolventRunUnit, "solvent", setup=solvent_setup) complex_setup = create_setup_units(SepTopComplexSetupUnit, "complex") - complex_run = create_run_units( - SepTopComplexRunUnit, "complex", setup=complex_setup - ) + complex_run = create_run_units(SepTopComplexRunUnit, "complex", setup=complex_setup) return solvent_setup + solvent_run + complex_setup + complex_run @@ -1424,16 +1401,12 @@ class SepTopProtocol(gufe.Protocol): if "Run" in pu.name: unsorted_solvent_repeats_run[pu.outputs["repeat_id"]].append(pu) elif "Setup" in pu.name: - unsorted_solvent_repeats_setup[pu.outputs["repeat_id"]].append( - pu - ) + unsorted_solvent_repeats_setup[pu.outputs["repeat_id"]].append(pu) else: if "Run" in pu.name: unsorted_complex_repeats_run[pu.outputs["repeat_id"]].append(pu) elif "Setup" in pu.name: - unsorted_complex_repeats_setup[pu.outputs["repeat_id"]].append( - pu - ) + unsorted_complex_repeats_setup[pu.outputs["repeat_id"]].append(pu) repeats: dict[str, dict[str, list[gufe.ProtocolUnitResult]]] = { "solvent_setup": {}, @@ -1442,22 +1415,14 @@ class SepTopProtocol(gufe.Protocol): "complex": {}, } for k, v in unsorted_solvent_repeats_setup.items(): - repeats["solvent_setup"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["solvent_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_solvent_repeats_run.items(): - repeats["solvent"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["solvent"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_complex_repeats_setup.items(): - repeats["complex_setup"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["complex_setup"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) for k, v in unsorted_complex_repeats_run.items(): - repeats["complex"][str(k)] = sorted( - v, key=lambda x: x.outputs["generation"] - ) + repeats["complex"][str(k)] = sorted(v, key=lambda x: x.outputs["generation"]) return repeats @@ -1586,9 +1551,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): """ mdtraj_complex_A = _get_mdtraj_from_openmm(omm_topology_A, positions_A) mdtraj_complex_B = _get_mdtraj_from_openmm(omm_topology_B, positions_B) - alignment_indices = SepTopComplexSetupUnit._get_selection_atom_indices( - mdtraj_complex_A - ) + alignment_indices = SepTopComplexSetupUnit._get_selection_atom_indices(mdtraj_complex_A) imaged_complex_B = mdtraj_complex_B.image_molecules() imaged_complex_B.superpose( mdtraj_complex_A, @@ -1773,8 +1736,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): # In some cases (debugging / dry runs) this won't be available # so we'll default to using input positions. out_traj = ( - self.shared_basepath - / settings["equil_output_settings"].production_trajectory_filename + self.shared_basepath / settings["equil_output_settings"].production_trajectory_filename ) u_A = self._get_mda_universe( topology_A, @@ -1811,15 +1773,12 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): settings["restraint_settings"], ) # We have to update the indices for ligand B to match the AB complex - new_boresch_B_indices = [ - ligand_B_inxs_B.index(i) for i in rest_geom_B.guest_atoms - ] + new_boresch_B_indices = [ligand_B_inxs_B.index(i) for i in rest_geom_B.guest_atoms] rest_geom_B.guest_atoms = [ligand_B_inxs[i] for i in new_boresch_B_indices] if self.verbose: self.logger.info( - f"restraint geometry is: ligand A: {rest_geom_A}" - f"and ligand B: {rest_geom_B}." + f"restraint geometry is: ligand A: {rest_geom_A}and ligand B: {rest_geom_B}." ) # We need a temporary thermodynamic state to add the restraint @@ -1852,7 +1811,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): ) # Multiply the correction for ligand B by -1 as for this ligands, # Boresch restraint has to be turned on in the analytical corr. - correction_B = -correction_B # type: ignore[operator] + correction_B = -correction_B # type: ignore[operator] # Get the system # Note: you have to remove the thermostat, otherwise you end up @@ -1903,9 +1862,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): # 1. Get components self.logger.info("Creating and setting up the OpenMM systems") alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components() - smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps( - alchem_comps, smc_comps - ) + smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps) # 3. Get settings settings = self._handle_settings() @@ -1921,7 +1878,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): smc_comps_A, settings, ) - ) + ) # fmt: skip omm_system_B, omm_topology_B, positions_B, modeller_B, comp_resids_B = ( self.get_system( @@ -1930,7 +1887,7 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): smc_comps_B, settings, ) - ) + ) # fmt: skip smc_B_unique_keys = smc_comps_B.keys() - smc_comps_A.keys() smc_comp_B_unique = {key: smc_comps_B[key] for key in smc_B_unique_keys} @@ -1943,15 +1900,15 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): ) # Virtual sites sanity check - ensure we restart velocities when # there are virtual sites in the system - self.check_assign_velocities_with_virtual_site(omm_system_AB, settings["integrator_settings"]) + self.check_assign_velocities_with_virtual_site( + omm_system_AB, settings["integrator_settings"] + ) # Get the comp_resids of the AB system resids_A = list(itertools.chain(*comp_resids_A.values())) resids_AB = [r.index for r in modeller_AB.topology.residues()] diff_resids = list(set(resids_AB) - set(resids_A)) - comp_resids_AB = comp_resids_A | { - alchem_comps["stateB"][0]: np.array(diff_resids) - } + comp_resids_AB = comp_resids_A | {alchem_comps["stateB"][0]: np.array(diff_resids)} # 6. Pre-equilbrate System (for restraint selection) self.logger.info("Pre-equilibrating the systems") @@ -2002,9 +1959,9 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): # Update positions from AB system positions_AB[all_atom_ids_A[0] : all_atom_ids_A[-1] + 1, :] = equil_positions_A - positions_AB[atom_indices_AB_B[0] : atom_indices_AB_B[-1] + 1, :] = ( - updated_positions_B[atom_indices_B[0] : atom_indices_B[-1] + 1] - ) + positions_AB[atom_indices_AB_B[0] : atom_indices_AB_B[-1] + 1, :] = updated_positions_B[ + atom_indices_B[0] : atom_indices_B[-1] + 1 + ] # 9. Create the alchemical system self.logger.info("Creating the alchemical system and applying restraints") @@ -2016,21 +1973,19 @@ class SepTopComplexSetupUnit(SepTopComplexMixin, BaseSepTopSetupUnit): ) # 10. Apply Restraints - corr_A, corr_B, system, restraint_geom_A, restraint_geom_B = ( - self._add_restraints( - alchemical_system, - omm_topology_A, - omm_topology_B, - equil_positions_A, - equil_positions_B, - alchem_comps["stateA"][0], - alchem_comps["stateB"][0], - atom_indices_AB_A, - atom_indices_AB_B, - atom_indices_B, - comp_atomids_AB[prot_comp], - settings, - ) + corr_A, corr_B, system, restraint_geom_A, restraint_geom_B = self._add_restraints( + alchemical_system, + omm_topology_A, + omm_topology_B, + equil_positions_A, + equil_positions_B, + alchem_comps["stateA"][0], + alchem_comps["stateB"][0], + atom_indices_AB_A, + atom_indices_AB_B, + atom_indices_B, + comp_atomids_AB[prot_comp], + settings, ) equil_positions_AB, box_AB = _pre_equilibrate( @@ -2120,12 +2075,8 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit): pos_ligandA = rdmol_A.GetConformers()[0].GetPositions() pos_ligandB = rdmol_B.GetConformers()[0].GetPositions() - ligand_1_radius = np.linalg.norm( - pos_ligandA - pos_ligandA.mean(axis=0), axis=1 - ).max() - ligand_2_radius = np.linalg.norm( - pos_ligandB - pos_ligandB.mean(axis=0), axis=1 - ).max() + ligand_1_radius = np.linalg.norm(pos_ligandA - pos_ligandA.mean(axis=0), axis=1).max() + ligand_2_radius = np.linalg.norm(pos_ligandB - pos_ligandB.mean(axis=0), axis=1).max() ligand_distance = (ligand_1_radius + ligand_2_radius) * 1.5 ligand_offset = pos_ligandA.mean(0) - pos_ligandB.mean(0) @@ -2183,7 +2134,6 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit): """ if isinstance(settings["restraint_settings"], DistanceRestraintSettings): - rest_geom = geometry.harmonic.get_molecule_centers_restraint( molA_rdmol=ligand_1, molB_rdmol=ligand_2, @@ -2199,8 +2149,7 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit): self.logger.info(f"restraint geometry is: {rest_geom}") distance = np.linalg.norm( - positions_AB[rest_geom.guest_atoms[0]] - - positions_AB[rest_geom.host_atoms[0]] + positions_AB[rest_geom.guest_atoms[0]] - positions_AB[rest_geom.host_atoms[0]] ) k_distance = to_openmm(settings["restraint_settings"].spring_constant) @@ -2259,9 +2208,7 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit): # 1. Get components self.logger.info("Creating and setting up the OpenMM systems") alchem_comps, solv_comp, prot_comp, smc_comps = self._get_components() - smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps( - alchem_comps, smc_comps - ) + smc_comps_A, smc_comps_B, smc_comps_AB = self.get_smc_comps(alchem_comps, smc_comps) # 2. Get settings settings = self._handle_settings() @@ -2285,10 +2232,13 @@ class SepTopSolventSetupUnit(SepTopSolventMixin, BaseSepTopSetupUnit): smc_comps_A | smc_off_B, settings, ) - ) + ) # fmt: skip + # Virtual sites sanity check - ensure we restart velocities when # there are virtual sites in the system - self.check_assign_velocities_with_virtual_site(omm_system_AB, settings["integrator_settings"]) + self.check_assign_velocities_with_virtual_site( + omm_system_AB, settings["integrator_settings"] + ) # 6. Get atom indices for ligand A and ligand B and the solvent in the # system AB @@ -2364,7 +2314,6 @@ class SepTopSolventRunUnit(SepTopSolventMixin, BaseSepTopRunUnit): def _get_lambda_schedule( self, settings: dict[str, SettingsBaseModel] ) -> dict[str, list[float]]: - lambdas = dict() lambda_elec_A = settings["lambda_settings"].lambda_elec_A diff --git a/openfe/protocols/openmm_septop/equil_septop_settings.py b/openfe/protocols/openmm_septop/equil_septop_settings.py index 91c36efb..4f54aeca 100644 --- a/openfe/protocols/openmm_septop/equil_septop_settings.py +++ b/openfe/protocols/openmm_septop/equil_septop_settings.py @@ -10,6 +10,7 @@ See Also -------- openfe.protocols.openmm_septop.SepTopProtocol """ + from typing import Optional import numpy as np @@ -34,6 +35,7 @@ from gufe.settings.typing import PicosecondQuantity from openff.units import unit as offunit from pydantic import field_validator + class AlchemicalSettings(SettingsBaseModel): """Settings for the alchemical protocol @@ -228,8 +230,7 @@ class LambdaSettings(SettingsBaseModel): for window in v: if not 0 <= window <= 1: errmsg = ( - "Lambda windows must be between 0 and 1, got a" - f" window with value {window}." + f"Lambda windows must be between 0 and 1, got a window with value {window}." ) raise ValueError(errmsg) return v @@ -240,14 +241,15 @@ class LambdaSettings(SettingsBaseModel): "lambda_restraints_A", ) def must_be_monotonically_increasing_A(cls, v): - difference = np.diff(v) monotonic = np.all(difference >= 0) if not monotonic: - errmsg = ("The lambda schedule for ligand A is not monotonically" - f" increasing, got schedule {v}.") + errmsg = ( + "The lambda schedule for ligand A is not monotonically" + f" increasing, got schedule {v}." + ) raise ValueError(errmsg) return v @@ -258,14 +260,15 @@ class LambdaSettings(SettingsBaseModel): "lambda_restraints_B", ) def must_be_monotonically_decreasing_B(cls, v): - difference = np.diff(v) monotonic = np.all(difference <= 0) if not monotonic: - errmsg = ("The lambda schedule for ligand B is not monotonically" - f" decreasing, got schedule {v}.") + errmsg = ( + "The lambda schedule for ligand B is not monotonically" + f" decreasing, got schedule {v}." + ) raise ValueError(errmsg) return v @@ -325,8 +328,7 @@ class SepTopEquilOutputSettings(MDOutputSettings): @field_validator("output_indices") def must_be_all(cls, v): if v != "all": - errmsg = ("Equilibration simulations need to output the full " - f"system, got {v}.") + errmsg = f"Equilibration simulations need to output the full system, got {v}." raise ValueError(errmsg) return v @@ -444,4 +446,3 @@ class SepTopSettings(SettingsBaseModel): """ Settings for the Boresch restraints in the complex """ - diff --git a/openfe/protocols/openmm_septop/utils.py b/openfe/protocols/openmm_septop/utils.py index b7e95691..bd49af38 100644 --- a/openfe/protocols/openmm_septop/utils.py +++ b/openfe/protocols/openmm_septop/utils.py @@ -131,18 +131,14 @@ class SepTopParameterState(GlobalParameterState): value 1.""" def __init__(self, parameter_name): - super().__init__( - parameter_name, standard_value=1.0, validator=self.lambda_validator - ) + super().__init__(parameter_name, standard_value=1.0, validator=self.lambda_validator) @staticmethod def lambda_validator(self, instance, parameter_value): if parameter_value is None: return parameter_value if not (0.0 <= parameter_value <= 1.0): - raise ValueError( - "{} must be between 0 and 1.".format(self.parameter_name) - ) + raise ValueError("{} must be between 0 and 1.".format(self.parameter_name)) return float(parameter_value) # Lambda parameters for ligand A diff --git a/openfe/protocols/openmm_utils/charge_generation.py b/openfe/protocols/openmm_utils/charge_generation.py index 0f059758..86ecab61 100644 --- a/openfe/protocols/openmm_utils/charge_generation.py +++ b/openfe/protocols/openmm_utils/charge_generation.py @@ -3,6 +3,7 @@ """ Reusable utilities for assigning partial charges to ChemicalComponents. """ + import copy from typing import Union, Optional, Literal, Callable import sys @@ -15,7 +16,7 @@ from openff.toolkit.utils.base_wrapper import ToolkitWrapper from openff.toolkit.utils.toolkits import ( AmberToolsToolkitWrapper, OpenEyeToolkitWrapper, - RDKitToolkitWrapper + RDKitToolkitWrapper, ) from openff.toolkit.utils.toolkit_registry import ToolkitRegistry from threadpoolctl import threadpool_limits @@ -34,14 +35,15 @@ try: except ImportError: # toolkit_registry_manager was made non private in 0.14.4 from openff.toolkit.utils.toolkit_registry import ( - _toolkit_registry_manager as toolkit_registry_manager + _toolkit_registry_manager as toolkit_registry_manager, ) try: from openff.toolkit.utils.nagl_wrapper import NAGLToolkitWrapper from openff.nagl_models import ( - get_models_by_type, validate_nagl_model_path, + get_models_by_type, + validate_nagl_model_path, ) except ImportError: HAS_NAGL = False @@ -68,10 +70,7 @@ BACKEND_OPTIONS: dict[str, list[ToolkitWrapper]] = { } -def assign_offmol_espaloma_charges( - offmol: OFFMol, - toolkit_registry: ToolkitRegistry -) -> None: +def assign_offmol_espaloma_charges(offmol: OFFMol, toolkit_registry: ToolkitRegistry) -> None: """ Assign Espaloma charges using the OpenFF toolkit. @@ -86,12 +85,10 @@ def assign_offmol_espaloma_charges( assignment stage. """ if not HAS_ESPALOMA_CHARGE: - errmsg = ("The Espaloma ToolkiWrapper is not available, " - "please install espaloma_charge") + errmsg = "The Espaloma ToolkiWrapper is not available, please install espaloma_charge" raise ImportError(errmsg) - warnings.warn("Using espaloma to assign charges is not well tested", - category=RuntimeWarning) + warnings.warn("Using espaloma to assign charges is not well tested", category=RuntimeWarning) # make a copy to remove conformers as espaloma enforces # a 0 conformer check @@ -103,7 +100,7 @@ def assign_offmol_espaloma_charges( # https://github.com/openforcefield/openff-nagl/issues/69 with toolkit_registry_manager(toolkit_registry): offmol_copy.assign_partial_charges( - partial_charge_method='espaloma-am1bcc', + partial_charge_method="espaloma-am1bcc", toolkit_registry=EspalomaChargeToolkitWrapper(), ) @@ -133,20 +130,22 @@ def assign_offmol_nagl_charges( If ``None``, will fetch the latest production "am1bcc" model. """ if not HAS_NAGL: - errmsg = ("The NAGL toolkit is not available, you may " - "be using an older version of the OpenFF " - "toolkit - you need v0.14.4 or above") + errmsg = ( + "The NAGL toolkit is not available, you may " + "be using an older version of the OpenFF " + "toolkit - you need v0.14.4 or above" + ) raise ImportError(errmsg) if nagl_model is None: - prod_models = get_models_by_type( - model_type='am1bcc', production_only=True - ) + prod_models = get_models_by_type(model_type="am1bcc", production_only=True) try: nagl_model = prod_models[-1] except IndexError: - errmsg = ("No production am1bcc NAGL models were found, " - "please manually select a candidate release model.") + errmsg = ( + "No production am1bcc NAGL models were found, " + "please manually select a candidate release model." + ) raise ValueError(errmsg) model_path = validate_nagl_model_path(nagl_model) @@ -163,7 +162,7 @@ def assign_offmol_nagl_charges( def assign_offmol_am1bcc_charges( offmol: OFFMol, - partial_charge_method: Literal['am1bcc', 'am1bccelf10'], + partial_charge_method: Literal["am1bcc", "am1bccelf10"], toolkit_registry: ToolkitRegistry, ) -> None: """ @@ -200,7 +199,7 @@ def assign_offmol_am1bcc_charges( offmol.assign_partial_charges( partial_charge_method=partial_charge_method, use_conformers=offmol.conformers, - toolkit_registry=toolkit_registry + toolkit_registry=toolkit_registry, ) @@ -238,32 +237,36 @@ def _generate_offmol_conformers( # Check number of conformers if generate_n_conformers is None and return if generate_n_conformers is None: if offmol.n_conformers == 0: - errmsg = ("No conformers are associated with input OpenFF " - "Molecule. Need at least one for partial charge " - "assignment") + errmsg = ( + "No conformers are associated with input OpenFF " + "Molecule. Need at least one for partial charge " + "assignment" + ) raise ValueError(errmsg) if offmol.n_conformers > max_conf: - errmsg = ("OpenFF Molecule has too many conformers: " - f"{offmol.n_conformers}, selected partial charge " - f"method can only support a maximum of {max_conf} " - "conformers.") + errmsg = ( + "OpenFF Molecule has too many conformers: " + f"{offmol.n_conformers}, selected partial charge " + f"method can only support a maximum of {max_conf} " + "conformers." + ) raise ValueError(errmsg) return - # Check that generate_n_conformers < max_conf if generate_n_conformers > max_conf: - errmsg = (f"{generate_n_conformers} conformers were requested " - "for partial charge generation, but the selected " - "method only supports up to {max_conf} conformers.") + errmsg = ( + f"{generate_n_conformers} conformers were requested " + "for partial charge generation, but the selected " + "method only supports up to {max_conf} conformers." + ) raise ValueError(errmsg) # Generate conformers # OpenEye tk needs cis carboxylic acids make_carbox_cis = any( - [isinstance(i, OpenEyeToolkitWrapper) - for i in toolkit_registry.registered_toolkits] + [isinstance(i, OpenEyeToolkitWrapper) for i in toolkit_registry.registered_toolkits] ) # We are being overly cautious by both passing the @@ -282,8 +285,8 @@ def _generate_offmol_conformers( def assign_offmol_partial_charges( offmol: OFFMol, overwrite: bool, - method: Literal['am1bcc', 'am1bccelf10', 'nagl', 'espaloma'], - toolkit_backend: Literal['ambertools', 'openeye', 'rdkit'], + method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"], + toolkit_backend: Literal["ambertools", "openeye", "rdkit"], generate_n_conformers: Optional[int], nagl_model: Optional[str], ) -> OFFMol: @@ -331,7 +334,7 @@ def assign_offmol_partial_charges( """ # If you have non-zero charges and not overwriting, just return - if (offmol.partial_charges is not None and np.any(offmol.partial_charges)): + if offmol.partial_charges is not None and np.any(offmol.partial_charges): if not overwrite: return offmol @@ -351,28 +354,28 @@ def assign_offmol_partial_charges( "am1bcc": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_am1bcc_charges, - "backends": ['ambertools', 'openeye'], + "backends": ["ambertools", "openeye"], "max_conf": 1, - "charge_extra_kwargs": {'partial_charge_method': 'am1bcc'}, + "charge_extra_kwargs": {"partial_charge_method": "am1bcc"}, }, "am1bccelf10": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_am1bcc_charges, - "backends": ['openeye'], + "backends": ["openeye"], "max_conf": sys.maxsize, - "charge_extra_kwargs": {'partial_charge_method': 'am1bccelf10'}, + "charge_extra_kwargs": {"partial_charge_method": "am1bccelf10"}, }, "nagl": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_nagl_charges, - "backends": ['openeye', 'rdkit', 'ambertools'], + "backends": ["openeye", "rdkit", "ambertools"], "max_conf": 1, "charge_extra_kwargs": {"nagl_model": nagl_model}, }, "espaloma": { "confgen_func": _generate_offmol_conformers, "charge_func": assign_offmol_espaloma_charges, - "backends": ['rdkit', 'ambertools'], + "backends": ["rdkit", "ambertools"], "max_conf": 1, "charge_extra_kwargs": {}, }, @@ -380,35 +383,35 @@ def assign_offmol_partial_charges( # Grab the backends and also check our method try: - backends = CHARGE_METHODS[method.lower()]['backends'] + backends = CHARGE_METHODS[method.lower()]["backends"] except KeyError: errmsg = f"Unknown partial charge method {method}" raise ValueError(errmsg) # Check our method actually supports the toolkit backend selected if toolkit_backend.lower() not in backends: # type: ignore - errmsg = (f"Selected toolkit_backend ({toolkit_backend}) cannot " - f"be used with the selected method ({method}). " - f"Available backends are: {backends}") + errmsg = ( + f"Selected toolkit_backend ({toolkit_backend}) cannot " + f"be used with the selected method ({method}). " + f"Available backends are: {backends}" + ) raise ValueError(errmsg) # OpenEye is the only optional dependency in the toolkit backends - if toolkit_backend.lower() == 'openeye' and not HAS_OPENEYE: + if toolkit_backend.lower() == "openeye" and not HAS_OPENEYE: errmsg = "OpenEye is not available and cannot be selected as a backend" raise ImportError(errmsg) - toolkits = ToolkitRegistry( - [i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]] - ) + toolkits = ToolkitRegistry([i() for i in BACKEND_OPTIONS[toolkit_backend.lower()]]) # We make a copy of the molecule since we're going to modify conformers offmol_copy = copy.deepcopy(offmol) # Generate conformers - note this method may differ based on the partial # charge method employed - CHARGE_METHODS[method.lower()]['confgen_func']( + CHARGE_METHODS[method.lower()]["confgen_func"]( offmol=offmol_copy, - max_conf=CHARGE_METHODS[method.lower()]['max_conf'], + max_conf=CHARGE_METHODS[method.lower()]["max_conf"], toolkit_registry=toolkits, generate_n_conformers=generate_n_conformers, ) # type: ignore @@ -417,10 +420,10 @@ def assign_offmol_partial_charges( # with threadpool_limits(limits=1): # Call selected method to assign partial charges - CHARGE_METHODS[method.lower()]['charge_func']( + CHARGE_METHODS[method.lower()]["charge_func"]( offmol=offmol_copy, toolkit_registry=toolkits, - **CHARGE_METHODS[method.lower()]['charge_extra_kwargs'], + **CHARGE_METHODS[method.lower()]["charge_extra_kwargs"], ) # type: ignore # Copy partial charges back @@ -431,8 +434,8 @@ def assign_offmol_partial_charges( def bulk_assign_partial_charges( molecules: list[SmallMoleculeComponent], overwrite: bool, - method: Literal['am1bcc', 'am1bccelf10', 'nagl', 'espaloma'], - toolkit_backend: Literal['ambertools', 'openeye', 'rdkit'], + method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"], + toolkit_backend: Literal["ambertools", "openeye", "rdkit"], generate_n_conformers: Optional[int], nagl_model: Optional[str], processors: int = 1, @@ -488,7 +491,7 @@ def bulk_assign_partial_charges( "method": method, "toolkit_backend": toolkit_backend, "generate_n_conformers": generate_n_conformers, - "nagl_model": nagl_model + "nagl_model": nagl_model, } charged_ligands = [] @@ -496,22 +499,23 @@ def bulk_assign_partial_charges( from concurrent.futures import ProcessPoolExecutor, as_completed with ProcessPoolExecutor(max_workers=processors) as pool: - work_list = [ pool.submit( assign_offmol_partial_charges, m.to_openff(), - **charge_keywords, # type: ignore + **charge_keywords, # type: ignore ) for m in molecules ] - for work in tqdm.tqdm(as_completed(work_list), desc="Generating charges", ncols=80, total=len(molecules)): + for work in tqdm.tqdm( + as_completed(work_list), desc="Generating charges", ncols=80, total=len(molecules) + ): charged_ligands.append(SmallMoleculeComponent.from_openff(work.result())) else: for m in tqdm.tqdm(molecules, desc="Generating charges", ncols=80, total=len(molecules)): - mol_with_charge = assign_offmol_partial_charges(m.to_openff(), **charge_keywords) # type: ignore + mol_with_charge = assign_offmol_partial_charges(m.to_openff(), **charge_keywords) # type: ignore charged_ligands.append(SmallMoleculeComponent.from_openff(mol_with_charge)) - return charged_ligands \ No newline at end of file + return charged_ligands diff --git a/openfe/protocols/openmm_utils/multistate_analysis.py b/openfe/protocols/openmm_utils/multistate_analysis.py index b8f486bb..015b257a 100644 --- a/openfe/protocols/openmm_utils/multistate_analysis.py +++ b/openfe/protocols/openmm_utils/multistate_analysis.py @@ -3,6 +3,7 @@ """ Reusable utility methods to analyze results from multistate calculations. """ + from pathlib import Path import warnings import matplotlib.pyplot as plt @@ -18,34 +19,44 @@ from typing import Optional, Union from openfe.due import due, Doi -due.cite(Doi("10.5281/zenodo.596622"), - description="OpenMMTools", - path="openfe.protocols.openmm_utils.multistate_analysis", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596622"), + description="OpenMMTools", + path="openfe.protocols.openmm_utils.multistate_analysis", + cite_module=True, +) -due.cite(Doi("10.1063/1.2978177"), - description="MBAR paper", - path="openfe.protocols.openmm_utils.multistate_analysis", - cite_module=True) +due.cite( + Doi("10.1063/1.2978177"), + description="MBAR paper", + path="openfe.protocols.openmm_utils.multistate_analysis", + cite_module=True, +) -due.cite(Doi("10.1021/ct0502864"), - description="MBAR timeseries algorithms", - path="openfe.protocols.openmm_utils.multistate_analysis", - cite_module=True) +due.cite( + Doi("10.1021/ct0502864"), + description="MBAR timeseries algorithms", + path="openfe.protocols.openmm_utils.multistate_analysis", + cite_module=True, +) -due.cite(Doi("10.1021/acs.jctc.5b00784"), - description="Automatic equilibration detection method", - path="openfe.protocols.openmm_utils.multistate_analysis", - cite_module=True) +due.cite( + Doi("10.1021/acs.jctc.5b00784"), + description="Automatic equilibration detection method", + path="openfe.protocols.openmm_utils.multistate_analysis", + cite_module=True, +) -due.cite(Doi("10.5281/zenodo.596220"), - description="pyMBAR zenodo", - path="openfe.protocols.openmm_utils.multistate_analysis", - cite_module=True) +due.cite( + Doi("10.5281/zenodo.596220"), + description="pyMBAR zenodo", + path="openfe.protocols.openmm_utils.multistate_analysis", + cite_module=True, +) class MultistateEquilFEAnalysis: @@ -73,13 +84,18 @@ class MultistateEquilFEAnalysis: The number of samples to use in the forward and reverse analysis of the free energies. Default 10. """ - def __init__(self, reporter: multistate.MultiStateReporter, - sampling_method: str, result_units: Quantity, - forward_reverse_samples: int = 10): + + def __init__( + self, + reporter: multistate.MultiStateReporter, + sampling_method: str, + result_units: Quantity, + forward_reverse_samples: int = 10, + ): self.analyzer = multistate.MultiStateSamplerAnalyzer(reporter) self.units = result_units - if sampling_method.lower() not in ['repex', 'sams', 'independent']: + if sampling_method.lower() not in ["repex", "sams", "independent"]: wmsg = f"Unknown sampling method {sampling_method}" warnings.warn(wmsg) self.sampling_method = sampling_method.lower() @@ -105,42 +121,36 @@ class MultistateEquilFEAnalysis: A prefix for the written filenames. """ # MBAR overlap matrix - ax = plotting.plot_lambda_transition_matrix(self.free_energy_overlaps['matrix']) - ax.set_title('MBAR overlap matrix') + ax = plotting.plot_lambda_transition_matrix(self.free_energy_overlaps["matrix"]) + ax.set_title("MBAR overlap matrix") ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'mbar_overlap_matrix.png') + filepath / (filename_prefix + "mbar_overlap_matrix.png") ) plt.close(ax.figure) # type: ignore # Reverse and forward analysis if self.forward_and_reverse_free_energies is not None: - ax = plotting.plot_convergence( - self.forward_and_reverse_free_energies, self.units - ) - ax.set_title('Forward and Reverse free energy convergence') + ax = plotting.plot_convergence(self.forward_and_reverse_free_energies, self.units) + ax.set_title("Forward and Reverse free energy convergence") ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'forward_reverse_convergence.png') + filepath / (filename_prefix + "forward_reverse_convergence.png") ) plt.close(ax.figure) # type: ignore # Replica state timeseries plot - ax = plotting.plot_replica_timeseries( - self.replica_states, self.equilibration_iterations - ) - ax.set_title('Change in replica state over time') + ax = plotting.plot_replica_timeseries(self.replica_states, self.equilibration_iterations) + ax.set_title("Change in replica state over time") ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'replica_state_timeseries.png') + filepath / (filename_prefix + "replica_state_timeseries.png") ) plt.close(ax.figure) # type: ignore # Replica exchange transition matrix - if self.sampling_method == 'repex': - ax = plotting.plot_lambda_transition_matrix( - self.replica_exchange_statistics['matrix'] - ) - ax.set_title('Replica exchange transition matrix') + if self.sampling_method == "repex": + ax = plotting.plot_lambda_transition_matrix(self.replica_exchange_statistics["matrix"]) + ax.set_title("Replica exchange transition matrix") ax.figure.savefig( # type: ignore - filepath / (filename_prefix + 'replica_exchange_matrix.png') + filepath / (filename_prefix + "replica_exchange_matrix.png") ) plt.close(ax.figure) # type: ignore @@ -174,9 +184,7 @@ class MultistateEquilFEAnalysis: self._free_energy, self._free_energy_err = self.get_equil_free_energy() # forward and reverse analysis - self._forward_reverse = self.get_forward_and_reverse_analysis( - forward_reverse_samples - ) + self._forward_reverse = self.get_forward_and_reverse_analysis(forward_reverse_samples) # Gather overlap matrix self._overlap_matrix = self.get_overlap_matrix() @@ -184,7 +192,7 @@ class MultistateEquilFEAnalysis: # Gather exchange transition matrix # Note we only generate these for replica exchange calculations # TODO: consider if this would also work for SAMS - if self.sampling_method == 'repex': + if self.sampling_method == "repex": self._exchange_matrix = self.get_exchanges() @staticmethod @@ -234,26 +242,25 @@ class MultistateEquilFEAnalysis: N_l, solver_protocol="robust", n_bootstraps=bootstraps, - bootstrap_solver_protocol="robust" + bootstrap_solver_protocol="robust", ) if bootstraps > 0: - uncertainty_method='bootstrap' + uncertainty_method = "bootstrap" else: - uncertainty_method=None + uncertainty_method = None r = mbar.compute_free_energy_differences( compute_uncertainty=True, uncertainty_method=uncertainty_method, ) - DF_ij = r['Delta_f'] - dDF_ij = r['dDelta_f'] + DF_ij = r["Delta_f"] + dDF_ij = r["dDelta_f"] DG = DF_ij[0, -1] * analyzer.kT dDG = dDF_ij[0, -1] * analyzer.kT - return (from_openmm(DG).to(return_units), - from_openmm(dDG).to(return_units)) + return (from_openmm(DG).to(return_units), from_openmm(dDG).to(return_units)) def get_equil_free_energy(self) -> tuple[Quantity, Quantity]: """ @@ -270,13 +277,7 @@ class MultistateEquilFEAnalysis: u_ln_decorr = self.analyzer._unbiased_decorrelated_u_ln N_l_decorr = self.analyzer._unbiased_decorrelated_N_l - DG, dDG = self._get_free_energy( - self.analyzer, - u_ln_decorr, - N_l_decorr, - 1000, - self.units - ) + DG, dDG = self._get_free_energy(self.analyzer, u_ln_decorr, N_l_decorr, 1000, self.units) return DG, dDG @@ -317,14 +318,12 @@ class MultistateEquilFEAnalysis: # Check that the N_l is the same across all states if not np.all(N_l == N_l[0]): - errmsg = ("The number of samples is not equivalent across all " - f"states {N_l}") + errmsg = f"The number of samples is not equivalent across all states {N_l}" raise ValueError(errmsg) # Get the chunks of N_l going from 10% to ~ 100% # Note: you always lose out a few data points but it's fine - chunks = [max(int(N_l[0] / num_samples * i), 1) - for i in range(1, num_samples + 1)] + chunks = [max(int(N_l[0] / num_samples * i), 1) for i in range(1, num_samples + 1)] forward_DGs = [] forward_dDGs = [] @@ -363,11 +362,11 @@ class MultistateEquilFEAnalysis: return None forward_reverse = { - 'fractions': np.array(fractions), - 'forward_DGs': Quantity.from_list(forward_DGs), # type: ignore - 'forward_dDGs': Quantity.from_list(forward_dDGs), # type: ignore - 'reverse_DGs': Quantity.from_list(reverse_DGs), # type: ignore - 'reverse_dDGs': Quantity.from_list(reverse_dDGs) # type: ignore + "fractions": np.array(fractions), + "forward_DGs": Quantity.from_list(forward_DGs), # type: ignore + "forward_dDGs": Quantity.from_list(forward_dDGs), # type: ignore + "reverse_DGs": Quantity.from_list(reverse_DGs), # type: ignore + "reverse_dDGs": Quantity.from_list(reverse_dDGs), # type: ignore } return forward_reverse @@ -387,7 +386,6 @@ class MultistateEquilFEAnalysis: """ return self.analyzer.mbar.compute_overlap() - def get_exchanges(self) -> dict[str, npt.NDArray]: """ Gather both the transition matrix (and relevant eigenvalues) between @@ -404,8 +402,10 @@ class MultistateEquilFEAnalysis: """ # Get replica mixing statistics mixing_stats = self.analyzer.generate_mixing_statistics() - transition_matrix = {'eigenvalues': mixing_stats.eigenvalues, - 'matrix': mixing_stats.transition_matrix} + transition_matrix = { + "eigenvalues": mixing_stats.eigenvalues, + "matrix": mixing_stats.transition_matrix, + } return transition_matrix @property @@ -465,26 +465,28 @@ class MultistateEquilFEAnalysis: A dictionary containing the estimated replica exchange matrix and corresponding eigenvalues. """ - if hasattr(self, '_exchange_matrix'): + if hasattr(self, "_exchange_matrix"): return self._exchange_matrix else: - errmsg = ("Exchange matrix was not generated, this is likely " - f"{self.sampling_method} is not repex.") + errmsg = ( + "Exchange matrix was not generated, this is likely " + f"{self.sampling_method} is not repex." + ) raise ValueError(errmsg) @property def unit_results_dict(self): results_dict = { - 'unit_estimate': self.free_energy, - 'unit_estimate_error': self.free_energy_error, - 'unit_mbar_overlap': self.free_energy_overlaps, - 'forward_and_reverse_energies': self.forward_and_reverse_free_energies, - 'production_iterations': self.production_iterations, - 'equilibration_iterations': self.equilibration_iterations, + "unit_estimate": self.free_energy, + "unit_estimate_error": self.free_energy_error, + "unit_mbar_overlap": self.free_energy_overlaps, + "forward_and_reverse_energies": self.forward_and_reverse_free_energies, + "production_iterations": self.production_iterations, + "equilibration_iterations": self.equilibration_iterations, } - if hasattr(self, '_exchange_matrix'): - results_dict['replica_exchange_statistics'] = self.replica_exchange_statistics + if hasattr(self, "_exchange_matrix"): + results_dict["replica_exchange_statistics"] = self.replica_exchange_statistics return results_dict diff --git a/openfe/protocols/openmm_utils/omm_compute.py b/openfe/protocols/openmm_utils/omm_compute.py index af48a96a..244cc541 100644 --- a/openfe/protocols/openmm_utils/omm_compute.py +++ b/openfe/protocols/openmm_utils/omm_compute.py @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) def get_openmm_platform( platform_name: Optional[str] = None, gpu_device_index: Optional[list[int]] = None, - restrict_cpu_count: bool = False + restrict_cpu_count: bool = False, ): """ Return OpenMM's platform object based on given name. Setting to mixed @@ -45,39 +45,42 @@ def get_openmm_platform( # No platform is specified, so retrieve fastest platform that supports # 'mixed' precision from openmmtools.utils import get_fastest_platform - platform = get_fastest_platform(minimum_precision='mixed') + + platform = get_fastest_platform(minimum_precision="mixed") else: try: platform_name = { - 'cpu': 'CPU', - 'opencl': 'OpenCL', - 'cuda': 'CUDA', + "cpu": "CPU", + "opencl": "OpenCL", + "cuda": "CUDA", }[str(platform_name).lower()] except KeyError: pass from openmm import Platform + platform = Platform.getPlatformByName(platform_name) # Set precision and properties name = platform.getName() - if name in ['CUDA', 'OpenCL']: - platform.setPropertyDefaultValue('Precision', 'mixed') + if name in ["CUDA", "OpenCL"]: + platform.setPropertyDefaultValue("Precision", "mixed") if gpu_device_index is not None: - index_list = ','.join(str(i) for i in gpu_device_index) - platform.setPropertyDefaultValue('DeviceIndex', index_list) + index_list = ",".join(str(i) for i in gpu_device_index) + platform.setPropertyDefaultValue("DeviceIndex", index_list) - if name == 'CUDA': - platform.setPropertyDefaultValue( - 'DeterministicForces', 'true') + if name == "CUDA": + platform.setPropertyDefaultValue("DeterministicForces", "true") - if name != 'CUDA': - wmsg = (f"Non-CUDA platform selected: {name}, this may significantly " - "impact simulation performance") + if name != "CUDA": + wmsg = ( + f"Non-CUDA platform selected: {name}, this may significantly " + "impact simulation performance" + ) warnings.warn(wmsg) logging.warning(wmsg) - if name == 'CPU' and restrict_cpu_count: - threads = os.getenv("OPENMM_CPU_THREADS", '1') - platform.setPropertyDefaultValue('Threads', threads) + if name == "CPU" and restrict_cpu_count: + threads = os.getenv("OPENMM_CPU_THREADS", "1") + platform.setPropertyDefaultValue("Threads", threads) return platform diff --git a/openfe/protocols/openmm_utils/omm_settings.py b/openfe/protocols/openmm_utils/omm_settings.py index 653821eb..2e04e62c 100644 --- a/openfe/protocols/openmm_utils/omm_settings.py +++ b/openfe/protocols/openmm_utils/omm_settings.py @@ -33,14 +33,19 @@ from openff.interchange.components._packmol import _box_vectors_are_in_reduced_f from openff.units import unit FemtosecondQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("femtosecond")] -InversePicosecondQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("1/picosecond")] -TimestepQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("timestep")] +InversePicosecondQuantity: TypeAlias = Annotated[ + GufeQuantity, specify_quantity_units("1/picosecond") +] +TimestepQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("timestep")] + class BaseSolvationSettings(SettingsBaseModel): """ Base class for SolvationSettings objects. """ - model_config = ConfigDict(arbitrary_types_allowed = True) + + model_config = ConfigDict(arbitrary_types_allowed=True) + class OpenMMSolvationSettings(BaseSolvationSettings): """Settings for controlling how a system is solvated using OpenMM tooling. @@ -111,7 +116,8 @@ class OpenMMSolvationSettings(BaseSolvationSettings): :mod:`openmm.app.Modeller` Base class for SolvationSettings objects """ - solvent_model: Literal['tip3p', 'spce', 'tip4pew', 'tip5p'] = 'tip3p' + + solvent_model: Literal["tip3p", "spce", "tip4pew", "tip5p"] = "tip3p" """ Force field water model to use when solvating and defining the model properties (e.g. adding virtual site particles). @@ -127,7 +133,7 @@ class OpenMMSolvationSettings(BaseSolvationSettings): * Cannot be defined alongside ``number_of_solvent_molecules``, ``box_size``, or ``box_vectors``. """ - box_shape: Optional[Literal['cube', 'dodecahedron', 'octahedron']] = 'dodecahedron' + box_shape: Optional[Literal["cube", "dodecahedron", "octahedron"]] = "dodecahedron" """ The shape of the periodic box to create. @@ -169,7 +175,7 @@ class OpenMMSolvationSettings(BaseSolvationSettings): ``number_of_solvent_molecules``, or ``box_vectors``. """ - @field_validator('box_vectors') + @field_validator("box_vectors") def supported_vectors(cls, v): if v is not None: if not _box_vectors_are_in_reduced_form(v): @@ -177,22 +183,21 @@ class OpenMMSolvationSettings(BaseSolvationSettings): raise ValueError(errmsg) return v - @field_validator('solvent_padding') + @field_validator("solvent_padding") def is_positive_distance(cls, v): # these are time units, not simulation steps if v is None: return v if not v.is_compatible_with(unit.nanometer): - raise ValueError("solvent_padding must be in distance units " - "(i.e. nanometers)") + raise ValueError("solvent_padding must be in distance units (i.e. nanometers)") if v < 0: errmsg = "solvent_padding must be a positive value" raise ValueError(errmsg) return v - @field_validator('number_of_solvent_molecules') + @field_validator("number_of_solvent_molecules") def positive_solvent_number(cls, v): if v is None: return v @@ -203,14 +208,13 @@ class OpenMMSolvationSettings(BaseSolvationSettings): return v - @field_validator('box_size') + @field_validator("box_size") def box_size_properties(cls, v): if v is None: return v - if v.shape != (3, ): - errmsg = (f"box_size must be a 1-D array of length 3 " - f"got {v} with shape {v.shape}") + if v.shape != (3,): + errmsg = f"box_size must be a 1-D array of length 3 got {v} with shape {v.shape}" raise ValueError(errmsg) return v @@ -220,6 +224,7 @@ class BasePartialChargeSettings(SettingsBaseModel): """ Base class for partial charge assignment. """ + model_config = ConfigDict(arbitrary_types_allowed=True) @@ -227,7 +232,8 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): """ Settings for controlling partial charge assignment using the OpenFF tooling """ - partial_charge_method: Literal['am1bcc', 'am1bccelf10', 'nagl', 'espaloma'] = 'am1bcc' + + partial_charge_method: Literal["am1bcc", "am1bccelf10", "nagl", "espaloma"] = "am1bcc" """ Selection of method for partial charge generation. @@ -261,7 +267,7 @@ class OpenFFPartialChargeSettings(BasePartialChargeSettings): are supported. A maximum of one conformer is allowed. """ - off_toolkit_backend: Literal['ambertools', 'openeye', 'rdkit'] = 'ambertools' + off_toolkit_backend: Literal["ambertools", "openeye", "rdkit"] = "ambertools" """ The OpenFF toolkit registry backend to use for partial charge generation. @@ -314,7 +320,7 @@ class OpenMMEngineSettings(SettingsBaseModel): * In the future make precision and deterministic forces user defined too. """ - compute_platform: Optional[str] = 'cuda' + compute_platform: Optional[str] = "cuda" """ OpenMM compute platform to perform MD integration with. If ``None``, will choose fastest available platform. @@ -335,12 +341,11 @@ class OpenMMEngineSettings(SettingsBaseModel): Default ``None``. """ - @field_validator('compute_platform') + @field_validator("compute_platform") def supported_sampler(cls, v): - supported = ['cpu', 'opencl', 'cuda'] + supported = ["cpu", "opencl", "cuda"] if v is not None and v.lower() not in supported: - errmsg = ("Only the following OpenMM compute backends are " - f"supported: {supported}") + errmsg = f"Only the following OpenMM compute backends are supported: {supported}" raise ValueError(errmsg) return v @@ -384,35 +389,34 @@ class IntegratorSettings(SettingsBaseModel): Whether or not to remove the center of mass motion. Default ``False``. """ - @field_validator('langevin_collision_rate', 'n_restart_attempts') + @field_validator("langevin_collision_rate", "n_restart_attempts") def must_be_positive_or_zero(cls, v): if v < 0: - errmsg = ("langevin_collision_rate, and n_restart_attempts must be" - f" zero or positive values, got {v}.") + errmsg = ( + "langevin_collision_rate, and n_restart_attempts must be" + f" zero or positive values, got {v}." + ) raise ValueError(errmsg) return v - @field_validator('timestep', 'constraint_tolerance') + @field_validator("timestep", "constraint_tolerance") def must_be_positive(cls, v): if v <= 0: - errmsg = ("timestep, and constraint_tolerance " - f"must be positive values, got {v}.") + errmsg = f"timestep, and constraint_tolerance must be positive values, got {v}." raise ValueError(errmsg) return v - @field_validator('timestep') + @field_validator("timestep") def is_time(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.picosecond): - raise ValueError("timestep must be in time units " - "(i.e. picoseconds)") + raise ValueError("timestep must be in time units (i.e. picoseconds)") return v - @field_validator('langevin_collision_rate') + @field_validator("langevin_collision_rate") def must_be_inverse_time(cls, v): if not v.is_compatible_with(1 / unit.picosecond): - raise ValueError("langevin collision_rate must be in inverse time " - "(i.e. 1/picoseconds)") + raise ValueError("langevin collision_rate must be in inverse time (i.e. 1/picoseconds)") return v @@ -421,10 +425,11 @@ class OutputSettings(SettingsBaseModel): Settings for simulation output settings, writing to disk, etc... """ + model_config = ConfigDict(arbitrary_types_allowed=True) # reporter settings - output_indices: str = 'not water' + output_indices: str = "not water" """ Selection string for which part of the system to write coordinates for. Default 'not water'. @@ -433,18 +438,18 @@ class OutputSettings(SettingsBaseModel): """ Frequency to write the checkpoint file. Default 1 * unit.nanosecond. """ - checkpoint_storage_filename: str = 'checkpoint.chk' + checkpoint_storage_filename: str = "checkpoint.chk" """ Separate filename for the checkpoint file. Note, this should not be a full path, just a filename. Default 'checkpoint.chk'. """ - forcefield_cache: Optional[str] = 'db.json' + forcefield_cache: Optional[str] = "db.json" """ Filename for caching small molecule residue templates so they can be later reused. """ - @field_validator('checkpoint_interval') + @field_validator("checkpoint_interval") def must_be_positive(cls, v): if v <= 0: errmsg = f"Checkpoint intervals must be positive, got {v}." @@ -457,12 +462,13 @@ class MultiStateOutputSettings(OutputSettings): Settings for MultiState simulation output settings, writing to disk, etc... """ + model_config = ConfigDict(arbitrary_types_allowed=True) # reporter settings - output_filename: str = 'simulation.nc' + output_filename: str = "simulation.nc" """Path to the trajectory storage file. Default 'simulation.nc'.""" - output_structure: str = 'hybrid_system.pdb' + output_structure: str = "hybrid_system.pdb" """ Path of the output hybrid topology structure file. This is used to visualise and further manipulate the system. @@ -489,12 +495,13 @@ class MultiStateOutputSettings(OutputSettings): ``MultiStateSimulationSettings.time_per_iteration``. """ - - @field_validator('positions_write_frequency', 'velocities_write_frequency') + @field_validator("positions_write_frequency", "velocities_write_frequency") def must_be_positive(cls, v): if v is not None and v < 0: - errmsg = ("Position_write_frequency and velocities_write_frequency" - f" must be positive (or None), got {v}.") + errmsg = ( + "Position_write_frequency and velocities_write_frequency" + f" must be positive (or None), got {v}." + ) raise ValueError(errmsg) return v @@ -503,6 +510,7 @@ class SimulationSettings(SettingsBaseModel): """ Settings for simulation control, including lengths, etc... """ + model_config = ConfigDict(arbitrary_types_allowed=True) minimization_steps: int = 5000 @@ -518,19 +526,17 @@ class SimulationSettings(SettingsBaseModel): Must be divisible by the :class:`IntegratorSettings.timestep`. """ - @field_validator('equilibration_length', 'production_length') + @field_validator("equilibration_length", "production_length") def is_time(cls, v): # these are time units, not simulation steps if not v.is_compatible_with(unit.picosecond): raise ValueError("Durations must be in time units") return v - @field_validator('minimization_steps', 'equilibration_length', - 'production_length') + @field_validator("minimization_steps", "equilibration_length", "production_length") def must_be_positive(cls, v): if v <= 0: - errmsg = ("Minimization steps, and MD lengths must be positive, " - f"got {v}") + errmsg = f"Minimization steps, and MD lengths must be positive, got {v}" raise ValueError(errmsg) return v @@ -567,7 +573,7 @@ class MultiStateSimulationSettings(SimulationSettings): """ Simulation time between each MCMC move attempt. Default 2.5 * unit.picosecond. """ - real_time_analysis_interval: PicosecondQuantity | None= 250.0 * unit.picosecond + real_time_analysis_interval: PicosecondQuantity | None = 250.0 * unit.picosecond # todo: Add validators in the protocol """ Time interval at which to perform an analysis of the free energies. @@ -605,7 +611,7 @@ class MultiStateSimulationSettings(SimulationSettings): Default 500 * unit.picosecond. """ - sams_flatness_criteria: str = 'logZ-flatness' + sams_flatness_criteria: str = "logZ-flatness" """ SAMS only. Method for assessing when to switch to asymptomatically optimal scheme. @@ -617,41 +623,41 @@ class MultiStateSimulationSettings(SimulationSettings): n_replicas: int = 11 """Number of replicas to use. Default 11.""" - @field_validator('sams_flatness_criteria') + @field_validator("sams_flatness_criteria") def supported_flatness(cls, v): - supported = [ - 'logz-flatness', 'minimum-visits', 'histogram-flatness' - ] + supported = ["logz-flatness", "minimum-visits", "histogram-flatness"] if v.lower() not in supported: - errmsg = ("Only the following sams_flatness_criteria are " - f"supported: {supported}") + errmsg = f"Only the following sams_flatness_criteria are supported: {supported}" raise ValueError(errmsg) return v - @field_validator('sampler_method') + @field_validator("sampler_method") def supported_sampler(cls, v): - supported = ['repex', 'sams', 'independent'] + supported = ["repex", "sams", "independent"] if v.lower() not in supported: - errmsg = ("Only the following sampler_method values are " - f"supported: {supported}") + errmsg = f"Only the following sampler_method values are supported: {supported}" raise ValueError(errmsg) return v - @field_validator('n_replicas', 'time_per_iteration') + @field_validator("n_replicas", "time_per_iteration") def must_be_positive(cls, v): if v <= 0: - errmsg = "n_replicas and steps_per_iteration must be positive " \ - f"values, got {v}." + errmsg = f"n_replicas and steps_per_iteration must be positive values, got {v}." raise ValueError(errmsg) return v - @field_validator('early_termination_target_error', - 'real_time_analysis_minimum_time', 'sams_gamma0', - 'n_replicas') + @field_validator( + "early_termination_target_error", + "real_time_analysis_minimum_time", + "sams_gamma0", + "n_replicas", + ) def must_be_zero_or_positive(cls, v): if v < 0: - errmsg = ("Early termination target error, minimum iteration and" - f" SAMS gamma0 must be 0 or positive values, got {v}.") + errmsg = ( + "Early termination target error, minimum iteration and" + f" SAMS gamma0 must be 0 or positive values, got {v}." + ) raise ValueError(errmsg) return v @@ -660,6 +666,7 @@ class MDSimulationSettings(SimulationSettings): """ Settings for simulation control for plain MD simulations """ + model_config = ConfigDict(arbitrary_types_allowed=True) equilibration_length_nvt: NanosecondQuantity | None @@ -672,29 +679,30 @@ class MDSimulationSettings(SimulationSettings): class MDOutputSettings(OutputSettings): - """ Settings for simulation output settings for plain MD simulations.""" + """Settings for simulation output settings for plain MD simulations.""" + model_config = ConfigDict(arbitrary_types_allowed=True) # reporter settings - production_trajectory_filename: Optional[str] = 'simulation.xtc' + production_trajectory_filename: Optional[str] = "simulation.xtc" """Path to the storage file for analysis. Default 'simulation.xtc'.""" trajectory_write_interval: PicosecondQuantity = 20.0 * unit.picosecond """ Frequency to write the xtc file. Default 5000 * unit.timestep. """ - preminimized_structure: Optional[str] = 'system.pdb' + preminimized_structure: Optional[str] = "system.pdb" """Path to the pdb file of the full pre-minimized system. Default 'system.pdb'.""" - minimized_structure: Optional[str] = 'minimized.pdb' + minimized_structure: Optional[str] = "minimized.pdb" """Path to the pdb file of the system after minimization. Only the specified atom subset is saved. Default 'minimized.pdb'.""" - equil_nvt_structure: Optional[str] = 'equil_nvt.pdb' + equil_nvt_structure: Optional[str] = "equil_nvt.pdb" """Path to the pdb file of the system after NVT equilibration. Only the specified atom subset is saved. Default 'equil_nvt.pdb'.""" - equil_npt_structure: Optional[str] = 'equil_npt.pdb' + equil_npt_structure: Optional[str] = "equil_npt.pdb" """Path to the pdb file of the system after NPT equilibration. Only the specified atom subset is saved. Default 'equil_npt.pdb'.""" - log_output: Optional[str] = 'simulation.log' + log_output: Optional[str] = "simulation.log" """ Filename for writing the log of the MD simulation, including timesteps, energies, density, etc. diff --git a/openfe/protocols/openmm_utils/settings_validation.py b/openfe/protocols/openmm_utils/settings_validation.py index 39c1c37d..8e330d14 100644 --- a/openfe/protocols/openmm_utils/settings_validation.py +++ b/openfe/protocols/openmm_utils/settings_validation.py @@ -4,6 +4,7 @@ Reusable utility methods to validate input settings to OpenMM-based alchemical Protocols. """ + from openff.units import unit, Quantity from typing import Optional from .omm_settings import ( @@ -13,9 +14,7 @@ from .omm_settings import ( from openfe.protocols.openmm_utils.omm_settings import OpenMMSolvationSettings -def validate_openmm_solvation_settings( - settings: OpenMMSolvationSettings -) -> None: +def validate_openmm_solvation_settings(settings: OpenMMSolvationSettings) -> None: """ Checks that the OpenMMSolvation settings are correct. @@ -28,19 +27,25 @@ def validate_openmm_solvation_settings( or ``box_size``. """ unique_attributes = ( - settings.solvent_padding, settings.number_of_solvent_molecules, - settings.box_vectors, settings.box_size, + settings.solvent_padding, + settings.number_of_solvent_molecules, + settings.box_vectors, + settings.box_size, ) if len([x for x in unique_attributes if x is not None]) > 1: - errmsg = ("Only one of solvent_padding, number_of_solvent_molecules, " - "box_vectors, and box_size can be defined in the solvation " - "settings.") + errmsg = ( + "Only one of solvent_padding, number_of_solvent_molecules, " + "box_vectors, and box_size can be defined in the solvation " + "settings." + ) raise ValueError(errmsg) if settings.box_shape is not None: if settings.box_size is not None or settings.box_vectors is not None: - errmsg = ("box_shape cannot be defined alongside either box_size " - "or box_vectors in the solvation settings.") + errmsg = ( + "box_shape cannot be defined alongside either box_size " + "or box_vectors in the solvation settings." + ) raise ValueError(errmsg) @@ -69,8 +74,7 @@ def validate_timestep(hmass: float, timestep: Quantity): raise ValueError(errmsg) -def get_simsteps(sim_length: Quantity, - timestep: Quantity, mc_steps: int) -> int: +def get_simsteps(sim_length: Quantity, timestep: Quantity, mc_steps: int) -> int: """ Gets and validates the number of simulation steps. @@ -89,17 +93,19 @@ def get_simsteps(sim_length: Quantity, The number of simulation timesteps. """ - sim_time = round(sim_length.to('attosecond').m) # type: ignore - ts = round(timestep.to('attosecond').m) # type: ignore + sim_time = round(sim_length.to("attosecond").m) # type: ignore + ts = round(timestep.to("attosecond").m) # type: ignore sim_steps, mod = divmod(sim_time, ts) if mod != 0: raise ValueError("Simulation time not divisible by timestep") if (sim_steps % mc_steps) != 0: - errmsg = (f"Simulation time {sim_time/1000000} ps should contain a " - "number of steps divisible by the number of integrator " - f"timesteps between MC moves {mc_steps}") + errmsg = ( + f"Simulation time {sim_time / 1000000} ps should contain a " + "number of steps divisible by the number of integrator " + f"timesteps between MC moves {mc_steps}" + ) raise ValueError(errmsg) return sim_steps @@ -134,8 +140,9 @@ def divmod_time( return iterations, remainder -def divmod_time_and_check(numerator: Quantity, denominator: Quantity, - numerator_name: str, denominator_name: str) -> int: +def divmod_time_and_check( + numerator: Quantity, denominator: Quantity, numerator_name: str, denominator_name: str +) -> int: """Perform a division of time, failing if there is a remainder For example numerator 20.0 ps and denominator 4.0 fs gives 5000 @@ -161,9 +168,11 @@ def divmod_time_and_check(numerator: Quantity, denominator: Quantity, its, rem = divmod_time(numerator, denominator) if rem: - errmsg = (f"The {numerator_name} ({numerator}) " - "does not evenly divide by the " - f"{denominator_name} ({denominator})") + errmsg = ( + f"The {numerator_name} ({numerator}) " + "does not evenly divide by the " + f"{denominator_name} ({denominator})" + ) raise ValueError(errmsg) return its @@ -195,9 +204,11 @@ def convert_checkpoint_interval_to_iterations( iterations, rem = divmod_time(checkpoint_interval, time_per_iteration) if rem: - errmsg = (f"The amount of time per checkpoint {checkpoint_interval} " - "does not evenly divide by the amount of time per " - f"state MCMC move attempt {time_per_iteration}") + errmsg = ( + f"The amount of time per checkpoint {checkpoint_interval} " + "does not evenly divide by the amount of time per " + f"state MCMC move attempt {time_per_iteration}" + ) raise ValueError(errmsg) return iterations diff --git a/openfe/protocols/openmm_utils/system_creation.py b/openfe/protocols/openmm_utils/system_creation.py index f11ed85f..f80782f8 100644 --- a/openfe/protocols/openmm_utils/system_creation.py +++ b/openfe/protocols/openmm_utils/system_creation.py @@ -4,6 +4,7 @@ Reusable utility methods to create Systems for OpenMM-based alchemical Protocols. """ + import numpy as np import numpy.typing as npt from openmm import app, MonteCarloBarostat @@ -14,9 +15,7 @@ from openmmforcefields.generators import SystemGenerator from typing import Optional from pathlib import Path from gufe.settings import OpenMMSystemGeneratorFFSettings, ThermoSettings -from gufe import ( - Component, ProteinComponent, SolventComponent, SmallMoleculeComponent -) +from gufe import Component, ProteinComponent, SolventComponent, SmallMoleculeComponent from openfe.protocols.openmm_utils.omm_settings import ( IntegratorSettings, OpenMMSolvationSettings, @@ -63,28 +62,28 @@ def get_system_generator( """ # get the right constraint constraints = { - 'hbonds': app.HBonds, - 'none': None, - 'allbonds': app.AllBonds, - 'hangles': app.HAngles + "hbonds": app.HBonds, + "none": None, + "allbonds": app.AllBonds, + "hangles": app.HAngles, # vvv can be None so string it }[str(forcefield_settings.constraints).lower()] # create forcefield_kwargs entry forcefield_kwargs = { - 'constraints': constraints, - 'rigidWater': forcefield_settings.rigid_water, - 'removeCMMotion': integrator_settings.remove_com, - 'hydrogenMass': forcefield_settings.hydrogen_mass * omm_unit.amu, + "constraints": constraints, + "rigidWater": forcefield_settings.rigid_water, + "removeCMMotion": integrator_settings.remove_com, + "hydrogenMass": forcefield_settings.hydrogen_mass * omm_unit.amu, } # get the right nonbonded method nonbonded_method = { - 'pme': app.PME, - 'nocutoff': app.NoCutoff, - 'cutoffnonperiodic': app.CutoffNonPeriodic, - 'cutoffperiodic': app.CutoffPeriodic, - 'ewald': app.Ewald + "pme": app.PME, + "nocutoff": app.NoCutoff, + "cutoffnonperiodic": app.CutoffNonPeriodic, + "cutoffperiodic": app.CutoffPeriodic, + "ewald": app.Ewald, }[forcefield_settings.nonbonded_method.lower()] nonbonded_cutoff = to_openmm( @@ -93,15 +92,15 @@ def get_system_generator( # create the periodic_kwarg entry periodic_kwargs = { - 'nonbondedMethod': nonbonded_method, - 'nonbondedCutoff': nonbonded_cutoff, + "nonbondedMethod": nonbonded_method, + "nonbondedCutoff": nonbonded_cutoff, } # Currently the else is a dead branch, we will want to investigate the # possibility of using CutoffNonPeriodic at some point though (for RF) if nonbonded_method is not app.CutoffNonPeriodic: nonperiodic_kwargs = { - 'nonbondedMethod': app.NoCutoff, + "nonbondedMethod": app.NoCutoff, } else: # pragma: no-cover nonperiodic_kwargs = periodic_kwargs @@ -110,8 +109,8 @@ def get_system_generator( # TODO: move this to its own place where we can handle membranes if has_solvent: barostat = MonteCarloBarostat( - ensure_quantity(thermo_settings.pressure, 'openmm'), - ensure_quantity(thermo_settings.temperature, 'openmm'), + ensure_quantity(thermo_settings.pressure, "openmm"), + ensure_quantity(thermo_settings.temperature, "openmm"), integrator_settings.barostat_frequency.m, ) else: @@ -137,8 +136,8 @@ def get_omm_modeller( protein_comp: Optional[ProteinComponent], solvent_comp: Optional[SolventComponent], small_mols: dict[SmallMoleculeComponent, OFFMol], - omm_forcefield : app.ForceField, - solvent_settings : OpenMMSolvationSettings + omm_forcefield: app.ForceField, + solvent_settings: OpenMMSolvationSettings, ) -> ModellerReturn: """ Generate an OpenMM Modeller class based on a potential input ProteinComponent, @@ -167,19 +166,15 @@ def get_omm_modeller( """ component_resids = {} - def _add_small_mol(comp, - mol, - system_modeller: app.Modeller, - comp_resids: dict[Component, npt.NDArray]): + def _add_small_mol( + comp, mol, system_modeller: app.Modeller, comp_resids: dict[Component, npt.NDArray] + ): """ Helper method to add OFFMol to an existing Modeller object and update a dictionary tracking residue indices for each component. """ omm_top = mol.to_topology().to_openmm() - system_modeller.add( - omm_top, - ensure_quantity(mol.conformers[0], 'openmm') - ) + system_modeller.add(omm_top, ensure_quantity(mol.conformers[0], "openmm")) nres = omm_top.getNumResidues() resids = [res.index for res in system_modeller.topology.residues()] @@ -190,19 +185,18 @@ def get_omm_modeller( # If there's a protein in the system, we add it first to the Modeller if protein_comp is not None: - system_modeller.add(protein_comp.to_openmm_topology(), - protein_comp.to_openmm_positions()) + system_modeller.add(protein_comp.to_openmm_topology(), protein_comp.to_openmm_positions()) # add missing virtual particles (from crystal waters) system_modeller.addExtraParticles(omm_forcefield) component_resids[protein_comp] = np.array( - [r.index for r in system_modeller.topology.residues()] + [r.index for r in system_modeller.topology.residues()] ) # if we solvate temporarily rename water molecules to 'WAT' # see openmm issue #4103 if solvent_comp is not None: for r in system_modeller.topology.residues(): - if r.name == 'HOH': - r.name = 'WAT' + if r.name == "HOH": + r.name = "WAT" # Now loop through small mols for comp, mol in small_mols.items(): @@ -238,21 +232,14 @@ def get_omm_modeller( numAdded=solvent_settings.number_of_solvent_molecules, ) - all_resids = np.array( - [r.index for r in system_modeller.topology.residues()] - ) + all_resids = np.array([r.index for r in system_modeller.topology.residues()]) - existing_resids = np.concatenate( - [resids for resids in component_resids.values()] - ) + existing_resids = np.concatenate([resids for resids in component_resids.values()]) - component_resids[solvent_comp] = np.setdiff1d( - all_resids, existing_resids - ) + component_resids[solvent_comp] = np.setdiff1d(all_resids, existing_resids) # undo rename of pre-existing waters for r in system_modeller.topology.residues(): - if r.name == 'WAT': - r.name = 'HOH' + if r.name == "WAT": + r.name = "HOH" return system_modeller, component_resids - diff --git a/openfe/protocols/openmm_utils/system_validation.py b/openfe/protocols/openmm_utils/system_validation.py index be3a8bed..ea341e9e 100644 --- a/openfe/protocols/openmm_utils/system_validation.py +++ b/openfe/protocols/openmm_utils/system_validation.py @@ -4,11 +4,15 @@ Reusable utility methods to validate input systems to OpenMM-based alchemical Protocols. """ + from typing import Optional, Tuple from openff.toolkit import Molecule as OFFMol from gufe import ( - Component, ChemicalSystem, SolventComponent, ProteinComponent, - SmallMoleculeComponent + Component, + ChemicalSystem, + SolventComponent, + ProteinComponent, + SmallMoleculeComponent, ) @@ -37,9 +41,10 @@ def get_alchemical_components( ValueError If there are any duplicate components in states A or B. """ - matched_components: dict[Component, Component] = {} + matched_components: dict[Component, Component] = {} alchemical_components: dict[str, list[Component]] = { - 'stateA': [], 'stateB': [], + "stateA": [], + "stateB": [], } for keyA, valA in stateA.components.items(): @@ -50,19 +55,21 @@ def get_alchemical_components( else: # Could be that either we have a duplicate component # in stateA or in stateB - errmsg = (f"state A components {keyA}: {valA} matches " - "multiple components in stateA or stateB") + errmsg = ( + f"state A components {keyA}: {valA} matches " + "multiple components in stateA or stateB" + ) raise ValueError(errmsg) # populate stateA alchemical components for valA in stateA.components.values(): if valA not in matched_components.keys(): - alchemical_components['stateA'].append(valA) + alchemical_components["stateA"].append(valA) # populate stateB alchemical components for valB in stateB.components.values(): if valB not in matched_components.values(): - alchemical_components['stateB'].append(valB) + alchemical_components["stateB"].append(valB) return alchemical_components @@ -87,14 +94,13 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): `nocutoff`. * If the SolventComponent solvent is not water. """ - solv = [comp for comp in state.values() - if isinstance(comp, SolventComponent)] + solv = [comp for comp in state.values() if isinstance(comp, SolventComponent)] if len(solv) > 0 and nonbonded_method.lower() == "nocutoff": errmsg = "nocutoff cannot be used for solvent transformations" raise ValueError(errmsg) - if len(solv) == 0 and nonbonded_method.lower() == 'pme': + if len(solv) == 0 and nonbonded_method.lower() == "pme": errmsg = "PME cannot be used for vacuum transform" raise ValueError(errmsg) @@ -102,7 +108,7 @@ def validate_solvent(state: ChemicalSystem, nonbonded_method: str): errmsg = "Multiple SolventComponent found, only one is supported" raise ValueError(errmsg) - if len(solv) > 0 and solv[0].smiles != 'O': + if len(solv) > 0 and solv[0].smiles != "O": errmsg = "Non water solvent is not currently supported" raise ValueError(errmsg) @@ -122,8 +128,7 @@ def validate_protein(state: ChemicalSystem): ValueError If there are multiple ProteinComponent in the ChemicalSystem. """ - nprot = sum(1 for comp in state.values() - if isinstance(comp, ProteinComponent)) + nprot = sum(1 for comp in state.values() if isinstance(comp, ProteinComponent)) if nprot > 1: errmsg = "Multiple ProteinComponent found, only one is supported" @@ -131,7 +136,8 @@ def validate_protein(state: ChemicalSystem): ParseCompRet = Tuple[ - Optional[SolventComponent], Optional[ProteinComponent], + Optional[SolventComponent], + Optional[ProteinComponent], list[SmallMoleculeComponent], ] @@ -153,9 +159,9 @@ def get_components(state: ChemicalSystem) -> ParseCompRet: If it exists, the ProteinComponent for the state, otherwise None. small_mols : list[SmallMoleculeComponent] """ + def _get_single_comps(comp_list, comptype): - ret_comps = [comp for comp in comp_list - if isinstance(comp, comptype)] + ret_comps = [comp for comp in comp_list if isinstance(comp, comptype)] if ret_comps: return ret_comps[0] else: diff --git a/openfe/protocols/restraint_utils/geometry/base.py b/openfe/protocols/restraint_utils/geometry/base.py index 15e12fa5..8c291de6 100644 --- a/openfe/protocols/restraint_utils/geometry/base.py +++ b/openfe/protocols/restraint_utils/geometry/base.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + import abc from pydantic import BaseModel, ConfigDict, field_validator @@ -16,6 +17,7 @@ class BaseRestraintGeometry(BaseModel, abc.ABC): """ A base class for a restraint geometry. """ + model_config = ConfigDict(arbitrary_types_allowed=True) @@ -42,7 +44,7 @@ class HostGuestRestraintGeometry(BaseRestraintGeometry): @field_validator("guest_atoms", "host_atoms") def positive_idxs(cls, v): - if v is not None and any([i < 0 for i in v]): #TODO: when would None be valid here? + if v is not None and any([i < 0 for i in v]): # TODO: when would None be valid here? errmsg = "negative indices passed" raise ValueError(errmsg) return v diff --git a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py index f8dbc3e7..bc648ed0 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch/geometry.py +++ b/openfe/protocols/restraint_utils/geometry/boresch/geometry.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + from typing import Annotated, Literal, Optional, TypeAlias import MDAnalysis as mda @@ -23,7 +24,8 @@ from .host import ( find_host_atom_candidates, ) -RadiansQuantity:TypeAlias = Annotated[GufeQuantity, specify_quantity_units("radians")] +RadiansQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("radians")] + class BoreschRestraintGeometry(HostGuestRestraintGeometry): """ @@ -138,14 +140,12 @@ def find_boresch_restraint( guest_restraint_atoms_idxs: Optional[list[int]] = None, host_restraint_atoms_idxs: Optional[list[int]] = None, host_selection: str = "all", - anchor_finding_strategy: Literal['multi-residue', 'bonded'] = 'multi-residue', + anchor_finding_strategy: Literal["multi-residue", "bonded"] = "multi-residue", dssp_filter: bool = False, rmsf_cutoff: Quantity = 0.1 * unit.nanometer, host_min_distance: Quantity = 1 * unit.nanometer, host_max_distance: Quantity = 3 * unit.nanometer, - angle_force_constant: Quantity = ( - 83.68 * unit.kilojoule_per_mole / unit.radians**2 - ), + angle_force_constant: Quantity = (83.68 * unit.kilojoule_per_mole / unit.radians**2), temperature: Quantity = 298.15 * unit.kelvin, ) -> BoreschRestraintGeometry: """ @@ -268,7 +268,7 @@ def find_boresch_restraint( max_search_distance=host_max_distance, ) - if anchor_finding_strategy == 'multi-residue': + if anchor_finding_strategy == "multi-residue": host_anchor = find_host_anchor_multi( guest_atoms=universe.atoms[list(guest_anchor)], host_atom_pool=universe.atoms[list(host_pool)], @@ -278,7 +278,7 @@ def find_boresch_restraint( angle_force_constant=angle_force_constant, temperature=temperature, ) - elif anchor_finding_strategy == 'bonded': + elif anchor_finding_strategy == "bonded": host_anchor = find_host_anchor_bonded( guest_atoms=universe.atoms[list(guest_anchor)], host_atom_pool=universe.atoms[list(host_pool)], @@ -288,9 +288,7 @@ def find_boresch_restraint( ) else: # We're doing something we shouldn't be - errmsg = ( - f"Unknown anchor finding strategy: {anchor_finding_strategy}" - ) + errmsg = f"Unknown anchor finding strategy: {anchor_finding_strategy}" raise NotImplementedError(errmsg) # continue if it's empty, otherwise stop diff --git a/openfe/protocols/restraint_utils/geometry/boresch/guest.py b/openfe/protocols/restraint_utils/geometry/boresch/guest.py index 9ca0c462..69380f23 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch/guest.py +++ b/openfe/protocols/restraint_utils/geometry/boresch/guest.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + from typing import Iterable, Optional import MDAnalysis as mda @@ -96,9 +97,7 @@ def _bonded_angles_from_pool( # are from the central atom for at2 in atom_pool: if at2 in at1_neighbors: - at2_neighbors = [ - at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors() - ] + at2_neighbors = [at.GetIdx() for at in rdmol.GetAtomWithIdx(at2).GetNeighbors()] for at3 in atom_pool: if at3 != atom_idx and at3 in at2_neighbors: angles.append((atom_idx, at2, at3)) @@ -124,8 +123,8 @@ def _bonded_angles_from_pool( def _get_guest_atom_pool( rdmol: Chem.Mol, - rmsf, #: ArrayQuantity, TODO: new pydantic v2-compatible quantity needed here. - rmsf_cutoff: Quantity + rmsf, #: ArrayQuantity, TODO: new pydantic v2-compatible quantity needed here. + rmsf_cutoff: Quantity, ) -> tuple[Optional[set[int]], bool]: """ Filter atoms based on rmsf & rings, defaulting to heavy atoms if diff --git a/openfe/protocols/restraint_utils/geometry/boresch/host.py b/openfe/protocols/restraint_utils/geometry/boresch/host.py index 626eab40..2b80e534 100644 --- a/openfe/protocols/restraint_utils/geometry/boresch/host.py +++ b/openfe/protocols/restraint_utils/geometry/boresch/host.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + import warnings from typing import Optional @@ -162,7 +163,7 @@ def find_host_atom_candidates( if len(filtered_host_idxs) < 20: wmsg = ( "Restraint generation: protein chain filter found too few " - f"host atoms ({len(filtered_host_idxs)} found). Will attempt to use all host atoms in " + f"host atoms ({len(filtered_host_idxs)} found). Will attempt to use all host atoms in " f"selection: {host_selection}." ) warnings.warn(wmsg) @@ -209,6 +210,7 @@ class EvaluateBoreschAtoms(AnalysisBase): temperature : openff.units.Quanity The system temperature in units compatible with Kelvin. """ + def __init__( self, restraints: list[mda.AtomGroup], @@ -258,10 +260,10 @@ class EvaluateBoreschAtoms(AnalysisBase): # angles for i in range(1, 3): - self.results.angles[ridx, i-1, self._frame_index] = calc_angles( + self.results.angles[ridx, i - 1, self._frame_index] = calc_angles( restraint.atoms[i].position, - restraint.atoms[i+1].position, - restraint.atoms[i+2].position, + restraint.atoms[i + 1].position, + restraint.atoms[i + 2].position, box=restraint.dimensions, ) @@ -269,9 +271,9 @@ class EvaluateBoreschAtoms(AnalysisBase): for i in range(3): self.results.dihedrals[ridx, i, self._frame_index] = calc_dihedrals( restraint.atoms[i].position, - restraint.atoms[i+1].position, - restraint.atoms[i+2].position, - restraint.atoms[i+3].position, + restraint.atoms[i + 1].position, + restraint.atoms[i + 2].position, + restraint.atoms[i + 3].position, box=restraint.dimensions, ) @@ -437,7 +439,7 @@ class EvaluateHostAtoms1(AnalysisBase): dihed_bounds = all( check_dihedral_bounds(dihed * unit.radians) for dihed in self.results.dihedrals[i] - ) + ) # fmt: skip dihed_variance = check_angular_variance( self.results.dihedrals[i] * unit.radians, upper_bound=np.pi * unit.radians, @@ -528,7 +530,7 @@ class EvaluateHostAtoms2(EvaluateHostAtoms1): dihed_bounds = all( check_dihedral_bounds(dihed * unit.radians) for dihed in self.results.dihedrals[i] - ) + ) # fmt: skip dihed_variance = check_angular_variance( self.results.dihedrals[i] * unit.radians, upper_bound=np.pi * unit.radians, @@ -552,7 +554,7 @@ def _get_lowest_variance_restraint_hostanchor( proposed_restraints: list[mda.AtomGroup], angle_force_constant: Quantity, temperature: Quantity -) -> list[int] | None: +) -> list[int] | None: # fmt: skip """ Evaluate a list of proposed restraints and return the lowest variance valid restraint. @@ -586,19 +588,23 @@ def _get_lowest_variance_restraint_hostanchor( valid_indices = [] valid_variances = [] - for ridx in range(len(proposed_restraints)): + for ridx in range(len(proposed_restraints)): if restraints_eval.results.valid[ridx]: valid_indices.append(ridx) - dih_variance = sum([ - circvar(diheds, high=np.pi, low=-np.pi) - for diheds in restraints_eval.results.dihedrals[ridx] - ]) + dih_variance = sum( + [ + circvar(diheds, high=np.pi, low=-np.pi) + for diheds in restraints_eval.results.dihedrals[ridx] + ] + ) - ang_variance = sum([ - circvar(angles, high=np.pi, low=0) - for angles in restraints_eval.results.angles[ridx] - ]) + ang_variance = sum( + [ + circvar(angles, high=np.pi, low=0) + for angles in restraints_eval.results.angles[ridx] + ] + ) bond_variance = np.var(restraints_eval.results.bonds[ridx]) @@ -641,9 +647,9 @@ def find_host_anchor_bonded( Optional[list[int]] A list of indices for a selected combination of H0, H1, and H2. """ - if not hasattr(guest_atoms, 'angles'): + if not hasattr(guest_atoms, "angles"): warnings.warn("no angles found - will attempt to guess") - guest_atoms.universe.guess_TopologyAttrs(context='default', to_guess=['angles']) + guest_atoms.universe.guess_TopologyAttrs(context="default", to_guess=["angles"]) # Evaluate the host_atom_pool for suitability as H0 atoms h0_eval = EvaluateHostAtoms1( @@ -671,9 +677,7 @@ def find_host_anchor_bonded( else: continue - proposed_restraints.append( - host_atom_pool.universe.atoms[indices] + guest_atoms - ) + proposed_restraints.append(host_atom_pool.universe.atoms[indices] + guest_atoms) # If there are no proposed restraints, return with nothing if len(proposed_restraints) == 0: @@ -760,12 +764,8 @@ def find_host_anchor_multi( if any(h2_eval.results.valid): # Get the sum of the average distances (dsum_avgs) # for all the host_atom_pool atoms - distance1_avgs = np.array( - [d.mean() for d in h2_eval.results.distances1] - ) - distance2_avgs = np.array( - [d.mean() for d in h2_eval.results.distances2] - ) + distance1_avgs = np.array([d.mean() for d in h2_eval.results.distances1]) + distance2_avgs = np.array([d.mean() for d in h2_eval.results.distances2]) dsum_avgs = distance1_avgs + distance2_avgs # Now filter by validity as H2 atom diff --git a/openfe/protocols/restraint_utils/geometry/flatbottom.py b/openfe/protocols/restraint_utils/geometry/flatbottom.py index 6c4f6432..fc2f756d 100644 --- a/openfe/protocols/restraint_utils/geometry/flatbottom.py +++ b/openfe/protocols/restraint_utils/geometry/flatbottom.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + from typing import Optional import MDAnalysis as mda diff --git a/openfe/protocols/restraint_utils/geometry/harmonic.py b/openfe/protocols/restraint_utils/geometry/harmonic.py index d97b9ba9..0957e3e5 100644 --- a/openfe/protocols/restraint_utils/geometry/harmonic.py +++ b/openfe/protocols/restraint_utils/geometry/harmonic.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + from typing import Optional import MDAnalysis as mda diff --git a/openfe/protocols/restraint_utils/geometry/utils.py b/openfe/protocols/restraint_utils/geometry/utils.py index d6836565..cd9937d4 100644 --- a/openfe/protocols/restraint_utils/geometry/utils.py +++ b/openfe/protocols/restraint_utils/geometry/utils.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + import warnings from itertools import combinations, groupby from typing import Optional, Union @@ -15,6 +16,7 @@ import MDAnalysis as mda import networkx as nx import numpy as np import numpy.typing as npt + # from gufe.vendor.openff.models.types import ArrayQuantity TODO: write a custom quantity to replace this in pydantic v2 from MDAnalysis.analysis.base import AnalysisBase from MDAnalysis.analysis.dssp import DSSP @@ -61,9 +63,7 @@ def _get_mda_selection( """ if atom_list is None: if selection is None: - raise ValueError( - "one of either the atom lists or selections must be defined" - ) + raise ValueError("one of either the atom lists or selections must be defined") ag = universe.select_atoms(selection) else: @@ -241,9 +241,7 @@ def is_collinear( v1 = minimize_vectors(v1, box=dimensions) v2 = minimize_vectors(v2, box=dimensions) - normalized_inner_product = np.dot(v1, v2) / np.sqrt( - np.dot(v1, v1) * np.dot(v2, v2) - ) + normalized_inner_product = np.dot(v1, v2) / np.sqrt(np.dot(v1, v1) * np.dot(v2, v2)) result = result or (np.abs(normalized_inner_product) > threshold) return result @@ -519,7 +517,7 @@ class FindHostAtoms(AnalysisBase): # TODO: needs custom type https://github.com/OpenFreeEnergy/openfe/issues/1569 -def get_local_rmsf(atomgroup: mda.AtomGroup): # -> ArrayQuantity: +def get_local_rmsf(atomgroup: mda.AtomGroup): # -> ArrayQuantity: """ Get the RMSF of an AtomGroup when aligned upon itself. @@ -570,11 +568,7 @@ def _atomgroup_has_bonds(atomgroup: Union[mda.AtomGroup, mda.Universe]) -> bool: return False # Assume that any residue with more than one atom should have a bond - if not all( - len(r.atoms.bonds) > 0 - for r in atomgroup.residues - if len(r.atoms) > 1 - ): + if not all(len(r.atoms.bonds) > 0 for r in atomgroup.residues if len(r.atoms) > 1): return False return True diff --git a/openfe/protocols/restraint_utils/openmm/omm_forces.py b/openfe/protocols/restraint_utils/openmm/omm_forces.py index 2df4bc5e..dbd3d130 100644 --- a/openfe/protocols/restraint_utils/openmm/omm_forces.py +++ b/openfe/protocols/restraint_utils/openmm/omm_forces.py @@ -7,6 +7,7 @@ TODO ---- * Add relevant duecredit entries. """ + import numpy as np import openmm diff --git a/openfe/protocols/restraint_utils/openmm/omm_restraints.py b/openfe/protocols/restraint_utils/openmm/omm_restraints.py index a1167736..87bedff5 100644 --- a/openfe/protocols/restraint_utils/openmm/omm_restraints.py +++ b/openfe/protocols/restraint_utils/openmm/omm_restraints.py @@ -14,6 +14,7 @@ TODO * Add relevant duecredit entries. * Add Periodic Torsion Boresch class """ + import abc import numpy as np @@ -81,9 +82,7 @@ class RestraintParameterState(GlobalParameterState): @lambda_restraints.validator # type: ignore def lambda_restraints(self, instance, new_value): if new_value is not None and not (0.0 <= new_value <= 1.0): - errmsg = ( - "lambda_restraints must be between 0.0 and 1.0 " f"and got {new_value}" - ) + errmsg = f"lambda_restraints must be between 0.0 and 1.0 and got {new_value}" raise ValueError(errmsg) # Not crashing out on None to match upstream behaviour return new_value @@ -273,9 +272,7 @@ class BaseRadiallySymmetricRestraintForce(BaseHostGuestRestraints): # Note: this is a throw-away force, so we hard code the # controlling parameter name force = self._get_force(geometry, "lambda_restraints") - corr = force.compute_standard_state_correction( - thermodynamic_state, max_volume="system" - ) + corr = force.compute_standard_state_correction(thermodynamic_state, max_volume="system") dg = corr * thermodynamic_state.kT return from_openmm(dg).to("kilojoule_per_mole") @@ -370,9 +367,7 @@ class FlatBottomBondRestraint(SingleBondMixin, BaseRadiallySymmetricRestraintFor spring_constant = to_openmm(self.settings.spring_constant).value_in_unit_system( omm_unit.md_unit_system ) - well_radius = to_openmm(geometry.well_radius).value_in_unit_system( - omm_unit.md_unit_system - ) + well_radius = to_openmm(geometry.well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintBondForce( spring_constant=spring_constant, well_radius=well_radius, @@ -468,9 +463,7 @@ class CentroidFlatBottomRestraint(BaseRadiallySymmetricRestraintForce): ) # the geometry will take precedence over the settings well_radius = self.settings.well_radius or geometry.well_radius - well_radius = to_openmm(well_radius).value_in_unit_system( - omm_unit.md_unit_system - ) + well_radius = to_openmm(well_radius).value_in_unit_system(omm_unit.md_unit_system) return FlatBottomRestraintForce( spring_constant=spring_constant, well_radius=well_radius, @@ -617,9 +610,7 @@ class BoreschRestraint(BaseHostGuestRestraints): "phi_C0": geometry.phi_C0, } for key, val in parameter_dict.items(): - param_values.append( - to_openmm(val).value_in_unit_system(omm_unit.md_unit_system) - ) + param_values.append(to_openmm(val).value_in_unit_system(omm_unit.md_unit_system)) force.addPerBondParameter(key) force.addGlobalParameter(controlling_parameter_name, 1.0) @@ -632,7 +623,7 @@ class BoreschRestraint(BaseHostGuestRestraints): geometry.guest_atoms[2], ] force.addBond(atoms, param_values) - force.setName('Boresch-like') + force.setName("Boresch-like") return force def get_standard_state_correction( diff --git a/openfe/protocols/restraint_utils/settings.py b/openfe/protocols/restraint_utils/settings.py index 6b5271ac..b51a8140 100644 --- a/openfe/protocols/restraint_utils/settings.py +++ b/openfe/protocols/restraint_utils/settings.py @@ -8,6 +8,7 @@ TODO * Rename from host/guest to molA/molB? * Add all the restraint settings entries. """ + from typing import Annotated, Literal, Optional, TypeAlias from gufe.settings import SettingsBaseModel @@ -15,15 +16,21 @@ from gufe.settings.typing import NanometerQuantity, GufeQuantity, specify_quanti from openff.units import unit from pydantic import ConfigDict, field_validator -SpringConstantLinearQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("kilojoule_per_mole / nm ** 2")] -SpringConstantAngularQuantity: TypeAlias = Annotated[GufeQuantity, specify_quantity_units("kilojoule_per_mole / radians ** 2")] +SpringConstantLinearQuantity: TypeAlias = Annotated[ + GufeQuantity, specify_quantity_units("kilojoule_per_mole / nm ** 2") +] +SpringConstantAngularQuantity: TypeAlias = Annotated[ + GufeQuantity, specify_quantity_units("kilojoule_per_mole / radians ** 2") +] class BaseRestraintSettings(SettingsBaseModel): """ Base class for RestraintSettings objects. """ - model_config = ConfigDict( arbitrary_types_allowed=True) + + model_config = ConfigDict(arbitrary_types_allowed=True) + class DistanceRestraintSettings(BaseRestraintSettings): """ @@ -113,43 +120,31 @@ class BoreschRestraintSettings(BaseRestraintSettings): (2025; DOI 10.26434/chemrxiv-2025-q08ld-v2) """ - K_r: SpringConstantLinearQuantity = ( - 4184.0 * unit.kilojoule_per_mole / unit.nm**2 - ) + K_r: SpringConstantLinearQuantity = 4184.0 * unit.kilojoule_per_mole / unit.nm**2 """ The bond spring constant between H0 and G0. Default 10 kcal/mol/Ų """ - K_thetaA: SpringConstantAngularQuantity = ( - 334.72 * unit.kilojoule_per_mole / unit.radians**2 - ) + K_thetaA: SpringConstantAngularQuantity = 334.72 * unit.kilojoule_per_mole / unit.radians**2 """ The spring constant for the angle formed by H1-H0-G0. Default 80 kcal/mol/rad² """ - K_thetaB: SpringConstantAngularQuantity = ( - 334.72 * unit.kilojoule_per_mole / unit.radians**2 - ) + K_thetaB: SpringConstantAngularQuantity = 334.72 * unit.kilojoule_per_mole / unit.radians**2 """ The spring constant for the angle formed by H0-G0-G1. Default 80 kcal/mol/rad² """ - K_phiA: SpringConstantAngularQuantity = ( - 334.72 * unit.kilojoule_per_mole / unit.radians**2 - ) + K_phiA: SpringConstantAngularQuantity = 334.72 * unit.kilojoule_per_mole / unit.radians**2 """ The equilibrium force constant for the dihedral formed by H2-H1-H0-G0. Default 80 kcal/mol/rad² """ - K_phiB: SpringConstantAngularQuantity = ( - 334.72 * unit.kilojoule_per_mole / unit.radians**2 - ) + K_phiB: SpringConstantAngularQuantity = 334.72 * unit.kilojoule_per_mole / unit.radians**2 """ The equilibrium force constant for the dihedral formed by H1-H0-G0-G1. Default 80 kcal/mol/rad² """ - K_phiC: SpringConstantAngularQuantity = ( - 334.72 * unit.kilojoule_per_mole / unit.radians**2 - ) + K_phiC: SpringConstantAngularQuantity = 334.72 * unit.kilojoule_per_mole / unit.radians**2 """ The equilibrium force constant for the dihedral formed by H0-G0-G1-G2. Default 80 kcal/mol/rad² @@ -190,7 +185,7 @@ class BoreschRestraintSettings(BaseRestraintSettings): # The indices of the guest component atoms to restraint. # If defined, these will override any automatic selection. # """ - anchor_finding_strategy: Literal['multi-residue', 'bonded'] = 'bonded' + anchor_finding_strategy: Literal["multi-residue", "bonded"] = "bonded" """ The Boresch atom picking strategy to use. @@ -199,6 +194,7 @@ class BoreschRestraintSettings(BaseRestraintSettings): * `multi-residue`: pick host atoms which can span multiple residues. """ + # @field_validator("guest_atoms", "host_atoms") # def positive_idxs_list(cls, v): # if v is not None and any([i < 0 for i in v]): diff --git a/pyproject.toml b/pyproject.toml index 7b24b3ff..10ed9371 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3.14", "Topic :: Scientific/Engineering :: Bio-Informatics", "Topic :: Scientific/Engineering :: Chemistry", ] @@ -46,6 +47,28 @@ openfe = [ '"./openfe/tests/data/lomap_basic/toluene.mol2"' ] [tool.setuptools_scm] fallback_version = "0.0.0" +[tool.ruff] +line-length = 100 + +format.exclude = [ "openfe/setup/*", "openfe/storage/*", "openfe/tests/*", "openfe/utils/*", "openfecli/*" ] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +lint.select = [ + # "F", # Pyflakes + # "I", # isort + # "W", # pycodestyle warnings + # "E", # pycodestyle errors + # "C901" # mccabe complexity TODO: add this back in + # "UP", # TODO: add this in +] +lint.ignore = [ + "E402", # module-level import not at top (conflicts w/ isort) + "E722", # bare excepts (TODO: we should fix these in a follow-up PR) + "E731", # lambda expressions (TODO: we should fix these) + "F401", # unused imports (TODO: we should fix these) + "UP03", # pyupgrade linting (TODO: we should fix these) +] +lint.isort.known-first-party = [ "openfe" ] + [tool.coverage.run] omit = [ "openfe/due.py",