adding batch size patch to catch low N edge cases

This commit is contained in:
Ben Perry
2026-05-04 12:18:27 -04:00
parent ebacb8e45c
commit 2cd8c52d9f
7 changed files with 113 additions and 86 deletions

View File

@@ -49,9 +49,9 @@ The Modal deployment uses an alternative **producer-consumer** architecture via
## Batch Size Tuning
The `--batch_size` flag controls how many input JSON files are processed together in a single MMseqs2 GPU search. All protein sequences from a batch are collected into a single MMseqs2 query database, which is significantly more efficient than sequential processing.
The `--batch_size` flag controls how many unique protein sequences are processed together in a single MMseqs2 GPU search. Protein chains are collected and deduplicated across the loaded input JSON files, then split into MMseqs2 query batches of up to `--batch_size` unique sequences. This also batches multiple protein chains from a single JSON file.
When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir`. For manual runs:
When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir` as a simple heuristic. For manual runs:
```bash
python run_data_pipeline.py \
@@ -118,7 +118,7 @@ Note: If `CUDA_VISIBLE_DEVICES` is set, `--gpu_device` refers to the index withi
| `--use_mmseqs_gpu` | `true` | Use GPU-accelerated MMseqs2 |
| `--mmseqs_n_threads` | all CPUs | CPU threads for MMseqs2 non-GPU operations |
| `--mmseqs_sequential` | `false` | Run database searches sequentially (lower memory) |
| `--batch_size` | -- | Batch multiple inputs into one MMseqs2 search |
| `--batch_size` | -- | Unique protein sequences per MMseqs2 batch |
### Template Search

View File

@@ -203,4 +203,4 @@ inputs/
protein_3.json
```
Each file is processed independently. When `--batch_size` is set, protein sequences from multiple files are grouped into efficient batched MMseqs2-GPU searches.
Each file is loaded as a separate fold input. When `--batch_size` is set, protein chains are collected and deduplicated across the loaded inputs, then grouped into efficient batched MMseqs2-GPU searches. Multiple protein chains from a single JSON file can therefore be searched in the same MMseqs2 batch.

View File

@@ -147,11 +147,11 @@ _MMSEQS_SEQUENTIAL = flags.DEFINE_bool(
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size",
None,
"Number of fold inputs to process together in a single batch. When set, "
"all protein sequences from up to batch_size fold inputs are collected "
"into a single MMseqs2 queryDB for GPU-accelerated batch search. This is "
"much more efficient than sequential processing. If not set, processes "
"each fold input sequentially (current default behavior).",
"Maximum number of unique protein sequences to process in a single "
"MMseqs2 query batch. Protein chains are collected and deduplicated across "
"all loaded fold inputs before batching, so multi-chain JSON files are "
"batched efficiently. If not set, processes each fold input sequentially "
"(current default behavior).",
lower_bound=1,
)
@@ -681,42 +681,31 @@ def main(_):
expanded_fold_inputs.append(fold_input)
# Check if batch mode is enabled
if _BATCH_SIZE.value is not None and len(expanded_fold_inputs) > 1:
# Batch mode: process multiple fold inputs together
if _BATCH_SIZE.value is not None:
# Batch mode: process all fold inputs together, chunking by unique
# protein sequences inside DataPipeline.process_batch().
batch_size = _BATCH_SIZE.value
print(
f"\nBATCH MODE: Processing {len(expanded_fold_inputs)} fold inputs "
f"in batches of {batch_size}\n"
f"with up to {batch_size} unique protein sequences per MMseqs2 batch\n"
)
data_pipeline = pipeline.DataPipeline(data_pipeline_config)
# Process in batches
for batch_start in range(0, len(expanded_fold_inputs), batch_size):
batch_end = min(
batch_start + batch_size, len(expanded_fold_inputs)
processed_inputs = data_pipeline.process_batch(
expanded_fold_inputs,
max_unique_sequences_per_batch=batch_size,
)
# Store each processed input with its output directory
for processed_input in processed_inputs:
output_dir = os.path.join(
_OUTPUT_DIR.value, processed_input.sanitised_name()
)
batch = expanded_fold_inputs[batch_start:batch_end]
print(
f"\n--- Processing batch {batch_start // batch_size + 1}: "
f"fold inputs {batch_start + 1} to {batch_end} ---\n"
)
# Run batch processing
processed_inputs = data_pipeline.process_batch(batch)
# Store each processed input with its output directory
for processed_input in processed_inputs:
output_dir = os.path.join(
_OUTPUT_DIR.value, processed_input.sanitised_name()
)
# Write the processed data JSON
write_fold_input_json(processed_input, output_dir)
processed_fold_inputs.append((processed_input, output_dir))
print(
f"Fold job {processed_input.name} data pipeline done.\n"
)
# Write the processed data JSON
write_fold_input_json(processed_input, output_dir)
processed_fold_inputs.append((processed_input, output_dir))
print(f"Fold job {processed_input.name} data pipeline done.\n")
else:
# Sequential mode: process each fold input individually
if _BATCH_SIZE.value is None and len(expanded_fold_inputs) > 1:

View File

@@ -138,11 +138,11 @@ _TEMP_DIR = flags.DEFINE_string(
_BATCH_SIZE = flags.DEFINE_integer(
"batch_size",
512,
"Number of fold inputs to process together in a single batch. When set, "
"all protein sequences from up to batch_size fold inputs are collected "
"into a single MMseqs2 queryDB for GPU-accelerated batch search. This is "
"much more efficient than sequential processing. Set to 0 to disable "
"batch mode and process each fold input sequentially.",
"Maximum number of unique protein sequences to process in a single "
"MMseqs2 query batch. Protein chains are collected and deduplicated across "
"all loaded fold inputs before batching, so multi-chain JSON files are "
"batched efficiently. Set to 0 to disable batch mode and process each fold "
"input sequentially.",
lower_bound=0,
)
@@ -610,49 +610,44 @@ def main(_):
rna_mmseqs_db_dir=rna_mmseqs_db_dir,
)
# Process fold inputs - either in batch mode or sequentially
# Process fold inputs - either in unique-sequence batch mode or sequentially
output_paths = []
data_pipeline = pipeline.DataPipeline(data_pipeline_config)
pipeline_start_time = time.time()
use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0 and len(fold_inputs) > 1
use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0
mode = "batch" if use_batch else "sequential"
if use_batch:
# Batch mode: process multiple fold inputs together
# Batch mode: process all fold inputs together, chunking by unique
# protein sequences inside DataPipeline.process_batch().
batch_size = _BATCH_SIZE.value
print(f"\n{'=' * 60}")
print(
f"BATCH MODE: Processing {len(fold_inputs)} fold inputs in batches of {batch_size}"
f"BATCH MODE: Processing {len(fold_inputs)} fold inputs with up to "
f"{batch_size} unique protein sequences per MMseqs2 batch"
)
print(f"{'=' * 60}\n")
# Process in batches
for batch_start in range(0, len(fold_inputs), batch_size):
batch_end = min(batch_start + batch_size, len(fold_inputs))
batch = fold_inputs[batch_start:batch_end]
processed_inputs = data_pipeline.process_batch(
fold_inputs,
max_unique_sequences_per_batch=batch_size,
)
print(
f"\n--- Processing batch {batch_start // batch_size + 1}: "
f"fold inputs {batch_start + 1} to {batch_end} ---\n"
# Write output for each processed input
for fold_input in processed_inputs:
output_subdir = os.path.join(
_OUTPUT_DIR.value, fold_input.sanitised_name()
)
processed_inputs = data_pipeline.process_batch(batch)
# Write output for each processed input
for fold_input in processed_inputs:
output_subdir = os.path.join(
_OUTPUT_DIR.value, fold_input.sanitised_name()
output_path = write_fold_input_json(fold_input, output_subdir)
output_paths.append(output_path)
if _QUEUE_DIR.value is not None:
_write_queue_token(
queue_dir=_QUEUE_DIR.value,
fold_input=fold_input,
data_json_path=output_path,
)
output_path = write_fold_input_json(fold_input, output_subdir)
output_paths.append(output_path)
if _QUEUE_DIR.value is not None:
_write_queue_token(
queue_dir=_QUEUE_DIR.value,
fold_input=fold_input,
data_json_path=output_path,
)
print(f"Fold job {fold_input.name} done.\n")
print(f"Fold job {fold_input.name} done.\n")
else:
# Sequential mode: process each fold input individually
if not use_batch:

View File

@@ -71,7 +71,8 @@ usage() {
echo " If both given, --gpu_devices takes precedence."
echo " --container IMAGE Container image or .sif path"
echo " (default: romerolabduke/alphafast:latest)"
echo " --batch_size N MSA batch size (default: auto = number of inputs)"
echo " --batch_size N Unique protein sequences per MSA batch"
echo " (default: auto = number of inputs)"
echo " --backend TYPE Force 'docker' or 'singularity' (default: auto-detect)"
echo " --temp_dir DIR Directory for MMseqs temporary files."
echo " Recommended on HPC: point this to fast local scratch"
@@ -166,7 +167,8 @@ if [ -n "$TEMP_DIR" ]; then
TEMP_DIR="$(cd "$TEMP_DIR" && pwd)"
fi
# Auto batch size: count input JSON files
# Auto batch size heuristic: count input JSON files. The Python pipeline applies
# this as the maximum number of unique protein sequences per MMseqs2 batch.
if [ -z "$BATCH_SIZE" ]; then
BATCH_SIZE=$(find "$INPUT_DIR" -maxdepth 1 -name "*.json" -type f | wc -l | tr -d ' ')
if [ "$BATCH_SIZE" -eq 0 ]; then

View File

@@ -20,7 +20,7 @@
# msa_output_dir - Output directory for MSA JSONs
# af_output_dir - Output directory for inference outputs
# num_gpus - Number of GPUs (derived from gpu_list by caller)
# batch_size - Batch size per GPU for MSA (default: 512)
# batch_size - Unique protein sequences per MSA batch per GPU (default: 512)
# gpu_list - Comma-separated GPU device indices (e.g. "6,7")
# mmseqs_threads - CPU threads per GPU (default: total_cores / num_gpus)
@@ -32,7 +32,7 @@ usage() {
echo " msa_output_dir: output directory for MSA JSONs"
echo " af_output_dir: output directory for inference outputs"
echo " num_gpus: number of GPUs to use"
echo " batch_size: batch size per GPU for MSA (default: partition size)"
echo " batch_size: unique protein sequences per MSA batch per GPU (default: 512)"
echo " gpu_list: comma-separated GPU indices (default: 0,1,...,N-1)"
echo " mmseqs_threads: CPU threads per GPU (default: total_cores / num_gpus)"
}

View File

@@ -1393,23 +1393,33 @@ class DataPipeline:
return dataclasses.replace(fold_input, chains=ordered_chains)
def process_batch(
self, fold_inputs: Sequence[folding_input.Input]
self,
fold_inputs: Sequence[folding_input.Input],
max_unique_sequences_per_batch: int | None = None,
) -> Sequence[folding_input.Input]:
"""Process multiple fold inputs with batched MSA search.
This method is more efficient than processing individually because:
1. Single createdb call for ALL protein sequences
2. GPU processes all sequences in parallel
1. Single createdb call per unique sequence batch
2. GPU processes sequences in parallel within each batch
3. Amortizes GPU kernel launch overhead
Args:
fold_inputs: Sequence of fold inputs to process together.
max_unique_sequences_per_batch: Maximum number of unique protein
sequences per MMseqs2 query batch. If None, all unique protein
sequences are searched in a single MMseqs2 batch.
Returns:
Sequence of processed fold inputs with MSA and templates.
"""
if not fold_inputs:
return []
if (
max_unique_sequences_per_batch is not None
and max_unique_sequences_per_batch < 1
):
raise ValueError("max_unique_sequences_per_batch must be at least 1")
# Check if using MMseqs2 (required for batch mode)
is_mmseqs = isinstance(
@@ -1548,8 +1558,35 @@ class DataPipeline:
temp_dir=self._temp_dir,
)
logging.info("Running batch MSA search across all databases...")
msa_results = batch_searcher.search_all_databases_pipelined(all_sequences)
sequence_items = list(all_sequences.items())
if max_unique_sequences_per_batch is None:
sequence_batches = [all_sequences]
else:
sequence_batches = [
dict(sequence_items[i : i + max_unique_sequences_per_batch])
for i in range(
0, len(sequence_items), max_unique_sequences_per_batch
)
]
logging.info(
"Running batch MSA search across all databases in %d "
"unique-sequence batch(es)...",
len(sequence_batches),
)
msa_results_by_db = {}
for batch_idx, sequence_batch in enumerate(sequence_batches, start=1):
logging.info(
"Running MMseqs2 unique-sequence batch %d/%d (%d sequences)...",
batch_idx,
len(sequence_batches),
len(sequence_batch),
)
batch_msa_results = batch_searcher.search_all_databases_pipelined(
sequence_batch
)
for db_name, db_result in batch_msa_results.items():
msa_results_by_db.setdefault(db_name, {}).update(db_result.results)
logging.info(
"Batch MSA search completed in %.2f seconds",
time.time() - batch_start_time,
@@ -1561,20 +1598,24 @@ class DataPipeline:
for seq_id, sequence in all_sequences.items():
# Get results for each database
uniref90_result = msa_results.get("uniref90")
mgnify_result = msa_results.get("mgnify")
small_bfd_result = msa_results.get("small_bfd")
uniprot_result = msa_results.get("uniprot")
uniref90_results = msa_results_by_db.get("uniref90", {})
mgnify_results = msa_results_by_db.get("mgnify", {})
small_bfd_results = msa_results_by_db.get("small_bfd", {})
uniprot_results = msa_results_by_db.get("uniprot", {})
# Parse A3M to Msa objects
uniref90_a3m = (
uniref90_result.results[seq_id].a3m if uniref90_result else ""
uniref90_results[seq_id].a3m if seq_id in uniref90_results else ""
)
mgnify_a3m = (
mgnify_results[seq_id].a3m if seq_id in mgnify_results else ""
)
mgnify_a3m = mgnify_result.results[seq_id].a3m if mgnify_result else ""
small_bfd_a3m = (
small_bfd_result.results[seq_id].a3m if small_bfd_result else ""
small_bfd_results[seq_id].a3m if seq_id in small_bfd_results else ""
)
uniprot_a3m = (
uniprot_results[seq_id].a3m if seq_id in uniprot_results else ""
)
uniprot_a3m = uniprot_result.results[seq_id].a3m if uniprot_result else ""
uniref90_msa = msa.Msa.from_a3m(
query_sequence=sequence,