mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Fix AF3 viewer annotation tables for duplicate residues
This commit is contained in:
@@ -262,37 +262,75 @@ def _has_duplicate_residue_ids_within_chain(struc) -> bool:
|
||||
def _make_viewer_compatible_inference_result(
|
||||
inference_result: model.InferenceResult,
|
||||
) -> model.InferenceResult:
|
||||
"""Creates a viewer-safe copy using insertion codes for duplicates."""
|
||||
"""Creates a viewer-safe copy with unique label IDs and original auth IDs."""
|
||||
if not _has_duplicate_residue_ids_within_chain(
|
||||
inference_result.predicted_structure
|
||||
):
|
||||
return inference_result
|
||||
|
||||
struc = inference_result.predicted_structure
|
||||
residue_chain_ids = struc.chains_table.apply_array_to_column(
|
||||
column_name='id',
|
||||
arr=struc.residues_table.chain_key,
|
||||
residue_chain_ids = [
|
||||
str(chain_id)
|
||||
for chain_id in struc.chains_table.apply_array_to_column(
|
||||
column_name='id',
|
||||
arr=struc.residues_table.chain_key,
|
||||
)
|
||||
]
|
||||
author_residue_ids = [
|
||||
str(residue_id)
|
||||
for residue_id in struc.residues_table.auth_seq_id
|
||||
]
|
||||
if all(residue_id == '.' for residue_id in author_residue_ids):
|
||||
author_residue_ids = [
|
||||
str(int(residue_id)) for residue_id in struc.residues_table.id
|
||||
]
|
||||
|
||||
sequential_label_ids = np.asarray(
|
||||
_sequential_residue_ids_per_chain(residue_chain_ids),
|
||||
dtype=np.int32,
|
||||
)
|
||||
occurrence_count_by_residue: dict[tuple[str, int], int] = {}
|
||||
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
|
||||
insertion_codes = []
|
||||
for chain_id, residue_id in zip(
|
||||
residue_chain_ids,
|
||||
struc.residues_table.id,
|
||||
author_residue_ids,
|
||||
strict=True,
|
||||
):
|
||||
key = (str(chain_id), int(residue_id))
|
||||
key = (chain_id, residue_id)
|
||||
occurrence = occurrence_count_by_residue.get(key, 0) + 1
|
||||
occurrence_count_by_residue[key] = occurrence
|
||||
insertion_codes.append(_duplicate_occurrence_to_insertion_code(occurrence))
|
||||
|
||||
viewer_structure = struc.copy_and_update_residues(
|
||||
res_id=sequential_label_ids,
|
||||
res_auth_seq_id=np.asarray(author_residue_ids, dtype=object),
|
||||
res_insertion_code=np.asarray(insertion_codes, dtype=object),
|
||||
)
|
||||
|
||||
token_chain_ids = [
|
||||
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
|
||||
]
|
||||
sequential_token_ids = _sequential_residue_ids_per_chain(token_chain_ids)
|
||||
metadata = dict(inference_result.metadata)
|
||||
metadata['token_res_ids'] = sequential_token_ids
|
||||
return dataclasses.replace(
|
||||
inference_result,
|
||||
predicted_structure=viewer_structure,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _sequential_residue_ids_per_chain(chain_ids: Sequence[str]) -> list[int]:
|
||||
"""Returns sequential residue IDs that are unique within each chain."""
|
||||
next_residue_id_by_chain: dict[str, int] = {}
|
||||
residue_ids = []
|
||||
for chain_id in chain_ids:
|
||||
next_residue_id = next_residue_id_by_chain.get(chain_id, 0) + 1
|
||||
next_residue_id_by_chain[chain_id] = next_residue_id
|
||||
residue_ids.append(next_residue_id)
|
||||
return residue_ids
|
||||
|
||||
|
||||
def predict_structure(
|
||||
fold_input: folding_input.Input,
|
||||
model_runner: ModelRunner,
|
||||
|
||||
@@ -642,8 +642,8 @@ class _TestBase(parameterized.TestCase):
|
||||
|
||||
return chains_and_sequences
|
||||
|
||||
def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[int]]]:
|
||||
"""Extract residue numbers for each polymer chain from a CIF file."""
|
||||
def _extract_cif_chain_residue_numbers(self, cif_path: Path) -> List[Tuple[str, List[Union[int, str]]]]:
|
||||
"""Extract author-facing residue numbers for each polymer chain from a CIF file."""
|
||||
try:
|
||||
from alphafold3.cpp import cif_dict
|
||||
|
||||
@@ -651,13 +651,28 @@ class _TestBase(parameterized.TestCase):
|
||||
cif = cif_dict.from_string(handle.read())
|
||||
|
||||
asym_ids = cif.get_array("_pdbx_poly_seq_scheme.asym_id", dtype=object)
|
||||
seq_ids = cif.get_array("_pdbx_poly_seq_scheme.seq_id", dtype=object)
|
||||
auth_seq_nums = cif.get_array(
|
||||
"_pdbx_poly_seq_scheme.auth_seq_num", dtype=object
|
||||
)
|
||||
ins_codes = cif.get_array(
|
||||
"_pdbx_poly_seq_scheme.pdb_ins_code", dtype=object
|
||||
)
|
||||
|
||||
chain_residue_numbers = []
|
||||
chain_to_numbers = {}
|
||||
for chain_id, seq_id in zip(asym_ids, seq_ids, strict=True):
|
||||
for chain_id, auth_seq_num, ins_code in zip(
|
||||
asym_ids,
|
||||
auth_seq_nums,
|
||||
ins_codes,
|
||||
strict=True,
|
||||
):
|
||||
residue_numbers = chain_to_numbers.setdefault(chain_id, [])
|
||||
residue_numbers.append(int(seq_id))
|
||||
ins_code = str(ins_code)
|
||||
auth_seq_num = int(auth_seq_num)
|
||||
if ins_code in {".", "?"}:
|
||||
residue_numbers.append(auth_seq_num)
|
||||
else:
|
||||
residue_numbers.append(f"{auth_seq_num}{ins_code}")
|
||||
|
||||
for chain_id, residue_numbers in chain_to_numbers.items():
|
||||
if residue_numbers:
|
||||
@@ -1402,11 +1417,15 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.present_residues.id.tolist(),
|
||||
original_residue_ids,
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_res_ids"],
|
||||
original_residue_ids,
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
|
||||
[str(residue_id) for residue_id in original_residue_ids],
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
|
||||
@@ -1449,7 +1468,11 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
self._get_sequence_for_protein("TEST"),
|
||||
concatenated_chopped_sequence,
|
||||
]
|
||||
expected_chopped_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
|
||||
expected_chopped_residue_ids = (
|
||||
list(range(1, 11))
|
||||
+ ["2A", "3A", "4A", "5A"]
|
||||
+ list(range(12, 16))
|
||||
)
|
||||
actual_sequences = [chain.sequence for chain in fold_input_obj.chains]
|
||||
self.assertCountEqual(actual_sequences, expected_sequences)
|
||||
self.assertLen(actual_sequences, 2)
|
||||
@@ -1650,7 +1673,11 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
self._get_sequence_for_protein("TEST"),
|
||||
concatenated_chopped_sequence,
|
||||
]
|
||||
expected_chopped_residue_ids = list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
|
||||
expected_chopped_residue_ids = (
|
||||
list(range(1, 11))
|
||||
+ ["2A", "3A", "4A", "5A"]
|
||||
+ list(range(12, 16))
|
||||
)
|
||||
|
||||
result_dir = self._resolve_single_af3_result_dir()
|
||||
cif_files = list(result_dir.glob("*_model.cif"))
|
||||
|
||||
Reference in New Issue
Block a user