fix: reference conformers (#235)

This commit is contained in:
Nathaniel Corley
2026-03-06 12:56:10 -08:00
committed by GitHub
parent 75986084a6
commit 4686f7300c
7 changed files with 92 additions and 3 deletions

View File

@@ -20,4 +20,6 @@ sharding_pattern: null
skip_existing: false
template_selection: null
ground_truth_conformer_selection: null
fallback_conformer_to_input_coords: true
cyclic_chains: []
add_missing_atoms: true

View File

@@ -6,7 +6,7 @@ defaults:
_target_: rf3.inference_engines.rf3.RF3InferenceEngine
ckpt_path: rf3_foundry_01_24_latest.ckpt
ckpt_path: rf3
# Transform arguments
n_recycles: 10

View File

@@ -1,9 +1,62 @@
import logging
import numpy as np
import torch
from atomworks.ml.transforms._checks import (
check_contains_keys,
)
from atomworks.ml.transforms.base import Transform
logger = logging.getLogger(__name__)
def patch_conformer_fallback_to_input_coords() -> None:
"""Monkey-patch sample_rdkit_conformer_for_atom_array to use input
coordinates instead of zeros when conformer generation fails.
This is applied once at pipeline-build time when
``fallback_conformer_to_input_coords=True``.
"""
import atomworks.ml.transforms.af3_reference_molecule as _af3_ref
import atomworks.ml.transforms.rdkit_utils as _rdkit_utils
if getattr(
_rdkit_utils.sample_rdkit_conformer_for_atom_array,
"_input_coord_fallback_patched",
False,
):
return # already patched
_orig = _rdkit_utils.sample_rdkit_conformer_for_atom_array
def _patched(atom_array, *args, **kwargs):
original_coord = atom_array.coord.copy()
result = _orig(atom_array, *args, **kwargs)
# _orig may return (AtomArray, mol) when return_mol=True
aa_result = result[0] if isinstance(result, tuple) else result
mol = result[1] if isinstance(result, tuple) else None
if np.all(aa_result.coord == 0) or np.any(np.isnan(aa_result.coord)):
logger.warning(
f"Conformer generation failed for {atom_array.res_name[0]}; "
"using input coordinates as fallback."
)
aa_result.coord = original_coord
# Also add the fallback coords as a conformer on the mol so that downstream
# steps (e.g. GetRDKitChiralCenters) don't attempt to re-generate conformers.
if mol is not None and mol.GetNumConformers() == 0:
from rdkit.Chem import Conformer as _Conformer
conf = _Conformer(mol.GetNumAtoms())
for i in range(min(mol.GetNumAtoms(), len(original_coord))):
conf.SetAtomPosition(i, original_coord[i].tolist())
mol.AddConformer(conf, assignId=True)
return result
_patched._input_coord_fallback_patched = True
_rdkit_utils.sample_rdkit_conformer_for_atom_array = _patched
# af3_reference_molecule imports the function directly, so patch that reference too
_af3_ref.sample_rdkit_conformer_for_atom_array = _patched
class CheckForNaNsInInputs(Transform):
"""

View File

@@ -99,7 +99,10 @@ from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
from omegaconf import DictConfig
from rf3.data.cyclic_transform import AddCyclicBonds
from rf3.data.extra_xforms import CheckForNaNsInInputs
from rf3.data.extra_xforms import (
CheckForNaNsInInputs,
patch_conformer_fallback_to_input_coords,
)
from rf3.data.pipeline_utils import (
annotate_post_crop_hash,
annotate_pre_crop_hash,
@@ -189,6 +192,7 @@ def build_af3_transform_pipeline(
add_cyclic_bonds: bool = True,
metrics_tags: list[str] | set[str] | None = None,
p_dropout_ref_conf: float = 0.0, # Unused
fallback_conformer_to_input_coords: bool = True,
):
"""Build the AF3 pipeline with specified parameters.
@@ -238,6 +242,9 @@ def build_af3_transform_pipeline(
crop_center_cutoff_distance > 0
), "Crop center cutoff distance must be greater than 0"
if fallback_conformer_to_input_coords:
patch_conformer_fallback_to_input_coords()
af3_sequence_encoding = AF3SequenceEncoding()
rf2aa_sequence_encoding = RF2AA_ATOM36_ENCODING

View File

@@ -50,6 +50,7 @@ def run_inference(cfg: DictConfig) -> None:
"ground_truth_conformer_selection", None
),
"cyclic_chains": cfg.get("cyclic_chains", []),
"add_missing_atoms": cfg.get("add_missing_atoms", True),
}
# Create init config with only __init__ params

View File

@@ -247,6 +247,8 @@ class RF3InferenceEngine(BaseInferenceEngine):
# Templating, MSAs, etc.
template_noise_scale: float = 1e-5,
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
# Conformer generation
fallback_conformer_to_input_coords: bool = True,
# Output control
compress_outputs: bool = False,
early_stopping_plddt_threshold: float | None = None,
@@ -264,6 +266,9 @@ class RF3InferenceEngine(BaseInferenceEngine):
num_steps: Number of diffusion steps. Defaults to ``50``.
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
fallback_conformer_to_input_coords: If True, residues with unknown CCD codes that fail
conformer generation will use their input PDB coordinates (centered) instead of zeros.
Defaults to ``False``.
compress_outputs: Whether to gzip output files. Defaults to ``False``.
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
metrics_cfg: Metrics configuration. Can be:
@@ -296,6 +301,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
"diffusion_batch_size": diffusion_batch_size,
"n_recycles": n_recycles,
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
"fallback_conformer_to_input_coords": fallback_conformer_to_input_coords,
"undesired_res_names": [],
"template_noise_scales": {
"atomized": template_noise_scale,
@@ -391,6 +397,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
template_selection: list[str] | str | None = None,
ground_truth_conformer_selection: list[str] | str | None = None,
cyclic_chains: list[str] = [],
add_missing_atoms: bool = True,
) -> dict[str, dict] | None:
"""Run inference on inputs.
@@ -408,6 +415,9 @@ class RF3InferenceEngine(BaseInferenceEngine):
template_selection: Template selection override. Defaults to ``None``.
ground_truth_conformer_selection: Conformer selection override. Defaults to ``None``.
cyclic_chains: List of chain IDs to cyclize. Defaults to ``[]``.
add_missing_atoms: Whether to add missing atoms from the CCD when parsing CIF/PDB
inputs. Has no effect when inputs are already InferenceInput or AtomArray objects.
Defaults to ``True``.
Returns:
If ``out_dir`` is None: Dict mapping example_id to list of RF3Output objects.
@@ -472,6 +482,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
sharding_pattern=sharding_pattern,
template_selection=template_selection,
ground_truth_conformer_selection=ground_truth_conformer_selection,
add_missing_atoms=add_missing_atoms,
)
else:
raise ValueError(f"Unsupported inputs type: {type(inputs)}")

View File

@@ -75,6 +75,7 @@ class InferenceInput:
example_id: str | None = None,
template_selection: list[str] | str | None = None,
ground_truth_conformer_selection: list[str] | str | None = None,
add_missing_atoms: bool = True,
) -> "InferenceInput":
"""Load from CIF/PDB file.
@@ -83,11 +84,19 @@ class InferenceInput:
example_id: Example ID. Defaults to filename stem.
template_selection: Template selection override.
ground_truth_conformer_selection: Conformer selection override.
add_missing_atoms: Whether to add missing atoms from the CCD to partially/fully
unresolved residues. Defaults to True (atomworks parser default). Set to False
to keep only the atoms present in the file.
Returns:
InferenceInput object.
"""
parsed = parse(path, hydrogen_policy="remove", keep_cif_block=True)
parsed = parse(
path,
hydrogen_policy="remove",
keep_cif_block=True,
add_missing_atoms=add_missing_atoms,
)
atom_array = (
parsed["assemblies"]["1"][0]
@@ -288,6 +297,7 @@ def _process_single_path(
sharding_pattern: str | None,
template_selection: list[str] | str | None,
ground_truth_conformer_selection: list[str] | str | None,
add_missing_atoms: bool = True,
) -> list[InferenceInput]:
"""Worker function to process a single input file path.
@@ -299,6 +309,7 @@ def _process_single_path(
sharding_pattern: Sharding pattern for output paths.
template_selection: Override for template selection.
ground_truth_conformer_selection: Override for conformer selection.
add_missing_atoms: Whether to add missing atoms from the CCD. Defaults to True.
Returns:
List of InferenceInput objects (may be empty if file is skipped).
@@ -345,6 +356,7 @@ def _process_single_path(
example_id=example_id,
template_selection=template_selection,
ground_truth_conformer_selection=ground_truth_conformer_selection,
add_missing_atoms=add_missing_atoms,
)
)
else:
@@ -362,6 +374,7 @@ def prepare_inference_inputs_from_paths(
sharding_pattern: str | None = None,
template_selection: list[str] | str | None = None,
ground_truth_conformer_selection: list[str] | str | None = None,
add_missing_atoms: bool = True,
) -> list[InferenceInput]:
"""Load InferenceInput objects from file paths.
@@ -374,6 +387,7 @@ def prepare_inference_inputs_from_paths(
sharding_pattern: Sharding pattern for output paths.
template_selection: Override for template selection (applied to all inputs).
ground_truth_conformer_selection: Override for conformer selection (applied to all inputs).
add_missing_atoms: Whether to add missing atoms from the CCD. Defaults to True.
Returns:
List of InferenceInput objects.
@@ -413,6 +427,7 @@ def prepare_inference_inputs_from_paths(
sharding_pattern,
template_selection,
ground_truth_conformer_selection,
add_missing_atoms,
)
for path in paths_to_raw_input_files
]