mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
618 lines
25 KiB
Python
Executable File
618 lines
25 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
Functional Alphapulldown alphafold2 backend tests.
|
|
Needs GPU(s) to run.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import os
|
|
import json
|
|
import pickle
|
|
import shutil
|
|
import subprocess
|
|
import sys
|
|
import tempfile
|
|
import logging
|
|
import unittest
|
|
import lzma
|
|
from pathlib import Path
|
|
|
|
from absl.testing import absltest, parameterized
|
|
|
|
import alphapulldown
|
|
from alphapulldown_input_parser import generate_fold_specifications
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# configuration / logging #
|
|
# --------------------------------------------------------------------------- #
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
|
TEST_ROOT = REPO_ROOT / "test"
|
|
|
|
DATA_DIR = Path(os.getenv("ALPHAFOLD_DATA_DIR", "/scratch/AlphaFold_DBs/2.3.0"))
|
|
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/scratch/dima/jax_cache"
|
|
#os.environ["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=8"
|
|
#os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
|
|
#os.environ["XLA_CLIENT_MEM_FRACTION"] = "0.95"
|
|
#os.environ["JAX_FLASH_ATTENTION_IMPL"] = "xla"
|
|
|
|
#FAST = os.getenv("ALPHAFOLD_FAST", "1") != "0" # <- no difference in performance
|
|
#if FAST:
|
|
# from alphafold.model import config
|
|
# config.CONFIG_MULTIMER.model.embeddings_and_evoformer.evoformer_num_block = 1
|
|
|
|
|
|
def _has_nvidia_gpu() -> bool:
|
|
nvidia_smi = shutil.which("nvidia-smi")
|
|
if not nvidia_smi:
|
|
return False
|
|
try:
|
|
result = subprocess.run(
|
|
[nvidia_smi, "-L"],
|
|
capture_output=True,
|
|
text=True,
|
|
check=False,
|
|
)
|
|
except OSError:
|
|
return False
|
|
return result.returncode == 0 and bool(result.stdout.strip())
|
|
|
|
|
|
def _gpu_functional_test_skip_reason() -> str | None:
|
|
if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
|
return None
|
|
if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv(
|
|
"GITHUB_ACTIONS", ""
|
|
).lower() == "true":
|
|
return (
|
|
"GPU functional tests are disabled on CI/CD. "
|
|
"Set RUN_GPU_FUNCTIONAL_TESTS=1 to override."
|
|
)
|
|
if not _has_nvidia_gpu():
|
|
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
|
|
return None
|
|
|
|
|
|
def _mmseqs_functional_test_skip_reason() -> str | None:
|
|
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
|
return None
|
|
return (
|
|
"MMseqs functional inference tests are disabled by default. "
|
|
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
|
|
)
|
|
|
|
|
|
def _load_feature_dict(feature_path: Path) -> dict:
|
|
opener = lzma.open if feature_path.suffix == ".xz" else open
|
|
with opener(feature_path, "rb") as handle:
|
|
payload = pickle.load(handle)
|
|
if hasattr(payload, "feature_dict"):
|
|
return payload.feature_dict
|
|
return payload
|
|
|
|
|
|
def _non_empty_identifier_count(values) -> int:
|
|
count = 0
|
|
for value in values:
|
|
if isinstance(value, bytes):
|
|
value = value.decode("utf-8")
|
|
if str(value).strip():
|
|
count += 1
|
|
return count
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# common helper mix-in / assertions #
|
|
# --------------------------------------------------------------------------- #
|
|
class _TestBase(parameterized.TestCase):
|
|
use_temp_dir = True
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
skip_reason = _gpu_functional_test_skip_reason()
|
|
if skip_reason:
|
|
raise unittest.SkipTest(skip_reason)
|
|
# do the skip here so import-time doesn't abort discovery
|
|
#if not DATA_DIR.is_dir():
|
|
# cls.skipTest(f"set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests (tried {DATA_DIR!r})")
|
|
|
|
# Create base output dir
|
|
if cls.use_temp_dir:
|
|
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af2_test_"))
|
|
else:
|
|
cls.base_output_dir = Path("test/test_data/predictions/af2_backend")
|
|
if cls.base_output_dir.exists():
|
|
try:
|
|
shutil.rmtree(cls.base_output_dir)
|
|
except (PermissionError, OSError) as e:
|
|
logger.warning("Could not remove %s: %s", cls.base_output_dir, e)
|
|
cls.base_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
if cls.use_temp_dir and cls.base_output_dir.exists():
|
|
try:
|
|
shutil.rmtree(cls.base_output_dir)
|
|
except (PermissionError, OSError) as e:
|
|
logger.warning("Could not remove temporary directory %s: %s", cls.base_output_dir, e)
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
self.test_data_dir = TEST_ROOT / "test_data"
|
|
self.test_features_dir = self.test_data_dir / "features"
|
|
self.test_protein_lists_dir = self.test_data_dir / "protein_lists"
|
|
self.test_modelling_dir = self.test_data_dir / "predictions"
|
|
# setUpClass already resolved this to either a temp root or the legacy shared root
|
|
self.af2_backend_dir = self.base_output_dir
|
|
|
|
test_name = self._testMethodName
|
|
self.output_dir = self.af2_backend_dir / test_name
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
apd_path = Path(alphapulldown.__path__[0])
|
|
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
|
|
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
|
|
self.script_create_features = (
|
|
apd_path / "scripts" / "create_individual_features.py"
|
|
)
|
|
|
|
# ---------------- assertions reused by all subclasses ----------------- #
|
|
def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirname: str | None = None):
|
|
if res.returncode != 0:
|
|
self.fail(
|
|
f"Subprocess failed (code {res.returncode})\n"
|
|
f"STDOUT:\n{res.stdout}\n"
|
|
f"STDERR:\n{res.stderr}"
|
|
)
|
|
|
|
if dirname is not None:
|
|
folders = [self.output_dir / dirname]
|
|
else:
|
|
folders = [d for d in self.output_dir.iterdir() if d.is_dir()]
|
|
|
|
for folder in folders:
|
|
files = list(folder.iterdir())
|
|
|
|
self.assertEqual(
|
|
len([f for f in files if f.name.startswith("ranked") and f.suffix == ".pdb"]),
|
|
5
|
|
)
|
|
|
|
pkls = [f for f in files if f.name.startswith("result") and f.suffix == ".pkl"]
|
|
self.assertEqual(len(pkls), 5)
|
|
|
|
example = pickle.load(pkls[0].open("rb"))
|
|
keys_multimer = {
|
|
"experimentally_resolved",
|
|
"predicted_aligned_error",
|
|
"predicted_lddt",
|
|
"structure_module",
|
|
"plddt",
|
|
"max_predicted_aligned_error",
|
|
"seqs",
|
|
"iptm",
|
|
"ptm",
|
|
"ranking_confidence",
|
|
}
|
|
keys_monomer = keys_multimer - {"iptm"}
|
|
expected_keys = keys_multimer if multimer else keys_monomer
|
|
self.assertTrue(expected_keys <= example.keys())
|
|
|
|
self.assertEqual(len([f for f in files if f.name.startswith("pae") and f.suffix == ".json"]), 5)
|
|
self.assertEqual(len([f for f in files if f.suffix == ".png"]), 5)
|
|
names = {f.name for f in files}
|
|
self.assertIn("ranking_debug.json", names)
|
|
self.assertIn("timings.json", names)
|
|
|
|
ranking = json.loads((folder / "ranking_debug.json").read_text())
|
|
self.assertEqual(len(ranking["order"]), 5)
|
|
|
|
def _args(self, *, plist, mode, script):
|
|
if script.endswith("run_multimer_jobs.py"):
|
|
return [
|
|
sys.executable,
|
|
str(self.script_multimer),
|
|
"--num_cycle=1",
|
|
"--num_predictions_per_model=1",
|
|
f"--data_dir={DATA_DIR}",
|
|
f"--monomer_objects_dir={self.test_features_dir}",
|
|
"--job_index=1",
|
|
f"--output_path={self.output_dir}",
|
|
f"--mode={mode}",
|
|
(
|
|
"--oligomer_state_file"
|
|
if mode == "homo-oligomer"
|
|
else "--protein_lists"
|
|
) + f"={self.test_protein_lists_dir / plist}",
|
|
]
|
|
else:
|
|
specifications = generate_fold_specifications(
|
|
input_files=[str(self.test_protein_lists_dir / plist)],
|
|
delimiter="+",
|
|
exclude_permutations=True,
|
|
)
|
|
lines = [
|
|
spec.replace(",", ":").replace(";", "+")
|
|
for spec in specifications if spec.strip()
|
|
]
|
|
formatted_input = lines[0] if lines else ""
|
|
return [
|
|
sys.executable,
|
|
str(self.script_single),
|
|
f"--input={formatted_input}",
|
|
f"--output_directory={self.output_dir}",
|
|
"--num_cycle=1",
|
|
"--num_predictions_per_model=1",
|
|
f"--data_directory={DATA_DIR}",
|
|
f"--features_directory={self.test_features_dir}",
|
|
]
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# parameterised “run mode” tests #
|
|
# --------------------------------------------------------------------------- #
|
|
class TestRunModes(_TestBase):
|
|
@parameterized.named_parameters(
|
|
dict(testcase_name="monomer", protein_list="test_monomer.txt", mode="custom", script="run_multimer_jobs.py"),
|
|
dict(testcase_name="dimer", protein_list="test_dimer.txt", mode="custom", script="run_multimer_jobs.py"),
|
|
dict(testcase_name="trimer", protein_list="test_trimer.txt", mode="custom", script="run_multimer_jobs.py"),
|
|
dict(testcase_name="homo_oligomer", protein_list="test_homooligomer.txt", mode="homo-oligomer", script="run_multimer_jobs.py"),
|
|
dict(testcase_name="chopped_dimer", protein_list="test_dimer_chopped.txt", mode="custom", script="run_multimer_jobs.py"),
|
|
dict(testcase_name="long_name", protein_list="test_long_name.txt", mode="custom", script="run_structure_prediction.py"),
|
|
)
|
|
def test_(self, protein_list, mode, script):
|
|
multimer = "monomer" not in protein_list
|
|
res = subprocess.run(
|
|
self._args(plist=protein_list, mode=mode, script=script),
|
|
capture_output=True, text=True
|
|
)
|
|
self._runCommonTests(res, multimer)
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# parameterised “resume” / relaxation tests #
|
|
# --------------------------------------------------------------------------- #
|
|
class TestResume(_TestBase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.protein_lists = self.test_protein_lists_dir / "test_dimer.txt"
|
|
# Resume tests need a pre-populated per-test output tree to continue from.
|
|
source = self.test_modelling_dir / "TEST_homo_2er"
|
|
target = self.output_dir / "TEST_homo_2er"
|
|
shutil.copytree(source, target, dirs_exist_ok=True)
|
|
|
|
self.base_args = [
|
|
sys.executable,
|
|
str(self.script_multimer),
|
|
"--mode=custom",
|
|
"--num_cycle=1",
|
|
"--num_predictions_per_model=1",
|
|
f"--data_dir={DATA_DIR}",
|
|
f"--protein_lists={self.protein_lists}",
|
|
f"--monomer_objects_dir={self.test_features_dir}",
|
|
"--job_index=1",
|
|
f"--output_path={self.output_dir}",
|
|
]
|
|
|
|
def _runAfterRelaxTests(self, relax_mode="All"):
|
|
expected = {"None": 0, "Best": 1, "All": 5}[relax_mode]
|
|
d = self.output_dir / "TEST_homo_2er"
|
|
got = len([f for f in d.iterdir() if f.name.startswith("relaxed") and f.suffix == ".pdb"])
|
|
self.assertEqual(got, expected)
|
|
|
|
@parameterized.named_parameters(
|
|
dict(
|
|
testcase_name="no_relax",
|
|
relax_mode="None",
|
|
remove=[
|
|
"relaxed_model_1_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_2_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_3_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_4_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_5_multimer_v3_pred_0.pdb",
|
|
],
|
|
),
|
|
dict(
|
|
testcase_name="relax_all",
|
|
relax_mode="All",
|
|
remove=[
|
|
"relaxed_model_1_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_2_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_3_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_4_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_5_multimer_v3_pred_0.pdb",
|
|
],
|
|
),
|
|
dict(
|
|
testcase_name="continue_relax",
|
|
relax_mode="All",
|
|
remove=["relaxed_model_5_multimer_v3_pred_0.pdb"],
|
|
),
|
|
dict(
|
|
testcase_name="continue_prediction",
|
|
relax_mode="Best",
|
|
remove=[
|
|
"unrelaxed_model_5_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_1_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_2_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_3_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_4_multimer_v3_pred_0.pdb",
|
|
"relaxed_model_5_multimer_v3_pred_0.pdb",
|
|
],
|
|
),
|
|
)
|
|
def test_(self, relax_mode, remove):
|
|
args = self.base_args + [f"--models_to_relax={relax_mode}"]
|
|
for fname in remove:
|
|
try:
|
|
(self.output_dir / "TEST_homo_2er" / fname).unlink()
|
|
except FileNotFoundError:
|
|
pass
|
|
res = subprocess.run(args, capture_output=True, text=True)
|
|
self._runCommonTests(res, multimer=True, dirname="TEST_homo_2er")
|
|
self._runAfterRelaxTests(relax_mode)
|
|
|
|
|
|
def _parse_test_args():
|
|
use_temp = '--use-temp-dir' in sys.argv or __import__("os").getenv('USE_TEMP_DIR', '').lower() in ('1','true','yes')
|
|
while '--use-temp-dir' in sys.argv:
|
|
sys.argv.remove('--use-temp-dir')
|
|
return use_temp
|
|
|
|
_TestBase.use_temp_dir = _parse_test_args()
|
|
|
|
|
|
# --------------------------------------------------------------------------- #
|
|
# dropout diversity tests #
|
|
# --------------------------------------------------------------------------- #
|
|
class TestDropoutDiversity(_TestBase):
|
|
"""Test that dropout flag generates more diverse models."""
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
# Use dimer because for monomer we can't use num_predictions_per_model
|
|
self.protein_lists = self.test_protein_lists_dir / "test_dropout.txt"
|
|
|
|
def test_dropout_increases_diversity(self):
|
|
"""Test that using --dropout flag increases diversity between predictions."""
|
|
|
|
# Create separate output directories for with/without dropout
|
|
dropout_output_dir = self.output_dir / "dropout_test"
|
|
no_dropout_output_dir = self.output_dir / "no_dropout_test"
|
|
dropout_output_dir.mkdir(parents=True, exist_ok=True)
|
|
no_dropout_output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Use simple test input
|
|
specifications = generate_fold_specifications(
|
|
input_files=[str(self.protein_lists)],
|
|
delimiter="+",
|
|
exclude_permutations=True,
|
|
)
|
|
lines = [
|
|
spec.replace(",", ":").replace(";", "+")
|
|
for spec in specifications if spec.strip()
|
|
]
|
|
formatted_input = lines[0] if lines else ""
|
|
|
|
# Base arguments for both runs
|
|
base_args = [
|
|
sys.executable,
|
|
str(self.script_single),
|
|
f"--input={formatted_input}",
|
|
"--num_cycle=1",
|
|
"--num_predictions_per_model=2", # Run 2 predictions to compare
|
|
f"--data_directory={DATA_DIR}",
|
|
f"--features_directory={self.test_features_dir}",
|
|
"--random_seed=42", # Fixed seed for reproducibility
|
|
"--model_names=model_2_multimer_v3",
|
|
]
|
|
|
|
# Run prediction without dropout
|
|
args_no_dropout = base_args + [f"--output_directory={no_dropout_output_dir}"]
|
|
|
|
# Run prediction with dropout
|
|
args_with_dropout = base_args + [
|
|
f"--output_directory={dropout_output_dir}",
|
|
"--dropout"
|
|
]
|
|
|
|
# Execute both predictions
|
|
logger.info("Running prediction without dropout...")
|
|
#logger.info("".join(args_no_dropout))
|
|
res_no_dropout = subprocess.run(args_no_dropout, capture_output=True, text=True)
|
|
self.assertEqual(res_no_dropout.returncode, 0,
|
|
f"No dropout prediction failed: {res_no_dropout.stderr}")
|
|
|
|
logger.info("Running prediction with dropout...")
|
|
res_with_dropout = subprocess.run(args_with_dropout, capture_output=True, text=True)
|
|
self.assertEqual(res_with_dropout.returncode, 0,
|
|
f"Dropout prediction failed: {res_with_dropout.stderr}")
|
|
|
|
# Find the generated PDB files
|
|
no_dropout_pdbs = sorted(list(no_dropout_output_dir.glob("**/unrelaxed_*.pdb")))
|
|
dropout_pdbs = sorted(list(dropout_output_dir.glob("**/unrelaxed_*.pdb")))
|
|
|
|
self.assertGreaterEqual(len(no_dropout_pdbs), 2, "Need at least 2 PDB files for no-dropout prediction")
|
|
self.assertGreaterEqual(len(dropout_pdbs), 2, "Need at least 2 PDB files for dropout prediction")
|
|
|
|
# Calculate RMSD between corresponding models
|
|
from alphapulldown.utils.calculate_rmsd import calculate_rmsd_and_superpose
|
|
|
|
# Create a temporary directory for RMSD calculation output
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
# Calculate RMSD between first and second prediction without dropout
|
|
rmsd_no_dropout = calculate_rmsd_and_superpose(
|
|
str(no_dropout_pdbs[0]), str(no_dropout_pdbs[1]), temp_dir
|
|
)
|
|
|
|
# Calculate RMSD between first and second prediction with dropout
|
|
rmsd_with_dropout = calculate_rmsd_and_superpose(
|
|
str(dropout_pdbs[0]), str(dropout_pdbs[1]), temp_dir
|
|
)
|
|
|
|
logger.info(f"RMSD without dropout (between pred_0 and pred_1): {rmsd_no_dropout:.4f}")
|
|
logger.info(f"RMSD with dropout (between pred_0 and pred_1): {rmsd_with_dropout:.4f}")
|
|
|
|
# Verify that dropout increases diversity (higher RMSD)
|
|
# Note: Due to randomness, this may not always be true, but it should be true on average
|
|
# For a robust test, we check that both calculations succeed and produce reasonable values
|
|
self.assertIsNotNone(rmsd_no_dropout, "RMSD calculation failed for no-dropout case")
|
|
self.assertIsNotNone(rmsd_with_dropout, "RMSD calculation failed for dropout case")
|
|
self.assertGreater(rmsd_no_dropout, 0, "RMSD should be positive for no-dropout case")
|
|
self.assertGreater(rmsd_with_dropout, 0, "RMSD should be positive for dropout case")
|
|
|
|
# Log the comparison result
|
|
if rmsd_with_dropout > rmsd_no_dropout:
|
|
logger.info("✓ Dropout increased structural diversity as expected")
|
|
else:
|
|
logger.info("⚠ Dropout did not increase diversity in this run (this can happen due to randomness)")
|
|
|
|
# The test passes if calculations succeed - the diversity check is informational
|
|
|
|
|
|
class TestMmseqsIssue588Inference(_TestBase):
|
|
"""Opt-in end-to-end regression for freshly generated mmseq AF2 features."""
|
|
|
|
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
|
|
|
|
def _require_mmseqs_functional_environment(self) -> None:
|
|
skip_reason = _mmseqs_functional_test_skip_reason()
|
|
if skip_reason:
|
|
self.skipTest(skip_reason)
|
|
for protein_id in self.ISSUE_588_IDS:
|
|
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
|
|
self.assertTrue(
|
|
fasta_path.is_file(),
|
|
f"Missing FASTA fixture {fasta_path}",
|
|
)
|
|
|
|
def _generate_issue_588_mmseq_features(self) -> Path:
|
|
feature_dir = self.output_dir / "issue_588_mmseq_features"
|
|
feature_dir.mkdir(parents=True, exist_ok=True)
|
|
fasta_paths = ",".join(
|
|
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
|
|
for protein_id in self.ISSUE_588_IDS
|
|
)
|
|
args = [
|
|
sys.executable,
|
|
str(self.script_create_features),
|
|
f"--fasta_paths={fasta_paths}",
|
|
f"--output_dir={feature_dir}",
|
|
f"--data_dir={DATA_DIR}",
|
|
"--max_template_date=2024-05-02",
|
|
"--use_mmseqs2=True",
|
|
"--data_pipeline=alphafold2",
|
|
"--compress_features=True",
|
|
"--skip_existing=False",
|
|
]
|
|
res = subprocess.run(args, capture_output=True, text=True)
|
|
self.assertEqual(
|
|
res.returncode,
|
|
0,
|
|
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
|
)
|
|
return feature_dir
|
|
|
|
def _resolve_af2_result_dir(self, root: Path) -> Path:
|
|
if (root / "ranking_debug.json").exists():
|
|
return root
|
|
candidates = sorted(
|
|
path.parent for path in root.rglob("ranking_debug.json")
|
|
)
|
|
self.assertEqual(
|
|
len(candidates),
|
|
1,
|
|
f"Expected one AF2 result directory under {root}, found {candidates}",
|
|
)
|
|
return candidates[0]
|
|
|
|
def test_issue_588_mmseqs_generated_features_enable_af2_multimer_inference(self):
|
|
from alphafold.data import feature_processing
|
|
from alphafold.data import msa_pairing
|
|
from alphafold.data import pipeline_multimer
|
|
|
|
self._require_mmseqs_functional_environment()
|
|
feature_dir = self._generate_issue_588_mmseq_features()
|
|
|
|
converted_chains = {}
|
|
for chain_id, protein_id in zip(("A", "B"), self.ISSUE_588_IDS):
|
|
feature_path = feature_dir / f"{protein_id}.pkl.xz"
|
|
feature_dict = _load_feature_dict(feature_path)
|
|
self.assertGreater(
|
|
_non_empty_identifier_count(
|
|
feature_dict["msa_species_identifiers_all_seq"]
|
|
),
|
|
0,
|
|
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
|
|
)
|
|
self.assertGreater(
|
|
_non_empty_identifier_count(
|
|
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
|
),
|
|
0,
|
|
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
|
|
)
|
|
converted_chains[chain_id] = pipeline_multimer.convert_monomer_features(
|
|
feature_dict,
|
|
chain_id,
|
|
)
|
|
|
|
assembly_features = pipeline_multimer.add_assembly_features(converted_chains)
|
|
feature_processing.process_unmerged_features(assembly_features)
|
|
np_chains = list(assembly_features.values())
|
|
paired_row_groups = msa_pairing.pair_sequences(np_chains)
|
|
paired_rows = msa_pairing.reorder_paired_rows(paired_row_groups)
|
|
self.assertGreater(
|
|
paired_rows.shape[0],
|
|
1,
|
|
"Fresh mmseq AF2 features should produce paired rows beyond the query",
|
|
)
|
|
|
|
prediction_dir = self.output_dir / "af2_prediction"
|
|
prediction_dir.mkdir(parents=True, exist_ok=True)
|
|
res = subprocess.run(
|
|
[
|
|
sys.executable,
|
|
str(self.script_single),
|
|
"--input=A0ABD7FQG0+P18004",
|
|
f"--output_directory={prediction_dir}",
|
|
"--num_cycle=1",
|
|
"--num_predictions_per_model=1",
|
|
"--model_names=model_4_multimer_v3",
|
|
f"--data_directory={DATA_DIR}",
|
|
f"--features_directory={feature_dir}",
|
|
"--random_seed=42",
|
|
],
|
|
capture_output=True,
|
|
text=True,
|
|
)
|
|
self.assertEqual(
|
|
res.returncode,
|
|
0,
|
|
f"AF2 inference failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
|
)
|
|
|
|
result_dir = self._resolve_af2_result_dir(prediction_dir)
|
|
ranking_payload = json.loads(
|
|
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
|
|
)
|
|
self.assertTrue(ranking_payload["order"])
|
|
|
|
result_pickles = sorted(result_dir.glob("result_*.pkl"))
|
|
self.assertLen(result_pickles, 1)
|
|
with result_pickles[0].open("rb") as handle:
|
|
result_payload = pickle.load(handle)
|
|
self.assertIn("iptm", result_payload)
|
|
self.assertIn("ranking_confidence", result_payload)
|
|
self.assertGreater(
|
|
result_payload["iptm"],
|
|
0.6,
|
|
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main()
|