mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: reference conformers (#235)
This commit is contained in:
@@ -20,4 +20,6 @@ sharding_pattern: null
|
||||
skip_existing: false
|
||||
template_selection: null
|
||||
ground_truth_conformer_selection: null
|
||||
fallback_conformer_to_input_coords: true
|
||||
cyclic_chains: []
|
||||
add_missing_atoms: true
|
||||
|
||||
@@ -6,7 +6,7 @@ defaults:
|
||||
|
||||
_target_: rf3.inference_engines.rf3.RF3InferenceEngine
|
||||
|
||||
ckpt_path: rf3_foundry_01_24_latest.ckpt
|
||||
ckpt_path: rf3
|
||||
|
||||
# Transform arguments
|
||||
n_recycles: 10
|
||||
|
||||
@@ -1,9 +1,62 @@
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.ml.transforms._checks import (
|
||||
check_contains_keys,
|
||||
)
|
||||
from atomworks.ml.transforms.base import Transform
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def patch_conformer_fallback_to_input_coords() -> None:
|
||||
"""Monkey-patch sample_rdkit_conformer_for_atom_array to use input
|
||||
coordinates instead of zeros when conformer generation fails.
|
||||
|
||||
This is applied once at pipeline-build time when
|
||||
``fallback_conformer_to_input_coords=True``.
|
||||
"""
|
||||
import atomworks.ml.transforms.af3_reference_molecule as _af3_ref
|
||||
import atomworks.ml.transforms.rdkit_utils as _rdkit_utils
|
||||
|
||||
if getattr(
|
||||
_rdkit_utils.sample_rdkit_conformer_for_atom_array,
|
||||
"_input_coord_fallback_patched",
|
||||
False,
|
||||
):
|
||||
return # already patched
|
||||
|
||||
_orig = _rdkit_utils.sample_rdkit_conformer_for_atom_array
|
||||
|
||||
def _patched(atom_array, *args, **kwargs):
|
||||
original_coord = atom_array.coord.copy()
|
||||
result = _orig(atom_array, *args, **kwargs)
|
||||
# _orig may return (AtomArray, mol) when return_mol=True
|
||||
aa_result = result[0] if isinstance(result, tuple) else result
|
||||
mol = result[1] if isinstance(result, tuple) else None
|
||||
if np.all(aa_result.coord == 0) or np.any(np.isnan(aa_result.coord)):
|
||||
logger.warning(
|
||||
f"Conformer generation failed for {atom_array.res_name[0]}; "
|
||||
"using input coordinates as fallback."
|
||||
)
|
||||
aa_result.coord = original_coord
|
||||
# Also add the fallback coords as a conformer on the mol so that downstream
|
||||
# steps (e.g. GetRDKitChiralCenters) don't attempt to re-generate conformers.
|
||||
if mol is not None and mol.GetNumConformers() == 0:
|
||||
from rdkit.Chem import Conformer as _Conformer
|
||||
|
||||
conf = _Conformer(mol.GetNumAtoms())
|
||||
for i in range(min(mol.GetNumAtoms(), len(original_coord))):
|
||||
conf.SetAtomPosition(i, original_coord[i].tolist())
|
||||
mol.AddConformer(conf, assignId=True)
|
||||
return result
|
||||
|
||||
_patched._input_coord_fallback_patched = True
|
||||
_rdkit_utils.sample_rdkit_conformer_for_atom_array = _patched
|
||||
# af3_reference_molecule imports the function directly, so patch that reference too
|
||||
_af3_ref.sample_rdkit_conformer_for_atom_array = _patched
|
||||
|
||||
|
||||
class CheckForNaNsInInputs(Transform):
|
||||
"""
|
||||
|
||||
@@ -99,7 +99,10 @@ from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
|
||||
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
||||
from omegaconf import DictConfig
|
||||
from rf3.data.cyclic_transform import AddCyclicBonds
|
||||
from rf3.data.extra_xforms import CheckForNaNsInInputs
|
||||
from rf3.data.extra_xforms import (
|
||||
CheckForNaNsInInputs,
|
||||
patch_conformer_fallback_to_input_coords,
|
||||
)
|
||||
from rf3.data.pipeline_utils import (
|
||||
annotate_post_crop_hash,
|
||||
annotate_pre_crop_hash,
|
||||
@@ -189,6 +192,7 @@ def build_af3_transform_pipeline(
|
||||
add_cyclic_bonds: bool = True,
|
||||
metrics_tags: list[str] | set[str] | None = None,
|
||||
p_dropout_ref_conf: float = 0.0, # Unused
|
||||
fallback_conformer_to_input_coords: bool = True,
|
||||
):
|
||||
"""Build the AF3 pipeline with specified parameters.
|
||||
|
||||
@@ -238,6 +242,9 @@ def build_af3_transform_pipeline(
|
||||
crop_center_cutoff_distance > 0
|
||||
), "Crop center cutoff distance must be greater than 0"
|
||||
|
||||
if fallback_conformer_to_input_coords:
|
||||
patch_conformer_fallback_to_input_coords()
|
||||
|
||||
af3_sequence_encoding = AF3SequenceEncoding()
|
||||
rf2aa_sequence_encoding = RF2AA_ATOM36_ENCODING
|
||||
|
||||
|
||||
@@ -50,6 +50,7 @@ def run_inference(cfg: DictConfig) -> None:
|
||||
"ground_truth_conformer_selection", None
|
||||
),
|
||||
"cyclic_chains": cfg.get("cyclic_chains", []),
|
||||
"add_missing_atoms": cfg.get("add_missing_atoms", True),
|
||||
}
|
||||
|
||||
# Create init config with only __init__ params
|
||||
|
||||
@@ -247,6 +247,8 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
# Templating, MSAs, etc.
|
||||
template_noise_scale: float = 1e-5,
|
||||
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
||||
# Conformer generation
|
||||
fallback_conformer_to_input_coords: bool = True,
|
||||
# Output control
|
||||
compress_outputs: bool = False,
|
||||
early_stopping_plddt_threshold: float | None = None,
|
||||
@@ -264,6 +266,9 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
num_steps: Number of diffusion steps. Defaults to ``50``.
|
||||
template_noise_scale: Noise scale for template coordinates. Defaults to ``1e-5``.
|
||||
raise_if_missing_msa_for_protein_of_length_n: Debug flag for MSA checking. Defaults to ``None``.
|
||||
fallback_conformer_to_input_coords: If True, residues with unknown CCD codes that fail
|
||||
conformer generation will use their input PDB coordinates (centered) instead of zeros.
|
||||
Defaults to ``False``.
|
||||
compress_outputs: Whether to gzip output files. Defaults to ``False``.
|
||||
early_stopping_plddt_threshold: Stop early if pLDDT below threshold. Defaults to ``None``.
|
||||
metrics_cfg: Metrics configuration. Can be:
|
||||
@@ -296,6 +301,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
"diffusion_batch_size": diffusion_batch_size,
|
||||
"n_recycles": n_recycles,
|
||||
"raise_if_missing_msa_for_protein_of_length_n": raise_if_missing_msa_for_protein_of_length_n,
|
||||
"fallback_conformer_to_input_coords": fallback_conformer_to_input_coords,
|
||||
"undesired_res_names": [],
|
||||
"template_noise_scales": {
|
||||
"atomized": template_noise_scale,
|
||||
@@ -391,6 +397,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
template_selection: list[str] | str | None = None,
|
||||
ground_truth_conformer_selection: list[str] | str | None = None,
|
||||
cyclic_chains: list[str] = [],
|
||||
add_missing_atoms: bool = True,
|
||||
) -> dict[str, dict] | None:
|
||||
"""Run inference on inputs.
|
||||
|
||||
@@ -408,6 +415,9 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
template_selection: Template selection override. Defaults to ``None``.
|
||||
ground_truth_conformer_selection: Conformer selection override. Defaults to ``None``.
|
||||
cyclic_chains: List of chain IDs to cyclize. Defaults to ``[]``.
|
||||
add_missing_atoms: Whether to add missing atoms from the CCD when parsing CIF/PDB
|
||||
inputs. Has no effect when inputs are already InferenceInput or AtomArray objects.
|
||||
Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
If ``out_dir`` is None: Dict mapping example_id to list of RF3Output objects.
|
||||
@@ -472,6 +482,7 @@ class RF3InferenceEngine(BaseInferenceEngine):
|
||||
sharding_pattern=sharding_pattern,
|
||||
template_selection=template_selection,
|
||||
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
||||
add_missing_atoms=add_missing_atoms,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
|
||||
|
||||
@@ -75,6 +75,7 @@ class InferenceInput:
|
||||
example_id: str | None = None,
|
||||
template_selection: list[str] | str | None = None,
|
||||
ground_truth_conformer_selection: list[str] | str | None = None,
|
||||
add_missing_atoms: bool = True,
|
||||
) -> "InferenceInput":
|
||||
"""Load from CIF/PDB file.
|
||||
|
||||
@@ -83,11 +84,19 @@ class InferenceInput:
|
||||
example_id: Example ID. Defaults to filename stem.
|
||||
template_selection: Template selection override.
|
||||
ground_truth_conformer_selection: Conformer selection override.
|
||||
add_missing_atoms: Whether to add missing atoms from the CCD to partially/fully
|
||||
unresolved residues. Defaults to True (atomworks parser default). Set to False
|
||||
to keep only the atoms present in the file.
|
||||
|
||||
Returns:
|
||||
InferenceInput object.
|
||||
"""
|
||||
parsed = parse(path, hydrogen_policy="remove", keep_cif_block=True)
|
||||
parsed = parse(
|
||||
path,
|
||||
hydrogen_policy="remove",
|
||||
keep_cif_block=True,
|
||||
add_missing_atoms=add_missing_atoms,
|
||||
)
|
||||
|
||||
atom_array = (
|
||||
parsed["assemblies"]["1"][0]
|
||||
@@ -288,6 +297,7 @@ def _process_single_path(
|
||||
sharding_pattern: str | None,
|
||||
template_selection: list[str] | str | None,
|
||||
ground_truth_conformer_selection: list[str] | str | None,
|
||||
add_missing_atoms: bool = True,
|
||||
) -> list[InferenceInput]:
|
||||
"""Worker function to process a single input file path.
|
||||
|
||||
@@ -299,6 +309,7 @@ def _process_single_path(
|
||||
sharding_pattern: Sharding pattern for output paths.
|
||||
template_selection: Override for template selection.
|
||||
ground_truth_conformer_selection: Override for conformer selection.
|
||||
add_missing_atoms: Whether to add missing atoms from the CCD. Defaults to True.
|
||||
|
||||
Returns:
|
||||
List of InferenceInput objects (may be empty if file is skipped).
|
||||
@@ -345,6 +356,7 @@ def _process_single_path(
|
||||
example_id=example_id,
|
||||
template_selection=template_selection,
|
||||
ground_truth_conformer_selection=ground_truth_conformer_selection,
|
||||
add_missing_atoms=add_missing_atoms,
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -362,6 +374,7 @@ def prepare_inference_inputs_from_paths(
|
||||
sharding_pattern: str | None = None,
|
||||
template_selection: list[str] | str | None = None,
|
||||
ground_truth_conformer_selection: list[str] | str | None = None,
|
||||
add_missing_atoms: bool = True,
|
||||
) -> list[InferenceInput]:
|
||||
"""Load InferenceInput objects from file paths.
|
||||
|
||||
@@ -374,6 +387,7 @@ def prepare_inference_inputs_from_paths(
|
||||
sharding_pattern: Sharding pattern for output paths.
|
||||
template_selection: Override for template selection (applied to all inputs).
|
||||
ground_truth_conformer_selection: Override for conformer selection (applied to all inputs).
|
||||
add_missing_atoms: Whether to add missing atoms from the CCD. Defaults to True.
|
||||
|
||||
Returns:
|
||||
List of InferenceInput objects.
|
||||
@@ -413,6 +427,7 @@ def prepare_inference_inputs_from_paths(
|
||||
sharding_pattern,
|
||||
template_selection,
|
||||
ground_truth_conformer_selection,
|
||||
add_missing_atoms,
|
||||
)
|
||||
for path in paths_to_raw_input_files
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user