mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Fix AF3 viewer numbering and output names
This commit is contained in:
@@ -14,6 +14,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import time
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
@@ -190,6 +191,10 @@ def write_outputs(
|
||||
post_processing.write_output(
|
||||
inference_result=result, output_dir=sample_dir
|
||||
)
|
||||
_augment_confidence_json_with_author_numbering(
|
||||
os.path.join(sample_dir, 'confidences.json'),
|
||||
result,
|
||||
)
|
||||
ranking_score = float(result.metadata['ranking_score'])
|
||||
ranking_scores.append((seed, sample_idx, ranking_score))
|
||||
if max_ranking_score is None or ranking_score > max_ranking_score:
|
||||
@@ -219,6 +224,10 @@ def write_outputs(
|
||||
terms_of_use=output_terms,
|
||||
name=job_name,
|
||||
)
|
||||
_augment_confidence_json_with_author_numbering(
|
||||
os.path.join(output_dir, f'{job_name}_confidences.json'),
|
||||
max_ranking_result,
|
||||
)
|
||||
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['seed', 'sample', 'ranking_score'])
|
||||
@@ -238,36 +247,163 @@ def _duplicate_occurrence_to_insertion_code(occurrence_index: int) -> str:
|
||||
return chr(ord('A') + offset)
|
||||
|
||||
|
||||
def _has_duplicate_residue_ids_within_chain(struc) -> bool:
|
||||
"""Returns True if a structure reuses residue IDs inside any one chain."""
|
||||
residue_chain_ids = struc.chains_table.apply_array_to_column(
|
||||
column_name='id',
|
||||
arr=struc.residues_table.chain_key,
|
||||
)
|
||||
seen_by_chain: dict[str, set[int]] = {}
|
||||
for chain_id, residue_id in zip(
|
||||
residue_chain_ids,
|
||||
struc.residues_table.id,
|
||||
strict=True,
|
||||
def _normalise_output_name_fragment(raw_name: str) -> str:
|
||||
"""Normalises one output-name fragment while preserving readable IDs."""
|
||||
cleaned = re.sub(r"\s+", "_", raw_name.strip())
|
||||
cleaned = re.sub(r"[^A-Za-z0-9_.-]+", "_", cleaned)
|
||||
cleaned = cleaned.strip("_.-")
|
||||
return cleaned or "ranked_0"
|
||||
|
||||
|
||||
def _regions_to_name_fragment(regions: Sequence[tuple[int, int]]) -> str:
|
||||
"""Returns a readable name fragment for a set of closed residue intervals."""
|
||||
return "_".join(f"{start}-{end}" for start, end in regions)
|
||||
|
||||
|
||||
def _json_input_basename(json_path: str) -> str:
|
||||
"""Returns a readable basename for an AF3 JSON input path."""
|
||||
stem = pathlib.Path(json_path).stem
|
||||
for suffix in ("_af3_input", "_input"):
|
||||
if stem.endswith(suffix):
|
||||
stem = stem[: -len(suffix)]
|
||||
break
|
||||
return stem or pathlib.Path(json_path).stem
|
||||
|
||||
|
||||
def _object_name_fragment(obj: typing.Any) -> str:
|
||||
"""Builds a deterministic output-name fragment for one modelling object."""
|
||||
if isinstance(obj, dict) and "json_input" in obj:
|
||||
fragment = _json_input_basename(obj["json_input"])
|
||||
regions = obj.get("regions")
|
||||
if isinstance(regions, Sequence) and regions:
|
||||
fragment = f"{fragment}__{_regions_to_name_fragment(regions)}"
|
||||
return _normalise_output_name_fragment(fragment)
|
||||
|
||||
if isinstance(obj, MultimericObject):
|
||||
return _normalise_output_name_fragment(obj.description or "multimer")
|
||||
|
||||
if isinstance(obj, (MonomericObject, ChoppedObject)):
|
||||
return _normalise_output_name_fragment(obj.description or "monomer")
|
||||
|
||||
if isinstance(obj, folding_input.Input):
|
||||
return _normalise_output_name_fragment(obj.name)
|
||||
|
||||
return _normalise_output_name_fragment(type(obj).__name__)
|
||||
|
||||
|
||||
def _build_output_job_name(objects_to_model: Sequence[dict]) -> str:
|
||||
"""Builds a readable AF3 job name from the requested modelling objects."""
|
||||
fragments: list[str] = []
|
||||
for entry in objects_to_model:
|
||||
object_to_model = entry["object"]
|
||||
if isinstance(object_to_model, list):
|
||||
fragments.extend(_object_name_fragment(obj) for obj in object_to_model)
|
||||
else:
|
||||
fragments.append(_object_name_fragment(object_to_model))
|
||||
fragments = [fragment for fragment in fragments if fragment]
|
||||
if not fragments:
|
||||
return "ranked_0"
|
||||
return "_and_".join(fragments)
|
||||
|
||||
|
||||
def _residue_author_ids(struc) -> list[str]:
|
||||
"""Returns author-facing residue IDs, falling back to residue IDs if unset."""
|
||||
author_residue_ids = [str(residue_id) for residue_id in struc.residues_table.auth_seq_id]
|
||||
if all(residue_id in {".", "?"} for residue_id in author_residue_ids):
|
||||
return [str(int(residue_id)) for residue_id in struc.residues_table.id]
|
||||
return author_residue_ids
|
||||
|
||||
|
||||
def _existing_insertion_codes(struc) -> list[str]:
|
||||
"""Returns normalised residue insertion codes from a structure."""
|
||||
return [
|
||||
"."
|
||||
if insertion_code in {".", "?", ""}
|
||||
else str(insertion_code)
|
||||
for insertion_code in struc.residues_table.insertion_code
|
||||
]
|
||||
|
||||
|
||||
def _author_ids_with_insertion_codes(
|
||||
chain_ids: Sequence[str],
|
||||
author_residue_ids: Sequence[str],
|
||||
existing_insertion_codes: Sequence[str] | None = None,
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
"""Returns author IDs, insertion codes, and combined author labels."""
|
||||
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
|
||||
insertion_codes: list[str] = []
|
||||
combined_labels: list[str] = []
|
||||
|
||||
for index, (chain_id, residue_id) in enumerate(
|
||||
zip(chain_ids, author_residue_ids, strict=True)
|
||||
):
|
||||
chain_id = str(chain_id)
|
||||
residue_id = int(residue_id)
|
||||
seen = seen_by_chain.setdefault(chain_id, set())
|
||||
if residue_id in seen:
|
||||
return True
|
||||
seen.add(residue_id)
|
||||
return False
|
||||
explicit_insertion_code = "."
|
||||
if existing_insertion_codes is not None:
|
||||
explicit_insertion_code = existing_insertion_codes[index]
|
||||
|
||||
if explicit_insertion_code not in {".", "?", ""}:
|
||||
insertion_code = explicit_insertion_code
|
||||
else:
|
||||
key = (chain_id, residue_id)
|
||||
occurrence = occurrence_count_by_residue.get(key, 0) + 1
|
||||
occurrence_count_by_residue[key] = occurrence
|
||||
insertion_code = _duplicate_occurrence_to_insertion_code(occurrence)
|
||||
|
||||
insertion_codes.append(insertion_code)
|
||||
if insertion_code == ".":
|
||||
combined_labels.append(residue_id)
|
||||
else:
|
||||
combined_labels.append(f"{residue_id}{insertion_code}")
|
||||
|
||||
return list(author_residue_ids), insertion_codes, combined_labels
|
||||
|
||||
|
||||
def _coerce_json_scalar(value: str) -> int | str:
|
||||
"""Converts a stringified integer back to int where possible."""
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return value
|
||||
|
||||
|
||||
def _augment_confidence_json_with_author_numbering(
|
||||
confidences_path: os.PathLike[str] | str,
|
||||
inference_result: model.InferenceResult,
|
||||
) -> None:
|
||||
"""Adds preserved author numbering to the confidence sidecar JSON."""
|
||||
token_auth_res_ids = inference_result.metadata.get("token_auth_res_ids")
|
||||
token_pdb_ins_codes = inference_result.metadata.get("token_pdb_ins_codes")
|
||||
token_auth_res_labels = inference_result.metadata.get("token_auth_res_labels")
|
||||
if (
|
||||
token_auth_res_ids is None
|
||||
or token_pdb_ins_codes is None
|
||||
or token_auth_res_labels is None
|
||||
):
|
||||
return
|
||||
|
||||
with open(confidences_path, "rt", encoding="utf-8") as handle:
|
||||
confidence_data = json.load(handle)
|
||||
|
||||
confidence_data["token_label_seq_ids"] = [
|
||||
int(token_id) for token_id in confidence_data.get("token_res_ids", [])
|
||||
]
|
||||
confidence_data["token_auth_res_ids"] = [
|
||||
_coerce_json_scalar(str(token_id)) for token_id in token_auth_res_ids
|
||||
]
|
||||
confidence_data["token_pdb_ins_codes"] = [str(code) for code in token_pdb_ins_codes]
|
||||
confidence_data["token_auth_res_labels"] = [
|
||||
str(label) for label in token_auth_res_labels
|
||||
]
|
||||
|
||||
with open(confidences_path, "wt", encoding="utf-8") as handle:
|
||||
json.dump(confidence_data, handle, indent=1)
|
||||
handle.write("\n")
|
||||
|
||||
|
||||
def _make_viewer_compatible_inference_result(
|
||||
inference_result: model.InferenceResult,
|
||||
) -> model.InferenceResult:
|
||||
"""Creates a viewer-safe copy with unique label IDs and original auth IDs."""
|
||||
if not _has_duplicate_residue_ids_within_chain(
|
||||
inference_result.predicted_structure
|
||||
):
|
||||
return inference_result
|
||||
|
||||
"""Creates a viewer-safe copy with sequential label IDs and preserved auth IDs."""
|
||||
struc = inference_result.predicted_structure
|
||||
residue_chain_ids = [
|
||||
str(chain_id)
|
||||
@@ -276,30 +412,22 @@ def _make_viewer_compatible_inference_result(
|
||||
arr=struc.residues_table.chain_key,
|
||||
)
|
||||
]
|
||||
author_residue_ids = [
|
||||
str(residue_id)
|
||||
for residue_id in struc.residues_table.auth_seq_id
|
||||
]
|
||||
if all(residue_id == '.' for residue_id in author_residue_ids):
|
||||
author_residue_ids = [
|
||||
str(int(residue_id)) for residue_id in struc.residues_table.id
|
||||
]
|
||||
author_residue_ids = _residue_author_ids(struc)
|
||||
existing_insertion_codes = _existing_insertion_codes(struc)
|
||||
|
||||
sequential_label_ids = np.asarray(
|
||||
_sequential_residue_ids_per_chain(residue_chain_ids),
|
||||
dtype=np.int32,
|
||||
)
|
||||
occurrence_count_by_residue: dict[tuple[str, str], int] = {}
|
||||
insertion_codes = []
|
||||
for chain_id, residue_id in zip(
|
||||
(
|
||||
author_residue_ids,
|
||||
insertion_codes,
|
||||
_,
|
||||
) = _author_ids_with_insertion_codes(
|
||||
residue_chain_ids,
|
||||
author_residue_ids,
|
||||
strict=True,
|
||||
):
|
||||
key = (chain_id, residue_id)
|
||||
occurrence = occurrence_count_by_residue.get(key, 0) + 1
|
||||
occurrence_count_by_residue[key] = occurrence
|
||||
insertion_codes.append(_duplicate_occurrence_to_insertion_code(occurrence))
|
||||
existing_insertion_codes,
|
||||
)
|
||||
|
||||
viewer_structure = struc.copy_and_update_residues(
|
||||
res_id=sequential_label_ids,
|
||||
@@ -307,12 +435,25 @@ def _make_viewer_compatible_inference_result(
|
||||
res_insertion_code=np.asarray(insertion_codes, dtype=object),
|
||||
)
|
||||
|
||||
token_chain_ids = [
|
||||
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
|
||||
]
|
||||
sequential_token_ids = _sequential_residue_ids_per_chain(token_chain_ids)
|
||||
metadata = dict(inference_result.metadata)
|
||||
metadata['token_res_ids'] = sequential_token_ids
|
||||
token_chain_ids = [
|
||||
str(chain_id)
|
||||
for chain_id in metadata.get("token_chain_ids", [])
|
||||
]
|
||||
if token_chain_ids and "token_res_ids" in metadata:
|
||||
token_author_ids = [str(token_id) for token_id in metadata["token_res_ids"]]
|
||||
(
|
||||
token_author_ids,
|
||||
token_insertion_codes,
|
||||
token_author_labels,
|
||||
) = _author_ids_with_insertion_codes(
|
||||
token_chain_ids,
|
||||
token_author_ids,
|
||||
)
|
||||
metadata["token_res_ids"] = _sequential_residue_ids_per_chain(token_chain_ids)
|
||||
metadata["token_auth_res_ids"] = token_author_ids
|
||||
metadata["token_pdb_ins_codes"] = token_insertion_codes
|
||||
metadata["token_auth_res_labels"] = token_author_labels
|
||||
return dataclasses.replace(
|
||||
inference_result,
|
||||
predicted_structure=viewer_structure,
|
||||
@@ -1256,7 +1397,7 @@ class AlphaFold3Backend(FoldingBackend):
|
||||
return [mono_obj]
|
||||
|
||||
def _process_single_object(obj, chain_id_counter_ref, used_chain_ids: set):
|
||||
nonlocal all_chains, job_name
|
||||
nonlocal all_chains
|
||||
if isinstance(obj, dict) and 'json_input' in obj:
|
||||
json_path = obj['json_input']
|
||||
json_regions = obj.get('regions')
|
||||
@@ -1296,10 +1437,6 @@ class AlphaFold3Backend(FoldingBackend):
|
||||
modified_chains.append(_clone_chain_with_id(chain, new_id))
|
||||
|
||||
all_chains.extend(modified_chains)
|
||||
if len(all_chains) == len(modified_chains):
|
||||
job_name = input_obj.name
|
||||
else:
|
||||
job_name = f"{job_name}_and_{input_obj.name}"
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse JSON file {json_path}: {e}")
|
||||
raise
|
||||
@@ -1404,7 +1541,6 @@ class AlphaFold3Backend(FoldingBackend):
|
||||
chain_id_counter = [0] # Use a list to allow pass-by-reference
|
||||
used_chain_ids = set() # Track used chain IDs
|
||||
all_chains = []
|
||||
job_name = "ranked_0"
|
||||
# Track chains whose MSAs were translated from AF2 features; they must
|
||||
# not be rewritten by the promotion heuristic below.
|
||||
af2_translated_msa_chain_ids: set[str] = set()
|
||||
@@ -1459,6 +1595,7 @@ class AlphaFold3Backend(FoldingBackend):
|
||||
promoted_chains.append(ch)
|
||||
|
||||
all_chains = promoted_chains
|
||||
job_name = _build_output_job_name(objects_to_model)
|
||||
combined_input = folding_input.Input(
|
||||
name=job_name,
|
||||
rng_seeds=[random_seed],
|
||||
|
||||
@@ -1376,6 +1376,66 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
|
||||
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
|
||||
|
||||
def test_af3_viewer_output_renumbers_gapped_residue_ids_for_viewers(self):
|
||||
"""Viewer-safe AF3 output must use sequential label IDs for gapped chains."""
|
||||
from alphafold3.common import folding_input
|
||||
from alphafold3.constants import chemical_components
|
||||
from alphafold3.model import model as af3_model
|
||||
from alphapulldown.folding_backend.alphafold3_backend import (
|
||||
_make_viewer_compatible_inference_result,
|
||||
)
|
||||
|
||||
original_residue_ids = [2, 3, 4, 5, 8, 9, 10]
|
||||
chain = folding_input.ProteinChain(
|
||||
id="A",
|
||||
sequence="ACDEFGH",
|
||||
ptms=[],
|
||||
residue_ids=original_residue_ids,
|
||||
unpaired_msa="",
|
||||
paired_msa="",
|
||||
templates=[],
|
||||
)
|
||||
fold_input = folding_input.Input(
|
||||
name="gapped_residue_ids_for_viewers",
|
||||
chains=[chain],
|
||||
rng_seeds=[1],
|
||||
)
|
||||
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
|
||||
inference_result = af3_model.InferenceResult(
|
||||
predicted_structure=struc,
|
||||
metadata={
|
||||
"token_chain_ids": ["A"] * len(original_residue_ids),
|
||||
"token_res_ids": original_residue_ids,
|
||||
},
|
||||
)
|
||||
|
||||
viewer_result = _make_viewer_compatible_inference_result(inference_result)
|
||||
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.present_residues.id.tolist(),
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_res_ids"],
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.residues_table.auth_seq_id.tolist(),
|
||||
[str(residue_id) for residue_id in original_residue_ids],
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
|
||||
["."] * len(original_residue_ids),
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_auth_res_ids"],
|
||||
[str(residue_id) for residue_id in original_residue_ids],
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_auth_res_labels"],
|
||||
[str(residue_id) for residue_id in original_residue_ids],
|
||||
)
|
||||
|
||||
def test_af3_viewer_output_uses_insertion_codes_for_duplicate_residue_ids(self):
|
||||
"""Viewer-safe AF3 output must preserve IDs and disambiguate with insertions."""
|
||||
from alphafold3.common import folding_input
|
||||
@@ -1431,6 +1491,20 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
viewer_result.predicted_structure.residues_table.insertion_code.tolist(),
|
||||
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_auth_res_ids"],
|
||||
[str(residue_id) for residue_id in original_residue_ids],
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_pdb_ins_codes"],
|
||||
['.'] * 10 + ['A'] * 4 + ['.'] * 4,
|
||||
)
|
||||
self.assertEqual(
|
||||
viewer_result.metadata["token_auth_res_labels"],
|
||||
[str(i) for i in range(1, 11)]
|
||||
+ [f"{i}A" for i in range(2, 6)]
|
||||
+ [str(i) for i in range(12, 16)],
|
||||
)
|
||||
|
||||
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
|
||||
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
|
||||
@@ -1568,6 +1642,10 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
[chain.sequence for chain in fold_input_obj.chains],
|
||||
[expected_sequence],
|
||||
)
|
||||
self.assertEqual(
|
||||
fold_input_obj.sanitised_name(),
|
||||
"A0A024R1R8__2-5_8-10",
|
||||
)
|
||||
self.assertEqual(
|
||||
[list(chain.residue_ids) for chain in fold_input_obj.chains],
|
||||
[expected_residue_ids],
|
||||
|
||||
Reference in New Issue
Block a user