mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Feat: cyclic inference (#552)
* feat: inference for cyclic peptides * chore: update README * chore: lint * chore: lint * chore: cleanup if statement
This commit is contained in:
committed by
GitHub
parent
d10cc62c69
commit
de64112d89
@@ -575,6 +575,8 @@ If inputs are given in a form that specifies chirality, the model will receive t
|
||||
|
||||
- **`template_selection`** *(optional)*: Selection syntax to provide token-level templates (for both polymers and non-polymers). Uses `AtomSelection` format. Similar to traditional homology-style templates, but also can be applied to small molecules (less rigidly adhered to than the `ground_truth_conformer_selection` approach). See <a>TODO: HREF TO TEMPLATING, SET ID</a> for more information.
|
||||
|
||||
- **`cyclic_chains`** *(optional)*: List of strings of chain ids that should be cyclized. If given a cif, the model will automatically detect cyclization.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
|
||||
@@ -27,3 +27,5 @@ ligand_dropout_prob: 0.0
|
||||
|
||||
add_residue_is_paired_feature: true
|
||||
|
||||
add_cyclic_bonds: true
|
||||
|
||||
|
||||
@@ -29,3 +29,4 @@ dataset:
|
||||
atomization_prob: ${datasets.atomization_prob}
|
||||
ligand_dropout_prob: ${datasets.ligand_dropout_prob}
|
||||
add_residue_is_paired_feature: ${datasets.add_residue_is_paired_feature}
|
||||
add_cyclic_bonds: ${datasets.add_cyclic_bonds}
|
||||
|
||||
@@ -21,6 +21,7 @@ early_stopping_plddt_threshold: 0.5
|
||||
seed: null
|
||||
print_config: true
|
||||
raise_if_missing_msa_for_protein_of_length_n: null
|
||||
cyclic_chains: []
|
||||
|
||||
# Metrics
|
||||
metrics_cfg:
|
||||
|
||||
@@ -98,7 +98,7 @@ from atomworks.ml.transforms.random_atomize_residues import RandomAtomizeResidue
|
||||
from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
|
||||
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
||||
from omegaconf import DictConfig
|
||||
from rf3.data.cyclic_transorm import AddCyclicBonds
|
||||
from rf3.data.cyclic_transform import AddCyclicBonds
|
||||
from rf3.data.extra_xforms import CheckForNaNsInInputs
|
||||
from rf3.data.pipeline_utils import (
|
||||
annotate_post_crop_hash,
|
||||
@@ -186,7 +186,7 @@ def build_af3_transform_pipeline(
|
||||
p_dropout_atom_level_embeddings: float = 0.0,
|
||||
embedding_dim: int = 384,
|
||||
n_conformers: int = 8,
|
||||
add_cyclic_bonds: bool = False,
|
||||
add_cyclic_bonds: bool = True,
|
||||
p_dropout_ref_conf: float = 0.0, # Unused
|
||||
):
|
||||
"""Build the AF3 pipeline with specified parameters.
|
||||
|
||||
@@ -93,6 +93,7 @@ class RF3InferenceEngine:
|
||||
metrics_cfg: dict | OmegaConf | None = None,
|
||||
num_nodes: int = 1,
|
||||
devices_per_node: int = 1,
|
||||
cyclic_chains: list[str] = [],
|
||||
# Debug
|
||||
print_config: bool = False,
|
||||
raise_if_missing_msa_for_protein_of_length_n: int | None = None,
|
||||
@@ -171,6 +172,8 @@ class RF3InferenceEngine:
|
||||
"p_dropout_ref_conf": 0.0,
|
||||
}
|
||||
|
||||
self.cyclic_chains = cyclic_chains
|
||||
|
||||
self.print_config = print_config
|
||||
|
||||
# Set random seed (only if seed is not None)
|
||||
@@ -307,6 +310,11 @@ class RF3InferenceEngine:
|
||||
else:
|
||||
raise ValueError(f"Unsupported inputs type: {type(inputs)}")
|
||||
|
||||
# Flag chains for cyclization if specified
|
||||
if self.cyclic_chains:
|
||||
for input_spec in inference_inputs:
|
||||
input_spec.cyclic_chains = self.cyclic_chains
|
||||
|
||||
# make InferenceInputDataset
|
||||
inference_dataset = InferenceInputDataset(inference_inputs)
|
||||
ranked_logger.info(f"Found {len(inference_dataset)} structures to predict!")
|
||||
|
||||
@@ -66,6 +66,7 @@ class InferenceInput:
|
||||
example_id: str
|
||||
template_selection: list[str] | None = None
|
||||
ground_truth_conformer_selection: list[str] | None = None
|
||||
cyclic_chains: list[str] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_cif_path(
|
||||
@@ -271,6 +272,9 @@ class InferenceInput:
|
||||
ground_truth_conformer_selection=self.ground_truth_conformer_selection,
|
||||
)
|
||||
|
||||
if self.cyclic_chains:
|
||||
atom_array = cyclize_atom_array(atom_array, self.cyclic_chains)
|
||||
|
||||
return {
|
||||
"example_id": self.example_id,
|
||||
"atom_array": atom_array,
|
||||
@@ -580,6 +584,47 @@ def apply_conformer_and_template_selections(
|
||||
return atom_array
|
||||
|
||||
|
||||
def cyclize_atom_array(atom_array: AtomArray, cyclic_chains: list[str]) -> AtomArray:
|
||||
"""Cyclize the atom array by positioining the termini properly if not already done.
|
||||
|
||||
Behavior:
|
||||
- Positions the last carbon atom in the chain to be 1.3 Angstroms away from the first nitrogen atom if they are not already close.
|
||||
- Adds a bond between the termini for proper cif output.
|
||||
|
||||
Args:
|
||||
atom_array: AtomArray to cyclize.
|
||||
cyclic_chains: List of chain IDs to cyclize.
|
||||
|
||||
Returns:
|
||||
The same AtomArray with the specified chains cyclized.
|
||||
"""
|
||||
for chain in cyclic_chains:
|
||||
# Find the first nitrogen atom in the chain
|
||||
nitrogen_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "N")
|
||||
nitrogen_mask_indices = np.where(nitrogen_mask)[0]
|
||||
first_nitrogen_index = nitrogen_mask_indices[0]
|
||||
nitrogen_coord = atom_array.coord[first_nitrogen_index]
|
||||
|
||||
# move the last carbon atom in the chain to be 1.3 Angstroms away from the nitrogen
|
||||
carbon_mask = (atom_array.chain_id == chain) & (atom_array.atom_name == "C")
|
||||
carbon_mask_indices = np.where(carbon_mask)[0]
|
||||
last_carbon_index = carbon_mask_indices[-1]
|
||||
# check if the last carbon is already close to the nitrogen
|
||||
termini_distance = np.linalg.norm(
|
||||
atom_array.coord[last_carbon_index] - nitrogen_coord
|
||||
)
|
||||
if not (termini_distance < 1.5 and termini_distance > 0.5):
|
||||
atom_array.coord[last_carbon_index] = nitrogen_coord + np.array(
|
||||
[1.3, 0.0, 0.0]
|
||||
)
|
||||
|
||||
# add a bond between the nitrogen and carbon so output cif has a connection
|
||||
atom_array.bonds.add_bond(first_nitrogen_index, last_carbon_index)
|
||||
atom_array.bonds.add_bond(last_carbon_index, first_nitrogen_index)
|
||||
|
||||
return atom_array
|
||||
|
||||
|
||||
class InferenceInputDataset(Dataset):
|
||||
"""
|
||||
Dataset for inference inputs. Also has a length key telling you the number of tokens in each example for LoadBalancedDistributedSampler.
|
||||
|
||||
Reference in New Issue
Block a user