Add option to run Jackhmmer/Nhmmer genetic search in sharded mode (10-30x faster)

PiperOrigin-RevId: 825546064
Change-Id: Ib421e47bb9ca7eea512c49a532e7e995a0f5721f
This commit is contained in:
Augustin Zidek
2025-10-29 07:54:40 -07:00
committed by Copybara-Service
parent 46fe3f0f60
commit 805adc3863
8 changed files with 737 additions and 47 deletions

View File

@@ -82,6 +82,85 @@ is the number of cores used for each Jackhmmer process times 4. Also note that
for sequences with deep MSAs, Jackhmmer or Nhmmer may need a substantial amount
of RAM beyond the recommended 64 GB of RAM.
### Sharded genetic databases
The run time of the genetic database search can be *significantly* sped up by
splitting the genetic databases if a machine with many CPU cores is used and the
databases are on very fast SSD or in a RAM-backed filesystem. With this
technique you can make Jackhmmer/Nhmmer genetic search fully utilize your
hardware and take advantage of multi-core systems.
Each genetic database with *n* sequences is split into *s* shards, each
containing roughly *n* / *s* sequences. We recommend splitting the sequences
between shards randomly to make sure each shard has similar sequence length
distribution. This could be achieved using standard tools:
1. Shuffle the sequences in the fasta. This can be done for example by running:
`seqkit shuffle --two-pass <db.fasta>`
2. Split the shuffled fasta in *s* shards. This can be done for example by
running: `seqkit split2 --by-part <s> <db.fasta>`
Make sure the shards names follow this pattern:
`prefix-<shard_index>-of-<total_shards>`, both `shard_index` and `total_shards`
having always 5 digits, with leading zeros as needed. The `shard_index` goes
from 0 to `total_shards - 1`. A file "path" (spec) for a sharded file is
`prefix@<total_shards>`.
E.g. for a file named `uniprot.fasta` split into 3 shards, the names of the
shards should be:
* `uniprot.fasta-00000-of-00003`
* `uniprot.fasta-00001-of-00003`
* `uniprot.fasta-00002-of-00003`
The file spec for these files is `uniprot.fasta@3`.
Save the total number of sequences in the protein databases, and the total
number of nucleic bases in the RNA databases these will be needed later as a
flag to Jackhmmer/Nhmmer to correctly scale e-values across all shards.
Save the sharded databases on a fast SSD or in a RAM-backed filesystem, then
launch AlphaFold with the sharded paths instead of normal paths and set the
Z-values.
For instance with each database sharded into 16 shards:
```bash
python run_alphafold.py \
--small_bfd_database_path="bfd-first_non_consensus_sequences.fasta@64" \
--small_bfd_z_value=65984053 \
--mgnify_database_path="mgy_clusters_2022_05.fa@512" \
--mgnify_z_value=623796864 \
--uniprot_cluster_annot_database_path="uniprot_cluster_annot_2021_04.fasta@256" \
--uniprot_cluster_annot_z_value=225619586 \
--uniref90_database_path="uniref90_2022_05.fasta@128" \
--uniref90_z_value=153742194 \
--ntrna_database_path="nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta@256" \
--ntrna_z_value=76752.808514 \
--rfam_database_path="rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta@16" \
--rfam_z_value=138.115553 \
--rna_central_database_path="rnacentral_active_seq_id_90_cov_80_linclust.fasta@64" \
--rna_central_z_value=13271.415730
--jackhmmer_n_cpu=2 \
--jackhmmer_max_parallel_shards=16 \
--nhmmer_n_cpu=2 \
--nhmmer_max_parallel_shards=16
```
This run will utilize (2 CPUs) × (16 max parallel shards) × (4 protein dbs
searched in parallel) = 128 cores for each protein chain, and (2 CPUs) × (16 max
parallel shards) × (3 RNA dbs searched in parallel) = 96 cores for each RNA
chain. Make sure to tune:
* the Jackhmmer/Nhmmer number of CPUs,
* the maximum number of shards searched in parallel,
* and the number of shards for each database
so that the memory bandwidth and CPUs on your machine are optimally utilized.
You should aim for consistent shard sizes across all databases (so e.g. if
database A is split into 16 shards and is 3× smaller than database B, database B
should be split into 3 × 16 = 48 shards).
## Model Inference
Table 8 in the Supplementary Information of the

View File

@@ -132,37 +132,86 @@ _SMALL_BFD_DATABASE_PATH = flags.DEFINE_string(
'${DB_DIR}/bfd-first_non_consensus_sequences.fasta',
'Small BFD database path, used for protein MSA search.',
)
_SMALL_BFD_Z_VALUE = flags.DEFINE_integer(
'small_bfd_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_MGNIFY_DATABASE_PATH = flags.DEFINE_string(
'mgnify_database_path',
'${DB_DIR}/mgy_clusters_2022_05.fa',
'Mgnify database path, used for protein MSA search.',
)
_MGNIFY_Z_VALUE = flags.DEFINE_integer(
'mgnify_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string(
'uniprot_cluster_annot_database_path',
'${DB_DIR}/uniprot_all_2021_04.fa',
'UniProt database path, used for protein paired MSA search.',
)
_UNIPROT_CLUSTER_ANNOT_Z_VALUE = flags.DEFINE_integer(
'uniprot_cluster_annot_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_UNIREF90_DATABASE_PATH = flags.DEFINE_string(
'uniref90_database_path',
'${DB_DIR}/uniref90_2022_05.fa',
'UniRef90 database path, used for MSA search. The MSA obtained by '
'searching it is used to construct the profile for template search.',
)
_UNIREF90_Z_VALUE = flags.DEFINE_integer(
'uniref90_z_value',
None,
'The Z-value representing the database size in number of sequences for'
' E-value calculation. Must be set for sharded databases.',
lower_bound=0,
)
_NTRNA_DATABASE_PATH = flags.DEFINE_string(
'ntrna_database_path',
'${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta',
'NT-RNA database path, used for RNA MSA search.',
)
_NTRNA_Z_VALUE = flags.DEFINE_float(
'ntrna_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_RFAM_DATABASE_PATH = flags.DEFINE_string(
'rfam_database_path',
'${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta',
'Rfam database path, used for RNA MSA search.',
)
_RFAM_Z_VALUE = flags.DEFINE_float(
'rfam_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string(
'rna_central_database_path',
'${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta',
'RNAcentral database path, used for RNA MSA search.',
)
_RNA_CENTRAL_Z_VALUE = flags.DEFINE_float(
'rna_central_z_value',
None,
'The Z-value representing the database size in megabases for E-value'
' calculation. Must be set for sharded databases.',
lower_bound=0.0,
)
_PDB_DATABASE_PATH = flags.DEFINE_string(
'pdb_database_path',
'${DB_DIR}/mmcif_files',
@@ -183,6 +232,14 @@ _JACKHMMER_N_CPU = flags.DEFINE_integer(
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_JACKHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
'jackhmmer_max_parallel_shards',
None,
'Maximum number of shards to search against in parallel. If unset, one'
' Jackhmmer instance will be run per shard. Only applicable if the'
' database is sharded.',
lower_bound=1,
)
_NHMMER_N_CPU = flags.DEFINE_integer(
'nhmmer_n_cpu',
# Unfortunately, os.process_cpu_count() is only available in Python 3.13+.
@@ -191,6 +248,14 @@ _NHMMER_N_CPU = flags.DEFINE_integer(
' above 8 CPUs provides very little additional speedup.',
lower_bound=0,
)
_NHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer(
'nhmmer_max_parallel_shards',
None,
'Maximum number of shards to search against in parallel. If unset, one'
' Nhmmer instance will be run per shard. Only applicable if the'
' database is sharded.',
lower_bound=1,
)
# Data pipeline configuration.
_RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool(
@@ -828,18 +893,27 @@ def main(_):
hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value,
hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value,
small_bfd_database_path=expand_path(_SMALL_BFD_DATABASE_PATH.value),
small_bfd_z_value=_SMALL_BFD_Z_VALUE.value,
mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value),
mgnify_z_value=_MGNIFY_Z_VALUE.value,
uniprot_cluster_annot_database_path=expand_path(
_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value
),
uniprot_cluster_annot_z_value=_UNIPROT_CLUSTER_ANNOT_Z_VALUE.value,
uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value),
uniref90_z_value=_UNIREF90_Z_VALUE.value,
ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value),
ntrna_z_value=_NTRNA_Z_VALUE.value,
rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value),
rfam_z_value=_RFAM_Z_VALUE.value,
rna_central_database_path=expand_path(_RNA_CENTRAL_DATABASE_PATH.value),
rna_central_z_value=_RNA_CENTRAL_Z_VALUE.value,
pdb_database_path=expand_path(_PDB_DATABASE_PATH.value),
seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value),
jackhmmer_n_cpu=_JACKHMMER_N_CPU.value,
jackhmmer_max_parallel_shards=_JACKHMMER_MAX_PARALLEL_SHARDS.value,
nhmmer_n_cpu=_NHMMER_N_CPU.value,
nhmmer_max_parallel_shards=_NHMMER_MAX_PARALLEL_SHARDS.value,
max_template_date=max_template_date,
)
else:

View File

@@ -42,9 +42,16 @@ class JackhmmerConfig:
n_cpu: An integer with the number of CPUs to use.
n_iter: An integer with the number of database search iterations.
e_value: e-value for the database lookup.
z_value: The Z-value representing the number of comparisons done (i.e
correct database size) for E-value calculation.
z_value: The Z-value representing the database size in number of sequences
for E-value and domain E-value calculation. Must be set for sharded
databases.
dom_z_value: The Z-value representing the database size in number of
sequences for domain E-value calculation. Must be set for sharded
databases.
max_sequences: Max sequences to return in MSA.
max_parallel_shards: If given, the maximum number of shards to search
against in parallel. If None, one Jackhmmer instance will be run per
shard. Only applicable if the database is sharded.
"""
binary_path: str
@@ -52,8 +59,10 @@ class JackhmmerConfig:
n_cpu: int
n_iter: int
e_value: float
z_value: float | int | None
z_value: int | None
dom_z_value: int | None
max_sequences: int
max_parallel_shards: int | None = None
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)
@@ -67,8 +76,14 @@ class NhmmerConfig:
database_config: Database configuration.
n_cpu: An integer with the number of CPUs to use.
e_value: e-value for the database lookup.
z_value: The Z-value representing the database size in megabases for
E-value calculation. Allows fractional values. Must be set for sharded
databases.
max_sequences: Max sequences to return in MSA.
alphabet: The alphabet when building a profile with hmmbuild.
max_parallel_shards: If given, the maximum number of shards to search
against in parallel. If None, one Nhmmer instance will be run per shard.
Only applicable if the database is sharded.
"""
binary_path: str
@@ -77,8 +92,10 @@ class NhmmerConfig:
database_config: DatabaseConfig
n_cpu: int
e_value: float
z_value: float | None
max_sequences: int
alphabet: str | None
max_parallel_shards: int | None = None
@dataclasses.dataclass(frozen=True, kw_only=True, slots=True)

View File

@@ -214,21 +214,42 @@ class DataPipelineConfig:
raw MSA in template search.
small_bfd_database_path: Small BFD database path, used for protein MSA
search.
small_bfd_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
mgnify_database_path: Mgnify database path, used for protein MSA search.
mgnify_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
uniprot_cluster_annot_database_path: Uniprot database path, used for protein
paired MSA search.
uniprot_cluster_annot_z_value: The Z-value representing the database size in
number of sequences for E-value calculation. Must be set for sharded
databases.
uniref90_database_path: UniRef90 database path, used for MSA search, and the
MSA obtained by searching it is used to construct the profile for template
search.
uniref90_z_value: The Z-value representing the database size in number of
sequences for E-value calculation. Must be set for sharded databases.
ntrna_database_path: NT-RNA database path, used for RNA MSA search.
ntrna_z_value: The Z-value representing the database size in megabases for
E-value calculation. Must be set for sharded databases.
rfam_database_path: Rfam database path, used for RNA MSA search.
rfam_z_value: The Z-value representing the database size in megabases for
E-value calculation. Must be set for sharded databases.
rna_central_database_path: RNAcentral database path, used for RNA MSA
search.
rna_central_z_value: The Z-value representing the database size in megabases
for E-value calculation. Must be set for sharded databases.
seqres_database_path: PDB sequence database path, used for template search.
pdb_database_path: PDB database directory with mmCIF files path, used for
template search.
jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer.
jackhmmer_max_parallel_shards: Maximum number of shards to search against in
parallel. If None, one Jackhmmer instance will be run per shard. Only
applicable if the database is sharded.
nhmmer_n_cpu: Number of CPUs to use for Nhmmer.
nhmmer_max_parallel_shards: Maximum number of shards to search against in
parallel. If None, one Nhmmer instance will be run per shard. Only
applicable if the database is sharded.
max_template_date: The latest date of templates to use.
"""
@@ -241,20 +262,29 @@ class DataPipelineConfig:
# Jackhmmer databases.
small_bfd_database_path: str
small_bfd_z_value: int | None = None
mgnify_database_path: str
mgnify_z_value: int | None = None
uniprot_cluster_annot_database_path: str
uniprot_cluster_annot_z_value: int | None = None
uniref90_database_path: str
uniref90_z_value: int | None = None
# Nhmmer databases.
ntrna_database_path: str
ntrna_z_value: int | None = None
rfam_database_path: str
rfam_z_value: int | None = None
rna_central_database_path: str
rna_central_z_value: int | None = None
# Template search databases.
seqres_database_path: str
pdb_database_path: str
# Optional configuration for MSA tools.
jackhmmer_n_cpu: int = 8
jackhmmer_max_parallel_shards: int | None = None
nhmmer_n_cpu: int = 8
nhmmer_max_parallel_shards: int | None = None
max_template_date: datetime.date
@@ -274,8 +304,10 @@ class DataPipeline:
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=None,
z_value=data_pipeline_config.uniref90_z_value,
dom_z_value=data_pipeline_config.uniref90_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
@@ -290,8 +322,10 @@ class DataPipeline:
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=None,
z_value=data_pipeline_config.mgnify_z_value,
dom_z_value=data_pipeline_config.mgnify_z_value,
max_sequences=5_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
@@ -308,8 +342,10 @@ class DataPipeline:
e_value=1e-4,
# Set z_value=138_515_945 to match the z_value used in the paper.
# In practice, this has minimal impact on predicted structures.
z_value=None,
z_value=data_pipeline_config.small_bfd_z_value,
dom_z_value=data_pipeline_config.small_bfd_z_value,
max_sequences=5_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
@@ -324,8 +360,10 @@ class DataPipeline:
n_cpu=data_pipeline_config.jackhmmer_n_cpu,
n_iter=1,
e_value=1e-4,
z_value=None,
z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
dom_z_value=data_pipeline_config.uniprot_cluster_annot_z_value,
max_sequences=50_000,
max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.PROTEIN_CHAIN,
crop_size=None,
@@ -342,7 +380,9 @@ class DataPipeline:
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.ntrna_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,
@@ -359,7 +399,9 @@ class DataPipeline:
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.rfam_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,
@@ -376,7 +418,9 @@ class DataPipeline:
n_cpu=data_pipeline_config.nhmmer_n_cpu,
e_value=1e-3,
alphabet='rna',
z_value=data_pipeline_config.rna_central_z_value,
max_sequences=10_000,
max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards,
),
chain_poly_type=mmcif_names.RNA_CHAIN,
crop_size=None,

View File

@@ -10,12 +10,19 @@
"""Library to run Jackhmmer from Python."""
from collections.abc import Iterable, Sequence
from concurrent import futures
import heapq
import os
import pathlib
import shutil
import tempfile
import time
from absl import logging
from alphafold3.data import parsers
from alphafold3.data.tools import msa_tool
from alphafold3.data.tools import shards
from alphafold3.data.tools import subprocess_utils
@@ -31,43 +38,91 @@ class Jackhmmer(msa_tool.MsaTool):
n_iter: int = 3,
e_value: float | None = 1e-3,
z_value: float | int | None = None,
dom_e: float | None = None,
dom_z_value: float | int | None = None,
max_sequences: int = 5000,
filter_f1: float = 5e-4,
filter_f2: float = 5e-5,
filter_f3: float = 5e-7,
max_threads: int | None = None,
**unused_kwargs,
):
"""Initializes the Python Jackhmmer wrapper.
NOTE: The MSA obtained by running against sharded dbs won't be always
exactly the same as the MSA obtained by running against an unsharded db.
This is because of Jackhmmer deduplication logic, which won't spot duplicate
hits across multiple shards. Usually this means that the sharded search
finds more hits (likely bounded by the number of shards), but this should
not pose an issue given how the results are used downstream. The problem is
more pronounced with deep MSAs and lower in the hit list (higher e-values).
Make sure to set the Z and domZ values when searching against a sharded
database, otherwise the results won't match the normal unsharded search.
Args:
binary_path: The path to the jackhmmer executable.
database_path: The path to the jackhmmer database (FASTA format).
database_path: The path to the jackhmmer database (FASTA format). Sharded
file specs, e.g. `<db_path>@<num_shards>`, are supported.
n_cpu: The number of CPUs to give Jackhmmer.
n_iter: The number of Jackhmmer iterations.
e_value: The E-value, see Jackhmmer docs for more details.
z_value: The Z-value representing the number of comparisons done (i.e
correct database size) for E-value calculation.
correct database size) for E-value calculation. Make sure to set this
when searching against a sharded database, otherwise the e-values will
be incorrectly scaled.
dom_e: Domain e-value criteria for inclusion in tblout.
dom_z_value: Domain z-value representing the number of comparisons done
(i.e correct database size) for domain E-value calculation. Make sure to
set this when searching against a sharded database, otherwise the domain
e-values will be incorrectly scaled.
max_sequences: Maximum number of sequences to return in the MSA.
filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off.
filter_f2: Viterbi pre-filter, set to >1.0 to turn off.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
max_threads: If given, the maximum number of threads used when running
sharded databases.
Raises:
RuntimeError: If Jackhmmer binary not found within the path.
ValueError: If an invalid configuration is provided in the args.
"""
self._binary_path = binary_path
self._database_path = database_path
if shard_paths := shards.get_sharded_paths(self._database_path):
if n_iter != 1:
raise ValueError('For a sharded db, only n_iter=1 is supported.')
if z_value is None:
raise ValueError(
'The Z-value must be set when searching against a sharded database '
'to correctly scale e-values.'
)
if max_sequences <= 1:
raise ValueError(
'max_sequences must be greater than 1 when running in sharded '
'mode, because each shard would return only the query sequence.'
)
self._shard_paths = shard_paths
self._max_threads = len(self._shard_paths)
if max_threads is not None:
self._max_threads = min(max_threads, self._max_threads)
logging.info('Jackhmmer running with max_threads = %d', self._max_threads)
else:
self._shard_paths = None
self._max_threads = None
self._binary_path = binary_path
subprocess_utils.check_binary_exists(
path=self._binary_path, name='Jackhmmer'
)
if not os.path.exists(self._database_path):
raise ValueError(f'Could not find Jackhmmer database {database_path}')
self._n_cpu = n_cpu
self._n_iter = n_iter
self._e_value = e_value
self._z_value = z_value
self._dom_e = dom_e
self._dom_z_value = dom_z_value
self._max_sequences = max_sequences
self._filter_f1 = filter_f1
self._filter_f2 = filter_f2
@@ -81,21 +136,73 @@ class Jackhmmer(msa_tool.MsaTool):
)
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
"""Queries the database using Jackhmmer."""
logging.info(
'Query sequence: %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
"""Query the database (sharded or unsharded) using Jackhmmer."""
if self._shard_paths:
# Sharded case, run the query against each database shard in parallel.
logging.info(
'Query sequence (sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
with tempfile.TemporaryDirectory() as query_tmp_dir:
global_temp_dir = tempfile.mkdtemp()
def _query_shard_fn(
shard_path: str,
) -> tuple[msa_tool.MsaToolResult, float]:
t_start = time.time()
result = self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=shard_path,
get_tblout=True, # Tblout contains e-values needed for merging.
global_temp_dir=global_temp_dir,
)
return result, time.time() - t_start
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
logging.info(
'Finished query for %d shards, shard timings (seconds): %s',
len(tool_outputs),
', '.join(f'{t:.1f}' for t in timings),
)
shutil.rmtree(global_temp_dir, ignore_errors=True)
return _merge_jackhmmer_results(tool_outputs, self._max_sequences)
else:
# Non-sharded case, run the query against the whole database.
logging.info(
'Query sequence (non-sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
return self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=self._database_path,
get_tblout=False,
)
def _query_db_shard(
self,
*,
target_sequence: str,
db_shard_path: str,
get_tblout: bool,
global_temp_dir: str | None = None,
) -> msa_tool.MsaToolResult:
"""Query the database shard using Jackhmmer."""
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta')
subprocess_utils.create_query_fasta_file(
sequence=target_sequence, path=input_fasta_path
)
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
pathlib.Path(output_sto_path).touch()
# The F1/F2/F3 are the expected proportion to pass each of the filtering
# stages (which get progressively more expensive), reducing these
@@ -113,6 +220,13 @@ class Jackhmmer(msa_tool.MsaTool):
*('-N', str(self._n_iter)),
]
if get_tblout:
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
output_tblout_path.touch()
cmd_flags.extend(['--tblout', str(output_tblout_path)])
else:
output_tblout_path = None
# Report only sequences with E-values <= x in per-sequence output.
if self._e_value is not None:
cmd_flags.extend(['-E', str(self._e_value)])
@@ -123,18 +237,21 @@ class Jackhmmer(msa_tool.MsaTool):
if self._z_value is not None:
cmd_flags.extend(['-Z', str(self._z_value)])
if self._dom_z_value is not None:
cmd_flags.extend(['--domZ', str(self._dom_z_value)])
if self._dom_e is not None:
cmd_flags.extend(['--domE', str(self._dom_e)])
if self._max_sequences is not None and self._supports_seq_limit:
cmd_flags.extend(['--seq_limit', str(self._max_sequences)])
cmd = (
[self._binary_path]
+ cmd_flags
+ [input_fasta_path, self._database_path]
)
# The input FASTA and the input db are the last two arguments.
cmd = [self._binary_path] + cmd_flags + [input_fasta_path, db_shard_path]
subprocess_utils.run(
cmd=cmd,
cmd_name=f'Jackhmmer ({os.path.basename(self._database_path)})',
cmd_name=f'Jackhmmer ({os.path.basename(db_shard_path)})',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
@@ -145,6 +262,73 @@ class Jackhmmer(msa_tool.MsaTool):
f, max_sequences=self._max_sequences
)
return msa_tool.MsaToolResult(
target_sequence=target_sequence, a3m=a3m, e_value=self._e_value
)
# Get the tabular output which has e.g. e-value for each target.
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
return msa_tool.MsaToolResult(
target_sequence=target_sequence,
a3m=a3m,
e_value=self._e_value,
tblout=tbl,
)
def _merge_jackhmmer_results(
jh_results: Sequence[msa_tool.MsaToolResult], max_sequences: int
) -> msa_tool.MsaToolResult:
"""Merges Jackhmmer result protos into a single one."""
assert len(set(jh_res.target_sequence for jh_res in jh_results)) == 1
assert len(set(jh_res.e_value for jh_res in jh_results)) == 1
# Parse the TBL output, create a mapping from hit name to TBL line.
parsed_tbl = {}
for jh_result in jh_results:
assert jh_result.tblout is not None
for line in jh_result.tblout.splitlines():
if not line.startswith('#'):
parsed_tbl[line.partition(' ')[0]] = line
# Create an iterator and merge a3m info with tbl info.
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
next(iterator) # Skip the query which isn't present in tblout.
for sequence, description in iterator:
name = description.partition(' ')[0].partition('/')[0]
if tbl_info := parsed_tbl.get(name):
# Skip sequences for which we don't have tbl information.
yield sequence, description, tbl_info, name
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]:
unused_seq, unused_description, tbl_info, name = seq_data
# Tblout lines have 19 whitespace delimited columns. "-" used if no value
# present. We want e-value in column with index 4, so do only 5 splits.
# Use the name in case of a e-value tie.
return float(tbl_info.split(maxsplit=5)[4]), name
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
merged_a3m_and_tblout = heapq.merge(
*[_merged_a3m_tbl_iter(res.a3m) for res in jh_results],
key=sort_key,
)
# Truncate the a3m to max_sequences. Do not truncate the tblout.
merged_tblout = []
merged_a3m = [f'>query\n{jh_results[0].target_sequence}']
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
merged_tblout.append(tbl_info)
if len(merged_a3m) < max_sequences:
merged_a3m.append(f'>{description}\n{seq}')
logging.info(
'Limiting merged MSA depth from %d to %d',
len(merged_tblout),
max_sequences,
)
return msa_tool.MsaToolResult(
target_sequence=jh_results[0].target_sequence,
a3m='\n'.join(merged_a3m),
e_value=jh_results[0].e_value,
tblout=None, # We no longer need the tblout.
)

View File

@@ -16,11 +16,20 @@ from typing import Protocol
@dataclasses.dataclass(frozen=True, slots=True, kw_only=True)
class MsaToolResult:
"""The result of a MSA tool query."""
"""The result of a MSA tool query.
Attributes:
target_sequence: The sequence that was used to query the MSA tool.
e_value: The e-value that was used to filter the MSA tool results.
a3m: The MSA output of the tool in the A3M format.
tblout: The optional tblout output of the MSA tool (needed for merging
results of queries against a sharded database).
"""
target_sequence: str
e_value: float
a3m: str
tblout: str | None = None
class MsaTool(Protocol):

View File

@@ -10,9 +10,14 @@
"""Library to run Nhmmer from Python."""
from collections.abc import Iterable, Sequence
from concurrent import futures
import heapq
import os
import pathlib
import shutil
import tempfile
import time
from typing import Final
from absl import logging
@@ -20,8 +25,10 @@ from alphafold3.data import parsers
from alphafold3.data.tools import hmmalign
from alphafold3.data.tools import hmmbuild
from alphafold3.data.tools import msa_tool
from alphafold3.data.tools import shards
from alphafold3.data.tools import subprocess_utils
_SHORT_SEQUENCE_CUTOFF: Final[int] = 50
@@ -36,38 +43,83 @@ class Nhmmer(msa_tool.MsaTool):
database_path: str,
n_cpu: int = 8,
e_value: float = 1e-3,
z_value: float | int | None = None,
max_sequences: int = 5000,
filter_f3: float = 1e-5,
alphabet: str | None = None,
strand: str | None = None,
max_threads: int | None = None,
):
"""Initializes the Python Nhmmer wrapper.
NOTE: The MSA obtained by running against sharded dbs won't be always
exactly the same as the MSA obtained by running against an unsharded db.
This is because of Jackhmmer deduplication logic, which won't spot duplicate
hits across multiple shards. Usually this means that the sharded search
finds more hits (likely bounded by the number of shards), but this should
not pose an issue given how the results are used downstream. The problem is
more pronounced with deep MSAs and lower in the hit list (higher e-values).
Make sure to set the Z value when searching against a sharded database,
otherwise the results won't match the normal unsharded search.
Args:
binary_path: Path to the Nhmmer binary.
hmmalign_binary_path: Path to the Hmmalign binary.
hmmbuild_binary_path: Path to the Hmmbuild binary.
database_path: MSA database path to search against. This can be either a
FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb
binary. The HMMERDB is ~10x faster but experimental.
binary. The HMMERDB is ~10x faster but experimental. Sharded file
specs, e.g. <db_path>@<num_shards>, are supported.
n_cpu: The number of CPUs to give Nhmmer.
e_value: The E-value, see Nhmmer docs for more details. Will be
overwritten if bit_score is set.
z_value: The Z-value representing the number of comparisons done (i.e
correct database size) for E-value calculation. Make sure to set this
when searching against a sharded database, otherwise the e-values will
be incorrectly scaled.
max_sequences: Maximum number of sequences to return in the MSA.
filter_f3: Forward pre-filter, set to >1.0 to turn off.
alphabet: The alphabet to assert when building a profile with hmmbuild.
This must be 'rna', 'dna', or None.
strand: "watson" searches query sequence, "crick" searches
reverse-compliment and default is None which means searching for both.
max_threads: If given, the maximum number of threads used when running
sharded databases.
Raises:
RuntimeError: If Nhmmer binary not found within the path.
ValueError: If an invalid configuration is provided in the args.
"""
self._database_path = database_path
if shard_paths := shards.get_sharded_paths(self._database_path):
if z_value is None:
raise ValueError(
'The Z-value must be set when searching against a sharded database '
'to correctly scale e-values.'
)
if 'hmmerdb' in self._database_path:
raise ValueError('HMMERDB is not supported in sharded mode.')
if max_sequences <= 1:
raise ValueError(
'max_sequences must be greater than 1 when running in sharded '
'mode, because each shard would return only the query sequence.'
)
self._shard_paths = shard_paths
self._max_threads = len(self._shard_paths)
if max_threads is not None:
self._max_threads = min(max_threads, self._max_threads)
logging.info('Nhmmer running with max_threads = %d', self._max_threads)
else:
self._shard_paths = None
self._max_threads = None
self._binary_path = binary_path
self._hmmalign_binary_path = hmmalign_binary_path
self._hmmbuild_binary_path = hmmbuild_binary_path
self._db_path = database_path
subprocess_utils.check_binary_exists(path=self._binary_path, name='Nhmmer')
if strand and strand not in {'watson', 'crick'}:
@@ -78,21 +130,75 @@ class Nhmmer(msa_tool.MsaTool):
self._e_value = e_value
self._n_cpu = n_cpu
self._z_value = z_value
self._max_sequences = max_sequences
self._filter_f3 = filter_f3
self._alphabet = alphabet
self._strand = strand
def query(self, target_sequence: str) -> msa_tool.MsaToolResult:
"""Query the database using Nhmmer."""
logging.info(
'Query sequence: %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
"""Query the database (sharded or unsharded) using Nhmmer."""
if self._shard_paths:
# Sharded case, run the query against each database shard in parallel.
logging.info(
'Query sequence (sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
with tempfile.TemporaryDirectory() as query_tmp_dir:
global_temp_dir = tempfile.mkdtemp()
def _query_shard_fn(
shard_path: str,
) -> tuple[msa_tool.MsaToolResult, float]:
t_start = time.time()
# Get tblout as it contains e-values we need for merging sequences.
result = self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=shard_path,
get_tblout=True, # Tblout contains e-values needed for merging.
global_temp_dir=global_temp_dir,
)
return result, time.time() - t_start
with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex:
tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths))
logging.info(
'Finished query for %d shards, shard timings (seconds): %s',
len(tool_outputs),
', '.join(f'{t:.1f}' for t in timings),
)
shutil.rmtree(global_temp_dir, ignore_errors=True)
return _merge_nhmmer_results(tool_outputs, self._max_sequences)
else:
# Non-sharded case, run the query against the whole database.
logging.info(
'Query sequence (non-sharded db): %s',
target_sequence
if len(target_sequence) <= 16
else f'{target_sequence[:16]}... (len {len(target_sequence)})',
)
return self._query_db_shard(
target_sequence=target_sequence,
db_shard_path=self._database_path,
get_tblout=False,
)
def _query_db_shard(
self,
*,
target_sequence: str,
db_shard_path: str,
get_tblout: bool,
global_temp_dir: str | None = None,
) -> msa_tool.MsaToolResult:
"""Query the database shard using Nhmmer."""
with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir:
input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m')
output_sto_path = os.path.join(query_tmp_dir, 'output.sto')
pathlib.Path(output_sto_path).touch()
@@ -106,8 +212,18 @@ class Nhmmer(msa_tool.MsaTool):
*('--cpu', str(self._n_cpu)),
]
if get_tblout:
output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt')
output_tblout_path.touch()
cmd_flags.extend(['--tblout', str(output_tblout_path)])
else:
output_tblout_path = None
cmd_flags.extend(['-E', str(self._e_value)])
if self._z_value is not None:
cmd_flags.extend(['-Z', str(self._z_value)])
if self._alphabet:
cmd_flags.extend([f'--{self._alphabet}'])
@@ -125,13 +241,12 @@ class Nhmmer(msa_tool.MsaTool):
cmd_flags.extend(['--F3', str(self._filter_f3)])
# The input A3M and the db are the last two arguments.
cmd_flags.extend((input_a3m_path, self._db_path))
cmd_flags.extend((input_a3m_path, db_shard_path))
cmd = [self._binary_path, *cmd_flags]
subprocess_utils.run(
cmd=cmd,
cmd_name=f'Nhmmer ({os.path.basename(self._db_path)})',
cmd_name=f'Nhmmer ({os.path.basename(db_shard_path)})',
log_stdout=False,
log_stderr=True,
log_on_process_error=True,
@@ -166,6 +281,80 @@ class Nhmmer(msa_tool.MsaTool):
# In this case return only the query sequence.
a3m = f'>query\n{target_sequence}'
# Get the tabular output which has e.g. e-value for each target.
tbl = '' if output_tblout_path is None else output_tblout_path.read_text()
return msa_tool.MsaToolResult(
target_sequence=target_sequence, e_value=self._e_value, a3m=a3m
target_sequence=target_sequence,
e_value=self._e_value,
a3m=a3m,
tblout=tbl,
)
def _merge_nhmmer_results(
nhmmer_results: Sequence[msa_tool.MsaToolResult],
max_sequences: int,
) -> msa_tool.MsaToolResult:
"""Merges nhmmer result protos into a single one."""
assert len(set(nh_res.target_sequence for nh_res in nhmmer_results)) == 1
assert len(set(nh_res.e_value for nh_res in nhmmer_results)) == 1
# Parse the TBL output, create a mapping from unique hit ID to TBL line.
parsed_tbl = {}
for nhmmer_result in nhmmer_results:
assert nhmmer_result.tblout is not None
for line in nhmmer_result.tblout.splitlines():
if not line.startswith('#'):
line_fields = line.split(maxsplit=15)
accession = line_fields[0]
alignment_from = line_fields[6]
alignment_to = line_fields[7]
# This is the unique ID that is used in the output A3M.
unique_id = f'{accession}/{alignment_from}-{alignment_to}'
parsed_tbl[unique_id] = line
# Create an iterator and merge a3m info with tbl info.
def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]:
# Don't parse the entire a3m, lazily parse only as many sequences as needed.
iterator = iter(parsers.lazy_parse_fasta_string(a3m))
next(iterator) # Skip the query which isn't present in tblout.
for sequence, description in iterator:
name = description.partition(' ')[0]
if tbl_info := parsed_tbl.get(name):
# Skip sequences for which we don't have tbl information.
yield sequence, description, tbl_info, name
def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]:
unused_seq, unused_description, tbl_info, name = seq_data
# Nucleic tblout has 16 space delimited columns. "-" used if no value
# present. We want e-value in column 12, so do only 13 splits. Use the name
# in case of an e-value tie.
return float(tbl_info.split(maxsplit=13)[12]), name
# A3M/TBL is sorted by e-value and name, hence we can merge them efficiently.
merged_a3m_and_tblout = heapq.merge(
*[_merged_a3m_tbl_iter(res.a3m) for res in nhmmer_results],
key=sort_key,
)
# Truncate the a3m to max_sequences. Do not truncate the tblout.
merged_tblout = []
merged_a3m = [f'>query\n{nhmmer_results[0].target_sequence}']
for seq, description, tbl_info, _ in merged_a3m_and_tblout:
merged_tblout.append(tbl_info)
if len(merged_a3m) < max_sequences:
merged_a3m.append(f'>{description}\n{seq}')
logging.info(
'Limiting merged MSA depth from %d to %d',
len(merged_tblout),
max_sequences,
)
return msa_tool.MsaToolResult(
target_sequence=nhmmer_results[0].target_sequence,
a3m='\n'.join(merged_a3m),
e_value=nhmmer_results[0].e_value,
tblout=None, # We no longer need the tblout.
)

View File

@@ -0,0 +1,94 @@
# Copyright 2025 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""A library to handle shards of the format file_path@NUM_SHARDS.
For instance, /path/to/file@20 will generate the following shards:
- /path/to/file-00000-of-00020
- /path/to/file-00001-of-00020
- ...
- /path/to/file-00019-of-00020
This also supports @* pattern, which will determine the number of shards based
on the filesystem content.
"""
from collections.abc import Sequence
import dataclasses
import pathlib
import re
_MAX_NUM_SHARDS = 99_999
_SHARD_RE = re.compile(
r"""
^(?P<prefix>[^\?\],\*]+)@
(?P<shards>(\d{1,5})|\*)
(?P<suffix>[\._][^\?\]@\*\/]*)?
$""",
re.X,
)
@dataclasses.dataclass(frozen=True)
class ShardSpec:
prefix: str
num_shards: int
suffix: str
def parse_shard_spec(path: str) -> ShardSpec | None:
"""Returns the shard spec or None if the path is not a shard spec.
For instance, if the shard spec is '/path/to/file@20', the output will be
('/path/to/file', 20).
Args:
path: the path to parse, e.g. /path/to/file@20 or /path/to/file@*.
"""
parsed = re.fullmatch(_SHARD_RE, path)
if not parsed:
return None
prefix = parsed.group('prefix')
shards = parsed.group('shards')
suffix = parsed.group('suffix') or ''
if shards != '*':
return ShardSpec(prefix=prefix, num_shards=int(shards), suffix=suffix)
shard_slice = slice(len(prefix) + 10, len(prefix) + 15)
shard_path = pathlib.Path(f'{prefix}-00000-of-?????{suffix}')
for shard in sorted(shard_path.parent.glob(shard_path.name), reverse=True):
try:
num_shards = int(str(shard)[shard_slice])
return ShardSpec(prefix=prefix, num_shards=num_shards, suffix=suffix)
except ValueError:
continue
return None
def get_sharded_paths(shard_spec: str) -> Sequence[str] | None:
"""Returns a list of file path or None if the input is not a shard spec.
Args:
shard_spec: the specifications of the shard, e.g. /path/to/file@20.
"""
parsed_spec = parse_shard_spec(shard_spec)
if not parsed_spec:
return None
prefix = parsed_spec.prefix
num_shards = parsed_spec.num_shards
suffix = parsed_spec.suffix
if num_shards > _MAX_NUM_SHARDS:
raise ValueError(f'Shard count for {shard_spec} exceeds {_MAX_NUM_SHARDS}')
return [
f'{prefix}-{i:05d}-of-{num_shards:05d}{suffix}' for i in range(num_shards)
]