mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Harden precomputed MSA identifier enrichment
This commit is contained in:
@@ -43,6 +43,23 @@ def _query_only_stockholm(sequence: str, query_id: str = "query") -> str:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_identifier_feature_arrays(
|
||||||
|
feature_dict: Dict[str, np.ndarray],
|
||||||
|
feature_groups: Tuple[Tuple[str, Tuple[str, ...]], ...],
|
||||||
|
) -> Dict[str, np.ndarray]:
|
||||||
|
"""Backfill missing identifier arrays to match the corresponding MSA rows."""
|
||||||
|
normalized = dict(feature_dict)
|
||||||
|
for msa_key, identifier_keys in feature_groups:
|
||||||
|
msa = normalized.get(msa_key)
|
||||||
|
if msa is None:
|
||||||
|
continue
|
||||||
|
num_rows = int(np.asarray(msa).shape[0])
|
||||||
|
for key in identifier_keys:
|
||||||
|
if key not in normalized:
|
||||||
|
normalized[key] = np.array([b""] * num_rows, dtype=object)
|
||||||
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
class MonomericObject:
|
class MonomericObject:
|
||||||
"""
|
"""
|
||||||
monomeric objects
|
monomeric objects
|
||||||
@@ -145,7 +162,18 @@ class MonomericObject:
|
|||||||
|
|
||||||
msa = parsers.parse_stockholm(result["sto"])
|
msa = parsers.parse_stockholm(result["sto"])
|
||||||
msa = msa.truncate(max_seqs=50000)
|
msa = msa.truncate(max_seqs=50000)
|
||||||
all_seq_features = pipeline.make_msa_features([msa])
|
all_seq_features = _ensure_identifier_feature_arrays(
|
||||||
|
pipeline.make_msa_features([msa]),
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"msa",
|
||||||
|
(
|
||||||
|
"msa_species_identifiers",
|
||||||
|
"msa_uniprot_accession_identifiers",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
valid_feats = msa_pairing.MSA_FEATURES + (
|
valid_feats = msa_pairing.MSA_FEATURES + (
|
||||||
"msa_species_identifiers",
|
"msa_species_identifiers",
|
||||||
"msa_uniprot_accession_identifiers",
|
"msa_uniprot_accession_identifiers",
|
||||||
@@ -362,9 +390,12 @@ class MonomericObject:
|
|||||||
# Remove header lines starting with '#' if present.
|
# Remove header lines starting with '#' if present.
|
||||||
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
|
a3m_lines[0] = strip_mmseq_comment_lines(a3m_lines[0])
|
||||||
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
|
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
|
||||||
|
# Enrich from the same A3M string that build_monomer_feature parsed, so
|
||||||
|
# the identifier rows go through the same parse_a3m dedup as msa_features
|
||||||
|
# and their count matches feature_dict['msa'] exactly.
|
||||||
enrich_mmseq_feature_dict_with_identifiers(
|
enrich_mmseq_feature_dict_with_identifiers(
|
||||||
self.feature_dict,
|
self.feature_dict,
|
||||||
a3m_lines[0],
|
unpaired_msa[0],
|
||||||
cache_path=os.path.join(
|
cache_path=os.path.join(
|
||||||
result_dir, f"{self.description}.mmseq_ids.json"
|
result_dir, f"{self.description}.mmseq_ids.json"
|
||||||
),
|
),
|
||||||
@@ -811,12 +842,48 @@ create_individual_features.py
|
|||||||
output_list.append(new_chain)
|
output_list.append(new_chain)
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def normalize_all_seq_identifier_features(np_chain_list: List[Dict]) -> List[Dict]:
|
||||||
|
"""Ensure identifier arrays exist consistently across chains.
|
||||||
|
|
||||||
|
Some feature sources provide species identifiers but omit UniProt
|
||||||
|
accession IDs, while DeepMind's multimer pairing and merge code assumes
|
||||||
|
both unpaired and `_all_seq` identifier keys exist consistently across
|
||||||
|
chains.
|
||||||
|
"""
|
||||||
|
output_list = []
|
||||||
|
for feat_dict in np_chain_list:
|
||||||
|
output_list.append(
|
||||||
|
_ensure_identifier_feature_arrays(
|
||||||
|
feat_dict,
|
||||||
|
(
|
||||||
|
(
|
||||||
|
"msa",
|
||||||
|
(
|
||||||
|
"msa_species_identifiers",
|
||||||
|
"msa_uniprot_accession_identifiers",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"msa_all_seq",
|
||||||
|
(
|
||||||
|
"msa_species_identifiers_all_seq",
|
||||||
|
"msa_uniprot_accession_identifiers_all_seq",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return output_list
|
||||||
|
|
||||||
def pair_and_merge(self, all_chain_features):
|
def pair_and_merge(self, all_chain_features):
|
||||||
"""merge all chain features"""
|
"""merge all chain features"""
|
||||||
feature_processing.process_unmerged_features(all_chain_features)
|
feature_processing.process_unmerged_features(all_chain_features)
|
||||||
MAX_TEMPLATES = 4
|
MAX_TEMPLATES = 4
|
||||||
MSA_CROP_SIZE = 2048
|
MSA_CROP_SIZE = 2048
|
||||||
np_chains_list = list(all_chain_features.values())
|
np_chains_list = MultimericObject.normalize_all_seq_identifier_features(
|
||||||
|
list(all_chain_features.values())
|
||||||
|
)
|
||||||
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
|
pair_msa_sequences = self.pair_msa and not feature_processing._is_homomer_or_monomer(
|
||||||
np_chains_list)
|
np_chains_list)
|
||||||
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")
|
logging.debug(f"pair_msa_sequences is type : {type(pair_msa_sequences)} value: {pair_msa_sequences}")
|
||||||
|
|||||||
@@ -92,6 +92,23 @@ def _load_feature_dict(feature_path: Path) -> dict:
|
|||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _load_feature_metadata(feature_dir: Path, protein_id: str) -> tuple[Path, dict]:
|
||||||
|
matches = sorted(feature_dir.glob(f"{protein_id}_feature_metadata_*.json*"))
|
||||||
|
if len(matches) != 1:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"Expected one feature metadata file for {protein_id} in {feature_dir}, "
|
||||||
|
f"found {matches}"
|
||||||
|
)
|
||||||
|
metadata_path = matches[0]
|
||||||
|
opener = lzma.open if metadata_path.suffix == ".xz" else open
|
||||||
|
with opener(metadata_path, "rt", encoding="utf-8") as handle:
|
||||||
|
return metadata_path, json.load(handle)
|
||||||
|
|
||||||
|
|
||||||
|
def _metadata_bool(value) -> bool:
|
||||||
|
return str(value).strip().lower() in {"1", "true", "yes"}
|
||||||
|
|
||||||
|
|
||||||
def _non_empty_identifier_count(values) -> int:
|
def _non_empty_identifier_count(values) -> int:
|
||||||
count = 0
|
count = 0
|
||||||
for value in values:
|
for value in values:
|
||||||
@@ -105,11 +122,12 @@ def _non_empty_identifier_count(values) -> int:
|
|||||||
def _af2_subprocess_env() -> dict[str, str]:
|
def _af2_subprocess_env() -> dict[str, str]:
|
||||||
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
|
"""Return stable GPU/JAX defaults for AF2 functional subprocesses."""
|
||||||
env = os.environ.copy()
|
env = os.environ.copy()
|
||||||
env.setdefault("OMP_NUM_THREADS", "4")
|
env.setdefault("OMP_NUM_THREADS", "1")
|
||||||
env.setdefault("MKL_NUM_THREADS", "4")
|
env.setdefault("OPENBLAS_NUM_THREADS", "1")
|
||||||
env.setdefault("NUMEXPR_NUM_THREADS", "4")
|
env.setdefault("MKL_NUM_THREADS", "1")
|
||||||
env.setdefault("TF_NUM_INTEROP_THREADS", "4")
|
env.setdefault("NUMEXPR_NUM_THREADS", "1")
|
||||||
env.setdefault("TF_NUM_INTRAOP_THREADS", "4")
|
env.setdefault("TF_NUM_INTEROP_THREADS", "1")
|
||||||
|
env.setdefault("TF_NUM_INTRAOP_THREADS", "1")
|
||||||
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
|
env.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true")
|
||||||
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
env.setdefault("TF_CPP_MIN_LOG_LEVEL", "2")
|
||||||
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
env.setdefault("XLA_PYTHON_CLIENT_PREALLOCATE", "false")
|
||||||
@@ -117,7 +135,8 @@ def _af2_subprocess_env() -> dict[str, str]:
|
|||||||
env.setdefault("JAX_PLATFORM_NAME", "gpu")
|
env.setdefault("JAX_PLATFORM_NAME", "gpu")
|
||||||
env.setdefault(
|
env.setdefault(
|
||||||
"XLA_FLAGS",
|
"XLA_FLAGS",
|
||||||
"--xla_gpu_force_compilation_parallelism=0 "
|
"--xla_gpu_force_compilation_parallelism=1 "
|
||||||
|
"--xla_force_host_platform_device_count=1 "
|
||||||
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
|
"--xla_cpu_multi_thread_eigen=false intra_op_parallelism_threads=1",
|
||||||
)
|
)
|
||||||
return env
|
return env
|
||||||
@@ -544,6 +563,87 @@ class TestMmseqsIssue588Inference(_TestBase):
|
|||||||
)
|
)
|
||||||
return feature_dir
|
return feature_dir
|
||||||
|
|
||||||
|
def _generate_issue_588_precomputed_mmseq_features(self) -> Path:
|
||||||
|
source_dir = self.output_dir / "issue_588_mmseq_source_features"
|
||||||
|
precomputed_dir = self.output_dir / "issue_588_mmseq_precomputed_features"
|
||||||
|
source_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
precomputed_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
fasta_paths = ",".join(
|
||||||
|
str(self.test_data_dir / "fastas" / f"{protein_id}.fasta")
|
||||||
|
for protein_id in self.ISSUE_588_IDS
|
||||||
|
)
|
||||||
|
|
||||||
|
source_res = self._run_prediction_subprocess(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(self.script_create_features),
|
||||||
|
f"--fasta_paths={fasta_paths}",
|
||||||
|
f"--output_dir={source_dir}",
|
||||||
|
f"--data_dir={DATA_DIR}",
|
||||||
|
"--max_template_date=2024-05-02",
|
||||||
|
"--use_mmseqs2=True",
|
||||||
|
"--data_pipeline=alphafold2",
|
||||||
|
"--save_msa_files=True",
|
||||||
|
"--compress_features=True",
|
||||||
|
"--skip_existing=False",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
source_res.returncode,
|
||||||
|
0,
|
||||||
|
"MMseqs source feature generation failed.\n"
|
||||||
|
f"STDOUT:\n{source_res.stdout}\nSTDERR:\n{source_res.stderr}",
|
||||||
|
)
|
||||||
|
|
||||||
|
for protein_id in self.ISSUE_588_IDS:
|
||||||
|
self.assertTrue(
|
||||||
|
(source_dir / f"{protein_id}.a3m").is_file(),
|
||||||
|
f"Expected MMseq A3M {source_dir / f'{protein_id}.a3m'} to be created.",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
(source_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||||
|
f"Expected compressed feature pickle {source_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||||
|
)
|
||||||
|
shutil.copy2(
|
||||||
|
source_dir / f"{protein_id}.a3m",
|
||||||
|
precomputed_dir / f"{protein_id}.a3m",
|
||||||
|
)
|
||||||
|
sidecar = source_dir / f"{protein_id}.mmseq_ids.json"
|
||||||
|
if sidecar.is_file():
|
||||||
|
shutil.copy2(sidecar, precomputed_dir / sidecar.name)
|
||||||
|
|
||||||
|
precomputed_res = self._run_prediction_subprocess(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(self.script_create_features),
|
||||||
|
f"--fasta_paths={fasta_paths}",
|
||||||
|
f"--output_dir={precomputed_dir}",
|
||||||
|
f"--data_dir={DATA_DIR}",
|
||||||
|
"--max_template_date=2024-05-02",
|
||||||
|
"--use_mmseqs2=True",
|
||||||
|
"--use_precomputed_msas=True",
|
||||||
|
"--data_pipeline=alphafold2",
|
||||||
|
"--compress_features=True",
|
||||||
|
"--skip_existing=False",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
precomputed_res.returncode,
|
||||||
|
0,
|
||||||
|
"Precomputed-MMseq feature generation failed.\n"
|
||||||
|
f"STDOUT:\n{precomputed_res.stdout}\nSTDERR:\n{precomputed_res.stderr}",
|
||||||
|
)
|
||||||
|
for protein_id in self.ISSUE_588_IDS:
|
||||||
|
self.assertTrue(
|
||||||
|
(precomputed_dir / f"{protein_id}.a3m").is_file(),
|
||||||
|
f"Expected copied MMseq A3M {precomputed_dir / f'{protein_id}.a3m'} to be present.",
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
(precomputed_dir / f"{protein_id}.pkl.xz").is_file(),
|
||||||
|
f"Expected precomputed feature pickle {precomputed_dir / f'{protein_id}.pkl.xz'} to be created.",
|
||||||
|
)
|
||||||
|
return precomputed_dir
|
||||||
|
|
||||||
def _resolve_af2_result_dir(self, root: Path) -> Path:
|
def _resolve_af2_result_dir(self, root: Path) -> Path:
|
||||||
if (root / "ranking_debug.json").exists():
|
if (root / "ranking_debug.json").exists():
|
||||||
return root
|
return root
|
||||||
@@ -642,5 +742,73 @@ class TestMmseqsIssue588Inference(_TestBase):
|
|||||||
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
|
f"Expected AF2 ipTM > 0.6, got {result_payload['iptm']}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_issue_614_precomputed_mmseqs_features_enable_af2_multimer_inference(self):
|
||||||
|
"""Issue #614 regression: AF2 should fold successfully from precomputed MMseq A3Ms."""
|
||||||
|
self._require_mmseqs_functional_environment()
|
||||||
|
feature_dir = self._generate_issue_588_precomputed_mmseq_features()
|
||||||
|
|
||||||
|
for protein_id in self.ISSUE_588_IDS:
|
||||||
|
metadata_path, metadata = _load_feature_metadata(feature_dir, protein_id)
|
||||||
|
self.assertTrue(
|
||||||
|
_metadata_bool(metadata["other"]["use_precomputed_msas"]),
|
||||||
|
f"{metadata_path} should record use_precomputed_msas=True",
|
||||||
|
)
|
||||||
|
feature_dict = _load_feature_dict(feature_dir / f"{protein_id}.pkl.xz")
|
||||||
|
self.assertGreater(
|
||||||
|
_non_empty_identifier_count(
|
||||||
|
feature_dict["msa_species_identifiers_all_seq"]
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
f"{protein_id} should keep recovered species IDs from cached MMseq A3Ms",
|
||||||
|
)
|
||||||
|
self.assertGreater(
|
||||||
|
_non_empty_identifier_count(
|
||||||
|
feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
f"{protein_id} should keep recovered accession IDs from cached MMseq A3Ms",
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction_dir = self.output_dir / "af2_precomputed_prediction"
|
||||||
|
prediction_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
res = self._run_prediction_subprocess(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
str(self.script_single),
|
||||||
|
"--input=A0ABD7FQG0+P18004",
|
||||||
|
f"--output_directory={prediction_dir}",
|
||||||
|
"--num_cycle=1",
|
||||||
|
"--num_predictions_per_model=1",
|
||||||
|
"--model_names=model_4_multimer_v3",
|
||||||
|
f"--data_directory={DATA_DIR}",
|
||||||
|
f"--features_directory={feature_dir}",
|
||||||
|
"--random_seed=42",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
res.returncode,
|
||||||
|
0,
|
||||||
|
"AF2 inference from precomputed MMseq features failed.\n"
|
||||||
|
f"STDOUT:\n{res.stdout}\nSTDERR:\n{res.stderr}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result_dir = self._resolve_af2_result_dir(prediction_dir)
|
||||||
|
ranking_payload = json.loads(
|
||||||
|
(result_dir / "ranking_debug.json").read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
self.assertTrue(ranking_payload["order"])
|
||||||
|
|
||||||
|
result_pickles = sorted(result_dir.glob("result_*.pkl"))
|
||||||
|
self.assertLen(result_pickles, 1)
|
||||||
|
with result_pickles[0].open("rb") as handle:
|
||||||
|
result_payload = pickle.load(handle)
|
||||||
|
self.assertIn("iptm", result_payload)
|
||||||
|
self.assertIn("ranking_confidence", result_payload)
|
||||||
|
self.assertGreater(
|
||||||
|
result_payload["iptm"],
|
||||||
|
0.6,
|
||||||
|
f"Expected AF2 ipTM > 0.6 from precomputed MMseq features, got {result_payload['iptm']}",
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
absltest.main()
|
absltest.main()
|
||||||
|
|||||||
@@ -20,7 +20,9 @@ import json
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import re
|
import re
|
||||||
import unittest
|
import unittest
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Dict, List, Tuple, Any
|
from typing import Dict, List, Tuple, Any
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
from absl.testing import absltest, parameterized
|
from absl.testing import absltest, parameterized
|
||||||
|
|
||||||
@@ -151,14 +153,24 @@ def _non_empty_a3m_payload_rows(a3m_text: str) -> list[str]:
|
|||||||
|
|
||||||
|
|
||||||
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
|
def _load_feature_dict(feature_path: Path) -> dict[str, Any]:
|
||||||
opener = lzma.open if feature_path.suffix == ".xz" else open
|
payload = _load_feature_payload(feature_path)
|
||||||
with opener(feature_path, "rb") as handle:
|
|
||||||
payload = pickle.load(handle)
|
|
||||||
if hasattr(payload, "feature_dict"):
|
if hasattr(payload, "feature_dict"):
|
||||||
return payload.feature_dict
|
return payload.feature_dict
|
||||||
return payload
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def _load_feature_payload(feature_path: Path) -> Any:
|
||||||
|
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||||
|
with opener(feature_path, "rb") as handle:
|
||||||
|
return pickle.load(handle)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_feature_payload(feature_path: Path, payload: Any) -> None:
|
||||||
|
opener = lzma.open if feature_path.suffix == ".xz" else open
|
||||||
|
with opener(feature_path, "wb") as handle:
|
||||||
|
pickle.dump(payload, handle)
|
||||||
|
|
||||||
|
|
||||||
def _non_empty_identifier_count(values) -> int:
|
def _non_empty_identifier_count(values) -> int:
|
||||||
count = 0
|
count = 0
|
||||||
for value in values:
|
for value in values:
|
||||||
@@ -1407,6 +1419,106 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
|||||||
fold_input_obj, _ = next(iter(mappings[0].items()))
|
fold_input_obj, _ = next(iter(mappings[0].items()))
|
||||||
return fold_input_obj
|
return fold_input_obj
|
||||||
|
|
||||||
|
def _copy_real_feature_fixture(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
source_dir: Path,
|
||||||
|
protein_id: str,
|
||||||
|
target_dir: Path,
|
||||||
|
) -> Path:
|
||||||
|
copied_feature_path = None
|
||||||
|
for pattern in (
|
||||||
|
f"{protein_id}.pkl",
|
||||||
|
f"{protein_id}.pkl.xz",
|
||||||
|
f"{protein_id}.a3m",
|
||||||
|
f"{protein_id}_feature_metadata_*.json*",
|
||||||
|
):
|
||||||
|
for source_path in sorted(source_dir.glob(pattern)):
|
||||||
|
target_path = target_dir / source_path.name
|
||||||
|
shutil.copy2(source_path, target_path)
|
||||||
|
if source_path.name.startswith(f"{protein_id}.pkl"):
|
||||||
|
copied_feature_path = target_path
|
||||||
|
|
||||||
|
self.assertIsNotNone(
|
||||||
|
copied_feature_path,
|
||||||
|
f"Missing real feature fixture for {protein_id} in {source_dir}",
|
||||||
|
)
|
||||||
|
return copied_feature_path
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _synthetic_accession_ids(species_ids: np.ndarray) -> np.ndarray:
|
||||||
|
identifiers = []
|
||||||
|
for index, value in enumerate(species_ids):
|
||||||
|
if isinstance(value, bytes):
|
||||||
|
value = value.decode("utf-8")
|
||||||
|
identifiers.append(
|
||||||
|
f"ACC{index:05d}".encode("utf-8") if str(value).strip() else b""
|
||||||
|
)
|
||||||
|
return np.asarray(identifiers, dtype=object)
|
||||||
|
|
||||||
|
def _prepare_mixed_identifier_fixture_dir(self) -> Path:
|
||||||
|
"""Materialize real AF2 fixtures with mixed identifier enrichment.
|
||||||
|
|
||||||
|
The underlying MSA rows come from repo fixtures in `test/test_data`.
|
||||||
|
We only adjust the identifier sidecars so one chain looks enriched while
|
||||||
|
the other reproduces the "no species enrichment / no accession IDs"
|
||||||
|
failure mode from issue #614's AF3 follow-up comment.
|
||||||
|
"""
|
||||||
|
feature_dir = self.output_dir / "mixed_identifier_features"
|
||||||
|
feature_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
source_dir = self.test_features_dir / "af2_features" / "protein"
|
||||||
|
|
||||||
|
enriched_feature_path = self._copy_real_feature_fixture(
|
||||||
|
source_dir=source_dir,
|
||||||
|
protein_id="A0A024R1R8",
|
||||||
|
target_dir=feature_dir,
|
||||||
|
)
|
||||||
|
unenriched_feature_path = self._copy_real_feature_fixture(
|
||||||
|
source_dir=source_dir,
|
||||||
|
protein_id="P61626",
|
||||||
|
target_dir=feature_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
enriched_payload = _load_feature_payload(enriched_feature_path)
|
||||||
|
enriched_feature_dict = (
|
||||||
|
enriched_payload.feature_dict
|
||||||
|
if hasattr(enriched_payload, "feature_dict")
|
||||||
|
else enriched_payload
|
||||||
|
)
|
||||||
|
enriched_feature_dict["msa_uniprot_accession_identifiers"] = (
|
||||||
|
self._synthetic_accession_ids(
|
||||||
|
np.asarray(enriched_feature_dict["msa_species_identifiers"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"] = (
|
||||||
|
self._synthetic_accession_ids(
|
||||||
|
np.asarray(enriched_feature_dict["msa_species_identifiers_all_seq"])
|
||||||
|
)
|
||||||
|
)
|
||||||
|
_write_feature_payload(enriched_feature_path, enriched_payload)
|
||||||
|
|
||||||
|
unenriched_payload = _load_feature_payload(unenriched_feature_path)
|
||||||
|
unenriched_feature_dict = (
|
||||||
|
unenriched_payload.feature_dict
|
||||||
|
if hasattr(unenriched_payload, "feature_dict")
|
||||||
|
else unenriched_payload
|
||||||
|
)
|
||||||
|
unenriched_feature_dict["msa_species_identifiers"] = np.asarray(
|
||||||
|
[b""] * int(np.asarray(unenriched_feature_dict["msa"]).shape[0]),
|
||||||
|
dtype=object,
|
||||||
|
)
|
||||||
|
unenriched_feature_dict["msa_species_identifiers_all_seq"] = np.asarray(
|
||||||
|
[b""] * int(np.asarray(unenriched_feature_dict["msa_all_seq"]).shape[0]),
|
||||||
|
dtype=object,
|
||||||
|
)
|
||||||
|
unenriched_feature_dict.pop("msa_uniprot_accession_identifiers", None)
|
||||||
|
unenriched_feature_dict.pop(
|
||||||
|
"msa_uniprot_accession_identifiers_all_seq", None
|
||||||
|
)
|
||||||
|
_write_feature_payload(unenriched_feature_path, unenriched_payload)
|
||||||
|
|
||||||
|
return feature_dir
|
||||||
|
|
||||||
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
|
def test_issue_588_mmseqs_af2_features_produce_sane_af3_chain_input_msas(self):
|
||||||
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
|
"""Issue #588 regression: verify AF3 input construction from exact AF2/mmseqs2 pkl fixtures."""
|
||||||
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
|
from alphapulldown.folding_backend.alphafold3_backend import process_fold_input
|
||||||
@@ -1638,6 +1750,117 @@ class TestAlphaFold3BackendRegressions(_BackendOnlyTestBase):
|
|||||||
all(template["mmcif"] for template in protein_entries[0]["templates"])
|
all(template["mmcif"] for template in protein_entries[0]["templates"])
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_af3_real_fixture_pipeline_tolerates_mixed_missing_accession_ids(self):
|
||||||
|
"""AF3 prep should tolerate a real mixed-enrichment multimer feature set."""
|
||||||
|
from alphapulldown.folding_backend.alphafold3_backend import (
|
||||||
|
AlphaFold3Backend,
|
||||||
|
process_fold_input,
|
||||||
|
)
|
||||||
|
from alphapulldown.scripts import run_structure_prediction
|
||||||
|
|
||||||
|
feature_dir = self._prepare_mixed_identifier_fixture_dir()
|
||||||
|
|
||||||
|
enriched_feature_dict = _load_feature_dict(feature_dir / "A0A024R1R8.pkl")
|
||||||
|
self.assertGreater(
|
||||||
|
_non_empty_identifier_count(
|
||||||
|
enriched_feature_dict["msa_uniprot_accession_identifiers_all_seq"]
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
unenriched_feature_dict = _load_feature_dict(feature_dir / "P61626.pkl")
|
||||||
|
self.assertEqual(
|
||||||
|
_non_empty_identifier_count(
|
||||||
|
unenriched_feature_dict["msa_species_identifiers_all_seq"]
|
||||||
|
),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
self.assertNotIn(
|
||||||
|
"msa_uniprot_accession_identifiers_all_seq",
|
||||||
|
unenriched_feature_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
script_flags = SimpleNamespace(
|
||||||
|
pair_msa=True,
|
||||||
|
multimeric_template=False,
|
||||||
|
description_file=None,
|
||||||
|
path_to_mmt=None,
|
||||||
|
threshold_clashes=1000,
|
||||||
|
hb_allowance=0.4,
|
||||||
|
plddt_threshold=0,
|
||||||
|
save_features_for_multimeric_object=False,
|
||||||
|
features_directory=[str(feature_dir)],
|
||||||
|
use_ap_style=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with mock.patch.object(run_structure_prediction, "FLAGS", script_flags):
|
||||||
|
parsed = run_structure_prediction.parse_fold(
|
||||||
|
["A0A024R1R8+P61626"],
|
||||||
|
[str(feature_dir)],
|
||||||
|
"+",
|
||||||
|
)
|
||||||
|
data = run_structure_prediction.create_custom_info(parsed)
|
||||||
|
all_interactors = run_structure_prediction.create_interactors(
|
||||||
|
data,
|
||||||
|
[str(feature_dir)],
|
||||||
|
)
|
||||||
|
self.assertLen(all_interactors, 1)
|
||||||
|
self.assertLen(all_interactors[0], 2)
|
||||||
|
object_to_model, prepared_output_dir = (
|
||||||
|
run_structure_prediction.pre_modelling_setup(
|
||||||
|
all_interactors[0],
|
||||||
|
output_dir=str(self.output_dir / "mixed_identifier_prediction"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mappings = AlphaFold3Backend.prepare_input(
|
||||||
|
objects_to_model=[
|
||||||
|
{"object": object_to_model, "output_dir": prepared_output_dir}
|
||||||
|
],
|
||||||
|
random_seed=42,
|
||||||
|
debug_msas=True,
|
||||||
|
)
|
||||||
|
self.assertLen(mappings, 1)
|
||||||
|
fold_input_obj, (
|
||||||
|
prepared_output_dir,
|
||||||
|
resolve_msa_overlaps,
|
||||||
|
) = next(iter(mappings[0].items()))
|
||||||
|
|
||||||
|
process_fold_input(
|
||||||
|
fold_input=fold_input_obj,
|
||||||
|
model_runner=None,
|
||||||
|
output_dir=prepared_output_dir,
|
||||||
|
buckets=(512,),
|
||||||
|
resolve_msa_overlaps=resolve_msa_overlaps,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_name = fold_input_obj.sanitised_name()
|
||||||
|
summary_path = (
|
||||||
|
Path(prepared_output_dir)
|
||||||
|
/ f"{job_name}_af2_to_af3_translation_summary.json"
|
||||||
|
)
|
||||||
|
self.assertTrue(summary_path.is_file(), f"Missing translation summary {summary_path}")
|
||||||
|
summary = json.loads(summary_path.read_text(encoding="utf-8"))
|
||||||
|
self.assertLen(summary["chains"], 2)
|
||||||
|
self.assertTrue(summary["unpaired_rows_valid"])
|
||||||
|
|
||||||
|
input_json = Path(prepared_output_dir) / f"{job_name}_data.json"
|
||||||
|
self.assertTrue(input_json.is_file(), f"Missing AF3 input JSON {input_json}")
|
||||||
|
written = json.loads(input_json.read_text(encoding="utf-8"))
|
||||||
|
protein_entries = {
|
||||||
|
protein_entry["id"]: protein_entry
|
||||||
|
for protein_entry in _protein_entries_from_af3_input(written)
|
||||||
|
}
|
||||||
|
self.assertEqual(set(protein_entries), {"A", "B"})
|
||||||
|
for chain in fold_input_obj.chains:
|
||||||
|
if not hasattr(chain, "sequence"):
|
||||||
|
continue
|
||||||
|
protein_entry = protein_entries[chain.id]
|
||||||
|
self.assertEqual(protein_entry["sequence"], chain.sequence)
|
||||||
|
self.assertEqual(
|
||||||
|
_a3m_query_sequence(protein_entry["unpairedMsa"]),
|
||||||
|
chain.sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
class TestAlphaFold3MmseqsIssue588Inference(_TestBase):
|
||||||
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
|
"""Opt-in AF3 end-to-end smoke test for freshly regenerated mmseq AF2 features."""
|
||||||
|
|||||||
4
test/test_data/fastas/P04737.fasta
Normal file
4
test/test_data/fastas/P04737.fasta
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
>sp|P04737|PIL1_ECOLI Pilin OS=Escherichia coli (strain K12) OX=83333 GN=traA PE=1 SV=1
|
||||||
|
MNAVLSVQGASAPVKKKSFFSKFTRLNMLRLARAVIPAAVLMMFFPQLAMAAGSSGQDLM
|
||||||
|
ASGNTTVKATFGKDSSVVKWVVLAEVLVGAVMYMMTKNVKFLAGFAIISVFIAVGMAVVG
|
||||||
|
L
|
||||||
9
test/test_data/fastas/P15069.fasta
Normal file
9
test/test_data/fastas/P15069.fasta
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
>sp|P15069|TRAH1_ECOLI Protein TraH OS=Escherichia coli (strain K12) OX=83333 GN=traH PE=3 SV=2
|
||||||
|
MMPRIKPLLVLCAALLTVTPAASADVNSDMNQFFNKLGFASNTTQPGVWQGQAAGYAYGG
|
||||||
|
SLYARTQVKNVQLISMTLPDINAGCGGIDAYLGSFSFINGEQLQRFVKQIMSNAAGYFFD
|
||||||
|
LALQTTVPEIKTAKDFLQKMASDINSMNLSSCQAAQGIIGGLFPRTQVSQQKVCQDIAGE
|
||||||
|
SNIFADWAASRQGCTVGGKSDSVRDKASDKDKERVTKNINIMWNALSKNRMFDGNKELKE
|
||||||
|
FVMTLTGSLVFGPNGEITPLSARTTDRSIIRAMMEGGTAKISHCNDSDKCLKVVADTPVT
|
||||||
|
ISRDNALKSQITKLLASIQNKAVSDTPLDDKEKGFISSTTIPVFKYLVDPQMLGVSNSMI
|
||||||
|
YQLTDYIGYDILLQYIQELIQQARAMVATGNYDEAVIGHINDNMNDATRQIAAFQSQVQV
|
||||||
|
QQDALLVVDRQMSYMRQQLSARMLSRYQNNYHFGGSTL
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
from urllib import error
|
from urllib import error
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -11,6 +12,32 @@ from alphapulldown.objects import MonomericObject
|
|||||||
from alphapulldown.utils import mmseqs_species_identifiers
|
from alphapulldown.utils import mmseqs_species_identifiers
|
||||||
|
|
||||||
|
|
||||||
|
_FASTAS_DIR = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||||
|
'test_data',
|
||||||
|
'fastas',
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_uniprot_fasta(uniprot_id: str) -> str:
|
||||||
|
path = os.path.join(_FASTAS_DIR, f'{uniprot_id}.fasta')
|
||||||
|
with open(path, encoding='utf-8') as handle:
|
||||||
|
lines = handle.read().splitlines()
|
||||||
|
return ''.join(line for line in lines[1:] if line and not line.startswith('>'))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_colabfold_server_a3m(
|
||||||
|
query_sequence: str, hits: list[tuple[str, str]]
|
||||||
|
) -> str:
|
||||||
|
"""Assemble a ColabFold-server-style A3M ('#<len>\\t1' header, one-line rows)."""
|
||||||
|
lines = [f'#{len(query_sequence)}\t1', '>101', query_sequence]
|
||||||
|
for header, aligned in hits:
|
||||||
|
lines.append(f'>{header}')
|
||||||
|
lines.append(aligned)
|
||||||
|
lines.append('')
|
||||||
|
return '\n'.join(lines)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def clear_species_id_cache():
|
def clear_species_id_cache():
|
||||||
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
|
mmseqs_species_identifiers._SPECIES_ID_CACHE.clear()
|
||||||
@@ -419,7 +446,7 @@ def test_make_mmseq_features_researches_templates_for_precomputed_msa(
|
|||||||
'template_feature': 'TEMPLATE_FROM_RESEARCH',
|
'template_feature': 'TEMPLATE_FROM_RESEARCH',
|
||||||
}
|
}
|
||||||
assert calls['enrich_mmseq_feature_dict_with_identifiers'] == {
|
assert calls['enrich_mmseq_feature_dict_with_identifiers'] == {
|
||||||
'a3m': '>101\nACDE',
|
'a3m': 'PRECOMPUTED_UNPAIRED',
|
||||||
'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')},
|
'kwargs': {'cache_path': str(tmp_path / 'dummy.mmseq_ids.json')},
|
||||||
}
|
}
|
||||||
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
|
assert isinstance(monomer.feature_dict['template_confidence_scores'], np.ndarray)
|
||||||
@@ -579,3 +606,110 @@ def test_resolve_species_ids_by_accession_skips_unsupported_accessions(
|
|||||||
}
|
}
|
||||||
assert uniprot_calls == [('A0A636IKY3',)]
|
assert uniprot_calls == [('A0A636IKY3',)]
|
||||||
assert uniparc_calls == [('UPI001118B830',)]
|
assert uniparc_calls == [('UPI001118B830',)]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'uniprot_id,accession_species',
|
||||||
|
[
|
||||||
|
(
|
||||||
|
'P04737',
|
||||||
|
{
|
||||||
|
'A0A636IKY3': '562',
|
||||||
|
'A0A743YDY2': '573',
|
||||||
|
'UPI001118B830': '562',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
(
|
||||||
|
'P15069',
|
||||||
|
{
|
||||||
|
'A0A636IKY3': '562',
|
||||||
|
'A0A743YDY2': '573',
|
||||||
|
'UPI001118B830': '562',
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_make_mmseq_features_precomputed_colabfold_a3m_enriches_identifiers(
|
||||||
|
monkeypatch, tmp_path, uniprot_id, accession_species
|
||||||
|
):
|
||||||
|
"""Regression for issue #613: precomputed ColabFold-server A3Ms must enrich.
|
||||||
|
|
||||||
|
Before the fix, make_mmseq_features parsed the raw A3M for identifiers but
|
||||||
|
fed a different processed string to build_monomer_feature, so dedup rules
|
||||||
|
disagreed and identifier rows did not match MSA rows. This drove
|
||||||
|
enrich_mmseq_feature_dict_with_identifiers to log a warning and skip
|
||||||
|
enrichment, leaving species pairing unusable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
query = _read_uniprot_fasta(uniprot_id)
|
||||||
|
assert len(query) > 0
|
||||||
|
|
||||||
|
# Realistic ColabFold-server-style A3M: '#<len>\t1' header, one header/seq per
|
||||||
|
# line pair, a mix of exact-duplicate hits, insertion-variant hits (lowercase
|
||||||
|
# letters) that strip to the query, an all-gap row, and point-mutation hits
|
||||||
|
# with real-format UniProt accessions so enrichment has something to resolve.
|
||||||
|
hits = [
|
||||||
|
(f'sp|{uniprot_id}|QUERY_DUP', query),
|
||||||
|
('UniRef100_A0A636IKY3', query[:10] + 'a' + query[10:]),
|
||||||
|
('UniRef100_A0A743YDY2', query[:15] + 'bc' + query[15:]),
|
||||||
|
('UniRef100_UPI001118B830', query[:20] + 'def' + query[20:]),
|
||||||
|
('UniRef100_A0A100XYZ0', query[:5] + 'X' + query[6:]),
|
||||||
|
('UniRef100_A0A200ABC5', query[:8] + 'Y' + query[9:]),
|
||||||
|
('UniRef100_GAP_ROW', '-' * len(query)),
|
||||||
|
('UniRef100_ALL_LOWER', 'a' * len(query)),
|
||||||
|
]
|
||||||
|
accession_species = {
|
||||||
|
**accession_species,
|
||||||
|
'A0A100XYZ0': '9606',
|
||||||
|
'A0A200ABC5': '10090',
|
||||||
|
}
|
||||||
|
a3m = _build_colabfold_server_a3m(query, hits)
|
||||||
|
|
||||||
|
precomputed_a3m = tmp_path / f'{uniprot_id}.a3m'
|
||||||
|
precomputed_a3m.write_text(a3m, encoding='utf-8')
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
MonomericObject, 'unzip_msa_files', staticmethod(lambda _path: False)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
mmseqs_species_identifiers,
|
||||||
|
'resolve_species_ids_by_accession',
|
||||||
|
lambda accessions, **_: {
|
||||||
|
accession: accession_species.get(accession, '')
|
||||||
|
for accession in accessions
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
monomer = MonomericObject(uniprot_id, query)
|
||||||
|
monomer.make_mmseq_features(
|
||||||
|
DEFAULT_API_SERVER='https://unused.example',
|
||||||
|
output_dir=str(tmp_path),
|
||||||
|
use_precomputed_msa=True,
|
||||||
|
use_templates=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
msa = monomer.feature_dict['msa']
|
||||||
|
species = monomer.feature_dict['msa_species_identifiers']
|
||||||
|
accessions = monomer.feature_dict['msa_uniprot_accession_identifiers']
|
||||||
|
|
||||||
|
assert species.shape[0] == msa.shape[0], (
|
||||||
|
f'enrichment row count {species.shape[0]} != msa rows {msa.shape[0]}'
|
||||||
|
)
|
||||||
|
assert accessions.shape[0] == msa.shape[0]
|
||||||
|
# '_all_seq' mirrors the enriched rows, used for pairing downstream.
|
||||||
|
assert (
|
||||||
|
monomer.feature_dict['msa_species_identifiers_all_seq'].shape[0]
|
||||||
|
== msa.shape[0]
|
||||||
|
)
|
||||||
|
# The resolver is called with the real UniProt-format accessions only —
|
||||||
|
# insertion-variant hits collapse onto the query row but the point-mutation
|
||||||
|
# hit survives, so at least one of the resolvable accessions lands in the
|
||||||
|
# deduped identifier rows.
|
||||||
|
resolvable = {
|
||||||
|
a.decode('utf-8')
|
||||||
|
for a in accessions.tolist()
|
||||||
|
if a.decode('utf-8') in accession_species
|
||||||
|
}
|
||||||
|
assert resolvable, (
|
||||||
|
f'expected at least one resolvable accession in {accessions.tolist()}'
|
||||||
|
)
|
||||||
|
|||||||
@@ -397,7 +397,7 @@ def test_make_mmseq_features_builds_all_seq_features_and_writes_a3m(
|
|||||||
|
|
||||||
assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE")
|
assert calls["build_monomer_feature"] == ("ACDE", "UNPAIRED", "TEMPLATE")
|
||||||
assert calls["enrich"] == {
|
assert calls["enrich"] == {
|
||||||
"a3m": ">101\nACDE\n>hit\nAC-E",
|
"a3m": "UNPAIRED",
|
||||||
"kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")},
|
"kwargs": {"cache_path": str(tmp_path / "proteinA.mmseq_ids.json")},
|
||||||
}
|
}
|
||||||
assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101")
|
assert (tmp_path / "proteinA.a3m").read_text(encoding="utf-8").startswith(">101")
|
||||||
@@ -421,7 +421,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
|
|||||||
|
|
||||||
def fake_get_msa_and_templates(**kwargs):
|
def fake_get_msa_and_templates(**kwargs):
|
||||||
calls["get_msa_and_templates"] = kwargs
|
calls["get_msa_and_templates"] = kwargs
|
||||||
return (["UNPAIRED"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
|
return ([">101\nACDE\n"], [""], ["UNIQUE"], ["CARD"], ["TEMPLATE"])
|
||||||
|
|
||||||
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
|
monkeypatch.setattr(objects_mod, "get_msa_and_templates", fake_get_msa_and_templates)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
@@ -462,7 +462,7 @@ def test_make_mmseq_features_skip_msa_uses_single_sequence_mode(
|
|||||||
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
|
assert calls["get_msa_and_templates"]["pair_mode"] == "none"
|
||||||
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
|
assert calls["get_msa_and_templates"]["a3m_lines"] == [">101\nACDE"]
|
||||||
assert calls["get_msa_and_templates"]["use_templates"] is True
|
assert calls["get_msa_and_templates"]["use_templates"] is True
|
||||||
assert calls["enrich"]["a3m"] == ">101\nACDE"
|
assert calls["enrich"]["a3m"] == ">101\nACDE\n"
|
||||||
assert monomer.skip_msa is True
|
assert monomer.skip_msa is True
|
||||||
assert monomer.feature_dict["msa"].shape == (1, 4)
|
assert monomer.feature_dict["msa"].shape == (1, 4)
|
||||||
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
|
assert monomer.feature_dict["msa_all_seq"].shape == (1, 4)
|
||||||
@@ -757,7 +757,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
|
|||||||
objects_mod,
|
objects_mod,
|
||||||
"get_msa_and_templates",
|
"get_msa_and_templates",
|
||||||
lambda **_kwargs: (
|
lambda **_kwargs: (
|
||||||
["UNPAIRED"],
|
[a3m_text],
|
||||||
["PAIRED"],
|
["PAIRED"],
|
||||||
["UNIQUE"],
|
["UNIQUE"],
|
||||||
["CARD"],
|
["CARD"],
|
||||||
@@ -774,7 +774,7 @@ def test_make_mmseq_features_reuses_identifier_sidecar_on_precomputed_run(
|
|||||||
objects_mod,
|
objects_mod,
|
||||||
"unserialize_msa",
|
"unserialize_msa",
|
||||||
lambda a3m_lines, sequence: (
|
lambda a3m_lines, sequence: (
|
||||||
["PRECOMP_MSA"],
|
[a3m_text],
|
||||||
["PRECOMP_PAIRED"],
|
["PRECOMP_PAIRED"],
|
||||||
["UNIQUE"],
|
["UNIQUE"],
|
||||||
["CARD"],
|
["CARD"],
|
||||||
@@ -1417,6 +1417,81 @@ def test_pair_and_merge_pairs_and_deduplicates_for_heteromer(monkeypatch):
|
|||||||
assert output == {"processed": {"merged": True}}
|
assert output == {"processed": {"merged": True}}
|
||||||
|
|
||||||
|
|
||||||
|
def test_pair_and_merge_backfills_missing_all_seq_accession_ids_before_pairing(
|
||||||
|
monkeypatch,
|
||||||
|
):
|
||||||
|
multimer = MultimericObject.__new__(MultimericObject)
|
||||||
|
multimer.pair_msa = True
|
||||||
|
calls = {}
|
||||||
|
|
||||||
|
chain_a = _feature_dict(sequence="ACDE", msa_rows=1, all_seq_rows=2, template_count=0)
|
||||||
|
chain_b = _feature_dict(sequence="FGHI", msa_rows=1, all_seq_rows=2, template_count=0)
|
||||||
|
chain_a["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
|
||||||
|
chain_b["msa_species_identifiers_all_seq"] = np.asarray([b"", b"9606"], dtype=object)
|
||||||
|
chain_a["msa_uniprot_accession_identifiers"] = np.asarray([b"P12345"], dtype=object)
|
||||||
|
chain_a["msa_uniprot_accession_identifiers_all_seq"] = np.asarray(
|
||||||
|
[b"", b"P12345"],
|
||||||
|
dtype=object,
|
||||||
|
)
|
||||||
|
|
||||||
|
real_create_paired_features = objects_mod.msa_pairing.create_paired_features
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.feature_processing,
|
||||||
|
"process_unmerged_features",
|
||||||
|
lambda _features: None,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.feature_processing,
|
||||||
|
"_is_homomer_or_monomer",
|
||||||
|
lambda _chains: False,
|
||||||
|
)
|
||||||
|
|
||||||
|
def wrapped_create_paired_features(*, chains):
|
||||||
|
calls["create_paired_features"] = chains
|
||||||
|
return real_create_paired_features(chains)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.msa_pairing,
|
||||||
|
"create_paired_features",
|
||||||
|
wrapped_create_paired_features,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.msa_pairing,
|
||||||
|
"deduplicate_unpaired_sequences",
|
||||||
|
lambda chains: chains,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.feature_processing,
|
||||||
|
"crop_chains",
|
||||||
|
lambda chains, **kwargs: chains,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.msa_pairing,
|
||||||
|
"merge_chain_features",
|
||||||
|
lambda **kwargs: {"chains": kwargs["np_chains_list"]},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
objects_mod.feature_processing,
|
||||||
|
"process_final",
|
||||||
|
lambda example: example,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = multimer.pair_and_merge({"A": chain_a, "B": chain_b})
|
||||||
|
|
||||||
|
assert calls["create_paired_features"][1][
|
||||||
|
"msa_uniprot_accession_identifiers"
|
||||||
|
].tolist() == [b""]
|
||||||
|
assert calls["create_paired_features"][1][
|
||||||
|
"msa_uniprot_accession_identifiers_all_seq"
|
||||||
|
].tolist() == [b"", b""]
|
||||||
|
assert output["chains"][1]["msa_uniprot_accession_identifiers"].tolist() == [b""]
|
||||||
|
assert output["chains"][1]["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
|
||||||
|
b"",
|
||||||
|
b"",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch):
|
def test_pair_and_merge_removes_all_seq_features_when_pairing_disabled(monkeypatch):
|
||||||
multimer = MultimericObject.__new__(MultimericObject)
|
multimer = MultimericObject.__new__(MultimericObject)
|
||||||
multimer.pair_msa = False
|
multimer.pair_msa = False
|
||||||
|
|||||||
@@ -135,3 +135,46 @@ def test_all_seq_msa_features_keeps_only_pairing_related_keys(monkeypatch, tmp_p
|
|||||||
True,
|
True,
|
||||||
)
|
)
|
||||||
assert run_kwargs == {}
|
assert run_kwargs == {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_seq_msa_features_backfills_missing_uniprot_accession_identifiers(
|
||||||
|
monkeypatch, tmp_path
|
||||||
|
):
|
||||||
|
monomer = MonomericObject("desc", "ACDE")
|
||||||
|
input_fasta_path = str(tmp_path / "input.fasta")
|
||||||
|
Path(input_fasta_path).write_text(">x\nACDE\n", encoding="utf-8")
|
||||||
|
|
||||||
|
class FakeMsa:
|
||||||
|
def truncate(self, max_seqs):
|
||||||
|
assert max_seqs == 50000
|
||||||
|
return self
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"alphapulldown.objects.pipeline.run_msa_tool",
|
||||||
|
lambda *args, **kwargs: {"sto": "fake"},
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"alphapulldown.objects.parsers.parse_stockholm",
|
||||||
|
lambda sto: FakeMsa(),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"alphapulldown.objects.pipeline.make_msa_features",
|
||||||
|
lambda _msas: {
|
||||||
|
"msa": np.asarray([[1, 2], [1, 3]], dtype=np.int32),
|
||||||
|
"msa_species_identifiers": np.asarray([b"", b"9606"], dtype=object),
|
||||||
|
"deletion_matrix_int": np.asarray([[0, 0], [0, 0]], dtype=np.int32),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
features = monomer.all_seq_msa_features(
|
||||||
|
input_fasta_path=input_fasta_path,
|
||||||
|
uniprot_msa_runner="runner",
|
||||||
|
output_dir=str(tmp_path),
|
||||||
|
use_precomputed_msa=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert features["msa_species_identifiers_all_seq"].tolist() == [b"", b"9606"]
|
||||||
|
assert features["msa_uniprot_accession_identifiers_all_seq"].tolist() == [
|
||||||
|
b"",
|
||||||
|
b"",
|
||||||
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user