Fix AF3 viewer numbering and output names

This commit is contained in:
Dima
2026-03-25 17:17:55 +01:00
parent f16189bcd8
commit c0cff91d23
2 changed files with 267 additions and 52 deletions

View File

@@ -14,6 +14,7 @@ import json
import logging
import os
import pathlib
import re
import time
import typing
from collections.abc import Sequence
@@ -190,6 +191,10 @@ def write_outputs(
post_processing.write_output(
inference_result=result, output_dir=sample_dir
)
_augment_confidence_json_with_author_numbering(
os.path.join(sample_dir, 'confidences.json'),
result,
)
ranking_score = float(result.metadata['ranking_score'])
ranking_scores.append((seed, sample_idx, ranking_score))
if max_ranking_score is None or ranking_score > max_ranking_score:
@@ -219,6 +224,10 @@ def write_outputs(
terms_of_use=output_terms,
name=job_name,
)
_augment_confidence_json_with_author_numbering(
os.path.join(output_dir, f'{job_name}_confidences.json'),
max_ranking_result,
)
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f:
writer = csv.writer(f)
writer.writerow(['seed', 'sample', 'ranking_score'])
@@ -238,36 +247,163 @@ def _duplicate_occurrence_to_insertion_code(occurrence_index: int) -> str:
return chr(ord('A') + offset)
def _has_duplicate_residue_ids_within_chain(struc) -> bool:
"""Returns True if a structure reuses residue IDs inside any one chain."""
residue_chain_ids = struc.chains_table.apply_array_to_column(
column_name='id',
arr=struc.residues_table.chain_key,
)
seen_by_chain: dict[str, set[int]] = {}
for chain_id, residue_id in zip(
residue_chain_ids,
struc.residues_table.id,
strict=True,
def _normalise_output_name_fragment(raw_name: str) -> str:
"""Normalises one output-name fragment while preserving readable IDs."""
cleaned = re.sub(r"\s+", "_", raw_name.strip())
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "_", cleaned)
cleaned = cleaned.strip("_.-")
return cleaned or "ranked_0"
def _regions_to_name_fragment(regions: Sequence[tuple[int, int]]) -> str:
"""Returns a readable name fragment for a set of closed residue intervals."""
return "_".join(f"{start}-{end}" for start, end in regions)
def _json_input_basename(json_path: str) -> str:
"""Returns a readable basename for an AF3 JSON input path."""
stem = pathlib.Path(json_path).stem
for suffix in ("_af3_input", "_input"):
if stem.endswith(suffix):
stem = stem[: -len(suffix)]
break
return stem or pathlib.Path(json_path).stem
def _object_name_fragment(obj: typing.Any) -> str:
"""Builds a deterministic output-name fragment for one modelling object."""
if isinstance(obj, dict) and "json_input" in obj:
fragment = _json_input_basename(obj["json_input"])
regions = obj.get("regions")
if isinstance(regions, Sequence) and regions:
fragment = f"{fragment}__{_regions_to_name_fragment(regions)}"
return _normalise_output_name_fragment(fragment)
if isinstance(obj, MultimericObject):
return _normalise_output_name_fragment(obj.description or "multimer")
if isinstance(obj, (MonomericObject, ChoppedObject)):
return _normalise_output_name_fragment(obj.description or "monomer")
if isinstance(obj, folding_input.Input):
return _normalise_output_name_fragment(obj.name)
return _normalise_output_name_fragment(type(obj).__name__)
def _build_output_job_name(objects_to_model: Sequence[dict]) -> str:
"""Builds a readable AF3 job name from the requested modelling objects."""
fragments: list[str] = []
for entry in objects_to_model:
object_to_model = entry["object"]
if isinstance(object_to_model, list):
fragments.extend(_object_name_fragment(obj) for obj in object_to_model)
else:
fragments.append(_object_name_fragment(object_to_model))
fragments = [fragment for fragment in fragments if fragment]
if not fragments:
return "ranked_0"
return "_and_".join(fragments)
def _residue_author_ids(struc) -> list[str]:
"""Returns author-facing residue IDs, falling back to residue IDs if unset."""
author_residue_ids = [str(residue_id) for residue_id in struc.residues_table.auth_seq_id]
if all(residue_id in {".", "?"} for residue_id in author_residue_ids):
return [str(int(residue_id)) for residue_id in struc.residues_table.id]
return author_residue_ids
def _existing_insertion_codes(struc) -> list[str]:
"""Returns normalised residue insertion codes from a structure."""
return [
"."
if insertion_code in {".", "?", ""}
else str(insertion_code)
for insertion_code in struc.residues_table.insertion_code
]
def _author_ids_with_insertion_codes(
chain_ids: Sequence[str],
author_residue_ids: Sequence[str],
existing_insertion_codes: Sequence[str] | None = None,
) -> tuple[list[str], list[str], list[str]]:
"""Returns author IDs, insertion codes, and combined author labels."""
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
insertion_codes: list[str] = []
combined_labels: list[str] = []
for index, (chain_id, residue_id) in enumerate(
zip(chain_ids, author_residue_ids, strict=True)
):
chain_id = str(chain_id)
residue_id = int(residue_id)
seen = seen_by_chain.setdefault(chain_id, set())
if residue_id in seen:
return True
seen.add(residue_id)
return False
explicit_insertion_code = "."
if existing_insertion_codes is not None:
explicit_insertion_code = existing_insertion_codes[index]
if explicit_insertion_code not in {".", "?", ""}:
insertion_code = explicit_insertion_code
else:
key = (chain_id, residue_id)
occurrence = occurrence_count_by_residue.get(key, 0) + 1
occurrence_count_by_residue[key] = occurrence
insertion_code = _duplicate_occurrence_to_insertion_code(occurrence)
insertion_codes.append(insertion_code)
if insertion_code == ".":
combined_labels.append(residue_id)
else:
combined_labels.append(f"{residue_id}{insertion_code}")
return list(author_residue_ids), insertion_codes, combined_labels
def _coerce_json_scalar(value: str) -> int | str:
"""Converts a stringified integer back to int where possible."""
try:
return int(value)
except (TypeError, ValueError):
return value
def _augment_confidence_json_with_author_numbering(
confidences_path: os.PathLike[str] | str,
inference_result: model.InferenceResult,
) -> None:
"""Adds preserved author numbering to the confidence sidecar JSON."""
token_auth_res_ids = inference_result.metadata.get("token_auth_res_ids")
token_pdb_ins_codes = inference_result.metadata.get("token_pdb_ins_codes")
token_auth_res_labels = inference_result.metadata.get("token_auth_res_labels")
if (
token_auth_res_ids is None
or token_pdb_ins_codes is None
or token_auth_res_labels is None
):
return
with open(confidences_path, "rt", encoding="utf-8") as handle:
confidence_data = json.load(handle)
confidence_data["token_label_seq_ids"] = [
int(token_id) for token_id in confidence_data.get("token_res_ids", [])
]
confidence_data["token_auth_res_ids"] = [
_coerce_json_scalar(str(token_id)) for token_id in token_auth_res_ids
]
confidence_data["token_pdb_ins_codes"] = [str(code) for code in token_pdb_ins_codes]
confidence_data["token_auth_res_labels"] = [
str(label) for label in token_auth_res_labels
]
with open(confidences_path, "wt", encoding="utf-8") as handle:
json.dump(confidence_data, handle, indent=1)
handle.write("\n")
def _make_viewer_compatible_inference_result(
inference_result: model.InferenceResult,
) -> model.InferenceResult:
"""Creates a viewer-safe copy with unique label IDs and original auth IDs."""
if not _has_duplicate_residue_ids_within_chain(
inference_result.predicted_structure
):
return inference_result
"""Creates a viewer-safe copy with sequential label IDs and preserved auth IDs."""
struc = inference_result.predicted_structure
residue_chain_ids = [
str(chain_id)
@@ -276,30 +412,22 @@ def _make_viewer_compatible_inference_result(
arr=struc.residues_table.chain_key,
)
]
author_residue_ids = [
str(residue_id)
for residue_id in struc.residues_table.auth_seq_id
]
if all(residue_id == '.' for residue_id in author_residue_ids):
author_residue_ids = [
str(int(residue_id)) for residue_id in struc.residues_table.id
]
author_residue_ids = _residue_author_ids(struc)
existing_insertion_codes = _existing_insertion_codes(struc)
sequential_label_ids = np.asarray(
_sequential_residue_ids_per_chain(residue_chain_ids),
dtype=np.int32,
)
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
insertion_codes = []
for chain_id, residue_id in zip(
(
author_residue_ids,
insertion_codes,
_,
) = _author_ids_with_insertion_codes(
residue_chain_ids,
author_residue_ids,
strict=True,
):
key = (chain_id, residue_id)
occurrence = occurrence_count_by_residue.get(key, 0) + 1
occurrence_count_by_residue[key] = occurrence
insertion_codes.append(_duplicate_occurrence_to_insertion_code(occurrence))
existing_insertion_codes,
)
viewer_structure = struc.copy_and_update_residues(
res_id=sequential_label_ids,
@@ -307,12 +435,25 @@ def _make_viewer_compatible_inference_result(
res_insertion_code=np.asarray(insertion_codes, dtype=object),
)
token_chain_ids = [
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
]
sequential_token_ids = _sequential_residue_ids_per_chain(token_chain_ids)
metadata = dict(inference_result.metadata)
metadata['token_res_ids'] = sequential_token_ids
token_chain_ids = [
str(chain_id)
for chain_id in metadata.get("token_chain_ids", [])
]
if token_chain_ids and "token_res_ids" in metadata:
token_author_ids = [str(token_id) for token_id in metadata["token_res_ids"]]
(
token_author_ids,
token_insertion_codes,
token_author_labels,
) = _author_ids_with_insertion_codes(
token_chain_ids,
token_author_ids,
)
metadata["token_res_ids"] = _sequential_residue_ids_per_chain(token_chain_ids)
metadata["token_auth_res_ids"] = token_author_ids
metadata["token_pdb_ins_codes"] = token_insertion_codes
metadata["token_auth_res_labels"] = token_author_labels
return dataclasses.replace(
inference_result,
predicted_structure=viewer_structure,
@@ -1256,7 +1397,7 @@ class AlphaFold3Backend(FoldingBackend):
return [mono_obj]
def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set):
nonlocal all_chains, job_name
nonlocal all_chains
if isinstance(obj, dict) and 'json_input' in obj:
json_path = obj['json_input']
json_regions = obj.get('regions')
@@ -1296,10 +1437,6 @@ class AlphaFold3Backend(FoldingBackend):
modified_chains.append(_clone_chain_with_id(chain, new_id))
all_chains.extend(modified_chains)
if len(all_chains) == len(modified_chains):
job_name = input_obj.name
else:
job_name = f"{job_name}_and_{input_obj.name}"
except Exception as e:
logging.error(f"Failed to parse JSON file {json_path}: {e}")
raise
@@ -1404,7 +1541,6 @@ class AlphaFold3Backend(FoldingBackend):
chain_id_counter = [0] # Use a list to allow pass-by-reference
used_chain_ids = set() # Track used chain IDs
all_chains = []
job_name = "ranked_0"
# Track chains whose MSAs were translated from AF2 features; they must
# not be rewritten by the promotion heuristic below.
af2_translated_msa_chain_ids: set[str] = set()
@@ -1459,6 +1595,7 @@ class AlphaFold3Backend(FoldingBackend):
promoted_chains.append(ch)
all_chains = promoted_chains
job_name = _build_output_job_name(objects_to_model)
combined_input = folding_input.Input(
name=job_name,
rng_seeds=[random_seed],

View File

@@ -1376,6 +1376,66 @@ class TestAlphaFold3RunModes(_TestBase):
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
def test_af3_viewer_output_renumbers_gapped_residue_ids_for_viewers(self):
"""Viewer-safe AF3 output must use sequential label IDs for gapped chains."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import model as af3_model
from alphapulldown.folding_backend.alphafold3_backend import (
_make_viewer_compatible_inference_result,
)
original_residue_ids = [2, 3, 4, 5, 8, 9, 10]
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGH",
ptms=[],
residue_ids=original_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="gapped_residue_ids_for_viewers",
chains=[chain],
rng_seeds=[1],
)
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
inference_result = af3_model.InferenceResult(
predicted_structure=struc,
metadata={
"token_chain_ids": ["A"] * len(original_residue_ids),
"token_res_ids": original_residue_ids,
},
)
viewer_result = _make_viewer_compatible_inference_result(inference_result)
self.assertEqual(
viewer_result.predicted_structure.present_residues.id.tolist(),
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.metadata["token_res_ids"],
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
["."] * len(original_residue_ids),
)
self.assertEqual(
viewer_result.metadata["token_auth_res_ids"],
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"],
[str(residue_id) for residue_id in original_residue_ids],
)
def test_af3_viewer_output_uses_insertion_codes_for_duplicate_residue_ids(self):
"""Viewer-safe AF3 output must preserve IDs and disambiguate with insertions."""
from alphafold3.common import folding_input
@@ -1431,6 +1491,20 @@ class TestAlphaFold3RunModes(_TestBase):
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
)
self.assertEqual(
viewer_result.metadata["token_auth_res_ids"],
[str(residue_id) for residue_id in original_residue_ids],
)
self.assertEqual(
viewer_result.metadata["token_pdb_ins_codes"],
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
)
self.assertEqual(
viewer_result.metadata["token_auth_res_labels"],
[str(i) for i in range(1, 11)]
+ [f"{i}A" for i in range(2, 6)]
+ [str(i) for i in range(12, 16)],
)
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
@@ -1568,6 +1642,10 @@ class TestAlphaFold3RunModes(_TestBase):
[chain.sequence for chain in fold_input_obj.chains],
[expected_sequence],
)
self.assertEqual(
fold_input_obj.sanitised_name(),
"A0A024R1R8__2-5_8-10",
)
self.assertEqual(
[list(chain.residue_ids) for chain in fold_input_obj.chains],
[expected_residue_ids],