From 2cd8c52d9fe520c4cd3ac8f3bce95528603b38d0 Mon Sep 17 00:00:00 2001 From: Ben Perry Date: Mon, 4 May 2026 12:18:27 -0400 Subject: [PATCH] adding batch size patch to catch low N edge cases --- docs/advanced.md | 6 +-- docs/input_format.md | 2 +- run_alphafold.py | 55 +++++++++++---------------- run_data_pipeline.py | 59 +++++++++++++---------------- scripts/run_alphafast.sh | 6 ++- scripts/run_multigpu.sh | 4 +- src/alphafold3/data/pipeline.py | 67 ++++++++++++++++++++++++++------- 7 files changed, 113 insertions(+), 86 deletions(-) diff --git a/docs/advanced.md b/docs/advanced.md index b098c46..84fcac0 100644 --- a/docs/advanced.md +++ b/docs/advanced.md @@ -49,9 +49,9 @@ The Modal deployment uses an alternative **producer-consumer** architecture via ## Batch Size Tuning -The `--batch_size` flag controls how many input JSON files are processed together in a single MMseqs2 GPU search. All protein sequences from a batch are collected into a single MMseqs2 query database, which is significantly more efficient than sequential processing. +The `--batch_size` flag controls how many unique protein sequences are processed together in a single MMseqs2 GPU search. Protein chains are collected and deduplicated across the loaded input JSON files, then split into MMseqs2 query batches of up to `--batch_size` unique sequences. This also batches multiple protein chains from a single JSON file. -When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir`. For manual runs: +When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir` as a simple heuristic. For manual runs: ```bash python run_data_pipeline.py \ @@ -118,7 +118,7 @@ Note: If `CUDA_VISIBLE_DEVICES` is set, `--gpu_device` refers to the index withi | `--use_mmseqs_gpu` | `true` | Use GPU-accelerated MMseqs2 | | `--mmseqs_n_threads` | all CPUs | CPU threads for MMseqs2 non-GPU operations | | `--mmseqs_sequential` | `false` | Run database searches sequentially (lower memory) | -| `--batch_size` | -- | Batch multiple inputs into one MMseqs2 search | +| `--batch_size` | -- | Unique protein sequences per MMseqs2 batch | ### Template Search diff --git a/docs/input_format.md b/docs/input_format.md index 46541ac..71aa07c 100644 --- a/docs/input_format.md +++ b/docs/input_format.md @@ -203,4 +203,4 @@ inputs/ protein_3.json ``` -Each file is processed independently. When `--batch_size` is set, protein sequences from multiple files are grouped into efficient batched MMseqs2-GPU searches. +Each file is loaded as a separate fold input. When `--batch_size` is set, protein chains are collected and deduplicated across the loaded inputs, then grouped into efficient batched MMseqs2-GPU searches. Multiple protein chains from a single JSON file can therefore be searched in the same MMseqs2 batch. diff --git a/run_alphafold.py b/run_alphafold.py index 4e44321..ab4c972 100644 --- a/run_alphafold.py +++ b/run_alphafold.py @@ -147,11 +147,11 @@ _MMSEQS_SEQUENTIAL = flags.DEFINE_bool( _BATCH_SIZE = flags.DEFINE_integer( "batch_size", None, - "Number of fold inputs to process together in a single batch. When set, " - "all protein sequences from up to batch_size fold inputs are collected " - "into a single MMseqs2 queryDB for GPU-accelerated batch search. This is " - "much more efficient than sequential processing. If not set, processes " - "each fold input sequentially (current default behavior).", + "Maximum number of unique protein sequences to process in a single " + "MMseqs2 query batch. Protein chains are collected and deduplicated across " + "all loaded fold inputs before batching, so multi-chain JSON files are " + "batched efficiently. If not set, processes each fold input sequentially " + "(current default behavior).", lower_bound=1, ) @@ -681,42 +681,31 @@ def main(_): expanded_fold_inputs.append(fold_input) # Check if batch mode is enabled - if _BATCH_SIZE.value is not None and len(expanded_fold_inputs) > 1: - # Batch mode: process multiple fold inputs together + if _BATCH_SIZE.value is not None: + # Batch mode: process all fold inputs together, chunking by unique + # protein sequences inside DataPipeline.process_batch(). batch_size = _BATCH_SIZE.value print( f"\nBATCH MODE: Processing {len(expanded_fold_inputs)} fold inputs " - f"in batches of {batch_size}\n" + f"with up to {batch_size} unique protein sequences per MMseqs2 batch\n" ) data_pipeline = pipeline.DataPipeline(data_pipeline_config) - # Process in batches - for batch_start in range(0, len(expanded_fold_inputs), batch_size): - batch_end = min( - batch_start + batch_size, len(expanded_fold_inputs) + processed_inputs = data_pipeline.process_batch( + expanded_fold_inputs, + max_unique_sequences_per_batch=batch_size, + ) + + # Store each processed input with its output directory + for processed_input in processed_inputs: + output_dir = os.path.join( + _OUTPUT_DIR.value, processed_input.sanitised_name() ) - batch = expanded_fold_inputs[batch_start:batch_end] - - print( - f"\n--- Processing batch {batch_start // batch_size + 1}: " - f"fold inputs {batch_start + 1} to {batch_end} ---\n" - ) - - # Run batch processing - processed_inputs = data_pipeline.process_batch(batch) - - # Store each processed input with its output directory - for processed_input in processed_inputs: - output_dir = os.path.join( - _OUTPUT_DIR.value, processed_input.sanitised_name() - ) - # Write the processed data JSON - write_fold_input_json(processed_input, output_dir) - processed_fold_inputs.append((processed_input, output_dir)) - print( - f"Fold job {processed_input.name} data pipeline done.\n" - ) + # Write the processed data JSON + write_fold_input_json(processed_input, output_dir) + processed_fold_inputs.append((processed_input, output_dir)) + print(f"Fold job {processed_input.name} data pipeline done.\n") else: # Sequential mode: process each fold input individually if _BATCH_SIZE.value is None and len(expanded_fold_inputs) > 1: diff --git a/run_data_pipeline.py b/run_data_pipeline.py index 9c6712d..57091c3 100644 --- a/run_data_pipeline.py +++ b/run_data_pipeline.py @@ -138,11 +138,11 @@ _TEMP_DIR = flags.DEFINE_string( _BATCH_SIZE = flags.DEFINE_integer( "batch_size", 512, - "Number of fold inputs to process together in a single batch. When set, " - "all protein sequences from up to batch_size fold inputs are collected " - "into a single MMseqs2 queryDB for GPU-accelerated batch search. This is " - "much more efficient than sequential processing. Set to 0 to disable " - "batch mode and process each fold input sequentially.", + "Maximum number of unique protein sequences to process in a single " + "MMseqs2 query batch. Protein chains are collected and deduplicated across " + "all loaded fold inputs before batching, so multi-chain JSON files are " + "batched efficiently. Set to 0 to disable batch mode and process each fold " + "input sequentially.", lower_bound=0, ) @@ -610,49 +610,44 @@ def main(_): rna_mmseqs_db_dir=rna_mmseqs_db_dir, ) - # Process fold inputs - either in batch mode or sequentially + # Process fold inputs - either in unique-sequence batch mode or sequentially output_paths = [] data_pipeline = pipeline.DataPipeline(data_pipeline_config) pipeline_start_time = time.time() - use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0 and len(fold_inputs) > 1 + use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0 mode = "batch" if use_batch else "sequential" if use_batch: - # Batch mode: process multiple fold inputs together + # Batch mode: process all fold inputs together, chunking by unique + # protein sequences inside DataPipeline.process_batch(). batch_size = _BATCH_SIZE.value print(f"\n{'=' * 60}") print( - f"BATCH MODE: Processing {len(fold_inputs)} fold inputs in batches of {batch_size}" + f"BATCH MODE: Processing {len(fold_inputs)} fold inputs with up to " + f"{batch_size} unique protein sequences per MMseqs2 batch" ) print(f"{'=' * 60}\n") - # Process in batches - for batch_start in range(0, len(fold_inputs), batch_size): - batch_end = min(batch_start + batch_size, len(fold_inputs)) - batch = fold_inputs[batch_start:batch_end] + processed_inputs = data_pipeline.process_batch( + fold_inputs, + max_unique_sequences_per_batch=batch_size, + ) - print( - f"\n--- Processing batch {batch_start // batch_size + 1}: " - f"fold inputs {batch_start + 1} to {batch_end} ---\n" + # Write output for each processed input + for fold_input in processed_inputs: + output_subdir = os.path.join( + _OUTPUT_DIR.value, fold_input.sanitised_name() ) - - processed_inputs = data_pipeline.process_batch(batch) - - # Write output for each processed input - for fold_input in processed_inputs: - output_subdir = os.path.join( - _OUTPUT_DIR.value, fold_input.sanitised_name() + output_path = write_fold_input_json(fold_input, output_subdir) + output_paths.append(output_path) + if _QUEUE_DIR.value is not None: + _write_queue_token( + queue_dir=_QUEUE_DIR.value, + fold_input=fold_input, + data_json_path=output_path, ) - output_path = write_fold_input_json(fold_input, output_subdir) - output_paths.append(output_path) - if _QUEUE_DIR.value is not None: - _write_queue_token( - queue_dir=_QUEUE_DIR.value, - fold_input=fold_input, - data_json_path=output_path, - ) - print(f"Fold job {fold_input.name} done.\n") + print(f"Fold job {fold_input.name} done.\n") else: # Sequential mode: process each fold input individually if not use_batch: diff --git a/scripts/run_alphafast.sh b/scripts/run_alphafast.sh index 8148cb6..8209c15 100755 --- a/scripts/run_alphafast.sh +++ b/scripts/run_alphafast.sh @@ -71,7 +71,8 @@ usage() { echo " If both given, --gpu_devices takes precedence." echo " --container IMAGE Container image or .sif path" echo " (default: romerolabduke/alphafast:latest)" - echo " --batch_size N MSA batch size (default: auto = number of inputs)" + echo " --batch_size N Unique protein sequences per MSA batch" + echo " (default: auto = number of inputs)" echo " --backend TYPE Force 'docker' or 'singularity' (default: auto-detect)" echo " --temp_dir DIR Directory for MMseqs temporary files." echo " Recommended on HPC: point this to fast local scratch" @@ -166,7 +167,8 @@ if [ -n "$TEMP_DIR" ]; then TEMP_DIR="$(cd "$TEMP_DIR" && pwd)" fi -# Auto batch size: count input JSON files +# Auto batch size heuristic: count input JSON files. The Python pipeline applies +# this as the maximum number of unique protein sequences per MMseqs2 batch. if [ -z "$BATCH_SIZE" ]; then BATCH_SIZE=$(find "$INPUT_DIR" -maxdepth 1 -name "*.json" -type f | wc -l | tr -d ' ') if [ "$BATCH_SIZE" -eq 0 ]; then diff --git a/scripts/run_multigpu.sh b/scripts/run_multigpu.sh index 108c530..0c1e999 100755 --- a/scripts/run_multigpu.sh +++ b/scripts/run_multigpu.sh @@ -20,7 +20,7 @@ # msa_output_dir - Output directory for MSA JSONs # af_output_dir - Output directory for inference outputs # num_gpus - Number of GPUs (derived from gpu_list by caller) -# batch_size - Batch size per GPU for MSA (default: 512) +# batch_size - Unique protein sequences per MSA batch per GPU (default: 512) # gpu_list - Comma-separated GPU device indices (e.g. "6,7") # mmseqs_threads - CPU threads per GPU (default: total_cores / num_gpus) @@ -32,7 +32,7 @@ usage() { echo " msa_output_dir: output directory for MSA JSONs" echo " af_output_dir: output directory for inference outputs" echo " num_gpus: number of GPUs to use" - echo " batch_size: batch size per GPU for MSA (default: partition size)" + echo " batch_size: unique protein sequences per MSA batch per GPU (default: 512)" echo " gpu_list: comma-separated GPU indices (default: 0,1,...,N-1)" echo " mmseqs_threads: CPU threads per GPU (default: total_cores / num_gpus)" } diff --git a/src/alphafold3/data/pipeline.py b/src/alphafold3/data/pipeline.py index 4b1a814..9b32362 100644 --- a/src/alphafold3/data/pipeline.py +++ b/src/alphafold3/data/pipeline.py @@ -1393,23 +1393,33 @@ class DataPipeline: return dataclasses.replace(fold_input, chains=ordered_chains) def process_batch( - self, fold_inputs: Sequence[folding_input.Input] + self, + fold_inputs: Sequence[folding_input.Input], + max_unique_sequences_per_batch: int | None = None, ) -> Sequence[folding_input.Input]: """Process multiple fold inputs with batched MSA search. This method is more efficient than processing individually because: - 1. Single createdb call for ALL protein sequences - 2. GPU processes all sequences in parallel + 1. Single createdb call per unique sequence batch + 2. GPU processes sequences in parallel within each batch 3. Amortizes GPU kernel launch overhead Args: fold_inputs: Sequence of fold inputs to process together. + max_unique_sequences_per_batch: Maximum number of unique protein + sequences per MMseqs2 query batch. If None, all unique protein + sequences are searched in a single MMseqs2 batch. Returns: Sequence of processed fold inputs with MSA and templates. """ if not fold_inputs: return [] + if ( + max_unique_sequences_per_batch is not None + and max_unique_sequences_per_batch < 1 + ): + raise ValueError("max_unique_sequences_per_batch must be at least 1") # Check if using MMseqs2 (required for batch mode) is_mmseqs = isinstance( @@ -1548,8 +1558,35 @@ class DataPipeline: temp_dir=self._temp_dir, ) - logging.info("Running batch MSA search across all databases...") - msa_results = batch_searcher.search_all_databases_pipelined(all_sequences) + sequence_items = list(all_sequences.items()) + if max_unique_sequences_per_batch is None: + sequence_batches = [all_sequences] + else: + sequence_batches = [ + dict(sequence_items[i : i + max_unique_sequences_per_batch]) + for i in range( + 0, len(sequence_items), max_unique_sequences_per_batch + ) + ] + + logging.info( + "Running batch MSA search across all databases in %d " + "unique-sequence batch(es)...", + len(sequence_batches), + ) + msa_results_by_db = {} + for batch_idx, sequence_batch in enumerate(sequence_batches, start=1): + logging.info( + "Running MMseqs2 unique-sequence batch %d/%d (%d sequences)...", + batch_idx, + len(sequence_batches), + len(sequence_batch), + ) + batch_msa_results = batch_searcher.search_all_databases_pipelined( + sequence_batch + ) + for db_name, db_result in batch_msa_results.items(): + msa_results_by_db.setdefault(db_name, {}).update(db_result.results) logging.info( "Batch MSA search completed in %.2f seconds", time.time() - batch_start_time, @@ -1561,20 +1598,24 @@ class DataPipeline: for seq_id, sequence in all_sequences.items(): # Get results for each database - uniref90_result = msa_results.get("uniref90") - mgnify_result = msa_results.get("mgnify") - small_bfd_result = msa_results.get("small_bfd") - uniprot_result = msa_results.get("uniprot") + uniref90_results = msa_results_by_db.get("uniref90", {}) + mgnify_results = msa_results_by_db.get("mgnify", {}) + small_bfd_results = msa_results_by_db.get("small_bfd", {}) + uniprot_results = msa_results_by_db.get("uniprot", {}) # Parse A3M to Msa objects uniref90_a3m = ( - uniref90_result.results[seq_id].a3m if uniref90_result else "" + uniref90_results[seq_id].a3m if seq_id in uniref90_results else "" + ) + mgnify_a3m = ( + mgnify_results[seq_id].a3m if seq_id in mgnify_results else "" ) - mgnify_a3m = mgnify_result.results[seq_id].a3m if mgnify_result else "" small_bfd_a3m = ( - small_bfd_result.results[seq_id].a3m if small_bfd_result else "" + small_bfd_results[seq_id].a3m if seq_id in small_bfd_results else "" + ) + uniprot_a3m = ( + uniprot_results[seq_id].a3m if seq_id in uniprot_results else "" ) - uniprot_a3m = uniprot_result.results[seq_id].a3m if uniprot_result else "" uniref90_msa = msa.Msa.from_a3m( query_sequence=sequence,