diff --git a/README.md b/README.md index a91673b..eae8379 100755 --- a/README.md +++ b/README.md @@ -475,6 +475,34 @@ boltzgen execute [-h] [--no_subprocess] [--steps {design,inverse_folding,design_ - `--no_subprocess` - Run each step in the main process. Will cause issues when devices >1. - `--steps {design,inverse_folding,design_folding,folding,affinity,analysis,filtering} [{design,inverse_folding,design_folding,folding,affinity,analysis,filtering} ...]` - Run only the specified pipeline steps (default: run all steps) +## `boltzgen merge` + +If you produced designs across multiple pipeline runs (e.g. for parallelization) you can merge the finished outputs into one directory and then rerun the fast filtering step on the combined set. + +### Example +```bash +boltzgen merge workbench/run_a workbench/run_b workbench/run_c \ + --output workbench/merged_run + +# Now rerun filtering (with any tweaked parameters you like) +boltzgen run example/vanilla_protein/1g13prot.yaml \ + --steps filtering \ + --output workbench/merged_run \ + --protocol protein-anything \ + --budget 60 \ + --alpha 0.05 +``` + +### Usage +```bash +boltzgen merge [-h] [--overwrite] --output OUTPUT source [source ...] +``` + +### Arguments +- `source` (positional) – One or more BoltzGen output directories that already contain folded/analyzed results (i.e., the directories you previously passed to `--output` when running the pipeline). +- `--output OUTPUT` – Destination directory for the merged data. The command creates (or replaces) the design artifacts inside this folder so that `boltzgen run --steps filtering --output OUTPUT ...` can be executed afterwards. +- `--overwrite` – Allow the destination directory (and its design subdirectory) to be replaced if they already exist. + # Training BoltzGen models Install in dev mode which will install additional packages like `wandb`. ```bash diff --git a/src/boltzgen/cli/boltzgen.py b/src/boltzgen/cli/boltzgen.py index 7d4702e..e559f7d 100644 --- a/src/boltzgen/cli/boltzgen.py +++ b/src/boltzgen/cli/boltzgen.py @@ -37,8 +37,11 @@ import subprocess import os import time import math +import re +import shutil import sys import numpy as np +import pandas as pd from pathlib import Path from typing import Any, Dict, List, Tuple import yaml @@ -484,6 +487,32 @@ def build_check_parser(subparsers) -> argparse.ArgumentParser: return check_parser +def build_merge_parser(subparsers) -> argparse.ArgumentParser: + merge_parser = subparsers.add_parser( + "merge", + description="Merge multiple BoltzGen output directories so filtering can be rerun on the combined set.", + help="Combine finished pipeline outputs into a single directory", + ) + merge_parser.add_argument( + "sources", + nargs="+", + type=Path, + help="Paths to completed BoltzGen output directories (results of 'run' or 'execute')", + ) + merge_parser.add_argument( + "--output", + type=Path, + required=True, + help="Destination directory for the merged outputs", + ) + merge_parser.add_argument( + "--overwrite", + action="store_true", + help="Allow reusing an existing destination; the merged design directory will be replaced if it exists.", + ) + return merge_parser + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="boltzgen", @@ -508,6 +537,7 @@ def build_parser() -> argparse.ArgumentParser: build_execute_parser(subparsers) build_download_parser(subparsers) build_check_parser(subparsers) + build_merge_parser(subparsers) return parser @@ -1442,6 +1472,225 @@ def parse_size_buckets(value_list): return size_buckets +def merge_command(args: argparse.Namespace) -> None: + """ + Merge multiple BoltzGen output directories into a single destination directory so + the filtering step can be rerun over the combined set of designs. + """ + + def _merge_design_dir( + sources: list[Path], + run_tags: dict[Path, str], + dir_name: str, + dest_dir: Path, + id_map: dict[tuple[Path, str], str], + ) -> int: + metrics_frames: list[pd.DataFrame] = [] + seq_frames: list[pd.DataFrame] = [] + per_target_frames: list[pd.DataFrame] = [] + merged_count = 0 + + for root in sources: + src_dir = root / dir_name + if not src_dir.exists(): + continue + + run_tag = run_tags[root] + source_mappings: list[tuple[str, str, str, str]] = [] + + metrics_path = src_dir / "aggregate_metrics_analyze.csv" + if metrics_path.exists(): + df = pd.read_csv(metrics_path) + if not df.empty: + updated_rows = [] + for _, row in df.iterrows(): + if "id" not in row or "file_name" not in row: + raise ValueError( + "aggregate_metrics_analyze.csv must contain 'id' and 'file_name' columns." + ) + original_id = str(row["id"]) + original_file = str(row["file_name"]) + key = (root, original_id) + new_id = id_map.setdefault(key, f"{run_tag}_{original_id}") + new_file = _make_new_file_name(original_file, new_id) + updated_rows.append( + {**row, "id": new_id, "file_name": new_file} + ) + source_mappings.append( + (original_id, new_id, original_file, new_file) + ) + metrics_frames.append(pd.DataFrame(updated_rows)) + merged_count += len(source_mappings) + else: + known_ids = [ + (orig, new_id) + for (src, orig), new_id in id_map.items() + if src == root + ] + for original_id, new_id in known_ids: + original_file = f"{original_id}.cif" + new_file = _make_new_file_name(original_file, new_id) + source_mappings.append( + (original_id, new_id, original_file, new_file) + ) + + if not source_mappings: + continue + + dest_dir.mkdir(parents=True, exist_ok=True) + + seq_path = src_dir / "ca_coords_sequences.pkl.gz" + if seq_path.exists(): + seq_df = pd.read_pickle(seq_path) + original_ids = [orig for orig, _, _, _ in source_mappings] + seq_subset = seq_df[seq_df["id"].astype(str).isin(original_ids)].copy() + if not seq_subset.empty: + id_lookup = {orig: new for orig, new, _, _ in source_mappings} + seq_subset["id"] = seq_subset["id"].astype(str).map(id_lookup) + seq_frames.append(seq_subset) + + per_target_path = src_dir / "per_target_metrics_analyze.csv" + if per_target_path.exists(): + per_target_frames.append(pd.read_csv(per_target_path)) + + for original_id, new_id, original_file, new_file in source_mappings: + _copy_design_files( + src_dir=src_dir, + dest_dir=dest_dir, + original_id=original_id, + new_id=new_id, + original_file=original_file, + new_file=new_file, + include_refold=True, + ) + + if metrics_frames: + pd.concat(metrics_frames, ignore_index=True).to_csv( + dest_dir / "aggregate_metrics_analyze.csv", index=False + ) + if seq_frames: + pd.concat(seq_frames, ignore_index=True).to_pickle( + dest_dir / "ca_coords_sequences.pkl.gz", compression="gzip" + ) + if per_target_frames: + pd.concat(per_target_frames, ignore_index=True).to_csv( + dest_dir / "per_target_metrics_analyze.csv", index=False + ) + + return merged_count + + + def _copy_design_files( + src_dir: Path, + dest_dir: Path, + original_id: str, + new_id: str, + original_file: str, + new_file: str, + include_refold: bool, + ) -> None: + _copy_path(src_dir / original_file, dest_dir / new_file, required=True) + _copy_path( + src_dir / f"{original_id}.npz", + dest_dir / f"{new_id}.npz", + required=False, + ) + _copy_path( + src_dir / f"{original_id}_native.cif", + dest_dir / f"{new_id}_native.cif", + required=False, + ) + _copy_path( + src_dir / f"{original_id}_native.pdb", + dest_dir / f"{new_id}_native.pdb", + required=False, + ) + if include_refold: + _copy_path( + src_dir / const.refold_cif_dirname / original_file, + dest_dir / const.refold_cif_dirname / new_file, + required=False, + ) + _copy_path( + src_dir / const.refold_design_cif_dirname / original_file, + dest_dir / const.refold_design_cif_dirname / new_file, + required=False, + ) + + + def _make_new_file_name(original_file: str, new_id: str) -> str: + path = Path(original_file) + suffix = "".join(path.suffixes) + return f"{new_id}{suffix}" if suffix else new_id + + + def _slugify_run_tag(path: Path, index: int) -> str: + slug = re.sub(r"[^0-9A-Za-z]+", "-", path.name).strip("-").lower() + return slug or f"run{index}" + + + def _copy_path(src: Path, dst: Path, *, required: bool) -> None: + if src.exists(): + dst.parent.mkdir(parents=True, exist_ok=True) + shutil.copy2(src, dst) + elif required: + raise FileNotFoundError(f"Required file missing during merge: {src}") + + if not args.sources: + raise ValueError("Provide at least one source directory to merge.") + + source_roots: list[Path] = [] + for src in args.sources: + root = Path(src).expanduser().resolve() + if not root.exists() or not root.is_dir(): + raise FileNotFoundError(f"Source directory not found: {root}") + source_roots.append(root) + + dest_root = args.output.expanduser().resolve() + if dest_root.exists(): + if not dest_root.is_dir(): + raise ValueError(f"Output path exists and is not a directory: {dest_root}") + if args.overwrite: + shutil.rmtree(dest_root) + dest_root.mkdir(parents=True) + elif any(dest_root.iterdir()): + raise ValueError( + f"Output directory {dest_root} already exists. " + "Use --overwrite to reuse it." + ) + else: + dest_root.mkdir(parents=True) + + run_tags = { + root: _slugify_run_tag(root, idx + 1) for idx, root in enumerate(source_roots) + } + id_map: dict[tuple[Path, str], str] = {} + + total_designs = 0 + for dir_name in [ + "intermediate_designs_inverse_folded", + "intermediate_designs", + ]: + dest_dir = dest_root / dir_name + merged = _merge_design_dir( + sources=source_roots, + run_tags=run_tags, + dir_name=dir_name, + dest_dir=dest_dir, + id_map=id_map, + ) + if merged: + total_designs += merged + print(f"- merged {merged} designs into {dest_dir}") + + if total_designs == 0: + print("No designs found to merge.") + else: + print("===============================================") + print(f"Merged {len(source_roots)} source(s) into {dest_root}") + print(f"Total designs available for filtering: {total_designs}") + + def main() -> None: parser = build_parser() args = parser.parse_args() @@ -1456,6 +1705,8 @@ def main() -> None: download_command(args) elif args.command == "check": check_command(args) + elif args.command == "merge": + merge_command(args) else: parser.print_help()