mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 05:58:11 +08:00
Harden precomputed MSA identifier enrichment
This commit is contained in:
@@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
|
||||
)
|
||||
|
||||
|
||||
def _ensure_identifier_feature_arrays(
|
||||
feature_dict: Dict[str, np.ndarray],
|
||||
feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...],
|
||||
) -> Dict[str, np.ndarray]:
|
||||
"""Backfill missing identifier arrays to match the corresponding MSA rows."""
|
||||
normalized = dict(feature_dict)
|
||||
for msa_key, identifier_keys in feature_groups:
|
||||
msa = normalized.get(msa_key)
|
||||
if msa is None:
|
||||
continue
|
||||
num_rows = int(np.asarray(msa).shape[0])
|
||||
for key in identifier_keys:
|
||||
if key not in normalized:
|
||||
normalized[key] = np.array([b""] * num_rows, dtype=object)
|
||||
return normalized
|
||||
|
||||
|
||||
class MonomericObject:
|
||||
"""
|
||||
monomeric objects
|
||||
@@ -145,7 +162,18 @@ class MonomericObject:
|
||||
|
||||
msa = parsers.parse_stockholm(result["sto"])
|
||||
msa = msa.truncate(max_seqs=50000)
|
||||
all_seq_features = pipeline.make_msa_features([msa])
|
||||
all_seq_features = _ensure_identifier_feature_arrays(
|
||||
pipeline.make_msa_features([msa]),
|
||||
(
|
||||
(
|
||||
"msa",
|
||||
(
|
||||
"msa_species_identifiers",
|
||||
"msa_uniprot_accession_identifiers",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
valid_feats = msa_pairing.MSA_FEATURES + (
|
||||
"msa_species_identifiers",
|
||||
"msa_uniprot_accession_identifiers",
|
||||
@@ -362,9 +390,12 @@ class MonomericObject:
|
||||
# Remove header lines starting with '#' if present.
|
||||
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
|
||||
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
|
||||
# Enrich from the same A3M string that build_monomer_feature parsed, so
|
||||
# the identifier rows go through the same parse_a3m dedup as msa_features
|
||||
# and their count matches feature_dict['msa'] exactly.
|
||||
enrich_mmseq_feature_dict_with_identifiers(
|
||||
self.feature_dict,
|
||||
a3m_lines[0],
|
||||
unpaired_msa[0],
|
||||
cache_path=os.path.join(
|
||||
result_dir, f"{self.description}.mmseq_ids.json"
|
||||
),
|
||||
@@ -811,12 +842,48 @@ create_individual_features.py
|
||||
output_list.append(new_chain)
|
||||
return output_list
|
||||
|
||||
@staticmethod
|
||||
def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]:
|
||||
"""Ensure identifier arrays exist consistently across chains.
|
||||
|
||||
Some feature sources provide species identifiers but omit UniProt
|
||||
accession IDs, while DeepMind's multimer pairing and merge code assumes
|
||||
both unpaired and `_all_seq` identifier keys exist consistently across
|
||||
chains.
|
||||
"""
|
||||
output_list = []
|
||||
for feat_dict in np_chain_list:
|
||||
output_list.append(
|
||||
_ensure_identifier_feature_arrays(
|
||||
feat_dict,
|
||||
(
|
||||
(
|
||||
"msa",
|
||||
(
|
||||
"msa_species_identifiers",
|
||||
"msa_uniprot_accession_identifiers",
|
||||
),
|
||||
),
|
||||
(
|
||||
"msa_all_seq",
|
||||
(
|
||||
"msa_species_identifiers_all_seq",
|
||||
"msa_uniprot_accession_identifiers_all_seq",
|
||||
),
|
||||
),
|
||||
),
|
||||
)
|
||||
)
|
||||
return output_list
|
||||
|
||||
def pair_and_merge(self, all_chain_features):
|
||||
"""merge all chain features"""
|
||||
feature_processing.process_unmerged_features(all_chain_features)
|
||||
MAX_TEMPLATES = 4
|
||||
MSA_CROP_SIZE = 2048
|
||||
np_chains_list = list(all_chain_features.values())
|
||||
np_chains_list = MultimericObject.normalize_all_seq_identifier_features(
|
||||
list(all_chain_features.values())
|
||||
)
|
||||
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
|
||||
np_chains_list)
|
||||
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")
|
||||
|
||||
@@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict:
|
||||
return payload
|
||||
|
||||
|
||||
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]:
|
||||
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
|
||||
if len(matches) != 1:
|
||||
raise FileNotFoundError(
|
||||
f"Expected one feature metadata file for {protein_id} in {feature_dir}, "
|
||||
f"found {matches}"
|
||||
)
|
||||
metadata_path = matches[0]
|
||||
opener = lzma.open if metadata_path.suffix == ".xz" else open
|
||||
with opener(metadata_path, "rt", encoding="utf-8") as handle:
|
||||
return metadata_path, json.load(handle)
|
||||
|
||||
|
||||
def _metadata_bool(value) -> bool:
|
||||
return str(value).strip().lower() in {"1", "true", "yes"}
|
||||
|
||||
|
||||
def _non_empty_identifier_count(values) -> int:
|
||||
count = 0
|
||||
for value in values:
|
||||
@@ -105,11 +122,12 @@ def _non_empty_identifier_count(values) -> int:
|
||||
def _af2_subprocess_env() -> dict[str, str]:
|
||||
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
|
||||
env = os.environ.copy()
|
||||
env.setdefault("OMP_NUM_THREADS", "4")
|
||||
env.setdefault("MKL_NUM_THREADS", "4")
|
||||
env.setdefault("NUMEXPR_NUM_THREADS", "4")
|
||||
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
|
||||
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
|
||||
env.setdefault("OMP_NUM_THREADS", "1")
|
||||
env.setdefault("OPENBLAS_NUM_THREADS", "1")
|
||||
env.setdefault("MKL_NUM_THREADS", "1")
|
||||
env.setdefault("NUMEXPR_NUM_THREADS", "1")
|
||||
env.setdefault("TF_NUM_INTEROP_THREADS", "1")
|
||||
env.setdefault("TF_NUM_INTRAOP_THREADS", "1")
|
||||
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
|
||||
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
||||
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||
@@ -117,7 +135,8 @@ def _af2_subprocess_env() -> dict[str, str]:
|
||||
env.setdefault("JAX_PLATFORM_NAME", "gpu")
|
||||
env.setdefault(
|
||||
"XLA_FLAGS",
|
||||
"--xla_gpu_force_compilation_parallelism=0 "
|
||||
"--xla_gpu_force_compilation_parallelism=1 "
|
||||
"--xla_force_host_platform_device_count=1 "
|
||||
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
|
||||
)
|
||||
return env
|
||||
@@ -544,6 +563,87 @@ class TestMmseqsIssue588Inference(_TestBase):
|
||||
)
|
||||
return feature_dir
|
||||
|
||||
def _generate_issue_588_precomputed_mmseq_features(self) -> Path:
|
||||
source_dir = self.output_dir / "issue_588_mmseq_source_features"
|
||||
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
|
||||
source_dir.mkdir(parents=True, exist_ok=True)
|
||||
precomputed_dir.mkdir(parents=True, exist_ok=True)
|
||||
fasta_paths = ",".join(
|
||||
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
|
||||
for protein_id in self.ISSUE_588_IDS
|
||||
)
|
||||
|
||||
source_res = self._run_prediction_subprocess(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_create_features),
|
||||
f"--fasta_paths={fasta_paths}",
|
||||
f"--output_dir={source_dir}",
|
||||
f"--data_dir={DATA_DIR}",
|
||||
"--max_template_date=2024-05-02",
|
||||
"--use_mmseqs2=True",
|
||||
"--data_pipeline=alphafold2",
|
||||
"--save_msa_files=True",
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
source_res.returncode,
|
||||
0,
|
||||
"MMseqs source feature generation failed.\n"
|
||||
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
|
||||
)
|
||||
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
self.assertTrue(
|
||||
(source_dir / f"{protein_id}.a3m").is_file(),
|
||||
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(source_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||
)
|
||||
shutil.copy2(
|
||||
source_dir / f"{protein_id}.a3m",
|
||||
precomputed_dir / f"{protein_id}.a3m",
|
||||
)
|
||||
sidecar = source_dir / f"{protein_id}.mmseq_ids.json"
|
||||
if sidecar.is_file():
|
||||
shutil.copy2(sidecar, precomputed_dir / sidecar.name)
|
||||
|
||||
precomputed_res = self._run_prediction_subprocess(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_create_features),
|
||||
f"--fasta_paths={fasta_paths}",
|
||||
f"--output_dir={precomputed_dir}",
|
||||
f"--data_dir={DATA_DIR}",
|
||||
"--max_template_date=2024-05-02",
|
||||
"--use_mmseqs2=True",
|
||||
"--use_precomputed_msas=True",
|
||||
"--data_pipeline=alphafold2",
|
||||
"--compress_features=True",
|
||||
"--skip_existing=False",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
precomputed_res.returncode,
|
||||
0,
|
||||
"Precomputed-MMseq feature generation failed.\n"
|
||||
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
|
||||
)
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
self.assertTrue(
|
||||
(precomputed_dir / f"{protein_id}.a3m").is_file(),
|
||||
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
|
||||
)
|
||||
self.assertTrue(
|
||||
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||
)
|
||||
return precomputed_dir
|
||||
|
||||
def _resolve_af2_result_dir(self, root: Path) -> Path:
|
||||
if (root / "ranking_debug.json").exists():
|
||||
return root
|
||||
@@ -642,5 +742,73 @@ class TestMmseqsIssue588Inference(_TestBase):
|
||||
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
|
||||
)
|
||||
|
||||
def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self):
|
||||
"""Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms."""
|
||||
self._require_mmseqs_functional_environment()
|
||||
feature_dir = self._generate_issue_588_precomputed_mmseq_features()
|
||||
|
||||
for protein_id in self.ISSUE_588_IDS:
|
||||
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
|
||||
self.assertTrue(
|
||||
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
|
||||
f"{metadata_path} should record use_precomputed_msas=True",
|
||||
)
|
||||
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_species_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
|
||||
)
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
|
||||
)
|
||||
|
||||
prediction_dir = self.output_dir / "af2_precomputed_prediction"
|
||||
prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||
res = self._run_prediction_subprocess(
|
||||
[
|
||||
sys.executable,
|
||||
str(self.script_single),
|
||||
"--input=A0ABD7FQG0+P18004",
|
||||
f"--output_directory={prediction_dir}",
|
||||
"--num_cycle=1",
|
||||
"--num_predictions_per_model=1",
|
||||
"--model_names=model_4_multimer_v3",
|
||||
f"--data_directory={DATA_DIR}",
|
||||
f"--features_directory={feature_dir}",
|
||||
"--random_seed=42",
|
||||
]
|
||||
)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
"AF2 inference from precomputed MMseq features failed.\n"
|
||||
f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
||||
)
|
||||
|
||||
result_dir = self._resolve_af2_result_dir(prediction_dir)
|
||||
ranking_payload = json.loads(
|
||||
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
|
||||
)
|
||||
self.assertTrue(ranking_payload["order"])
|
||||
|
||||
result_pickles = sorted(result_dir.glob("result_*.pkl"))
|
||||
self.assertLen(result_pickles, 1)
|
||||
with result_pickles[0].open("rb") as handle:
|
||||
result_payload = pickle.load(handle)
|
||||
self.assertIn("iptm", result_payload)
|
||||
self.assertIn("ranking_confidence", result_payload)
|
||||
self.assertGreater(
|
||||
result_payload["iptm"],
|
||||
0.6,
|
||||
f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}",
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
||||
@@ -20,7 +20,9 @@ import json
|
||||
import numpy as np
|
||||
import re
|
||||
import unittest
|
||||
from types import SimpleNamespace
|
||||
from typing import Dict, List, Tuple, Any
|
||||
from unittest import mock
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
@@ -151,14 +153,24 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
|
||||
|
||||
|
||||
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
|
||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||
with opener(feature_path, "rb") as handle:
|
||||
payload = pickle.load(handle)
|
||||
payload = _load_feature_payload(feature_path)
|
||||
if hasattr(payload, "feature_dict"):
|
||||
return payload.feature_dict
|
||||
return payload
|
||||
|
||||
|
||||
def _load_feature_payload(feature_path: Path) -> Any:
|
||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||
with opener(feature_path, "rb") as handle:
|
||||
return pickle.load(handle)
|
||||
|
||||
|
||||
def _write_feature_payload(feature_path: Path, payload: Any) -> None:
|
||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||
with opener(feature_path, "wb") as handle:
|
||||
pickle.dump(payload, handle)
|
||||
|
||||
|
||||
def _non_empty_identifier_count(values) -> int:
|
||||
count = 0
|
||||
for value in values:
|
||||
@@ -1407,6 +1419,106 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
fold_input_obj, _ = next(iter(mappings[0].items()))
|
||||
return fold_input_obj
|
||||
|
||||
def _copy_real_feature_fixture(
|
||||
self,
|
||||
*,
|
||||
source_dir: Path,
|
||||
protein_id: str,
|
||||
target_dir: Path,
|
||||
) -> Path:
|
||||
copied_feature_path = None
|
||||
for pattern in (
|
||||
f"{protein_id}.pkl",
|
||||
f"{protein_id}.pkl.xz",
|
||||
f"{protein_id}.a3m",
|
||||
f"{protein_id}_feature_metadata_*.json*",
|
||||
):
|
||||
for source_path in sorted(source_dir.glob(pattern)):
|
||||
target_path = target_dir / source_path.name
|
||||
shutil.copy2(source_path, target_path)
|
||||
if source_path.name.startswith(f"{protein_id}.pkl"):
|
||||
copied_feature_path = target_path
|
||||
|
||||
self.assertIsNotNone(
|
||||
copied_feature_path,
|
||||
f"Missing real feature fixture for {protein_id} in {source_dir}",
|
||||
)
|
||||
return copied_feature_path
|
||||
|
||||
@staticmethod
|
||||
def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray:
|
||||
identifiers = []
|
||||
for index, value in enumerate(species_ids):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode("utf-8")
|
||||
identifiers.append(
|
||||
f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b""
|
||||
)
|
||||
return np.asarray(identifiers, dtype=object)
|
||||
|
||||
def _prepare_mixed_identifier_fixture_dir(self) -> Path:
|
||||
"""Materialize real AF2 fixtures with mixed identifier enrichment.
|
||||
|
||||
The underlying MSA rows come from repo fixtures in `test/test_data`.
|
||||
We only adjust the identifier sidecars so one chain looks enriched while
|
||||
the other reproduces the "no species enrichment / no accession IDs"
|
||||
failure mode from issue #614's AF3 follow-up comment.
|
||||
"""
|
||||
feature_dir = self.output_dir / "mixed_identifier_features"
|
||||
feature_dir.mkdir(parents=True, exist_ok=True)
|
||||
source_dir = self.test_features_dir / "af2_features" / "protein"
|
||||
|
||||
enriched_feature_path = self._copy_real_feature_fixture(
|
||||
source_dir=source_dir,
|
||||
protein_id="A0A024R1R8",
|
||||
target_dir=feature_dir,
|
||||
)
|
||||
unenriched_feature_path = self._copy_real_feature_fixture(
|
||||
source_dir=source_dir,
|
||||
protein_id="P61626",
|
||||
target_dir=feature_dir,
|
||||
)
|
||||
|
||||
enriched_payload = _load_feature_payload(enriched_feature_path)
|
||||
enriched_feature_dict = (
|
||||
enriched_payload.feature_dict
|
||||
if hasattr(enriched_payload, "feature_dict")
|
||||
else enriched_payload
|
||||
)
|
||||
enriched_feature_dict["msa_uniprot_accession_identifiers"] = (
|
||||
self._synthetic_accession_ids(
|
||||
np.asarray(enriched_feature_dict["msa_species_identifiers"])
|
||||
)
|
||||
)
|
||||
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = (
|
||||
self._synthetic_accession_ids(
|
||||
np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"])
|
||||
)
|
||||
)
|
||||
_write_feature_payload(enriched_feature_path, enriched_payload)
|
||||
|
||||
unenriched_payload = _load_feature_payload(unenriched_feature_path)
|
||||
unenriched_feature_dict = (
|
||||
unenriched_payload.feature_dict
|
||||
if hasattr(unenriched_payload, "feature_dict")
|
||||
else unenriched_payload
|
||||
)
|
||||
unenriched_feature_dict["msa_species_identifiers"] = np.asarray(
|
||||
[b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]),
|
||||
dtype=object,
|
||||
)
|
||||
unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray(
|
||||
[b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]),
|
||||
dtype=object,
|
||||
)
|
||||
unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None)
|
||||
unenriched_feature_dict.pop(
|
||||
"msa_uniprot_accession_identifiers_all_seq", None
|
||||
)
|
||||
_write_feature_payload(unenriched_feature_path, unenriched_payload)
|
||||
|
||||
return feature_dir
|
||||
|
||||
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
|
||||
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
|
||||
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
|
||||
@@ -1638,6 +1750,117 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
||||
all(template["mmcif"] for template in protein_entries[0]["templates"])
|
||||
)
|
||||
|
||||
def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self):
|
||||
"""AF3 prep should tolerate a real mixed-enrichment multimer feature set."""
|
||||
from alphapulldown.folding_backend.alphafold3_backend import (
|
||||
AlphaFold3Backend,
|
||||
process_fold_input,
|
||||
)
|
||||
from alphapulldown.scripts import run_structure_prediction
|
||||
|
||||
feature_dir = self._prepare_mixed_identifier_fixture_dir()
|
||||
|
||||
enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl")
|
||||
self.assertGreater(
|
||||
_non_empty_identifier_count(
|
||||
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
)
|
||||
unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl")
|
||||
self.assertEqual(
|
||||
_non_empty_identifier_count(
|
||||
unenriched_feature_dict["msa_species_identifiers_all_seq"]
|
||||
),
|
||||
0,
|
||||
)
|
||||
self.assertNotIn(
|
||||
"msa_uniprot_accession_identifiers_all_seq",
|
||||
unenriched_feature_dict,
|
||||
)
|
||||
|
||||
script_flags = SimpleNamespace(
|
||||
pair_msa=True,
|
||||
multimeric_template=False,
|
||||
description_file=None,
|
||||
path_to_mmt=None,
|
||||
threshold_clashes=1000,
|
||||
hb_allowance=0.4,
|
||||
plddt_threshold=0,
|
||||
save_features_for_multimeric_object=False,
|
||||
features_directory=[str(feature_dir)],
|
||||
use_ap_style=False,
|
||||
)
|
||||
|
||||
with mock.patch.object(run_structure_prediction, "FLAGS", script_flags):
|
||||
parsed = run_structure_prediction.parse_fold(
|
||||
["A0A024R1R8+P61626"],
|
||||
[str(feature_dir)],
|
||||
"+",
|
||||
)
|
||||
data = run_structure_prediction.create_custom_info(parsed)
|
||||
all_interactors = run_structure_prediction.create_interactors(
|
||||
data,
|
||||
[str(feature_dir)],
|
||||
)
|
||||
self.assertLen(all_interactors, 1)
|
||||
self.assertLen(all_interactors[0], 2)
|
||||
object_to_model, prepared_output_dir = (
|
||||
run_structure_prediction.pre_modelling_setup(
|
||||
all_interactors[0],
|
||||
output_dir=str(self.output_dir / "mixed_identifier_prediction"),
|
||||
)
|
||||
)
|
||||
|
||||
mappings = AlphaFold3Backend.prepare_input(
|
||||
objects_to_model=[
|
||||
{"object": object_to_model, "output_dir": prepared_output_dir}
|
||||
],
|
||||
random_seed=42,
|
||||
debug_msas=True,
|
||||
)
|
||||
self.assertLen(mappings, 1)
|
||||
fold_input_obj, (
|
||||
prepared_output_dir,
|
||||
resolve_msa_overlaps,
|
||||
) = next(iter(mappings[0].items()))
|
||||
|
||||
process_fold_input(
|
||||
fold_input=fold_input_obj,
|
||||
model_runner=None,
|
||||
output_dir=prepared_output_dir,
|
||||
buckets=(512,),
|
||||
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||
)
|
||||
|
||||
job_name = fold_input_obj.sanitised_name()
|
||||
summary_path = (
|
||||
Path(prepared_output_dir)
|
||||
/ f"{job_name}_af2_to_af3_translation_summary.json"
|
||||
)
|
||||
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
|
||||
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||
self.assertLen(summary["chains"], 2)
|
||||
self.assertTrue(summary["unpaired_rows_valid"])
|
||||
|
||||
input_json = Path(prepared_output_dir) / f"{job_name}_data.json"
|
||||
self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}")
|
||||
written = json.loads(input_json.read_text(encoding="utf-8"))
|
||||
protein_entries = {
|
||||
protein_entry["id"]: protein_entry
|
||||
for protein_entry in _protein_entries_from_af3_input(written)
|
||||
}
|
||||
self.assertEqual(set(protein_entries), {"A", "B"})
|
||||
for chain in fold_input_obj.chains:
|
||||
if not hasattr(chain, "sequence"):
|
||||
continue
|
||||
protein_entry = protein_entries[chain.id]
|
||||
self.assertEqual(protein_entry["sequence"], chain.sequence)
|
||||
self.assertEqual(
|
||||
_a3m_query_sequence(protein_entry["unpairedMsa"]),
|
||||
chain.sequence,
|
||||
)
|
||||
|
||||
|
||||
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
||||
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
|
||||
|
||||
4
test/test_data/fastas/P04737.fasta
Normal file
4
test/test_data/fastas/P04737.fasta
Normal file
@@ -0,0 +1,4 @@
|
||||
>sp|P04737|PIL1_ECOLI Pilin OS=Escherichia coli (strain K12) OX=83333 GN=traA PE=1 SV=1
|
||||
MNAVLSVQGASAPVKKKSFFSKFTRLNMLRLARAVIPAAVLMMFFPQLAMAAGSSGQDLM
|
||||
ASGNTTVKATFGKDSSVVKWVVLAEVLVGAVMYMMTKNVKFLAGFAIISVFIAVGMAVVG
|
||||
L
|
||||
9
test/test_data/fastas/P15069.fasta
Normal file
9
test/test_data/fastas/P15069.fasta
Normal file
@@ -0,0 +1,9 @@
|
||||
>sp|P15069|TRAH1_ECOLI Protein TraH OS=Escherichia coli (strain K12) OX=83333 GN=traH PE=3 SV=2
|
||||
MMPRIKPLLVLCAALLTVTPAASADVNSDMNQFFNKLGFASNTTQPGVWQGQAAGYAYGG
|
||||
SLYARTQVKNVQLISMTLPDINAGCGGIDAYLGSFSFINGEQLQRFVKQIMSNAAGYFFD
|
||||
LALQTTVPEIKTAKDFLQKMASDINSMNLSSCQAAQGIIGGLFPRTQVSQQKVCQDIAGE
|
||||
SNIFADWAASRQGCTVGGKSDSVRDKASDKDKERVTKNINIMWNALSKNRMFDGNKELKE
|
||||
FVMTLTGSLVFGPNGEITPLSARTTDRSIIRAMMEGGTAKISHCNDSDKCLKVVADTPVT
|
||||
ISRDNALKSQITKLLASIQNKAVSDTPLDDKEKGFISSTTIPVFKYLVDPQMLGVSNSMI
|
||||
YQLTDYIGYDILLQYIQELIQQARAMVATGNYDEAVIGHINDNMNDATRQIAAFQSQVQV
|
||||
QQDALLVVDRQMSYMRQQLSARMLSRYQNNYHFGGSTL
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import os
|
||||
from urllib import error
|
||||
|
||||
import numpy as np
|
||||
@@ -11,6 +12,32 @@ from alphapulldown.objects import MonomericObject
|
||||
from alphapulldown.utils import mmseqs_species_identifiers
|
||||
|
||||
|
||||
_FASTAS_DIR = os.path.join(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
'test_data',
|
||||
'fastas',
|
||||
)
|
||||
|
||||
|
||||
def _read_uniprot_fasta(uniprot_id: str) -> str:
|
||||
path = os.path.join(_FASTAS_DIR, f'{uniprot_id}.fasta')
|
||||
with open(path, encoding='utf-8') as handle:
|
||||
lines = handle.read().splitlines()
|
||||
return ''.join(line for line in lines[1:] if line and not line.startswith('>'))
|
||||
|
||||
|
||||
def _build_colabfold_server_a3m(
|
||||
query_sequence: str, hits: list[tuple[str, str]]
|
||||
) -> str:
|
||||
"""Assemble a ColabFold-server-style A3M ('#<len>\\t1' header, one-line rows)."""
|
||||
lines = [f'#{len(query_sequence)}\t1', '>101', query_sequence]
|
||||
for header, aligned in hits:
|
||||
lines.append(f'>{header}')
|
||||
lines.append(aligned)
|
||||
lines.append('')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_species_id_cache():
|
||||
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
|
||||
@@ -419,7 +446,7 @@ def test_make_mmseq_features_researches_templates_for_precomputed_msa(
|
||||
'template_feature': 'TEMPLATE_FROM_RESEARCH',
|
||||
}
|
||||
assert calls['enrich_mmseq_feature_dict_with_identifiers'] == {
|
||||
'a3m': '>101\nACDE',
|
||||
'a3m': 'PRECOMPUTED_UNPAIRED',
|
||||
'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')},
|
||||
}
|
||||
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
|
||||
@@ -579,3 +606,110 @@ def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
|
||||
}
|
||||
assert uniprot_calls == [('A0A636IKY3',)]
|
||||
assert uniparc_calls == [('UPI001118B830',)]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'uniprot_id,accession_species',
|
||||
[
|
||||
(
|
||||
'P04737',
|
||||
{
|
||||
'A0A636IKY3': '562',
|
||||
'A0A743YDY2': '573',
|
||||
'UPI001118B830': '562',
|
||||
},
|
||||
),
|
||||
(
|
||||
'P15069',
|
||||
{
|
||||
'A0A636IKY3': '562',
|
||||
'A0A743YDY2': '573',
|
||||
'UPI001118B830': '562',
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_make_mmseq_features_precomputed_colabfold_a3m_enriches_identifiers(
|
||||
monkeypatch, tmp_path, uniprot_id, accession_species
|
||||
):
|
||||
"""Regression for issue #613: precomputed ColabFold-server A3Ms must enrich.
|
||||
|
||||
Before the fix, make_mmseq_features parsed the raw A3M for identifiers but
|
||||
fed a different processed string to build_monomer_feature, so dedup rules
|
||||
disagreed and identifier rows did not match MSA rows. This drove
|
||||
enrich_mmseq_feature_dict_with_identifiers to log a warning and skip
|
||||
enrichment, leaving species pairing unusable.
|
||||
"""
|
||||
|
||||
query = _read_uniprot_fasta(uniprot_id)
|
||||
assert len(query) > 0
|
||||
|
||||
# Realistic ColabFold-server-style A3M: '#<len>\t1' header, one header/seq per
|
||||
# line pair, a mix of exact-duplicate hits, insertion-variant hits (lowercase
|
||||
# letters) that strip to the query, an all-gap row, and point-mutation hits
|
||||
# with real-format UniProt accessions so enrichment has something to resolve.
|
||||
hits = [
|
||||
(f'sp|{uniprot_id}|QUERY_DUP', query),
|
||||
('UniRef100_A0A636IKY3', query[:10] + 'a' + query[10:]),
|
||||
('UniRef100_A0A743YDY2', query[:15] + 'bc' + query[15:]),
|
||||
('UniRef100_UPI001118B830', query[:20] + 'def' + query[20:]),
|
||||
('UniRef100_A0A100XYZ0', query[:5] + 'X' + query[6:]),
|
||||
('UniRef100_A0A200ABC5', query[:8] + 'Y' + query[9:]),
|
||||
('UniRef100_GAP_ROW', '-' * len(query)),
|
||||
('UniRef100_ALL_LOWER', 'a' * len(query)),
|
||||
]
|
||||
accession_species = {
|
||||
**accession_species,
|
||||
'A0A100XYZ0': '9606',
|
||||
'A0A200ABC5': '10090',
|
||||
}
|
||||
a3m = _build_colabfold_server_a3m(query, hits)
|
||||
|
||||
precomputed_a3m = tmp_path / f'{uniprot_id}.a3m'
|
||||
precomputed_a3m.write_text(a3m, encoding='utf-8')
|
||||
|
||||
monkeypatch.setattr(
|
||||
MonomericObject, 'unzip_msa_files', staticmethod(lambda _path: False)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
mmseqs_species_identifiers,
|
||||
'resolve_species_ids_by_accession',
|
||||
lambda accessions, **_: {
|
||||
accession: accession_species.get(accession, '')
|
||||
for accession in accessions
|
||||
},
|
||||
)
|
||||
|
||||
monomer = MonomericObject(uniprot_id, query)
|
||||
monomer.make_mmseq_features(
|
||||
DEFAULT_API_SERVER='https://unused.example',
|
||||
output_dir=str(tmp_path),
|
||||
use_precomputed_msa=True,
|
||||
use_templates=False,
|
||||
)
|
||||
|
||||
msa = monomer.feature_dict['msa']
|
||||
species = monomer.feature_dict['msa_species_identifiers']
|
||||
accessions = monomer.feature_dict['msa_uniprot_accession_identifiers']
|
||||
|
||||
assert species.shape[0] == msa.shape[0], (
|
||||
f'enrichment row count {species.shape[0]} != msa rows {msa.shape[0]}'
|
||||
)
|
||||
assert accessions.shape[0] == msa.shape[0]
|
||||
# '_all_seq' mirrors the enriched rows, used for pairing downstream.
|
||||
assert (
|
||||
monomer.feature_dict['msa_species_identifiers_all_seq'].shape[0]
|
||||
== msa.shape[0]
|
||||
)
|
||||
# The resolver is called with the real UniProt-format accessions only —
|
||||
# insertion-variant hits collapse onto the query row but the point-mutation
|
||||
# hit survives, so at least one of the resolvable accessions lands in the
|
||||
# deduped identifier rows.
|
||||
resolvable = {
|
||||
a.decode('utf-8')
|
||||
for a in accessions.tolist()
|
||||
if a.decode('utf-8') in accession_species
|
||||
}
|
||||
assert resolvable, (
|
||||
f'expected at least one resolvable accession in {accessions.tolist()}'
|
||||
)
|
||||
|
||||
@@ -397,7 +397,7 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
|
||||
|
||||
assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE")
|
||||
assert calls["enrich"] == {
|
||||
"a3m": ">101\nACDE\n>hit\nAC-E",
|
||||
"a3m": "UNPAIRED",
|
||||
"kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")},
|
||||
}
|
||||
assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101")
|
||||
@@ -421,7 +421,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
|
||||
|
||||
def fake_get_msa_and_templates(**kwargs):
|
||||
calls["get_msa_and_templates"] = kwargs
|
||||
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
|
||||
return ([">101\nACDE\n"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
|
||||
|
||||
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
|
||||
monkeypatch.setattr(
|
||||
@@ -462,7 +462,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
|
||||
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
|
||||
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
|
||||
assert calls["get_msa_and_templates"]["use_templates"] is True
|
||||
assert calls["enrich"]["a3m"] == ">101\nACDE"
|
||||
assert calls["enrich"]["a3m"] == ">101\nACDE\n"
|
||||
assert monomer.skip_msa is True
|
||||
assert monomer.feature_dict["msa"].shape == (1, 4)
|
||||
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
|
||||
@@ -757,7 +757,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
|
||||
objects_mod,
|
||||
"get_msa_and_templates",
|
||||
lambda **_kwargs: (
|
||||
["UNPAIRED"],
|
||||
[a3m_text],
|
||||
["PAIRED"],
|
||||
["UNIQUE"],
|
||||
["CARD"],
|
||||
@@ -774,7 +774,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
|
||||
objects_mod,
|
||||
"unserialize_msa",
|
||||
lambda a3m_lines, sequence: (
|
||||
["PRECOMP_MSA"],
|
||||
[a3m_text],
|
||||
["PRECOMP_PAIRED"],
|
||||
["UNIQUE"],
|
||||
["CARD"],
|
||||
@@ -1417,6 +1417,81 @@ def test_pair_and_merge_pairs_and_deduplicates_for_heteromer(monkeypatch):
|
||||
assert output == {"processed": {"merged": True}}
|
||||
|
||||
|
||||
def test_pair_and_merge_backfills_missing_all_seq_accession_ids_before_pairing(
|
||||
monkeypatch,
|
||||
):
|
||||
multimer = MultimericObject.__new__(MultimericObject)
|
||||
multimer.pair_msa = True
|
||||
calls = {}
|
||||
|
||||
chain_a = _feature_dict(sequence="ACDE", msa_rows=1, all_seq_rows=2, template_count=0)
|
||||
chain_b = _feature_dict(sequence="FGHI", msa_rows=1, all_seq_rows=2, template_count=0)
|
||||
chain_a["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
|
||||
chain_b["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
|
||||
chain_a["msa_uniprot_accession_identifiers"] = np.asarray([b"P12345"], dtype=object)
|
||||
chain_a["msa_uniprot_accession_identifiers_all_seq"] = np.asarray(
|
||||
[b"", b"P12345"],
|
||||
dtype=object,
|
||||
)
|
||||
|
||||
real_create_paired_features = objects_mod.msa_pairing.create_paired_features
|
||||
|
||||
monkeypatch.setattr(
|
||||
objects_mod.feature_processing,
|
||||
"process_unmerged_features",
|
||||
lambda _features: None,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod.feature_processing,
|
||||
"_is_homomer_or_monomer",
|
||||
lambda _chains: False,
|
||||
)
|
||||
|
||||
def wrapped_create_paired_features(*, chains):
|
||||
calls["create_paired_features"] = chains
|
||||
return real_create_paired_features(chains)
|
||||
|
||||
monkeypatch.setattr(
|
||||
objects_mod.msa_pairing,
|
||||
"create_paired_features",
|
||||
wrapped_create_paired_features,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod.msa_pairing,
|
||||
"deduplicate_unpaired_sequences",
|
||||
lambda chains: chains,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod.feature_processing,
|
||||
"crop_chains",
|
||||
lambda chains, **kwargs: chains,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod.msa_pairing,
|
||||
"merge_chain_features",
|
||||
lambda **kwargs: {"chains": kwargs["np_chains_list"]},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
objects_mod.feature_processing,
|
||||
"process_final",
|
||||
lambda example: example,
|
||||
)
|
||||
|
||||
output = multimer.pair_and_merge({"A": chain_a, "B": chain_b})
|
||||
|
||||
assert calls["create_paired_features"][1][
|
||||
"msa_uniprot_accession_identifiers"
|
||||
].tolist() == [b""]
|
||||
assert calls["create_paired_features"][1][
|
||||
"msa_uniprot_accession_identifiers_all_seq"
|
||||
].tolist() == [b"", b""]
|
||||
assert output["chains"][1]["msa_uniprot_accession_identifiers"].tolist() == [b""]
|
||||
assert output["chains"][1]["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
|
||||
|
||||
def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch):
|
||||
multimer = MultimericObject.__new__(MultimericObject)
|
||||
multimer.pair_msa = False
|
||||
|
||||
@@ -135,3 +135,46 @@ def test_all_seq_msa_features_keeps_only_pairing_related_keys(monkeypatch, tmp_p
|
||||
True,
|
||||
)
|
||||
assert run_kwargs == {}
|
||||
|
||||
|
||||
def test_all_seq_msa_features_backfills_missing_uniprot_accession_identifiers(
|
||||
monkeypatch, tmp_path
|
||||
):
|
||||
monomer = MonomericObject("desc", "ACDE")
|
||||
input_fasta_path = str(tmp_path / "input.fasta")
|
||||
Path(input_fasta_path).write_text(">x\nACDE\n", encoding="utf-8")
|
||||
|
||||
class FakeMsa:
|
||||
def truncate(self, max_seqs):
|
||||
assert max_seqs == 50000
|
||||
return self
|
||||
|
||||
monkeypatch.setattr(
|
||||
"alphapulldown.objects.pipeline.run_msa_tool",
|
||||
lambda *args, **kwargs: {"sto": "fake"},
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"alphapulldown.objects.parsers.parse_stockholm",
|
||||
lambda sto: FakeMsa(),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"alphapulldown.objects.pipeline.make_msa_features",
|
||||
lambda _msas: {
|
||||
"msa": np.asarray([[1, 2], [1, 3]], dtype=np.int32),
|
||||
"msa_species_identifiers": np.asarray([b"", b"9606"], dtype=object),
|
||||
"deletion_matrix_int": np.asarray([[0, 0], [0, 0]], dtype=np.int32),
|
||||
},
|
||||
)
|
||||
|
||||
features = monomer.all_seq_msa_features(
|
||||
input_fasta_path=input_fasta_path,
|
||||
uniprot_msa_runner="runner",
|
||||
output_dir=str(tmp_path),
|
||||
use_precomputed_msa=False,
|
||||
)
|
||||
|
||||
assert features["msa_species_identifiers_all_seq"].tolist() == [b"", b"9606"]
|
||||
assert features["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
|
||||
b"",
|
||||
b"",
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user