add 'boltzgen merge' command

This commit is contained in:
Tim O'Donnell
2025-12-01 17:03:36 -05:00
parent e22556a099
commit 869a449582
2 changed files with 279 additions and 0 deletions

View File

@@ -475,6 +475,34 @@ boltzgen execute [-h] [--no_subprocess] [--steps {design,inverse_folding,design_
- `--no_subprocess` - Run each step in the main process. Will cause issues when devices >1.
- `--steps {design,inverse_folding,design_folding,folding,affinity,analysis,filtering} [{design,inverse_folding,design_folding,folding,affinity,analysis,filtering} ...]` - Run only the specified pipeline steps (default: run all steps)
## `boltzgen merge`
If you produced designs across multiple pipeline runs (e.g. for parallelization) you can merge the finished outputs into one directory and then rerun the fast filtering step on the combined set.
### Example
```bash
boltzgen merge workbench/run_a workbench/run_b workbench/run_c \
--output workbench/merged_run
# Now rerun filtering (with any tweaked parameters you like)
boltzgen run example/vanilla_protein/1g13prot.yaml \
--steps filtering \
--output workbench/merged_run \
--protocol protein-anything \
--budget 60 \
--alpha 0.05
```
### Usage
```bash
boltzgen merge [-h] [--overwrite] --output OUTPUT source [source ...]
```
### Arguments
- `source` (positional) One or more BoltzGen output directories that already contain folded/analyzed results (i.e., the directories you previously passed to `--output` when running the pipeline).
- `--output OUTPUT` Destination directory for the merged data. The command creates (or replaces) the design artifacts inside this folder so that `boltzgen run --steps filtering --output OUTPUT ...` can be executed afterwards.
- `--overwrite` Allow the destination directory (and its design subdirectory) to be replaced if they already exist.
# Training BoltzGen models
Install in dev mode which will install additional packages like `wandb`.
```bash

View File

@@ -37,8 +37,11 @@ import subprocess
import os
import time
import math
import re
import shutil
import sys
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Any, Dict, List, Tuple
import yaml
@@ -484,6 +487,32 @@ def build_check_parser(subparsers) -> argparse.ArgumentParser:
return check_parser
def build_merge_parser(subparsers) -> argparse.ArgumentParser:
merge_parser = subparsers.add_parser(
"merge",
description="Merge multiple BoltzGen output directories so filtering can be rerun on the combined set.",
help="Combine finished pipeline outputs into a single directory",
)
merge_parser.add_argument(
"sources",
nargs="+",
type=Path,
help="Paths to completed BoltzGen output directories (results of 'run' or 'execute')",
)
merge_parser.add_argument(
"--output",
type=Path,
required=True,
help="Destination directory for the merged outputs",
)
merge_parser.add_argument(
"--overwrite",
action="store_true",
help="Allow reusing an existing destination; the merged design directory will be replaced if it exists.",
)
return merge_parser
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
prog="boltzgen",
@@ -508,6 +537,7 @@ def build_parser() -> argparse.ArgumentParser:
build_execute_parser(subparsers)
build_download_parser(subparsers)
build_check_parser(subparsers)
build_merge_parser(subparsers)
return parser
@@ -1442,6 +1472,225 @@ def parse_size_buckets(value_list):
return size_buckets
def merge_command(args: argparse.Namespace) -> None:
"""
Merge multiple BoltzGen output directories into a single destination directory so
the filtering step can be rerun over the combined set of designs.
"""
def _merge_design_dir(
sources: list[Path],
run_tags: dict[Path, str],
dir_name: str,
dest_dir: Path,
id_map: dict[tuple[Path, str], str],
) -> int:
metrics_frames: list[pd.DataFrame] = []
seq_frames: list[pd.DataFrame] = []
per_target_frames: list[pd.DataFrame] = []
merged_count = 0
for root in sources:
src_dir = root / dir_name
if not src_dir.exists():
continue
run_tag = run_tags[root]
source_mappings: list[tuple[str, str, str, str]] = []
metrics_path = src_dir / "aggregate_metrics_analyze.csv"
if metrics_path.exists():
df = pd.read_csv(metrics_path)
if not df.empty:
updated_rows = []
for _, row in df.iterrows():
if "id" not in row or "file_name" not in row:
raise ValueError(
"aggregate_metrics_analyze.csv must contain 'id' and 'file_name' columns."
)
original_id = str(row["id"])
original_file = str(row["file_name"])
key = (root, original_id)
new_id = id_map.setdefault(key, f"{run_tag}_{original_id}")
new_file = _make_new_file_name(original_file, new_id)
updated_rows.append(
{**row, "id": new_id, "file_name": new_file}
)
source_mappings.append(
(original_id, new_id, original_file, new_file)
)
metrics_frames.append(pd.DataFrame(updated_rows))
merged_count += len(source_mappings)
else:
known_ids = [
(orig, new_id)
for (src, orig), new_id in id_map.items()
if src == root
]
for original_id, new_id in known_ids:
original_file = f"{original_id}.cif"
new_file = _make_new_file_name(original_file, new_id)
source_mappings.append(
(original_id, new_id, original_file, new_file)
)
if not source_mappings:
continue
dest_dir.mkdir(parents=True, exist_ok=True)
seq_path = src_dir / "ca_coords_sequences.pkl.gz"
if seq_path.exists():
seq_df = pd.read_pickle(seq_path)
original_ids = [orig for orig, _, _, _ in source_mappings]
seq_subset = seq_df[seq_df["id"].astype(str).isin(original_ids)].copy()
if not seq_subset.empty:
id_lookup = {orig: new for orig, new, _, _ in source_mappings}
seq_subset["id"] = seq_subset["id"].astype(str).map(id_lookup)
seq_frames.append(seq_subset)
per_target_path = src_dir / "per_target_metrics_analyze.csv"
if per_target_path.exists():
per_target_frames.append(pd.read_csv(per_target_path))
for original_id, new_id, original_file, new_file in source_mappings:
_copy_design_files(
src_dir=src_dir,
dest_dir=dest_dir,
original_id=original_id,
new_id=new_id,
original_file=original_file,
new_file=new_file,
include_refold=True,
)
if metrics_frames:
pd.concat(metrics_frames, ignore_index=True).to_csv(
dest_dir / "aggregate_metrics_analyze.csv", index=False
)
if seq_frames:
pd.concat(seq_frames, ignore_index=True).to_pickle(
dest_dir / "ca_coords_sequences.pkl.gz", compression="gzip"
)
if per_target_frames:
pd.concat(per_target_frames, ignore_index=True).to_csv(
dest_dir / "per_target_metrics_analyze.csv", index=False
)
return merged_count
def _copy_design_files(
src_dir: Path,
dest_dir: Path,
original_id: str,
new_id: str,
original_file: str,
new_file: str,
include_refold: bool,
) -> None:
_copy_path(src_dir / original_file, dest_dir / new_file, required=True)
_copy_path(
src_dir / f"{original_id}.npz",
dest_dir / f"{new_id}.npz",
required=False,
)
_copy_path(
src_dir / f"{original_id}_native.cif",
dest_dir / f"{new_id}_native.cif",
required=False,
)
_copy_path(
src_dir / f"{original_id}_native.pdb",
dest_dir / f"{new_id}_native.pdb",
required=False,
)
if include_refold:
_copy_path(
src_dir / const.refold_cif_dirname / original_file,
dest_dir / const.refold_cif_dirname / new_file,
required=False,
)
_copy_path(
src_dir / const.refold_design_cif_dirname / original_file,
dest_dir / const.refold_design_cif_dirname / new_file,
required=False,
)
def _make_new_file_name(original_file: str, new_id: str) -> str:
path = Path(original_file)
suffix = "".join(path.suffixes)
return f"{new_id}{suffix}" if suffix else new_id
def _slugify_run_tag(path: Path, index: int) -> str:
slug = re.sub(r"[^0-9A-Za-z]+", "-", path.name).strip("-").lower()
return slug or f"run{index}"
def _copy_path(src: Path, dst: Path, *, required: bool) -> None:
if src.exists():
dst.parent.mkdir(parents=True, exist_ok=True)
shutil.copy2(src, dst)
elif required:
raise FileNotFoundError(f"Required file missing during merge: {src}")
if not args.sources:
raise ValueError("Provide at least one source directory to merge.")
source_roots: list[Path] = []
for src in args.sources:
root = Path(src).expanduser().resolve()
if not root.exists() or not root.is_dir():
raise FileNotFoundError(f"Source directory not found: {root}")
source_roots.append(root)
dest_root = args.output.expanduser().resolve()
if dest_root.exists():
if not dest_root.is_dir():
raise ValueError(f"Output path exists and is not a directory: {dest_root}")
if args.overwrite:
shutil.rmtree(dest_root)
dest_root.mkdir(parents=True)
elif any(dest_root.iterdir()):
raise ValueError(
f"Output directory {dest_root} already exists. "
"Use --overwrite to reuse it."
)
else:
dest_root.mkdir(parents=True)
run_tags = {
root: _slugify_run_tag(root, idx + 1) for idx, root in enumerate(source_roots)
}
id_map: dict[tuple[Path, str], str] = {}
total_designs = 0
for dir_name in [
"intermediate_designs_inverse_folded",
"intermediate_designs",
]:
dest_dir = dest_root / dir_name
merged = _merge_design_dir(
sources=source_roots,
run_tags=run_tags,
dir_name=dir_name,
dest_dir=dest_dir,
id_map=id_map,
)
if merged:
total_designs += merged
print(f"- merged {merged} designs into {dest_dir}")
if total_designs == 0:
print("No designs found to merge.")
else:
print("===============================================")
print(f"Merged {len(source_roots)} source(s) into {dest_root}")
print(f"Total designs available for filtering: {total_designs}")
def main() -> None:
parser = build_parser()
args = parser.parse_args()
@@ -1456,6 +1705,8 @@ def main() -> None:
download_command(args)
elif args.command == "check":
check_command(args)
elif args.command == "merge":
merge_command(args)
else:
parser.print_help()