mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
add 'boltzgen merge' command
This commit is contained in:
28
README.md
28
README.md
@@ -475,6 +475,34 @@ boltzgen execute [-h] [--no_subprocess] [--steps {design,inverse_folding,design_
|
||||
- `--no_subprocess` - Run each step in the main process. Will cause issues when devices >1.
|
||||
- `--steps {design,inverse_folding,design_folding,folding,affinity,analysis,filtering} [{design,inverse_folding,design_folding,folding,affinity,analysis,filtering} ...]` - Run only the specified pipeline steps (default: run all steps)
|
||||
|
||||
## `boltzgen merge`
|
||||
|
||||
If you produced designs across multiple pipeline runs (e.g. for parallelization) you can merge the finished outputs into one directory and then rerun the fast filtering step on the combined set.
|
||||
|
||||
### Example
|
||||
```bash
|
||||
boltzgen merge workbench/run_a workbench/run_b workbench/run_c \
|
||||
--output workbench/merged_run
|
||||
|
||||
# Now rerun filtering (with any tweaked parameters you like)
|
||||
boltzgen run example/vanilla_protein/1g13prot.yaml \
|
||||
--steps filtering \
|
||||
--output workbench/merged_run \
|
||||
--protocol protein-anything \
|
||||
--budget 60 \
|
||||
--alpha 0.05
|
||||
```
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
boltzgen merge [-h] [--overwrite] --output OUTPUT source [source ...]
|
||||
```
|
||||
|
||||
### Arguments
|
||||
- `source` (positional) – One or more BoltzGen output directories that already contain folded/analyzed results (i.e., the directories you previously passed to `--output` when running the pipeline).
|
||||
- `--output OUTPUT` – Destination directory for the merged data. The command creates (or replaces) the design artifacts inside this folder so that `boltzgen run --steps filtering --output OUTPUT ...` can be executed afterwards.
|
||||
- `--overwrite` – Allow the destination directory (and its design subdirectory) to be replaced if they already exist.
|
||||
|
||||
# Training BoltzGen models
|
||||
Install in dev mode which will install additional packages like `wandb`.
|
||||
```bash
|
||||
|
||||
@@ -37,8 +37,11 @@ import subprocess
|
||||
import os
|
||||
import time
|
||||
import math
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple
|
||||
import yaml
|
||||
@@ -484,6 +487,32 @@ def build_check_parser(subparsers) -> argparse.ArgumentParser:
|
||||
return check_parser
|
||||
|
||||
|
||||
def build_merge_parser(subparsers) -> argparse.ArgumentParser:
|
||||
merge_parser = subparsers.add_parser(
|
||||
"merge",
|
||||
description="Merge multiple BoltzGen output directories so filtering can be rerun on the combined set.",
|
||||
help="Combine finished pipeline outputs into a single directory",
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"sources",
|
||||
nargs="+",
|
||||
type=Path,
|
||||
help="Paths to completed BoltzGen output directories (results of 'run' or 'execute')",
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--output",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Destination directory for the merged outputs",
|
||||
)
|
||||
merge_parser.add_argument(
|
||||
"--overwrite",
|
||||
action="store_true",
|
||||
help="Allow reusing an existing destination; the merged design directory will be replaced if it exists.",
|
||||
)
|
||||
return merge_parser
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="boltzgen",
|
||||
@@ -508,6 +537,7 @@ def build_parser() -> argparse.ArgumentParser:
|
||||
build_execute_parser(subparsers)
|
||||
build_download_parser(subparsers)
|
||||
build_check_parser(subparsers)
|
||||
build_merge_parser(subparsers)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -1442,6 +1472,225 @@ def parse_size_buckets(value_list):
|
||||
return size_buckets
|
||||
|
||||
|
||||
def merge_command(args: argparse.Namespace) -> None:
|
||||
"""
|
||||
Merge multiple BoltzGen output directories into a single destination directory so
|
||||
the filtering step can be rerun over the combined set of designs.
|
||||
"""
|
||||
|
||||
def _merge_design_dir(
|
||||
sources: list[Path],
|
||||
run_tags: dict[Path, str],
|
||||
dir_name: str,
|
||||
dest_dir: Path,
|
||||
id_map: dict[tuple[Path, str], str],
|
||||
) -> int:
|
||||
metrics_frames: list[pd.DataFrame] = []
|
||||
seq_frames: list[pd.DataFrame] = []
|
||||
per_target_frames: list[pd.DataFrame] = []
|
||||
merged_count = 0
|
||||
|
||||
for root in sources:
|
||||
src_dir = root / dir_name
|
||||
if not src_dir.exists():
|
||||
continue
|
||||
|
||||
run_tag = run_tags[root]
|
||||
source_mappings: list[tuple[str, str, str, str]] = []
|
||||
|
||||
metrics_path = src_dir / "aggregate_metrics_analyze.csv"
|
||||
if metrics_path.exists():
|
||||
df = pd.read_csv(metrics_path)
|
||||
if not df.empty:
|
||||
updated_rows = []
|
||||
for _, row in df.iterrows():
|
||||
if "id" not in row or "file_name" not in row:
|
||||
raise ValueError(
|
||||
"aggregate_metrics_analyze.csv must contain 'id' and 'file_name' columns."
|
||||
)
|
||||
original_id = str(row["id"])
|
||||
original_file = str(row["file_name"])
|
||||
key = (root, original_id)
|
||||
new_id = id_map.setdefault(key, f"{run_tag}_{original_id}")
|
||||
new_file = _make_new_file_name(original_file, new_id)
|
||||
updated_rows.append(
|
||||
{**row, "id": new_id, "file_name": new_file}
|
||||
)
|
||||
source_mappings.append(
|
||||
(original_id, new_id, original_file, new_file)
|
||||
)
|
||||
metrics_frames.append(pd.DataFrame(updated_rows))
|
||||
merged_count += len(source_mappings)
|
||||
else:
|
||||
known_ids = [
|
||||
(orig, new_id)
|
||||
for (src, orig), new_id in id_map.items()
|
||||
if src == root
|
||||
]
|
||||
for original_id, new_id in known_ids:
|
||||
original_file = f"{original_id}.cif"
|
||||
new_file = _make_new_file_name(original_file, new_id)
|
||||
source_mappings.append(
|
||||
(original_id, new_id, original_file, new_file)
|
||||
)
|
||||
|
||||
if not source_mappings:
|
||||
continue
|
||||
|
||||
dest_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
seq_path = src_dir / "ca_coords_sequences.pkl.gz"
|
||||
if seq_path.exists():
|
||||
seq_df = pd.read_pickle(seq_path)
|
||||
original_ids = [orig for orig, _, _, _ in source_mappings]
|
||||
seq_subset = seq_df[seq_df["id"].astype(str).isin(original_ids)].copy()
|
||||
if not seq_subset.empty:
|
||||
id_lookup = {orig: new for orig, new, _, _ in source_mappings}
|
||||
seq_subset["id"] = seq_subset["id"].astype(str).map(id_lookup)
|
||||
seq_frames.append(seq_subset)
|
||||
|
||||
per_target_path = src_dir / "per_target_metrics_analyze.csv"
|
||||
if per_target_path.exists():
|
||||
per_target_frames.append(pd.read_csv(per_target_path))
|
||||
|
||||
for original_id, new_id, original_file, new_file in source_mappings:
|
||||
_copy_design_files(
|
||||
src_dir=src_dir,
|
||||
dest_dir=dest_dir,
|
||||
original_id=original_id,
|
||||
new_id=new_id,
|
||||
original_file=original_file,
|
||||
new_file=new_file,
|
||||
include_refold=True,
|
||||
)
|
||||
|
||||
if metrics_frames:
|
||||
pd.concat(metrics_frames, ignore_index=True).to_csv(
|
||||
dest_dir / "aggregate_metrics_analyze.csv", index=False
|
||||
)
|
||||
if seq_frames:
|
||||
pd.concat(seq_frames, ignore_index=True).to_pickle(
|
||||
dest_dir / "ca_coords_sequences.pkl.gz", compression="gzip"
|
||||
)
|
||||
if per_target_frames:
|
||||
pd.concat(per_target_frames, ignore_index=True).to_csv(
|
||||
dest_dir / "per_target_metrics_analyze.csv", index=False
|
||||
)
|
||||
|
||||
return merged_count
|
||||
|
||||
|
||||
def _copy_design_files(
|
||||
src_dir: Path,
|
||||
dest_dir: Path,
|
||||
original_id: str,
|
||||
new_id: str,
|
||||
original_file: str,
|
||||
new_file: str,
|
||||
include_refold: bool,
|
||||
) -> None:
|
||||
_copy_path(src_dir / original_file, dest_dir / new_file, required=True)
|
||||
_copy_path(
|
||||
src_dir / f"{original_id}.npz",
|
||||
dest_dir / f"{new_id}.npz",
|
||||
required=False,
|
||||
)
|
||||
_copy_path(
|
||||
src_dir / f"{original_id}_native.cif",
|
||||
dest_dir / f"{new_id}_native.cif",
|
||||
required=False,
|
||||
)
|
||||
_copy_path(
|
||||
src_dir / f"{original_id}_native.pdb",
|
||||
dest_dir / f"{new_id}_native.pdb",
|
||||
required=False,
|
||||
)
|
||||
if include_refold:
|
||||
_copy_path(
|
||||
src_dir / const.refold_cif_dirname / original_file,
|
||||
dest_dir / const.refold_cif_dirname / new_file,
|
||||
required=False,
|
||||
)
|
||||
_copy_path(
|
||||
src_dir / const.refold_design_cif_dirname / original_file,
|
||||
dest_dir / const.refold_design_cif_dirname / new_file,
|
||||
required=False,
|
||||
)
|
||||
|
||||
|
||||
def _make_new_file_name(original_file: str, new_id: str) -> str:
|
||||
path = Path(original_file)
|
||||
suffix = "".join(path.suffixes)
|
||||
return f"{new_id}{suffix}" if suffix else new_id
|
||||
|
||||
|
||||
def _slugify_run_tag(path: Path, index: int) -> str:
|
||||
slug = re.sub(r"[^0-9A-Za-z]+", "-", path.name).strip("-").lower()
|
||||
return slug or f"run{index}"
|
||||
|
||||
|
||||
def _copy_path(src: Path, dst: Path, *, required: bool) -> None:
|
||||
if src.exists():
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy2(src, dst)
|
||||
elif required:
|
||||
raise FileNotFoundError(f"Required file missing during merge: {src}")
|
||||
|
||||
if not args.sources:
|
||||
raise ValueError("Provide at least one source directory to merge.")
|
||||
|
||||
source_roots: list[Path] = []
|
||||
for src in args.sources:
|
||||
root = Path(src).expanduser().resolve()
|
||||
if not root.exists() or not root.is_dir():
|
||||
raise FileNotFoundError(f"Source directory not found: {root}")
|
||||
source_roots.append(root)
|
||||
|
||||
dest_root = args.output.expanduser().resolve()
|
||||
if dest_root.exists():
|
||||
if not dest_root.is_dir():
|
||||
raise ValueError(f"Output path exists and is not a directory: {dest_root}")
|
||||
if args.overwrite:
|
||||
shutil.rmtree(dest_root)
|
||||
dest_root.mkdir(parents=True)
|
||||
elif any(dest_root.iterdir()):
|
||||
raise ValueError(
|
||||
f"Output directory {dest_root} already exists. "
|
||||
"Use --overwrite to reuse it."
|
||||
)
|
||||
else:
|
||||
dest_root.mkdir(parents=True)
|
||||
|
||||
run_tags = {
|
||||
root: _slugify_run_tag(root, idx + 1) for idx, root in enumerate(source_roots)
|
||||
}
|
||||
id_map: dict[tuple[Path, str], str] = {}
|
||||
|
||||
total_designs = 0
|
||||
for dir_name in [
|
||||
"intermediate_designs_inverse_folded",
|
||||
"intermediate_designs",
|
||||
]:
|
||||
dest_dir = dest_root / dir_name
|
||||
merged = _merge_design_dir(
|
||||
sources=source_roots,
|
||||
run_tags=run_tags,
|
||||
dir_name=dir_name,
|
||||
dest_dir=dest_dir,
|
||||
id_map=id_map,
|
||||
)
|
||||
if merged:
|
||||
total_designs += merged
|
||||
print(f"- merged {merged} designs into {dest_dir}")
|
||||
|
||||
if total_designs == 0:
|
||||
print("No designs found to merge.")
|
||||
else:
|
||||
print("===============================================")
|
||||
print(f"Merged {len(source_roots)} source(s) into {dest_root}")
|
||||
print(f"Total designs available for filtering: {total_designs}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args()
|
||||
@@ -1456,6 +1705,8 @@ def main() -> None:
|
||||
download_command(args)
|
||||
elif args.command == "check":
|
||||
check_command(args)
|
||||
elif args.command == "merge":
|
||||
merge_command(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user