From 1bf46afa213eb45d5876174e31801613a99b4c16 Mon Sep 17 00:00:00 2001 From: ncorley Date: Wed, 3 Dec 2025 01:48:46 -0800 Subject: [PATCH] chore: ruff, fix checkpoint code --- .../mpnn/src/mpnn/inference_engines/mpnn.py | 13 ++++- models/rfd3/src/rfd3/engine.py | 5 +- src/foundry/inference_engines/base.py | 20 +++++-- .../inference_engines/checkpoint_registry.py | 58 ++++++++++--------- src/foundry_cli/download_checkpoints.py | 39 +++++-------- 5 files changed, 75 insertions(+), 60 deletions(-) diff --git a/models/mpnn/src/mpnn/inference_engines/mpnn.py b/models/mpnn/src/mpnn/inference_engines/mpnn.py index 3bb4851..037b626 100644 --- a/models/mpnn/src/mpnn/inference_engines/mpnn.py +++ b/models/mpnn/src/mpnn/inference_engines/mpnn.py @@ -54,11 +54,18 @@ class MPNNInferenceEngine: self.out_directory = out_directory self.write_fasta = write_fasta self.write_structures = write_structures - + # allow null for checkpoint path when foundry-installed # TODO: Currently this assumes the model type is the key in the registered path. Rework needed - self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \ - if not checkpoint_path else checkpoint_path + self.checkpoint_path = ( + str( + REGISTERED_CHECKPOINTS[ + self.model_type.replace("_", "") + ].get_default_path() + ) + if not checkpoint_path + else checkpoint_path + ) # Determine the device. if device is not None: diff --git a/models/rfd3/src/rfd3/engine.py b/models/rfd3/src/rfd3/engine.py index a277f22..8a0daa9 100644 --- a/models/rfd3/src/rfd3/engine.py +++ b/models/rfd3/src/rfd3/engine.py @@ -15,7 +15,6 @@ from toolz import merge_with from foundry.common import exists from foundry.inference_engines.base import BaseInferenceEngine -from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS from foundry.utils.alignment import weighted_rigid_align from foundry.utils.ddp import RankedLogger from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS @@ -38,7 +37,9 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True) @dataclass(kw_only=True) class RFD3InferenceConfig: - ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation + ckpt_path: str | Path = ( + "rfd3" # Defaults to foundry installation upon instantiation + ) diffusion_batch_size: int = 16 # RFD3 specific diff --git a/src/foundry/inference_engines/base.py b/src/foundry/inference_engines/base.py index cbfe17c..7954807 100644 --- a/src/foundry/inference_engines/base.py +++ b/src/foundry/inference_engines/base.py @@ -69,15 +69,25 @@ class BaseInferenceEngine: self.verbose = verbose # Resolve checkpoint path - if '.' not in str(ckpt_path): + if "." not in str(ckpt_path): # Assume registered model name = str(ckpt_path) - assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints' + assert ( + name in REGISTERED_CHECKPOINTS + ), "Checkpoint provided not and not in registered checkpoints" ckpt = REGISTERED_CHECKPOINTS[name] - + ckpt_path = ckpt.get_default_path() - ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path))) - assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path) + ranked_logger.info( + "Using checkpoint from default installation directory, got: {}".format( + str(ckpt_path) + ) + ) + assert os.path.exists( + ckpt_path + ), "Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}".format( + name, ckpt_path + ) self.ckpt_path = Path(ckpt_path).resolve() # Set random seed (only if seed is not None) diff --git a/src/foundry/inference_engines/checkpoint_registry.py b/src/foundry/inference_engines/checkpoint_registry.py index 92736f5..50f69e4 100644 --- a/src/foundry/inference_engines/checkpoint_registry.py +++ b/src/foundry/inference_engines/checkpoint_registry.py @@ -1,4 +1,5 @@ -'''Management of checkpoints''' +"""Management of checkpoints""" + import os from dataclasses import dataclass from pathlib import Path @@ -11,10 +12,13 @@ def get_default_checkpoint_dir() -> Path: 1. FOUNDRY_CHECKPOINTS_DIR environment variable 2. ~/.foundry/checkpoints """ - if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"): + if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get( + "FOUNDRY_CHECKPOINTS_DIR" + ): return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute() return Path.home() / ".foundry" / "checkpoints" + @dataclass class RegisteredCheckpoint: url: str @@ -28,39 +32,39 @@ class RegisteredCheckpoint: REGISTERED_CHECKPOINTS = { "rfd3": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt", - filename = "rfd3_latest.ckpt", - description = "RFdiffusion3 checkpoint", + url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt", + filename="rfd3_latest.ckpt", + description="RFdiffusion3 checkpoint", ), - "rf3": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt", - filename= "rf3_foundry_01_24_latest_remapped.ckpt", - description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)", + "rf3": RegisteredCheckpoint( + url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt", + filename="rf3_foundry_01_24_latest_remapped.ckpt", + description="latest RF3 checkpoint trained with data until 1/2024 (expect best performance)", ), - "proteinmpnn": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt", - filename= "proteinmpnn_v_48_020.pt", - description= "ProteinMPNN checkpoint", + "proteinmpnn": RegisteredCheckpoint( + url="https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt", + filename="proteinmpnn_v_48_020.pt", + description="ProteinMPNN checkpoint", ), "ligandmpnn": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt", - filename= "ligandmpnn_v_32_010_25.pt", - description= "LigandMPNN checkpoint", + url="https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt", + filename="ligandmpnn_v_32_010_25.pt", + description="LigandMPNN checkpoint", ), # Other models "rf3_preprint_921": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt", - filename = "rf3_foundry_09_21_preprint_remapped.ckpt", - description = "RF3 preprint checkpoint trained with data until 9/2021", + url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt", + filename="rf3_foundry_09_21_preprint_remapped.ckpt", + description="RF3 preprint checkpoint trained with data until 9/2021", ), "rf3_preprint_124": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt", - filename = "rf3_foundry_01_24_preprint_remapped.ckpt", - description= "RF3 preprint checkpoint trained with data until 1/2024", + url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt", + filename="rf3_foundry_01_24_preprint_remapped.ckpt", + description="RF3 preprint checkpoint trained with data until 1/2024", + ), + "solublempnn": RegisteredCheckpoint( + url="https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt", + filename="solublempnn_v_48_020.pt", + description="SolubleMPNN checkpoint", ), - "solublempnn": RegisteredCheckpoint( - url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt", - filename= "solublempnn_v_48_020.pt", - description= "SolubleMPNN checkpoint" - ) } diff --git a/src/foundry_cli/download_checkpoints.py b/src/foundry_cli/download_checkpoints.py index 175bdcf..0ccad7e 100644 --- a/src/foundry_cli/download_checkpoints.py +++ b/src/foundry_cli/download_checkpoints.py @@ -5,7 +5,6 @@ from pathlib import Path from typing import Optional from urllib.request import urlopen -import rootutils import typer from dotenv import find_dotenv, load_dotenv, set_key from rich.console import Console @@ -29,6 +28,7 @@ load_dotenv(override=True) app = typer.Typer(help="Foundry model checkpoint installation utilities") console = Console() + def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None: """Download a file with progress bar and optional hash verification. @@ -81,9 +81,7 @@ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> No console.print("[green]✓[/green] Hash verification passed") -def install_model( - model_name: str, checkpoint_dir: Path, force: bool = False -) -> None: +def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) -> None: """Install a single model checkpoint. Args: @@ -112,9 +110,7 @@ def install_model( ) try: - download_file( - checkpoint_info.url, dest_path, checkpoint_info.sha256 - ) + download_file(checkpoint_info.url, dest_path, checkpoint_info.sha256) console.print( f"[green]✓[/green] Successfully installed {model_name} to {dest_path}" ) @@ -158,7 +154,7 @@ def install( # Expand 'all' to all available models if "all" in models: - models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3'] + models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"] else: models_to_install = models @@ -167,15 +163,16 @@ def install( install_model(model_name, checkpoint_dir, force) console.print() - set_key( - dotenv_path=find_dotenv(), - key_to_set='FOUNDRY_CHECKPOINTS_DIR', - value_to_set=str(checkpoint_dir), - export = False, - ) - console.print( - f"Set checkpoint installation directory to: {checkpoint_dir}" - ) + # Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.) + dotenv_path = find_dotenv() + if dotenv_path: + set_key( + dotenv_path=dotenv_path, + key_to_set="FOUNDRY_CHECKPOINTS_DIR", + value_to_set=str(checkpoint_dir), + export=False, + ) + console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}") console.print("[bold green]Installation complete![/bold green]") @@ -209,9 +206,7 @@ def show( checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) if not checkpoint_files: - console.print( - f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]" - ) + console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") raise typer.Exit(0) console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n") @@ -247,9 +242,7 @@ def clean( # List files to delete checkpoint_files = list(checkpoint_dir.glob("*.ckpt")) if not checkpoint_files: - console.print( - f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]" - ) + console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]") raise typer.Exit(0) console.print("[bold]Files to delete:[/bold]")