Add MMseqs regression coverage for pairing and template reuse

Add opt-in AF2 and AF3 inference regressions that regenerate fresh MMseqs-derived AF2 features for the issue #588 A0ABD7FQG0/P18004 reproducer, verify recovered species identifiers, and exercise the existing wrapper entrypoints.\n\nAdd FASTA fixtures for that reproducer so the tests do not depend on a live UniProt download.\n\nAdd focused coverage for the MMseqs precomputed-MSA plus template re-search branch in make_mmseq_features(), and relax the outdated AF3 wrapper so it can submit Slurm jobs from login nodes without requiring a local GPU.
This commit is contained in:
Dima
2026-03-27 13:07:47 +01:00
parent 53a75e14c8
commit 6cd6511a77
6 changed files with 438 additions and 3 deletions

View File

@@ -14,6 +14,7 @@ import sys
import tempfile
import logging
import unittest
import lzma
from pathlib import Path
from absl.testing import absltest, parameterized
@@ -70,6 +71,34 @@ def _gpu_functional_test_skip_reason() -> str | None:
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
return None
def _mmseqs_functional_test_skip_reason() -> str | None:
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
return None
return (
"MMseqs functional inference tests are disabled by default. "
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
)
def _load_feature_dict(feature_path: Path) -> dict:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
payload = pickle.load(handle)
if hasattr(payload, "feature_dict"):
return payload.feature_dict
return payload
def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
if isinstance(value, bytes):
value = value.decode("utf-8")
if str(value).strip():
count += 1
return count
# --------------------------------------------------------------------------- #
# common helper mix-in / assertions #
# --------------------------------------------------------------------------- #
@@ -125,6 +154,9 @@ class _TestBase(parameterized.TestCase):
apd_path = Path(alphapulldown.__path__[0])
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
self.script_create_features = (
apd_path / "scripts" / "create_individual_features.py"
)
# ---------------- assertions reused by all subclasses ----------------- #
def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirname: str | None = None):
@@ -439,5 +471,140 @@ class TestDropoutDiversity(_TestBase):
# The test passes if calculations succeed - the diversity check is informational
class TestMmseqsIssue588Inference(_TestBase):
"""Opt-in end-to-end regression for freshly generated mmseq AF2 features."""
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
def _require_mmseqs_functional_environment(self) -> None:
skip_reason = _mmseqs_functional_test_skip_reason()
if skip_reason:
self.skipTest(skip_reason)
for protein_id in self.ISSUE_588_IDS:
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
self.assertTrue(
fasta_path.is_file(),
f"Missing FASTA fixture {fasta_path}",
)
def _generate_issue_588_mmseq_features(self) -> Path:
feature_dir = self.output_dir / "issue_588_mmseq_features"
feature_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
args = [
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={feature_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
]
res = subprocess.run(args, capture_output=True, text=True)
self.assertEqual(
res.returncode,
0,
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
return feature_dir
def _resolve_af2_result_dir(self, root: Path) -> Path:
if (root / "ranking_debug.json").exists():
return root
candidates = sorted(
path.parent for path in root.rglob("ranking_debug.json")
)
self.assertEqual(
len(candidates),
1,
f"Expected one AF2 result directory under {root}, found {candidates}",
)
return candidates[0]
def test_issue_588_mmseqs_generated_features_enable_af2_multimer_inference(self):
from alphafold.data import feature_processing
from alphafold.data import msa_pairing
from alphafold.data import pipeline_multimer
self._require_mmseqs_functional_environment()
feature_dir = self._generate_issue_588_mmseq_features()
converted_chains = {}
for chain_id, protein_id in zip(("A", "B"), self.ISSUE_588_IDS):
feature_path = feature_dir / f"{protein_id}.pkl.xz"
feature_dict = _load_feature_dict(feature_path)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
)
converted_chains[chain_id] = pipeline_multimer.convert_monomer_features(
feature_dict,
chain_id,
)
assembly_features = pipeline_multimer.add_assembly_features(converted_chains)
np_chains = list(assembly_features.values())
feature_processing.process_unmerged_features(np_chains)
paired_rows = msa_pairing.pair_sequences(np_chains)
self.assertGreater(
paired_rows.shape[0],
1,
"Fresh mmseq AF2 features should produce paired rows beyond the query",
)
prediction_dir = self.output_dir / "af2_prediction"
prediction_dir.mkdir(parents=True, exist_ok=True)
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={prediction_dir}",
"--num_cycle=1",
"--num_predictions_per_model=1",
"--model_names=model_4_multimer_v3",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--random_seed=42",
],
capture_output=True,
text=True,
)
self.assertEqual(
res.returncode,
0,
f"AF2 inference failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
result_dir = self._resolve_af2_result_dir(prediction_dir)
ranking_payload = json.loads(
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
)
self.assertTrue(ranking_payload["order"])
result_pickles = sorted(result_dir.glob("result_*.pkl"))
self.assertLen(result_pickles, 1)
with result_pickles[0].open("rb") as handle:
result_payload = pickle.load(handle)
self.assertIn("iptm", result_payload)
self.assertIn("ranking_confidence", result_payload)
if __name__ == "__main__":
absltest.main()

View File

@@ -78,6 +78,15 @@ def _gpu_functional_test_skip_reason() -> str | None:
return None
def _mmseqs_functional_test_skip_reason() -> str | None:
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
return None
return (
"MMseqs functional inference tests are disabled by default. "
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
)
def _a3m_sequences(a3m_text: str) -> list[str]:
if not a3m_text:
return []
@@ -139,6 +148,25 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
return _a3m_payload_sequences(a3m_text) if a3m_text else []
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
payload = pickle.load(handle)
if hasattr(payload, "feature_dict"):
return payload.feature_dict
return payload
def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
if isinstance(value, bytes):
value = value.decode("utf-8")
if str(value).strip():
count += 1
return count
# --------------------------------------------------------------------------- #
# common helper mix-in / assertions #
# --------------------------------------------------------------------------- #
@@ -185,6 +213,9 @@ class _TestBase(parameterized.TestCase):
apd_path = Path(alphapulldown.__path__[0])
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
self.script_create_features = (
apd_path / "scripts" / "create_individual_features.py"
)
@classmethod
def tearDownClass(cls):
@@ -1438,6 +1469,124 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
)
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
def _require_mmseqs_functional_environment(self) -> None:
self._require_af3_functional_environment()
skip_reason = _mmseqs_functional_test_skip_reason()
if skip_reason:
self.skipTest(skip_reason)
for protein_id in self.ISSUE_588_IDS:
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
self.assertTrue(
fasta_path.is_file(),
f"Missing FASTA fixture {fasta_path}",
)
def _generate_issue_588_mmseq_features(self, env: Dict[str, str]) -> Path:
feature_dir = self.output_dir / "issue_588_mmseq_features"
feature_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
res = subprocess.run(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={feature_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
],
capture_output=True,
text=True,
env=env,
)
self.assertEqual(
res.returncode,
0,
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
return feature_dir
def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference(self):
self._require_mmseqs_functional_environment()
env = self._make_af3_test_env()
feature_dir = self._generate_issue_588_mmseq_features(env)
for protein_id in self.ISSUE_588_IDS:
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
)
flash_impl = self._af3_flash_attention_impl()
res = subprocess.run(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={self.output_dir}",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--fold_backend=alphafold3",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--random_seed=42",
"--debug_msas",
],
capture_output=True,
text=True,
env=env,
)
self._runCommonTests(res)
result_dir = self._resolve_single_af3_result_dir()
summary_paths = sorted(
result_dir.glob("*_af2_to_af3_translation_summary.json")
)
self.assertLen(summary_paths, 1)
summary = json.loads(summary_paths[0].read_text(encoding="utf-8"))
self.assertEqual(
summary["translation_modes"],
["af3_species_pairing_from_af2_individual_msas"],
)
self.assertTrue(summary["paired_rows_valid"])
self.assertTrue(summary["unpaired_rows_valid"])
for chain_summary in summary["chains"]:
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
self.assertGreater(chain_summary["paired_species_identifier_count"], 0)
confidence_files = sorted(result_dir.glob("*_summary_confidences.json"))
self.assertLen(confidence_files, 1)
confidence_payload = json.loads(
confidence_files[0].read_text(encoding="utf-8")
)
self.assertIn("iptm", confidence_payload)
self.assertGreaterEqual(confidence_payload["iptm"], 0.0)
self.assertLessEqual(confidence_payload["iptm"], 1.0)
# --------------------------------------------------------------------------- #
# parameterised "run mode" tests #
# --------------------------------------------------------------------------- #

View File

@@ -154,7 +154,7 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
# ---------- per-test set-up ------------------------------------------- #
def setUp(self):
super().setUp()
if not _has_gpu():
if not _has_gpu() and not self._is_slurm_available():
self.skipTest("NVIDIA GPU not detected skipping Alphafold3 tests")
# Check for correct conda environment
@@ -239,7 +239,7 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
log_path = self.case_dir / f"test_{idx}_{cls_name}_{test_name}.log"
res = subprocess.run(
["sbatch", f"--output={log_path}", str(script_path)],
["sbatch", "--export=ALL", f"--output={log_path}", str(script_path)],
text=True,
capture_output=True,
check=True,
@@ -263,6 +263,12 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
{"testcase_name": "homo_oligomer", "i": 5, "cls": "TestAlphaFold3RunModes", "test": "test__homo_oligomer"},
{"testcase_name": "chopped_dimer", "i": 6, "cls": "TestAlphaFold3RunModes", "test": "test__chopped_dimer"},
{"testcase_name": "long_name", "i": 7, "cls": "TestAlphaFold3RunModes", "test": "test__long_name"},
{
"testcase_name": "issue_588_mmseqs_inference",
"i": 8,
"cls": "TestAlphaFold3MmseqsIssue588Inference",
"test": "test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference",
},
)
def test_predict_structure(self, i: int, cls: str, test: str):
"""Route each parameterised test either through Slurm or local run."""
@@ -277,4 +283,4 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
absltest.main()
absltest.main()

View File

@@ -0,0 +1,2 @@
>A0ABD7FQG0
MSGDENKLKKYRFPETLTNQSRWFGLPLDELIPAAICIGWGITTSKYLFGIGAAVLVYFGIKKLKKGRGSSWLRDLIYWYMPTALLRGIFHNVPDSCFRQWIK

View File

@@ -0,0 +1,2 @@
>P18004
MNNPLEAVTQAVNSLVTALKLPDESAKANEVLGEMSFPQFSRLLPYRDYNQESGLFMNDTTMGFMLEAIPINGANESIVEALDHMLRTKLPRGIPLCIHLMSSQLVGDRIEYGLREFSWSGEQAERFNAITRAYYMKAAATQFPLPEGMNLPLTLRHYRVFISYCSPSKKKSRADILEMENLVKIIRASLQGASITTQTVDAQAFIDIVGEMINHNPDSLYPKRRQLDPYSDLNYQCVEDSFDLKVRADYLTLGLRENGRNSTARILNFHLARNPEIAFLWNMADNYSNLLNPELSISCPFILTLTLVVEDQVKTHSEANLKYMDLEKKSKTSYAKWFPSVEKEAKEWGELRQRLGSGQSSVVSYFLNITAFCKDNNETALEVEQDILNSFRKNGFELISPRFNHMRNFLTCLPFMAGKGLFKQLKEAGVVQRAESFNVANLMPLVADNPLTPAGLLAPTYRNQLAFIDIFFRGMNNTNYNMAVCGTSGAGKTGLIQPLIRSVLDSGGFAVVFDMGDGYKSLCENMGGVYLDGETLRFNPFANITDIDQSAERVRDQLSVMASPNGNLDEVHEGLLLQAVRASWLAKENRARIDDVVDFLKNASDSEQYAESPTIRSRLDEMIVLLDQYTANGTYGQYFNSDEPSLRDDAKMVVLELGGLEDRPSLLVAVMFSLIIYIENRMYRTPRNLKKLNVIDEGWRLLDFKNHKVGEFIEKGYRTARRHTGAYITITQNIVDFDSDKASSAARAAWGNSSYKIILKQSAKEFAKYNQLYPDQFLPLQRDMIGKFGAAKDQWFSSFLLQVENHSSWHRLFVDPLSRAMYSSDGPDFEFVQQKRKEGLSIHEAVWQLAWKKSGPEMASLEAWLEEHEKYRSVA

View File

@@ -3,6 +3,7 @@ import numpy as np
from alphafold.data import msa_pairing
from alphafold.data import parsers
from alphafold.data import pipeline
from alphapulldown.objects import MonomericObject
from alphapulldown.utils import mmseqs_species_identifiers
@@ -115,3 +116,111 @@ def test_pair_sequences_works_with_mmseqs_accession_species_resolution(
assert paired_rows.shape == (3, 2)
assert tuple(paired_rows[0]) == (0, 0)
assert {tuple(row) for row in paired_rows[1:]} == {(1, 1), (2, 2)}
def test_make_mmseq_features_researches_templates_for_precomputed_msa(
monkeypatch,
tmp_path,
):
import alphapulldown.objects as objects_mod
a3m_path = tmp_path / 'dummy.a3m'
a3m_text = '\n'.join([
'# header line that should be ignored later',
'>101',
'ACDE',
'',
])
a3m_path.write_text(a3m_text, encoding='utf-8')
calls = {}
def fake_unserialize_msa(a3m_lines, sequence):
calls['unserialize_msa'] = {
'a3m_lines': a3m_lines,
'sequence': sequence,
}
return (
['PRECOMPUTED_UNPAIRED'],
['PRECOMPUTED_PAIRED'],
['PRECOMPUTED_UNIQUE'],
['PRECOMPUTED_CARDINALITY'],
['PRECOMPUTED_TEMPLATE'],
)
def fake_get_msa_and_templates(**kwargs):
calls['get_msa_and_templates'] = kwargs
return (
['IGNORED_UNPAIRED'],
['IGNORED_PAIRED'],
['IGNORED_UNIQUE'],
['IGNORED_CARDINALITY'],
['TEMPLATE_FROM_RESEARCH'],
)
def fake_build_monomer_feature(sequence, msa, template_feature):
calls['build_monomer_feature'] = {
'sequence': sequence,
'msa': msa,
'template_feature': template_feature,
}
return {
'template_confidence_scores': None,
'template_release_date': None,
}
def fake_enrich(feature_dict, a3m, **_kwargs):
calls['enrich_mmseq_feature_dict_with_identifiers'] = a3m
feature_dict['msa_species_identifiers'] = np.asarray([b''])
feature_dict['msa_uniprot_accession_identifiers'] = np.asarray([b''])
monkeypatch.setattr(objects_mod, 'unserialize_msa', fake_unserialize_msa)
monkeypatch.setattr(
objects_mod,
'get_msa_and_templates',
fake_get_msa_and_templates,
)
monkeypatch.setattr(
objects_mod,
'build_monomer_feature',
fake_build_monomer_feature,
)
monkeypatch.setattr(
objects_mod,
'enrich_mmseq_feature_dict_with_identifiers',
fake_enrich,
)
monomer = MonomericObject('dummy', 'ACDE')
monomer.make_mmseq_features(
DEFAULT_API_SERVER='https://fake.server',
output_dir=str(tmp_path),
use_precomputed_msa=True,
use_templates=True,
)
assert calls['unserialize_msa']['sequence'] == 'ACDE'
assert calls['unserialize_msa']['a3m_lines'] == ['>101\nACDE']
assert calls['get_msa_and_templates'] == {
'jobname': 'dummy',
'query_sequences': 'ACDE',
'a3m_lines': False,
'result_dir': tmp_path,
'msa_mode': 'single_sequence',
'use_templates': True,
'custom_template_path': None,
'pair_mode': 'none',
'host_url': 'https://fake.server',
'user_agent': 'alphapulldown',
}
assert calls['build_monomer_feature'] == {
'sequence': 'ACDE',
'msa': 'PRECOMPUTED_UNPAIRED',
'template_feature': 'TEMPLATE_FROM_RESEARCH',
}
assert (
calls['enrich_mmseq_feature_dict_with_identifiers']
== 'PRECOMPUTED_UNPAIRED'
)
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
assert monomer.feature_dict['template_release_date'] == ['none']