mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Write ChimeraX-safe AF3 outputs
This commit is contained in:
@@ -189,6 +189,10 @@ def write_outputs(
|
||||
post_processing.write_output(
|
||||
inference_result=result, output_dir=sample_dir
|
||||
)
|
||||
_write_chimerax_output_if_needed(
|
||||
inference_result=result,
|
||||
output_dir=sample_dir,
|
||||
)
|
||||
ranking_score = float(result.metadata['ranking_score'])
|
||||
ranking_scores.append((seed, sample_idx, ranking_score))
|
||||
if max_ranking_score is None or ranking_score > max_ranking_score:
|
||||
@@ -218,12 +222,119 @@ def write_outputs(
|
||||
terms_of_use=output_terms,
|
||||
name=job_name,
|
||||
)
|
||||
_write_chimerax_output_if_needed(
|
||||
inference_result=max_ranking_result,
|
||||
output_dir=output_dir,
|
||||
name=job_name,
|
||||
)
|
||||
with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(['seed', 'sample', 'ranking_score'])
|
||||
writer.writerows(ranking_scores)
|
||||
|
||||
|
||||
def _sequential_residue_ids_per_chain(chain_ids: Sequence[str]) -> list[int]:
|
||||
"""Returns sequential residue IDs that are unique within each chain."""
|
||||
next_residue_id_by_chain: dict[str, int] = {}
|
||||
residue_ids = []
|
||||
for chain_id in chain_ids:
|
||||
next_residue_id = next_residue_id_by_chain.get(chain_id, 0) + 1
|
||||
next_residue_id_by_chain[chain_id] = next_residue_id
|
||||
residue_ids.append(next_residue_id)
|
||||
return residue_ids
|
||||
|
||||
|
||||
def _has_duplicate_residue_ids_within_chain(struc) -> bool:
|
||||
"""Returns True if a structure reuses residue IDs inside any one chain."""
|
||||
residue_chain_ids = struc.chains_table.apply_array_to_column(
|
||||
column_name='id',
|
||||
arr=struc.residues_table.chain_key,
|
||||
)
|
||||
seen_by_chain: dict[str, set[int]] = {}
|
||||
for chain_id, residue_id in zip(
|
||||
residue_chain_ids,
|
||||
struc.residues_table.id,
|
||||
strict=True,
|
||||
):
|
||||
chain_id = str(chain_id)
|
||||
residue_id = int(residue_id)
|
||||
seen = seen_by_chain.setdefault(chain_id, set())
|
||||
if residue_id in seen:
|
||||
return True
|
||||
seen.add(residue_id)
|
||||
return False
|
||||
|
||||
|
||||
def _make_chimerax_compatible_inference_result(
|
||||
inference_result: model.InferenceResult,
|
||||
) -> model.InferenceResult:
|
||||
"""Creates a viewer-safe copy with unique residue IDs per chain."""
|
||||
struc = inference_result.predicted_structure
|
||||
residue_chain_ids = [
|
||||
str(chain_id)
|
||||
for chain_id in struc.chains_table.apply_array_to_column(
|
||||
column_name='id',
|
||||
arr=struc.residues_table.chain_key,
|
||||
)
|
||||
]
|
||||
chimerax_residue_ids = np.asarray(
|
||||
_sequential_residue_ids_per_chain(residue_chain_ids),
|
||||
dtype=np.int32,
|
||||
)
|
||||
chimerax_structure = struc.copy_and_update_residues(
|
||||
res_id=chimerax_residue_ids,
|
||||
res_auth_seq_id=np.char.mod('%d', chimerax_residue_ids).astype(object),
|
||||
)
|
||||
|
||||
token_chain_ids = [
|
||||
str(chain_id) for chain_id in inference_result.metadata['token_chain_ids']
|
||||
]
|
||||
metadata = dict(inference_result.metadata)
|
||||
metadata['token_res_ids'] = _sequential_residue_ids_per_chain(token_chain_ids)
|
||||
return dataclasses.replace(
|
||||
inference_result,
|
||||
predicted_structure=chimerax_structure,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
def _write_chimerax_output_if_needed(
|
||||
inference_result: model.InferenceResult,
|
||||
output_dir: os.PathLike[str] | str,
|
||||
name: str | None = None,
|
||||
) -> None:
|
||||
"""Writes viewer-safe output files when duplicate residue IDs are present."""
|
||||
if not _has_duplicate_residue_ids_within_chain(
|
||||
inference_result.predicted_structure
|
||||
):
|
||||
return
|
||||
|
||||
processed_result = post_processing.post_process_inference_result(
|
||||
_make_chimerax_compatible_inference_result(inference_result)
|
||||
)
|
||||
prefix = f'{name}_' if name is not None else ''
|
||||
|
||||
with open(os.path.join(output_dir, f'{prefix}model_chimerax.cif'), 'wb') as f:
|
||||
f.write(processed_result.cif)
|
||||
|
||||
with open(
|
||||
os.path.join(output_dir, f'{prefix}summary_confidences_chimerax.json'),
|
||||
'wb',
|
||||
) as f:
|
||||
f.write(processed_result.structure_confidence_summary_json)
|
||||
|
||||
with open(
|
||||
os.path.join(output_dir, f'{prefix}confidences_chimerax.json'),
|
||||
'wb',
|
||||
) as f:
|
||||
f.write(processed_result.structure_full_data_json)
|
||||
|
||||
logging.info(
|
||||
'Wrote ChimeraX-compatible outputs with unique residue numbering for %s',
|
||||
name or os.fspath(output_dir),
|
||||
)
|
||||
|
||||
|
||||
def predict_structure(
|
||||
fold_input: folding_input.Input,
|
||||
model_runner: ModelRunner,
|
||||
|
||||
@@ -1361,6 +1361,54 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
|
||||
self.assertEqual(rebuilt.present_residues.id.tolist(), expected_residue_ids)
|
||||
|
||||
def test_af3_chimerax_export_renumbers_duplicate_residue_ids(self):
|
||||
"""ChimeraX export must assign unique sequential residue IDs per chain."""
|
||||
from alphafold3.common import folding_input
|
||||
from alphafold3.constants import chemical_components
|
||||
from alphafold3.model import model as af3_model
|
||||
from alphapulldown.folding_backend.alphafold3_backend import (
|
||||
_make_chimerax_compatible_inference_result,
|
||||
)
|
||||
|
||||
original_residue_ids = (
|
||||
list(range(1, 11)) + list(range(2, 6)) + list(range(12, 16))
|
||||
)
|
||||
chain = folding_input.ProteinChain(
|
||||
id="A",
|
||||
sequence="ACDEFGHIKLCDEFMNPQ",
|
||||
ptms=[],
|
||||
residue_ids=original_residue_ids,
|
||||
unpaired_msa="",
|
||||
paired_msa="",
|
||||
templates=[],
|
||||
)
|
||||
fold_input = folding_input.Input(
|
||||
name="duplicate_residue_ids_for_chimerax",
|
||||
chains=[chain],
|
||||
rng_seeds=[1],
|
||||
)
|
||||
struc = fold_input.to_structure(ccd=chemical_components.Ccd())
|
||||
inference_result = af3_model.InferenceResult(
|
||||
predicted_structure=struc,
|
||||
metadata={
|
||||
"token_chain_ids": ["A"] * len(original_residue_ids),
|
||||
"token_res_ids": original_residue_ids,
|
||||
},
|
||||
)
|
||||
|
||||
chimerax_result = _make_chimerax_compatible_inference_result(
|
||||
inference_result
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
chimerax_result.predicted_structure.present_residues.id.tolist(),
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
self.assertEqual(
|
||||
chimerax_result.metadata["token_res_ids"],
|
||||
list(range(1, len(original_residue_ids) + 1)),
|
||||
)
|
||||
|
||||
def test_af3_keeps_discontinuous_chopped_regions_in_one_gapped_chain(self):
|
||||
"""AF3 must keep multi-region chopped inputs as one gapped protein chain."""
|
||||
from alphapulldown.folding_backend.alphafold3_backend import (
|
||||
|
||||
Reference in New Issue
Block a user