Write ChimeraX-safe AF3 outputs

This commit is contained in:
Dima
2026-03-25 16:29:19 +01:00
parent e6016ac2a3
commit 8a6f4a1463
2 changed files with 159 additions and 0 deletions

View File

@@ -189,6 +189,10 @@ def write_outputs(
post_processing.write_output(
inference_result=result, output_dir=sample_dir
)
_write_chimerax_output_if_needed(
inference_result=result,
output_dir=sample_dir,
)
ranking_score = float(result.metadata['ranking_score'])
ranking_scores.append((seed, sample_idx, ranking_score))
if max_ranking_score is None or ranking_score > max_ranking_score:
@@ -218,12 +222,119 @@ def write_outputs(
terms_of_use=output_terms,
name=job_name,
)
_write_chimerax_output_if_needed(
inference_result=max_ranking_result,
output_dir=output_dir,
name=job_name,
)
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f:
writer = csv.writer(f)
writer.writerow(['seed', 'sample', 'ranking_score'])
writer.writerows(ranking_scores)
def _sequential_residue_ids_per_chain(chain_ids: Sequence[str]) -> list[int]:
"""Returns sequential residue IDs that are unique within each chain."""
next_residue_id_by_chain: dict[str, int] = {}
residue_ids = []
for chain_id in chain_ids:
next_residue_id = next_residue_id_by_chain.get(chain_id, 0) + 1
next_residue_id_by_chain[chain_id] = next_residue_id
residue_ids.append(next_residue_id)
return residue_ids
def _has_duplicate_residue_ids_within_chain(struc) -> bool:
"""Returns True if a structure reuses residue IDs inside any one chain."""
residue_chain_ids = struc.chains_table.apply_array_to_column(
column_name='id',
arr=struc.residues_table.chain_key,
)
seen_by_chain: dict[str, set[int]] = {}
for chain_id, residue_id in zip(
residue_chain_ids,
struc.residues_table.id,
strict=True,
):
chain_id = str(chain_id)
residue_id = int(residue_id)
seen = seen_by_chain.setdefault(chain_id, set())
if residue_id in seen:
return True
seen.add(residue_id)
return False
def _make_chimerax_compatible_inference_result(
inference_result: model.InferenceResult,
) -> model.InferenceResult:
"""Creates a viewer-safe copy with unique residue IDs per chain."""
struc = inference_result.predicted_structure
residue_chain_ids = [
str(chain_id)
for chain_id in struc.chains_table.apply_array_to_column(
column_name='id',
arr=struc.residues_table.chain_key,
)
]
chimerax_residue_ids = np.asarray(
_sequential_residue_ids_per_chain(residue_chain_ids),
dtype=np.int32,
)
chimerax_structure = struc.copy_and_update_residues(
res_id=chimerax_residue_ids,
res_auth_seq_id=np.char.mod('%d', chimerax_residue_ids).astype(object),
)
token_chain_ids = [
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
]
metadata = dict(inference_result.metadata)
metadata['token_res_ids'] = _sequential_residue_ids_per_chain(token_chain_ids)
return dataclasses.replace(
inference_result,
predicted_structure=chimerax_structure,
metadata=metadata,
)
def _write_chimerax_output_if_needed(
inference_result: model.InferenceResult,
output_dir: os.PathLike[str] | str,
name: str | None = None,
) -> None:
"""Writes viewer-safe output files when duplicate residue IDs are present."""
if not _has_duplicate_residue_ids_within_chain(
inference_result.predicted_structure
):
return
processed_result = post_processing.post_process_inference_result(
_make_chimerax_compatible_inference_result(inference_result)
)
prefix = f'{name}_' if name is not None else ''
with open(os.path.join(output_dir, f'{prefix}model_chimerax.cif'), 'wb') as f:
f.write(processed_result.cif)
with open(
os.path.join(output_dir, f'{prefix}summary_confidences_chimerax.json'),
'wb',
) as f:
f.write(processed_result.structure_confidence_summary_json)
with open(
os.path.join(output_dir, f'{prefix}confidences_chimerax.json'),
'wb',
) as f:
f.write(processed_result.structure_full_data_json)
logging.info(
'Wrote ChimeraX-compatible outputs with unique residue numbering for %s',
name or os.fspath(output_dir),
)
def predict_structure(
fold_input: folding_input.Input,
model_runner: ModelRunner,

View File

@@ -1361,6 +1361,54 @@ class TestAlphaFold3RunModes(_TestBase):
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
def test_af3_chimerax_export_renumbers_duplicate_residue_ids(self):
"""ChimeraX export must assign unique sequential residue IDs per chain."""
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.model import model as af3_model
from alphapulldown.folding_backend.alphafold3_backend import (
_make_chimerax_compatible_inference_result,
)
original_residue_ids = (
list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
)
chain = folding_input.ProteinChain(
id="A",
sequence="ACDEFGHIKLCDEFMNPQ",
ptms=[],
residue_ids=original_residue_ids,
unpaired_msa="",
paired_msa="",
templates=[],
)
fold_input = folding_input.Input(
name="duplicate_residue_ids_for_chimerax",
chains=[chain],
rng_seeds=[1],
)
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
inference_result = af3_model.InferenceResult(
predicted_structure=struc,
metadata={
"token_chain_ids": ["A"] * len(original_residue_ids),
"token_res_ids": original_residue_ids,
},
)
chimerax_result = _make_chimerax_compatible_inference_result(
inference_result
)
self.assertEqual(
chimerax_result.predicted_structure.present_residues.id.tolist(),
list(range(1, len(original_residue_ids) + 1)),
)
self.assertEqual(
chimerax_result.metadata["token_res_ids"],
list(range(1, len(original_residue_ids) + 1)),
)
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
from alphapulldown.folding_backend.alphafold3_backend import (