chore: ruff, fix checkpoint code

This commit is contained in:
ncorley
2025-12-03 01:48:46 -08:00
parent d14203b264
commit 1bf46afa21
5 changed files with 75 additions and 60 deletions

View File

@@ -54,11 +54,18 @@ class MPNNInferenceEngine:
self.out_directory = out_directory
self.write_fasta = write_fasta
self.write_structures = write_structures
# allow null for checkpoint path when foundry-installed
# TODO: Currently this assumes the model type is the key in the registered path. Rework needed
self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
if not checkpoint_path else checkpoint_path
self.checkpoint_path = (
str(
REGISTERED_CHECKPOINTS[
self.model_type.replace("_", "")
].get_default_path()
)
if not checkpoint_path
else checkpoint_path
)
# Determine the device.
if device is not None:

View File

@@ -15,7 +15,6 @@ from toolz import merge_with
from foundry.common import exists
from foundry.inference_engines.base import BaseInferenceEngine
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
from foundry.utils.alignment import weighted_rigid_align
from foundry.utils.ddp import RankedLogger
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
@@ -38,7 +37,9 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@dataclass(kw_only=True)
class RFD3InferenceConfig:
ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
ckpt_path: str | Path = (
"rfd3" # Defaults to foundry installation upon instantiation
)
diffusion_batch_size: int = 16
# RFD3 specific

View File

@@ -69,15 +69,25 @@ class BaseInferenceEngine:
self.verbose = verbose
# Resolve checkpoint path
if '.' not in str(ckpt_path):
if "." not in str(ckpt_path):
# Assume registered model
name = str(ckpt_path)
assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
assert (
name in REGISTERED_CHECKPOINTS
), "Checkpoint provided not and not in registered checkpoints"
ckpt = REGISTERED_CHECKPOINTS[name]
ckpt_path = ckpt.get_default_path()
ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
ranked_logger.info(
"Using checkpoint from default installation directory, got: {}".format(
str(ckpt_path)
)
)
assert os.path.exists(
ckpt_path
), "Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}".format(
name, ckpt_path
)
self.ckpt_path = Path(ckpt_path).resolve()
# Set random seed (only if seed is not None)

View File

@@ -1,4 +1,5 @@
'''Management of checkpoints'''
"""Management of checkpoints"""
import os
from dataclasses import dataclass
from pathlib import Path
@@ -11,10 +12,13 @@ def get_default_checkpoint_dir() -> Path:
1. FOUNDRY_CHECKPOINTS_DIR environment variable
2. ~/.foundry/checkpoints
"""
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
"FOUNDRY_CHECKPOINTS_DIR"
):
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
return Path.home() / ".foundry" / "checkpoints"
@dataclass
class RegisteredCheckpoint:
url: str
@@ -28,39 +32,39 @@ class RegisteredCheckpoint:
REGISTERED_CHECKPOINTS = {
"rfd3": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
filename = "rfd3_latest.ckpt",
description = "RFdiffusion3 checkpoint",
url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
filename="rfd3_latest.ckpt",
description="RFdiffusion3 checkpoint",
),
"rf3": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
filename= "rf3_foundry_01_24_latest_remapped.ckpt",
description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
"rf3": RegisteredCheckpoint(
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
filename="rf3_foundry_01_24_latest_remapped.ckpt",
description="latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
),
"proteinmpnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
filename= "proteinmpnn_v_48_020.pt",
description= "ProteinMPNN checkpoint",
"proteinmpnn": RegisteredCheckpoint(
url="https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
filename="proteinmpnn_v_48_020.pt",
description="ProteinMPNN checkpoint",
),
"ligandmpnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
filename= "ligandmpnn_v_32_010_25.pt",
description= "LigandMPNN checkpoint",
url="https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
filename="ligandmpnn_v_32_010_25.pt",
description="LigandMPNN checkpoint",
),
# Other models
"rf3_preprint_921": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
description = "RF3 preprint checkpoint trained with data until 9/2021",
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
filename="rf3_foundry_09_21_preprint_remapped.ckpt",
description="RF3 preprint checkpoint trained with data until 9/2021",
),
"rf3_preprint_124": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
description= "RF3 preprint checkpoint trained with data until 1/2024",
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
filename="rf3_foundry_01_24_preprint_remapped.ckpt",
description="RF3 preprint checkpoint trained with data until 1/2024",
),
"solublempnn": RegisteredCheckpoint(
url="https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
filename="solublempnn_v_48_020.pt",
description="SolubleMPNN checkpoint",
),
"solublempnn": RegisteredCheckpoint(
url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
filename= "solublempnn_v_48_020.pt",
description= "SolubleMPNN checkpoint"
)
}

View File

@@ -5,7 +5,6 @@ from pathlib import Path
from typing import Optional
from urllib.request import urlopen
import rootutils
import typer
from dotenv import find_dotenv, load_dotenv, set_key
from rich.console import Console
@@ -29,6 +28,7 @@ load_dotenv(override=True)
app = typer.Typer(help="Foundry model checkpoint installation utilities")
console = Console()
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
"""Download a file with progress bar and optional hash verification.
@@ -81,9 +81,7 @@ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> No
console.print("[green]✓[/green] Hash verification passed")
def install_model(
model_name: str, checkpoint_dir: Path, force: bool = False
) -> None:
def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) -> None:
"""Install a single model checkpoint.
Args:
@@ -112,9 +110,7 @@ def install_model(
)
try:
download_file(
checkpoint_info.url, dest_path, checkpoint_info.sha256
)
download_file(checkpoint_info.url, dest_path, checkpoint_info.sha256)
console.print(
f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
)
@@ -158,7 +154,7 @@ def install(
# Expand 'all' to all available models
if "all" in models:
models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
else:
models_to_install = models
@@ -167,15 +163,16 @@ def install(
install_model(model_name, checkpoint_dir, force)
console.print()
set_key(
dotenv_path=find_dotenv(),
key_to_set='FOUNDRY_CHECKPOINTS_DIR',
value_to_set=str(checkpoint_dir),
export = False,
)
console.print(
f"Set checkpoint installation directory to: {checkpoint_dir}"
)
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
dotenv_path = find_dotenv()
if dotenv_path:
set_key(
dotenv_path=dotenv_path,
key_to_set="FOUNDRY_CHECKPOINTS_DIR",
value_to_set=str(checkpoint_dir),
export=False,
)
console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
console.print("[bold green]Installation complete![/bold green]")
@@ -209,9 +206,7 @@ def show(
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
if not checkpoint_files:
console.print(
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
)
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
raise typer.Exit(0)
console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
@@ -247,9 +242,7 @@ def clean(
# List files to delete
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
if not checkpoint_files:
console.print(
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
)
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
raise typer.Exit(0)
console.print("[bold]Files to delete:[/bold]")