From 529b756796001f54f446cebb1c2495fab2119c45 Mon Sep 17 00:00:00 2001 From: Leonardo Marino-Ramirez Date: Fri, 24 Apr 2026 11:41:07 -0500 Subject: [PATCH] feat: add inference.empty_cache_per_design flag to reduce CUDA allocator fragmentation (#451) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Problem When running RFdiffusion with variable-length contigs (e.g. `contigmap.contigs=[A1-469/0 1-50]`) over hundreds or thousands of designs, per-worker VRAM grows steadily from ~7 GB to 10–13 GB per process. This limits how many workers can run in parallel on a single GPU before exhausting VRAM. Root cause: PyTorch's CUDA caching allocator accumulates fragmented memory blocks across designs. With variable-length contigs each design allocates differently-sized tensors; freed blocks are cached but cannot be reused for different-sized allocations, causing steady VRAM growth. ## Fix Add an optional `inference.empty_cache_per_design` flag (default `False`, opt-in) that calls `torch.cuda.empty_cache()` at the end of each design iteration. This releases all unused cached CUDA memory blocks back to the CUDA memory manager, keeping each worker near its initial VRAM footprint for the full run. ### Changes **`config/inference/base.yaml`** ```yaml write_trajectory: True empty_cache_per_design: False # NEW ``` **`scripts/run_inference.py`** — after the trajectory/PDB write block, before `log.info`: ```python if conf.inference.empty_cache_per_design and torch.cuda.is_available(): torch.cuda.empty_cache() log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes") ``` ## Measured impact Tested on NVIDIA RTX 5090 32 GB running a long PPI campaign with variable-length contigs: | Setting | Per-worker VRAM (steady-state) | |---------|-------------------------------| | Without fix | 8–13 GB (grows over run) | | With `empty_cache_per_design=True` | ~5.2 GB (stable) | This allowed raising the number of parallel workers from 3 to 5 on a 32 GB GPU. ## Why opt-in `torch.cuda.empty_cache()` adds a small per-design overhead (~1–2 ms) and is only beneficial for long runs with variable-length contigs. For short runs or fixed-length designs there is no fragmentation issue, so the default remains `False` to preserve existing behavior. ## Testing All 20 applicable tests in `tests/test_diffusion.py` pass with this change. The one skipped test (`design_ppi_scaffolded`) fails due to a missing `ppi_scaffolds/` directory in the test fixture — a pre-existing issue unrelated to this PR. ## Notes - Placement is after both the PDB write (`writepdb`) and the optional trajectory block — every consumer of `denoised_xyz_stack` / `px0_xyz_stack` has already finished before the cache is cleared. - This does not affect memory held by live tensors — only frees cached-but-unused blocks. - Compatible with all existing RFdiffusion design modes (PPI, motif scaffolding, unconditional). --- config/inference/base.yaml | 1 + scripts/run_inference.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/config/inference/base.yaml b/config/inference/base.yaml index 3bb0a5c..686e71b 100644 --- a/config/inference/base.yaml +++ b/config/inference/base.yaml @@ -11,6 +11,7 @@ inference: model_only_neighbors: False output_prefix: samples/design write_trajectory: True + empty_cache_per_design: False scaffold_guided: False model_runner: SelfConditioning cautious: True diff --git a/scripts/run_inference.py b/scripts/run_inference.py index 3fb6466..3ebb5e3 100755 --- a/scripts/run_inference.py +++ b/scripts/run_inference.py @@ -188,6 +188,9 @@ def main(conf: HydraConfig) -> None: chain_ids=sampler.chain_idx, ) + if conf.inference.empty_cache_per_design and torch.cuda.is_available(): + torch.cuda.empty_cache() + log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")