mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Add option to run Jackhmmer/Nhmmer genetic search in sharded mode (10-30x faster)
PiperOrigin-RevId: 825546064 Change-Id: Ib421e47bb9ca7eea512c49a532e7e995a0f5721f
This commit is contained in:
committed by
Copybara-Service
parent
46fe3f0f60
commit
805adc3863
@@ -82,6 +82,85 @@ is the number of cores used for each Jackhmmer process times 4. Also note that
|
||||
for sequences with deep MSAs, Jackhmmer or Nhmmer may need a substantial amount
|
||||
of RAM beyond the recommended 64 GB of RAM.
|
||||
|
||||
### Sharded genetic databases
|
||||
|
||||
The run time of the genetic database search can be *significantly* sped up by
|
||||
splitting the genetic databases if a machine with many CPU cores is used and the
|
||||
databases are on very fast SSD or in a RAM-backed filesystem. With this
|
||||
technique you can make Jackhmmer/Nhmmer genetic search fully utilize your
|
||||
hardware and take advantage of multi-core systems.
|
||||
|
||||
Each genetic database with *n* sequences is split into *s* shards, each
|
||||
containing roughly *n* / *s* sequences. We recommend splitting the sequences
|
||||
between shards randomly to make sure each shard has similar sequence length
|
||||
distribution. This could be achieved using standard tools:
|
||||
|
||||
1. Shuffle the sequences in the fasta. This can be done for example by running:
|
||||
`seqkit shuffle --two-pass <db.fasta>`
|
||||
2. Split the shuffled fasta in *s* shards. This can be done for example by
|
||||
running: `seqkit split2 --by-part <s> <db.fasta>`
|
||||
|
||||
Make sure the shards names follow this pattern:
|
||||
`prefix-<shard_index>-of-<total_shards>`, both `shard_index` and `total_shards`
|
||||
having always 5 digits, with leading zeros as needed. The `shard_index` goes
|
||||
from 0 to `total_shards - 1`. A file "path" (spec) for a sharded file is
|
||||
`prefix@<total_shards>`.
|
||||
|
||||
E.g. for a file named `uniprot.fasta` split into 3 shards, the names of the
|
||||
shards should be:
|
||||
|
||||
* `uniprot.fasta-00000-of-00003`
|
||||
* `uniprot.fasta-00001-of-00003`
|
||||
* `uniprot.fasta-00002-of-00003`
|
||||
|
||||
The file spec for these files is `uniprot.fasta@3`.
|
||||
|
||||
Save the total number of sequences in the protein databases, and the total
|
||||
number of nucleic bases in the RNA databases – these will be needed later as a
|
||||
flag to Jackhmmer/Nhmmer to correctly scale e-values across all shards.
|
||||
|
||||
Save the sharded databases on a fast SSD or in a RAM-backed filesystem, then
|
||||
launch AlphaFold with the sharded paths instead of normal paths and set the
|
||||
Z-values.
|
||||
|
||||
For instance with each database sharded into 16 shards:
|
||||
|
||||
```bash
|
||||
python run_alphafold.py \
|
||||
--small_bfd_database_path="bfd-first_non_consensus_sequences.fasta@64" \
|
||||
--small_bfd_z_value=65984053 \
|
||||
--mgnify_database_path="mgy_clusters_2022_05.fa@512" \
|
||||
--mgnify_z_value=623796864 \
|
||||
--uniprot_cluster_annot_database_path="uniprot_cluster_annot_2021_04.fasta@256" \
|
||||
--uniprot_cluster_annot_z_value=225619586 \
|
||||
--uniref90_database_path="uniref90_2022_05.fasta@128" \
|
||||
--uniref90_z_value=153742194 \
|
||||
--ntrna_database_path="nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta@256" \
|
||||
--ntrna_z_value=76752.808514 \
|
||||
--rfam_database_path="rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta@16" \
|
||||
--rfam_z_value=138.115553 \
|
||||
--rna_central_database_path="rnacentral_active_seq_id_90_cov_80_linclust.fasta@64" \
|
||||
--rna_central_z_value=13271.415730
|
||||
--jackhmmer_n_cpu=2 \
|
||||
--jackhmmer_max_parallel_shards=16 \
|
||||
--nhmmer_n_cpu=2 \
|
||||
--nhmmer_max_parallel_shards=16
|
||||
```
|
||||
|
||||
This run will utilize (2 CPUs) × (16 max parallel shards) × (4 protein dbs
|
||||
searched in parallel) = 128 cores for each protein chain, and (2 CPUs) × (16 max
|
||||
parallel shards) × (3 RNA dbs searched in parallel) = 96 cores for each RNA
|
||||
chain. Make sure to tune:
|
||||
|
||||
* the Jackhmmer/Nhmmer number of CPUs,
|
||||
* the maximum number of shards searched in parallel,
|
||||
* and the number of shards for each database
|
||||
|
||||
so that the memory bandwidth and CPUs on your machine are optimally utilized.
|
||||
You should aim for consistent shard sizes across all databases (so e.g. if
|
||||
database A is split into 16 shards and is 3× smaller than database B, database B
|
||||
should be split into 3 × 16 = 48 shards).
|
||||
|
||||
## Model Inference
|
||||
|
||||
Table 8 in the Supplementary Information of the
|
||||
|
||||
@@ -132,37 +132,86 @@ _SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
|
||||
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
|
||||
'Small BFD database path, used for protein MSA search.',
|
||||
)
|
||||
_SMALL_BFD_Z_VALUE = flags.DEFINE_integer(
|
||||
'small_bfd_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in number of sequences for'
|
||||
' E-value calculation. Must be set for sharded databases.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_MGNIFY_DATABASE_PATH = flags.DEFINE_string(
|
||||
'mgnify_database_path',
|
||||
'${DB_DIR}/mgy_clusters_2022_05.fa',
|
||||
'Mgnify database path, used for protein MSA search.',
|
||||
)
|
||||
_MGNIFY_Z_VALUE = flags.DEFINE_integer(
|
||||
'mgnify_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in number of sequences for'
|
||||
' E-value calculation. Must be set for sharded databases.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string(
|
||||
'uniprot_cluster_annot_database_path',
|
||||
'${DB_DIR}/uniprot_all_2021_04.fa',
|
||||
'UniProt database path, used for protein paired MSA search.',
|
||||
)
|
||||
_UNIPROT_CLUSTER_ANNOT_Z_VALUE = flags.DEFINE_integer(
|
||||
'uniprot_cluster_annot_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in number of sequences for'
|
||||
' E-value calculation. Must be set for sharded databases.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_UNIREF90_DATABASE_PATH = flags.DEFINE_string(
|
||||
'uniref90_database_path',
|
||||
'${DB_DIR}/uniref90_2022_05.fa',
|
||||
'UniRef90 database path, used for MSA search. The MSA obtained by '
|
||||
'searching it is used to construct the profile for template search.',
|
||||
)
|
||||
_UNIREF90_Z_VALUE = flags.DEFINE_integer(
|
||||
'uniref90_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in number of sequences for'
|
||||
' E-value calculation. Must be set for sharded databases.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_NTRNA_DATABASE_PATH = flags.DEFINE_string(
|
||||
'ntrna_database_path',
|
||||
'${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta',
|
||||
'NT-RNA database path, used for RNA MSA search.',
|
||||
)
|
||||
_NTRNA_Z_VALUE = flags.DEFINE_float(
|
||||
'ntrna_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in megabases for E-value'
|
||||
' calculation. Must be set for sharded databases.',
|
||||
lower_bound=0.0,
|
||||
)
|
||||
_RFAM_DATABASE_PATH = flags.DEFINE_string(
|
||||
'rfam_database_path',
|
||||
'${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta',
|
||||
'Rfam database path, used for RNA MSA search.',
|
||||
)
|
||||
_RFAM_Z_VALUE = flags.DEFINE_float(
|
||||
'rfam_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in megabases for E-value'
|
||||
' calculation. Must be set for sharded databases.',
|
||||
lower_bound=0.0,
|
||||
)
|
||||
_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string(
|
||||
'rna_central_database_path',
|
||||
'${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta',
|
||||
'RNAcentral database path, used for RNA MSA search.',
|
||||
)
|
||||
_RNA_CENTRAL_Z_VALUE = flags.DEFINE_float(
|
||||
'rna_central_z_value',
|
||||
None,
|
||||
'The Z-value representing the database size in megabases for E-value'
|
||||
' calculation. Must be set for sharded databases.',
|
||||
lower_bound=0.0,
|
||||
)
|
||||
_PDB_DATABASE_PATH = flags.DEFINE_string(
|
||||
'pdb_database_path',
|
||||
'${DB_DIR}/mmcif_files',
|
||||
@@ -183,6 +232,14 @@ _JACKHMMER_N_CPU = flags.DEFINE_integer(
|
||||
' above 8 CPUs provides very little additional speedup.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_JACKHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
|
||||
'jackhmmer_max_parallel_shards',
|
||||
None,
|
||||
'Maximum number of shards to search against in parallel. If unset, one'
|
||||
' Jackhmmer instance will be run per shard. Only applicable if the'
|
||||
' database is sharded.',
|
||||
lower_bound=1,
|
||||
)
|
||||
_NHMMER_N_CPU = flags.DEFINE_integer(
|
||||
'nhmmer_n_cpu',
|
||||
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
|
||||
@@ -191,6 +248,14 @@ _NHMMER_N_CPU = flags.DEFINE_integer(
|
||||
' above 8 CPUs provides very little additional speedup.',
|
||||
lower_bound=0,
|
||||
)
|
||||
_NHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
|
||||
'nhmmer_max_parallel_shards',
|
||||
None,
|
||||
'Maximum number of shards to search against in parallel. If unset, one'
|
||||
' Nhmmer instance will be run per shard. Only applicable if the'
|
||||
' database is sharded.',
|
||||
lower_bound=1,
|
||||
)
|
||||
|
||||
# Data pipeline configuration.
|
||||
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
|
||||
@@ -828,18 +893,27 @@ def main(_):
|
||||
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value,
|
||||
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value,
|
||||
small_bfd_database_path=expand_path(_SMALL_BFD_DATABASE_PATH.value),
|
||||
small_bfd_z_value=_SMALL_BFD_Z_VALUE.value,
|
||||
mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value),
|
||||
mgnify_z_value=_MGNIFY_Z_VALUE.value,
|
||||
uniprot_cluster_annot_database_path=expand_path(
|
||||
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value
|
||||
),
|
||||
uniprot_cluster_annot_z_value=_UNIPROT_CLUSTER_ANNOT_Z_VALUE.value,
|
||||
uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value),
|
||||
uniref90_z_value=_UNIREF90_Z_VALUE.value,
|
||||
ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value),
|
||||
ntrna_z_value=_NTRNA_Z_VALUE.value,
|
||||
rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value),
|
||||
rfam_z_value=_RFAM_Z_VALUE.value,
|
||||
rna_central_database_path=expand_path(_RNA_CENTRAL_DATABASE_PATH.value),
|
||||
rna_central_z_value=_RNA_CENTRAL_Z_VALUE.value,
|
||||
pdb_database_path=expand_path(_PDB_DATABASE_PATH.value),
|
||||
seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value),
|
||||
jackhmmer_n_cpu=_JACKHMMER_N_CPU.value,
|
||||
jackhmmer_max_parallel_shards=_JACKHMMER_MAX_PARALLEL_SHARDS.value,
|
||||
nhmmer_n_cpu=_NHMMER_N_CPU.value,
|
||||
nhmmer_max_parallel_shards=_NHMMER_MAX_PARALLEL_SHARDS.value,
|
||||
max_template_date=max_template_date,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -42,9 +42,16 @@ class JackhmmerConfig:
|
||||
n_cpu: An integer with the number of CPUs to use.
|
||||
n_iter: An integer with the number of database search iterations.
|
||||
e_value: e-value for the database lookup.
|
||||
z_value: The Z-value representing the number of comparisons done (i.e
|
||||
correct database size) for E-value calculation.
|
||||
z_value: The Z-value representing the database size in number of sequences
|
||||
for E-value and domain E-value calculation. Must be set for sharded
|
||||
databases.
|
||||
dom_z_value: The Z-value representing the database size in number of
|
||||
sequences for domain E-value calculation. Must be set for sharded
|
||||
databases.
|
||||
max_sequences: Max sequences to return in MSA.
|
||||
max_parallel_shards: If given, the maximum number of shards to search
|
||||
against in parallel. If None, one Jackhmmer instance will be run per
|
||||
shard. Only applicable if the database is sharded.
|
||||
"""
|
||||
|
||||
binary_path: str
|
||||
@@ -52,8 +59,10 @@ class JackhmmerConfig:
|
||||
n_cpu: int
|
||||
n_iter: int
|
||||
e_value: float
|
||||
z_value: float | int | None
|
||||
z_value: int | None
|
||||
dom_z_value: int | None
|
||||
max_sequences: int
|
||||
max_parallel_shards: int | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
|
||||
@@ -67,8 +76,14 @@ class NhmmerConfig:
|
||||
database_config: Database configuration.
|
||||
n_cpu: An integer with the number of CPUs to use.
|
||||
e_value: e-value for the database lookup.
|
||||
z_value: The Z-value representing the database size in megabases for
|
||||
E-value calculation. Allows fractional values. Must be set for sharded
|
||||
databases.
|
||||
max_sequences: Max sequences to return in MSA.
|
||||
alphabet: The alphabet when building a profile with hmmbuild.
|
||||
max_parallel_shards: If given, the maximum number of shards to search
|
||||
against in parallel. If None, one Nhmmer instance will be run per shard.
|
||||
Only applicable if the database is sharded.
|
||||
"""
|
||||
|
||||
binary_path: str
|
||||
@@ -77,8 +92,10 @@ class NhmmerConfig:
|
||||
database_config: DatabaseConfig
|
||||
n_cpu: int
|
||||
e_value: float
|
||||
z_value: float | None
|
||||
max_sequences: int
|
||||
alphabet: str | None
|
||||
max_parallel_shards: int | None = None
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
|
||||
|
||||
@@ -214,21 +214,42 @@ class DataPipelineConfig:
|
||||
raw MSA in template search.
|
||||
small_bfd_database_path: Small BFD database path, used for protein MSA
|
||||
search.
|
||||
small_bfd_z_value: The Z-value representing the database size in number of
|
||||
sequences for E-value calculation. Must be set for sharded databases.
|
||||
mgnify_database_path: Mgnify database path, used for protein MSA search.
|
||||
mgnify_z_value: The Z-value representing the database size in number of
|
||||
sequences for E-value calculation. Must be set for sharded databases.
|
||||
uniprot_cluster_annot_database_path: Uniprot database path, used for protein
|
||||
paired MSA search.
|
||||
uniprot_cluster_annot_z_value: The Z-value representing the database size in
|
||||
number of sequences for E-value calculation. Must be set for sharded
|
||||
databases.
|
||||
uniref90_database_path: UniRef90 database path, used for MSA search, and the
|
||||
MSA obtained by searching it is used to construct the profile for template
|
||||
search.
|
||||
uniref90_z_value: The Z-value representing the database size in number of
|
||||
sequences for E-value calculation. Must be set for sharded databases.
|
||||
ntrna_database_path: NT-RNA database path, used for RNA MSA search.
|
||||
ntrna_z_value: The Z-value representing the database size in megabases for
|
||||
E-value calculation. Must be set for sharded databases.
|
||||
rfam_database_path: Rfam database path, used for RNA MSA search.
|
||||
rfam_z_value: The Z-value representing the database size in megabases for
|
||||
E-value calculation. Must be set for sharded databases.
|
||||
rna_central_database_path: RNAcentral database path, used for RNA MSA
|
||||
search.
|
||||
rna_central_z_value: The Z-value representing the database size in megabases
|
||||
for E-value calculation. Must be set for sharded databases.
|
||||
seqres_database_path: PDB sequence database path, used for template search.
|
||||
pdb_database_path: PDB database directory with mmCIF files path, used for
|
||||
template search.
|
||||
jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer.
|
||||
jackhmmer_max_parallel_shards: Maximum number of shards to search against in
|
||||
parallel. If None, one Jackhmmer instance will be run per shard. Only
|
||||
applicable if the database is sharded.
|
||||
nhmmer_n_cpu: Number of CPUs to use for Nhmmer.
|
||||
nhmmer_max_parallel_shards: Maximum number of shards to search against in
|
||||
parallel. If None, one Nhmmer instance will be run per shard. Only
|
||||
applicable if the database is sharded.
|
||||
max_template_date: The latest date of templates to use.
|
||||
"""
|
||||
|
||||
@@ -241,20 +262,29 @@ class DataPipelineConfig:
|
||||
|
||||
# Jackhmmer databases.
|
||||
small_bfd_database_path: str
|
||||
small_bfd_z_value: int | None = None
|
||||
mgnify_database_path: str
|
||||
mgnify_z_value: int | None = None
|
||||
uniprot_cluster_annot_database_path: str
|
||||
uniprot_cluster_annot_z_value: int | None = None
|
||||
uniref90_database_path: str
|
||||
uniref90_z_value: int | None = None
|
||||
# Nhmmer databases.
|
||||
ntrna_database_path: str
|
||||
ntrna_z_value: int | None = None
|
||||
rfam_database_path: str
|
||||
rfam_z_value: int | None = None
|
||||
rna_central_database_path: str
|
||||
rna_central_z_value: int | None = None
|
||||
# Template search databases.
|
||||
seqres_database_path: str
|
||||
pdb_database_path: str
|
||||
|
||||
# Optional configuration for MSA tools.
|
||||
jackhmmer_n_cpu: int = 8
|
||||
jackhmmer_max_parallel_shards: int | None = None
|
||||
nhmmer_n_cpu: int = 8
|
||||
nhmmer_max_parallel_shards: int | None = None
|
||||
|
||||
max_template_date: datetime.date
|
||||
|
||||
@@ -274,8 +304,10 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
|
||||
n_iter=1,
|
||||
e_value=1e-4,
|
||||
z_value=None,
|
||||
z_value=data_pipeline_config.uniref90_z_value,
|
||||
dom_z_value=data_pipeline_config.uniref90_z_value,
|
||||
max_sequences=10_000,
|
||||
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -290,8 +322,10 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
|
||||
n_iter=1,
|
||||
e_value=1e-4,
|
||||
z_value=None,
|
||||
z_value=data_pipeline_config.mgnify_z_value,
|
||||
dom_z_value=data_pipeline_config.mgnify_z_value,
|
||||
max_sequences=5_000,
|
||||
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -308,8 +342,10 @@ class DataPipeline:
|
||||
e_value=1e-4,
|
||||
# Set z_value=138_515_945 to match the z_value used in the paper.
|
||||
# In practice, this has minimal impact on predicted structures.
|
||||
z_value=None,
|
||||
z_value=data_pipeline_config.small_bfd_z_value,
|
||||
dom_z_value=data_pipeline_config.small_bfd_z_value,
|
||||
max_sequences=5_000,
|
||||
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -324,8 +360,10 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
|
||||
n_iter=1,
|
||||
e_value=1e-4,
|
||||
z_value=None,
|
||||
z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
|
||||
dom_z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
|
||||
max_sequences=50_000,
|
||||
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -342,7 +380,9 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.nhmmer_n_cpu,
|
||||
e_value=1e-3,
|
||||
alphabet='rna',
|
||||
z_value=data_pipeline_config.ntrna_z_value,
|
||||
max_sequences=10_000,
|
||||
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.RNA_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -359,7 +399,9 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.nhmmer_n_cpu,
|
||||
e_value=1e-3,
|
||||
alphabet='rna',
|
||||
z_value=data_pipeline_config.rfam_z_value,
|
||||
max_sequences=10_000,
|
||||
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.RNA_CHAIN,
|
||||
crop_size=None,
|
||||
@@ -376,7 +418,9 @@ class DataPipeline:
|
||||
n_cpu=data_pipeline_config.nhmmer_n_cpu,
|
||||
e_value=1e-3,
|
||||
alphabet='rna',
|
||||
z_value=data_pipeline_config.rna_central_z_value,
|
||||
max_sequences=10_000,
|
||||
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
|
||||
),
|
||||
chain_poly_type=mmcif_names.RNA_CHAIN,
|
||||
crop_size=None,
|
||||
|
||||
@@ -10,12 +10,19 @@
|
||||
|
||||
"""Library to run Jackhmmer from Python."""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent import futures
|
||||
import heapq
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
from absl import logging
|
||||
from alphafold3.data import parsers
|
||||
from alphafold3.data.tools import msa_tool
|
||||
from alphafold3.data.tools import shards
|
||||
from alphafold3.data.tools import subprocess_utils
|
||||
|
||||
|
||||
@@ -31,43 +38,91 @@ class Jackhmmer(msa_tool.MsaTool):
|
||||
n_iter: int = 3,
|
||||
e_value: float | None = 1e-3,
|
||||
z_value: float | int | None = None,
|
||||
dom_e: float | None = None,
|
||||
dom_z_value: float | int | None = None,
|
||||
max_sequences: int = 5000,
|
||||
filter_f1: float = 5e-4,
|
||||
filter_f2: float = 5e-5,
|
||||
filter_f3: float = 5e-7,
|
||||
max_threads: int | None = None,
|
||||
**unused_kwargs,
|
||||
):
|
||||
"""Initializes the Python Jackhmmer wrapper.
|
||||
|
||||
NOTE: The MSA obtained by running against sharded dbs won't be always
|
||||
exactly the same as the MSA obtained by running against an unsharded db.
|
||||
This is because of Jackhmmer deduplication logic, which won't spot duplicate
|
||||
hits across multiple shards. Usually this means that the sharded search
|
||||
finds more hits (likely bounded by the number of shards), but this should
|
||||
not pose an issue given how the results are used downstream. The problem is
|
||||
more pronounced with deep MSAs and lower in the hit list (higher e-values).
|
||||
|
||||
Make sure to set the Z and domZ values when searching against a sharded
|
||||
database, otherwise the results won't match the normal unsharded search.
|
||||
|
||||
Args:
|
||||
binary_path: The path to the jackhmmer executable.
|
||||
database_path: The path to the jackhmmer database (FASTA format).
|
||||
database_path: The path to the jackhmmer database (FASTA format). Sharded
|
||||
file specs, e.g. `<db_path>@<num_shards>`, are supported.
|
||||
n_cpu: The number of CPUs to give Jackhmmer.
|
||||
n_iter: The number of Jackhmmer iterations.
|
||||
e_value: The E-value, see Jackhmmer docs for more details.
|
||||
z_value: The Z-value representing the number of comparisons done (i.e
|
||||
correct database size) for E-value calculation.
|
||||
correct database size) for E-value calculation. Make sure to set this
|
||||
when searching against a sharded database, otherwise the e-values will
|
||||
be incorrectly scaled.
|
||||
dom_e: Domain e-value criteria for inclusion in tblout.
|
||||
dom_z_value: Domain z-value representing the number of comparisons done
|
||||
(i.e correct database size) for domain E-value calculation. Make sure to
|
||||
set this when searching against a sharded database, otherwise the domain
|
||||
e-values will be incorrectly scaled.
|
||||
max_sequences: Maximum number of sequences to return in the MSA.
|
||||
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
|
||||
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
|
||||
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
||||
max_threads: If given, the maximum number of threads used when running
|
||||
sharded databases.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Jackhmmer binary not found within the path.
|
||||
ValueError: If an invalid configuration is provided in the args.
|
||||
"""
|
||||
self._binary_path = binary_path
|
||||
self._database_path = database_path
|
||||
|
||||
if shard_paths := shards.get_sharded_paths(self._database_path):
|
||||
if n_iter != 1:
|
||||
raise ValueError('For a sharded db, only n_iter=1 is supported.')
|
||||
if z_value is None:
|
||||
raise ValueError(
|
||||
'The Z-value must be set when searching against a sharded database '
|
||||
'to correctly scale e-values.'
|
||||
)
|
||||
if max_sequences <= 1:
|
||||
raise ValueError(
|
||||
'max_sequences must be greater than 1 when running in sharded '
|
||||
'mode, because each shard would return only the query sequence.'
|
||||
)
|
||||
|
||||
self._shard_paths = shard_paths
|
||||
self._max_threads = len(self._shard_paths)
|
||||
if max_threads is not None:
|
||||
self._max_threads = min(max_threads, self._max_threads)
|
||||
logging.info('Jackhmmer running with max_threads = %d', self._max_threads)
|
||||
else:
|
||||
self._shard_paths = None
|
||||
self._max_threads = None
|
||||
|
||||
self._binary_path = binary_path
|
||||
subprocess_utils.check_binary_exists(
|
||||
path=self._binary_path, name='Jackhmmer'
|
||||
)
|
||||
|
||||
if not os.path.exists(self._database_path):
|
||||
raise ValueError(f'Could not find Jackhmmer database {database_path}')
|
||||
|
||||
self._n_cpu = n_cpu
|
||||
self._n_iter = n_iter
|
||||
self._e_value = e_value
|
||||
self._z_value = z_value
|
||||
self._dom_e = dom_e
|
||||
self._dom_z_value = dom_z_value
|
||||
self._max_sequences = max_sequences
|
||||
self._filter_f1 = filter_f1
|
||||
self._filter_f2 = filter_f2
|
||||
@@ -81,21 +136,73 @@ class Jackhmmer(msa_tool.MsaTool):
|
||||
)
|
||||
|
||||
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
|
||||
"""Queries the database using Jackhmmer."""
|
||||
logging.info(
|
||||
'Query sequence: %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
"""Query the database (sharded or unsharded) using Jackhmmer."""
|
||||
if self._shard_paths:
|
||||
# Sharded case, run the query against each database shard in parallel.
|
||||
logging.info(
|
||||
'Query sequence (sharded db): %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as query_tmp_dir:
|
||||
global_temp_dir = tempfile.mkdtemp()
|
||||
|
||||
def _query_shard_fn(
|
||||
shard_path: str,
|
||||
) -> tuple[msa_tool.MsaToolResult, float]:
|
||||
t_start = time.time()
|
||||
result = self._query_db_shard(
|
||||
target_sequence=target_sequence,
|
||||
db_shard_path=shard_path,
|
||||
get_tblout=True, # Tblout contains e-values needed for merging.
|
||||
global_temp_dir=global_temp_dir,
|
||||
)
|
||||
return result, time.time() - t_start
|
||||
|
||||
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
|
||||
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
|
||||
|
||||
logging.info(
|
||||
'Finished query for %d shards, shard timings (seconds): %s',
|
||||
len(tool_outputs),
|
||||
', '.join(f'{t:.1f}' for t in timings),
|
||||
)
|
||||
|
||||
shutil.rmtree(global_temp_dir, ignore_errors=True)
|
||||
return _merge_jackhmmer_results(tool_outputs, self._max_sequences)
|
||||
|
||||
else:
|
||||
# Non-sharded case, run the query against the whole database.
|
||||
logging.info(
|
||||
'Query sequence (non-sharded db): %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
return self._query_db_shard(
|
||||
target_sequence=target_sequence,
|
||||
db_shard_path=self._database_path,
|
||||
get_tblout=False,
|
||||
)
|
||||
|
||||
def _query_db_shard(
|
||||
self,
|
||||
*,
|
||||
target_sequence: str,
|
||||
db_shard_path: str,
|
||||
get_tblout: bool,
|
||||
global_temp_dir: str | None = None,
|
||||
) -> msa_tool.MsaToolResult:
|
||||
"""Query the database shard using Jackhmmer."""
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
|
||||
input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta')
|
||||
subprocess_utils.create_query_fasta_file(
|
||||
sequence=target_sequence, path=input_fasta_path
|
||||
)
|
||||
|
||||
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
|
||||
pathlib.Path(output_sto_path).touch()
|
||||
|
||||
# The F1/F2/F3 are the expected proportion to pass each of the filtering
|
||||
# stages (which get progressively more expensive), reducing these
|
||||
@@ -113,6 +220,13 @@ class Jackhmmer(msa_tool.MsaTool):
|
||||
*('-N', str(self._n_iter)),
|
||||
]
|
||||
|
||||
if get_tblout:
|
||||
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
|
||||
output_tblout_path.touch()
|
||||
cmd_flags.extend(['--tblout', str(output_tblout_path)])
|
||||
else:
|
||||
output_tblout_path = None
|
||||
|
||||
# Report only sequences with E-values <= x in per-sequence output.
|
||||
if self._e_value is not None:
|
||||
cmd_flags.extend(['-E', str(self._e_value)])
|
||||
@@ -123,18 +237,21 @@ class Jackhmmer(msa_tool.MsaTool):
|
||||
if self._z_value is not None:
|
||||
cmd_flags.extend(['-Z', str(self._z_value)])
|
||||
|
||||
if self._dom_z_value is not None:
|
||||
cmd_flags.extend(['--domZ', str(self._dom_z_value)])
|
||||
|
||||
if self._dom_e is not None:
|
||||
cmd_flags.extend(['--domE', str(self._dom_e)])
|
||||
|
||||
if self._max_sequences is not None and self._supports_seq_limit:
|
||||
cmd_flags.extend(['--seq_limit', str(self._max_sequences)])
|
||||
|
||||
cmd = (
|
||||
[self._binary_path]
|
||||
+ cmd_flags
|
||||
+ [input_fasta_path, self._database_path]
|
||||
)
|
||||
# The input FASTA and the input db are the last two arguments.
|
||||
cmd = [self._binary_path] + cmd_flags + [input_fasta_path, db_shard_path]
|
||||
|
||||
subprocess_utils.run(
|
||||
cmd=cmd,
|
||||
cmd_name=f'Jackhmmer ({os.path.basename(self._database_path)})',
|
||||
cmd_name=f'Jackhmmer ({os.path.basename(db_shard_path)})',
|
||||
log_stdout=False,
|
||||
log_stderr=True,
|
||||
log_on_process_error=True,
|
||||
@@ -145,6 +262,73 @@ class Jackhmmer(msa_tool.MsaTool):
|
||||
f, max_sequences=self._max_sequences
|
||||
)
|
||||
|
||||
return msa_tool.MsaToolResult(
|
||||
target_sequence=target_sequence, a3m=a3m, e_value=self._e_value
|
||||
)
|
||||
# Get the tabular output which has e.g. e-value for each target.
|
||||
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
|
||||
|
||||
return msa_tool.MsaToolResult(
|
||||
target_sequence=target_sequence,
|
||||
a3m=a3m,
|
||||
e_value=self._e_value,
|
||||
tblout=tbl,
|
||||
)
|
||||
|
||||
|
||||
def _merge_jackhmmer_results(
|
||||
jh_results: Sequence[msa_tool.MsaToolResult], max_sequences: int
|
||||
) -> msa_tool.MsaToolResult:
|
||||
"""Merges Jackhmmer result protos into a single one."""
|
||||
assert len(set(jh_res.target_sequence for jh_res in jh_results)) == 1
|
||||
assert len(set(jh_res.e_value for jh_res in jh_results)) == 1
|
||||
|
||||
# Parse the TBL output, create a mapping from hit name to TBL line.
|
||||
parsed_tbl = {}
|
||||
for jh_result in jh_results:
|
||||
assert jh_result.tblout is not None
|
||||
for line in jh_result.tblout.splitlines():
|
||||
if not line.startswith('#'):
|
||||
parsed_tbl[line.partition(' ')[0]] = line
|
||||
|
||||
# Create an iterator and merge a3m info with tbl info.
|
||||
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
|
||||
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
|
||||
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
|
||||
next(iterator) # Skip the query which isn't present in tblout.
|
||||
for sequence, description in iterator:
|
||||
name = description.partition(' ')[0].partition('/')[0]
|
||||
if tbl_info := parsed_tbl.get(name):
|
||||
# Skip sequences for which we don't have tbl information.
|
||||
yield sequence, description, tbl_info, name
|
||||
|
||||
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]:
|
||||
unused_seq, unused_description, tbl_info, name = seq_data
|
||||
# Tblout lines have 19 whitespace delimited columns. "-" used if no value
|
||||
# present. We want e-value in column with index 4, so do only 5 splits.
|
||||
# Use the name in case of a e-value tie.
|
||||
return float(tbl_info.split(maxsplit=5)[4]), name
|
||||
|
||||
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
|
||||
merged_a3m_and_tblout = heapq.merge(
|
||||
*[_merged_a3m_tbl_iter(res.a3m) for res in jh_results],
|
||||
key=sort_key,
|
||||
)
|
||||
|
||||
# Truncate the a3m to max_sequences. Do not truncate the tblout.
|
||||
merged_tblout = []
|
||||
merged_a3m = [f'>query\n{jh_results[0].target_sequence}']
|
||||
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
|
||||
merged_tblout.append(tbl_info)
|
||||
if len(merged_a3m) < max_sequences:
|
||||
merged_a3m.append(f'>{description}\n{seq}')
|
||||
|
||||
logging.info(
|
||||
'Limiting merged MSA depth from %d to %d',
|
||||
len(merged_tblout),
|
||||
max_sequences,
|
||||
)
|
||||
|
||||
return msa_tool.MsaToolResult(
|
||||
target_sequence=jh_results[0].target_sequence,
|
||||
a3m='\n'.join(merged_a3m),
|
||||
e_value=jh_results[0].e_value,
|
||||
tblout=None, # We no longer need the tblout.
|
||||
)
|
||||
|
||||
@@ -16,11 +16,20 @@ from typing import Protocol
|
||||
|
||||
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
|
||||
class MsaToolResult:
|
||||
"""The result of a MSA tool query."""
|
||||
"""The result of a MSA tool query.
|
||||
|
||||
Attributes:
|
||||
target_sequence: The sequence that was used to query the MSA tool.
|
||||
e_value: The e-value that was used to filter the MSA tool results.
|
||||
a3m: The MSA output of the tool in the A3M format.
|
||||
tblout: The optional tblout output of the MSA tool (needed for merging
|
||||
results of queries against a sharded database).
|
||||
"""
|
||||
|
||||
target_sequence: str
|
||||
e_value: float
|
||||
a3m: str
|
||||
tblout: str | None = None
|
||||
|
||||
|
||||
class MsaTool(Protocol):
|
||||
|
||||
@@ -10,9 +10,14 @@
|
||||
|
||||
"""Library to run Nhmmer from Python."""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent import futures
|
||||
import heapq
|
||||
import os
|
||||
import pathlib
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Final
|
||||
|
||||
from absl import logging
|
||||
@@ -20,8 +25,10 @@ from alphafold3.data import parsers
|
||||
from alphafold3.data.tools import hmmalign
|
||||
from alphafold3.data.tools import hmmbuild
|
||||
from alphafold3.data.tools import msa_tool
|
||||
from alphafold3.data.tools import shards
|
||||
from alphafold3.data.tools import subprocess_utils
|
||||
|
||||
|
||||
_SHORT_SEQUENCE_CUTOFF: Final[int] = 50
|
||||
|
||||
|
||||
@@ -36,38 +43,83 @@ class Nhmmer(msa_tool.MsaTool):
|
||||
database_path: str,
|
||||
n_cpu: int = 8,
|
||||
e_value: float = 1e-3,
|
||||
z_value: float | int | None = None,
|
||||
max_sequences: int = 5000,
|
||||
filter_f3: float = 1e-5,
|
||||
alphabet: str | None = None,
|
||||
strand: str | None = None,
|
||||
max_threads: int | None = None,
|
||||
):
|
||||
"""Initializes the Python Nhmmer wrapper.
|
||||
|
||||
NOTE: The MSA obtained by running against sharded dbs won't be always
|
||||
exactly the same as the MSA obtained by running against an unsharded db.
|
||||
This is because of Jackhmmer deduplication logic, which won't spot duplicate
|
||||
hits across multiple shards. Usually this means that the sharded search
|
||||
finds more hits (likely bounded by the number of shards), but this should
|
||||
not pose an issue given how the results are used downstream. The problem is
|
||||
more pronounced with deep MSAs and lower in the hit list (higher e-values).
|
||||
|
||||
Make sure to set the Z value when searching against a sharded database,
|
||||
otherwise the results won't match the normal unsharded search.
|
||||
|
||||
Args:
|
||||
binary_path: Path to the Nhmmer binary.
|
||||
hmmalign_binary_path: Path to the Hmmalign binary.
|
||||
hmmbuild_binary_path: Path to the Hmmbuild binary.
|
||||
database_path: MSA database path to search against. This can be either a
|
||||
FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb
|
||||
binary. The HMMERDB is ~10x faster but experimental.
|
||||
binary. The HMMERDB is ~10x faster but experimental. Sharded file
|
||||
specs, e.g. <db_path>@<num_shards>, are supported.
|
||||
n_cpu: The number of CPUs to give Nhmmer.
|
||||
e_value: The E-value, see Nhmmer docs for more details. Will be
|
||||
overwritten if bit_score is set.
|
||||
z_value: The Z-value representing the number of comparisons done (i.e
|
||||
correct database size) for E-value calculation. Make sure to set this
|
||||
when searching against a sharded database, otherwise the e-values will
|
||||
be incorrectly scaled.
|
||||
max_sequences: Maximum number of sequences to return in the MSA.
|
||||
filter_f3: Forward pre-filter, set to >1.0 to turn off.
|
||||
alphabet: The alphabet to assert when building a profile with hmmbuild.
|
||||
This must be 'rna', 'dna', or None.
|
||||
strand: "watson" searches query sequence, "crick" searches
|
||||
reverse-compliment and default is None which means searching for both.
|
||||
max_threads: If given, the maximum number of threads used when running
|
||||
sharded databases.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If Nhmmer binary not found within the path.
|
||||
ValueError: If an invalid configuration is provided in the args.
|
||||
"""
|
||||
self._database_path = database_path
|
||||
|
||||
if shard_paths := shards.get_sharded_paths(self._database_path):
|
||||
if z_value is None:
|
||||
raise ValueError(
|
||||
'The Z-value must be set when searching against a sharded database '
|
||||
'to correctly scale e-values.'
|
||||
)
|
||||
if 'hmmerdb' in self._database_path:
|
||||
raise ValueError('HMMERDB is not supported in sharded mode.')
|
||||
|
||||
if max_sequences <= 1:
|
||||
raise ValueError(
|
||||
'max_sequences must be greater than 1 when running in sharded '
|
||||
'mode, because each shard would return only the query sequence.'
|
||||
)
|
||||
|
||||
self._shard_paths = shard_paths
|
||||
self._max_threads = len(self._shard_paths)
|
||||
if max_threads is not None:
|
||||
self._max_threads = min(max_threads, self._max_threads)
|
||||
logging.info('Nhmmer running with max_threads = %d', self._max_threads)
|
||||
else:
|
||||
self._shard_paths = None
|
||||
self._max_threads = None
|
||||
|
||||
self._binary_path = binary_path
|
||||
self._hmmalign_binary_path = hmmalign_binary_path
|
||||
self._hmmbuild_binary_path = hmmbuild_binary_path
|
||||
self._db_path = database_path
|
||||
|
||||
subprocess_utils.check_binary_exists(path=self._binary_path, name='Nhmmer')
|
||||
|
||||
if strand and strand not in {'watson', 'crick'}:
|
||||
@@ -78,21 +130,75 @@ class Nhmmer(msa_tool.MsaTool):
|
||||
|
||||
self._e_value = e_value
|
||||
self._n_cpu = n_cpu
|
||||
self._z_value = z_value
|
||||
self._max_sequences = max_sequences
|
||||
self._filter_f3 = filter_f3
|
||||
self._alphabet = alphabet
|
||||
self._strand = strand
|
||||
|
||||
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
|
||||
"""Query the database using Nhmmer."""
|
||||
logging.info(
|
||||
'Query sequence: %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
"""Query the database (sharded or unsharded) using Nhmmer."""
|
||||
if self._shard_paths:
|
||||
# Sharded case, run the query against each database shard in parallel.
|
||||
logging.info(
|
||||
'Query sequence (sharded db): %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as query_tmp_dir:
|
||||
global_temp_dir = tempfile.mkdtemp()
|
||||
|
||||
def _query_shard_fn(
|
||||
shard_path: str,
|
||||
) -> tuple[msa_tool.MsaToolResult, float]:
|
||||
t_start = time.time()
|
||||
# Get tblout as it contains e-values we need for merging sequences.
|
||||
result = self._query_db_shard(
|
||||
target_sequence=target_sequence,
|
||||
db_shard_path=shard_path,
|
||||
get_tblout=True, # Tblout contains e-values needed for merging.
|
||||
global_temp_dir=global_temp_dir,
|
||||
)
|
||||
return result, time.time() - t_start
|
||||
|
||||
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
|
||||
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
|
||||
|
||||
logging.info(
|
||||
'Finished query for %d shards, shard timings (seconds): %s',
|
||||
len(tool_outputs),
|
||||
', '.join(f'{t:.1f}' for t in timings),
|
||||
)
|
||||
|
||||
shutil.rmtree(global_temp_dir, ignore_errors=True)
|
||||
return _merge_nhmmer_results(tool_outputs, self._max_sequences)
|
||||
|
||||
else:
|
||||
# Non-sharded case, run the query against the whole database.
|
||||
logging.info(
|
||||
'Query sequence (non-sharded db): %s',
|
||||
target_sequence
|
||||
if len(target_sequence) <= 16
|
||||
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
|
||||
)
|
||||
return self._query_db_shard(
|
||||
target_sequence=target_sequence,
|
||||
db_shard_path=self._database_path,
|
||||
get_tblout=False,
|
||||
)
|
||||
|
||||
def _query_db_shard(
|
||||
self,
|
||||
*,
|
||||
target_sequence: str,
|
||||
db_shard_path: str,
|
||||
get_tblout: bool,
|
||||
global_temp_dir: str | None = None,
|
||||
) -> msa_tool.MsaToolResult:
|
||||
"""Query the database shard using Nhmmer."""
|
||||
|
||||
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
|
||||
input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m')
|
||||
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
|
||||
pathlib.Path(output_sto_path).touch()
|
||||
@@ -106,8 +212,18 @@ class Nhmmer(msa_tool.MsaTool):
|
||||
*('--cpu', str(self._n_cpu)),
|
||||
]
|
||||
|
||||
if get_tblout:
|
||||
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
|
||||
output_tblout_path.touch()
|
||||
cmd_flags.extend(['--tblout', str(output_tblout_path)])
|
||||
else:
|
||||
output_tblout_path = None
|
||||
|
||||
cmd_flags.extend(['-E', str(self._e_value)])
|
||||
|
||||
if self._z_value is not None:
|
||||
cmd_flags.extend(['-Z', str(self._z_value)])
|
||||
|
||||
if self._alphabet:
|
||||
cmd_flags.extend([f'--{self._alphabet}'])
|
||||
|
||||
@@ -125,13 +241,12 @@ class Nhmmer(msa_tool.MsaTool):
|
||||
cmd_flags.extend(['--F3', str(self._filter_f3)])
|
||||
|
||||
# The input A3M and the db are the last two arguments.
|
||||
cmd_flags.extend((input_a3m_path, self._db_path))
|
||||
cmd_flags.extend((input_a3m_path, db_shard_path))
|
||||
|
||||
cmd = [self._binary_path, *cmd_flags]
|
||||
|
||||
subprocess_utils.run(
|
||||
cmd=cmd,
|
||||
cmd_name=f'Nhmmer ({os.path.basename(self._db_path)})',
|
||||
cmd_name=f'Nhmmer ({os.path.basename(db_shard_path)})',
|
||||
log_stdout=False,
|
||||
log_stderr=True,
|
||||
log_on_process_error=True,
|
||||
@@ -166,6 +281,80 @@ class Nhmmer(msa_tool.MsaTool):
|
||||
# In this case return only the query sequence.
|
||||
a3m = f'>query\n{target_sequence}'
|
||||
|
||||
# Get the tabular output which has e.g. e-value for each target.
|
||||
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
|
||||
|
||||
return msa_tool.MsaToolResult(
|
||||
target_sequence=target_sequence, e_value=self._e_value, a3m=a3m
|
||||
target_sequence=target_sequence,
|
||||
e_value=self._e_value,
|
||||
a3m=a3m,
|
||||
tblout=tbl,
|
||||
)
|
||||
|
||||
|
||||
def _merge_nhmmer_results(
|
||||
nhmmer_results: Sequence[msa_tool.MsaToolResult],
|
||||
max_sequences: int,
|
||||
) -> msa_tool.MsaToolResult:
|
||||
"""Merges nhmmer result protos into a single one."""
|
||||
assert len(set(nh_res.target_sequence for nh_res in nhmmer_results)) == 1
|
||||
assert len(set(nh_res.e_value for nh_res in nhmmer_results)) == 1
|
||||
|
||||
# Parse the TBL output, create a mapping from unique hit ID to TBL line.
|
||||
parsed_tbl = {}
|
||||
for nhmmer_result in nhmmer_results:
|
||||
assert nhmmer_result.tblout is not None
|
||||
for line in nhmmer_result.tblout.splitlines():
|
||||
if not line.startswith('#'):
|
||||
line_fields = line.split(maxsplit=15)
|
||||
accession = line_fields[0]
|
||||
alignment_from = line_fields[6]
|
||||
alignment_to = line_fields[7]
|
||||
# This is the unique ID that is used in the output A3M.
|
||||
unique_id = f'{accession}/{alignment_from}-{alignment_to}'
|
||||
parsed_tbl[unique_id] = line
|
||||
|
||||
# Create an iterator and merge a3m info with tbl info.
|
||||
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
|
||||
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
|
||||
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
|
||||
next(iterator) # Skip the query which isn't present in tblout.
|
||||
for sequence, description in iterator:
|
||||
name = description.partition(' ')[0]
|
||||
if tbl_info := parsed_tbl.get(name):
|
||||
# Skip sequences for which we don't have tbl information.
|
||||
yield sequence, description, tbl_info, name
|
||||
|
||||
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]:
|
||||
unused_seq, unused_description, tbl_info, name = seq_data
|
||||
# Nucleic tblout has 16 space delimited columns. "-" used if no value
|
||||
# present. We want e-value in column 12, so do only 13 splits. Use the name
|
||||
# in case of an e-value tie.
|
||||
return float(tbl_info.split(maxsplit=13)[12]), name
|
||||
|
||||
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
|
||||
merged_a3m_and_tblout = heapq.merge(
|
||||
*[_merged_a3m_tbl_iter(res.a3m) for res in nhmmer_results],
|
||||
key=sort_key,
|
||||
)
|
||||
|
||||
# Truncate the a3m to max_sequences. Do not truncate the tblout.
|
||||
merged_tblout = []
|
||||
merged_a3m = [f'>query\n{nhmmer_results[0].target_sequence}']
|
||||
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
|
||||
merged_tblout.append(tbl_info)
|
||||
if len(merged_a3m) < max_sequences:
|
||||
merged_a3m.append(f'>{description}\n{seq}')
|
||||
|
||||
logging.info(
|
||||
'Limiting merged MSA depth from %d to %d',
|
||||
len(merged_tblout),
|
||||
max_sequences,
|
||||
)
|
||||
|
||||
return msa_tool.MsaToolResult(
|
||||
target_sequence=nhmmer_results[0].target_sequence,
|
||||
a3m='\n'.join(merged_a3m),
|
||||
e_value=nhmmer_results[0].e_value,
|
||||
tblout=None, # We no longer need the tblout.
|
||||
)
|
||||
|
||||
94
src/alphafold3/data/tools/shards.py
Normal file
94
src/alphafold3/data/tools/shards.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright 2025 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""A library to handle shards of the format file_path@NUM_SHARDS.
|
||||
|
||||
For instance, /path/to/file@20 will generate the following shards:
|
||||
|
||||
- /path/to/file-00000-of-00020
|
||||
- /path/to/file-00001-of-00020
|
||||
- ...
|
||||
- /path/to/file-00019-of-00020
|
||||
|
||||
This also supports @* pattern, which will determine the number of shards based
|
||||
on the filesystem content.
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import pathlib
|
||||
import re
|
||||
|
||||
|
||||
_MAX_NUM_SHARDS = 99_999
|
||||
_SHARD_RE = re.compile(
|
||||
r"""
|
||||
^(?P<prefix>[^\?\],\*]+)@
|
||||
(?P<shards>(\d{1,5})|\*)
|
||||
(?P<suffix>[\._][^\?\]@\*\/]*)?
|
||||
$""",
|
||||
re.X,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ShardSpec:
|
||||
prefix: str
|
||||
num_shards: int
|
||||
suffix: str
|
||||
|
||||
|
||||
def parse_shard_spec(path: str) -> ShardSpec | None:
|
||||
"""Returns the shard spec or None if the path is not a shard spec.
|
||||
|
||||
For instance, if the shard spec is '/path/to/file@20', the output will be
|
||||
('/path/to/file', 20).
|
||||
|
||||
Args:
|
||||
path: the path to parse, e.g. /path/to/file@20 or /path/to/file@*.
|
||||
"""
|
||||
parsed = re.fullmatch(_SHARD_RE, path)
|
||||
if not parsed:
|
||||
return None
|
||||
prefix = parsed.group('prefix')
|
||||
shards = parsed.group('shards')
|
||||
suffix = parsed.group('suffix') or ''
|
||||
|
||||
if shards != '*':
|
||||
return ShardSpec(prefix=prefix, num_shards=int(shards), suffix=suffix)
|
||||
shard_slice = slice(len(prefix) + 10, len(prefix) + 15)
|
||||
shard_path = pathlib.Path(f'{prefix}-00000-of-?????{suffix}')
|
||||
for shard in sorted(shard_path.parent.glob(shard_path.name), reverse=True):
|
||||
try:
|
||||
num_shards = int(str(shard)[shard_slice])
|
||||
return ShardSpec(prefix=prefix, num_shards=num_shards, suffix=suffix)
|
||||
except ValueError:
|
||||
continue
|
||||
return None
|
||||
|
||||
|
||||
def get_sharded_paths(shard_spec: str) -> Sequence[str] | None:
|
||||
"""Returns a list of file path or None if the input is not a shard spec.
|
||||
|
||||
Args:
|
||||
shard_spec: the specifications of the shard, e.g. /path/to/file@20.
|
||||
"""
|
||||
parsed_spec = parse_shard_spec(shard_spec)
|
||||
if not parsed_spec:
|
||||
return None
|
||||
|
||||
prefix = parsed_spec.prefix
|
||||
num_shards = parsed_spec.num_shards
|
||||
suffix = parsed_spec.suffix
|
||||
if num_shards > _MAX_NUM_SHARDS:
|
||||
raise ValueError(f'Shard count for {shard_spec} exceeds {_MAX_NUM_SHARDS}')
|
||||
return [
|
||||
f'{prefix}-{i:05d}-of-{num_shards:05d}{suffix}' for i in range(num_shards)
|
||||
]
|
||||
Reference in New Issue
Block a user