mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Recover mmseqs species identifiers for AF2 pairing
This commit is contained in:
@@ -21,6 +21,9 @@ from colabfold.batch import get_msa_and_templates, msa_to_str, build_monomer_fea
|
||||
from alphapulldown.utils.multimeric_template_utils import (extract_multimeric_template_features_for_single_chain,
|
||||
prepare_multimeric_template_meta_info)
|
||||
from alphapulldown.utils.file_handling import temp_fasta_file
|
||||
from alphapulldown.utils.mmseqs_species_identifiers import (
|
||||
enrich_mmseq_feature_dict_with_identifiers,
|
||||
)
|
||||
|
||||
class MonomericObject:
|
||||
"""
|
||||
@@ -253,6 +256,10 @@ class MonomericObject:
|
||||
# Remove header lines starting with '#' if present.
|
||||
a3m_lines[0] = "\n".join([line for line in a3m_lines[0].splitlines() if not line.startswith("#")])
|
||||
self.feature_dict = build_monomer_feature(self.sequence, unpaired_msa[0], template_features[0])
|
||||
enrich_mmseq_feature_dict_with_identifiers(
|
||||
self.feature_dict,
|
||||
unpaired_msa[0],
|
||||
)
|
||||
|
||||
# Fix: Change tuple to list so that we can concatenate with msa_pairing.MSA_FEATURES.
|
||||
valid_feats = msa_pairing.MSA_FEATURES + ("msa_species_identifiers", "msa_uniprot_accession_identifiers")
|
||||
|
||||
278
alphapulldown/utils/mmseqs_species_identifiers.py
Normal file
278
alphapulldown/utils/mmseqs_species_identifiers.py
Normal file
@@ -0,0 +1,278 @@
|
||||
"""Helpers for recovering species identifiers from mmseqs-derived A3Ms."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Callable, Iterable, Sequence
|
||||
from urllib import parse
|
||||
from urllib import request
|
||||
|
||||
from absl import logging
|
||||
from alphafold.data import parsers
|
||||
import numpy as np
|
||||
|
||||
_UNIPROT_HEADER_PATTERN = re.compile(
|
||||
r"""
|
||||
^
|
||||
(?:tr|sp)
|
||||
\|
|
||||
(?P<accession>[A-Za-z0-9]{6,10})
|
||||
(?:_\d)?
|
||||
\|
|
||||
(?:[A-Za-z0-9]+)
|
||||
_
|
||||
(?P<species>[A-Za-z0-9]{1,5})
|
||||
(?:_\d+)?
|
||||
$
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
_UNIREF_HEADER_PATTERN = re.compile(
|
||||
r'^UniRef\d+_(?P<accession>[A-Za-z0-9]+)$'
|
||||
)
|
||||
_UNIPARC_HEADER_PATTERN = re.compile(r'^(?P<accession>UPI[A-Z0-9]+)$')
|
||||
_GENERIC_ACCESSION_PATTERN = re.compile(
|
||||
r'^(?P<accession>[A-Za-z0-9]{6,16})$'
|
||||
)
|
||||
|
||||
_UNIPROT_BATCH_SIZE = 32
|
||||
_UNIPARC_BATCH_SIZE = 16
|
||||
_UNIPROT_TIMEOUT_SECONDS = 30
|
||||
_SPECIES_ID_CACHE: dict[str, str] = {}
|
||||
|
||||
|
||||
def _extract_sequence_identifier(description: str) -> str:
|
||||
split_description = description.split()
|
||||
if not split_description:
|
||||
return ""
|
||||
return split_description[0].partition('/')[0]
|
||||
|
||||
|
||||
def _extract_accession_and_species(description: str) -> tuple[str, str]:
|
||||
sequence_identifier = _extract_sequence_identifier(description)
|
||||
if not sequence_identifier:
|
||||
return "", ""
|
||||
|
||||
matches = _UNIPROT_HEADER_PATTERN.search(sequence_identifier.strip())
|
||||
if matches:
|
||||
return matches.group("accession"), matches.group("species")
|
||||
|
||||
for pattern in (
|
||||
_UNIREF_HEADER_PATTERN,
|
||||
_UNIPARC_HEADER_PATTERN,
|
||||
_GENERIC_ACCESSION_PATTERN,
|
||||
):
|
||||
matches = pattern.search(sequence_identifier.strip())
|
||||
if matches:
|
||||
return matches.group("accession"), ""
|
||||
|
||||
return "", ""
|
||||
|
||||
|
||||
def _batched(items: Sequence[str], batch_size: int) -> Iterable[Sequence[str]]:
|
||||
for start in range(0, len(items), batch_size):
|
||||
yield items[start : start + batch_size]
|
||||
|
||||
|
||||
def _query_uniprot_batch(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
urlopen: Callable[..., object],
|
||||
) -> dict[str, object]:
|
||||
query = " OR ".join(f"accession:{accession}" for accession in accessions)
|
||||
url = (
|
||||
"https://rest.uniprot.org/uniprotkb/search?query="
|
||||
f"{parse.quote(query)}"
|
||||
"&fields=accession,organism_id"
|
||||
"&format=json"
|
||||
f"&size={len(accessions)}"
|
||||
)
|
||||
with urlopen(url, timeout=_UNIPROT_TIMEOUT_SECONDS) as response:
|
||||
return json.load(response)
|
||||
|
||||
|
||||
def _query_uniparc_batch(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
urlopen: Callable[..., object],
|
||||
) -> dict[str, object]:
|
||||
query = " OR ".join(f"upi:{accession}" for accession in accessions)
|
||||
url = (
|
||||
"https://rest.uniprot.org/uniparc/search?query="
|
||||
f"{parse.quote(query)}"
|
||||
"&fields=upi,organism_id"
|
||||
"&format=json"
|
||||
f"&size={len(accessions)}"
|
||||
)
|
||||
with urlopen(url, timeout=_UNIPROT_TIMEOUT_SECONDS) as response:
|
||||
return json.load(response)
|
||||
|
||||
|
||||
def _query_uniprot_species_ids(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
urlopen: Callable[..., object],
|
||||
) -> dict[str, str]:
|
||||
resolved: dict[str, str] = {}
|
||||
for batch in _batched(sorted(set(accessions)), _UNIPROT_BATCH_SIZE):
|
||||
try:
|
||||
payload = _query_uniprot_batch(batch, urlopen=urlopen)
|
||||
except Exception as exc: # pragma: no cover - best-effort network fallback
|
||||
logging.warning(
|
||||
"Unable to resolve UniProtKB taxonomy for %d accessions: %s",
|
||||
len(batch),
|
||||
exc,
|
||||
)
|
||||
payload = {"results": []}
|
||||
for accession in batch:
|
||||
try:
|
||||
payload["results"].extend(
|
||||
_query_uniprot_batch([accession], urlopen=urlopen)["results"]
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
for result in payload.get("results", []):
|
||||
accession = result.get("primaryAccession")
|
||||
taxon_id = result.get("organism", {}).get("taxonId")
|
||||
if accession and taxon_id is not None:
|
||||
resolved[accession] = str(taxon_id)
|
||||
return resolved
|
||||
|
||||
|
||||
def _query_uniparc_species_ids(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
urlopen: Callable[..., object],
|
||||
) -> dict[str, str]:
|
||||
resolved: dict[str, str] = {}
|
||||
for batch in _batched(sorted(set(accessions)), _UNIPARC_BATCH_SIZE):
|
||||
try:
|
||||
payload = _query_uniparc_batch(batch, urlopen=urlopen)
|
||||
except Exception as exc: # pragma: no cover - best-effort network fallback
|
||||
logging.warning(
|
||||
"Unable to resolve UniParc taxonomy for %d accessions: %s",
|
||||
len(batch),
|
||||
exc,
|
||||
)
|
||||
payload = {"results": []}
|
||||
for accession in batch:
|
||||
try:
|
||||
payload["results"].extend(
|
||||
_query_uniparc_batch([accession], urlopen=urlopen)["results"]
|
||||
)
|
||||
except Exception:
|
||||
continue
|
||||
for result in payload.get("results", []):
|
||||
accession = result.get("uniParcId")
|
||||
organisms = result.get("organisms", [])
|
||||
taxon_ids = {
|
||||
str(organism.get("taxonId"))
|
||||
for organism in organisms
|
||||
if organism.get("taxonId") is not None
|
||||
}
|
||||
if accession and len(taxon_ids) == 1:
|
||||
resolved[accession] = next(iter(taxon_ids))
|
||||
return resolved
|
||||
|
||||
|
||||
def resolve_species_ids_by_accession(
|
||||
accessions: Sequence[str],
|
||||
*,
|
||||
urlopen: Callable[..., object] = request.urlopen,
|
||||
) -> dict[str, str]:
|
||||
unresolved = [
|
||||
accession
|
||||
for accession in sorted(set(accessions))
|
||||
if accession and accession not in _SPECIES_ID_CACHE
|
||||
]
|
||||
if unresolved:
|
||||
uniprot_accessions = [
|
||||
accession for accession in unresolved if not accession.startswith("UPI")
|
||||
]
|
||||
uniparc_accessions = [
|
||||
accession for accession in unresolved if accession.startswith("UPI")
|
||||
]
|
||||
resolved = _query_uniprot_species_ids(
|
||||
uniprot_accessions, urlopen=urlopen
|
||||
)
|
||||
resolved.update(
|
||||
_query_uniparc_species_ids(uniparc_accessions, urlopen=urlopen)
|
||||
)
|
||||
for accession in unresolved:
|
||||
_SPECIES_ID_CACHE[accession] = resolved.get(accession, "")
|
||||
|
||||
return {
|
||||
accession: _SPECIES_ID_CACHE.get(accession, "")
|
||||
for accession in accessions
|
||||
if accession
|
||||
}
|
||||
|
||||
|
||||
def build_mmseq_identifier_features(
|
||||
a3m_string: str,
|
||||
*,
|
||||
species_resolver: Callable[[Sequence[str]], dict[str, str]] = (
|
||||
resolve_species_ids_by_accession
|
||||
),
|
||||
) -> dict[str, np.ndarray]:
|
||||
msa = parsers.parse_a3m(a3m_string)
|
||||
seen_sequences: set[str] = set()
|
||||
accessions: list[str] = []
|
||||
species_ids: list[str] = []
|
||||
|
||||
for sequence, description in zip(
|
||||
msa.sequences, msa.descriptions, strict=True
|
||||
):
|
||||
if sequence in seen_sequences:
|
||||
continue
|
||||
seen_sequences.add(sequence)
|
||||
accession_id, species_id = _extract_accession_and_species(description)
|
||||
accessions.append(accession_id)
|
||||
species_ids.append(species_id)
|
||||
|
||||
resolved_species_ids = species_resolver(
|
||||
[accession for accession, species_id in zip(accessions, species_ids, strict=True)
|
||||
if accession and not species_id]
|
||||
)
|
||||
species_ids = [
|
||||
species_id or resolved_species_ids.get(accession_id, "")
|
||||
for accession_id, species_id in zip(accessions, species_ids, strict=True)
|
||||
]
|
||||
|
||||
return {
|
||||
"msa_species_identifiers": np.array(
|
||||
[species_id.encode("utf-8") for species_id in species_ids],
|
||||
dtype=np.object_,
|
||||
),
|
||||
"msa_uniprot_accession_identifiers": np.array(
|
||||
[accession_id.encode("utf-8") for accession_id in accessions],
|
||||
dtype=np.object_,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def enrich_mmseq_feature_dict_with_identifiers(
|
||||
feature_dict: dict[str, np.ndarray],
|
||||
a3m_string: str,
|
||||
*,
|
||||
species_resolver: Callable[[Sequence[str]], dict[str, str]] = (
|
||||
resolve_species_ids_by_accession
|
||||
),
|
||||
) -> None:
|
||||
identifier_features = build_mmseq_identifier_features(
|
||||
a3m_string,
|
||||
species_resolver=species_resolver,
|
||||
)
|
||||
msa = feature_dict.get("msa")
|
||||
if msa is None:
|
||||
return
|
||||
if len(identifier_features["msa_species_identifiers"]) != msa.shape[0]:
|
||||
logging.warning(
|
||||
"Skipping mmseqs species enrichment because identifier rows do not "
|
||||
"match MSA rows: %d != %d",
|
||||
len(identifier_features["msa_species_identifiers"]),
|
||||
msa.shape[0],
|
||||
)
|
||||
return
|
||||
feature_dict.update(identifier_features)
|
||||
117
test/test_mmseqs_species_identifiers.py
Normal file
117
test/test_mmseqs_species_identifiers.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import numpy as np
|
||||
|
||||
from alphafold.data import msa_pairing
|
||||
from alphafold.data import parsers
|
||||
from alphafold.data import pipeline
|
||||
from alphapulldown.utils import mmseqs_species_identifiers
|
||||
|
||||
|
||||
def _feature_dict_from_a3m(
|
||||
sequence: str,
|
||||
a3m: str,
|
||||
*,
|
||||
species_resolver,
|
||||
) -> dict[str, np.ndarray]:
|
||||
feature_dict = {
|
||||
**pipeline.make_sequence_features(sequence, 'none', len(sequence)),
|
||||
**pipeline.make_msa_features([parsers.parse_a3m(a3m)]),
|
||||
}
|
||||
mmseqs_species_identifiers.enrich_mmseq_feature_dict_with_identifiers(
|
||||
feature_dict,
|
||||
a3m,
|
||||
species_resolver=species_resolver,
|
||||
)
|
||||
valid_feats = msa_pairing.MSA_FEATURES + (
|
||||
'msa_species_identifiers',
|
||||
'msa_uniprot_accession_identifiers',
|
||||
)
|
||||
feature_dict.update(
|
||||
{
|
||||
f'{key}_all_seq': value
|
||||
for key, value in feature_dict.items()
|
||||
if key in valid_feats
|
||||
}
|
||||
)
|
||||
return feature_dict
|
||||
|
||||
|
||||
def test_make_msa_features_resolves_mmseqs_species_identifiers(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
mmseqs_species_identifiers,
|
||||
'resolve_species_ids_by_accession',
|
||||
lambda accessions, **_: {
|
||||
'A0A636IKY3': '108619',
|
||||
'UPI001118B830': '562',
|
||||
},
|
||||
)
|
||||
|
||||
a3m = '\n'.join([
|
||||
'>101',
|
||||
'ACDE',
|
||||
'>UniRef100_A0A636IKY3\t136\t0.883',
|
||||
'ACDF',
|
||||
'>UniRef100_UPI001118B830\t855\t0.990',
|
||||
'AC-E',
|
||||
'',
|
||||
])
|
||||
|
||||
features = mmseqs_species_identifiers.build_mmseq_identifier_features(a3m)
|
||||
|
||||
assert features['msa_species_identifiers'].tolist() == [
|
||||
b'',
|
||||
b'108619',
|
||||
b'562',
|
||||
]
|
||||
assert features['msa_uniprot_accession_identifiers'].tolist() == [
|
||||
b'',
|
||||
b'A0A636IKY3',
|
||||
b'UPI001118B830',
|
||||
]
|
||||
|
||||
|
||||
def test_pair_sequences_works_with_mmseqs_accession_species_resolution(
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
mmseqs_species_identifiers,
|
||||
'resolve_species_ids_by_accession',
|
||||
lambda accessions, **_: {
|
||||
'A0A636IKY3': '562',
|
||||
'A0A743YDY2': '573',
|
||||
'UPI001118B830': '562',
|
||||
'UPI00101273C6': '573',
|
||||
},
|
||||
)
|
||||
|
||||
chain_a = _feature_dict_from_a3m(
|
||||
'ACDE',
|
||||
'\n'.join([
|
||||
'>101',
|
||||
'ACDE',
|
||||
'>UniRef100_A0A636IKY3\t136\t0.883',
|
||||
'ACDF',
|
||||
'>UniRef100_A0A743YDY2\t134\t0.932',
|
||||
'AC-E',
|
||||
'',
|
||||
]),
|
||||
species_resolver=mmseqs_species_identifiers.resolve_species_ids_by_accession,
|
||||
)
|
||||
chain_b = _feature_dict_from_a3m(
|
||||
'WXYZ',
|
||||
'\n'.join([
|
||||
'>101',
|
||||
'WXYZ',
|
||||
'>UniRef100_UPI001118B830\t855\t0.990',
|
||||
'WXYW',
|
||||
'>UniRef100_UPI00101273C6\t833\t0.919',
|
||||
'WX-Z',
|
||||
'',
|
||||
]),
|
||||
species_resolver=mmseqs_species_identifiers.resolve_species_ids_by_accession,
|
||||
)
|
||||
|
||||
paired_rows = msa_pairing.pair_sequences([chain_a, chain_b])[2]
|
||||
|
||||
assert paired_rows.shape == (3, 2)
|
||||
assert tuple(paired_rows[0]) == (0, 0)
|
||||
assert {tuple(row) for row in paired_rows[1:]} == {(1, 1), (2, 2)}
|
||||
Reference in New Issue
Block a user