mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Add MMseqs regression coverage for pairing and template reuse
Add opt-in AF2 and AF3 inference regressions that regenerate fresh MMseqs-derived AF2 features for the issue #588 A0ABD7FQG0/P18004 reproducer, verify recovered species identifiers, and exercise the existing wrapper entrypoints.\n\nAdd FASTA fixtures for that reproducer so the tests do not depend on a live UniProt download.\n\nAdd focused coverage for the MMseqs precomputed-MSA plus template re-search branch in make_mmseq_features(), and relax the outdated AF3 wrapper so it can submit Slurm jobs from login nodes without requiring a local GPU.
This commit is contained in:
@@ -14,6 +14,7 @@ import sys
|
||||
import tempfile
|
||||
import logging
|
||||
import unittest
|
||||
import lzma
|
||||
from pathlib import Path
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@@ -70,6 +71,34 @@ def _gpu_functional_test_skip_reason() -> str | None:
|
||||
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
|
||||
return None
|
||||
|
||||
|
||||
def _mmseqs_functional_test_skip_reason() -> str | None:
|
||||
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
||||
return None
|
||||
return (
|
||||
"MMseqs functional inference tests are disabled by default. "
|
||||
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
|
||||
)
|
||||
|
||||
|
||||
def _load_feature_dict(feature_path: Path) -> dict:
|
||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||
with opener(feature_path, "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
if hasattr(payload, "feature_dict"):
|
||||
return payload.feature_dict
|
||||
return payload
|
||||
|
||||
|
||||
def _non_empty_identifier_count(values) -> int:
|
||||
count = 0
|
||||
for value in values:
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
if str(value).strip():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -125,6 +154,9 @@ class _TestBase(parameterized.TestCase):
|
||||
apd_path = Path(alphapulldown.__path__[0])
|
||||
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
|
||||
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
|
||||
self.script_create_features = (
|
||||
apd_path / "scripts" / "create_individual_features.py"
|
||||
)
|
||||
|
||||
# ---------------- assertions reused by all subclasses ----------------- #
|
||||
def _runCommonTests(self, res: subprocess.CompletedProcess, multimer: bool, dirname: str | None = None):
|
||||
@@ -439,5 +471,140 @@ class TestDropoutDiversity(_TestBase):
|
||||
|
||||
# The test passes if calculations succeed - the diversity check is informational
|
||||
|
||||
|
||||
class TestMmseqsIssue588Inference(_TestBase):
|
||||
"""Opt-in end-to-end regression for freshly generated mmseq AF2 features."""
|
||||
|
||||
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
|
||||
|
||||
def _require_mmseqs_functional_environment(self) -> None:
|
||||
skip_reason = _mmseqs_functional_test_skip_reason()
|
||||
if skip_reason:
|
||||
self.skipTest(skip_reason)
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
|
||||
self.assertTrue(
|
||||
fasta_path.is_file(),
|
||||
f"Missing FASTA fixture {fasta_path}",
|
||||
)
|
||||
|
||||
def _generate_issue_588_mmseq_features(self) -> Path:
|
||||
feature_dir = self.output_dir / "issue_588_mmseq_features"
|
||||
feature_dir.mkdir(parents=True, exist_ok=True)
|
||||
fasta_paths = ",".join(
|
||||
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
|
||||
for protein_id in self.ISSUE_588_IDS
|
||||
)
|
||||
args = [
|
||||
sys.executable,
|
||||
str(self.script_create_features),
|
||||
f"--fasta_paths={fasta_paths}",
|
||||
f"--output_dir={feature_dir}",
|
||||
f"--data_dir={DATA_DIR}",
|
||||
"--max_template_date=2024-05-02",
|
||||
"--use_mmseqs2=True",
|
||||
"--data_pipeline=alphafold2",
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
]
|
||||
res = subprocess.run(args, capture_output=True, text=True)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
||||
)
|
||||
return feature_dir
|
||||
|
||||
def _resolve_af2_result_dir(self, root: Path) -> Path:
|
||||
if (root / "ranking_debug.json").exists():
|
||||
return root
|
||||
candidates = sorted(
|
||||
path.parent for path in root.rglob("ranking_debug.json")
|
||||
)
|
||||
self.assertEqual(
|
||||
len(candidates),
|
||||
1,
|
||||
f"Expected one AF2 result directory under {root}, found {candidates}",
|
||||
)
|
||||
return candidates[0]
|
||||
|
||||
def test_issue_588_mmseqs_generated_features_enable_af2_multimer_inference(self):
|
||||
from alphafold.data import feature_processing
|
||||
from alphafold.data import msa_pairing
|
||||
from alphafold.data import pipeline_multimer
|
||||
|
||||
self._require_mmseqs_functional_environment()
|
||||
feature_dir = self._generate_issue_588_mmseq_features()
|
||||
|
||||
converted_chains = {}
|
||||
for chain_id, protein_id in zip(("A", "B"), self.ISSUE_588_IDS):
|
||||
feature_path = feature_dir / f"{protein_id}.pkl.xz"
|
||||
feature_dict = _load_feature_dict(feature_path)
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_species_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
|
||||
)
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
|
||||
)
|
||||
converted_chains[chain_id] = pipeline_multimer.convert_monomer_features(
|
||||
feature_dict,
|
||||
chain_id,
|
||||
)
|
||||
|
||||
assembly_features = pipeline_multimer.add_assembly_features(converted_chains)
|
||||
np_chains = list(assembly_features.values())
|
||||
feature_processing.process_unmerged_features(np_chains)
|
||||
paired_rows = msa_pairing.pair_sequences(np_chains)
|
||||
self.assertGreater(
|
||||
paired_rows.shape[0],
|
||||
1,
|
||||
"Fresh mmseq AF2 features should produce paired rows beyond the query",
|
||||
)
|
||||
|
||||
prediction_dir = self.output_dir / "af2_prediction"
|
||||
prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
res = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_single),
|
||||
"--input=A0ABD7FQG0+P18004",
|
||||
f"--output_directory={prediction_dir}",
|
||||
"--num_cycle=1",
|
||||
"--num_predictions_per_model=1",
|
||||
"--model_names=model_4_multimer_v3",
|
||||
f"--data_directory={DATA_DIR}",
|
||||
f"--features_directory={feature_dir}",
|
||||
"--random_seed=42",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
f"AF2 inference failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
||||
)
|
||||
|
||||
result_dir = self._resolve_af2_result_dir(prediction_dir)
|
||||
ranking_payload = json.loads(
|
||||
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
|
||||
)
|
||||
self.assertTrue(ranking_payload["order"])
|
||||
|
||||
result_pickles = sorted(result_dir.glob("result_*.pkl"))
|
||||
self.assertLen(result_pickles, 1)
|
||||
with result_pickles[0].open("rb") as handle:
|
||||
result_payload = pickle.load(handle)
|
||||
self.assertIn("iptm", result_payload)
|
||||
self.assertIn("ranking_confidence", result_payload)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
||||
@@ -78,6 +78,15 @@ def _gpu_functional_test_skip_reason() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def _mmseqs_functional_test_skip_reason() -> str | None:
|
||||
if os.getenv("RUN_MMSEQS_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
||||
return None
|
||||
return (
|
||||
"MMseqs functional inference tests are disabled by default. "
|
||||
"Set RUN_MMSEQS_FUNCTIONAL_TESTS=1 to enable."
|
||||
)
|
||||
|
||||
|
||||
def _a3m_sequences(a3m_text: str) -> list[str]:
|
||||
if not a3m_text:
|
||||
return []
|
||||
@@ -139,6 +148,25 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
|
||||
return _a3m_payload_sequences(a3m_text) if a3m_text else []
|
||||
|
||||
|
||||
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
|
||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||
with opener(feature_path, "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
if hasattr(payload, "feature_dict"):
|
||||
return payload.feature_dict
|
||||
return payload
|
||||
|
||||
|
||||
def _non_empty_identifier_count(values) -> int:
|
||||
count = 0
|
||||
for value in values:
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
if str(value).strip():
|
||||
count += 1
|
||||
return count
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -185,6 +213,9 @@ class _TestBase(parameterized.TestCase):
|
||||
apd_path = Path(alphapulldown.__path__[0])
|
||||
self.script_multimer = apd_path / "scripts" / "run_multimer_jobs.py"
|
||||
self.script_single = apd_path / "scripts" / "run_structure_prediction.py"
|
||||
self.script_create_features = (
|
||||
apd_path / "scripts" / "create_individual_features.py"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
@@ -1438,6 +1469,124 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
)
|
||||
|
||||
|
||||
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
||||
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
|
||||
|
||||
ISSUE_588_IDS = ("A0ABD7FQG0", "P18004")
|
||||
|
||||
def _require_mmseqs_functional_environment(self) -> None:
|
||||
self._require_af3_functional_environment()
|
||||
skip_reason = _mmseqs_functional_test_skip_reason()
|
||||
if skip_reason:
|
||||
self.skipTest(skip_reason)
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
fasta_path = self.test_data_dir / "fastas" / f"{protein_id}.fasta"
|
||||
self.assertTrue(
|
||||
fasta_path.is_file(),
|
||||
f"Missing FASTA fixture {fasta_path}",
|
||||
)
|
||||
|
||||
def _generate_issue_588_mmseq_features(self, env: Dict[str, str]) -> Path:
|
||||
feature_dir = self.output_dir / "issue_588_mmseq_features"
|
||||
feature_dir.mkdir(parents=True, exist_ok=True)
|
||||
fasta_paths = ",".join(
|
||||
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
|
||||
for protein_id in self.ISSUE_588_IDS
|
||||
)
|
||||
res = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_create_features),
|
||||
f"--fasta_paths={fasta_paths}",
|
||||
f"--output_dir={feature_dir}",
|
||||
f"--data_dir={DATA_DIR}",
|
||||
"--max_template_date=2024-05-02",
|
||||
"--use_mmseqs2=True",
|
||||
"--data_pipeline=alphafold2",
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
f"MMseqs feature generation failed.\nSTDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
||||
)
|
||||
return feature_dir
|
||||
|
||||
def test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference(self):
|
||||
self._require_mmseqs_functional_environment()
|
||||
env = self._make_af3_test_env()
|
||||
feature_dir = self._generate_issue_588_mmseq_features(env)
|
||||
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_species_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered species IDs in msa_species_identifiers_all_seq",
|
||||
)
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered accession IDs in msa_uniprot_accession_identifiers_all_seq",
|
||||
)
|
||||
|
||||
flash_impl = self._af3_flash_attention_impl()
|
||||
res = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_single),
|
||||
"--input=A0ABD7FQG0+P18004",
|
||||
f"--output_directory={self.output_dir}",
|
||||
f"--data_directory={DATA_DIR}",
|
||||
f"--features_directory={feature_dir}",
|
||||
"--fold_backend=alphafold3",
|
||||
f"--flash_attention_implementation={flash_impl}",
|
||||
"--num_diffusion_samples=1",
|
||||
"--random_seed=42",
|
||||
"--debug_msas",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
env=env,
|
||||
)
|
||||
self._runCommonTests(res)
|
||||
|
||||
result_dir = self._resolve_single_af3_result_dir()
|
||||
summary_paths = sorted(
|
||||
result_dir.glob("*_af2_to_af3_translation_summary.json")
|
||||
)
|
||||
self.assertLen(summary_paths, 1)
|
||||
summary = json.loads(summary_paths[0].read_text(encoding="utf-8"))
|
||||
self.assertEqual(
|
||||
summary["translation_modes"],
|
||||
["af3_species_pairing_from_af2_individual_msas"],
|
||||
)
|
||||
self.assertTrue(summary["paired_rows_valid"])
|
||||
self.assertTrue(summary["unpaired_rows_valid"])
|
||||
for chain_summary in summary["chains"]:
|
||||
self.assertGreater(chain_summary["paired_msa_row_count"], 0)
|
||||
self.assertGreater(chain_summary["unpaired_msa_row_count"], 0)
|
||||
self.assertGreater(chain_summary["paired_species_identifier_count"], 0)
|
||||
|
||||
confidence_files = sorted(result_dir.glob("*_summary_confidences.json"))
|
||||
self.assertLen(confidence_files, 1)
|
||||
confidence_payload = json.loads(
|
||||
confidence_files[0].read_text(encoding="utf-8")
|
||||
)
|
||||
self.assertIn("iptm", confidence_payload)
|
||||
self.assertGreaterEqual(confidence_payload["iptm"], 0.0)
|
||||
self.assertLessEqual(confidence_payload["iptm"], 1.0)
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# parameterised "run mode" tests #
|
||||
# --------------------------------------------------------------------------- #
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
|
||||
# ---------- per-test set-up ------------------------------------------- #
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not _has_gpu():
|
||||
if not _has_gpu() and not self._is_slurm_available():
|
||||
self.skipTest("NVIDIA GPU not detected – skipping Alphafold3 tests")
|
||||
|
||||
# Check for correct conda environment
|
||||
@@ -239,7 +239,7 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
|
||||
log_path = self.case_dir / f"test_{idx}_{cls_name}_{test_name}.log"
|
||||
|
||||
res = subprocess.run(
|
||||
["sbatch", f"--output={log_path}", str(script_path)],
|
||||
["sbatch", "--export=ALL", f"--output={log_path}", str(script_path)],
|
||||
text=True,
|
||||
capture_output=True,
|
||||
check=True,
|
||||
@@ -263,6 +263,12 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
|
||||
{"testcase_name": "homo_oligomer", "i": 5, "cls": "TestAlphaFold3RunModes", "test": "test__homo_oligomer"},
|
||||
{"testcase_name": "chopped_dimer", "i": 6, "cls": "TestAlphaFold3RunModes", "test": "test__chopped_dimer"},
|
||||
{"testcase_name": "long_name", "i": 7, "cls": "TestAlphaFold3RunModes", "test": "test__long_name"},
|
||||
{
|
||||
"testcase_name": "issue_588_mmseqs_inference",
|
||||
"i": 8,
|
||||
"cls": "TestAlphaFold3MmseqsIssue588Inference",
|
||||
"test": "test_issue_588_mmseqs_af2_features_enable_af3_species_pairing_inference",
|
||||
},
|
||||
)
|
||||
def test_predict_structure(self, i: int, cls: str, test: str):
|
||||
"""Route each parameterised test either through Slurm or local run."""
|
||||
@@ -277,4 +283,4 @@ class TestAlphaFold3PredictStructure(parameterized.TestCase):
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main()
|
||||
|
||||
2
test/test_data/fastas/A0ABD7FQG0.fasta
Normal file
2
test/test_data/fastas/A0ABD7FQG0.fasta
Normal file
@@ -0,0 +1,2 @@
|
||||
>A0ABD7FQG0
|
||||
MSGDENKLKKYRFPETLTNQSRWFGLPLDELIPAAICIGWGITTSKYLFGIGAAVLVYFGIKKLKKGRGSSWLRDLIYWYMPTALLRGIFHNVPDSCFRQWIK
|
||||
2
test/test_data/fastas/P18004.fasta
Normal file
2
test/test_data/fastas/P18004.fasta
Normal file
@@ -0,0 +1,2 @@
|
||||
>P18004
|
||||
MNNPLEAVTQAVNSLVTALKLPDESAKANEVLGEMSFPQFSRLLPYRDYNQESGLFMNDTTMGFMLEAIPINGANESIVEALDHMLRTKLPRGIPLCIHLMSSQLVGDRIEYGLREFSWSGEQAERFNAITRAYYMKAAATQFPLPEGMNLPLTLRHYRVFISYCSPSKKKSRADILEMENLVKIIRASLQGASITTQTVDAQAFIDIVGEMINHNPDSLYPKRRQLDPYSDLNYQCVEDSFDLKVRADYLTLGLRENGRNSTARILNFHLARNPEIAFLWNMADNYSNLLNPELSISCPFILTLTLVVEDQVKTHSEANLKYMDLEKKSKTSYAKWFPSVEKEAKEWGELRQRLGSGQSSVVSYFLNITAFCKDNNETALEVEQDILNSFRKNGFELISPRFNHMRNFLTCLPFMAGKGLFKQLKEAGVVQRAESFNVANLMPLVADNPLTPAGLLAPTYRNQLAFIDIFFRGMNNTNYNMAVCGTSGAGKTGLIQPLIRSVLDSGGFAVVFDMGDGYKSLCENMGGVYLDGETLRFNPFANITDIDQSAERVRDQLSVMASPNGNLDEVHEGLLLQAVRASWLAKENRARIDDVVDFLKNASDSEQYAESPTIRSRLDEMIVLLDQYTANGTYGQYFNSDEPSLRDDAKMVVLELGGLEDRPSLLVAVMFSLIIYIENRMYRTPRNLKKLNVIDEGWRLLDFKNHKVGEFIEKGYRTARRHTGAYITITQNIVDFDSDKASSAARAAWGNSSYKIILKQSAKEFAKYNQLYPDQFLPLQRDMIGKFGAAKDQWFSSFLLQVENHSSWHRLFVDPLSRAMYSSDGPDFEFVQQKRKEGLSIHEAVWQLAWKKSGPEMASLEAWLEEHEKYRSVA
|
||||
@@ -3,6 +3,7 @@ import numpy as np
|
||||
from alphafold.data import msa_pairing
|
||||
from alphafold.data import parsers
|
||||
from alphafold.data import pipeline
|
||||
from alphapulldown.objects import MonomericObject
|
||||
from alphapulldown.utils import mmseqs_species_identifiers
|
||||
|
||||
|
||||
@@ -115,3 +116,111 @@ def test_pair_sequences_works_with_mmseqs_accession_species_resolution(
|
||||
assert paired_rows.shape == (3, 2)
|
||||
assert tuple(paired_rows[0]) == (0, 0)
|
||||
assert {tuple(row) for row in paired_rows[1:]} == {(1, 1), (2, 2)}
|
||||
|
||||
|
||||
def test_make_mmseq_features_researches_templates_for_precomputed_msa(
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
import alphapulldown.objects as objects_mod
|
||||
|
||||
a3m_path = tmp_path / 'dummy.a3m'
|
||||
a3m_text = '\n'.join([
|
||||
'# header line that should be ignored later',
|
||||
'>101',
|
||||
'ACDE',
|
||||
'',
|
||||
])
|
||||
a3m_path.write_text(a3m_text, encoding='utf-8')
|
||||
|
||||
calls = {}
|
||||
|
||||
def fake_unserialize_msa(a3m_lines, sequence):
|
||||
calls['unserialize_msa'] = {
|
||||
'a3m_lines': a3m_lines,
|
||||
'sequence': sequence,
|
||||
}
|
||||
return (
|
||||
['PRECOMPUTED_UNPAIRED'],
|
||||
['PRECOMPUTED_PAIRED'],
|
||||
['PRECOMPUTED_UNIQUE'],
|
||||
['PRECOMPUTED_CARDINALITY'],
|
||||
['PRECOMPUTED_TEMPLATE'],
|
||||
)
|
||||
|
||||
def fake_get_msa_and_templates(**kwargs):
|
||||
calls['get_msa_and_templates'] = kwargs
|
||||
return (
|
||||
['IGNORED_UNPAIRED'],
|
||||
['IGNORED_PAIRED'],
|
||||
['IGNORED_UNIQUE'],
|
||||
['IGNORED_CARDINALITY'],
|
||||
['TEMPLATE_FROM_RESEARCH'],
|
||||
)
|
||||
|
||||
def fake_build_monomer_feature(sequence, msa, template_feature):
|
||||
calls['build_monomer_feature'] = {
|
||||
'sequence': sequence,
|
||||
'msa': msa,
|
||||
'template_feature': template_feature,
|
||||
}
|
||||
return {
|
||||
'template_confidence_scores': None,
|
||||
'template_release_date': None,
|
||||
}
|
||||
|
||||
def fake_enrich(feature_dict, a3m, **_kwargs):
|
||||
calls['enrich_mmseq_feature_dict_with_identifiers'] = a3m
|
||||
feature_dict['msa_species_identifiers'] = np.asarray([b''])
|
||||
feature_dict['msa_uniprot_accession_identifiers'] = np.asarray([b''])
|
||||
|
||||
monkeypatch.setattr(objects_mod, 'unserialize_msa', fake_unserialize_msa)
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
'get_msa_and_templates',
|
||||
fake_get_msa_and_templates,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
'build_monomer_feature',
|
||||
fake_build_monomer_feature,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod,
|
||||
'enrich_mmseq_feature_dict_with_identifiers',
|
||||
fake_enrich,
|
||||
)
|
||||
|
||||
monomer = MonomericObject('dummy', 'ACDE')
|
||||
monomer.make_mmseq_features(
|
||||
DEFAULT_API_SERVER='https://fake.server',
|
||||
output_dir=str(tmp_path),
|
||||
use_precomputed_msa=True,
|
||||
use_templates=True,
|
||||
)
|
||||
|
||||
assert calls['unserialize_msa']['sequence'] == 'ACDE'
|
||||
assert calls['unserialize_msa']['a3m_lines'] == ['>101\nACDE']
|
||||
assert calls['get_msa_and_templates'] == {
|
||||
'jobname': 'dummy',
|
||||
'query_sequences': 'ACDE',
|
||||
'a3m_lines': False,
|
||||
'result_dir': tmp_path,
|
||||
'msa_mode': 'single_sequence',
|
||||
'use_templates': True,
|
||||
'custom_template_path': None,
|
||||
'pair_mode': 'none',
|
||||
'host_url': 'https://fake.server',
|
||||
'user_agent': 'alphapulldown',
|
||||
}
|
||||
assert calls['build_monomer_feature'] == {
|
||||
'sequence': 'ACDE',
|
||||
'msa': 'PRECOMPUTED_UNPAIRED',
|
||||
'template_feature': 'TEMPLATE_FROM_RESEARCH',
|
||||
}
|
||||
assert (
|
||||
calls['enrich_mmseq_feature_dict_with_identifiers']
|
||||
== 'PRECOMPUTED_UNPAIRED'
|
||||
)
|
||||
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
|
||||
assert monomer.feature_dict['template_release_date'] == ['none']
|
||||
|
||||
Reference in New Issue
Block a user