Refactor/checkpoint defaults (#723)

* Registry rework

* Update README 12

* Mc

* Add all option correction
This commit is contained in:
Jasper Butcher
2025-12-02 23:43:51 -08:00
committed by Rohith Krishna
parent d07d003ae2
commit 94bb987998
11 changed files with 165 additions and 93 deletions

4
.env
View File

@@ -4,6 +4,8 @@
# - you will need to set the paths below as appropriate for your environment.
# We provide examples in the comments to give you an idea of the expected format.
# Foundry install dir for checkpoints
FOUNDRY_CHECKPOINTS_DIR=
# --- Mirrors to RCSB data ---
@@ -58,5 +60,3 @@ COLABFOLD_LOCAL_DB_PATH_CPU=
# Network access (fallback; may cause IO-related issues)
COLABFOLD_NET_DB_PATH_GPU=
COLABFOLD_NET_DB_PATH_CPU=

5
.gitignore vendored
View File

@@ -1,3 +1,8 @@
# For docs / outputs from example notebooks;
examples/*.cif
**.ckpt
**.pt
# Base .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore
*.lock
**.nfs**

View File

@@ -19,22 +19,23 @@ This will download all the models supported (including multiple checkpoints of r
foundry install rfd3 ligandmpnn rf3 --checkpoint_dir <path/to/ckpt/dir>
```
**Running a basic example of everything** See `examples/all.ipynb` for how to run each model in a notebook.
We include details DNA, Ligands, Protein-Protein Interaction, Symmetry-conditioned systems and enzymes.
>*See `examples/all.ipynb` for how to run each model in a notebook.*
### RFdiffusion3
[RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints.
> See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.
> *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.*
<div align="center">
<img src="docs/_static/rfd3_trajectory.png" alt="RFdiffusion3 generation trajectory." width="400">
</div>
### ProteinMPNN
[ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions.
> *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.*
### RosettaFold3
[RF3](https://doi.org/10.1101/2025.08.14.670328) is a structure prediction neural network that narrows the gap between closed-source AF-3 and open-source alternatives.

View File

@@ -19,6 +19,8 @@
"\n",
"All models are unified through [AtomWorks](https://github.com/RosettaCommons/atomworks) (for both inference and training), relying on Biotite `AtomArray` objects.\n",
"\n",
"Note: This notebook assumes you have the base checkpoints downloaded: `foundry install rfd3 ligandmpnn rf3`. You can also specify the paths directly yourself if you wish.\n",
"\n",
"### Pipeline Flow\n",
"```\n",
"RFD3 (backbone) → MPNN (sequence) → RF3 (validation) → RMSD comparison\n",
@@ -29,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 13,
"id": "819e8193",
"metadata": {},
"outputs": [],
@@ -72,9 +74,8 @@
"\n",
"# Configure RFD3 inference\n",
"config = RFD3InferenceConfig(\n",
" ckpt_path='/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt',\n",
" specification={\n",
" 'length': 80, # Generate 80-residue proteins\n",
" 'length': 15, # Generate 80-residue proteins\n",
" },\n",
" diffusion_batch_size=2, # Generate 2 structures per batch\n",
")\n",
@@ -144,7 +145,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"id": "d07ae413",
"metadata": {},
"outputs": [],
@@ -155,8 +156,7 @@
"# See mpnn.utils.inference.MPNN_GLOBAL_INFERENCE_DEFAULTS for all options\n",
"engine_config = {\n",
" \"model_type\": \"ligand_mpnn\", # or \"protein_mpnn\" for vanilla ProteinMPNN\n",
" \"checkpoint_path\": \"/databases/mpnn/ligand_mpnn_model_weights/s25_r010_t300_p.pt\",\n",
" \"is_legacy_weights\": True,\n",
" \"is_legacy_weights\": True, # Required for now for ligand_mpnn and protein_mpnn\n",
" \"out_directory\": None, # Return results in memory\n",
" \"write_structures\": False,\n",
" \"write_fasta\": False,\n",
@@ -234,10 +234,9 @@
"from rf3.inference_engines.rf3 import RF3InferenceEngine\n",
"from rf3.utils.inference import InferenceInput\n",
"\n",
"CKPT_PATH = \"/net/software/containers/versions/modelhub_inference/ckpts/rf3-w-conf-run10-ep922-remapped.ckpt\"\n",
"\n",
"# Initialize RF3 inference engine\n",
"inference_engine = RF3InferenceEngine(ckpt_path=CKPT_PATH, verbose=False)\n",
"inference_engine = RF3InferenceEngine(ckpt_path='rf3_preprint_921', verbose=False)\n",
"\n",
"# Create input from the MPNN-designed structure (first design)\n",
"# This re-folds the sequence to validate it adopts the intended structure\n",
@@ -379,6 +378,12 @@
"\n",
"![Superimposed Protein](../docs/_static/superimposed_80_residue_protein.png)"
]
},
{
"cell_type": "markdown",
"id": "c439c90d",
"metadata": {},
"source": []
}
],
"metadata": {

View File

@@ -27,6 +27,7 @@ from mpnn.utils.inference import (
)
from mpnn.utils.weights import load_legacy_weights
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
from foundry.metrics.metric import MetricManager
from foundry.utils.ddp import RankedLogger
@@ -49,11 +50,15 @@ class MPNNInferenceEngine:
):
# Store raw configuration
self.model_type = model_type
self.checkpoint_path = checkpoint_path
self.is_legacy_weights = is_legacy_weights
self.out_directory = out_directory
self.write_fasta = write_fasta
self.write_structures = write_structures
# allow null for checkpoint path when foundry-installed
# TODO: Currently this assumes the model type is the key in the registered path. Rework needed
self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
if not checkpoint_path else checkpoint_path
# Determine the device.
if device is not None:

View File

@@ -36,8 +36,8 @@ MPNN_GLOBAL_INFERENCE_DEFAULTS: dict[str, Any] = {
# Top-level Config JSON
"config_json": None,
# Model Type and Weights
"model_type": None,
"checkpoint_path": None,
"model_type": None,
"is_legacy_weights": None,
# Output controls
"out_directory": None,

View File

@@ -105,13 +105,6 @@ The output directory will automatically be created.
For full details on how to specify inputs, see the [input specification documentation](./docs/input.md). You can also see `models/rfd3/configs/inference_engine/rfdiffusion3.yaml`.
## Training (w & w/o WandB): #TODO make sure correct
To launch a training run, use:
```
uv run python models/rfd3/src/rfd3/train.py experiment=pretrain
```
See the paths [configs](/models/rfd3/configs/paths/) to customize the paths where data is read from and where logs are written. There is also a wandb config that can be enabled if you want to log training through wandb.
### Install HBPLUS for training with hydrogen bond conditioning:
@@ -120,10 +113,34 @@ See the paths [configs](/models/rfd3/configs/paths/) to customize the paths wher
2. Follow the installation instruction here: https://www.ebi.ac.uk/thornton-srv/software/HBPLUS/install.html
3. Update `HBPLUS_PATH` in `foundry/.env` file with the path to your `hbplus` executable.
## Training (w & w/o WandB): #TODO make sure correct
**Launching:** To launch distributed training on slurm, we recommend the following setup:
```
EFFECTIVE_BATCH_SIZE=16
DEVICES_PER_NODE= #INSERT NUMBER OF DEVICES PER NODE
NNODES = # INSERT NUMBER OF NODES
GRAD_ACCUM_STEPS=$((EFFECTIVE_BATCH_SIZE / (DEVICES_PER_NODE * NNODES)))
uv run python models/rfd3/src/rfd3/train.py \
experiment=$SLURM_JOB_NAME \
trainer.devices_per_node=$DEVICES_PER_NODE \
trainer.num_nodes=$SLURM_NNODES \
trainer.grad_accum_steps=$GRAD_ACCUM_STEPS"
```
Notably, fabric must receive `devices_per_node` and the number of nodes (`num_nodes`) you're training on.
**Dataset Paths:** See the paths [configs](/models/rfd3/configs/paths/) to customize the paths where data is read from and where logs are written. There is also a wandb config that can be enabled if you want to log training through wandb.
**Hydra configs and experiments:** In the example above, the `experiment` argument is a hydra-native argument. For RFD3, it will look for config overrides in `/models/rfd3/configs/experiment/<experiment-name>.yaml` and apply them on top of the base configs
**Conditioning during training:** RFD3 is trained on a multitude of conditioning tasks, and does so by randomly 'creating problems' for it to solve during training. For example, for a random training example it gets a random set of tokens to be 'motif tokens', then subsets those to whether specific atoms should be fixed, and further subsets the information to whether, say, sequence, coordinates or the sequence index should be fixed. It's pretty complicated to evaluate and it's more of an art than a science how this was put together; which means there's likely some optimization further work can do!
In `models/rfd3/configs/datasets/design_base.yaml` there's the shared configs for all datasets under `global_transform_args`. The dials that control the conditioning described above go under `training_conditions`, where for example `tipatom` - a specific preset conditioning sampler which more frequently fixes few tokens with few atoms - and others can be found.
**Training with WandB:** We strongly recommend tracking your runs via wandb. To use it, simply have your WANDB_API_KEY set. For more details see [here](wandb.ai)
## Citation
If you use this code or data in your work, please cite:
If you use this code or data in your work, please consider citing:
```bibtex
@article {butcher2025_rfdiffusion3,

View File

@@ -15,6 +15,7 @@ from toolz import merge_with
from foundry.common import exists
from foundry.inference_engines.base import BaseInferenceEngine
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
from foundry.utils.alignment import weighted_rigid_align
from foundry.utils.ddp import RankedLogger
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
@@ -37,7 +38,7 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@dataclass(kw_only=True)
class RFD3InferenceConfig:
ckpt_path: str = "/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt"
ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
diffusion_batch_size: int = 16
# RFD3 specific

View File

@@ -1,4 +1,5 @@
import logging
import os
from os import PathLike
from pathlib import Path
from typing import Any, Dict
@@ -9,6 +10,7 @@ from biotite.structure import AtomArray
from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability
from foundry.utils.logging import (
configure_minimal_inference_logging,
@@ -65,7 +67,18 @@ class BaseInferenceEngine:
self.trainer = None
self.pipeline = None
self.verbose = verbose
self.ckpt_path = ckpt_path
# Resolve checkpoint path
if '.' not in str(ckpt_path):
# Assume registered model
name = str(ckpt_path)
assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
ckpt = REGISTERED_CHECKPOINTS[name]
ckpt_path = ckpt.get_default_path()
ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
self.ckpt_path = Path(ckpt_path).resolve()
# Set random seed (only if seed is not None)
if seed is not None:

View File

@@ -0,0 +1,66 @@
'''Management of checkpoints'''
import os
from dataclasses import dataclass
from pathlib import Path
def get_default_checkpoint_dir() -> Path:
"""Get the default checkpoint directory.
Priority:
1. FOUNDRY_CHECKPOINTS_DIR environment variable
2. ~/.foundry/checkpoints
"""
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
return Path.home() / ".foundry" / "checkpoints"
@dataclass
class RegisteredCheckpoint:
url: str
filename: str
description: str
sha256: None = None # Optional: add checksum for verification
def get_default_path(self):
return get_default_checkpoint_dir() / self.filename
REGISTERED_CHECKPOINTS = {
"rfd3": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
filename = "rfd3_latest.ckpt",
description = "RFdiffusion3 checkpoint",
),
"rf3": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
filename= "rf3_foundry_01_24_latest_remapped.ckpt",
description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
),
"proteinmpnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
filename= "proteinmpnn_v_48_020.pt",
description= "ProteinMPNN checkpoint",
),
"ligandmpnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
filename= "ligandmpnn_v_32_010_25.pt",
description= "LigandMPNN checkpoint",
),
# Other models
"rf3_preprint_921": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
description = "RF3 preprint checkpoint trained with data until 9/2021",
),
"rf3_preprint_124": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
description= "RF3 preprint checkpoint trained with data until 1/2024",
),
"solublempnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
filename= "solublempnn_v_48_020.pt",
description= "SolubleMPNN checkpoint"
)
}

View File

@@ -1,12 +1,13 @@
"""CLI for foundry model checkpoint installation and management."""
import hashlib
import os
from pathlib import Path
from typing import Optional
from urllib.request import urlopen
import rootutils
import typer
from dotenv import find_dotenv, load_dotenv, set_key
from rich.console import Console
from rich.progress import (
BarColumn,
@@ -18,69 +19,17 @@ from rich.progress import (
TransferSpeedColumn,
)
from foundry.inference_engines.checkpoint_registry import (
REGISTERED_CHECKPOINTS,
get_default_checkpoint_dir,
)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
load_dotenv(override=True)
app = typer.Typer(help="Foundry model checkpoint installation utilities")
console = Console()
# Checkpoint URLs and metadata
# TODO: Replace these with your actual checkpoint URLs
CHECKPOINTS = {
"rfd3": {
"url": "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
"filename": "rfd3_latest.ckpt",
"sha256": None, # Optional: add checksum for verification
"description": "RFdiffusion3 checkpoint",
},
"rf3_preprint_921": {
"url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
"filename": "rf3_foundry_09_21_preprint_remapped.ckpt",
"sha256": None,
"description": "RF3 preprint checkpoint trained with data until 9/2021",
},
"rf3_preprint_124": {
"url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
"filename": "rf3_foundry_01_24_preprint_remapped.ckpt",
"sha256": None,
"description": "RF3 preprint checkpoint trained with data until 1/2024",
},
"rf3": {
"url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
"filename": "rf3_foundry_01_24_latest_remapped.ckpt",
"sha256": None,
"description": "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
},
"proteinmpnn": {
"url": "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
"filename": "proteinmpnn_v_48_020.pt",
"sha256": None,
"description": "ProteinMPNN checkpoint",
},
"ligandmpnn": {
"url": "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
"filename": "ligandmpnn_v_32_010_25.pt",
"sha256": None,
"description": "LigandMPNN checkpoint",
},
"solublempnn": {
"url": "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
"filename": "solublempnn_v_48_020.pt",
"sha256": None,
"description": "SolubleMPNN checkpoint",
}
}
def get_default_checkpoint_dir() -> Path:
"""Get the default checkpoint directory.
Priority:
1. FOUNDRY_CHECKPOINTS_DIR environment variable
2. ~/.foundry/checkpoints
"""
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ:
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"])
return Path.home() / ".foundry" / "checkpoints"
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
"""Download a file with progress bar and optional hash verification.
@@ -143,12 +92,12 @@ def install_model(
checkpoint_dir: Directory to save checkpoints
force: Overwrite existing checkpoint if it exists
"""
if model_name not in CHECKPOINTS:
if model_name not in REGISTERED_CHECKPOINTS:
console.print(f"[red]Error:[/red] Unknown model '{model_name}'")
console.print(f"Available models: {', '.join(CHECKPOINTS.keys())}")
console.print(f"Available models: {', '.join(REGISTERED_CHECKPOINTS.keys())}")
raise typer.Exit(1)
checkpoint_info = CHECKPOINTS[model_name]
checkpoint_info = REGISTERED_CHECKPOINTS[model_name]
dest_path = checkpoint_dir / checkpoint_info["filename"]
# Check if already exists
@@ -210,7 +159,7 @@ def install(
# Expand 'all' to all available models
if "all" in models:
models_to_install = list(CHECKPOINTS.keys())
models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
else:
models_to_install = models
@@ -219,6 +168,16 @@ def install(
install_model(model_name, checkpoint_dir, force)
console.print()
set_key(
dotenv_path=find_dotenv(),
key_to_set='FOUNDRY_CHECKPOINTS_DIR',
value_to_set=str(checkpoint_dir),
export = False,
)
console.print(
f"Set checkpoint installation directory to: {checkpoint_dir}"
)
console.print("[bold green]Installation complete![/bold green]")
@@ -226,7 +185,7 @@ def install(
def list_models():
"""List available model checkpoints."""
console.print("[bold]Available models:[/bold]\n")
for name, info in CHECKPOINTS.items():
for name, info in REGISTERED_CHECKPOINTS.items():
console.print(f" [cyan]{name:8}[/cyan] - {info['description']}")