diff --git a/.env b/.env index d771095..b4ad1a8 100644 --- a/.env +++ b/.env @@ -4,6 +4,8 @@ # - you will need to set the paths below as appropriate for your environment. # We provide examples in the comments to give you an idea of the expected format. +# Foundry install dir for checkpoints +FOUNDRY_CHECKPOINTS_DIR= # --- Mirrors to RCSB data --- @@ -58,5 +60,3 @@ COLABFOLD_LOCAL_DB_PATH_CPU= # Network access (fallback; may cause IO-related issues) COLABFOLD_NET_DB_PATH_GPU= COLABFOLD_NET_DB_PATH_CPU= - - diff --git a/.gitignore b/.gitignore index 245088f..cf405e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +# For docs / outputs from example notebooks; +examples/*.cif +**.ckpt +**.pt + # Base .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore *.lock **.nfs** diff --git a/README.md b/README.md index 19367ca..e668906 100644 --- a/README.md +++ b/README.md @@ -19,22 +19,23 @@ This will download all the models supported (including multiple checkpoints of r foundry install rfd3 ligandmpnn rf3 --checkpoint_dir ``` -**Running a basic example of everything** See `examples/all.ipynb` for how to run each model in a notebook. - -We include details DNA, Ligands, Protein-Protein Interaction, Symmetry-conditioned systems and enzymes. +>*See `examples/all.ipynb` for how to run each model in a notebook.* ### RFdiffusion3 [RFdiffusion3](https://www.biorxiv.org/content/10.1101/2025.09.18.676967v2) is an all-atom generative model capable of designing protein structures under complex constraints. -> See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation. +> *See [models/rfd3/README.md](models/rfd3/README.md) for complete documentation.* + +
+ RFdiffusion3 generation trajectory. +
### ProteinMPNN [ProteinMPNN](https://www.science.org/doi/10.1126/science.add2187) and [LigandMPNN](https://www.nature.com/articles/s41592-025-02626-1) are lightweight inverse-folding models which can be use to design diverse sequences for backbones under constrained conditions. > *See [models/mpnn/README.md](models/mpnn/README.md) for complete documentation.* - ### RosettaFold3 [RF3](https://doi.org/10.1101/2025.08.14.670328) is a structure prediction neural network that narrows the gap between closed-source AF-3 and open-source alternatives. diff --git a/examples/all.ipynb b/examples/all.ipynb index be2f80e..d191938 100644 --- a/examples/all.ipynb +++ b/examples/all.ipynb @@ -19,6 +19,8 @@ "\n", "All models are unified through [AtomWorks](https://github.com/RosettaCommons/atomworks) (for both inference and training), relying on Biotite `AtomArray` objects.\n", "\n", + "Note: This notebook assumes you have the base checkpoints downloaded: `foundry install rfd3 ligandmpnn rf3`. You can also specify the paths directly yourself if you wish.\n", + "\n", "### Pipeline Flow\n", "```\n", "RFD3 (backbone) → MPNN (sequence) → RF3 (validation) → RMSD comparison\n", @@ -29,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "id": "819e8193", "metadata": {}, "outputs": [], @@ -72,9 +74,8 @@ "\n", "# Configure RFD3 inference\n", "config = RFD3InferenceConfig(\n", - " ckpt_path='/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt',\n", " specification={\n", - " 'length': 80, # Generate 80-residue proteins\n", + " 'length': 15, # Generate 80-residue proteins\n", " },\n", " diffusion_batch_size=2, # Generate 2 structures per batch\n", ")\n", @@ -144,7 +145,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "d07ae413", "metadata": {}, "outputs": [], @@ -155,8 +156,7 @@ "# See mpnn.utils.inference.MPNN_GLOBAL_INFERENCE_DEFAULTS for all options\n", "engine_config = {\n", " \"model_type\": \"ligand_mpnn\", # or \"protein_mpnn\" for vanilla ProteinMPNN\n", - " \"checkpoint_path\": \"/databases/mpnn/ligand_mpnn_model_weights/s25_r010_t300_p.pt\",\n", - " \"is_legacy_weights\": True,\n", + " \"is_legacy_weights\": True, # Required for now for ligand_mpnn and protein_mpnn\n", " \"out_directory\": None, # Return results in memory\n", " \"write_structures\": False,\n", " \"write_fasta\": False,\n", @@ -234,10 +234,9 @@ "from rf3.inference_engines.rf3 import RF3InferenceEngine\n", "from rf3.utils.inference import InferenceInput\n", "\n", - "CKPT_PATH = \"/net/software/containers/versions/modelhub_inference/ckpts/rf3-w-conf-run10-ep922-remapped.ckpt\"\n", "\n", "# Initialize RF3 inference engine\n", - "inference_engine = RF3InferenceEngine(ckpt_path=CKPT_PATH, verbose=False)\n", + "inference_engine = RF3InferenceEngine(ckpt_path='rf3_preprint_921', verbose=False)\n", "\n", "# Create input from the MPNN-designed structure (first design)\n", "# This re-folds the sequence to validate it adopts the intended structure\n", @@ -379,6 +378,12 @@ "\n", "![Superimposed Protein](../docs/_static/superimposed_80_residue_protein.png)" ] + }, + { + "cell_type": "markdown", + "id": "c439c90d", + "metadata": {}, + "source": [] } ], "metadata": { diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index 59e6df1..3bb4851 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -27,6 +27,7 @@ from mpnn.utils.inference import ( ) from mpnn.utils.weights import load_legacy_weights +from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS from foundry.metrics.metric import MetricManager from foundry.utils.ddp import RankedLogger @@ -49,11 +50,15 @@ class MPNNInferenceEngine: ): # Store raw configuration self.model_type = model_type - self.checkpoint_path = checkpoint_path self.is_legacy_weights = is_legacy_weights self.out_directory = out_directory self.write_fasta = write_fasta self.write_structures = write_structures + + # allow null for checkpoint path when foundry-installed + # TODO: Currently this assumes the model type is the key in the registered path. Rework needed + self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \ + if not checkpoint_path else checkpoint_path # Determine the device. if device is not None: diff --git a/models/mpnn/src/mpnn/utils/inference.py b/models/mpnn/src/mpnn/utils/inference.py index 6dc352f..9ba3904 100644 --- a/models/mpnn/src/mpnn/utils/inference.py +++ b/models/mpnn/src/mpnn/utils/inference.py @@ -36,8 +36,8 @@ MPNN_GLOBAL_INFERENCE_DEFAULTS: dict[str, Any] = { # Top-level Config JSON "config_json": None, # Model Type and Weights - "model_type": None, "checkpoint_path": None, + "model_type": None, "is_legacy_weights": None, # Output controls "out_directory": None, diff --git a/models/rfd3/README.md b/models/rfd3/README.md index 117667c..76c8694 100644 --- a/models/rfd3/README.md +++ b/models/rfd3/README.md @@ -105,13 +105,6 @@ The output directory will automatically be created. For full details on how to specify inputs, see the [input specification documentation](./docs/input.md). You can also see `models/rfd3/configs/inference_engine/rfdiffusion3.yaml`. -## Training (w & w/o WandB): #TODO make sure correct - -To launch a training run, use: -``` -uv run python models/rfd3/src/rfd3/train.py experiment=pretrain -``` -See the paths [configs](/models/rfd3/configs/paths/) to customize the paths where data is read from and where logs are written. There is also a wandb config that can be enabled if you want to log training through wandb. ### Install HBPLUS for training with hydrogen bond conditioning: @@ -120,10 +113,34 @@ See the paths [configs](/models/rfd3/configs/paths/) to customize the paths wher 2. Follow the installation instruction here: https://www.ebi.ac.uk/thornton-srv/software/HBPLUS/install.html 3. Update `HBPLUS_PATH` in `foundry/.env` file with the path to your `hbplus` executable. +## Training (w & w/o WandB): #TODO make sure correct +**Launching:** To launch distributed training on slurm, we recommend the following setup: +``` +EFFECTIVE_BATCH_SIZE=16 +DEVICES_PER_NODE= #INSERT NUMBER OF DEVICES PER NODE +NNODES = # INSERT NUMBER OF NODES +GRAD_ACCUM_STEPS=$((EFFECTIVE_BATCH_SIZE / (DEVICES_PER_NODE * NNODES))) +uv run python models/rfd3/src/rfd3/train.py \ + experiment=$SLURM_JOB_NAME \ + trainer.devices_per_node=$DEVICES_PER_NODE \ + trainer.num_nodes=$SLURM_NNODES \ + trainer.grad_accum_steps=$GRAD_ACCUM_STEPS" +``` +Notably, fabric must receive `devices_per_node` and the number of nodes (`num_nodes`) you're training on. + +**Dataset Paths:** See the paths [configs](/models/rfd3/configs/paths/) to customize the paths where data is read from and where logs are written. There is also a wandb config that can be enabled if you want to log training through wandb. + +**Hydra configs and experiments:** In the example above, the `experiment` argument is a hydra-native argument. For RFD3, it will look for config overrides in `/models/rfd3/configs/experiment/.yaml` and apply them on top of the base configs + +**Conditioning during training:** RFD3 is trained on a multitude of conditioning tasks, and does so by randomly 'creating problems' for it to solve during training. For example, for a random training example it gets a random set of tokens to be 'motif tokens', then subsets those to whether specific atoms should be fixed, and further subsets the information to whether, say, sequence, coordinates or the sequence index should be fixed. It's pretty complicated to evaluate and it's more of an art than a science how this was put together; which means there's likely some optimization further work can do! + +In `models/rfd3/configs/datasets/design_base.yaml` there's the shared configs for all datasets under `global_transform_args`. The dials that control the conditioning described above go under `training_conditions`, where for example `tipatom` - a specific preset conditioning sampler which more frequently fixes few tokens with few atoms - and others can be found. + +**Training with WandB:** We strongly recommend tracking your runs via wandb. To use it, simply have your WANDB_API_KEY set. For more details see [here](wandb.ai) ## Citation -If you use this code or data in your work, please cite: +If you use this code or data in your work, please consider citing: ```bibtex @article {butcher2025_rfdiffusion3, diff --git a/models/rfd3/src/rfd3/engine.py b/models/rfd3/src/rfd3/engine.py index 84d3dec..a277f22 100644 --- a/models/rfd3/src/rfd3/engine.py +++ b/models/rfd3/src/rfd3/engine.py @@ -15,6 +15,7 @@ from toolz import merge_with from foundry.common import exists from foundry.inference_engines.base import BaseInferenceEngine +from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS from foundry.utils.alignment import weighted_rigid_align from foundry.utils.ddp import RankedLogger from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS @@ -37,7 +38,7 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True) @dataclass(kw_only=True) class RFD3InferenceConfig: - ckpt_path: str = "/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt" + ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation diffusion_batch_size: int = 16 # RFD3 specific diff --git a/src/foundry/inference_engines/base.py b/src/foundry/inference_engines/base.py index 79f5f81..cbfe17c 100644 --- a/src/foundry/inference_engines/base.py +++ b/src/foundry/inference_engines/base.py @@ -1,4 +1,5 @@ import logging +import os from os import PathLike from pathlib import Path from typing import Any, Dict @@ -9,6 +10,7 @@ from biotite.structure import AtomArray from lightning.fabric import seed_everything from omegaconf import OmegaConf +from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS from foundry.utils.ddp import RankedLogger, set_accelerator_based_on_availability from foundry.utils.logging import ( configure_minimal_inference_logging, @@ -65,7 +67,18 @@ class BaseInferenceEngine: self.trainer = None self.pipeline = None self.verbose = verbose - self.ckpt_path = ckpt_path + + # Resolve checkpoint path + if '.' not in str(ckpt_path): + # Assume registered model + name = str(ckpt_path) + assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints' + ckpt = REGISTERED_CHECKPOINTS[name] + + ckpt_path = ckpt.get_default_path() + ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path))) + assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path) + self.ckpt_path = Path(ckpt_path).resolve() # Set random seed (only if seed is not None) if seed is not None: diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py new file mode 100644 index 0000000..92736f5 --- /dev/null +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -0,0 +1,66 @@ +'''Management of checkpoints''' +import os +from dataclasses import dataclass +from pathlib import Path + + +def get_default_checkpoint_dir() -> Path: + """Get the default checkpoint directory. + + Priority: + 1. FOUNDRY_CHECKPOINTS_DIR environment variable + 2. ~/.foundry/checkpoints + """ + if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"): + return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute() + return Path.home() / ".foundry" / "checkpoints" + +@dataclass +class RegisteredCheckpoint: + url: str + filename: str + description: str + sha256: None = None # Optional: add checksum for verification + + def get_default_path(self): + return get_default_checkpoint_dir() / self.filename + + +REGISTERED_CHECKPOINTS = { + "rfd3": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt", + filename = "rfd3_latest.ckpt", + description = "RFdiffusion3 checkpoint", + ), + "rf3": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt", + filename= "rf3_foundry_01_24_latest_remapped.ckpt", + description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)", + ), + "proteinmpnn": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt", + filename= "proteinmpnn_v_48_020.pt", + description= "ProteinMPNN checkpoint", + ), + "ligandmpnn": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt", + filename= "ligandmpnn_v_32_010_25.pt", + description= "LigandMPNN checkpoint", + ), + # Other models + "rf3_preprint_921": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt", + filename = "rf3_foundry_09_21_preprint_remapped.ckpt", + description = "RF3 preprint checkpoint trained with data until 9/2021", + ), + "rf3_preprint_124": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt", + filename = "rf3_foundry_01_24_preprint_remapped.ckpt", + description= "RF3 preprint checkpoint trained with data until 1/2024", + ), + "solublempnn": RegisteredCheckpoint( + url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt", + filename= "solublempnn_v_48_020.pt", + description= "SolubleMPNN checkpoint" + ) +} diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index 2aec51a..6a8c756 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -1,12 +1,13 @@ """CLI for foundry model checkpoint installation and management.""" import hashlib -import os from pathlib import Path from typing import Optional from urllib.request import urlopen +import rootutils import typer +from dotenv import find_dotenv, load_dotenv, set_key from rich.console import Console from rich.progress import ( BarColumn, @@ -18,69 +19,17 @@ from rich.progress import ( TransferSpeedColumn, ) +from foundry.inference_engines.checkpoint_registry import ( + REGISTERED_CHECKPOINTS, + get_default_checkpoint_dir, +) + +rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) +load_dotenv(override=True) + app = typer.Typer(help="Foundry model checkpoint installation utilities") console = Console() -# Checkpoint URLs and metadata -# TODO: Replace these with your actual checkpoint URLs -CHECKPOINTS = { - "rfd3": { - "url": "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt", - "filename": "rfd3_latest.ckpt", - "sha256": None, # Optional: add checksum for verification - "description": "RFdiffusion3 checkpoint", - }, - "rf3_preprint_921": { - "url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt", - "filename": "rf3_foundry_09_21_preprint_remapped.ckpt", - "sha256": None, - "description": "RF3 preprint checkpoint trained with data until 9/2021", - }, - "rf3_preprint_124": { - "url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt", - "filename": "rf3_foundry_01_24_preprint_remapped.ckpt", - "sha256": None, - "description": "RF3 preprint checkpoint trained with data until 1/2024", - }, - "rf3": { - "url": "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt", - "filename": "rf3_foundry_01_24_latest_remapped.ckpt", - "sha256": None, - "description": "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)", - }, - "proteinmpnn": { - "url": "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt", - "filename": "proteinmpnn_v_48_020.pt", - "sha256": None, - "description": "ProteinMPNN checkpoint", - }, - "ligandmpnn": { - "url": "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt", - "filename": "ligandmpnn_v_32_010_25.pt", - "sha256": None, - "description": "LigandMPNN checkpoint", - }, - "solublempnn": { - "url": "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt", - "filename": "solublempnn_v_48_020.pt", - "sha256": None, - "description": "SolubleMPNN checkpoint", - } -} - - -def get_default_checkpoint_dir() -> Path: - """Get the default checkpoint directory. - - Priority: - 1. FOUNDRY_CHECKPOINTS_DIR environment variable - 2. ~/.foundry/checkpoints - """ - if "FOUNDRY_CHECKPOINTS_DIR" in os.environ: - return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]) - return Path.home() / ".foundry" / "checkpoints" - - def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None: """Download a file with progress bar and optional hash verification. @@ -143,12 +92,12 @@ def install_model( checkpoint_dir: Directory to save checkpoints force: Overwrite existing checkpoint if it exists """ - if model_name not in CHECKPOINTS: + if model_name not in REGISTERED_CHECKPOINTS: console.print(f"[red]Error:[/red] Unknown model '{model_name}'") - console.print(f"Available models: {', '.join(CHECKPOINTS.keys())}") + console.print(f"Available models: {', '.join(REGISTERED_CHECKPOINTS.keys())}") raise typer.Exit(1) - checkpoint_info = CHECKPOINTS[model_name] + checkpoint_info = REGISTERED_CHECKPOINTS[model_name] dest_path = checkpoint_dir / checkpoint_info["filename"] # Check if already exists @@ -210,7 +159,7 @@ def install( # Expand 'all' to all available models if "all" in models: - models_to_install = list(CHECKPOINTS.keys()) + models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3'] else: models_to_install = models @@ -219,6 +168,16 @@ def install( install_model(model_name, checkpoint_dir, force) console.print() + set_key( + dotenv_path=find_dotenv(), + key_to_set='FOUNDRY_CHECKPOINTS_DIR', + value_to_set=str(checkpoint_dir), + export = False, + ) + console.print( + f"Set checkpoint installation directory to: {checkpoint_dir}" + ) + console.print("[bold green]Installation complete![/bold green]") @@ -226,7 +185,7 @@ def install( def list_models(): """List available model checkpoints.""" console.print("[bold]Available models:[/bold]\n") - for name, info in CHECKPOINTS.items(): + for name, info in REGISTERED_CHECKPOINTS.items(): console.print(f" [cyan]{name:8}[/cyan] - {info['description']}")