Update to aa_design/main latest (post loop fix)

This commit is contained in:
jbutch
2025-11-23 19:30:17 -08:00
parent 3dba499b6d
commit e520df0a97
12 changed files with 367 additions and 76 deletions

View File

@@ -6,5 +6,5 @@ defaults:
dataset:
name: rigid-ligand-enzymes
eval_every_n: 1
data: ${paths.data.design_benchmark_data_dir}/mcsa_41_short_rigid.json
data: ${paths.data.design_benchmark_data_dir}/mcsa_41_short_rigid_new.json

View File

@@ -26,6 +26,9 @@ OPTIONAL_CONDITIONING_VALUES = {
"ref_plddt": 0,
"is_non_loopy": 0,
"partial_t": np.nan,
# kept for legacy reasons
"is_motif_token": 1,
"is_motif_atom": 1,
}
"""Optional conditioning annotations and their default values if not provided."""

View File

@@ -44,7 +44,7 @@
| Field | Type | Description |
| -------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------- |
| `input` | `str?` | Path to input **PDB/CIF**. Required if you provide contig+length. |
| `input` | `str?` | Path to input **PDB/CIF**. Required if you provide contig+length. |
| `atom_array_input` | internal | Pre-loaded `AtomArray` (not recommended). |
| `contig` | `InputSelection?` | Indexed motif specification, e.g., `"A1-80,10,\0,B5-12"`. |
| `unindex` | `InputSelection?` | Unindexed motif components (unknown sequence placement). |

View File

@@ -46,6 +46,7 @@ from rfd3.transforms.conditioning_base import (
)
from rfd3.transforms.util_transforms import assign_types_
from rfd3.utils.inference import (
_restore_bonds_for_nonstandard_residues,
extract_ligand_array,
inference_load_,
set_com,
@@ -81,7 +82,7 @@ class LegacySpecification(BaseModel):
def build(self, *args, **kwargs):
"""Build atom array using legacy input parsing."""
atom_array = create_atom_array_from_design_specification_legacy(
design_specification=self.model_dump(),
**self.model_dump(),
)
return atom_array, self.model_dump()
@@ -377,7 +378,6 @@ class DesignInputSpecification(BaseModel):
"RASA",
("select_buried", "select_partially_buried", "select_exposed"),
),
("Hydrogen bonds", ("select_hbond_acceptor", "select_hbond_donor")),
]
for name, excl_set in exclusive_sets:
@@ -573,6 +573,8 @@ class DesignInputSpecification(BaseModel):
unindexed_tokens=unindexed_tokens,
atom_array_accum=[],
unindexed_breaks=unindexed_breaks,
start_chain="A",
start_resid=1,
)
else:
# ... Set common annotations
@@ -662,6 +664,14 @@ class DesignInputSpecification(BaseModel):
+ list(atom_array_input_annotated.get_annotation_categories())
),
)
# Offset ligand residue ids based on the original input to avoid clashes
# with any newly created residues (matches legacy behaviour).
ligand_array.res_id = (
ligand_array.res_id
- np.min(ligand_array.res_id)
+ np.max(atom_array.res_id)
+ 1
)
atom_array = atom_array + ligand_array
return atom_array
@@ -802,6 +812,15 @@ def prepare_pipeline_input_from_atom_array( # see atomworks.ml.datasets.parsers
atom_array = convert_existing_annotations_to_bool(atom_array)
atom_array.set_annotation("chain_iid", [f"{c}_1" for c in atom_array.chain_id])
atom_array.set_annotation("pn_unit_iid", [f"{c}_1" for c in atom_array.pn_unit_id])
# Ensure motif annotations are removed
atom_array.del_annotation(
"is_motif_token"
) if "is_motif_token" in atom_array.get_annotation_categories() else None
atom_array.del_annotation(
"is_motif_atom"
) if "is_motif_atom" in atom_array.get_annotation_categories() else None
data = {
"atom_array": atom_array, # First model
"chain_info": result_dict["chain_info"],
@@ -933,6 +952,8 @@ def accumulate_components(
start_chain: str = "A",
start_resid: int = 1,
unindexed_breaks: Optional[List[bool]] = [],
src_atom_array: Optional[AtomArray] = None,
strip_sidechains_by_default: bool = False,
**kwargs,
) -> AtomArray:
# ... Create list of components
@@ -950,6 +971,12 @@ def accumulate_components(
for tok in all_tokens.values()
]
all_annots = set(all_annots)
atom_array_accum = [] if atom_array_accum is None else atom_array_accum
unindexed_breaks = (
[None] * len(components_to_accumulate)
if unindexed_breaks is None
else unindexed_breaks
)
# ... For-loop accum variables
unindexed_components_started = (
@@ -958,12 +985,21 @@ def accumulate_components(
chain = start_chain
res_id = start_resid
molecule_id = 0
source_to_accum_idx: Dict[int, int] = {}
current_accum_idx = sum(len(arr) for arr in atom_array_accum)
# ... Insert contig information one- by one-
assert len(components_to_accumulate) == len(
unindexed_breaks
), "Mismatch in number of components to accumulate and breaks"
for component, is_break in zip(components_to_accumulate, unindexed_breaks):
src_indices = None
if exists(is_break) and is_break:
if not unindexed_components_started:
chain = start_chain
res_id = start_resid
unindexed_components_started = True
if component == "/0":
# Reset iterators on next chain
chain = chr(ord(chain) + 1)
@@ -977,12 +1013,21 @@ def accumulate_components(
# ... Fetch the motif residue
token = all_tokens[component]
if src_atom_array is not None:
src_mask = fetch_mask_from_idx(component, atom_array=src_atom_array)
src_indices = np.where(src_mask)[0]
# try:
# except ComponentValidationError as e:
# src_indices = None
# print(e)
# ... Ensure motif residues are set properly
token = create_motif_residue(
token, strip_sidechains_by_default=strip_sidechains_by_default
)
# ... Insert breakpoint when break clause is met
if exists(is_break) and is_break:
if not unindexed_components_started:
chain = start_chain
unindexed_components_started = True
token.set_annotation(
"is_motif_atom_unindexed_motif_breakpoint",
np.ones(token.shape[0], dtype=int),
@@ -1015,15 +1060,42 @@ def accumulate_components(
len(get_token_starts(token)) == n
), f"Mismatch in number of residues: expected {n}, got {len(get_token_starts(token))} in \n{token}"
if (
src_atom_array is not None
and str(component)[0].isalpha()
and src_indices is not None
and len(src_indices) == len(token)
):
for i, src_idx in enumerate(src_indices):
source_to_accum_idx[int(src_idx)] = current_accum_idx + i
# ... Insert & Increment residue ID
atom_array_accum.append(token)
res_id += n
current_accum_idx += len(token)
# ... Concatenate all components
atom_array_accum = struc.concatenate(atom_array_accum)
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
# Reset res_id for unindexed residues to avoid duplicates
should_restore_bonds = (
src_atom_array is not None
and bool(source_to_accum_idx)
and _check_has_backbone_connections_to_nonstandard_residues(
atom_array_accum, src_atom_array
)
)
if should_restore_bonds:
assert not unindexed_tokens, (
"PTM backbone bond restoration is not compatible with unindexed components. "
"PTMs must be specified as indexed components (using 'contig' parameter, not 'unindex'). "
f"Found unindexed components: {list(unindexed_tokens.keys())}"
)
atom_array_accum = _restore_bonds_for_nonstandard_residues(
atom_array_accum, src_atom_array, source_to_accum_idx
)
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
atom_array_accum.is_motif_atom_unindexed.astype(bool)
):
@@ -1032,9 +1104,14 @@ def accumulate_components(
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
min_id_udx = np.min(
atom_array_accum[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
atom_array_accum.res_id[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
] += max_id + 1
] += max_id - min_id_udx + 1
# ... Bonds
if atom_array_accum.bonds is None:

View File

@@ -10,7 +10,6 @@ from atomworks.io.utils.io_utils import to_cif_file
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.utils.token import (
get_token_starts,
spread_token_wise,
)
from rfd3.constants import (
INFERENCE_ANNOTATIONS,
@@ -372,7 +371,7 @@ def accumulate_components(
atom_array_accum = struc.concatenate(atom_array_accum)
atom_array_accum.set_annotation("pn_unit_iid", atom_array_accum.chain_id)
# Reset res_id for unindexed residues to avoid duplicates
# Reset res_id for unindexed residues to avoid duplicates (ridiculously long lines of code, cleanup later)
if np.any(atom_array_accum.is_motif_atom_unindexed.astype(bool)) and not np.all(
atom_array_accum.is_motif_atom_unindexed.astype(bool)
):
@@ -381,9 +380,14 @@ def accumulate_components(
~atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
min_id_udx = np.min(
atom_array_accum[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
].res_id
)
atom_array_accum.res_id[
atom_array_accum.is_motif_atom_unindexed.astype(bool)
] += max_id + 1
] += max_id - min_id_udx + 1
return atom_array_accum
@@ -406,7 +410,6 @@ def create_atom_array_from_design_specification_legacy(
unfix_all=False,
unfix_specific: str = None,
flexible_backbone: bool = False,
is_d_amino_acid=None,
# Args for biomolecular design (Enzymes, DNA/PNA):
ligand: str = None,
ori_token: list[float] = None,
@@ -556,7 +559,7 @@ def create_atom_array_from_design_specification_legacy(
if indexed_components and indexed_components_provided:
for component in indexed_components:
if str(component)[0].isalpha():
mask = fetch_residue_mask(atom_array, component)
mask = fetch_mask_from_component(component, atom_array=atom_array)
# Set the component as a motif token
set_default_conditioning_annotations(
@@ -624,31 +627,24 @@ def create_atom_array_from_design_specification_legacy(
is_motif_atom = f["is_motif_atom"]
atom_array.set_annotation("is_motif_atom", is_motif_atom.astype(int))
# This is an annotation on the diffused regions, so must be added after accumulate_components
if spoof_helical_bundle_ss_conditioning:
is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
is_sheet = None
is_loop = None
if exists(is_helix):
set_atom_level_argument(atom_array, is_helix, "is_helix")
if exists(is_sheet):
set_atom_level_argument(atom_array, is_sheet, "is_sheet")
optional_conditions.append("is_sheet")
if exists(is_loop):
set_atom_level_argument(atom_array, is_loop, "is_loop")
optional_conditions.append("is_loop")
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
diffused_region_mask = ~(is_motif_token.astype(bool))
if exists(is_non_loopy):
is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
# ... If ligand, post-pend it
if exists(ligand):
ligand_array = extract_ligand_array(atom_array_input, ligand, fixed_atoms)
ligand_array = extract_ligand_array(
atom_array_input,
ligand,
fixed_atoms,
additional_annotations=set(
list(atom_array.get_annotation_categories())
+ list(atom_array_input.get_annotation_categories())
+ ["is_motif_atom", "is_motif_token"]
),
)
ligand_array.res_id = (
ligand_array.res_id
- np.min(ligand_array.res_id)
+ np.max(atom_array.res_id)
+ 1
)
atom_array = atom_array + ligand_array
# ... Apply symmetry if it exists ahead of any other processing
@@ -674,13 +670,31 @@ def create_atom_array_from_design_specification_legacy(
atom_array, ori_token=ori_token, infer_ori_strategy=infer_ori_strategy
)
# diffused atoms initialized at origin
atom_array.coord[~atom_array.is_motif_token.astype(bool)] = 0.0
atom_array.coord[~atom_array.is_motif_atom_with_fixed_coord.astype(bool), :] = (
0.0
)
# ... Add is_d_amino_acid annotation if specified
if is_d_amino_acid is not None:
# convert to nd_array
is_d_amino_acid = np.asarray(is_d_amino_acid, dtype=bool)
is_d_amino_acid = spread_token_wise(atom_array, is_d_amino_acid)
# This is an annotation on the diffused regions, so must be added after accumulate_components
if spoof_helical_bundle_ss_conditioning:
is_helix = spoof_helical_bundle_ss_conditioning_fn(atom_array)
is_sheet = None
is_loop = None
if exists(is_helix):
set_atom_level_argument(atom_array, is_helix, "is_helix")
if exists(is_sheet):
set_atom_level_argument(atom_array, is_sheet, "is_sheet")
optional_conditions.append("is_sheet")
if exists(is_loop):
set_atom_level_argument(atom_array, is_loop, "is_loop")
optional_conditions.append("is_loop")
is_non_loopy_annot = np.zeros(atom_array.array_length(), dtype=int)
diffused_region_mask = ~(atom_array.is_motif_token.astype(bool))
if exists(is_non_loopy):
is_non_loopy_annot[diffused_region_mask] = 1 if is_non_loopy else -1
atom_array.set_annotation("is_non_loopy", is_non_loopy_annot)
atom_array.set_annotation("is_non_loopy_atom_level", is_non_loopy_annot)
if plddt_enhanced:
atom_array.set_annotation(

View File

@@ -111,7 +111,12 @@ DIRS = [
def load_test_json():
test_files = ["demo.json", "demo_extended.json", "tests.json"]
test_files += ["mcsa_41.json", "rfd_unindexed.json", "sym_tests.json"]
test_files += [
"mcsa_41.json",
"rfd_unindexed.json",
"sym_tests.json",
"brk_regression.json",
]
test_json_data = {}
for dir in DIRS:
test_data_dir = Path(dir, "test_data")
@@ -163,10 +168,6 @@ def instantiate_example(args, is_inference=True):
args = copy.deepcopy(args)
if is_inference:
# Keep only the kwargs that the function actually accepts
# args = filter_inference_args(args)
# atom_array, spec = create_atom_array_from_design_specification(**args)
# input = prepare_pipeline_input_from_atom_array(atom_array)
input = DesignInputSpecification.safe_init(**args).to_pipeline_input(
example_id=args.get("example_id", "example")
)

View File

@@ -84,11 +84,6 @@ from rfd3.transforms.design_transforms import (
)
from rfd3.transforms.dna_crop import ProteinDNAContactContiguousCrop
from rfd3.transforms.hbonds_hbplus import CalculateHbondsPlus
from rfd3.transforms.ncaa_transforms import (
AddIsDAminoAcidFeat,
RandomlyMirrorInputs,
StrtoBoolforIsDAminoAcidFeature,
)
from rfd3.transforms.ppi_transforms import (
Add1DSSFeature,
AddGlobalIsNonLoopyFeature,
@@ -152,7 +147,6 @@ def get_pre_crop_transforms(
):
return [
InferenceRoute(StrtoBoolforIsXFeatures()),
InferenceRoute(StrtoBoolforIsDAminoAcidFeature()),
RemoveHydrogens(),
FilterToSpecifiedPNUnits(
extra_info_key_with_pn_unit_iids_to_keep="all_pn_unit_iids_after_processing"
@@ -543,22 +537,6 @@ def build_atom14_base_pipeline_(
diffusion_batch_size=diffusion_batch_size,
)
# ... Mixed chirality handling
transforms += [
TrainingRoute(
ConditionalRoute(
condition_func=lambda data: data["conditions"].get(
"mirror_input", False
),
transform_map={
True: RandomlyMirrorInputs(),
False: Identity(),
},
)
),
AddIsDAminoAcidFeat(),
]
# ... Random augmentation accounting for motif
transforms += [
MotifCenterRandomAugmentation(

View File

@@ -522,7 +522,9 @@ def sample_unindexed_breaks(
token_idxs = np.arange(len(starts))
breaks_all = np.zeros(len(starts), dtype=bool)
if np.any(is_unindexed_token):
if is_unindexed_token.sum() == 1:
breaks_all = is_unindexed_token
elif np.any(is_unindexed_token):
# ... Subset to unindexed tokens
unindexed_token_starts = starts[is_unindexed_token]
unindexed_token_resid = atom_array[unindexed_token_starts].res_id

View File

@@ -5,12 +5,15 @@ Utilities for inference input preparation
import logging
import os
from os import PathLike
from typing import Dict
import biotite.structure as struc
import numpy as np
from atomworks import parse
from atomworks.constants import STANDARD_DNA
from atomworks.io.parser import STANDARD_PARSER_ARGS
from atomworks.constants import STANDARD_AA, STANDARD_DNA
from atomworks.io.parser import (
STANDARD_PARSER_ARGS,
)
from atomworks.ml.encoding_definitions import AF3SequenceEncoding
from atomworks.ml.preprocessing.utils.structure_utils import (
get_atom_mask_from_cell_list,
@@ -162,6 +165,150 @@ def extract_na_array(atom_array_input):
)
def _restore_bonds_for_nonstandard_residues(
atom_array_accum: struc.AtomArray,
src_atom_array: struc.AtomArray | None,
source_to_accum_idx: Dict[int, int],
) -> struc.AtomArray:
"""
Restores and creates bonds for non-standard residues (PTMs, modified AAs, etc.)
from source structure and between consecutive residues.
This function:
1. Preserves inter-residue bonds from the source structure (if available)
2. Adds backbone C-N bonds between consecutive residues where at least one is non-standard
Args:
atom_array_accum: The accumulated atom array to add bonds to
src_atom_array: The source atom array containing original bond information
source_to_accum_idx: Mapping from source atom indices to accumulated array indices
Returns:
atom_array_accum with bonds added
"""
# Initialize bonds if needed
if atom_array_accum.bonds is None:
atom_array_accum.bonds = struc.BondList(atom_array_accum.array_length())
# Step 1: Restore inter-residue bonds from the source atom array (only for non-standard residues)
if (
src_atom_array is not None
and hasattr(src_atom_array, "bonds")
and src_atom_array.bonds is not None
):
original_bonds = src_atom_array.bonds.as_array()
if len(original_bonds) > 0:
# Extract bonds where both atoms are in the accumulated array
bonds_to_add = []
for bond in original_bonds:
atom_i, atom_j, bond_type = bond
# Check if both atoms are in our mapping
if (
int(atom_i) in source_to_accum_idx
and int(atom_j) in source_to_accum_idx
):
# Check if at least one atom is from a non-standard residue
src_res_i = src_atom_array[int(atom_i)].res_name
src_res_j = src_atom_array[int(atom_j)].res_name
# Only preserve if at least one residue is non-standard
if src_res_i not in STANDARD_AA or src_res_j not in STANDARD_AA:
new_i = source_to_accum_idx[int(atom_i)]
new_j = source_to_accum_idx[int(atom_j)]
bonds_to_add.append([new_i, new_j, int(bond_type)])
if bonds_to_add:
# Add the preserved bonds
new_bonds = struc.BondList(
atom_array_accum.array_length(),
np.array(bonds_to_add, dtype=np.int64),
)
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
logger.info(
f"Preserved {len(bonds_to_add)} inter-residue bonds involving non-standard residues from source structure"
)
# Step 2: Add backbone bonds between consecutive residues where at least one is non-standard
# This handles: PTM-to-diffused, diffused-to-PTM, PTM-to-PTM, ligand-to-protein
bonds_to_add = []
# Group by residue
token_starts = get_token_starts(atom_array_accum, add_exclusive_stop=True)
for i in range(
len(token_starts) - 2
): # -2 because we need pairs and token_starts has exclusive stop
curr_start, curr_end = token_starts[i], token_starts[i + 1]
next_start, next_end = token_starts[i + 1], token_starts[i + 2]
curr_residue = atom_array_accum[curr_start:curr_end]
next_residue = atom_array_accum[next_start:next_end]
# Check if at least one residue is non-standard (PTMs, modified AAs, etc.)
curr_is_nonstandard = curr_residue.res_name[0] not in STANDARD_AA
next_is_nonstandard = next_residue.res_name[0] not in STANDARD_AA
# Only add bonds if at least one residue is non-standard
if not (curr_is_nonstandard or next_is_nonstandard):
continue
# Check if consecutive in same chain
if curr_residue.chain_id[0] != next_residue.chain_id[0]:
continue
if next_residue.res_id[0] - curr_residue.res_id[0] != 1:
continue
# Find C atom in current residue (C-terminus connection point)
c_mask = curr_residue.atom_name == "C"
if not np.any(c_mask):
# If a non-standard residue doesn't have a C atom, it can't connect to next residue
# This is expected for some atomized residues or ligands at chain termini
if curr_is_nonstandard and next_is_nonstandard:
# Both are non-standard but no C in current - might be an atomized region without proper termini
logger.debug(
f"Non-standard residue {curr_residue.res_name[0]} (res_id {curr_residue.res_id[0]}) "
f"has no C atom - cannot form backbone bond to next residue"
)
continue
c_idx = curr_start + np.where(c_mask)[0][0]
# Find N atom in next residue (N-terminus connection point)
n_mask = next_residue.atom_name == "N"
if not np.any(n_mask):
# If a non-standard residue doesn't have an N atom, it can't connect to previous residue
# This is expected for some atomized residues or ligands at chain termini
if curr_is_nonstandard and next_is_nonstandard:
# Both are non-standard but no N in next - might be an atomized region without proper termini
logger.debug(
f"Non-standard residue {next_residue.res_name[0]} (res_id {next_residue.res_id[0]}) "
f"has no N atom - cannot form backbone bond from previous residue"
)
continue
n_idx = next_start + np.where(n_mask)[0][0]
# Check if this bond already exists (from source preservation)
existing_bonds = atom_array_accum.bonds.as_array()
bond_exists = False
if len(existing_bonds) > 0:
for existing_bond in existing_bonds:
if (existing_bond[0] == c_idx and existing_bond[1] == n_idx) or (
existing_bond[0] == n_idx and existing_bond[1] == c_idx
):
bond_exists = True
break
if not bond_exists:
bonds_to_add.append([c_idx, n_idx, struc.BondType.SINGLE])
if bonds_to_add:
new_bonds = struc.BondList(
atom_array_accum.array_length(), np.array(bonds_to_add, dtype=np.int64)
)
atom_array_accum.bonds = atom_array_accum.bonds.merge(new_bonds)
logger.info(
f"Added {len(bonds_to_add)} backbone bonds involving non-standard residues"
)
return atom_array_accum
#################################################################################
# File IO utilities
#################################################################################

View File

@@ -10,8 +10,7 @@ def pytest_configure(config):
paths_to_add = [
root / "src",
root / "lib" / "cifutils" / "src",
root / "lib" / "datahub" / "src",
root / "models" / "rfd3" / "tests",
]
for path in paths_to_add:

View File

@@ -47,7 +47,6 @@ def test_glycine_features_and_is_x(example, is_inference):
assert bad_feats == [
"is_central",
"is_d_amino_acid",
], "Expected only is_central to differ: {}".format(bad_feats)
assert (
actual["is_central"].sum() == actual["is_ca"].sum()

View File

@@ -0,0 +1,71 @@
from rfd3.inference.input_parsing import DesignInputSpecification
from rfd3.testing.testing_utils import PIPES, TEST_JSON_DATA
def test_legacy_pipeline_equivalence():
from transforms.test_pipeline_regression import (
_assert_features_equal,
_assert_pipeline_results_equal,
assert_same_atom_array,
)
args_new = TEST_JSON_DATA["brk-new"]
args_legacy = TEST_JSON_DATA["brk-legacy"]
spec_new = DesignInputSpecification.safe_init(**args_new)
spec_legacy = DesignInputSpecification.safe_init(**args_legacy)
spec_new_input = spec_new.to_pipeline_input("new")
spec_legacy_input = spec_legacy.to_pipeline_input("legacy")
aa_in_new = spec_new_input["atom_array"]
aa_in_old = spec_legacy_input["atom_array"]
# assert equivalent
assert_same_atom_array(
aa_in_new,
aa_in_old,
compare_coords=True,
compare_bonds=True,
# (All annotation categories present in the expected atom array are compared)
annotations_to_compare=set(
list(aa_in_old.get_annotation_categories())
+ list(aa_in_new.get_annotation_categories())
),
)
is_inference = True
example_new = PIPES[is_inference](spec_new_input)
example_legacy = PIPES[is_inference](spec_legacy_input)
aa_new = example_new["atom_array"]
aa_old = example_legacy["atom_array"]
_assert_features_equal(
example_new["feats"],
example_legacy["feats"],
"Brianne King's features for non-loopy",
"inference",
)
assert_same_atom_array(
aa_new,
aa_old,
compare_coords=True,
compare_bonds=True,
# (All annotation categories present in the expected atom array are compared)
annotations_to_compare=set(
list(aa_old.get_annotation_categories())
+ list(aa_new.get_annotation_categories())
),
)
_assert_pipeline_results_equal(
example_new,
example_legacy,
"Brianne King's example for non-loopy",
"inference",
)
if __name__ == "__main__":
test_legacy_pipeline_equivalence()