Recover mmseqs species identifiers for AF2 pairing

This commit is contained in:
Dima
2026-03-27 12:13:19 +01:00
parent 269c15b215
commit 53a75e14c8
3 changed files with 402 additions and 0 deletions

View File

@@ -21,6 +21,9 @@ from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_fea
from alphapulldown.utils.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain,
prepare_multimeric_template_meta_info)
from alphapulldown.utils.file_handling import temp_fasta_file
from alphapulldown.utils.mmseqs_species_identifiers import (
enrich_mmseq_feature_dict_with_identifiers,
)
class MonomericObject:
"""
@@ -253,6 +256,10 @@ class MonomericObject:
# Remove header lines starting with '#' if present.
a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
enrich_mmseq_feature_dict_with_identifiers(
self.feature_dict,
unpaired_msa[0],
)
# Fix: Change tuple to list so that we can concatenate with msa_pairing.MSA_FEATURES.
valid_feats = msa_pairing.MSA_FEATURES + ("msa_species_identifiers", "msa_uniprot_accession_identifiers")

View File

@@ -0,0 +1,278 @@
"""Helpers for recovering species identifiers from mmseqs-derived A3Ms."""
from __future__ import annotations
import json
import re
from typing import Callable, Iterable, Sequence
from urllib import parse
from urllib import request
from absl import logging
from alphafold.data import parsers
import numpy as np
_UNIPROT_HEADER_PATTERN = re.compile(
r"""
^
(?:tr|sp)
\|
(?P<accession>[A-Za-z0-9]{6,10})
(?:_\d)?
\|
(?:[A-Za-z0-9]+)
_
(?P<species>[A-Za-z0-9]{1,5})
(?:_\d+)?
$
""",
re.VERBOSE,
)
_UNIREF_HEADER_PATTERN = re.compile(
r'^UniRef\d+_(?P<accession>[A-Za-z0-9]+)$'
)
_UNIPARC_HEADER_PATTERN = re.compile(r'^(?P<accession>UPI[A-Z0-9]+)$')
_GENERIC_ACCESSION_PATTERN = re.compile(
r'^(?P<accession>[A-Za-z0-9]{6,16})$'
)
_UNIPROT_BATCH_SIZE = 32
_UNIPARC_BATCH_SIZE = 16
_UNIPROT_TIMEOUT_SECONDS = 30
_SPECIES_ID_CACHE: dict[str, str] = {}
def _extract_sequence_identifier(description: str) -> str:
split_description = description.split()
if not split_description:
return ""
return split_description[0].partition('/')[0]
def _extract_accession_and_species(description: str) -> tuple[str, str]:
sequence_identifier = _extract_sequence_identifier(description)
if not sequence_identifier:
return "", ""
matches = _UNIPROT_HEADER_PATTERN.search(sequence_identifier.strip())
if matches:
return matches.group("accession"), matches.group("species")
for pattern in (
_UNIREF_HEADER_PATTERN,
_UNIPARC_HEADER_PATTERN,
_GENERIC_ACCESSION_PATTERN,
):
matches = pattern.search(sequence_identifier.strip())
if matches:
return matches.group("accession"), ""
return "", ""
def _batched(items: Sequence[str], batch_size: int) -> Iterable[Sequence[str]]:
for start in range(0, len(items), batch_size):
yield items[start : start + batch_size]
def _query_uniprot_batch(
accessions: Sequence[str],
*,
urlopen: Callable[..., object],
) -> dict[str, object]:
query = " OR ".join(f"accession:{accession}" for accession in accessions)
url = (
"https://rest.uniprot.org/uniprotkb/search?query="
f"{parse.quote(query)}"
"&fields=accession,organism_id"
"&format=json"
f"&size={len(accessions)}"
)
with urlopen(url, timeout=_UNIPROT_TIMEOUT_SECONDS) as response:
return json.load(response)
def _query_uniparc_batch(
accessions: Sequence[str],
*,
urlopen: Callable[..., object],
) -> dict[str, object]:
query = " OR ".join(f"upi:{accession}" for accession in accessions)
url = (
"https://rest.uniprot.org/uniparc/search?query="
f"{parse.quote(query)}"
"&fields=upi,organism_id"
"&format=json"
f"&size={len(accessions)}"
)
with urlopen(url, timeout=_UNIPROT_TIMEOUT_SECONDS) as response:
return json.load(response)
def _query_uniprot_species_ids(
accessions: Sequence[str],
*,
urlopen: Callable[..., object],
) -> dict[str, str]:
resolved: dict[str, str] = {}
for batch in _batched(sorted(set(accessions)), _UNIPROT_BATCH_SIZE):
try:
payload = _query_uniprot_batch(batch, urlopen=urlopen)
except Exception as exc: # pragma: no cover - best-effort network fallback
logging.warning(
"Unable to resolve UniProtKB taxonomy for %d accessions: %s",
len(batch),
exc,
)
payload = {"results": []}
for accession in batch:
try:
payload["results"].extend(
_query_uniprot_batch([accession], urlopen=urlopen)["results"]
)
except Exception:
continue
for result in payload.get("results", []):
accession = result.get("primaryAccession")
taxon_id = result.get("organism", {}).get("taxonId")
if accession and taxon_id is not None:
resolved[accession] = str(taxon_id)
return resolved
def _query_uniparc_species_ids(
accessions: Sequence[str],
*,
urlopen: Callable[..., object],
) -> dict[str, str]:
resolved: dict[str, str] = {}
for batch in _batched(sorted(set(accessions)), _UNIPARC_BATCH_SIZE):
try:
payload = _query_uniparc_batch(batch, urlopen=urlopen)
except Exception as exc: # pragma: no cover - best-effort network fallback
logging.warning(
"Unable to resolve UniParc taxonomy for %d accessions: %s",
len(batch),
exc,
)
payload = {"results": []}
for accession in batch:
try:
payload["results"].extend(
_query_uniparc_batch([accession], urlopen=urlopen)["results"]
)
except Exception:
continue
for result in payload.get("results", []):
accession = result.get("uniParcId")
organisms = result.get("organisms", [])
taxon_ids = {
str(organism.get("taxonId"))
for organism in organisms
if organism.get("taxonId") is not None
}
if accession and len(taxon_ids) == 1:
resolved[accession] = next(iter(taxon_ids))
return resolved
def resolve_species_ids_by_accession(
accessions: Sequence[str],
*,
urlopen: Callable[..., object] = request.urlopen,
) -> dict[str, str]:
unresolved = [
accession
for accession in sorted(set(accessions))
if accession and accession not in _SPECIES_ID_CACHE
]
if unresolved:
uniprot_accessions = [
accession for accession in unresolved if not accession.startswith("UPI")
]
uniparc_accessions = [
accession for accession in unresolved if accession.startswith("UPI")
]
resolved = _query_uniprot_species_ids(
uniprot_accessions, urlopen=urlopen
)
resolved.update(
_query_uniparc_species_ids(uniparc_accessions, urlopen=urlopen)
)
for accession in unresolved:
_SPECIES_ID_CACHE[accession] = resolved.get(accession, "")
return {
accession: _SPECIES_ID_CACHE.get(accession, "")
for accession in accessions
if accession
}
def build_mmseq_identifier_features(
a3m_string: str,
*,
species_resolver: Callable[[Sequence[str]], dict[str, str]] = (
resolve_species_ids_by_accession
),
) -> dict[str, np.ndarray]:
msa = parsers.parse_a3m(a3m_string)
seen_sequences: set[str] = set()
accessions: list[str] = []
species_ids: list[str] = []
for sequence, description in zip(
msa.sequences, msa.descriptions, strict=True
):
if sequence in seen_sequences:
continue
seen_sequences.add(sequence)
accession_id, species_id = _extract_accession_and_species(description)
accessions.append(accession_id)
species_ids.append(species_id)
resolved_species_ids = species_resolver(
[accession for accession, species_id in zip(accessions, species_ids, strict=True)
if accession and not species_id]
)
species_ids = [
species_id or resolved_species_ids.get(accession_id, "")
for accession_id, species_id in zip(accessions, species_ids, strict=True)
]
return {
"msa_species_identifiers": np.array(
[species_id.encode("utf-8") for species_id in species_ids],
dtype=np.object_,
),
"msa_uniprot_accession_identifiers": np.array(
[accession_id.encode("utf-8") for accession_id in accessions],
dtype=np.object_,
),
}
def enrich_mmseq_feature_dict_with_identifiers(
feature_dict: dict[str, np.ndarray],
a3m_string: str,
*,
species_resolver: Callable[[Sequence[str]], dict[str, str]] = (
resolve_species_ids_by_accession
),
) -> None:
identifier_features = build_mmseq_identifier_features(
a3m_string,
species_resolver=species_resolver,
)
msa = feature_dict.get("msa")
if msa is None:
return
if len(identifier_features["msa_species_identifiers"]) != msa.shape[0]:
logging.warning(
"Skipping mmseqs species enrichment because identifier rows do not "
"match MSA rows: %d != %d",
len(identifier_features["msa_species_identifiers"]),
msa.shape[0],
)
return
feature_dict.update(identifier_features)

View File

@@ -0,0 +1,117 @@
import numpy as np
from alphafold.data import msa_pairing
from alphafold.data import parsers
from alphafold.data import pipeline
from alphapulldown.utils import mmseqs_species_identifiers
def _feature_dict_from_a3m(
sequence: str,
a3m: str,
*,
species_resolver,
) -> dict[str, np.ndarray]:
feature_dict = {
**pipeline.make_sequence_features(sequence, 'none', len(sequence)),
**pipeline.make_msa_features([parsers.parse_a3m(a3m)]),
}
mmseqs_species_identifiers.enrich_mmseq_feature_dict_with_identifiers(
feature_dict,
a3m,
species_resolver=species_resolver,
)
valid_feats = msa_pairing.MSA_FEATURES + (
'msa_species_identifiers',
'msa_uniprot_accession_identifiers',
)
feature_dict.update(
{
f'{key}_all_seq': value
for key, value in feature_dict.items()
if key in valid_feats
}
)
return feature_dict
def test_make_msa_features_resolves_mmseqs_species_identifiers(monkeypatch):
monkeypatch.setattr(
mmseqs_species_identifiers,
'resolve_species_ids_by_accession',
lambda accessions, **_: {
'A0A636IKY3': '108619',
'UPI001118B830': '562',
},
)
a3m = '\n'.join([
'>101',
'ACDE',
'>UniRef100_A0A636IKY3\t136\t0.883',
'ACDF',
'>UniRef100_UPI001118B830\t855\t0.990',
'AC-E',
'',
])
features = mmseqs_species_identifiers.build_mmseq_identifier_features(a3m)
assert features['msa_species_identifiers'].tolist() == [
b'',
b'108619',
b'562',
]
assert features['msa_uniprot_accession_identifiers'].tolist() == [
b'',
b'A0A636IKY3',
b'UPI001118B830',
]
def test_pair_sequences_works_with_mmseqs_accession_species_resolution(
monkeypatch,
):
monkeypatch.setattr(
mmseqs_species_identifiers,
'resolve_species_ids_by_accession',
lambda accessions, **_: {
'A0A636IKY3': '562',
'A0A743YDY2': '573',
'UPI001118B830': '562',
'UPI00101273C6': '573',
},
)
chain_a = _feature_dict_from_a3m(
'ACDE',
'\n'.join([
'>101',
'ACDE',
'>UniRef100_A0A636IKY3\t136\t0.883',
'ACDF',
'>UniRef100_A0A743YDY2\t134\t0.932',
'AC-E',
'',
]),
species_resolver=mmseqs_species_identifiers.resolve_species_ids_by_accession,
)
chain_b = _feature_dict_from_a3m(
'WXYZ',
'\n'.join([
'>101',
'WXYZ',
'>UniRef100_UPI001118B830\t855\t0.990',
'WXYW',
'>UniRef100_UPI00101273C6\t833\t0.919',
'WX-Z',
'',
]),
species_resolver=mmseqs_species_identifiers.resolve_species_ids_by_accession,
)
paired_rows = msa_pairing.pair_sequences([chain_a, chain_b])[2]
assert paired_rows.shape == (3, 2)
assert tuple(paired_rows[0]) == (0, 0)
assert {tuple(row) for row in paired_rows[1:]} == {(1, 1), (2, 2)}