Feat: cyclic inference (#552)

* feat: inference for cyclic peptides

* chore: update README

* chore: lint

* chore: lint

* chore: cleanup if statement
This commit is contained in:
Tuscan R Thompson
2025-10-24 16:27:42 -07:00
committed by GitHub
parent d10cc62c69
commit de64112d89
8 changed files with 61 additions and 2 deletions

View File

@@ -575,6 +575,8 @@ If inputs are given in a form that specifies chirality, the model will receive t
- **`template_selection`** *(optional)*: Selection syntax to provide token-level templates (for both polymers and non-polymers). Uses `AtomSelection` format. Similar to traditional homology-style templates, but also can be applied to small molecules (less rigidly adhered to than the `ground_truth_conformer_selection` approach). See <a>TODO: HREF TO TEMPLATING, SET ID</a> for more information.
- **`cyclic_chains`** *(optional)*: List of strings of chain ids that should be cyclized. If given a cif, the model will automatically detect cyclization.
</details>
<details>

View File

@@ -27,3 +27,5 @@ ligand_dropout_prob: 0.0
add_residue_is_paired_feature: true
add_cyclic_bonds: true

View File

@@ -29,3 +29,4 @@ dataset:
atomization_prob: ${datasets.atomization_prob}
ligand_dropout_prob: ${datasets.ligand_dropout_prob}
add_residue_is_paired_feature: ${datasets.add_residue_is_paired_feature}
add_cyclic_bonds: ${datasets.add_cyclic_bonds}

View File

@@ -21,6 +21,7 @@ early_stopping_plddt_threshold: 0.5
seed: null
print_config: true
raise_if_missing_msa_for_protein_of_length_n: null
cyclic_chains: []
# Metrics
metrics_cfg:

View File

@@ -98,7 +98,7 @@ from atomworks.ml.transforms.random_atomize_residues import RandomAtomizeResidue
from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
from omegaconf import DictConfig
from rf3.data.cyclic_transorm import AddCyclicBonds
from rf3.data.cyclic_transform import AddCyclicBonds
from rf3.data.extra_xforms import CheckForNaNsInInputs
from rf3.data.pipeline_utils import (
annotate_post_crop_hash,
@@ -186,7 +186,7 @@ def build_af3_transform_pipeline(
p_dropout_atom_level_embeddings: float = 0.0,
embedding_dim: int = 384,
n_conformers: int = 8,
add_cyclic_bonds: bool = False,
add_cyclic_bonds: bool = True,
p_dropout_ref_conf: float = 0.0, # Unused
):
"""Build the AF3 pipeline with specified parameters.

View File

@@ -93,6 +93,7 @@ class RF3InferenceEngine:
metrics_cfg: dict | OmegaConf | None = None,
num_nodes: int = 1,
devices_per_node: int = 1,
cyclic_chains: list[str] = [],
# Debug
print_config: bool = False,
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
@@ -171,6 +172,8 @@ class RF3InferenceEngine:
"p_dropout_ref_conf": 0.0,
}
self.cyclic_chains = cyclic_chains
self.print_config = print_config
# Set random seed (only if seed is not None)
@@ -307,6 +310,11 @@ class RF3InferenceEngine:
else:
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
# Flag chains for cyclization if specified
if self.cyclic_chains:
for input_spec in inference_inputs:
input_spec.cyclic_chains = self.cyclic_chains
# make InferenceInputDataset
inference_dataset = InferenceInputDataset(inference_inputs)
ranked_logger.info(f"Found {len(inference_dataset)} structures to predict!")

View File

@@ -66,6 +66,7 @@ class InferenceInput:
example_id: str
template_selection: list[str] | None = None
ground_truth_conformer_selection: list[str] | None = None
cyclic_chains: list[str] | None = None
@classmethod
def from_cif_path(
@@ -271,6 +272,9 @@ class InferenceInput:
ground_truth_conformer_selection=self.ground_truth_conformer_selection,
)
if self.cyclic_chains:
atom_array = cyclize_atom_array(atom_array, self.cyclic_chains)
return {
"example_id": self.example_id,
"atom_array": atom_array,
@@ -580,6 +584,47 @@ def apply_conformer_and_template_selections(
return atom_array
def cyclize_atom_array(atom_array: AtomArray, cyclic_chains: list[str]) -> AtomArray:
"""Cyclize the atom array by positioining the termini properly if not already done.
Behavior:
- Positions the last carbon atom in the chain to be 1.3 Angstroms away from the first nitrogen atom if they are not already close.
- Adds a bond between the termini for proper cif output.
Args:
atom_array: AtomArray to cyclize.
cyclic_chains: List of chain IDs to cyclize.
Returns:
The same AtomArray with the specified chains cyclized.
"""
for chain in cyclic_chains:
# Find the first nitrogen atom in the chain
nitrogen_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "N")
nitrogen_mask_indices = np.where(nitrogen_mask)[0]
first_nitrogen_index = nitrogen_mask_indices[0]
nitrogen_coord = atom_array.coord[first_nitrogen_index]
# move the last carbon atom in the chain to be 1.3 Angstroms away from the nitrogen
carbon_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "C")
carbon_mask_indices = np.where(carbon_mask)[0]
last_carbon_index = carbon_mask_indices[-1]
# check if the last carbon is already close to the nitrogen
termini_distance = np.linalg.norm(
atom_array.coord[last_carbon_index] - nitrogen_coord
)
if not (termini_distance < 1.5 and termini_distance > 0.5):
atom_array.coord[last_carbon_index] = nitrogen_coord + np.array(
[1.3, 0.0, 0.0]
)
# add a bond between the nitrogen and carbon so output cif has a connection
atom_array.bonds.add_bond(first_nitrogen_index, last_carbon_index)
atom_array.bonds.add_bond(last_carbon_index, first_nitrogen_index)
return atom_array
class InferenceInputDataset(Dataset):
"""
Dataset for inference inputs. Also has a length key telling you the number of tokens in each example for LoadBalancedDistributedSampler.