Files
RFdiffusion/config/inference/base.yaml
Leonardo Marino-Ramirez 529b756796 feat: add inference.empty_cache_per_design flag to reduce CUDA allocator fragmentation (#451)
## Problem

When running RFdiffusion with variable-length contigs (e.g.
`contigmap.contigs=[A1-469/0 1-50]`) over hundreds or thousands of
designs, per-worker VRAM grows steadily from ~7 GB to 10–13 GB per
process. This limits how many workers can run in parallel on a single
GPU before exhausting VRAM.

Root cause: PyTorch's CUDA caching allocator accumulates fragmented
memory blocks across designs. With variable-length contigs each design
allocates differently-sized tensors; freed blocks are cached but cannot
be reused for different-sized allocations, causing steady VRAM growth.

## Fix

Add an optional `inference.empty_cache_per_design` flag (default
`False`, opt-in) that calls `torch.cuda.empty_cache()` at the end of
each design iteration. This releases all unused cached CUDA memory
blocks back to the CUDA memory manager, keeping each worker near its
initial VRAM footprint for the full run.

### Changes

**`config/inference/base.yaml`**
```yaml
  write_trajectory: True
  empty_cache_per_design: False   # NEW
```

**`scripts/run_inference.py`** — after the trajectory/PDB write block,
before `log.info`:
```python
        if conf.inference.empty_cache_per_design and torch.cuda.is_available():
            torch.cuda.empty_cache()

        log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")
```

## Measured impact

Tested on NVIDIA RTX 5090 32 GB running a long PPI campaign with
variable-length contigs:

| Setting | Per-worker VRAM (steady-state) |
|---------|-------------------------------|
| Without fix | 8–13 GB (grows over run) |
| With `empty_cache_per_design=True` | ~5.2 GB (stable) |

This allowed raising the number of parallel workers from 3 to 5 on a 32
GB GPU.

## Why opt-in

`torch.cuda.empty_cache()` adds a small per-design overhead (~1–2 ms)
and is only beneficial for long runs with variable-length contigs. For
short runs or fixed-length designs there is no fragmentation issue, so
the default remains `False` to preserve existing behavior.

## Testing

All 20 applicable tests in `tests/test_diffusion.py` pass with this
change. The one skipped test (`design_ppi_scaffolded`) fails due to a
missing `ppi_scaffolds/` directory in the test fixture — a pre-existing
issue unrelated to this PR.

## Notes

- Placement is after both the PDB write (`writepdb`) and the optional
trajectory block — every consumer of `denoised_xyz_stack` /
`px0_xyz_stack` has already finished before the cache is cleared.
- This does not affect memory held by live tensors — only frees
cached-but-unused blocks.
- Compatible with all existing RFdiffusion design modes (PPI, motif
scaffolding, unconditional).
2026-04-24 10:41:07 -06:00

144 lines
2.6 KiB
YAML

# Base inference Configuration.
inference:
input_pdb: null
num_designs: 10
design_startnum: 0
ckpt_override_path: null
symmetry: null
recenter: True
radius: 10.0
model_only_neighbors: False
output_prefix: samples/design
write_trajectory: True
empty_cache_per_design: False
scaffold_guided: False
model_runner: SelfConditioning
cautious: True
align_motif: True
symmetric_self_cond: True
final_step: 1
deterministic: False
trb_save_ckpt_path: null
schedule_directory_path: null
model_directory_path: null
cyclic: False
cyc_chains: 'a'
contigmap:
contigs: null
inpaint_seq: null
inpaint_str: null
inpaint_str_helix: null
inpaint_str_strand: null
inpaint_str_loop: null
provide_seq: null
length: null
model:
n_extra_block: 4
n_main_block: 32
n_ref_block: 4
d_msa: 256
d_msa_full: 64
d_pair: 128
d_templ: 64
n_head_msa: 8
n_head_pair: 4
n_head_templ: 4
d_hidden: 32
d_hidden_templ: 32
p_drop: 0.15
SE3_param_full:
num_layers: 1
num_channels: 32
num_degrees: 2
n_heads: 4
div: 4
l0_in_features: 8
l0_out_features: 8
l1_in_features: 3
l1_out_features: 2
num_edge_features: 32
SE3_param_topk:
num_layers: 1
num_channels: 32
num_degrees: 2
n_heads: 4
div: 4
l0_in_features: 64
l0_out_features: 64
l1_in_features: 3
l1_out_features: 2
num_edge_features: 64
freeze_track_motif: False
use_motif_timestep: False
diffuser:
T: 50
b_0: 1e-2
b_T: 7e-2
schedule_type: linear
so3_type: igso3
crd_scale: 0.25
partial_T: null
so3_schedule_type: linear
min_b: 1.5
max_b: 2.5
min_sigma: 0.02
max_sigma: 1.5
denoiser:
noise_scale_ca: 1
final_noise_scale_ca: 1
ca_noise_schedule_type: constant
noise_scale_frame: 1
final_noise_scale_frame: 1
frame_noise_schedule_type: constant
ppi:
hotspot_res: null
potentials:
guiding_potentials: null
guide_scale: 10
guide_decay: constant
olig_inter_all : null
olig_intra_all : null
olig_custom_contact : null
substrate: null
contig_settings:
ref_idx: null
hal_idx: null
idx_rf: null
inpaint_seq_tensor: null
preprocess:
sidechain_input: False
motif_sidechain_input: True
d_t1d: 22
d_t2d: 44
prob_self_cond: 0.0
str_self_cond: False
predict_previous: False
logging:
inputs: False
scaffoldguided:
scaffoldguided: False
target_pdb: False
target_path: null
scaffold_list: null
scaffold_dir: null
sampled_insertion: 0
sampled_N: 0
sampled_C: 0
ss_mask: 0
systematic: False
target_ss: null
target_adj: null
mask_loops: True
contig_crop: null