Add issue #588 AF3 regression fixtures and tests

This commit is contained in:
Dima
2026-03-27 09:25:30 +01:00
parent f387696724
commit 2223191a05
5 changed files with 290 additions and 0 deletions

View File

@@ -6,6 +6,7 @@ The script is identical for Slurm and workstation users only the
wrapper decides *how* each case is executed.
"""
from __future__ import annotations
import lzma
import os
import subprocess
import time
@@ -77,6 +78,51 @@ def _gpu_functional_test_skip_reason() -> str | None:
return None
def _a3m_sequences(a3m_text: str) -> list[str]:
if not a3m_text:
return []
lines = [line.strip() for line in a3m_text.splitlines() if line.strip()]
return [lines[index] for index in range(1, len(lines), 2)]
def _a3m_query_sequence(a3m_text: str) -> str:
sequences = _a3m_sequences(a3m_text)
return sequences[0] if sequences else ""
def _a3m_payload_sequences(a3m_text: str) -> list[str]:
sequences = _a3m_sequences(a3m_text)
return sequences[1:]
def _aligned_a3m_row_length(a3m_row: str) -> int:
return len(re.sub(r"[a-z]", "", a3m_row))
def _protein_entries_from_af3_input(payload: dict[str, Any]) -> list[dict[str, Any]]:
return [
sequence_entry["protein"]
for sequence_entry in payload.get("sequences", [])
if "protein" in sequence_entry
]
def _load_json_payload(path: Path) -> dict[str, Any]:
if path.suffix == ".xz":
with lzma.open(path, "rt", encoding="utf-8") as handle:
return json.load(handle)
return json.loads(path.read_text(encoding="utf-8"))
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict[str, Any]]:
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
if len(matches) != 1:
raise AssertionError(
f"Expected exactly one metadata file for {protein_id} in {feature_dir}, found {matches}"
)
return matches[0], _load_json_payload(matches[0])
# --------------------------------------------------------------------------- #
# common helper mix-in / assertions #
# --------------------------------------------------------------------------- #
@@ -1157,6 +1203,195 @@ class _TestBase(parameterized.TestCase):
return args
# --------------------------------------------------------------------------- #
# backend-only AF3 preparation tests #
# --------------------------------------------------------------------------- #
class _BackendOnlyTestBase(_TestBase):
"""Backend-only AF3 preparation tests that do not run model inference."""
@classmethod
def setUpClass(cls):
parameterized.TestCase.setUpClass()
if cls.use_temp_dir:
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_backend_test_"))
else:
cls.base_output_dir = Path("test/test_data/predictions/af3_backend")
if cls.base_output_dir.exists():
try:
shutil.rmtree(cls.base_output_dir)
except (PermissionError, OSError) as e:
print(
"Warning: Could not remove existing output directory "
f"{cls.base_output_dir}: {e}"
)
cls.base_output_dir.mkdir(parents=True, exist_ok=True)
class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
"""AF3 input-construction regressions; these tests do not assert end-to-end ipTM quality."""
def _prepare_fold_input(
self,
*,
fold_spec: str,
feature_dir: Path,
debug_msas: bool = False,
):
from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend
parsed = parse_fold([fold_spec], [str(feature_dir)], "+")
data = create_custom_info(parsed)
all_interactors = create_interactors(data, [str(feature_dir)])
self.assertLen(all_interactors, 1)
self.assertGreaterEqual(len(all_interactors[0]), 1)
interactors = all_interactors[0]
if len(interactors) == 1:
object_to_model = interactors[0]
else:
object_to_model = MultimericObject(interactors=interactors, pair_msa=True)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{"object": object_to_model, "output_dir": str(self.output_dir)}
],
random_seed=42,
debug_msas=debug_msas,
)
self.assertLen(mappings, 1)
fold_input_obj, _ = next(iter(mappings[0].items()))
return fold_input_obj
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
issue_588_dir = self.test_features_dir / "issue_588"
for protein_id in ("A0ABD7FQG0", "P18004"):
metadata_path, metadata = _load_feature_metadata(issue_588_dir, protein_id)
other = metadata["other"]
self.assertTrue(
other["use_mmseqs2"],
f"{metadata_path} is not a mmseqs2-generated AF2 fixture.",
)
self.assertEqual(other["data_pipeline"], "alphafold2")
self.assertFalse(other["re_search_templates_mmseqs2"])
fold_input_obj = self._prepare_fold_input(
fold_spec="A0ABD7FQG0+P18004",
feature_dir=issue_588_dir,
debug_msas=True,
)
protein_chains = [chain for chain in fold_input_obj.chains if hasattr(chain, "sequence")]
chain_sequences = {chain.id: chain.sequence for chain in protein_chains}
self.assertEqual(sorted(chain_sequences), ["A", "B"])
job_name = fold_input_obj.sanitised_name()
summary_path = self.output_dir / f"{job_name}_af2_to_af3_translation_summary.json"
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
self.assertIn(
"af3_species_pairing_from_af2_individual_msas",
summary["translation_modes"],
)
self.assertLen(summary["chains"], 2)
for chain_summary in summary["chains"]:
chain_id = chain_summary["chain_id"]
expected_sequence = chain_sequences[chain_id]
self.assertGreater(
chain_summary["paired_msa_row_count"],
0,
f"Expected non-empty paired MSA rows for chain {chain_id}",
)
self.assertGreater(
chain_summary["unpaired_msa_row_count"],
0,
f"Expected non-empty unpaired MSA rows for chain {chain_id}",
)
for msa_kind in ("paired_input", "unpaired_input"):
msa_path = self.output_dir / f"{job_name}_chain-{chain_id}_{msa_kind}.a3m"
self.assertTrue(msa_path.is_file(), f"Missing debug MSA {msa_path}")
msa_text = msa_path.read_text(encoding="utf-8")
self.assertEqual(_a3m_query_sequence(msa_text), expected_sequence)
payload_sequences = _a3m_payload_sequences(msa_text)
self.assertGreater(
len(payload_sequences),
0,
f"Expected payload rows in {msa_path}",
)
for payload_sequence in payload_sequences:
self.assertEqual(
_aligned_a3m_row_length(payload_sequence),
len(expected_sequence),
f"Aligned row length mismatch in {msa_path}",
)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{job_name}_data.json"
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = {
protein_entry["id"]: protein_entry
for protein_entry in _protein_entries_from_af3_input(written)
}
self.assertEqual(set(protein_entries), set(chain_sequences))
for chain_id, protein_entry in protein_entries.items():
expected_sequence = chain_sequences[chain_id]
self.assertEqual(protein_entry["sequence"], expected_sequence)
self.assertEqual(
_a3m_query_sequence(protein_entry["pairedMsa"]),
expected_sequence,
)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
expected_sequence,
)
# These exact issue-588 fixtures are AF2/mmseqs2-derived and were
# generated without MMseqs template re-search. Empty templates
# document fixture provenance here, not an AF3 conversion failure.
self.assertEqual(protein_entry["templates"], [])
def test_af3_prepare_input_preserves_templates_for_templated_af2_pkl_features(self):
"""Positive control: templated AF2 pkl inputs should keep templates in AF3 JSON."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
feature_dir = self.test_features_dir / "af2_features" / "protein"
fold_input_obj = self._prepare_fold_input(
fold_spec="P61626",
feature_dir=feature_dir,
)
self.assertLen(fold_input_obj.chains, 1)
self.assertGreater(len(fold_input_obj.chains[0].templates), 0)
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=str(self.output_dir),
buckets=(512,),
)
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = _protein_entries_from_af3_input(written)
self.assertLen(protein_entries, 1)
self.assertGreater(len(protein_entries[0]["templates"]), 0)
self.assertTrue(
all(template["mmcif"] for template in protein_entries[0]["templates"])
)
# --------------------------------------------------------------------------- #
# parameterised "run mode" tests #
# --------------------------------------------------------------------------- #
@@ -2387,6 +2622,61 @@ class TestAlphaFold3RunModes(_TestBase):
)
self._assert_af3_outputs_present(current_output_dir)
def test_af3_run_multimer_jobs_multiple_json_jobs_create_per_job_subdirs(self):
"""Shared AF3 wrapper output roots must isolate multiple JSON jobs by subdirectory."""
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_inputs = [
self.test_features_dir / "protein_with_ptms.json",
self.test_features_dir / "P01308_af3_input.json",
]
protein_list = self.output_dir / "test_multiple_json_jobs.txt"
protein_list.write_text(
"\n".join(json_input.name for json_input in json_inputs) + "\n",
encoding="utf-8",
)
res = subprocess.run(
[
sys.executable,
str(self.script_multimer),
"--num_cycle=1",
"--num_predictions_per_model=1",
f"--data_dir={DATA_DIR}",
f"--monomer_objects_dir={self.test_features_dir}",
f"--output_path={self.output_dir}",
"--mode=custom",
f"--protein_lists={protein_list}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--use_ap_style",
],
capture_output=True,
text=True,
env=env,
)
print(res.stdout)
print(res.stderr)
self.assertEqual(res.returncode, 0, "sub-process failed")
self.assertFalse(
(self.output_dir / "ranking_scores.csv").exists(),
"Shared wrapper output root should not contain flattened AF3 JSON outputs.",
)
for json_input in json_inputs:
current_output_dir = self.output_dir / derive_af3_job_name_from_json(
str(json_input)
)
self.assertTrue(
current_output_dir.is_dir(),
f"Expected per-job output directory {current_output_dir} to be created.",
)
self._assert_af3_outputs_present(current_output_dir)
@parameterized.named_parameters(
dict(testcase_name="monomer", protein_list="test_monomer.txt", script="run_structure_prediction.py"),
dict(testcase_name="dimer", protein_list="test_dimer.txt", script="run_structure_prediction.py"),

Binary file not shown.

Binary file not shown.