Fix AF3 viewer annotation tables for duplicate residues

This commit is contained in:
Dima
2026-03-25 16:54:34 +01:00
parent a9b9905e79
commit f16189bcd8
2 changed files with 81 additions and 16 deletions

View File

@@ -262,37 +262,75 @@ def _has_duplicate_residue_ids_within_chain(struc) -> bool:
def _make_viewer_compatible_inference_result(
inference_result: model.InferenceResult,
) -> model.InferenceResult:
"""Creates a viewer-safe copy using insertion codes for duplicates."""
"""Creates a viewer-safe copy with unique label IDs and original auth IDs."""
if not _has_duplicate_residue_ids_within_chain(
inference_result.predicted_structure
):
return inference_result
struc = inference_result.predicted_structure
residue_chain_ids = struc.chains_table.apply_array_to_column(
column_name='id',
arr=struc.residues_table.chain_key,
residue_chain_ids = [
str(chain_id)
for chain_id in struc.chains_table.apply_array_to_column(
column_name='id',
arr=struc.residues_table.chain_key,
)
]
author_residue_ids = [
str(residue_id)
for residue_id in struc.residues_table.auth_seq_id
]
if all(residue_id == '.' for residue_id in author_residue_ids):
author_residue_ids = [
str(int(residue_id)) for residue_id in struc.residues_table.id
]
sequential_label_ids = np.asarray(
_sequential_residue_ids_per_chain(residue_chain_ids),
dtype=np.int32,
)
occurrence_count_by_residue: dict[tuple[str, int], int] = {}
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
insertion_codes = []
for chain_id, residue_id in zip(
residue_chain_ids,
struc.residues_table.id,
author_residue_ids,
strict=True,
):
key = (str(chain_id), int(residue_id))
key = (chain_id, residue_id)
occurrence = occurrence_count_by_residue.get(key, 0) + 1
occurrence_count_by_residue[key] = occurrence
insertion_codes.append(_duplicate_occurrence_to_insertion_code(occurrence))
viewer_structure = struc.copy_and_update_residues(
res_id=sequential_label_ids,
res_auth_seq_id=np.asarray(author_residue_ids, dtype=object),
res_insertion_code=np.asarray(insertion_codes, dtype=object),
)
token_chain_ids = [
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
]
sequential_token_ids = _sequential_residue_ids_per_chain(token_chain_ids)
metadata = dict(inference_result.metadata)
metadata['token_res_ids'] = sequential_token_ids
return dataclasses.replace(
inference_result,
predicted_structure=viewer_structure,
metadata=metadata,
)
def _sequential_residue_ids_per_chain(chain_ids: Sequence[str]) -> list[int]:
"""Returns sequential residue IDs that are unique within each chain."""
next_residue_id_by_chain: dict[str, int] = {}
residue_ids = []
for chain_id in chain_ids:
next_residue_id = next_residue_id_by_chain.get(chain_id, 0) + 1
next_residue_id_by_chain[chain_id] = next_residue_id
residue_ids.append(next_residue_id)
return residue_ids
def predict_structure(
fold_input: folding_input.Input,
model_runner: ModelRunner,

View File

@@ -642,8 +642,8 @@ class _TestBase(parameterized.TestCase):
return chains_and_sequences
def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[int]]]:
"""Extract residue numbers for each polymer chain from a CIF file."""
def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[Union[int, str]]]]:
"""Extract author-facing residue numbers for each polymer chain from a CIF file."""
try:
from alphafold3.cpp import cif_dict
@@ -651,13 +651,28 @@ class _TestBase(parameterized.TestCase):
cif = cif_dict.from_string(handle.read())
asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object)
seq_ids = cif.get_array("_pdbx_poly_seq_scheme.seq_id", dtype=object)
auth_seq_nums = cif.get_array(
"_pdbx_poly_seq_scheme.auth_seq_num", dtype=object
)
ins_codes = cif.get_array(
"_pdbx_poly_seq_scheme.pdb_ins_code", dtype=object
)
chain_residue_numbers = []
chain_to_numbers = {}
for chain_id, seq_id in zip(asym_ids, seq_ids, strict=True):
for chain_id, auth_seq_num, ins_code in zip(
asym_ids,
auth_seq_nums,
ins_codes,
strict=True,
):
residue_numbers = chain_to_numbers.setdefault(chain_id, [])
residue_numbers.append(int(seq_id))
ins_code = str(ins_code)
auth_seq_num = int(auth_seq_num)
if ins_code in {".", "?"}:
residue_numbers.append(auth_seq_num)
else:
residue_numbers.append(f"{auth_seq_num}{ins_code}")
for chain_id, residue_numbers in chain_to_numbers.items():
if residue_numbers:
@@ -1402,11 +1417,15 @@ class TestAlphaFold3RunModes(_TestBase):
self.assertEqual(
viewer_result.predicted_structure.present_residues.id.tolist(),
original_residue_ids,
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.metadata["token_res_ids"],
original_residue_ids,
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
@@ -1449,7 +1468,11 @@ class TestAlphaFold3RunModes(_TestBase):
self._get_sequence_for_protein("TEST"),
concatenated_chopped_sequence,
]
expected_chopped_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
expected_chopped_residue_ids = (
list(range(1, 11))
+ ["2A", "3A", "4A", "5A"]
+ list(range(12, 16))
)
actual_sequences = [chain.sequence for chain in fold_input_obj.chains]
self.assertCountEqual(actual_sequences, expected_sequences)
self.assertLen(actual_sequences, 2)
@@ -1650,7 +1673,11 @@ class TestAlphaFold3RunModes(_TestBase):
self._get_sequence_for_protein("TEST"),
concatenated_chopped_sequence,
]
expected_chopped_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
expected_chopped_residue_ids = (
list(range(1, 11))
+ ["2A", "3A", "4A", "5A"]
+ list(range(12, 16))
)
result_dir = self._resolve_single_af3_result_dir()
cif_files = list(result_dir.glob("*_model.cif"))