From 805adc3863841d83d631ccd18136ad58ce3ecb34 Mon Sep 17 00:00:00 2001 From: Augustin Zidek Date: Wed, 29 Oct 2025 07:54:40 -0700 Subject: [PATCH] Add option to run Jackhmmer/Nhmmer genetic search in sharded mode (10-30x faster) PiperOrigin-RevId: 825546064 Change-Id: Ib421e47bb9ca7eea512c49a532e7e995a0f5721f --- docs/performance.md | 79 +++++++++ run_alphafold.py | 74 ++++++++ src/alphafold3/data/msa_config.py | 23 ++- src/alphafold3/data/pipeline.py | 52 +++++- src/alphafold3/data/tools/jackhmmer.py | 232 ++++++++++++++++++++++--- src/alphafold3/data/tools/msa_tool.py | 11 +- src/alphafold3/data/tools/nhmmer.py | 219 +++++++++++++++++++++-- src/alphafold3/data/tools/shards.py | 94 ++++++++++ 8 files changed, 737 insertions(+), 47 deletions(-) create mode 100644 src/alphafold3/data/tools/shards.py diff --git a/docs/performance.md b/docs/performance.md index 0e6de64..55caf7d 100644 --- a/docs/performance.md +++ b/docs/performance.md @@ -82,6 +82,85 @@ is the number of cores used for each Jackhmmer process times 4. Also note that for sequences with deep MSAs, Jackhmmer or Nhmmer may need a substantial amount of RAM beyond the recommended 64 GB of RAM. +### Sharded genetic databases + +The run time of the genetic database search can be *significantly* sped up by +splitting the genetic databases if a machine with many CPU cores is used and the +databases are on very fast SSD or in a RAM-backed filesystem. With this +technique you can make Jackhmmer/Nhmmer genetic search fully utilize your +hardware and take advantage of multi-core systems. + +Each genetic database with *n* sequences is split into *s* shards, each +containing roughly *n* / *s* sequences. We recommend splitting the sequences +between shards randomly to make sure each shard has similar sequence length +distribution. This could be achieved using standard tools: + +1. Shuffle the sequences in the fasta. This can be done for example by running: + `seqkit shuffle --two-pass ` +2. Split the shuffled fasta in *s* shards. This can be done for example by + running: `seqkit split2 --by-part ` + +Make sure the shards names follow this pattern: +`prefix--of-`, both `shard_index` and `total_shards` +having always 5 digits, with leading zeros as needed. The `shard_index` goes +from 0 to `total_shards - 1`. A file "path" (spec) for a sharded file is +`prefix@`. + +E.g. for a file named `uniprot.fasta` split into 3 shards, the names of the +shards should be: + +* `uniprot.fasta-00000-of-00003` +* `uniprot.fasta-00001-of-00003` +* `uniprot.fasta-00002-of-00003` + +The file spec for these files is `uniprot.fasta@3`. + +Save the total number of sequences in the protein databases, and the total +number of nucleic bases in the RNA databases – these will be needed later as a +flag to Jackhmmer/Nhmmer to correctly scale e-values across all shards. + +Save the sharded databases on a fast SSD or in a RAM-backed filesystem, then +launch AlphaFold with the sharded paths instead of normal paths and set the +Z-values. + +For instance with each database sharded into 16 shards: + +```bash +python run_alphafold.py \ + --small_bfd_database_path="bfd-first_non_consensus_sequences.fasta@64" \ + --small_bfd_z_value=65984053 \ + --mgnify_database_path="mgy_clusters_2022_05.fa@512" \ + --mgnify_z_value=623796864 \ + --uniprot_cluster_annot_database_path="uniprot_cluster_annot_2021_04.fasta@256" \ + --uniprot_cluster_annot_z_value=225619586 \ + --uniref90_database_path="uniref90_2022_05.fasta@128" \ + --uniref90_z_value=153742194 \ + --ntrna_database_path="nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta@256" \ + --ntrna_z_value=76752.808514 \ + --rfam_database_path="rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta@16" \ + --rfam_z_value=138.115553 \ + --rna_central_database_path="rnacentral_active_seq_id_90_cov_80_linclust.fasta@64" \ + --rna_central_z_value=13271.415730 + --jackhmmer_n_cpu=2 \ + --jackhmmer_max_parallel_shards=16 \ + --nhmmer_n_cpu=2 \ + --nhmmer_max_parallel_shards=16 +``` + +This run will utilize (2 CPUs) × (16 max parallel shards) × (4 protein dbs +searched in parallel) = 128 cores for each protein chain, and (2 CPUs) × (16 max +parallel shards) × (3 RNA dbs searched in parallel) = 96 cores for each RNA +chain. Make sure to tune: + +* the Jackhmmer/Nhmmer number of CPUs, +* the maximum number of shards searched in parallel, +* and the number of shards for each database + +so that the memory bandwidth and CPUs on your machine are optimally utilized. +You should aim for consistent shard sizes across all databases (so e.g. if +database A is split into 16 shards and is 3× smaller than database B, database B +should be split into 3 × 16 = 48 shards). + ## Model Inference Table 8 in the Supplementary Information of the diff --git a/run_alphafold.py b/run_alphafold.py index b7c6590..ca1030b 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -132,37 +132,86 @@ _SMALL_BFD_DATABASE_PATH = flags.DEFINE_string( '${DB_DIR}/bfd-first_non_consensus_sequences.fasta', 'Small BFD database path, used for protein MSA search.', ) +_SMALL_BFD_Z_VALUE = flags.DEFINE_integer( + 'small_bfd_z_value', + None, + 'The Z-value representing the database size in number of sequences for' + ' E-value calculation. Must be set for sharded databases.', + lower_bound=0, +) _MGNIFY_DATABASE_PATH = flags.DEFINE_string( 'mgnify_database_path', '${DB_DIR}/mgy_clusters_2022_05.fa', 'Mgnify database path, used for protein MSA search.', ) +_MGNIFY_Z_VALUE = flags.DEFINE_integer( + 'mgnify_z_value', + None, + 'The Z-value representing the database size in number of sequences for' + ' E-value calculation. Must be set for sharded databases.', + lower_bound=0, +) _UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string( 'uniprot_cluster_annot_database_path', '${DB_DIR}/uniprot_all_2021_04.fa', 'UniProt database path, used for protein paired MSA search.', ) +_UNIPROT_CLUSTER_ANNOT_Z_VALUE = flags.DEFINE_integer( + 'uniprot_cluster_annot_z_value', + None, + 'The Z-value representing the database size in number of sequences for' + ' E-value calculation. Must be set for sharded databases.', + lower_bound=0, +) _UNIREF90_DATABASE_PATH = flags.DEFINE_string( 'uniref90_database_path', '${DB_DIR}/uniref90_2022_05.fa', 'UniRef90 database path, used for MSA search. The MSA obtained by ' 'searching it is used to construct the profile for template search.', ) +_UNIREF90_Z_VALUE = flags.DEFINE_integer( + 'uniref90_z_value', + None, + 'The Z-value representing the database size in number of sequences for' + ' E-value calculation. Must be set for sharded databases.', + lower_bound=0, +) _NTRNA_DATABASE_PATH = flags.DEFINE_string( 'ntrna_database_path', '${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta', 'NT-RNA database path, used for RNA MSA search.', ) +_NTRNA_Z_VALUE = flags.DEFINE_float( + 'ntrna_z_value', + None, + 'The Z-value representing the database size in megabases for E-value' + ' calculation. Must be set for sharded databases.', + lower_bound=0.0, +) _RFAM_DATABASE_PATH = flags.DEFINE_string( 'rfam_database_path', '${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta', 'Rfam database path, used for RNA MSA search.', ) +_RFAM_Z_VALUE = flags.DEFINE_float( + 'rfam_z_value', + None, + 'The Z-value representing the database size in megabases for E-value' + ' calculation. Must be set for sharded databases.', + lower_bound=0.0, +) _RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string( 'rna_central_database_path', '${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta', 'RNAcentral database path, used for RNA MSA search.', ) +_RNA_CENTRAL_Z_VALUE = flags.DEFINE_float( + 'rna_central_z_value', + None, + 'The Z-value representing the database size in megabases for E-value' + ' calculation. Must be set for sharded databases.', + lower_bound=0.0, +) _PDB_DATABASE_PATH = flags.DEFINE_string( 'pdb_database_path', '${DB_DIR}/mmcif_files', @@ -183,6 +232,14 @@ _JACKHMMER_N_CPU = flags.DEFINE_integer( ' above 8 CPUs provides very little additional speedup.', lower_bound=0, ) +_JACKHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer( + 'jackhmmer_max_parallel_shards', + None, + 'Maximum number of shards to search against in parallel. If unset, one' + ' Jackhmmer instance will be run per shard. Only applicable if the' + ' database is sharded.', + lower_bound=1, +) _NHMMER_N_CPU = flags.DEFINE_integer( 'nhmmer_n_cpu', # Unfortunately, os.process_cpu_count() is only available in Python 3.13+. @@ -191,6 +248,14 @@ _NHMMER_N_CPU = flags.DEFINE_integer( ' above 8 CPUs provides very little additional speedup.', lower_bound=0, ) +_NHMMER_MAX_PARALLEL_SHARDS = flags.DEFINE_integer( + 'nhmmer_max_parallel_shards', + None, + 'Maximum number of shards to search against in parallel. If unset, one' + ' Nhmmer instance will be run per shard. Only applicable if the' + ' database is sharded.', + lower_bound=1, +) # Data pipeline configuration. _RESOLVE_MSA_OVERLAPS = flags.DEFINE_bool( @@ -828,18 +893,27 @@ def main(_): hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value, hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value, small_bfd_database_path=expand_path(_SMALL_BFD_DATABASE_PATH.value), + small_bfd_z_value=_SMALL_BFD_Z_VALUE.value, mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value), + mgnify_z_value=_MGNIFY_Z_VALUE.value, uniprot_cluster_annot_database_path=expand_path( _UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value ), + uniprot_cluster_annot_z_value=_UNIPROT_CLUSTER_ANNOT_Z_VALUE.value, uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value), + uniref90_z_value=_UNIREF90_Z_VALUE.value, ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value), + ntrna_z_value=_NTRNA_Z_VALUE.value, rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value), + rfam_z_value=_RFAM_Z_VALUE.value, rna_central_database_path=expand_path(_RNA_CENTRAL_DATABASE_PATH.value), + rna_central_z_value=_RNA_CENTRAL_Z_VALUE.value, pdb_database_path=expand_path(_PDB_DATABASE_PATH.value), seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value), jackhmmer_n_cpu=_JACKHMMER_N_CPU.value, + jackhmmer_max_parallel_shards=_JACKHMMER_MAX_PARALLEL_SHARDS.value, nhmmer_n_cpu=_NHMMER_N_CPU.value, + nhmmer_max_parallel_shards=_NHMMER_MAX_PARALLEL_SHARDS.value, max_template_date=max_template_date, ) else: diff --git a/src/alphafold3/data/msa_config.py b/src/alphafold3/data/msa_config.py index 4670010..c2b0333 100644 --- a/src/alphafold3/data/msa_config.py +++ b/src/alphafold3/data/msa_config.py @@ -42,9 +42,16 @@ class JackhmmerConfig: n_cpu: An integer with the number of CPUs to use. n_iter: An integer with the number of database search iterations. e_value: e-value for the database lookup. - z_value: The Z-value representing the number of comparisons done (i.e - correct database size) for E-value calculation. + z_value: The Z-value representing the database size in number of sequences + for E-value and domain E-value calculation. Must be set for sharded + databases. + dom_z_value: The Z-value representing the database size in number of + sequences for domain E-value calculation. Must be set for sharded + databases. max_sequences: Max sequences to return in MSA. + max_parallel_shards: If given, the maximum number of shards to search + against in parallel. If None, one Jackhmmer instance will be run per + shard. Only applicable if the database is sharded. """ binary_path: str @@ -52,8 +59,10 @@ class JackhmmerConfig: n_cpu: int n_iter: int e_value: float - z_value: float | int | None + z_value: int | None + dom_z_value: int | None max_sequences: int + max_parallel_shards: int | None = None @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) @@ -67,8 +76,14 @@ class NhmmerConfig: database_config: Database configuration. n_cpu: An integer with the number of CPUs to use. e_value: e-value for the database lookup. + z_value: The Z-value representing the database size in megabases for + E-value calculation. Allows fractional values. Must be set for sharded + databases. max_sequences: Max sequences to return in MSA. alphabet: The alphabet when building a profile with hmmbuild. + max_parallel_shards: If given, the maximum number of shards to search + against in parallel. If None, one Nhmmer instance will be run per shard. + Only applicable if the database is sharded. """ binary_path: str @@ -77,8 +92,10 @@ class NhmmerConfig: database_config: DatabaseConfig n_cpu: int e_value: float + z_value: float | None max_sequences: int alphabet: str | None + max_parallel_shards: int | None = None @dataclasses.dataclass(frozen=True, kw_only=True, slots=True) diff --git a/src/alphafold3/data/pipeline.py b/src/alphafold3/data/pipeline.py index 026ccfa..db48d03 100644 --- a/src/alphafold3/data/pipeline.py +++ b/src/alphafold3/data/pipeline.py @@ -214,21 +214,42 @@ class DataPipelineConfig: raw MSA in template search. small_bfd_database_path: Small BFD database path, used for protein MSA search. + small_bfd_z_value: The Z-value representing the database size in number of + sequences for E-value calculation. Must be set for sharded databases. mgnify_database_path: Mgnify database path, used for protein MSA search. + mgnify_z_value: The Z-value representing the database size in number of + sequences for E-value calculation. Must be set for sharded databases. uniprot_cluster_annot_database_path: Uniprot database path, used for protein paired MSA search. + uniprot_cluster_annot_z_value: The Z-value representing the database size in + number of sequences for E-value calculation. Must be set for sharded + databases. uniref90_database_path: UniRef90 database path, used for MSA search, and the MSA obtained by searching it is used to construct the profile for template search. + uniref90_z_value: The Z-value representing the database size in number of + sequences for E-value calculation. Must be set for sharded databases. ntrna_database_path: NT-RNA database path, used for RNA MSA search. + ntrna_z_value: The Z-value representing the database size in megabases for + E-value calculation. Must be set for sharded databases. rfam_database_path: Rfam database path, used for RNA MSA search. + rfam_z_value: The Z-value representing the database size in megabases for + E-value calculation. Must be set for sharded databases. rna_central_database_path: RNAcentral database path, used for RNA MSA search. + rna_central_z_value: The Z-value representing the database size in megabases + for E-value calculation. Must be set for sharded databases. seqres_database_path: PDB sequence database path, used for template search. pdb_database_path: PDB database directory with mmCIF files path, used for template search. jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer. + jackhmmer_max_parallel_shards: Maximum number of shards to search against in + parallel. If None, one Jackhmmer instance will be run per shard. Only + applicable if the database is sharded. nhmmer_n_cpu: Number of CPUs to use for Nhmmer. + nhmmer_max_parallel_shards: Maximum number of shards to search against in + parallel. If None, one Nhmmer instance will be run per shard. Only + applicable if the database is sharded. max_template_date: The latest date of templates to use. """ @@ -241,20 +262,29 @@ class DataPipelineConfig: # Jackhmmer databases. small_bfd_database_path: str + small_bfd_z_value: int | None = None mgnify_database_path: str + mgnify_z_value: int | None = None uniprot_cluster_annot_database_path: str + uniprot_cluster_annot_z_value: int | None = None uniref90_database_path: str + uniref90_z_value: int | None = None # Nhmmer databases. ntrna_database_path: str + ntrna_z_value: int | None = None rfam_database_path: str + rfam_z_value: int | None = None rna_central_database_path: str + rna_central_z_value: int | None = None # Template search databases. seqres_database_path: str pdb_database_path: str # Optional configuration for MSA tools. jackhmmer_n_cpu: int = 8 + jackhmmer_max_parallel_shards: int | None = None nhmmer_n_cpu: int = 8 + nhmmer_max_parallel_shards: int | None = None max_template_date: datetime.date @@ -274,8 +304,10 @@ class DataPipeline: n_cpu=data_pipeline_config.jackhmmer_n_cpu, n_iter=1, e_value=1e-4, - z_value=None, + z_value=data_pipeline_config.uniref90_z_value, + dom_z_value=data_pipeline_config.uniref90_z_value, max_sequences=10_000, + max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.PROTEIN_CHAIN, crop_size=None, @@ -290,8 +322,10 @@ class DataPipeline: n_cpu=data_pipeline_config.jackhmmer_n_cpu, n_iter=1, e_value=1e-4, - z_value=None, + z_value=data_pipeline_config.mgnify_z_value, + dom_z_value=data_pipeline_config.mgnify_z_value, max_sequences=5_000, + max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.PROTEIN_CHAIN, crop_size=None, @@ -308,8 +342,10 @@ class DataPipeline: e_value=1e-4, # Set z_value=138_515_945 to match the z_value used in the paper. # In practice, this has minimal impact on predicted structures. - z_value=None, + z_value=data_pipeline_config.small_bfd_z_value, + dom_z_value=data_pipeline_config.small_bfd_z_value, max_sequences=5_000, + max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.PROTEIN_CHAIN, crop_size=None, @@ -324,8 +360,10 @@ class DataPipeline: n_cpu=data_pipeline_config.jackhmmer_n_cpu, n_iter=1, e_value=1e-4, - z_value=None, + z_value=data_pipeline_config.uniprot_cluster_annot_z_value, + dom_z_value=data_pipeline_config.uniprot_cluster_annot_z_value, max_sequences=50_000, + max_parallel_shards=data_pipeline_config.jackhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.PROTEIN_CHAIN, crop_size=None, @@ -342,7 +380,9 @@ class DataPipeline: n_cpu=data_pipeline_config.nhmmer_n_cpu, e_value=1e-3, alphabet='rna', + z_value=data_pipeline_config.ntrna_z_value, max_sequences=10_000, + max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.RNA_CHAIN, crop_size=None, @@ -359,7 +399,9 @@ class DataPipeline: n_cpu=data_pipeline_config.nhmmer_n_cpu, e_value=1e-3, alphabet='rna', + z_value=data_pipeline_config.rfam_z_value, max_sequences=10_000, + max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.RNA_CHAIN, crop_size=None, @@ -376,7 +418,9 @@ class DataPipeline: n_cpu=data_pipeline_config.nhmmer_n_cpu, e_value=1e-3, alphabet='rna', + z_value=data_pipeline_config.rna_central_z_value, max_sequences=10_000, + max_parallel_shards=data_pipeline_config.nhmmer_max_parallel_shards, ), chain_poly_type=mmcif_names.RNA_CHAIN, crop_size=None, diff --git a/src/alphafold3/data/tools/jackhmmer.py b/src/alphafold3/data/tools/jackhmmer.py index 1ea94d7..bb2fec1 100644 --- a/src/alphafold3/data/tools/jackhmmer.py +++ b/src/alphafold3/data/tools/jackhmmer.py @@ -10,12 +10,19 @@ """Library to run Jackhmmer from Python.""" +from collections.abc import Iterable, Sequence +from concurrent import futures +import heapq import os +import pathlib +import shutil import tempfile +import time from absl import logging from alphafold3.data import parsers from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import shards from alphafold3.data.tools import subprocess_utils @@ -31,43 +38,91 @@ class Jackhmmer(msa_tool.MsaTool): n_iter: int = 3, e_value: float | None = 1e-3, z_value: float | int | None = None, + dom_e: float | None = None, + dom_z_value: float | int | None = None, max_sequences: int = 5000, filter_f1: float = 5e-4, filter_f2: float = 5e-5, filter_f3: float = 5e-7, + max_threads: int | None = None, + **unused_kwargs, ): """Initializes the Python Jackhmmer wrapper. + NOTE: The MSA obtained by running against sharded dbs won't be always + exactly the same as the MSA obtained by running against an unsharded db. + This is because of Jackhmmer deduplication logic, which won't spot duplicate + hits across multiple shards. Usually this means that the sharded search + finds more hits (likely bounded by the number of shards), but this should + not pose an issue given how the results are used downstream. The problem is + more pronounced with deep MSAs and lower in the hit list (higher e-values). + + Make sure to set the Z and domZ values when searching against a sharded + database, otherwise the results won't match the normal unsharded search. + Args: binary_path: The path to the jackhmmer executable. - database_path: The path to the jackhmmer database (FASTA format). + database_path: The path to the jackhmmer database (FASTA format). Sharded + file specs, e.g. `@`, are supported. n_cpu: The number of CPUs to give Jackhmmer. n_iter: The number of Jackhmmer iterations. e_value: The E-value, see Jackhmmer docs for more details. z_value: The Z-value representing the number of comparisons done (i.e - correct database size) for E-value calculation. + correct database size) for E-value calculation. Make sure to set this + when searching against a sharded database, otherwise the e-values will + be incorrectly scaled. + dom_e: Domain e-value criteria for inclusion in tblout. + dom_z_value: Domain z-value representing the number of comparisons done + (i.e correct database size) for domain E-value calculation. Make sure to + set this when searching against a sharded database, otherwise the domain + e-values will be incorrectly scaled. max_sequences: Maximum number of sequences to return in the MSA. filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. filter_f2: Viterbi pre-filter, set to >1.0 to turn off. filter_f3: Forward pre-filter, set to >1.0 to turn off. + max_threads: If given, the maximum number of threads used when running + sharded databases. Raises: RuntimeError: If Jackhmmer binary not found within the path. + ValueError: If an invalid configuration is provided in the args. """ - self._binary_path = binary_path self._database_path = database_path + if shard_paths := shards.get_sharded_paths(self._database_path): + if n_iter != 1: + raise ValueError('For a sharded db, only n_iter=1 is supported.') + if z_value is None: + raise ValueError( + 'The Z-value must be set when searching against a sharded database ' + 'to correctly scale e-values.' + ) + if max_sequences <= 1: + raise ValueError( + 'max_sequences must be greater than 1 when running in sharded ' + 'mode, because each shard would return only the query sequence.' + ) + + self._shard_paths = shard_paths + self._max_threads = len(self._shard_paths) + if max_threads is not None: + self._max_threads = min(max_threads, self._max_threads) + logging.info('Jackhmmer running with max_threads = %d', self._max_threads) + else: + self._shard_paths = None + self._max_threads = None + + self._binary_path = binary_path subprocess_utils.check_binary_exists( path=self._binary_path, name='Jackhmmer' ) - if not os.path.exists(self._database_path): - raise ValueError(f'Could not find Jackhmmer database {database_path}') - self._n_cpu = n_cpu self._n_iter = n_iter self._e_value = e_value self._z_value = z_value + self._dom_e = dom_e + self._dom_z_value = dom_z_value self._max_sequences = max_sequences self._filter_f1 = filter_f1 self._filter_f2 = filter_f2 @@ -81,21 +136,73 @@ class Jackhmmer(msa_tool.MsaTool): ) def query(self, target_sequence: str) -> msa_tool.MsaToolResult: - """Queries the database using Jackhmmer.""" - logging.info( - 'Query sequence: %s', - target_sequence - if len(target_sequence) <= 16 - else f'{target_sequence[:16]}... (len {len(target_sequence)})', - ) + """Query the database (sharded or unsharded) using Jackhmmer.""" + if self._shard_paths: + # Sharded case, run the query against each database shard in parallel. + logging.info( + 'Query sequence (sharded db): %s', + target_sequence + if len(target_sequence) <= 16 + else f'{target_sequence[:16]}... (len {len(target_sequence)})', + ) - with tempfile.TemporaryDirectory() as query_tmp_dir: + global_temp_dir = tempfile.mkdtemp() + + def _query_shard_fn( + shard_path: str, + ) -> tuple[msa_tool.MsaToolResult, float]: + t_start = time.time() + result = self._query_db_shard( + target_sequence=target_sequence, + db_shard_path=shard_path, + get_tblout=True, # Tblout contains e-values needed for merging. + global_temp_dir=global_temp_dir, + ) + return result, time.time() - t_start + + with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex: + tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths)) + + logging.info( + 'Finished query for %d shards, shard timings (seconds): %s', + len(tool_outputs), + ', '.join(f'{t:.1f}' for t in timings), + ) + + shutil.rmtree(global_temp_dir, ignore_errors=True) + return _merge_jackhmmer_results(tool_outputs, self._max_sequences) + + else: + # Non-sharded case, run the query against the whole database. + logging.info( + 'Query sequence (non-sharded db): %s', + target_sequence + if len(target_sequence) <= 16 + else f'{target_sequence[:16]}... (len {len(target_sequence)})', + ) + return self._query_db_shard( + target_sequence=target_sequence, + db_shard_path=self._database_path, + get_tblout=False, + ) + + def _query_db_shard( + self, + *, + target_sequence: str, + db_shard_path: str, + get_tblout: bool, + global_temp_dir: str | None = None, + ) -> msa_tool.MsaToolResult: + """Query the database shard using Jackhmmer.""" + + with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir: input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta') subprocess_utils.create_query_fasta_file( sequence=target_sequence, path=input_fasta_path ) - output_sto_path = os.path.join(query_tmp_dir, 'output.sto') + pathlib.Path(output_sto_path).touch() # The F1/F2/F3 are the expected proportion to pass each of the filtering # stages (which get progressively more expensive), reducing these @@ -113,6 +220,13 @@ class Jackhmmer(msa_tool.MsaTool): *('-N', str(self._n_iter)), ] + if get_tblout: + output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt') + output_tblout_path.touch() + cmd_flags.extend(['--tblout', str(output_tblout_path)]) + else: + output_tblout_path = None + # Report only sequences with E-values <= x in per-sequence output. if self._e_value is not None: cmd_flags.extend(['-E', str(self._e_value)]) @@ -123,18 +237,21 @@ class Jackhmmer(msa_tool.MsaTool): if self._z_value is not None: cmd_flags.extend(['-Z', str(self._z_value)]) + if self._dom_z_value is not None: + cmd_flags.extend(['--domZ', str(self._dom_z_value)]) + + if self._dom_e is not None: + cmd_flags.extend(['--domE', str(self._dom_e)]) + if self._max_sequences is not None and self._supports_seq_limit: cmd_flags.extend(['--seq_limit', str(self._max_sequences)]) - cmd = ( - [self._binary_path] - + cmd_flags - + [input_fasta_path, self._database_path] - ) + # The input FASTA and the input db are the last two arguments. + cmd = [self._binary_path] + cmd_flags + [input_fasta_path, db_shard_path] subprocess_utils.run( cmd=cmd, - cmd_name=f'Jackhmmer ({os.path.basename(self._database_path)})', + cmd_name=f'Jackhmmer ({os.path.basename(db_shard_path)})', log_stdout=False, log_stderr=True, log_on_process_error=True, @@ -145,6 +262,73 @@ class Jackhmmer(msa_tool.MsaTool): f, max_sequences=self._max_sequences ) - return msa_tool.MsaToolResult( - target_sequence=target_sequence, a3m=a3m, e_value=self._e_value - ) + # Get the tabular output which has e.g. e-value for each target. + tbl = '' if output_tblout_path is None else output_tblout_path.read_text() + + return msa_tool.MsaToolResult( + target_sequence=target_sequence, + a3m=a3m, + e_value=self._e_value, + tblout=tbl, + ) + + +def _merge_jackhmmer_results( + jh_results: Sequence[msa_tool.MsaToolResult], max_sequences: int +) -> msa_tool.MsaToolResult: + """Merges Jackhmmer result protos into a single one.""" + assert len(set(jh_res.target_sequence for jh_res in jh_results)) == 1 + assert len(set(jh_res.e_value for jh_res in jh_results)) == 1 + + # Parse the TBL output, create a mapping from hit name to TBL line. + parsed_tbl = {} + for jh_result in jh_results: + assert jh_result.tblout is not None + for line in jh_result.tblout.splitlines(): + if not line.startswith('#'): + parsed_tbl[line.partition(' ')[0]] = line + + # Create an iterator and merge a3m info with tbl info. + def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]: + # Don't parse the entire a3m, lazily parse only as many sequences as needed. + iterator = iter(parsers.lazy_parse_fasta_string(a3m)) + next(iterator) # Skip the query which isn't present in tblout. + for sequence, description in iterator: + name = description.partition(' ')[0].partition('/')[0] + if tbl_info := parsed_tbl.get(name): + # Skip sequences for which we don't have tbl information. + yield sequence, description, tbl_info, name + + def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]: + unused_seq, unused_description, tbl_info, name = seq_data + # Tblout lines have 19 whitespace delimited columns. "-" used if no value + # present. We want e-value in column with index 4, so do only 5 splits. + # Use the name in case of a e-value tie. + return float(tbl_info.split(maxsplit=5)[4]), name + + # A3M/TBL is sorted by e-value and name, hence we can merge them efficiently. + merged_a3m_and_tblout = heapq.merge( + *[_merged_a3m_tbl_iter(res.a3m) for res in jh_results], + key=sort_key, + ) + + # Truncate the a3m to max_sequences. Do not truncate the tblout. + merged_tblout = [] + merged_a3m = [f'>query\n{jh_results[0].target_sequence}'] + for seq, description, tbl_info, _ in merged_a3m_and_tblout: + merged_tblout.append(tbl_info) + if len(merged_a3m) < max_sequences: + merged_a3m.append(f'>{description}\n{seq}') + + logging.info( + 'Limiting merged MSA depth from %d to %d', + len(merged_tblout), + max_sequences, + ) + + return msa_tool.MsaToolResult( + target_sequence=jh_results[0].target_sequence, + a3m='\n'.join(merged_a3m), + e_value=jh_results[0].e_value, + tblout=None, # We no longer need the tblout. + ) diff --git a/src/alphafold3/data/tools/msa_tool.py b/src/alphafold3/data/tools/msa_tool.py index a739f06..c6c076c 100644 --- a/src/alphafold3/data/tools/msa_tool.py +++ b/src/alphafold3/data/tools/msa_tool.py @@ -16,11 +16,20 @@ from typing import Protocol @dataclasses.dataclass(frozen=True, slots=True, kw_only=True) class MsaToolResult: - """The result of a MSA tool query.""" + """The result of a MSA tool query. + + Attributes: + target_sequence: The sequence that was used to query the MSA tool. + e_value: The e-value that was used to filter the MSA tool results. + a3m: The MSA output of the tool in the A3M format. + tblout: The optional tblout output of the MSA tool (needed for merging + results of queries against a sharded database). + """ target_sequence: str e_value: float a3m: str + tblout: str | None = None class MsaTool(Protocol): diff --git a/src/alphafold3/data/tools/nhmmer.py b/src/alphafold3/data/tools/nhmmer.py index c4989e2..f833ac7 100644 --- a/src/alphafold3/data/tools/nhmmer.py +++ b/src/alphafold3/data/tools/nhmmer.py @@ -10,9 +10,14 @@ """Library to run Nhmmer from Python.""" +from collections.abc import Iterable, Sequence +from concurrent import futures +import heapq import os import pathlib +import shutil import tempfile +import time from typing import Final from absl import logging @@ -20,8 +25,10 @@ from alphafold3.data import parsers from alphafold3.data.tools import hmmalign from alphafold3.data.tools import hmmbuild from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import shards from alphafold3.data.tools import subprocess_utils + _SHORT_SEQUENCE_CUTOFF: Final[int] = 50 @@ -36,38 +43,83 @@ class Nhmmer(msa_tool.MsaTool): database_path: str, n_cpu: int = 8, e_value: float = 1e-3, + z_value: float | int | None = None, max_sequences: int = 5000, filter_f3: float = 1e-5, alphabet: str | None = None, strand: str | None = None, + max_threads: int | None = None, ): """Initializes the Python Nhmmer wrapper. + NOTE: The MSA obtained by running against sharded dbs won't be always + exactly the same as the MSA obtained by running against an unsharded db. + This is because of Jackhmmer deduplication logic, which won't spot duplicate + hits across multiple shards. Usually this means that the sharded search + finds more hits (likely bounded by the number of shards), but this should + not pose an issue given how the results are used downstream. The problem is + more pronounced with deep MSAs and lower in the hit list (higher e-values). + + Make sure to set the Z value when searching against a sharded database, + otherwise the results won't match the normal unsharded search. + Args: binary_path: Path to the Nhmmer binary. hmmalign_binary_path: Path to the Hmmalign binary. hmmbuild_binary_path: Path to the Hmmbuild binary. database_path: MSA database path to search against. This can be either a FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb - binary. The HMMERDB is ~10x faster but experimental. + binary. The HMMERDB is ~10x faster but experimental. Sharded file + specs, e.g. @, are supported. n_cpu: The number of CPUs to give Nhmmer. e_value: The E-value, see Nhmmer docs for more details. Will be overwritten if bit_score is set. + z_value: The Z-value representing the number of comparisons done (i.e + correct database size) for E-value calculation. Make sure to set this + when searching against a sharded database, otherwise the e-values will + be incorrectly scaled. max_sequences: Maximum number of sequences to return in the MSA. filter_f3: Forward pre-filter, set to >1.0 to turn off. alphabet: The alphabet to assert when building a profile with hmmbuild. This must be 'rna', 'dna', or None. strand: "watson" searches query sequence, "crick" searches reverse-compliment and default is None which means searching for both. + max_threads: If given, the maximum number of threads used when running + sharded databases. Raises: RuntimeError: If Nhmmer binary not found within the path. + ValueError: If an invalid configuration is provided in the args. """ + self._database_path = database_path + + if shard_paths := shards.get_sharded_paths(self._database_path): + if z_value is None: + raise ValueError( + 'The Z-value must be set when searching against a sharded database ' + 'to correctly scale e-values.' + ) + if 'hmmerdb' in self._database_path: + raise ValueError('HMMERDB is not supported in sharded mode.') + + if max_sequences <= 1: + raise ValueError( + 'max_sequences must be greater than 1 when running in sharded ' + 'mode, because each shard would return only the query sequence.' + ) + + self._shard_paths = shard_paths + self._max_threads = len(self._shard_paths) + if max_threads is not None: + self._max_threads = min(max_threads, self._max_threads) + logging.info('Nhmmer running with max_threads = %d', self._max_threads) + else: + self._shard_paths = None + self._max_threads = None + self._binary_path = binary_path self._hmmalign_binary_path = hmmalign_binary_path self._hmmbuild_binary_path = hmmbuild_binary_path - self._db_path = database_path - subprocess_utils.check_binary_exists(path=self._binary_path, name='Nhmmer') if strand and strand not in {'watson', 'crick'}: @@ -78,21 +130,75 @@ class Nhmmer(msa_tool.MsaTool): self._e_value = e_value self._n_cpu = n_cpu + self._z_value = z_value self._max_sequences = max_sequences self._filter_f3 = filter_f3 self._alphabet = alphabet self._strand = strand def query(self, target_sequence: str) -> msa_tool.MsaToolResult: - """Query the database using Nhmmer.""" - logging.info( - 'Query sequence: %s', - target_sequence - if len(target_sequence) <= 16 - else f'{target_sequence[:16]}... (len {len(target_sequence)})', - ) + """Query the database (sharded or unsharded) using Nhmmer.""" + if self._shard_paths: + # Sharded case, run the query against each database shard in parallel. + logging.info( + 'Query sequence (sharded db): %s', + target_sequence + if len(target_sequence) <= 16 + else f'{target_sequence[:16]}... (len {len(target_sequence)})', + ) - with tempfile.TemporaryDirectory() as query_tmp_dir: + global_temp_dir = tempfile.mkdtemp() + + def _query_shard_fn( + shard_path: str, + ) -> tuple[msa_tool.MsaToolResult, float]: + t_start = time.time() + # Get tblout as it contains e-values we need for merging sequences. + result = self._query_db_shard( + target_sequence=target_sequence, + db_shard_path=shard_path, + get_tblout=True, # Tblout contains e-values needed for merging. + global_temp_dir=global_temp_dir, + ) + return result, time.time() - t_start + + with futures.ThreadPoolExecutor(max_workers=self._max_threads) as ex: + tool_outputs, timings = zip(*ex.map(_query_shard_fn, self._shard_paths)) + + logging.info( + 'Finished query for %d shards, shard timings (seconds): %s', + len(tool_outputs), + ', '.join(f'{t:.1f}' for t in timings), + ) + + shutil.rmtree(global_temp_dir, ignore_errors=True) + return _merge_nhmmer_results(tool_outputs, self._max_sequences) + + else: + # Non-sharded case, run the query against the whole database. + logging.info( + 'Query sequence (non-sharded db): %s', + target_sequence + if len(target_sequence) <= 16 + else f'{target_sequence[:16]}... (len {len(target_sequence)})', + ) + return self._query_db_shard( + target_sequence=target_sequence, + db_shard_path=self._database_path, + get_tblout=False, + ) + + def _query_db_shard( + self, + *, + target_sequence: str, + db_shard_path: str, + get_tblout: bool, + global_temp_dir: str | None = None, + ) -> msa_tool.MsaToolResult: + """Query the database shard using Nhmmer.""" + + with tempfile.TemporaryDirectory(dir=global_temp_dir) as query_tmp_dir: input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m') output_sto_path = os.path.join(query_tmp_dir, 'output.sto') pathlib.Path(output_sto_path).touch() @@ -106,8 +212,18 @@ class Nhmmer(msa_tool.MsaTool): *('--cpu', str(self._n_cpu)), ] + if get_tblout: + output_tblout_path = pathlib.Path(query_tmp_dir, 'tblout.txt') + output_tblout_path.touch() + cmd_flags.extend(['--tblout', str(output_tblout_path)]) + else: + output_tblout_path = None + cmd_flags.extend(['-E', str(self._e_value)]) + if self._z_value is not None: + cmd_flags.extend(['-Z', str(self._z_value)]) + if self._alphabet: cmd_flags.extend([f'--{self._alphabet}']) @@ -125,13 +241,12 @@ class Nhmmer(msa_tool.MsaTool): cmd_flags.extend(['--F3', str(self._filter_f3)]) # The input A3M and the db are the last two arguments. - cmd_flags.extend((input_a3m_path, self._db_path)) + cmd_flags.extend((input_a3m_path, db_shard_path)) cmd = [self._binary_path, *cmd_flags] - subprocess_utils.run( cmd=cmd, - cmd_name=f'Nhmmer ({os.path.basename(self._db_path)})', + cmd_name=f'Nhmmer ({os.path.basename(db_shard_path)})', log_stdout=False, log_stderr=True, log_on_process_error=True, @@ -166,6 +281,80 @@ class Nhmmer(msa_tool.MsaTool): # In this case return only the query sequence. a3m = f'>query\n{target_sequence}' + # Get the tabular output which has e.g. e-value for each target. + tbl = '' if output_tblout_path is None else output_tblout_path.read_text() + return msa_tool.MsaToolResult( - target_sequence=target_sequence, e_value=self._e_value, a3m=a3m + target_sequence=target_sequence, + e_value=self._e_value, + a3m=a3m, + tblout=tbl, ) + + +def _merge_nhmmer_results( + nhmmer_results: Sequence[msa_tool.MsaToolResult], + max_sequences: int, +) -> msa_tool.MsaToolResult: + """Merges nhmmer result protos into a single one.""" + assert len(set(nh_res.target_sequence for nh_res in nhmmer_results)) == 1 + assert len(set(nh_res.e_value for nh_res in nhmmer_results)) == 1 + + # Parse the TBL output, create a mapping from unique hit ID to TBL line. + parsed_tbl = {} + for nhmmer_result in nhmmer_results: + assert nhmmer_result.tblout is not None + for line in nhmmer_result.tblout.splitlines(): + if not line.startswith('#'): + line_fields = line.split(maxsplit=15) + accession = line_fields[0] + alignment_from = line_fields[6] + alignment_to = line_fields[7] + # This is the unique ID that is used in the output A3M. + unique_id = f'{accession}/{alignment_from}-{alignment_to}' + parsed_tbl[unique_id] = line + + # Create an iterator and merge a3m info with tbl info. + def _merged_a3m_tbl_iter(a3m: str) -> Iterable[tuple[str, str, str, str]]: + # Don't parse the entire a3m, lazily parse only as many sequences as needed. + iterator = iter(parsers.lazy_parse_fasta_string(a3m)) + next(iterator) # Skip the query which isn't present in tblout. + for sequence, description in iterator: + name = description.partition(' ')[0] + if tbl_info := parsed_tbl.get(name): + # Skip sequences for which we don't have tbl information. + yield sequence, description, tbl_info, name + + def sort_key(seq_data: tuple[str, str, str, str]) -> tuple[float, str]: + unused_seq, unused_description, tbl_info, name = seq_data + # Nucleic tblout has 16 space delimited columns. "-" used if no value + # present. We want e-value in column 12, so do only 13 splits. Use the name + # in case of an e-value tie. + return float(tbl_info.split(maxsplit=13)[12]), name + + # A3M/TBL is sorted by e-value and name, hence we can merge them efficiently. + merged_a3m_and_tblout = heapq.merge( + *[_merged_a3m_tbl_iter(res.a3m) for res in nhmmer_results], + key=sort_key, + ) + + # Truncate the a3m to max_sequences. Do not truncate the tblout. + merged_tblout = [] + merged_a3m = [f'>query\n{nhmmer_results[0].target_sequence}'] + for seq, description, tbl_info, _ in merged_a3m_and_tblout: + merged_tblout.append(tbl_info) + if len(merged_a3m) < max_sequences: + merged_a3m.append(f'>{description}\n{seq}') + + logging.info( + 'Limiting merged MSA depth from %d to %d', + len(merged_tblout), + max_sequences, + ) + + return msa_tool.MsaToolResult( + target_sequence=nhmmer_results[0].target_sequence, + a3m='\n'.join(merged_a3m), + e_value=nhmmer_results[0].e_value, + tblout=None, # We no longer need the tblout. + ) diff --git a/src/alphafold3/data/tools/shards.py b/src/alphafold3/data/tools/shards.py new file mode 100644 index 0000000..09e2387 --- /dev/null +++ b/src/alphafold3/data/tools/shards.py @@ -0,0 +1,94 @@ +# Copyright 2025 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""A library to handle shards of the format file_path@NUM_SHARDS. + +For instance, /path/to/file@20 will generate the following shards: + +- /path/to/file-00000-of-00020 +- /path/to/file-00001-of-00020 +- ... +- /path/to/file-00019-of-00020 + +This also supports @* pattern, which will determine the number of shards based +on the filesystem content. +""" + +from collections.abc import Sequence +import dataclasses +import pathlib +import re + + +_MAX_NUM_SHARDS = 99_999 +_SHARD_RE = re.compile( + r""" + ^(?P[^\?\],\*]+)@ + (?P(\d{1,5})|\*) + (?P[\._][^\?\]@\*\/]*)? + $""", + re.X, +) + + +@dataclasses.dataclass(frozen=True) +class ShardSpec: + prefix: str + num_shards: int + suffix: str + + +def parse_shard_spec(path: str) -> ShardSpec | None: + """Returns the shard spec or None if the path is not a shard spec. + + For instance, if the shard spec is '/path/to/file@20', the output will be + ('/path/to/file', 20). + + Args: + path: the path to parse, e.g. /path/to/file@20 or /path/to/file@*. + """ + parsed = re.fullmatch(_SHARD_RE, path) + if not parsed: + return None + prefix = parsed.group('prefix') + shards = parsed.group('shards') + suffix = parsed.group('suffix') or '' + + if shards != '*': + return ShardSpec(prefix=prefix, num_shards=int(shards), suffix=suffix) + shard_slice = slice(len(prefix) + 10, len(prefix) + 15) + shard_path = pathlib.Path(f'{prefix}-00000-of-?????{suffix}') + for shard in sorted(shard_path.parent.glob(shard_path.name), reverse=True): + try: + num_shards = int(str(shard)[shard_slice]) + return ShardSpec(prefix=prefix, num_shards=num_shards, suffix=suffix) + except ValueError: + continue + return None + + +def get_sharded_paths(shard_spec: str) -> Sequence[str] | None: + """Returns a list of file path or None if the input is not a shard spec. + + Args: + shard_spec: the specifications of the shard, e.g. /path/to/file@20. + """ + parsed_spec = parse_shard_spec(shard_spec) + if not parsed_spec: + return None + + prefix = parsed_spec.prefix + num_shards = parsed_spec.num_shards + suffix = parsed_spec.suffix + if num_shards > _MAX_NUM_SHARDS: + raise ValueError(f'Shard count for {shard_spec} exceeds {_MAX_NUM_SHARDS}') + return [ + f'{prefix}-{i:05d}-of-{num_shards:05d}{suffix}' for i in range(num_shards) + ]