mirror of
https://github.com/RomeroLab/alphafast.git
synced 2026-06-04 13:30:25 +08:00
adding batch size patch to catch low N edge cases
This commit is contained in:
@@ -49,9 +49,9 @@ The Modal deployment uses an alternative **producer-consumer** architecture via
|
||||
|
||||
## Batch Size Tuning
|
||||
|
||||
The `--batch_size` flag controls how many input JSON files are processed together in a single MMseqs2 GPU search. All protein sequences from a batch are collected into a single MMseqs2 query database, which is significantly more efficient than sequential processing.
|
||||
The `--batch_size` flag controls how many unique protein sequences are processed together in a single MMseqs2 GPU search. Protein chains are collected and deduplicated across the loaded input JSON files, then split into MMseqs2 query batches of up to `--batch_size` unique sequences. This also batches multiple protein chains from a single JSON file.
|
||||
|
||||
When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir`. For manual runs:
|
||||
When using `run_alphafast.sh`, the batch size defaults to the number of input JSON files in `--input_dir` as a simple heuristic. For manual runs:
|
||||
|
||||
```bash
|
||||
python run_data_pipeline.py \
|
||||
@@ -118,7 +118,7 @@ Note: If `CUDA_VISIBLE_DEVICES` is set, `--gpu_device` refers to the index withi
|
||||
| `--use_mmseqs_gpu` | `true` | Use GPU-accelerated MMseqs2 |
|
||||
| `--mmseqs_n_threads` | all CPUs | CPU threads for MMseqs2 non-GPU operations |
|
||||
| `--mmseqs_sequential` | `false` | Run database searches sequentially (lower memory) |
|
||||
| `--batch_size` | -- | Batch multiple inputs into one MMseqs2 search |
|
||||
| `--batch_size` | -- | Unique protein sequences per MMseqs2 batch |
|
||||
|
||||
### Template Search
|
||||
|
||||
|
||||
@@ -203,4 +203,4 @@ inputs/
|
||||
protein_3.json
|
||||
```
|
||||
|
||||
Each file is processed independently. When `--batch_size` is set, protein sequences from multiple files are grouped into efficient batched MMseqs2-GPU searches.
|
||||
Each file is loaded as a separate fold input. When `--batch_size` is set, protein chains are collected and deduplicated across the loaded inputs, then grouped into efficient batched MMseqs2-GPU searches. Multiple protein chains from a single JSON file can therefore be searched in the same MMseqs2 batch.
|
||||
|
||||
@@ -147,11 +147,11 @@ _MMSEQS_SEQUENTIAL = flags.DEFINE_bool(
|
||||
_BATCH_SIZE = flags.DEFINE_integer(
|
||||
"batch_size",
|
||||
None,
|
||||
"Number of fold inputs to process together in a single batch. When set, "
|
||||
"all protein sequences from up to batch_size fold inputs are collected "
|
||||
"into a single MMseqs2 queryDB for GPU-accelerated batch search. This is "
|
||||
"much more efficient than sequential processing. If not set, processes "
|
||||
"each fold input sequentially (current default behavior).",
|
||||
"Maximum number of unique protein sequences to process in a single "
|
||||
"MMseqs2 query batch. Protein chains are collected and deduplicated across "
|
||||
"all loaded fold inputs before batching, so multi-chain JSON files are "
|
||||
"batched efficiently. If not set, processes each fold input sequentially "
|
||||
"(current default behavior).",
|
||||
lower_bound=1,
|
||||
)
|
||||
|
||||
@@ -681,42 +681,31 @@ def main(_):
|
||||
expanded_fold_inputs.append(fold_input)
|
||||
|
||||
# Check if batch mode is enabled
|
||||
if _BATCH_SIZE.value is not None and len(expanded_fold_inputs) > 1:
|
||||
# Batch mode: process multiple fold inputs together
|
||||
if _BATCH_SIZE.value is not None:
|
||||
# Batch mode: process all fold inputs together, chunking by unique
|
||||
# protein sequences inside DataPipeline.process_batch().
|
||||
batch_size = _BATCH_SIZE.value
|
||||
print(
|
||||
f"\nBATCH MODE: Processing {len(expanded_fold_inputs)} fold inputs "
|
||||
f"in batches of {batch_size}\n"
|
||||
f"with up to {batch_size} unique protein sequences per MMseqs2 batch\n"
|
||||
)
|
||||
|
||||
data_pipeline = pipeline.DataPipeline(data_pipeline_config)
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, len(expanded_fold_inputs), batch_size):
|
||||
batch_end = min(
|
||||
batch_start + batch_size, len(expanded_fold_inputs)
|
||||
processed_inputs = data_pipeline.process_batch(
|
||||
expanded_fold_inputs,
|
||||
max_unique_sequences_per_batch=batch_size,
|
||||
)
|
||||
|
||||
# Store each processed input with its output directory
|
||||
for processed_input in processed_inputs:
|
||||
output_dir = os.path.join(
|
||||
_OUTPUT_DIR.value, processed_input.sanitised_name()
|
||||
)
|
||||
batch = expanded_fold_inputs[batch_start:batch_end]
|
||||
|
||||
print(
|
||||
f"\n--- Processing batch {batch_start // batch_size + 1}: "
|
||||
f"fold inputs {batch_start + 1} to {batch_end} ---\n"
|
||||
)
|
||||
|
||||
# Run batch processing
|
||||
processed_inputs = data_pipeline.process_batch(batch)
|
||||
|
||||
# Store each processed input with its output directory
|
||||
for processed_input in processed_inputs:
|
||||
output_dir = os.path.join(
|
||||
_OUTPUT_DIR.value, processed_input.sanitised_name()
|
||||
)
|
||||
# Write the processed data JSON
|
||||
write_fold_input_json(processed_input, output_dir)
|
||||
processed_fold_inputs.append((processed_input, output_dir))
|
||||
print(
|
||||
f"Fold job {processed_input.name} data pipeline done.\n"
|
||||
)
|
||||
# Write the processed data JSON
|
||||
write_fold_input_json(processed_input, output_dir)
|
||||
processed_fold_inputs.append((processed_input, output_dir))
|
||||
print(f"Fold job {processed_input.name} data pipeline done.\n")
|
||||
else:
|
||||
# Sequential mode: process each fold input individually
|
||||
if _BATCH_SIZE.value is None and len(expanded_fold_inputs) > 1:
|
||||
|
||||
@@ -138,11 +138,11 @@ _TEMP_DIR = flags.DEFINE_string(
|
||||
_BATCH_SIZE = flags.DEFINE_integer(
|
||||
"batch_size",
|
||||
512,
|
||||
"Number of fold inputs to process together in a single batch. When set, "
|
||||
"all protein sequences from up to batch_size fold inputs are collected "
|
||||
"into a single MMseqs2 queryDB for GPU-accelerated batch search. This is "
|
||||
"much more efficient than sequential processing. Set to 0 to disable "
|
||||
"batch mode and process each fold input sequentially.",
|
||||
"Maximum number of unique protein sequences to process in a single "
|
||||
"MMseqs2 query batch. Protein chains are collected and deduplicated across "
|
||||
"all loaded fold inputs before batching, so multi-chain JSON files are "
|
||||
"batched efficiently. Set to 0 to disable batch mode and process each fold "
|
||||
"input sequentially.",
|
||||
lower_bound=0,
|
||||
)
|
||||
|
||||
@@ -610,49 +610,44 @@ def main(_):
|
||||
rna_mmseqs_db_dir=rna_mmseqs_db_dir,
|
||||
)
|
||||
|
||||
# Process fold inputs - either in batch mode or sequentially
|
||||
# Process fold inputs - either in unique-sequence batch mode or sequentially
|
||||
output_paths = []
|
||||
data_pipeline = pipeline.DataPipeline(data_pipeline_config)
|
||||
|
||||
pipeline_start_time = time.time()
|
||||
use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0 and len(fold_inputs) > 1
|
||||
use_batch = _BATCH_SIZE.value and _BATCH_SIZE.value > 0
|
||||
mode = "batch" if use_batch else "sequential"
|
||||
|
||||
if use_batch:
|
||||
# Batch mode: process multiple fold inputs together
|
||||
# Batch mode: process all fold inputs together, chunking by unique
|
||||
# protein sequences inside DataPipeline.process_batch().
|
||||
batch_size = _BATCH_SIZE.value
|
||||
print(f"\n{'=' * 60}")
|
||||
print(
|
||||
f"BATCH MODE: Processing {len(fold_inputs)} fold inputs in batches of {batch_size}"
|
||||
f"BATCH MODE: Processing {len(fold_inputs)} fold inputs with up to "
|
||||
f"{batch_size} unique protein sequences per MMseqs2 batch"
|
||||
)
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, len(fold_inputs), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(fold_inputs))
|
||||
batch = fold_inputs[batch_start:batch_end]
|
||||
processed_inputs = data_pipeline.process_batch(
|
||||
fold_inputs,
|
||||
max_unique_sequences_per_batch=batch_size,
|
||||
)
|
||||
|
||||
print(
|
||||
f"\n--- Processing batch {batch_start // batch_size + 1}: "
|
||||
f"fold inputs {batch_start + 1} to {batch_end} ---\n"
|
||||
# Write output for each processed input
|
||||
for fold_input in processed_inputs:
|
||||
output_subdir = os.path.join(
|
||||
_OUTPUT_DIR.value, fold_input.sanitised_name()
|
||||
)
|
||||
|
||||
processed_inputs = data_pipeline.process_batch(batch)
|
||||
|
||||
# Write output for each processed input
|
||||
for fold_input in processed_inputs:
|
||||
output_subdir = os.path.join(
|
||||
_OUTPUT_DIR.value, fold_input.sanitised_name()
|
||||
output_path = write_fold_input_json(fold_input, output_subdir)
|
||||
output_paths.append(output_path)
|
||||
if _QUEUE_DIR.value is not None:
|
||||
_write_queue_token(
|
||||
queue_dir=_QUEUE_DIR.value,
|
||||
fold_input=fold_input,
|
||||
data_json_path=output_path,
|
||||
)
|
||||
output_path = write_fold_input_json(fold_input, output_subdir)
|
||||
output_paths.append(output_path)
|
||||
if _QUEUE_DIR.value is not None:
|
||||
_write_queue_token(
|
||||
queue_dir=_QUEUE_DIR.value,
|
||||
fold_input=fold_input,
|
||||
data_json_path=output_path,
|
||||
)
|
||||
print(f"Fold job {fold_input.name} done.\n")
|
||||
print(f"Fold job {fold_input.name} done.\n")
|
||||
else:
|
||||
# Sequential mode: process each fold input individually
|
||||
if not use_batch:
|
||||
|
||||
@@ -71,7 +71,8 @@ usage() {
|
||||
echo " If both given, --gpu_devices takes precedence."
|
||||
echo " --container IMAGE Container image or .sif path"
|
||||
echo " (default: romerolabduke/alphafast:latest)"
|
||||
echo " --batch_size N MSA batch size (default: auto = number of inputs)"
|
||||
echo " --batch_size N Unique protein sequences per MSA batch"
|
||||
echo " (default: auto = number of inputs)"
|
||||
echo " --backend TYPE Force 'docker' or 'singularity' (default: auto-detect)"
|
||||
echo " --temp_dir DIR Directory for MMseqs temporary files."
|
||||
echo " Recommended on HPC: point this to fast local scratch"
|
||||
@@ -166,7 +167,8 @@ if [ -n "$TEMP_DIR" ]; then
|
||||
TEMP_DIR="$(cd "$TEMP_DIR" && pwd)"
|
||||
fi
|
||||
|
||||
# Auto batch size: count input JSON files
|
||||
# Auto batch size heuristic: count input JSON files. The Python pipeline applies
|
||||
# this as the maximum number of unique protein sequences per MMseqs2 batch.
|
||||
if [ -z "$BATCH_SIZE" ]; then
|
||||
BATCH_SIZE=$(find "$INPUT_DIR" -maxdepth 1 -name "*.json" -type f | wc -l | tr -d ' ')
|
||||
if [ "$BATCH_SIZE" -eq 0 ]; then
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
# msa_output_dir - Output directory for MSA JSONs
|
||||
# af_output_dir - Output directory for inference outputs
|
||||
# num_gpus - Number of GPUs (derived from gpu_list by caller)
|
||||
# batch_size - Batch size per GPU for MSA (default: 512)
|
||||
# batch_size - Unique protein sequences per MSA batch per GPU (default: 512)
|
||||
# gpu_list - Comma-separated GPU device indices (e.g. "6,7")
|
||||
# mmseqs_threads - CPU threads per GPU (default: total_cores / num_gpus)
|
||||
|
||||
@@ -32,7 +32,7 @@ usage() {
|
||||
echo " msa_output_dir: output directory for MSA JSONs"
|
||||
echo " af_output_dir: output directory for inference outputs"
|
||||
echo " num_gpus: number of GPUs to use"
|
||||
echo " batch_size: batch size per GPU for MSA (default: partition size)"
|
||||
echo " batch_size: unique protein sequences per MSA batch per GPU (default: 512)"
|
||||
echo " gpu_list: comma-separated GPU indices (default: 0,1,...,N-1)"
|
||||
echo " mmseqs_threads: CPU threads per GPU (default: total_cores / num_gpus)"
|
||||
}
|
||||
|
||||
@@ -1393,23 +1393,33 @@ class DataPipeline:
|
||||
return dataclasses.replace(fold_input, chains=ordered_chains)
|
||||
|
||||
def process_batch(
|
||||
self, fold_inputs: Sequence[folding_input.Input]
|
||||
self,
|
||||
fold_inputs: Sequence[folding_input.Input],
|
||||
max_unique_sequences_per_batch: int | None = None,
|
||||
) -> Sequence[folding_input.Input]:
|
||||
"""Process multiple fold inputs with batched MSA search.
|
||||
|
||||
This method is more efficient than processing individually because:
|
||||
1. Single createdb call for ALL protein sequences
|
||||
2. GPU processes all sequences in parallel
|
||||
1. Single createdb call per unique sequence batch
|
||||
2. GPU processes sequences in parallel within each batch
|
||||
3. Amortizes GPU kernel launch overhead
|
||||
|
||||
Args:
|
||||
fold_inputs: Sequence of fold inputs to process together.
|
||||
max_unique_sequences_per_batch: Maximum number of unique protein
|
||||
sequences per MMseqs2 query batch. If None, all unique protein
|
||||
sequences are searched in a single MMseqs2 batch.
|
||||
|
||||
Returns:
|
||||
Sequence of processed fold inputs with MSA and templates.
|
||||
"""
|
||||
if not fold_inputs:
|
||||
return []
|
||||
if (
|
||||
max_unique_sequences_per_batch is not None
|
||||
and max_unique_sequences_per_batch < 1
|
||||
):
|
||||
raise ValueError("max_unique_sequences_per_batch must be at least 1")
|
||||
|
||||
# Check if using MMseqs2 (required for batch mode)
|
||||
is_mmseqs = isinstance(
|
||||
@@ -1548,8 +1558,35 @@ class DataPipeline:
|
||||
temp_dir=self._temp_dir,
|
||||
)
|
||||
|
||||
logging.info("Running batch MSA search across all databases...")
|
||||
msa_results = batch_searcher.search_all_databases_pipelined(all_sequences)
|
||||
sequence_items = list(all_sequences.items())
|
||||
if max_unique_sequences_per_batch is None:
|
||||
sequence_batches = [all_sequences]
|
||||
else:
|
||||
sequence_batches = [
|
||||
dict(sequence_items[i : i + max_unique_sequences_per_batch])
|
||||
for i in range(
|
||||
0, len(sequence_items), max_unique_sequences_per_batch
|
||||
)
|
||||
]
|
||||
|
||||
logging.info(
|
||||
"Running batch MSA search across all databases in %d "
|
||||
"unique-sequence batch(es)...",
|
||||
len(sequence_batches),
|
||||
)
|
||||
msa_results_by_db = {}
|
||||
for batch_idx, sequence_batch in enumerate(sequence_batches, start=1):
|
||||
logging.info(
|
||||
"Running MMseqs2 unique-sequence batch %d/%d (%d sequences)...",
|
||||
batch_idx,
|
||||
len(sequence_batches),
|
||||
len(sequence_batch),
|
||||
)
|
||||
batch_msa_results = batch_searcher.search_all_databases_pipelined(
|
||||
sequence_batch
|
||||
)
|
||||
for db_name, db_result in batch_msa_results.items():
|
||||
msa_results_by_db.setdefault(db_name, {}).update(db_result.results)
|
||||
logging.info(
|
||||
"Batch MSA search completed in %.2f seconds",
|
||||
time.time() - batch_start_time,
|
||||
@@ -1561,20 +1598,24 @@ class DataPipeline:
|
||||
|
||||
for seq_id, sequence in all_sequences.items():
|
||||
# Get results for each database
|
||||
uniref90_result = msa_results.get("uniref90")
|
||||
mgnify_result = msa_results.get("mgnify")
|
||||
small_bfd_result = msa_results.get("small_bfd")
|
||||
uniprot_result = msa_results.get("uniprot")
|
||||
uniref90_results = msa_results_by_db.get("uniref90", {})
|
||||
mgnify_results = msa_results_by_db.get("mgnify", {})
|
||||
small_bfd_results = msa_results_by_db.get("small_bfd", {})
|
||||
uniprot_results = msa_results_by_db.get("uniprot", {})
|
||||
|
||||
# Parse A3M to Msa objects
|
||||
uniref90_a3m = (
|
||||
uniref90_result.results[seq_id].a3m if uniref90_result else ""
|
||||
uniref90_results[seq_id].a3m if seq_id in uniref90_results else ""
|
||||
)
|
||||
mgnify_a3m = (
|
||||
mgnify_results[seq_id].a3m if seq_id in mgnify_results else ""
|
||||
)
|
||||
mgnify_a3m = mgnify_result.results[seq_id].a3m if mgnify_result else ""
|
||||
small_bfd_a3m = (
|
||||
small_bfd_result.results[seq_id].a3m if small_bfd_result else ""
|
||||
small_bfd_results[seq_id].a3m if seq_id in small_bfd_results else ""
|
||||
)
|
||||
uniprot_a3m = (
|
||||
uniprot_results[seq_id].a3m if seq_id in uniprot_results else ""
|
||||
)
|
||||
uniprot_a3m = uniprot_result.results[seq_id].a3m if uniprot_result else ""
|
||||
|
||||
uniref90_msa = msa.Msa.from_a3m(
|
||||
query_sequence=sequence,
|
||||
|
||||
Reference in New Issue
Block a user