mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
chore: ruff, fix checkpoint code
This commit is contained in:
@@ -54,11 +54,18 @@ class MPNNInferenceEngine:
|
||||
self.out_directory = out_directory
|
||||
self.write_fasta = write_fasta
|
||||
self.write_structures = write_structures
|
||||
|
||||
|
||||
# allow null for checkpoint path when foundry-installed
|
||||
# TODO: Currently this assumes the model type is the key in the registered path. Rework needed
|
||||
self.checkpoint_path = str(REGISTERED_CHECKPOINTS[self.model_type.replace('_', '')].get_default_path()) \
|
||||
if not checkpoint_path else checkpoint_path
|
||||
self.checkpoint_path = (
|
||||
str(
|
||||
REGISTERED_CHECKPOINTS[
|
||||
self.model_type.replace("_", "")
|
||||
].get_default_path()
|
||||
)
|
||||
if not checkpoint_path
|
||||
else checkpoint_path
|
||||
)
|
||||
|
||||
# Determine the device.
|
||||
if device is not None:
|
||||
|
||||
@@ -15,7 +15,6 @@ from toolz import merge_with
|
||||
|
||||
from foundry.common import exists
|
||||
from foundry.inference_engines.base import BaseInferenceEngine
|
||||
from foundry.inference_engines.checkpoint_registry import REGISTERED_CHECKPOINTS
|
||||
from foundry.utils.alignment import weighted_rigid_align
|
||||
from foundry.utils.ddp import RankedLogger
|
||||
from rfd3.constants import SAVED_CONDITIONING_ANNOTATIONS
|
||||
@@ -38,7 +37,9 @@ ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RFD3InferenceConfig:
|
||||
ckpt_path: str | Path = 'rfd3' # Defaults to foundry installation upon instantiation
|
||||
ckpt_path: str | Path = (
|
||||
"rfd3" # Defaults to foundry installation upon instantiation
|
||||
)
|
||||
diffusion_batch_size: int = 16
|
||||
|
||||
# RFD3 specific
|
||||
|
||||
@@ -69,15 +69,25 @@ class BaseInferenceEngine:
|
||||
self.verbose = verbose
|
||||
|
||||
# Resolve checkpoint path
|
||||
if '.' not in str(ckpt_path):
|
||||
if "." not in str(ckpt_path):
|
||||
# Assume registered model
|
||||
name = str(ckpt_path)
|
||||
assert name in REGISTERED_CHECKPOINTS, 'Checkpoint provided not and not in registered checkpoints'
|
||||
assert (
|
||||
name in REGISTERED_CHECKPOINTS
|
||||
), "Checkpoint provided not and not in registered checkpoints"
|
||||
ckpt = REGISTERED_CHECKPOINTS[name]
|
||||
|
||||
|
||||
ckpt_path = ckpt.get_default_path()
|
||||
ranked_logger.info("Using checkpoint from default installation directory, got: {}".format(str(ckpt_path)))
|
||||
assert os.path.exists(ckpt_path), 'Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}'.format(name, ckpt_path)
|
||||
ranked_logger.info(
|
||||
"Using checkpoint from default installation directory, got: {}".format(
|
||||
str(ckpt_path)
|
||||
)
|
||||
)
|
||||
assert os.path.exists(
|
||||
ckpt_path
|
||||
), "Invalid checkpoint: {}. And could not find checkpoint in default installation location: {}".format(
|
||||
name, ckpt_path
|
||||
)
|
||||
self.ckpt_path = Path(ckpt_path).resolve()
|
||||
|
||||
# Set random seed (only if seed is not None)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
'''Management of checkpoints'''
|
||||
"""Management of checkpoints"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -11,10 +12,13 @@ def get_default_checkpoint_dir() -> Path:
|
||||
1. FOUNDRY_CHECKPOINTS_DIR environment variable
|
||||
2. ~/.foundry/checkpoints
|
||||
"""
|
||||
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get("FOUNDRY_CHECKPOINTS_DIR"):
|
||||
if "FOUNDRY_CHECKPOINTS_DIR" in os.environ and os.environ.get(
|
||||
"FOUNDRY_CHECKPOINTS_DIR"
|
||||
):
|
||||
return Path(os.environ["FOUNDRY_CHECKPOINTS_DIR"]).absolute()
|
||||
return Path.home() / ".foundry" / "checkpoints"
|
||||
|
||||
|
||||
@dataclass
|
||||
class RegisteredCheckpoint:
|
||||
url: str
|
||||
@@ -28,39 +32,39 @@ class RegisteredCheckpoint:
|
||||
|
||||
REGISTERED_CHECKPOINTS = {
|
||||
"rfd3": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
|
||||
filename = "rfd3_latest.ckpt",
|
||||
description = "RFdiffusion3 checkpoint",
|
||||
url="https://files.ipd.uw.edu/pub/rfd3/rfd3_foundry_2025_12_01_remapped.ckpt",
|
||||
filename="rfd3_latest.ckpt",
|
||||
description="RFdiffusion3 checkpoint",
|
||||
),
|
||||
"rf3": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
|
||||
filename= "rf3_foundry_01_24_latest_remapped.ckpt",
|
||||
description= "latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
|
||||
"rf3": RegisteredCheckpoint(
|
||||
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_latest_remapped.ckpt",
|
||||
filename="rf3_foundry_01_24_latest_remapped.ckpt",
|
||||
description="latest RF3 checkpoint trained with data until 1/2024 (expect best performance)",
|
||||
),
|
||||
"proteinmpnn": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
|
||||
filename= "proteinmpnn_v_48_020.pt",
|
||||
description= "ProteinMPNN checkpoint",
|
||||
"proteinmpnn": RegisteredCheckpoint(
|
||||
url="https://files.ipd.uw.edu/pub/ligandmpnn/proteinmpnn_v_48_020.pt",
|
||||
filename="proteinmpnn_v_48_020.pt",
|
||||
description="ProteinMPNN checkpoint",
|
||||
),
|
||||
"ligandmpnn": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
|
||||
filename= "ligandmpnn_v_32_010_25.pt",
|
||||
description= "LigandMPNN checkpoint",
|
||||
url="https://files.ipd.uw.edu/pub/ligandmpnn/ligandmpnn_v_32_010_25.pt",
|
||||
filename="ligandmpnn_v_32_010_25.pt",
|
||||
description="LigandMPNN checkpoint",
|
||||
),
|
||||
# Other models
|
||||
"rf3_preprint_921": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
|
||||
filename = "rf3_foundry_09_21_preprint_remapped.ckpt",
|
||||
description = "RF3 preprint checkpoint trained with data until 9/2021",
|
||||
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_09_21_preprint_remapped.ckpt",
|
||||
filename="rf3_foundry_09_21_preprint_remapped.ckpt",
|
||||
description="RF3 preprint checkpoint trained with data until 9/2021",
|
||||
),
|
||||
"rf3_preprint_124": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
|
||||
filename = "rf3_foundry_01_24_preprint_remapped.ckpt",
|
||||
description= "RF3 preprint checkpoint trained with data until 1/2024",
|
||||
url="https://files.ipd.uw.edu/pub/rf3/rf3_foundry_01_24_preprint_remapped.ckpt",
|
||||
filename="rf3_foundry_01_24_preprint_remapped.ckpt",
|
||||
description="RF3 preprint checkpoint trained with data until 1/2024",
|
||||
),
|
||||
"solublempnn": RegisteredCheckpoint(
|
||||
url="https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
|
||||
filename="solublempnn_v_48_020.pt",
|
||||
description="SolubleMPNN checkpoint",
|
||||
),
|
||||
"solublempnn": RegisteredCheckpoint(
|
||||
url = "https://files.ipd.uw.edu/pub/ligandmpnn/solublempnn_v_48_020.pt",
|
||||
filename= "solublempnn_v_48_020.pt",
|
||||
description= "SolubleMPNN checkpoint"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.request import urlopen
|
||||
|
||||
import rootutils
|
||||
import typer
|
||||
from dotenv import find_dotenv, load_dotenv, set_key
|
||||
from rich.console import Console
|
||||
@@ -29,6 +28,7 @@ load_dotenv(override=True)
|
||||
app = typer.Typer(help="Foundry model checkpoint installation utilities")
|
||||
console = Console()
|
||||
|
||||
|
||||
def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> None:
|
||||
"""Download a file with progress bar and optional hash verification.
|
||||
|
||||
@@ -81,9 +81,7 @@ def download_file(url: str, dest: Path, verify_hash: Optional[str] = None) -> No
|
||||
console.print("[green]✓[/green] Hash verification passed")
|
||||
|
||||
|
||||
def install_model(
|
||||
model_name: str, checkpoint_dir: Path, force: bool = False
|
||||
) -> None:
|
||||
def install_model(model_name: str, checkpoint_dir: Path, force: bool = False) -> None:
|
||||
"""Install a single model checkpoint.
|
||||
|
||||
Args:
|
||||
@@ -112,9 +110,7 @@ def install_model(
|
||||
)
|
||||
|
||||
try:
|
||||
download_file(
|
||||
checkpoint_info.url, dest_path, checkpoint_info.sha256
|
||||
)
|
||||
download_file(checkpoint_info.url, dest_path, checkpoint_info.sha256)
|
||||
console.print(
|
||||
f"[green]✓[/green] Successfully installed {model_name} to {dest_path}"
|
||||
)
|
||||
@@ -158,7 +154,7 @@ def install(
|
||||
|
||||
# Expand 'all' to all available models
|
||||
if "all" in models:
|
||||
models_to_install = ['rfd3', 'proteinmpnn', 'ligandmpnn', 'rf3']
|
||||
models_to_install = ["rfd3", "proteinmpnn", "ligandmpnn", "rf3"]
|
||||
else:
|
||||
models_to_install = models
|
||||
|
||||
@@ -167,15 +163,16 @@ def install(
|
||||
install_model(model_name, checkpoint_dir, force)
|
||||
console.print()
|
||||
|
||||
set_key(
|
||||
dotenv_path=find_dotenv(),
|
||||
key_to_set='FOUNDRY_CHECKPOINTS_DIR',
|
||||
value_to_set=str(checkpoint_dir),
|
||||
export = False,
|
||||
)
|
||||
console.print(
|
||||
f"Set checkpoint installation directory to: {checkpoint_dir}"
|
||||
)
|
||||
# Try to persist checkpoint dir to .env (optional, may not exist in Colab etc.)
|
||||
dotenv_path = find_dotenv()
|
||||
if dotenv_path:
|
||||
set_key(
|
||||
dotenv_path=dotenv_path,
|
||||
key_to_set="FOUNDRY_CHECKPOINTS_DIR",
|
||||
value_to_set=str(checkpoint_dir),
|
||||
export=False,
|
||||
)
|
||||
console.print(f"Saved FOUNDRY_CHECKPOINTS_DIR to {dotenv_path}")
|
||||
|
||||
console.print("[bold green]Installation complete![/bold green]")
|
||||
|
||||
@@ -209,9 +206,7 @@ def show(
|
||||
|
||||
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
||||
if not checkpoint_files:
|
||||
console.print(
|
||||
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
||||
)
|
||||
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print(f"[bold]Installed checkpoints in {checkpoint_dir}:[/bold]\n")
|
||||
@@ -247,9 +242,7 @@ def clean(
|
||||
# List files to delete
|
||||
checkpoint_files = list(checkpoint_dir.glob("*.ckpt"))
|
||||
if not checkpoint_files:
|
||||
console.print(
|
||||
f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]"
|
||||
)
|
||||
console.print(f"[yellow]No checkpoint files found in {checkpoint_dir}[/yellow]")
|
||||
raise typer.Exit(0)
|
||||
|
||||
console.print("[bold]Files to delete:[/bold]")
|
||||
|
||||
Reference in New Issue
Block a user