mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Add issue #588 AF3 regression fixtures and tests
This commit is contained in:
@@ -6,6 +6,7 @@ The script is identical for Slurm and workstation users – only the
|
||||
wrapper decides *how* each case is executed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
import lzma
|
||||
import os
|
||||
import subprocess
|
||||
import time
|
||||
@@ -77,6 +78,51 @@ def _gpu_functional_test_skip_reason() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _a3m_sequences(a3m_text: str) -> list[str]:
|
||||
if not a3m_text:
|
||||
return []
|
||||
lines = [line.strip() for line in a3m_text.splitlines() if line.strip()]
|
||||
return [lines[index] for index in range(1, len(lines), 2)]
|
||||
|
||||
|
||||
def _a3m_query_sequence(a3m_text: str) -> str:
|
||||
sequences = _a3m_sequences(a3m_text)
|
||||
return sequences[0] if sequences else ""
|
||||
|
||||
|
||||
def _a3m_payload_sequences(a3m_text: str) -> list[str]:
|
||||
sequences = _a3m_sequences(a3m_text)
|
||||
return sequences[1:]
|
||||
|
||||
|
||||
def _aligned_a3m_row_length(a3m_row: str) -> int:
|
||||
return len(re.sub(r"[a-z]", "", a3m_row))
|
||||
|
||||
|
||||
def _protein_entries_from_af3_input(payload: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
return [
|
||||
sequence_entry["protein"]
|
||||
for sequence_entry in payload.get("sequences", [])
|
||||
if "protein" in sequence_entry
|
||||
]
|
||||
|
||||
|
||||
def _load_json_payload(path: Path) -> dict[str, Any]:
|
||||
if path.suffix == ".xz":
|
||||
with lzma.open(path, "rt", encoding="utf-8") as handle:
|
||||
return json.load(handle)
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
|
||||
|
||||
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict[str, Any]]:
|
||||
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
|
||||
if len(matches) != 1:
|
||||
raise AssertionError(
|
||||
f"Expected exactly one metadata file for {protein_id} in {feature_dir}, found {matches}"
|
||||
)
|
||||
return matches[0], _load_json_payload(matches[0])
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -1157,6 +1203,195 @@ class _TestBase(parameterized.TestCase):
|
||||
return args
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# backend-only AF3 preparation tests #
|
||||
# --------------------------------------------------------------------------- #
|
||||
class _BackendOnlyTestBase(_TestBase):
|
||||
"""Backend-only AF3 preparation tests that do not run model inference."""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
parameterized.TestCase.setUpClass()
|
||||
if cls.use_temp_dir:
|
||||
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_backend_test_"))
|
||||
else:
|
||||
cls.base_output_dir = Path("test/test_data/predictions/af3_backend")
|
||||
if cls.base_output_dir.exists():
|
||||
try:
|
||||
shutil.rmtree(cls.base_output_dir)
|
||||
except (PermissionError, OSError) as e:
|
||||
print(
|
||||
"Warning: Could not remove existing output directory "
|
||||
f"{cls.base_output_dir}: {e}"
|
||||
)
|
||||
cls.base_output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
"""AF3 input-construction regressions; these tests do not assert end-to-end ipTM quality."""
|
||||
|
||||
def _prepare_fold_input(
|
||||
self,
|
||||
*,
|
||||
fold_spec: str,
|
||||
feature_dir: Path,
|
||||
debug_msas: bool = False,
|
||||
):
|
||||
from alphapulldown.folding_backend.alphafold3_backend import AlphaFold3Backend
|
||||
|
||||
parsed = parse_fold([fold_spec], [str(feature_dir)], "+")
|
||||
data = create_custom_info(parsed)
|
||||
all_interactors = create_interactors(data, [str(feature_dir)])
|
||||
self.assertLen(all_interactors, 1)
|
||||
self.assertGreaterEqual(len(all_interactors[0]), 1)
|
||||
|
||||
interactors = all_interactors[0]
|
||||
if len(interactors) == 1:
|
||||
object_to_model = interactors[0]
|
||||
else:
|
||||
object_to_model = MultimericObject(interactors=interactors, pair_msa=True)
|
||||
|
||||
mappings = AlphaFold3Backend.prepare_input(
|
||||
objects_to_model=[
|
||||
{"object": object_to_model, "output_dir": str(self.output_dir)}
|
||||
],
|
||||
random_seed=42,
|
||||
debug_msas=debug_msas,
|
||||
)
|
||||
self.assertLen(mappings, 1)
|
||||
fold_input_obj, _ = next(iter(mappings[0].items()))
|
||||
return fold_input_obj
|
||||
|
||||
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
|
||||
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
|
||||
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
|
||||
|
||||
issue_588_dir = self.test_features_dir / "issue_588"
|
||||
for protein_id in ("A0ABD7FQG0", "P18004"):
|
||||
metadata_path, metadata = _load_feature_metadata(issue_588_dir, protein_id)
|
||||
other = metadata["other"]
|
||||
self.assertTrue(
|
||||
other["use_mmseqs2"],
|
||||
f"{metadata_path} is not a mmseqs2-generated AF2 fixture.",
|
||||
)
|
||||
self.assertEqual(other["data_pipeline"], "alphafold2")
|
||||
self.assertFalse(other["re_search_templates_mmseqs2"])
|
||||
|
||||
fold_input_obj = self._prepare_fold_input(
|
||||
fold_spec="A0ABD7FQG0+P18004",
|
||||
feature_dir=issue_588_dir,
|
||||
debug_msas=True,
|
||||
)
|
||||
|
||||
protein_chains = [chain for chain in fold_input_obj.chains if hasattr(chain, "sequence")]
|
||||
chain_sequences = {chain.id: chain.sequence for chain in protein_chains}
|
||||
self.assertEqual(sorted(chain_sequences), ["A", "B"])
|
||||
|
||||
job_name = fold_input_obj.sanitised_name()
|
||||
summary_path = self.output_dir / f"{job_name}_af2_to_af3_translation_summary.json"
|
||||
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
|
||||
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
self.assertTrue(summary["paired_rows_valid"])
|
||||
self.assertTrue(summary["unpaired_rows_valid"])
|
||||
self.assertIn(
|
||||
"af3_species_pairing_from_af2_individual_msas",
|
||||
summary["translation_modes"],
|
||||
)
|
||||
self.assertLen(summary["chains"], 2)
|
||||
|
||||
for chain_summary in summary["chains"]:
|
||||
chain_id = chain_summary["chain_id"]
|
||||
expected_sequence = chain_sequences[chain_id]
|
||||
self.assertGreater(
|
||||
chain_summary["paired_msa_row_count"],
|
||||
0,
|
||||
f"Expected non-empty paired MSA rows for chain {chain_id}",
|
||||
)
|
||||
self.assertGreater(
|
||||
chain_summary["unpaired_msa_row_count"],
|
||||
0,
|
||||
f"Expected non-empty unpaired MSA rows for chain {chain_id}",
|
||||
)
|
||||
|
||||
for msa_kind in ("paired_input", "unpaired_input"):
|
||||
msa_path = self.output_dir / f"{job_name}_chain-{chain_id}_{msa_kind}.a3m"
|
||||
self.assertTrue(msa_path.is_file(), f"Missing debug MSA {msa_path}")
|
||||
msa_text = msa_path.read_text(encoding="utf-8")
|
||||
self.assertEqual(_a3m_query_sequence(msa_text), expected_sequence)
|
||||
payload_sequences = _a3m_payload_sequences(msa_text)
|
||||
self.assertGreater(
|
||||
len(payload_sequences),
|
||||
0,
|
||||
f"Expected payload rows in {msa_path}",
|
||||
)
|
||||
for payload_sequence in payload_sequences:
|
||||
self.assertEqual(
|
||||
_aligned_a3m_row_length(payload_sequence),
|
||||
len(expected_sequence),
|
||||
f"Aligned row length mismatch in {msa_path}",
|
||||
)
|
||||
|
||||
process_fold_input(
|
||||
fold_input=fold_input_obj,
|
||||
model_runner=None,
|
||||
output_dir=str(self.output_dir),
|
||||
buckets=(512,),
|
||||
)
|
||||
input_json = self.output_dir / f"{job_name}_data.json"
|
||||
written = json.loads(input_json.read_text(encoding="utf-8"))
|
||||
protein_entries = {
|
||||
protein_entry["id"]: protein_entry
|
||||
for protein_entry in _protein_entries_from_af3_input(written)
|
||||
}
|
||||
self.assertEqual(set(protein_entries), set(chain_sequences))
|
||||
|
||||
for chain_id, protein_entry in protein_entries.items():
|
||||
expected_sequence = chain_sequences[chain_id]
|
||||
self.assertEqual(protein_entry["sequence"], expected_sequence)
|
||||
self.assertEqual(
|
||||
_a3m_query_sequence(protein_entry["pairedMsa"]),
|
||||
expected_sequence,
|
||||
)
|
||||
self.assertEqual(
|
||||
_a3m_query_sequence(protein_entry["unpairedMsa"]),
|
||||
expected_sequence,
|
||||
)
|
||||
# These exact issue-588 fixtures are AF2/mmseqs2-derived and were
|
||||
# generated without MMseqs template re-search. Empty templates
|
||||
# document fixture provenance here, not an AF3 conversion failure.
|
||||
self.assertEqual(protein_entry["templates"], [])
|
||||
|
||||
def test_af3_prepare_input_preserves_templates_for_templated_af2_pkl_features(self):
|
||||
"""Positive control: templated AF2 pkl inputs should keep templates in AF3 JSON."""
|
||||
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
|
||||
|
||||
feature_dir = self.test_features_dir / "af2_features" / "protein"
|
||||
fold_input_obj = self._prepare_fold_input(
|
||||
fold_spec="P61626",
|
||||
feature_dir=feature_dir,
|
||||
)
|
||||
|
||||
self.assertLen(fold_input_obj.chains, 1)
|
||||
self.assertGreater(len(fold_input_obj.chains[0].templates), 0)
|
||||
|
||||
process_fold_input(
|
||||
fold_input=fold_input_obj,
|
||||
model_runner=None,
|
||||
output_dir=str(self.output_dir),
|
||||
buckets=(512,),
|
||||
)
|
||||
input_json = self.output_dir / f"{fold_input_obj.sanitised_name()}_data.json"
|
||||
written = json.loads(input_json.read_text(encoding="utf-8"))
|
||||
protein_entries = _protein_entries_from_af3_input(written)
|
||||
|
||||
self.assertLen(protein_entries, 1)
|
||||
self.assertGreater(len(protein_entries[0]["templates"]), 0)
|
||||
self.assertTrue(
|
||||
all(template["mmcif"] for template in protein_entries[0]["templates"])
|
||||
)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# parameterised "run mode" tests #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -2387,6 +2622,61 @@ class TestAlphaFold3RunModes(_TestBase):
|
||||
)
|
||||
self._assert_af3_outputs_present(current_output_dir)
|
||||
|
||||
def test_af3_run_multimer_jobs_multiple_json_jobs_create_per_job_subdirs(self):
|
||||
"""Shared AF3 wrapper output roots must isolate multiple JSON jobs by subdirectory."""
|
||||
from alphapulldown.utils.output_paths import derive_af3_job_name_from_json
|
||||
|
||||
self._require_af3_functional_environment()
|
||||
env = self._make_af3_test_env()
|
||||
flash_impl = self._af3_flash_attention_impl()
|
||||
json_inputs = [
|
||||
self.test_features_dir / "protein_with_ptms.json",
|
||||
self.test_features_dir / "P01308_af3_input.json",
|
||||
]
|
||||
protein_list = self.output_dir / "test_multiple_json_jobs.txt"
|
||||
protein_list.write_text(
|
||||
"\n".join(json_input.name for json_input in json_inputs) + "\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
res = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_multimer),
|
||||
"--num_cycle=1",
|
||||
"--num_predictions_per_model=1",
|
||||
f"--data_dir={DATA_DIR}",
|
||||
f"--monomer_objects_dir={self.test_features_dir}",
|
||||
f"--output_path={self.output_dir}",
|
||||
"--mode=custom",
|
||||
f"--protein_lists={protein_list}",
|
||||
"--fold_backend=alphafold3",
|
||||
f"--flash_attention_implementation={flash_impl}",
|
||||
"--num_diffusion_samples=1",
|
||||
"--use_ap_style",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
print(res.stdout)
|
||||
print(res.stderr)
|
||||
self.assertEqual(res.returncode, 0, "sub-process failed")
|
||||
self.assertFalse(
|
||||
(self.output_dir / "ranking_scores.csv").exists(),
|
||||
"Shared wrapper output root should not contain flattened AF3 JSON outputs.",
|
||||
)
|
||||
|
||||
for json_input in json_inputs:
|
||||
current_output_dir = self.output_dir / derive_af3_job_name_from_json(
|
||||
str(json_input)
|
||||
)
|
||||
self.assertTrue(
|
||||
current_output_dir.is_dir(),
|
||||
f"Expected per-job output directory {current_output_dir} to be created.",
|
||||
)
|
||||
self._assert_af3_outputs_present(current_output_dir)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
dict(testcase_name="monomer", protein_list="test_monomer.txt", script="run_structure_prediction.py"),
|
||||
dict(testcase_name="dimer", protein_list="test_dimer.txt", script="run_structure_prediction.py"),
|
||||
|
||||
BIN
test/test_data/features/issue_588/A0ABD7FQG0.pkl.xz
Executable file
BIN
test/test_data/features/issue_588/A0ABD7FQG0.pkl.xz
Executable file
Binary file not shown.
BIN
test/test_data/features/issue_588/A0ABD7FQG0_feature_metadata_2026-03-16.json.xz
Executable file
BIN
test/test_data/features/issue_588/A0ABD7FQG0_feature_metadata_2026-03-16.json.xz
Executable file
Binary file not shown.
BIN
test/test_data/features/issue_588/P18004.pkl.xz
Executable file
BIN
test/test_data/features/issue_588/P18004.pkl.xz
Executable file
Binary file not shown.
BIN
test/test_data/features/issue_588/P18004_feature_metadata_2026-03-16.json.xz
Executable file
BIN
test/test_data/features/issue_588/P18004_feature_metadata_2026-03-16.json.xz
Executable file
Binary file not shown.
Reference in New Issue
Block a user