mirror of
https://github.com/RosettaCommons/RFdiffusion.git
synced 2026-06-04 18:44:21 +08:00
## Problem
When running RFdiffusion with variable-length contigs (e.g.
`contigmap.contigs=[A1-469/0 1-50]`) over hundreds or thousands of
designs, per-worker VRAM grows steadily from ~7 GB to 10–13 GB per
process. This limits how many workers can run in parallel on a single
GPU before exhausting VRAM.
Root cause: PyTorch's CUDA caching allocator accumulates fragmented
memory blocks across designs. With variable-length contigs each design
allocates differently-sized tensors; freed blocks are cached but cannot
be reused for different-sized allocations, causing steady VRAM growth.
## Fix
Add an optional `inference.empty_cache_per_design` flag (default
`False`, opt-in) that calls `torch.cuda.empty_cache()` at the end of
each design iteration. This releases all unused cached CUDA memory
blocks back to the CUDA memory manager, keeping each worker near its
initial VRAM footprint for the full run.
### Changes
**`config/inference/base.yaml`**
```yaml
write_trajectory: True
empty_cache_per_design: False # NEW
```
**`scripts/run_inference.py`** — after the trajectory/PDB write block,
before `log.info`:
```python
if conf.inference.empty_cache_per_design and torch.cuda.is_available():
torch.cuda.empty_cache()
log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")
```
## Measured impact
Tested on NVIDIA RTX 5090 32 GB running a long PPI campaign with
variable-length contigs:
| Setting | Per-worker VRAM (steady-state) |
|---------|-------------------------------|
| Without fix | 8–13 GB (grows over run) |
| With `empty_cache_per_design=True` | ~5.2 GB (stable) |
This allowed raising the number of parallel workers from 3 to 5 on a 32
GB GPU.
## Why opt-in
`torch.cuda.empty_cache()` adds a small per-design overhead (~1–2 ms)
and is only beneficial for long runs with variable-length contigs. For
short runs or fixed-length designs there is no fragmentation issue, so
the default remains `False` to preserve existing behavior.
## Testing
All 20 applicable tests in `tests/test_diffusion.py` pass with this
change. The one skipped test (`design_ppi_scaffolded`) fails due to a
missing `ppi_scaffolds/` directory in the test fixture — a pre-existing
issue unrelated to this PR.
## Notes
- Placement is after both the PDB write (`writepdb`) and the optional
trajectory block — every consumer of `denoised_xyz_stack` /
`px0_xyz_stack` has already finished before the cache is cleared.
- This does not affect memory held by live tensors — only frees
cached-but-unused blocks.
- Compatible with all existing RFdiffusion design modes (PPI, motif
scaffolding, unconditional).
199 lines
6.3 KiB
Python
Executable File
199 lines
6.3 KiB
Python
Executable File
#!/usr/bin/env python
|
|
"""
|
|
Inference script.
|
|
|
|
To run with base.yaml as the config,
|
|
|
|
> python run_inference.py
|
|
|
|
To specify a different config,
|
|
|
|
> python run_inference.py --config-name symmetry
|
|
|
|
where symmetry can be the filename of any other config (without .yaml extension)
|
|
See https://hydra.cc/docs/advanced/hydra-command-line-flags/ for more options.
|
|
|
|
"""
|
|
|
|
import re
|
|
import os, time, pickle
|
|
import torch
|
|
from omegaconf import OmegaConf
|
|
import hydra
|
|
import logging
|
|
from rfdiffusion.util import writepdb_multi, writepdb
|
|
from rfdiffusion.inference import utils as iu
|
|
from hydra.core.hydra_config import HydraConfig
|
|
import numpy as np
|
|
import random
|
|
import glob
|
|
|
|
|
|
def make_deterministic(seed=0):
|
|
torch.manual_seed(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
|
|
|
|
@hydra.main(version_base=None, config_path="../config/inference", config_name="base")
|
|
def main(conf: HydraConfig) -> None:
|
|
log = logging.getLogger(__name__)
|
|
if conf.inference.deterministic:
|
|
make_deterministic()
|
|
|
|
# Check for available GPU and print result of check
|
|
if torch.cuda.is_available():
|
|
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
|
|
log.info(f"Found GPU with device_name {device_name}. Will run RFdiffusion on {device_name}")
|
|
else:
|
|
log.info("////////////////////////////////////////////////")
|
|
log.info("///// NO GPU DETECTED! Falling back to CPU /////")
|
|
log.info("////////////////////////////////////////////////")
|
|
|
|
# Initialize sampler and target/contig.
|
|
sampler = iu.sampler_selector(conf)
|
|
|
|
# Loop over number of designs to sample.
|
|
design_startnum = sampler.inf_conf.design_startnum
|
|
if sampler.inf_conf.design_startnum == -1:
|
|
existing = glob.glob(sampler.inf_conf.output_prefix + "*.pdb")
|
|
indices = [-1]
|
|
for e in existing:
|
|
print(e)
|
|
m = re.match(".*_(\d+)\.pdb$", e)
|
|
print(m)
|
|
if not m:
|
|
continue
|
|
m = m.groups()[0]
|
|
indices.append(int(m))
|
|
design_startnum = max(indices) + 1
|
|
|
|
for i_des in range(design_startnum, design_startnum + sampler.inf_conf.num_designs):
|
|
if conf.inference.deterministic:
|
|
make_deterministic(i_des)
|
|
|
|
start_time = time.time()
|
|
out_prefix = f"{sampler.inf_conf.output_prefix}_{i_des}"
|
|
log.info(f"Making design {out_prefix}")
|
|
if sampler.inf_conf.cautious and os.path.exists(out_prefix + ".pdb"):
|
|
log.info(
|
|
f"(cautious mode) Skipping this design because {out_prefix}.pdb already exists."
|
|
)
|
|
continue
|
|
|
|
x_init, seq_init = sampler.sample_init()
|
|
denoised_xyz_stack = []
|
|
px0_xyz_stack = []
|
|
seq_stack = []
|
|
plddt_stack = []
|
|
|
|
x_t = torch.clone(x_init)
|
|
seq_t = torch.clone(seq_init)
|
|
# Loop over number of reverse diffusion time steps.
|
|
for t in range(int(sampler.t_step_input), sampler.inf_conf.final_step - 1, -1):
|
|
px0, x_t, seq_t, plddt = sampler.sample_step(
|
|
t=t, x_t=x_t, seq_init=seq_t, final_step=sampler.inf_conf.final_step
|
|
)
|
|
px0_xyz_stack.append(px0)
|
|
denoised_xyz_stack.append(x_t)
|
|
seq_stack.append(seq_t)
|
|
plddt_stack.append(plddt[0]) # remove singleton leading dimension
|
|
|
|
# Flip order for better visualization in pymol
|
|
denoised_xyz_stack = torch.stack(denoised_xyz_stack)
|
|
denoised_xyz_stack = torch.flip(
|
|
denoised_xyz_stack,
|
|
[
|
|
0,
|
|
],
|
|
)
|
|
px0_xyz_stack = torch.stack(px0_xyz_stack)
|
|
px0_xyz_stack = torch.flip(
|
|
px0_xyz_stack,
|
|
[
|
|
0,
|
|
],
|
|
)
|
|
|
|
# For logging -- don't flip
|
|
plddt_stack = torch.stack(plddt_stack)
|
|
|
|
# Save outputs
|
|
os.makedirs(os.path.dirname(out_prefix), exist_ok=True)
|
|
final_seq = seq_stack[-1]
|
|
|
|
# Output glycines, except for motif region
|
|
final_seq = torch.where(
|
|
torch.argmax(seq_init, dim=-1) == 21, 7, torch.argmax(seq_init, dim=-1)
|
|
) # 7 is glycine
|
|
|
|
bfacts = torch.ones_like(final_seq.squeeze())
|
|
# make bfact=0 for diffused coordinates
|
|
bfacts[torch.where(torch.argmax(seq_init, dim=-1) == 21, True, False)] = 0
|
|
# pX0 last step
|
|
out = f"{out_prefix}.pdb"
|
|
|
|
# Now don't output sidechains
|
|
writepdb(
|
|
out,
|
|
denoised_xyz_stack[0, :, :4],
|
|
final_seq,
|
|
sampler.binderlen,
|
|
chain_idx=sampler.chain_idx,
|
|
bfacts=bfacts,
|
|
idx_pdb=sampler.idx_pdb
|
|
)
|
|
|
|
# run metadata
|
|
trb = dict(
|
|
config=OmegaConf.to_container(sampler._conf, resolve=True),
|
|
plddt=plddt_stack.cpu().numpy(),
|
|
device=torch.cuda.get_device_name(torch.cuda.current_device())
|
|
if torch.cuda.is_available()
|
|
else "CPU",
|
|
time=time.time() - start_time,
|
|
)
|
|
if hasattr(sampler, "contig_map"):
|
|
for key, value in sampler.contig_map.get_mappings().items():
|
|
trb[key] = value
|
|
with open(f"{out_prefix}.trb", "wb") as f_out:
|
|
pickle.dump(trb, f_out)
|
|
|
|
if sampler.inf_conf.write_trajectory:
|
|
# trajectory pdbs
|
|
traj_prefix = (
|
|
os.path.dirname(out_prefix) + "/traj/" + os.path.basename(out_prefix)
|
|
)
|
|
os.makedirs(os.path.dirname(traj_prefix), exist_ok=True)
|
|
|
|
out = f"{traj_prefix}_Xt-1_traj.pdb"
|
|
writepdb_multi(
|
|
out,
|
|
denoised_xyz_stack,
|
|
bfacts,
|
|
final_seq.squeeze(),
|
|
use_hydrogens=False,
|
|
backbone_only=False,
|
|
chain_ids=sampler.chain_idx,
|
|
)
|
|
|
|
out = f"{traj_prefix}_pX0_traj.pdb"
|
|
writepdb_multi(
|
|
out,
|
|
px0_xyz_stack,
|
|
bfacts,
|
|
final_seq.squeeze(),
|
|
use_hydrogens=False,
|
|
backbone_only=False,
|
|
chain_ids=sampler.chain_idx,
|
|
)
|
|
|
|
if conf.inference.empty_cache_per_design and torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
|
|
log.info(f"Finished design in {(time.time()-start_time)/60:.2f} minutes")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|