mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Update to aa_design/main latest (post loop fix)
This commit is contained in:
@@ -6,5 +6,5 @@ defaults:
|
||||
dataset:
|
||||
name: rigid-ligand-enzymes
|
||||
eval_every_n: 1
|
||||
data: ${paths.data.design_benchmark_data_dir}/mcsa_41_short_rigid.json
|
||||
data: ${paths.data.design_benchmark_data_dir}/mcsa_41_short_rigid_new.json
|
||||
|
||||
|
||||
@@ -26,6 +26,9 @@ OPTIONAL_CONDITIONING_VALUES = {
|
||||
"ref_plddt": 0,
|
||||
"is_non_loopy": 0,
|
||||
"partial_t": np.nan,
|
||||
# kept for legacy reasons
|
||||
"is_motif_token": 1,
|
||||
"is_motif_atom": 1,
|
||||
}
|
||||
"""Optional conditioning annotations and their default values if not provided."""
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@
|
||||
|
||||
| Field | Type | Description |
|
||||
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |
|
||||
| `input` | `str?` | Path to input **PDB/CIF**. Required if you provide contig+length. |
|
||||
| `input` | `str?` | Path to input **PDB/CIF**. Required if you provide contig+length. |
|
||||
| `atom_array_input` | internal | Pre-loaded `AtomArray` (not recommended). |
|
||||
| `contig` | `InputSelection?` | Indexed motif specification, e.g., `"A1-80,10,\0,B5-12"`. |
|
||||
| `unindex` | `InputSelection?` | Unindexed motif components (unknown sequence placement). |
|
||||
|
||||
@@ -46,6 +46,7 @@ from rfd3.transforms.conditioning_base import (
|
||||
)
|
||||
from rfd3.transforms.util_transforms import assign_types_
|
||||
from rfd3.utils.inference import (
|
||||
_restore_bonds_for_nonstandard_residues,
|
||||
extract_ligand_array,
|
||||
inference_load_,
|
||||
set_com,
|
||||
@@ -81,7 +82,7 @@ class LegacySpecification(BaseModel):
|
||||
def build(self, *args, **kwargs):
|
||||
"""Build atom array using legacy input parsing."""
|
||||
atom_array = create_atom_array_from_design_specification_legacy(
|
||||
design_specification=self.model_dump(),
|
||||
**self.model_dump(),
|
||||
)
|
||||
return atom_array, self.model_dump()
|
||||
|
||||
@@ -377,7 +378,6 @@ class DesignInputSpecification(BaseModel):
|
||||
"RASA",
|
||||
("select_buried", "select_partially_buried", "select_exposed"),
|
||||
),
|
||||
("Hydrogen bonds", ("select_hbond_acceptor", "select_hbond_donor")),
|
||||
]
|
||||
|
||||
for name, excl_set in exclusive_sets:
|
||||
@@ -573,6 +573,8 @@ class DesignInputSpecification(BaseModel):
|
||||
unindexed_tokens=unindexed_tokens,
|
||||
atom_array_accum=[],
|
||||
unindexed_breaks=unindexed_breaks,
|
||||
start_chain="A",
|
||||
start_resid=1,
|
||||
)
|
||||
else:
|
||||
# ... Set common annotations
|
||||
@@ -662,6 +664,14 @@ class DesignInputSpecification(BaseModel):
|
||||
+ list(atom_array_input_annotated.get_annotation_categories())
|
||||
),
|
||||
)
|
||||
# Offset ligand residue ids based on the original input to avoid clashes
|
||||
# with any newly created residues (matches legacy behaviour).
|
||||
ligand_array.res_id = (
|
||||
ligand_array.res_id
|
||||
- np.min(ligand_array.res_id)
|
||||
+ np.max(atom_array.res_id)
|
||||
+ 1
|
||||
)
|
||||
atom_array = atom_array + ligand_array
|
||||
return atom_array
|
||||
|
||||
@@ -802,6 +812,15 @@ def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers
|
||||
atom_array = convert_existing_annotations_to_bool(atom_array)
|
||||
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
|
||||
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
|
||||
|
||||
# Ensure motif annotations are removed
|
||||
atom_array.del_annotation(
|
||||
"is_motif_token"
|
||||
) if "is_motif_token" in atom_array.get_annotation_categories() else None
|
||||
atom_array.del_annotation(
|
||||
"is_motif_atom"
|
||||
) if "is_motif_atom" in atom_array.get_annotation_categories() else None
|
||||
|
||||
data = {
|
||||
"atom_array": atom_array, # First model
|
||||
"chain_info": result_dict["chain_info"],
|
||||
@@ -933,6 +952,8 @@ def accumulate_components(
|
||||
start_chain: str = "A",
|
||||
start_resid: int = 1,
|
||||
unindexed_breaks: Optional[List[bool]] = [],
|
||||
src_atom_array: Optional[AtomArray] = None,
|
||||
strip_sidechains_by_default: bool = False,
|
||||
**kwargs,
|
||||
) -> AtomArray:
|
||||
# ... Create list of components
|
||||
@@ -950,6 +971,12 @@ def accumulate_components(
|
||||
for tok in all_tokens.values()
|
||||
]
|
||||
all_annots = set(all_annots)
|
||||
atom_array_accum = [] if atom_array_accum is None else atom_array_accum
|
||||
unindexed_breaks = (
|
||||
[None] * len(components_to_accumulate)
|
||||
if unindexed_breaks is None
|
||||
else unindexed_breaks
|
||||
)
|
||||
|
||||
# ... For-loop accum variables
|
||||
unindexed_components_started = (
|
||||
@@ -958,12 +985,21 @@ def accumulate_components(
|
||||
chain = start_chain
|
||||
res_id = start_resid
|
||||
molecule_id = 0
|
||||
source_to_accum_idx: Dict[int, int] = {}
|
||||
current_accum_idx = sum(len(arr) for arr in atom_array_accum)
|
||||
|
||||
# ... Insert contig information one- by one-
|
||||
assert len(components_to_accumulate) == len(
|
||||
unindexed_breaks
|
||||
), "Mismatch in number of components to accumulate and breaks"
|
||||
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
|
||||
src_indices = None
|
||||
if exists(is_break) and is_break:
|
||||
if not unindexed_components_started:
|
||||
chain = start_chain
|
||||
res_id = start_resid
|
||||
unindexed_components_started = True
|
||||
|
||||
if component == "/0":
|
||||
# Reset iterators on next chain
|
||||
chain = chr(ord(chain) + 1)
|
||||
@@ -977,12 +1013,21 @@ def accumulate_components(
|
||||
|
||||
# ... Fetch the motif residue
|
||||
token = all_tokens[component]
|
||||
if src_atom_array is not None:
|
||||
src_mask = fetch_mask_from_idx(component, atom_array=src_atom_array)
|
||||
src_indices = np.where(src_mask)[0]
|
||||
# try:
|
||||
# except ComponentValidationError as e:
|
||||
# src_indices = None
|
||||
# print(e)
|
||||
|
||||
# ... Ensure motif residues are set properly
|
||||
token = create_motif_residue(
|
||||
token, strip_sidechains_by_default=strip_sidechains_by_default
|
||||
)
|
||||
|
||||
# ... Insert breakpoint when break clause is met
|
||||
if exists(is_break) and is_break:
|
||||
if not unindexed_components_started:
|
||||
chain = start_chain
|
||||
unindexed_components_started = True
|
||||
token.set_annotation(
|
||||
"is_motif_atom_unindexed_motif_breakpoint",
|
||||
np.ones(token.shape[0], dtype=int),
|
||||
@@ -1015,15 +1060,42 @@ def accumulate_components(
|
||||
len(get_token_starts(token)) == n
|
||||
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
|
||||
|
||||
if (
|
||||
src_atom_array is not None
|
||||
and str(component)[0].isalpha()
|
||||
and src_indices is not None
|
||||
and len(src_indices) == len(token)
|
||||
):
|
||||
for i, src_idx in enumerate(src_indices):
|
||||
source_to_accum_idx[int(src_idx)] = current_accum_idx + i
|
||||
|
||||
# ... Insert & Increment residue ID
|
||||
atom_array_accum.append(token)
|
||||
res_id += n
|
||||
current_accum_idx += len(token)
|
||||
|
||||
# ... Concatenate all components
|
||||
atom_array_accum = struc.concatenate(atom_array_accum)
|
||||
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
||||
|
||||
# Reset res_id for unindexed residues to avoid duplicates
|
||||
should_restore_bonds = (
|
||||
src_atom_array is not None
|
||||
and bool(source_to_accum_idx)
|
||||
and _check_has_backbone_connections_to_nonstandard_residues(
|
||||
atom_array_accum, src_atom_array
|
||||
)
|
||||
)
|
||||
if should_restore_bonds:
|
||||
assert not unindexed_tokens, (
|
||||
"PTM backbone bond restoration is not compatible with unindexed components. "
|
||||
"PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
|
||||
f"Found unindexed components: {list(unindexed_tokens.keys())}"
|
||||
)
|
||||
atom_array_accum = _restore_bonds_for_nonstandard_residues(
|
||||
atom_array_accum, src_atom_array, source_to_accum_idx
|
||||
)
|
||||
|
||||
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
|
||||
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
):
|
||||
@@ -1032,9 +1104,14 @@ def accumulate_components(
|
||||
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
min_id_udx = np.min(
|
||||
atom_array_accum[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
atom_array_accum.res_id[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
] += max_id + 1
|
||||
] += max_id - min_id_udx + 1
|
||||
|
||||
# ... Bonds
|
||||
if atom_array_accum.bonds is None:
|
||||
|
||||
@@ -10,7 +10,6 @@ from atomworks.io.utils.io_utils import to_cif_file
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.utils.token import (
|
||||
get_token_starts,
|
||||
spread_token_wise,
|
||||
)
|
||||
from rfd3.constants import (
|
||||
INFERENCE_ANNOTATIONS,
|
||||
@@ -372,7 +371,7 @@ def accumulate_components(
|
||||
atom_array_accum = struc.concatenate(atom_array_accum)
|
||||
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
|
||||
|
||||
# Reset res_id for unindexed residues to avoid duplicates
|
||||
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
|
||||
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
):
|
||||
@@ -381,9 +380,14 @@ def accumulate_components(
|
||||
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
min_id_udx = np.min(
|
||||
atom_array_accum[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
].res_id
|
||||
)
|
||||
atom_array_accum.res_id[
|
||||
atom_array_accum.is_motif_atom_unindexed.astype(bool)
|
||||
] += max_id + 1
|
||||
] += max_id - min_id_udx + 1
|
||||
|
||||
return atom_array_accum
|
||||
|
||||
@@ -406,7 +410,6 @@ def create_atom_array_from_design_specification_legacy(
|
||||
unfix_all=False,
|
||||
unfix_specific: str = None,
|
||||
flexible_backbone: bool = False,
|
||||
is_d_amino_acid=None,
|
||||
# Args for biomolecular design (Enzymes, DNA/PNA):
|
||||
ligand: str = None,
|
||||
ori_token: list[float] = None,
|
||||
@@ -556,7 +559,7 @@ def create_atom_array_from_design_specification_legacy(
|
||||
if indexed_components and indexed_components_provided:
|
||||
for component in indexed_components:
|
||||
if str(component)[0].isalpha():
|
||||
mask = fetch_residue_mask(atom_array, component)
|
||||
mask = fetch_mask_from_component(component, atom_array=atom_array)
|
||||
|
||||
# Set the component as a motif token
|
||||
set_default_conditioning_annotations(
|
||||
@@ -624,31 +627,24 @@ def create_atom_array_from_design_specification_legacy(
|
||||
is_motif_atom = f["is_motif_atom"]
|
||||
atom_array.set_annotation("is_motif_atom", is_motif_atom.astype(int))
|
||||
|
||||
# This is an annotation on the diffused regions, so must be added after accumulate_components
|
||||
if spoof_helical_bundle_ss_conditioning:
|
||||
is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
|
||||
is_sheet = None
|
||||
is_loop = None
|
||||
if exists(is_helix):
|
||||
set_atom_level_argument(atom_array, is_helix, "is_helix")
|
||||
if exists(is_sheet):
|
||||
set_atom_level_argument(atom_array, is_sheet, "is_sheet")
|
||||
optional_conditions.append("is_sheet")
|
||||
if exists(is_loop):
|
||||
set_atom_level_argument(atom_array, is_loop, "is_loop")
|
||||
optional_conditions.append("is_loop")
|
||||
|
||||
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
|
||||
diffused_region_mask = ~(is_motif_token.astype(bool))
|
||||
if exists(is_non_loopy):
|
||||
is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
|
||||
|
||||
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
|
||||
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
|
||||
|
||||
# ... If ligand, post-pend it
|
||||
if exists(ligand):
|
||||
ligand_array = extract_ligand_array(atom_array_input, ligand, fixed_atoms)
|
||||
ligand_array = extract_ligand_array(
|
||||
atom_array_input,
|
||||
ligand,
|
||||
fixed_atoms,
|
||||
additional_annotations=set(
|
||||
list(atom_array.get_annotation_categories())
|
||||
+ list(atom_array_input.get_annotation_categories())
|
||||
+ ["is_motif_atom", "is_motif_token"]
|
||||
),
|
||||
)
|
||||
ligand_array.res_id = (
|
||||
ligand_array.res_id
|
||||
- np.min(ligand_array.res_id)
|
||||
+ np.max(atom_array.res_id)
|
||||
+ 1
|
||||
)
|
||||
atom_array = atom_array + ligand_array
|
||||
|
||||
# ... Apply symmetry if it exists ahead of any other processing
|
||||
@@ -674,13 +670,31 @@ def create_atom_array_from_design_specification_legacy(
|
||||
atom_array, ori_token=ori_token, infer_ori_strategy=infer_ori_strategy
|
||||
)
|
||||
# diffused atoms initialized at origin
|
||||
atom_array.coord[~atom_array.is_motif_token.astype(bool)] = 0.0
|
||||
atom_array.coord[~atom_array.is_motif_atom_with_fixed_coord.astype(bool), :] = (
|
||||
0.0
|
||||
)
|
||||
|
||||
# ... Add is_d_amino_acid annotation if specified
|
||||
if is_d_amino_acid is not None:
|
||||
# convert to nd_array
|
||||
is_d_amino_acid = np.asarray(is_d_amino_acid, dtype=bool)
|
||||
is_d_amino_acid = spread_token_wise(atom_array, is_d_amino_acid)
|
||||
# This is an annotation on the diffused regions, so must be added after accumulate_components
|
||||
if spoof_helical_bundle_ss_conditioning:
|
||||
is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
|
||||
is_sheet = None
|
||||
is_loop = None
|
||||
if exists(is_helix):
|
||||
set_atom_level_argument(atom_array, is_helix, "is_helix")
|
||||
if exists(is_sheet):
|
||||
set_atom_level_argument(atom_array, is_sheet, "is_sheet")
|
||||
optional_conditions.append("is_sheet")
|
||||
if exists(is_loop):
|
||||
set_atom_level_argument(atom_array, is_loop, "is_loop")
|
||||
optional_conditions.append("is_loop")
|
||||
|
||||
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
|
||||
diffused_region_mask = ~(atom_array.is_motif_token.astype(bool))
|
||||
if exists(is_non_loopy):
|
||||
is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
|
||||
|
||||
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
|
||||
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
|
||||
|
||||
if plddt_enhanced:
|
||||
atom_array.set_annotation(
|
||||
|
||||
@@ -111,7 +111,12 @@ DIRS = [
|
||||
|
||||
def load_test_json():
|
||||
test_files = ["demo.json", "demo_extended.json", "tests.json"]
|
||||
test_files += ["mcsa_41.json", "rfd_unindexed.json", "sym_tests.json"]
|
||||
test_files += [
|
||||
"mcsa_41.json",
|
||||
"rfd_unindexed.json",
|
||||
"sym_tests.json",
|
||||
"brk_regression.json",
|
||||
]
|
||||
test_json_data = {}
|
||||
for dir in DIRS:
|
||||
test_data_dir = Path(dir, "test_data")
|
||||
@@ -163,10 +168,6 @@ def instantiate_example(args, is_inference=True):
|
||||
args = copy.deepcopy(args)
|
||||
|
||||
if is_inference:
|
||||
# Keep only the kwargs that the function actually accepts
|
||||
# args = filter_inference_args(args)
|
||||
# atom_array, spec = create_atom_array_from_design_specification(**args)
|
||||
# input = prepare_pipeline_input_from_atom_array(atom_array)
|
||||
input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
|
||||
example_id=args.get("example_id", "example")
|
||||
)
|
||||
|
||||
@@ -84,11 +84,6 @@ from rfd3.transforms.design_transforms import (
|
||||
)
|
||||
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
|
||||
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
|
||||
from rfd3.transforms.ncaa_transforms import (
|
||||
AddIsDAminoAcidFeat,
|
||||
RandomlyMirrorInputs,
|
||||
StrtoBoolforIsDAminoAcidFeature,
|
||||
)
|
||||
from rfd3.transforms.ppi_transforms import (
|
||||
Add1DSSFeature,
|
||||
AddGlobalIsNonLoopyFeature,
|
||||
@@ -152,7 +147,6 @@ def get_pre_crop_transforms(
|
||||
):
|
||||
return [
|
||||
InferenceRoute(StrtoBoolforIsXFeatures()),
|
||||
InferenceRoute(StrtoBoolforIsDAminoAcidFeature()),
|
||||
RemoveHydrogens(),
|
||||
FilterToSpecifiedPNUnits(
|
||||
extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
|
||||
@@ -543,22 +537,6 @@ def build_atom14_base_pipeline_(
|
||||
diffusion_batch_size=diffusion_batch_size,
|
||||
)
|
||||
|
||||
# ... Mixed chirality handling
|
||||
transforms += [
|
||||
TrainingRoute(
|
||||
ConditionalRoute(
|
||||
condition_func=lambda data: data["conditions"].get(
|
||||
"mirror_input", False
|
||||
),
|
||||
transform_map={
|
||||
True: RandomlyMirrorInputs(),
|
||||
False: Identity(),
|
||||
},
|
||||
)
|
||||
),
|
||||
AddIsDAminoAcidFeat(),
|
||||
]
|
||||
|
||||
# ... Random augmentation accounting for motif
|
||||
transforms += [
|
||||
MotifCenterRandomAugmentation(
|
||||
|
||||
@@ -522,7 +522,9 @@ def sample_unindexed_breaks(
|
||||
token_idxs = np.arange(len(starts))
|
||||
breaks_all = np.zeros(len(starts), dtype=bool)
|
||||
|
||||
if np.any(is_unindexed_token):
|
||||
if is_unindexed_token.sum() == 1:
|
||||
breaks_all = is_unindexed_token
|
||||
elif np.any(is_unindexed_token):
|
||||
# ... Subset to unindexed tokens
|
||||
unindexed_token_starts = starts[is_unindexed_token]
|
||||
unindexed_token_resid = atom_array[unindexed_token_starts].res_id
|
||||
|
||||
@@ -5,12 +5,15 @@ Utilities for inference input preparation
|
||||
import logging
|
||||
import os
|
||||
from os import PathLike
|
||||
from typing import Dict
|
||||
|
||||
import biotite.structure as struc
|
||||
import numpy as np
|
||||
from atomworks import parse
|
||||
from atomworks.constants import STANDARD_DNA
|
||||
from atomworks.io.parser import STANDARD_PARSER_ARGS
|
||||
from atomworks.constants import STANDARD_AA, STANDARD_DNA
|
||||
from atomworks.io.parser import (
|
||||
STANDARD_PARSER_ARGS,
|
||||
)
|
||||
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
|
||||
from atomworks.ml.preprocessing.utils.structure_utils import (
|
||||
get_atom_mask_from_cell_list,
|
||||
@@ -162,6 +165,150 @@ def extract_na_array(atom_array_input):
|
||||
)
|
||||
|
||||
|
||||
def _restore_bonds_for_nonstandard_residues(
|
||||
atom_array_accum: struc.AtomArray,
|
||||
src_atom_array: struc.AtomArray | None,
|
||||
source_to_accum_idx: Dict[int, int],
|
||||
) -> struc.AtomArray:
|
||||
"""
|
||||
Restores and creates bonds for non-standard residues (PTMs, modified AAs, etc.)
|
||||
from source structure and between consecutive residues.
|
||||
This function:
|
||||
1. Preserves inter-residue bonds from the source structure (if available)
|
||||
2. Adds backbone C-N bonds between consecutive residues where at least one is non-standard
|
||||
Args:
|
||||
atom_array_accum: The accumulated atom array to add bonds to
|
||||
src_atom_array: The source atom array containing original bond information
|
||||
source_to_accum_idx: Mapping from source atom indices to accumulated array indices
|
||||
Returns:
|
||||
atom_array_accum with bonds added
|
||||
"""
|
||||
# Initialize bonds if needed
|
||||
if atom_array_accum.bonds is None:
|
||||
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
|
||||
|
||||
# Step 1: Restore inter-residue bonds from the source atom array (only for non-standard residues)
|
||||
if (
|
||||
src_atom_array is not None
|
||||
and hasattr(src_atom_array, "bonds")
|
||||
and src_atom_array.bonds is not None
|
||||
):
|
||||
original_bonds = src_atom_array.bonds.as_array()
|
||||
if len(original_bonds) > 0:
|
||||
# Extract bonds where both atoms are in the accumulated array
|
||||
bonds_to_add = []
|
||||
for bond in original_bonds:
|
||||
atom_i, atom_j, bond_type = bond
|
||||
# Check if both atoms are in our mapping
|
||||
if (
|
||||
int(atom_i) in source_to_accum_idx
|
||||
and int(atom_j) in source_to_accum_idx
|
||||
):
|
||||
# Check if at least one atom is from a non-standard residue
|
||||
src_res_i = src_atom_array[int(atom_i)].res_name
|
||||
src_res_j = src_atom_array[int(atom_j)].res_name
|
||||
|
||||
# Only preserve if at least one residue is non-standard
|
||||
if src_res_i not in STANDARD_AA or src_res_j not in STANDARD_AA:
|
||||
new_i = source_to_accum_idx[int(atom_i)]
|
||||
new_j = source_to_accum_idx[int(atom_j)]
|
||||
bonds_to_add.append([new_i, new_j, int(bond_type)])
|
||||
|
||||
if bonds_to_add:
|
||||
# Add the preserved bonds
|
||||
new_bonds = struc.BondList(
|
||||
atom_array_accum.array_length(),
|
||||
np.array(bonds_to_add, dtype=np.int64),
|
||||
)
|
||||
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
|
||||
logger.info(
|
||||
f"Preserved {len(bonds_to_add)} inter-residue bonds involving non-standard residues from source structure"
|
||||
)
|
||||
|
||||
# Step 2: Add backbone bonds between consecutive residues where at least one is non-standard
|
||||
# This handles: PTM-to-diffused, diffused-to-PTM, PTM-to-PTM, ligand-to-protein
|
||||
bonds_to_add = []
|
||||
|
||||
# Group by residue
|
||||
token_starts = get_token_starts(atom_array_accum, add_exclusive_stop=True)
|
||||
|
||||
for i in range(
|
||||
len(token_starts) - 2
|
||||
): # -2 because we need pairs and token_starts has exclusive stop
|
||||
curr_start, curr_end = token_starts[i], token_starts[i + 1]
|
||||
next_start, next_end = token_starts[i + 1], token_starts[i + 2]
|
||||
|
||||
curr_residue = atom_array_accum[curr_start:curr_end]
|
||||
next_residue = atom_array_accum[next_start:next_end]
|
||||
|
||||
# Check if at least one residue is non-standard (PTMs, modified AAs, etc.)
|
||||
curr_is_nonstandard = curr_residue.res_name[0] not in STANDARD_AA
|
||||
next_is_nonstandard = next_residue.res_name[0] not in STANDARD_AA
|
||||
|
||||
# Only add bonds if at least one residue is non-standard
|
||||
if not (curr_is_nonstandard or next_is_nonstandard):
|
||||
continue
|
||||
|
||||
# Check if consecutive in same chain
|
||||
if curr_residue.chain_id[0] != next_residue.chain_id[0]:
|
||||
continue
|
||||
if next_residue.res_id[0] - curr_residue.res_id[0] != 1:
|
||||
continue
|
||||
|
||||
# Find C atom in current residue (C-terminus connection point)
|
||||
c_mask = curr_residue.atom_name == "C"
|
||||
if not np.any(c_mask):
|
||||
# If a non-standard residue doesn't have a C atom, it can't connect to next residue
|
||||
# This is expected for some atomized residues or ligands at chain termini
|
||||
if curr_is_nonstandard and next_is_nonstandard:
|
||||
# Both are non-standard but no C in current - might be an atomized region without proper termini
|
||||
logger.debug(
|
||||
f"Non-standard residue {curr_residue.res_name[0]} (res_id {curr_residue.res_id[0]}) "
|
||||
f"has no C atom - cannot form backbone bond to next residue"
|
||||
)
|
||||
continue
|
||||
c_idx = curr_start + np.where(c_mask)[0][0]
|
||||
|
||||
# Find N atom in next residue (N-terminus connection point)
|
||||
n_mask = next_residue.atom_name == "N"
|
||||
if not np.any(n_mask):
|
||||
# If a non-standard residue doesn't have an N atom, it can't connect to previous residue
|
||||
# This is expected for some atomized residues or ligands at chain termini
|
||||
if curr_is_nonstandard and next_is_nonstandard:
|
||||
# Both are non-standard but no N in next - might be an atomized region without proper termini
|
||||
logger.debug(
|
||||
f"Non-standard residue {next_residue.res_name[0]} (res_id {next_residue.res_id[0]}) "
|
||||
f"has no N atom - cannot form backbone bond from previous residue"
|
||||
)
|
||||
continue
|
||||
n_idx = next_start + np.where(n_mask)[0][0]
|
||||
|
||||
# Check if this bond already exists (from source preservation)
|
||||
existing_bonds = atom_array_accum.bonds.as_array()
|
||||
bond_exists = False
|
||||
if len(existing_bonds) > 0:
|
||||
for existing_bond in existing_bonds:
|
||||
if (existing_bond[0] == c_idx and existing_bond[1] == n_idx) or (
|
||||
existing_bond[0] == n_idx and existing_bond[1] == c_idx
|
||||
):
|
||||
bond_exists = True
|
||||
break
|
||||
|
||||
if not bond_exists:
|
||||
bonds_to_add.append([c_idx, n_idx, struc.BondType.SINGLE])
|
||||
|
||||
if bonds_to_add:
|
||||
new_bonds = struc.BondList(
|
||||
atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
|
||||
)
|
||||
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
|
||||
logger.info(
|
||||
f"Added {len(bonds_to_add)} backbone bonds involving non-standard residues"
|
||||
)
|
||||
|
||||
return atom_array_accum
|
||||
|
||||
|
||||
#################################################################################
|
||||
# File IO utilities
|
||||
#################################################################################
|
||||
|
||||
@@ -10,8 +10,7 @@ def pytest_configure(config):
|
||||
|
||||
paths_to_add = [
|
||||
root / "src",
|
||||
root / "lib" / "cifutils" / "src",
|
||||
root / "lib" / "datahub" / "src",
|
||||
root / "models" / "rfd3" / "tests",
|
||||
]
|
||||
|
||||
for path in paths_to_add:
|
||||
|
||||
@@ -47,7 +47,6 @@ def test_glycine_features_and_is_x(example, is_inference):
|
||||
|
||||
assert bad_feats == [
|
||||
"is_central",
|
||||
"is_d_amino_acid",
|
||||
], "Expected only is_central to differ: {}".format(bad_feats)
|
||||
assert (
|
||||
actual["is_central"].sum() == actual["is_ca"].sum()
|
||||
|
||||
71
models/rfd3/tests/test_legacy_pipeline_equivalence.py
Normal file
71
models/rfd3/tests/test_legacy_pipeline_equivalence.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from rfd3.inference.input_parsing import DesignInputSpecification
|
||||
from rfd3.testing.testing_utils import PIPES, TEST_JSON_DATA
|
||||
|
||||
|
||||
def test_legacy_pipeline_equivalence():
|
||||
from transforms.test_pipeline_regression import (
|
||||
_assert_features_equal,
|
||||
_assert_pipeline_results_equal,
|
||||
assert_same_atom_array,
|
||||
)
|
||||
|
||||
args_new = TEST_JSON_DATA["brk-new"]
|
||||
args_legacy = TEST_JSON_DATA["brk-legacy"]
|
||||
|
||||
spec_new = DesignInputSpecification.safe_init(**args_new)
|
||||
spec_legacy = DesignInputSpecification.safe_init(**args_legacy)
|
||||
|
||||
spec_new_input = spec_new.to_pipeline_input("new")
|
||||
spec_legacy_input = spec_legacy.to_pipeline_input("legacy")
|
||||
|
||||
aa_in_new = spec_new_input["atom_array"]
|
||||
aa_in_old = spec_legacy_input["atom_array"]
|
||||
|
||||
# assert equivalent
|
||||
assert_same_atom_array(
|
||||
aa_in_new,
|
||||
aa_in_old,
|
||||
compare_coords=True,
|
||||
compare_bonds=True,
|
||||
# (All annotation categories present in the expected atom array are compared)
|
||||
annotations_to_compare=set(
|
||||
list(aa_in_old.get_annotation_categories())
|
||||
+ list(aa_in_new.get_annotation_categories())
|
||||
),
|
||||
)
|
||||
|
||||
is_inference = True
|
||||
example_new = PIPES[is_inference](spec_new_input)
|
||||
example_legacy = PIPES[is_inference](spec_legacy_input)
|
||||
|
||||
aa_new = example_new["atom_array"]
|
||||
aa_old = example_legacy["atom_array"]
|
||||
|
||||
_assert_features_equal(
|
||||
example_new["feats"],
|
||||
example_legacy["feats"],
|
||||
"Brianne King's features for non-loopy",
|
||||
"inference",
|
||||
)
|
||||
|
||||
assert_same_atom_array(
|
||||
aa_new,
|
||||
aa_old,
|
||||
compare_coords=True,
|
||||
compare_bonds=True,
|
||||
# (All annotation categories present in the expected atom array are compared)
|
||||
annotations_to_compare=set(
|
||||
list(aa_old.get_annotation_categories())
|
||||
+ list(aa_new.get_annotation_categories())
|
||||
),
|
||||
)
|
||||
_assert_pipeline_results_equal(
|
||||
example_new,
|
||||
example_legacy,
|
||||
"Brianne King's example for non-loopy",
|
||||
"inference",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_legacy_pipeline_equivalence()
|
||||
Reference in New Issue
Block a user