Harden precomputed MSA identifier enrichment

This commit is contained in:
Dima Molodenskiy
2026-04-23 09:52:54 +02:00
committed by Dima
parent ce4be4866b
commit fb09717cf3
8 changed files with 741 additions and 18 deletions

View File

@@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
) )
def _ensure_identifier_feature_arrays(
feature_dict: Dict[str, np.ndarray],
feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...],
) -> Dict[str, np.ndarray]:
"""Backfill missing identifier arrays to match the corresponding MSA rows."""
normalized = dict(feature_dict)
for msa_key, identifier_keys in feature_groups:
msa = normalized.get(msa_key)
if msa is None:
continue
num_rows = int(np.asarray(msa).shape[0])
for key in identifier_keys:
if key not in normalized:
normalized[key] = np.array([b""] * num_rows, dtype=object)
return normalized
class MonomericObject: class MonomericObject:
""" """
monomeric objects monomeric objects
@@ -145,7 +162,18 @@ class MonomericObject:
msa = parsers.parse_stockholm(result["sto"]) msa = parsers.parse_stockholm(result["sto"])
msa = msa.truncate(max_seqs=50000) msa = msa.truncate(max_seqs=50000)
all_seq_features = pipeline.make_msa_features([msa]) all_seq_features = _ensure_identifier_feature_arrays(
pipeline.make_msa_features([msa]),
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
),
)
valid_feats = msa_pairing.MSA_FEATURES + ( valid_feats = msa_pairing.MSA_FEATURES + (
"msa_species_identifiers", "msa_species_identifiers",
"msa_uniprot_accession_identifiers", "msa_uniprot_accession_identifiers",
@@ -362,9 +390,12 @@ class MonomericObject:
# Remove header lines starting with '#' if present. # Remove header lines starting with '#' if present.
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0]) a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0]) self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
# Enrich from the same A3M string that build_monomer_feature parsed, so
# the identifier rows go through the same parse_a3m dedup as msa_features
# and their count matches feature_dict['msa'] exactly.
enrich_mmseq_feature_dict_with_identifiers( enrich_mmseq_feature_dict_with_identifiers(
self.feature_dict, self.feature_dict,
a3m_lines[0], unpaired_msa[0],
cache_path=os.path.join( cache_path=os.path.join(
result_dir, f"{self.description}.mmseq_ids.json" result_dir, f"{self.description}.mmseq_ids.json"
), ),
@@ -811,12 +842,48 @@ create_individual_features.py
output_list.append(new_chain) output_list.append(new_chain)
return output_list return output_list
@staticmethod
def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]:
"""Ensure identifier arrays exist consistently across chains.
Some feature sources provide species identifiers but omit UniProt
accession IDs, while DeepMind's multimer pairing and merge code assumes
both unpaired and `_all_seq` identifier keys exist consistently across
chains.
"""
output_list = []
for feat_dict in np_chain_list:
output_list.append(
_ensure_identifier_feature_arrays(
feat_dict,
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
(
"msa_all_seq",
(
"msa_species_identifiers_all_seq",
"msa_uniprot_accession_identifiers_all_seq",
),
),
),
)
)
return output_list
def pair_and_merge(self, all_chain_features): def pair_and_merge(self, all_chain_features):
"""merge all chain features""" """merge all chain features"""
feature_processing.process_unmerged_features(all_chain_features) feature_processing.process_unmerged_features(all_chain_features)
MAX_TEMPLATES = 4 MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048 MSA_CROP_SIZE = 2048
np_chains_list = list(all_chain_features.values()) np_chains_list = MultimericObject.normalize_all_seq_identifier_features(
list(all_chain_features.values())
)
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer( pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
np_chains_list) np_chains_list)
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}") logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")

View File

@@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict:
return payload return payload
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]:
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
if len(matches) != 1:
raise FileNotFoundError(
f"Expected one feature metadata file for {protein_id} in {feature_dir}, "
f"found {matches}"
)
metadata_path = matches[0]
opener = lzma.open if metadata_path.suffix == ".xz" else open
with opener(metadata_path, "rt", encoding="utf-8") as handle:
return metadata_path, json.load(handle)
def _metadata_bool(value) -> bool:
return str(value).strip().lower() in {"1", "true", "yes"}
def _non_empty_identifier_count(values) -> int: def _non_empty_identifier_count(values) -> int:
count = 0 count = 0
for value in values: for value in values:
@@ -105,11 +122,12 @@ def _non_empty_identifier_count(values) -> int:
def _af2_subprocess_env() -> dict[str, str]: def _af2_subprocess_env() -> dict[str, str]:
"""Return stable GPU/JAX defaults for AF2 functional subprocesses.""" """Return stable GPU/JAX defaults for AF2 functional subprocesses."""
env = os.environ.copy() env = os.environ.copy()
env.setdefault("OMP_NUM_THREADS", "4") env.setdefault("OMP_NUM_THREADS", "1")
env.setdefault("MKL_NUM_THREADS", "4") env.setdefault("OPENBLAS_NUM_THREADS", "1")
env.setdefault("NUMEXPR_NUM_THREADS", "4") env.setdefault("MKL_NUM_THREADS", "1")
env.setdefault("TF_NUM_INTEROP_THREADS", "4") env.setdefault("NUMEXPR_NUM_THREADS", "1")
env.setdefault("TF_NUM_INTRAOP_THREADS", "4") env.setdefault("TF_NUM_INTEROP_THREADS", "1")
env.setdefault("TF_NUM_INTRAOP_THREADS", "1")
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false") env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
@@ -117,7 +135,8 @@ def _af2_subprocess_env() -> dict[str, str]:
env.setdefault("JAX_PLATFORM_NAME", "gpu") env.setdefault("JAX_PLATFORM_NAME", "gpu")
env.setdefault( env.setdefault(
"XLA_FLAGS", "XLA_FLAGS",
"--xla_gpu_force_compilation_parallelism=0 " "--xla_gpu_force_compilation_parallelism=1 "
"--xla_force_host_platform_device_count=1 "
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1", "--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
) )
return env return env
@@ -544,6 +563,87 @@ class TestMmseqsIssue588Inference(_TestBase):
) )
return feature_dir return feature_dir
def _generate_issue_588_precomputed_mmseq_features(self) -> Path:
source_dir = self.output_dir / "issue_588_mmseq_source_features"
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
source_dir.mkdir(parents=True, exist_ok=True)
precomputed_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
source_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={source_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
source_res.returncode,
0,
"MMseqs source feature generation failed.\n"
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(source_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(source_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
)
shutil.copy2(
source_dir / f"{protein_id}.a3m",
precomputed_dir / f"{protein_id}.a3m",
)
sidecar = source_dir / f"{protein_id}.mmseq_ids.json"
if sidecar.is_file():
shutil.copy2(sidecar, precomputed_dir / sidecar.name)
precomputed_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={precomputed_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--use_precomputed_msas=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
precomputed_res.returncode,
0,
"Precomputed-MMseq feature generation failed.\n"
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(precomputed_dir / f"{protein_id}.a3m").is_file(),
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
)
self.assertTrue(
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
)
return precomputed_dir
def _resolve_af2_result_dir(self, root: Path) -> Path: def _resolve_af2_result_dir(self, root: Path) -> Path:
if (root / "ranking_debug.json").exists(): if (root / "ranking_debug.json").exists():
return root return root
@@ -642,5 +742,73 @@ class TestMmseqsIssue588Inference(_TestBase):
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}", f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
) )
def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self):
"""Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms."""
self._require_mmseqs_functional_environment()
feature_dir = self._generate_issue_588_precomputed_mmseq_features()
for protein_id in self.ISSUE_588_IDS:
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
self.assertTrue(
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
f"{metadata_path} should record use_precomputed_msas=True",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
)
prediction_dir = self.output_dir / "af2_precomputed_prediction"
prediction_dir.mkdir(parents=True, exist_ok=True)
res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={prediction_dir}",
"--num_cycle=1",
"--num_predictions_per_model=1",
"--model_names=model_4_multimer_v3",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--random_seed=42",
]
)
self.assertEqual(
res.returncode,
0,
"AF2 inference from precomputed MMseq features failed.\n"
f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
result_dir = self._resolve_af2_result_dir(prediction_dir)
ranking_payload = json.loads(
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
)
self.assertTrue(ranking_payload["order"])
result_pickles = sorted(result_dir.glob("result_*.pkl"))
self.assertLen(result_pickles, 1)
with result_pickles[0].open("rb") as handle:
result_payload = pickle.load(handle)
self.assertIn("iptm", result_payload)
self.assertIn("ranking_confidence", result_payload)
self.assertGreater(
result_payload["iptm"],
0.6,
f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}",
)
if __name__ == "__main__": if __name__ == "__main__":
absltest.main() absltest.main()

View File

@@ -20,7 +20,9 @@ import json
import numpy as np import numpy as np
import re import re
import unittest import unittest
from types import SimpleNamespace
from typing import Dict, List, Tuple, Any from typing import Dict, List, Tuple, Any
from unittest import mock
from absl.testing import absltest, parameterized from absl.testing import absltest, parameterized
@@ -151,14 +153,24 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
def _load_feature_dict(feature_path: Path) -> dict[str, Any]: def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
opener = lzma.open if feature_path.suffix == ".xz" else open payload = _load_feature_payload(feature_path)
with opener(feature_path, "rb") as handle:
payload = pickle.load(handle)
if hasattr(payload, "feature_dict"): if hasattr(payload, "feature_dict"):
return payload.feature_dict return payload.feature_dict
return payload return payload
def _load_feature_payload(feature_path: Path) -> Any:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
return pickle.load(handle)
def _write_feature_payload(feature_path: Path, payload: Any) -> None:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "wb") as handle:
pickle.dump(payload, handle)
def _non_empty_identifier_count(values) -> int: def _non_empty_identifier_count(values) -> int:
count = 0 count = 0
for value in values: for value in values:
@@ -1407,6 +1419,106 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
fold_input_obj, _ = next(iter(mappings[0].items())) fold_input_obj, _ = next(iter(mappings[0].items()))
return fold_input_obj return fold_input_obj
def _copy_real_feature_fixture(
self,
*,
source_dir: Path,
protein_id: str,
target_dir: Path,
) -> Path:
copied_feature_path = None
for pattern in (
f"{protein_id}.pkl",
f"{protein_id}.pkl.xz",
f"{protein_id}.a3m",
f"{protein_id}_feature_metadata_*.json*",
):
for source_path in sorted(source_dir.glob(pattern)):
target_path = target_dir / source_path.name
shutil.copy2(source_path, target_path)
if source_path.name.startswith(f"{protein_id}.pkl"):
copied_feature_path = target_path
self.assertIsNotNone(
copied_feature_path,
f"Missing real feature fixture for {protein_id} in {source_dir}",
)
return copied_feature_path
@staticmethod
def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray:
identifiers = []
for index, value in enumerate(species_ids):
if isinstance(value, bytes):
value = value.decode("utf-8")
identifiers.append(
f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b""
)
return np.asarray(identifiers, dtype=object)
def _prepare_mixed_identifier_fixture_dir(self) -> Path:
"""Materialize real AF2 fixtures with mixed identifier enrichment.
The underlying MSA rows come from repo fixtures in `test/test_data`.
We only adjust the identifier sidecars so one chain looks enriched while
the other reproduces the "no species enrichment / no accession IDs"
failure mode from issue #614's AF3 follow-up comment.
"""
feature_dir = self.output_dir / "mixed_identifier_features"
feature_dir.mkdir(parents=True, exist_ok=True)
source_dir = self.test_features_dir / "af2_features" / "protein"
enriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="A0A024R1R8",
target_dir=feature_dir,
)
unenriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="P61626",
target_dir=feature_dir,
)
enriched_payload = _load_feature_payload(enriched_feature_path)
enriched_feature_dict = (
enriched_payload.feature_dict
if hasattr(enriched_payload, "feature_dict")
else enriched_payload
)
enriched_feature_dict["msa_uniprot_accession_identifiers"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers"])
)
)
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"])
)
)
_write_feature_payload(enriched_feature_path, enriched_payload)
unenriched_payload = _load_feature_payload(unenriched_feature_path)
unenriched_feature_dict = (
unenriched_payload.feature_dict
if hasattr(unenriched_payload, "feature_dict")
else unenriched_payload
)
unenriched_feature_dict["msa_species_identifiers"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]),
dtype=object,
)
unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]),
dtype=object,
)
unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None)
unenriched_feature_dict.pop(
"msa_uniprot_accession_identifiers_all_seq", None
)
_write_feature_payload(unenriched_feature_path, unenriched_payload)
return feature_dir
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self): def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures.""" """Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
@@ -1638,6 +1750,117 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
all(template["mmcif"] for template in protein_entries[0]["templates"]) all(template["mmcif"] for template in protein_entries[0]["templates"])
) )
def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self):
"""AF3 prep should tolerate a real mixed-enrichment multimer feature set."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
from alphapulldown.scripts import run_structure_prediction
feature_dir = self._prepare_mixed_identifier_fixture_dir()
enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl")
self.assertGreater(
_non_empty_identifier_count(
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
)
unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl")
self.assertEqual(
_non_empty_identifier_count(
unenriched_feature_dict["msa_species_identifiers_all_seq"]
),
0,
)
self.assertNotIn(
"msa_uniprot_accession_identifiers_all_seq",
unenriched_feature_dict,
)
script_flags = SimpleNamespace(
pair_msa=True,
multimeric_template=False,
description_file=None,
path_to_mmt=None,
threshold_clashes=1000,
hb_allowance=0.4,
plddt_threshold=0,
save_features_for_multimeric_object=False,
features_directory=[str(feature_dir)],
use_ap_style=False,
)
with mock.patch.object(run_structure_prediction, "FLAGS", script_flags):
parsed = run_structure_prediction.parse_fold(
["A0A024R1R8+P61626"],
[str(feature_dir)],
"+",
)
data = run_structure_prediction.create_custom_info(parsed)
all_interactors = run_structure_prediction.create_interactors(
data,
[str(feature_dir)],
)
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
object_to_model, prepared_output_dir = (
run_structure_prediction.pre_modelling_setup(
all_interactors[0],
output_dir=str(self.output_dir / "mixed_identifier_prediction"),
)
)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{"object": object_to_model, "output_dir": prepared_output_dir}
],
random_seed=42,
debug_msas=True,
)
self.assertLen(mappings, 1)
fold_input_obj, (
prepared_output_dir,
resolve_msa_overlaps,
) = next(iter(mappings[0].items()))
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=prepared_output_dir,
buckets=(512,),
resolve_msa_overlaps=resolve_msa_overlaps,
)
job_name = fold_input_obj.sanitised_name()
summary_path = (
Path(prepared_output_dir)
/ f"{job_name}_af2_to_af3_translation_summary.json"
)
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertLen(summary["chains"], 2)
self.assertTrue(summary["unpaired_rows_valid"])
input_json = Path(prepared_output_dir) / f"{job_name}_data.json"
self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}")
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = {
protein_entry["id"]: protein_entry
for protein_entry in _protein_entries_from_af3_input(written)
}
self.assertEqual(set(protein_entries), {"A", "B"})
for chain in fold_input_obj.chains:
if not hasattr(chain, "sequence"):
continue
protein_entry = protein_entries[chain.id]
self.assertEqual(protein_entry["sequence"], chain.sequence)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
chain.sequence,
)
class TestAlphaFold3MmseqsIssue588Inference(_TestBase): class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features.""" """Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""

View File

@@ -0,0 +1,4 @@
>sp|P04737|PIL1_ECOLI Pilin OS=Escherichia coli (strain K12) OX=83333 GN=traA PE=1 SV=1
MNAVLSVQGASAPVKKKSFFSKFTRLNMLRLARAVIPAAVLMMFFPQLAMAAGSSGQDLM
ASGNTTVKATFGKDSSVVKWVVLAEVLVGAVMYMMTKNVKFLAGFAIISVFIAVGMAVVG
L

View File

@@ -0,0 +1,9 @@
>sp|P15069|TRAH1_ECOLI Protein TraH OS=Escherichia coli (strain K12) OX=83333 GN=traH PE=3 SV=2
MMPRIKPLLVLCAALLTVTPAASADVNSDMNQFFNKLGFASNTTQPGVWQGQAAGYAYGG
SLYARTQVKNVQLISMTLPDINAGCGGIDAYLGSFSFINGEQLQRFVKQIMSNAAGYFFD
LALQTTVPEIKTAKDFLQKMASDINSMNLSSCQAAQGIIGGLFPRTQVSQQKVCQDIAGE
SNIFADWAASRQGCTVGGKSDSVRDKASDKDKERVTKNINIMWNALSKNRMFDGNKELKE
FVMTLTGSLVFGPNGEITPLSARTTDRSIIRAMMEGGTAKISHCNDSDKCLKVVADTPVT
ISRDNALKSQITKLLASIQNKAVSDTPLDDKEKGFISSTTIPVFKYLVDPQMLGVSNSMI
YQLTDYIGYDILLQYIQELIQQARAMVATGNYDEAVIGHINDNMNDATRQIAAFQSQVQV
QQDALLVVDRQMSYMRQQLSARMLSRYQNNYHFGGSTL

View File

@@ -1,4 +1,5 @@
import json import json
import os
from urllib import error from urllib import error
import numpy as np import numpy as np
@@ -11,6 +12,32 @@ from alphapulldown.objects import MonomericObject
from alphapulldown.utils import mmseqs_species_identifiers from alphapulldown.utils import mmseqs_species_identifiers
_FASTAS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'test_data',
'fastas',
)
def _read_uniprot_fasta(uniprot_id: str) -> str:
path = os.path.join(_FASTAS_DIR, f'{uniprot_id}.fasta')
with open(path, encoding='utf-8') as handle:
lines = handle.read().splitlines()
return ''.join(line for line in lines[1:] if line and not line.startswith('>'))
def _build_colabfold_server_a3m(
query_sequence: str, hits: list[tuple[str, str]]
) -> str:
"""Assemble a ColabFold-server-style A3M ('#<len>\\t1' header, one-line rows)."""
lines = [f'#{len(query_sequence)}\t1', '>101', query_sequence]
for header, aligned in hits:
lines.append(f'>{header}')
lines.append(aligned)
lines.append('')
return '\n'.join(lines)
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_species_id_cache(): def clear_species_id_cache():
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear() mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
@@ -419,7 +446,7 @@ def test_make_mmseq_features_researches_templates_for_precomputed_msa(
'template_feature': 'TEMPLATE_FROM_RESEARCH', 'template_feature': 'TEMPLATE_FROM_RESEARCH',
} }
assert calls['enrich_mmseq_feature_dict_with_identifiers'] == { assert calls['enrich_mmseq_feature_dict_with_identifiers'] == {
'a3m': '>101\nACDE', 'a3m': 'PRECOMPUTED_UNPAIRED',
'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')}, 'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')},
} }
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray) assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
@@ -579,3 +606,110 @@ def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
} }
assert uniprot_calls == [('A0A636IKY3',)] assert uniprot_calls == [('A0A636IKY3',)]
assert uniparc_calls == [('UPI001118B830',)] assert uniparc_calls == [('UPI001118B830',)]
@pytest.mark.parametrize(
'uniprot_id,accession_species',
[
(
'P04737',
{
'A0A636IKY3': '562',
'A0A743YDY2': '573',
'UPI001118B830': '562',
},
),
(
'P15069',
{
'A0A636IKY3': '562',
'A0A743YDY2': '573',
'UPI001118B830': '562',
},
),
],
)
def test_make_mmseq_features_precomputed_colabfold_a3m_enriches_identifiers(
monkeypatch, tmp_path, uniprot_id, accession_species
):
"""Regression for issue #613: precomputed ColabFold-server A3Ms must enrich.
Before the fix, make_mmseq_features parsed the raw A3M for identifiers but
fed a different processed string to build_monomer_feature, so dedup rules
disagreed and identifier rows did not match MSA rows. This drove
enrich_mmseq_feature_dict_with_identifiers to log a warning and skip
enrichment, leaving species pairing unusable.
"""
query = _read_uniprot_fasta(uniprot_id)
assert len(query) > 0
# Realistic ColabFold-server-style A3M: '#<len>\t1' header, one header/seq per
# line pair, a mix of exact-duplicate hits, insertion-variant hits (lowercase
# letters) that strip to the query, an all-gap row, and point-mutation hits
# with real-format UniProt accessions so enrichment has something to resolve.
hits = [
(f'sp|{uniprot_id}|QUERY_DUP', query),
('UniRef100_A0A636IKY3', query[:10] + 'a' + query[10:]),
('UniRef100_A0A743YDY2', query[:15] + 'bc' + query[15:]),
('UniRef100_UPI001118B830', query[:20] + 'def' + query[20:]),
('UniRef100_A0A100XYZ0', query[:5] + 'X' + query[6:]),
('UniRef100_A0A200ABC5', query[:8] + 'Y' + query[9:]),
('UniRef100_GAP_ROW', '-' * len(query)),
('UniRef100_ALL_LOWER', 'a' * len(query)),
]
accession_species = {
**accession_species,
'A0A100XYZ0': '9606',
'A0A200ABC5': '10090',
}
a3m = _build_colabfold_server_a3m(query, hits)
precomputed_a3m = tmp_path / f'{uniprot_id}.a3m'
precomputed_a3m.write_text(a3m, encoding='utf-8')
monkeypatch.setattr(
MonomericObject, 'unzip_msa_files', staticmethod(lambda _path: False)
)
monkeypatch.setattr(
mmseqs_species_identifiers,
'resolve_species_ids_by_accession',
lambda accessions, **_: {
accession: accession_species.get(accession, '')
for accession in accessions
},
)
monomer = MonomericObject(uniprot_id, query)
monomer.make_mmseq_features(
DEFAULT_API_SERVER='https://unused.example',
output_dir=str(tmp_path),
use_precomputed_msa=True,
use_templates=False,
)
msa = monomer.feature_dict['msa']
species = monomer.feature_dict['msa_species_identifiers']
accessions = monomer.feature_dict['msa_uniprot_accession_identifiers']
assert species.shape[0] == msa.shape[0], (
f'enrichment row count {species.shape[0]} != msa rows {msa.shape[0]}'
)
assert accessions.shape[0] == msa.shape[0]
# '_all_seq' mirrors the enriched rows, used for pairing downstream.
assert (
monomer.feature_dict['msa_species_identifiers_all_seq'].shape[0]
== msa.shape[0]
)
# The resolver is called with the real UniProt-format accessions only —
# insertion-variant hits collapse onto the query row but the point-mutation
# hit survives, so at least one of the resolvable accessions lands in the
# deduped identifier rows.
resolvable = {
a.decode('utf-8')
for a in accessions.tolist()
if a.decode('utf-8') in accession_species
}
assert resolvable, (
f'expected at least one resolvable accession in {accessions.tolist()}'
)

View File

@@ -397,7 +397,7 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE") assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE")
assert calls["enrich"] == { assert calls["enrich"] == {
"a3m": ">101\nACDE\n>hit\nAC-E", "a3m": "UNPAIRED",
"kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")}, "kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")},
} }
assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101") assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101")
@@ -421,7 +421,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
def fake_get_msa_and_templates(**kwargs): def fake_get_msa_and_templates(**kwargs):
calls["get_msa_and_templates"] = kwargs calls["get_msa_and_templates"] = kwargs
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"]) return ([">101\nACDE\n"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates) monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
monkeypatch.setattr( monkeypatch.setattr(
@@ -462,7 +462,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
assert calls["get_msa_and_templates"]["pair_mode"] == "none" assert calls["get_msa_and_templates"]["pair_mode"] == "none"
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"] assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
assert calls["get_msa_and_templates"]["use_templates"] is True assert calls["get_msa_and_templates"]["use_templates"] is True
assert calls["enrich"]["a3m"] == ">101\nACDE" assert calls["enrich"]["a3m"] == ">101\nACDE\n"
assert monomer.skip_msa is True assert monomer.skip_msa is True
assert monomer.feature_dict["msa"].shape == (1, 4) assert monomer.feature_dict["msa"].shape == (1, 4)
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4) assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
@@ -757,7 +757,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
objects_mod, objects_mod,
"get_msa_and_templates", "get_msa_and_templates",
lambda **_kwargs: ( lambda **_kwargs: (
["UNPAIRED"], [a3m_text],
["PAIRED"], ["PAIRED"],
["UNIQUE"], ["UNIQUE"],
["CARD"], ["CARD"],
@@ -774,7 +774,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
objects_mod, objects_mod,
"unserialize_msa", "unserialize_msa",
lambda a3m_lines, sequence: ( lambda a3m_lines, sequence: (
["PRECOMP_MSA"], [a3m_text],
["PRECOMP_PAIRED"], ["PRECOMP_PAIRED"],
["UNIQUE"], ["UNIQUE"],
["CARD"], ["CARD"],
@@ -1417,6 +1417,81 @@ def test_pair_and_merge_pairs_and_deduplicates_for_heteromer(monkeypatch):
assert output == {"processed": {"merged": True}} assert output == {"processed": {"merged": True}}
def test_pair_and_merge_backfills_missing_all_seq_accession_ids_before_pairing(
monkeypatch,
):
multimer = MultimericObject.__new__(MultimericObject)
multimer.pair_msa = True
calls = {}
chain_a = _feature_dict(sequence="ACDE", msa_rows=1, all_seq_rows=2, template_count=0)
chain_b = _feature_dict(sequence="FGHI", msa_rows=1, all_seq_rows=2, template_count=0)
chain_a["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
chain_b["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
chain_a["msa_uniprot_accession_identifiers"] = np.asarray([b"P12345"], dtype=object)
chain_a["msa_uniprot_accession_identifiers_all_seq"] = np.asarray(
[b"", b"P12345"],
dtype=object,
)
real_create_paired_features = objects_mod.msa_pairing.create_paired_features
monkeypatch.setattr(
objects_mod.feature_processing,
"process_unmerged_features",
lambda _features: None,
)
monkeypatch.setattr(
objects_mod.feature_processing,
"_is_homomer_or_monomer",
lambda _chains: False,
)
def wrapped_create_paired_features(*, chains):
calls["create_paired_features"] = chains
return real_create_paired_features(chains)
monkeypatch.setattr(
objects_mod.msa_pairing,
"create_paired_features",
wrapped_create_paired_features,
)
monkeypatch.setattr(
objects_mod.msa_pairing,
"deduplicate_unpaired_sequences",
lambda chains: chains,
)
monkeypatch.setattr(
objects_mod.feature_processing,
"crop_chains",
lambda chains, **kwargs: chains,
)
monkeypatch.setattr(
objects_mod.msa_pairing,
"merge_chain_features",
lambda **kwargs: {"chains": kwargs["np_chains_list"]},
)
monkeypatch.setattr(
objects_mod.feature_processing,
"process_final",
lambda example: example,
)
output = multimer.pair_and_merge({"A": chain_a, "B": chain_b})
assert calls["create_paired_features"][1][
"msa_uniprot_accession_identifiers"
].tolist() == [b""]
assert calls["create_paired_features"][1][
"msa_uniprot_accession_identifiers_all_seq"
].tolist() == [b"", b""]
assert output["chains"][1]["msa_uniprot_accession_identifiers"].tolist() == [b""]
assert output["chains"][1]["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b"",
b"",
]
def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch): def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch):
multimer = MultimericObject.__new__(MultimericObject) multimer = MultimericObject.__new__(MultimericObject)
multimer.pair_msa = False multimer.pair_msa = False

View File

@@ -135,3 +135,46 @@ def test_all_seq_msa_features_keeps_only_pairing_related_keys(monkeypatch, tmp_p
True, True,
) )
assert run_kwargs == {} assert run_kwargs == {}
def test_all_seq_msa_features_backfills_missing_uniprot_accession_identifiers(
monkeypatch, tmp_path
):
monomer = MonomericObject("desc", "ACDE")
input_fasta_path = str(tmp_path / "input.fasta")
Path(input_fasta_path).write_text(">x\nACDE\n", encoding="utf-8")
class FakeMsa:
def truncate(self, max_seqs):
assert max_seqs == 50000
return self
monkeypatch.setattr(
"alphapulldown.objects.pipeline.run_msa_tool",
lambda *args, **kwargs: {"sto": "fake"},
)
monkeypatch.setattr(
"alphapulldown.objects.parsers.parse_stockholm",
lambda sto: FakeMsa(),
)
monkeypatch.setattr(
"alphapulldown.objects.pipeline.make_msa_features",
lambda _msas: {
"msa": np.asarray([[1, 2], [1, 3]], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"", b"9606"], dtype=object),
"deletion_matrix_int": np.asarray([[0, 0], [0, 0]], dtype=np.int32),
},
)
features = monomer.all_seq_msa_features(
input_fasta_path=input_fasta_path,
uniprot_msa_runner="runner",
output_dir=str(tmp_path),
use_precomputed_msa=False,
)
assert features["msa_species_identifiers_all_seq"].tolist() == [b"", b"9606"]
assert features["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b"",
b"",
]