Harden precomputed MSA identifier enrichment

This commit is contained in:
Dima Molodenskiy
2026-04-23 09:52:54 +02:00
committed by Dima
parent ce4be4866b
commit fb09717cf3
8 changed files with 741 additions and 18 deletions

View File

@@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
)
def _ensure_identifier_feature_arrays(
feature_dict: Dict[str, np.ndarray],
feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...],
) -> Dict[str, np.ndarray]:
"""Backfill missing identifier arrays to match the corresponding MSA rows."""
normalized = dict(feature_dict)
for msa_key, identifier_keys in feature_groups:
msa = normalized.get(msa_key)
if msa is None:
continue
num_rows = int(np.asarray(msa).shape[0])
for key in identifier_keys:
if key not in normalized:
normalized[key] = np.array([b""] * num_rows, dtype=object)
return normalized
class MonomericObject:
"""
monomeric objects
@@ -145,7 +162,18 @@ class MonomericObject:
msa = parsers.parse_stockholm(result["sto"])
msa = msa.truncate(max_seqs=50000)
all_seq_features = pipeline.make_msa_features([msa])
all_seq_features = _ensure_identifier_feature_arrays(
pipeline.make_msa_features([msa]),
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
),
)
valid_feats = msa_pairing.MSA_FEATURES + (
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
@@ -362,9 +390,12 @@ class MonomericObject:
# Remove header lines starting with '#' if present.
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
# Enrich from the same A3M string that build_monomer_feature parsed, so
# the identifier rows go through the same parse_a3m dedup as msa_features
# and their count matches feature_dict['msa'] exactly.
enrich_mmseq_feature_dict_with_identifiers(
self.feature_dict,
a3m_lines[0],
unpaired_msa[0],
cache_path=os.path.join(
result_dir, f"{self.description}.mmseq_ids.json"
),
@@ -811,12 +842,48 @@ create_individual_features.py
output_list.append(new_chain)
return output_list
@staticmethod
def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]:
"""Ensure identifier arrays exist consistently across chains.
Some feature sources provide species identifiers but omit UniProt
accession IDs, while DeepMind's multimer pairing and merge code assumes
both unpaired and `_all_seq` identifier keys exist consistently across
chains.
"""
output_list = []
for feat_dict in np_chain_list:
output_list.append(
_ensure_identifier_feature_arrays(
feat_dict,
(
(
"msa",
(
"msa_species_identifiers",
"msa_uniprot_accession_identifiers",
),
),
(
"msa_all_seq",
(
"msa_species_identifiers_all_seq",
"msa_uniprot_accession_identifiers_all_seq",
),
),
),
)
)
return output_list
def pair_and_merge(self, all_chain_features):
"""merge all chain features"""
feature_processing.process_unmerged_features(all_chain_features)
MAX_TEMPLATES = 4
MSA_CROP_SIZE = 2048
np_chains_list = list(all_chain_features.values())
np_chains_list = MultimericObject.normalize_all_seq_identifier_features(
list(all_chain_features.values())
)
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
np_chains_list)
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")

View File

@@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict:
return payload
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]:
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
if len(matches) != 1:
raise FileNotFoundError(
f"Expected one feature metadata file for {protein_id} in {feature_dir}, "
f"found {matches}"
)
metadata_path = matches[0]
opener = lzma.open if metadata_path.suffix == ".xz" else open
with opener(metadata_path, "rt", encoding="utf-8") as handle:
return metadata_path, json.load(handle)
def _metadata_bool(value) -> bool:
return str(value).strip().lower() in {"1", "true", "yes"}
def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
@@ -105,11 +122,12 @@ def _non_empty_identifier_count(values) -> int:
def _af2_subprocess_env() -> dict[str, str]:
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
env = os.environ.copy()
env.setdefault("OMP_NUM_THREADS", "4")
env.setdefault("MKL_NUM_THREADS", "4")
env.setdefault("NUMEXPR_NUM_THREADS", "4")
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
env.setdefault("OMP_NUM_THREADS", "1")
env.setdefault("OPENBLAS_NUM_THREADS", "1")
env.setdefault("MKL_NUM_THREADS", "1")
env.setdefault("NUMEXPR_NUM_THREADS", "1")
env.setdefault("TF_NUM_INTEROP_THREADS", "1")
env.setdefault("TF_NUM_INTRAOP_THREADS", "1")
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
@@ -117,7 +135,8 @@ def _af2_subprocess_env() -> dict[str, str]:
env.setdefault("JAX_PLATFORM_NAME", "gpu")
env.setdefault(
"XLA_FLAGS",
"--xla_gpu_force_compilation_parallelism=0 "
"--xla_gpu_force_compilation_parallelism=1 "
"--xla_force_host_platform_device_count=1 "
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
)
return env
@@ -544,6 +563,87 @@ class TestMmseqsIssue588Inference(_TestBase):
)
return feature_dir
def _generate_issue_588_precomputed_mmseq_features(self) -> Path:
source_dir = self.output_dir / "issue_588_mmseq_source_features"
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
source_dir.mkdir(parents=True, exist_ok=True)
precomputed_dir.mkdir(parents=True, exist_ok=True)
fasta_paths = ",".join(
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
for protein_id in self.ISSUE_588_IDS
)
source_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={source_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--data_pipeline=alphafold2",
"--save_msa_files=True",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
source_res.returncode,
0,
"MMseqs source feature generation failed.\n"
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(source_dir / f"{protein_id}.a3m").is_file(),
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
)
self.assertTrue(
(source_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
)
shutil.copy2(
source_dir / f"{protein_id}.a3m",
precomputed_dir / f"{protein_id}.a3m",
)
sidecar = source_dir / f"{protein_id}.mmseq_ids.json"
if sidecar.is_file():
shutil.copy2(sidecar, precomputed_dir / sidecar.name)
precomputed_res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_create_features),
f"--fasta_paths={fasta_paths}",
f"--output_dir={precomputed_dir}",
f"--data_dir={DATA_DIR}",
"--max_template_date=2024-05-02",
"--use_mmseqs2=True",
"--use_precomputed_msas=True",
"--data_pipeline=alphafold2",
"--compress_features=True",
"--skip_existing=False",
]
)
self.assertEqual(
precomputed_res.returncode,
0,
"Precomputed-MMseq feature generation failed.\n"
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
)
for protein_id in self.ISSUE_588_IDS:
self.assertTrue(
(precomputed_dir / f"{protein_id}.a3m").is_file(),
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
)
self.assertTrue(
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
)
return precomputed_dir
def _resolve_af2_result_dir(self, root: Path) -> Path:
if (root / "ranking_debug.json").exists():
return root
@@ -642,5 +742,73 @@ class TestMmseqsIssue588Inference(_TestBase):
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
)
def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self):
"""Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms."""
self._require_mmseqs_functional_environment()
feature_dir = self._generate_issue_588_precomputed_mmseq_features()
for protein_id in self.ISSUE_588_IDS:
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
self.assertTrue(
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
f"{metadata_path} should record use_precomputed_msas=True",
)
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_species_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
)
self.assertGreater(
_non_empty_identifier_count(
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
)
prediction_dir = self.output_dir / "af2_precomputed_prediction"
prediction_dir.mkdir(parents=True, exist_ok=True)
res = self._run_prediction_subprocess(
[
sys.executable,
str(self.script_single),
"--input=A0ABD7FQG0+P18004",
f"--output_directory={prediction_dir}",
"--num_cycle=1",
"--num_predictions_per_model=1",
"--model_names=model_4_multimer_v3",
f"--data_directory={DATA_DIR}",
f"--features_directory={feature_dir}",
"--random_seed=42",
]
)
self.assertEqual(
res.returncode,
0,
"AF2 inference from precomputed MMseq features failed.\n"
f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
)
result_dir = self._resolve_af2_result_dir(prediction_dir)
ranking_payload = json.loads(
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
)
self.assertTrue(ranking_payload["order"])
result_pickles = sorted(result_dir.glob("result_*.pkl"))
self.assertLen(result_pickles, 1)
with result_pickles[0].open("rb") as handle:
result_payload = pickle.load(handle)
self.assertIn("iptm", result_payload)
self.assertIn("ranking_confidence", result_payload)
self.assertGreater(
result_payload["iptm"],
0.6,
f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}",
)
if __name__ == "__main__":
absltest.main()

View File

@@ -20,7 +20,9 @@ import json
import numpy as np
import re
import unittest
from types import SimpleNamespace
from typing import Dict, List, Tuple, Any
from unittest import mock
from absl.testing import absltest, parameterized
@@ -151,14 +153,24 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
payload = pickle.load(handle)
payload = _load_feature_payload(feature_path)
if hasattr(payload, "feature_dict"):
return payload.feature_dict
return payload
def _load_feature_payload(feature_path: Path) -> Any:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "rb") as handle:
return pickle.load(handle)
def _write_feature_payload(feature_path: Path, payload: Any) -> None:
opener = lzma.open if feature_path.suffix == ".xz" else open
with opener(feature_path, "wb") as handle:
pickle.dump(payload, handle)
def _non_empty_identifier_count(values) -> int:
count = 0
for value in values:
@@ -1407,6 +1419,106 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
fold_input_obj, _ = next(iter(mappings[0].items()))
return fold_input_obj
def _copy_real_feature_fixture(
self,
*,
source_dir: Path,
protein_id: str,
target_dir: Path,
) -> Path:
copied_feature_path = None
for pattern in (
f"{protein_id}.pkl",
f"{protein_id}.pkl.xz",
f"{protein_id}.a3m",
f"{protein_id}_feature_metadata_*.json*",
):
for source_path in sorted(source_dir.glob(pattern)):
target_path = target_dir / source_path.name
shutil.copy2(source_path, target_path)
if source_path.name.startswith(f"{protein_id}.pkl"):
copied_feature_path = target_path
self.assertIsNotNone(
copied_feature_path,
f"Missing real feature fixture for {protein_id} in {source_dir}",
)
return copied_feature_path
@staticmethod
def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray:
identifiers = []
for index, value in enumerate(species_ids):
if isinstance(value, bytes):
value = value.decode("utf-8")
identifiers.append(
f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b""
)
return np.asarray(identifiers, dtype=object)
def _prepare_mixed_identifier_fixture_dir(self) -> Path:
"""Materialize real AF2 fixtures with mixed identifier enrichment.
The underlying MSA rows come from repo fixtures in `test/test_data`.
We only adjust the identifier sidecars so one chain looks enriched while
the other reproduces the "no species enrichment / no accession IDs"
failure mode from issue #614's AF3 follow-up comment.
"""
feature_dir = self.output_dir / "mixed_identifier_features"
feature_dir.mkdir(parents=True, exist_ok=True)
source_dir = self.test_features_dir / "af2_features" / "protein"
enriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="A0A024R1R8",
target_dir=feature_dir,
)
unenriched_feature_path = self._copy_real_feature_fixture(
source_dir=source_dir,
protein_id="P61626",
target_dir=feature_dir,
)
enriched_payload = _load_feature_payload(enriched_feature_path)
enriched_feature_dict = (
enriched_payload.feature_dict
if hasattr(enriched_payload, "feature_dict")
else enriched_payload
)
enriched_feature_dict["msa_uniprot_accession_identifiers"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers"])
)
)
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = (
self._synthetic_accession_ids(
np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"])
)
)
_write_feature_payload(enriched_feature_path, enriched_payload)
unenriched_payload = _load_feature_payload(unenriched_feature_path)
unenriched_feature_dict = (
unenriched_payload.feature_dict
if hasattr(unenriched_payload, "feature_dict")
else unenriched_payload
)
unenriched_feature_dict["msa_species_identifiers"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]),
dtype=object,
)
unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray(
[b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]),
dtype=object,
)
unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None)
unenriched_feature_dict.pop(
"msa_uniprot_accession_identifiers_all_seq", None
)
_write_feature_payload(unenriched_feature_path, unenriched_payload)
return feature_dir
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
@@ -1638,6 +1750,117 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
all(template["mmcif"] for template in protein_entries[0]["templates"])
)
def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self):
"""AF3 prep should tolerate a real mixed-enrichment multimer feature set."""
from alphapulldown.folding_backend.alphafold3_backend import (
AlphaFold3Backend,
process_fold_input,
)
from alphapulldown.scripts import run_structure_prediction
feature_dir = self._prepare_mixed_identifier_fixture_dir()
enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl")
self.assertGreater(
_non_empty_identifier_count(
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"]
),
0,
)
unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl")
self.assertEqual(
_non_empty_identifier_count(
unenriched_feature_dict["msa_species_identifiers_all_seq"]
),
0,
)
self.assertNotIn(
"msa_uniprot_accession_identifiers_all_seq",
unenriched_feature_dict,
)
script_flags = SimpleNamespace(
pair_msa=True,
multimeric_template=False,
description_file=None,
path_to_mmt=None,
threshold_clashes=1000,
hb_allowance=0.4,
plddt_threshold=0,
save_features_for_multimeric_object=False,
features_directory=[str(feature_dir)],
use_ap_style=False,
)
with mock.patch.object(run_structure_prediction, "FLAGS", script_flags):
parsed = run_structure_prediction.parse_fold(
["A0A024R1R8+P61626"],
[str(feature_dir)],
"+",
)
data = run_structure_prediction.create_custom_info(parsed)
all_interactors = run_structure_prediction.create_interactors(
data,
[str(feature_dir)],
)
self.assertLen(all_interactors, 1)
self.assertLen(all_interactors[0], 2)
object_to_model, prepared_output_dir = (
run_structure_prediction.pre_modelling_setup(
all_interactors[0],
output_dir=str(self.output_dir / "mixed_identifier_prediction"),
)
)
mappings = AlphaFold3Backend.prepare_input(
objects_to_model=[
{"object": object_to_model, "output_dir": prepared_output_dir}
],
random_seed=42,
debug_msas=True,
)
self.assertLen(mappings, 1)
fold_input_obj, (
prepared_output_dir,
resolve_msa_overlaps,
) = next(iter(mappings[0].items()))
process_fold_input(
fold_input=fold_input_obj,
model_runner=None,
output_dir=prepared_output_dir,
buckets=(512,),
resolve_msa_overlaps=resolve_msa_overlaps,
)
job_name = fold_input_obj.sanitised_name()
summary_path = (
Path(prepared_output_dir)
/ f"{job_name}_af2_to_af3_translation_summary.json"
)
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
summary = json.loads(summary_path.read_text(encoding="utf-8"))
self.assertLen(summary["chains"], 2)
self.assertTrue(summary["unpaired_rows_valid"])
input_json = Path(prepared_output_dir) / f"{job_name}_data.json"
self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}")
written = json.loads(input_json.read_text(encoding="utf-8"))
protein_entries = {
protein_entry["id"]: protein_entry
for protein_entry in _protein_entries_from_af3_input(written)
}
self.assertEqual(set(protein_entries), {"A", "B"})
for chain in fold_input_obj.chains:
if not hasattr(chain, "sequence"):
continue
protein_entry = protein_entries[chain.id]
self.assertEqual(protein_entry["sequence"], chain.sequence)
self.assertEqual(
_a3m_query_sequence(protein_entry["unpairedMsa"]),
chain.sequence,
)
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""

View File

@@ -0,0 +1,4 @@
>sp|P04737|PIL1_ECOLI Pilin OS=Escherichia coli (strain K12) OX=83333 GN=traA PE=1 SV=1
MNAVLSVQGASAPVKKKSFFSKFTRLNMLRLARAVIPAAVLMMFFPQLAMAAGSSGQDLM
ASGNTTVKATFGKDSSVVKWVVLAEVLVGAVMYMMTKNVKFLAGFAIISVFIAVGMAVVG
L

View File

@@ -0,0 +1,9 @@
>sp|P15069|TRAH1_ECOLI Protein TraH OS=Escherichia coli (strain K12) OX=83333 GN=traH PE=3 SV=2
MMPRIKPLLVLCAALLTVTPAASADVNSDMNQFFNKLGFASNTTQPGVWQGQAAGYAYGG
SLYARTQVKNVQLISMTLPDINAGCGGIDAYLGSFSFINGEQLQRFVKQIMSNAAGYFFD
LALQTTVPEIKTAKDFLQKMASDINSMNLSSCQAAQGIIGGLFPRTQVSQQKVCQDIAGE
SNIFADWAASRQGCTVGGKSDSVRDKASDKDKERVTKNINIMWNALSKNRMFDGNKELKE
FVMTLTGSLVFGPNGEITPLSARTTDRSIIRAMMEGGTAKISHCNDSDKCLKVVADTPVT
ISRDNALKSQITKLLASIQNKAVSDTPLDDKEKGFISSTTIPVFKYLVDPQMLGVSNSMI
YQLTDYIGYDILLQYIQELIQQARAMVATGNYDEAVIGHINDNMNDATRQIAAFQSQVQV
QQDALLVVDRQMSYMRQQLSARMLSRYQNNYHFGGSTL

View File

@@ -1,4 +1,5 @@
import json
import os
from urllib import error
import numpy as np
@@ -11,6 +12,32 @@ from alphapulldown.objects import MonomericObject
from alphapulldown.utils import mmseqs_species_identifiers
_FASTAS_DIR = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
'test_data',
'fastas',
)
def _read_uniprot_fasta(uniprot_id: str) -> str:
path = os.path.join(_FASTAS_DIR, f'{uniprot_id}.fasta')
with open(path, encoding='utf-8') as handle:
lines = handle.read().splitlines()
return ''.join(line for line in lines[1:] if line and not line.startswith('>'))
def _build_colabfold_server_a3m(
query_sequence: str, hits: list[tuple[str, str]]
) -> str:
"""Assemble a ColabFold-server-style A3M ('#<len>\\t1' header, one-line rows)."""
lines = [f'#{len(query_sequence)}\t1', '>101', query_sequence]
for header, aligned in hits:
lines.append(f'>{header}')
lines.append(aligned)
lines.append('')
return '\n'.join(lines)
@pytest.fixture(autouse=True)
def clear_species_id_cache():
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
@@ -419,7 +446,7 @@ def test_make_mmseq_features_researches_templates_for_precomputed_msa(
'template_feature': 'TEMPLATE_FROM_RESEARCH',
}
assert calls['enrich_mmseq_feature_dict_with_identifiers'] == {
'a3m': '>101\nACDE',
'a3m': 'PRECOMPUTED_UNPAIRED',
'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')},
}
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
@@ -579,3 +606,110 @@ def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
}
assert uniprot_calls == [('A0A636IKY3',)]
assert uniparc_calls == [('UPI001118B830',)]
@pytest.mark.parametrize(
'uniprot_id,accession_species',
[
(
'P04737',
{
'A0A636IKY3': '562',
'A0A743YDY2': '573',
'UPI001118B830': '562',
},
),
(
'P15069',
{
'A0A636IKY3': '562',
'A0A743YDY2': '573',
'UPI001118B830': '562',
},
),
],
)
def test_make_mmseq_features_precomputed_colabfold_a3m_enriches_identifiers(
monkeypatch, tmp_path, uniprot_id, accession_species
):
"""Regression for issue #613: precomputed ColabFold-server A3Ms must enrich.
Before the fix, make_mmseq_features parsed the raw A3M for identifiers but
fed a different processed string to build_monomer_feature, so dedup rules
disagreed and identifier rows did not match MSA rows. This drove
enrich_mmseq_feature_dict_with_identifiers to log a warning and skip
enrichment, leaving species pairing unusable.
"""
query = _read_uniprot_fasta(uniprot_id)
assert len(query) > 0
# Realistic ColabFold-server-style A3M: '#<len>\t1' header, one header/seq per
# line pair, a mix of exact-duplicate hits, insertion-variant hits (lowercase
# letters) that strip to the query, an all-gap row, and point-mutation hits
# with real-format UniProt accessions so enrichment has something to resolve.
hits = [
(f'sp|{uniprot_id}|QUERY_DUP', query),
('UniRef100_A0A636IKY3', query[:10] + 'a' + query[10:]),
('UniRef100_A0A743YDY2', query[:15] + 'bc' + query[15:]),
('UniRef100_UPI001118B830', query[:20] + 'def' + query[20:]),
('UniRef100_A0A100XYZ0', query[:5] + 'X' + query[6:]),
('UniRef100_A0A200ABC5', query[:8] + 'Y' + query[9:]),
('UniRef100_GAP_ROW', '-' * len(query)),
('UniRef100_ALL_LOWER', 'a' * len(query)),
]
accession_species = {
**accession_species,
'A0A100XYZ0': '9606',
'A0A200ABC5': '10090',
}
a3m = _build_colabfold_server_a3m(query, hits)
precomputed_a3m = tmp_path / f'{uniprot_id}.a3m'
precomputed_a3m.write_text(a3m, encoding='utf-8')
monkeypatch.setattr(
MonomericObject, 'unzip_msa_files', staticmethod(lambda _path: False)
)
monkeypatch.setattr(
mmseqs_species_identifiers,
'resolve_species_ids_by_accession',
lambda accessions, **_: {
accession: accession_species.get(accession, '')
for accession in accessions
},
)
monomer = MonomericObject(uniprot_id, query)
monomer.make_mmseq_features(
DEFAULT_API_SERVER='https://unused.example',
output_dir=str(tmp_path),
use_precomputed_msa=True,
use_templates=False,
)
msa = monomer.feature_dict['msa']
species = monomer.feature_dict['msa_species_identifiers']
accessions = monomer.feature_dict['msa_uniprot_accession_identifiers']
assert species.shape[0] == msa.shape[0], (
f'enrichment row count {species.shape[0]} != msa rows {msa.shape[0]}'
)
assert accessions.shape[0] == msa.shape[0]
# '_all_seq' mirrors the enriched rows, used for pairing downstream.
assert (
monomer.feature_dict['msa_species_identifiers_all_seq'].shape[0]
== msa.shape[0]
)
# The resolver is called with the real UniProt-format accessions only —
# insertion-variant hits collapse onto the query row but the point-mutation
# hit survives, so at least one of the resolvable accessions lands in the
# deduped identifier rows.
resolvable = {
a.decode('utf-8')
for a in accessions.tolist()
if a.decode('utf-8') in accession_species
}
assert resolvable, (
f'expected at least one resolvable accession in {accessions.tolist()}'
)

View File

@@ -397,7 +397,7 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE")
assert calls["enrich"] == {
"a3m": ">101\nACDE\n>hit\nAC-E",
"a3m": "UNPAIRED",
"kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")},
}
assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101")
@@ -421,7 +421,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
def fake_get_msa_and_templates(**kwargs):
calls["get_msa_and_templates"] = kwargs
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
return ([">101\nACDE\n"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
monkeypatch.setattr(
@@ -462,7 +462,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
assert calls["get_msa_and_templates"]["use_templates"] is True
assert calls["enrich"]["a3m"] == ">101\nACDE"
assert calls["enrich"]["a3m"] == ">101\nACDE\n"
assert monomer.skip_msa is True
assert monomer.feature_dict["msa"].shape == (1, 4)
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
@@ -757,7 +757,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
objects_mod,
"get_msa_and_templates",
lambda **_kwargs: (
["UNPAIRED"],
[a3m_text],
["PAIRED"],
["UNIQUE"],
["CARD"],
@@ -774,7 +774,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
objects_mod,
"unserialize_msa",
lambda a3m_lines, sequence: (
["PRECOMP_MSA"],
[a3m_text],
["PRECOMP_PAIRED"],
["UNIQUE"],
["CARD"],
@@ -1417,6 +1417,81 @@ def test_pair_and_merge_pairs_and_deduplicates_for_heteromer(monkeypatch):
assert output == {"processed": {"merged": True}}
def test_pair_and_merge_backfills_missing_all_seq_accession_ids_before_pairing(
monkeypatch,
):
multimer = MultimericObject.__new__(MultimericObject)
multimer.pair_msa = True
calls = {}
chain_a = _feature_dict(sequence="ACDE", msa_rows=1, all_seq_rows=2, template_count=0)
chain_b = _feature_dict(sequence="FGHI", msa_rows=1, all_seq_rows=2, template_count=0)
chain_a["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
chain_b["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
chain_a["msa_uniprot_accession_identifiers"] = np.asarray([b"P12345"], dtype=object)
chain_a["msa_uniprot_accession_identifiers_all_seq"] = np.asarray(
[b"", b"P12345"],
dtype=object,
)
real_create_paired_features = objects_mod.msa_pairing.create_paired_features
monkeypatch.setattr(
objects_mod.feature_processing,
"process_unmerged_features",
lambda _features: None,
)
monkeypatch.setattr(
objects_mod.feature_processing,
"_is_homomer_or_monomer",
lambda _chains: False,
)
def wrapped_create_paired_features(*, chains):
calls["create_paired_features"] = chains
return real_create_paired_features(chains)
monkeypatch.setattr(
objects_mod.msa_pairing,
"create_paired_features",
wrapped_create_paired_features,
)
monkeypatch.setattr(
objects_mod.msa_pairing,
"deduplicate_unpaired_sequences",
lambda chains: chains,
)
monkeypatch.setattr(
objects_mod.feature_processing,
"crop_chains",
lambda chains, **kwargs: chains,
)
monkeypatch.setattr(
objects_mod.msa_pairing,
"merge_chain_features",
lambda **kwargs: {"chains": kwargs["np_chains_list"]},
)
monkeypatch.setattr(
objects_mod.feature_processing,
"process_final",
lambda example: example,
)
output = multimer.pair_and_merge({"A": chain_a, "B": chain_b})
assert calls["create_paired_features"][1][
"msa_uniprot_accession_identifiers"
].tolist() == [b""]
assert calls["create_paired_features"][1][
"msa_uniprot_accession_identifiers_all_seq"
].tolist() == [b"", b""]
assert output["chains"][1]["msa_uniprot_accession_identifiers"].tolist() == [b""]
assert output["chains"][1]["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b"",
b"",
]
def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch):
multimer = MultimericObject.__new__(MultimericObject)
multimer.pair_msa = False

View File

@@ -135,3 +135,46 @@ def test_all_seq_msa_features_keeps_only_pairing_related_keys(monkeypatch, tmp_p
True,
)
assert run_kwargs == {}
def test_all_seq_msa_features_backfills_missing_uniprot_accession_identifiers(
monkeypatch, tmp_path
):
monomer = MonomericObject("desc", "ACDE")
input_fasta_path = str(tmp_path / "input.fasta")
Path(input_fasta_path).write_text(">x\nACDE\n", encoding="utf-8")
class FakeMsa:
def truncate(self, max_seqs):
assert max_seqs == 50000
return self
monkeypatch.setattr(
"alphapulldown.objects.pipeline.run_msa_tool",
lambda *args, **kwargs: {"sto": "fake"},
)
monkeypatch.setattr(
"alphapulldown.objects.parsers.parse_stockholm",
lambda sto: FakeMsa(),
)
monkeypatch.setattr(
"alphapulldown.objects.pipeline.make_msa_features",
lambda _msas: {
"msa": np.asarray([[1, 2], [1, 3]], dtype=np.int32),
"msa_species_identifiers": np.asarray([b"", b"9606"], dtype=object),
"deletion_matrix_int": np.asarray([[0, 0], [0, 0]], dtype=np.int32),
},
)
features = monomer.all_seq_msa_features(
input_fasta_path=input_fasta_path,
uniprot_msa_runner="runner",
output_dir=str(tmp_path),
use_precomputed_msa=False,
)
assert features["msa_species_identifiers_all_seq"].tolist() == [b"", b"9606"]
assert features["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
b"",
b"",
]