Handle duplicate AF3 residue IDs during inference

This commit is contained in:
Dima
2026-03-25 15:48:06 +01:00
parent 04e945a3b7
commit 962bca9256
2 changed files with 34 additions and 1 deletions

View File

@@ -1262,6 +1262,39 @@ class TestAlphaFold3RunModes(_TestBase):
self.assertEqual(relative_encoding[3, 4, inter_chain_bin], 0)
self.assertEqual(np.argmax(relative_encoding[3, 4, : 2 * 4 + 2]), 0)
def test_af3_duplicate_residue_ids_survive_empty_structure_round_trip(self):
"""AF3 must preserve duplicate residue IDs when rebuilding empty structures."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model.atom_layout import atom_layout
expected_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGHIKLCDEFMNPQ",
ptms=[],
residue_ids=expected_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="duplicate_residue_ids_test",
chains=[chain],
rng_seeds=[1],
)
ccd = chemical_components.Ccd()
struc = fold_input.to_structure(ccd=ccd)
flat_layout = atom_layout.atom_layout_from_structure(struc)
rebuilt = atom_layout.make_structure(
flat_layout,
atom_coords=np.zeros((flat_layout.atom_name.shape[0], 3), dtype=np.float32),
name="duplicate_residue_ids_test",
all_physical_residues=struc.present_residues,
)
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
from alphapulldown.folding_backend.alphafold3_backend import (