Add non-hydra interface for instantiating model

This commit is contained in:
jbutch
2025-11-23 22:38:39 -08:00
parent 22407d510c
commit 1251f0fcf1
7 changed files with 316 additions and 87 deletions

152
examples/all.ipynb Normal file
View File

@@ -0,0 +1,152 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "819e8193",
"metadata": {},
"outputs": [],
"source": [
"# For now, we still need to manually add atomworks dependency\n",
"import sys;\n",
"sys.path.insert(0, '/home/jbutch/Projects/HT25/af3/rfd3-release/lib/atomworks/src')\n",
"sys.path.append('/home/jbutch/Projects/HT25/af3/rfd3-release/models/rfd3/src')\n",
"sys.path.append('/home/jbutch/Projects/HT25/af3/rfd3-release/src')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2d422b11",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22:35:01 WARNING atomworks.constants: Environment variable CCD_MIRROR_PATH not set. Will not be able to use function requiring this variable. To set it you may:\n",
" (1) add the line 'export VAR_NAME=path/to/variable' to your .bashrc or .zshrc file\n",
" (2) set it in your current shell with 'export VAR_NAME=path/to/variable'\n",
" (3) write it to a .env file in the root of the atomworks.io repository\n",
"22:35:01 WARNING atomworks.constants: Environment variable PDB_MIRROR_PATH not set. Will not be able to use function requiring this variable. To set it you may:\n",
" (1) add the line 'export VAR_NAME=path/to/variable' to your .bashrc or .zshrc file\n",
" (2) set it in your current shell with 'export VAR_NAME=path/to/variable'\n",
" (3) write it to a .env file in the root of the atomworks.io repository\n"
]
}
],
"source": [
"from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine\n",
"\n",
"conf = RFD3InferenceConfig(\n",
" ckpt_path='/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt',\n",
" specification={\n",
" 'length': 2\n",
" },\n",
" diffusion_batch_size=1,\n",
")\n",
"model = RFD3InferenceEngine(**conf)\n",
"outputs = model.run(\n",
" inputs=None,\n",
" out_dir=None,\n",
" n_batches=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "de18ec69",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"22:35:31 INFO modelhub.inference_engines.base: [rank: 0] Seed is None - using external RNG state\n",
"22:35:31 INFO modelhub.inference_engines.base: [rank: 0] Loading checkpoint from /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt...\n",
"22:36:01 ERROR modelhub.utils.ddp: No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?\n",
"22:36:01 INFO modelhub.inference_engines.base: [rank: 0] Building Transform pipeline...\n",
"22:36:01 INFO modelhub.inference_engines.base: [rank: 0] Using settings from validation dataset: unconditional.\n",
"22:36:01 INFO rdkit: Enabling RDKit 2025.03.6 jupyter extensions\n",
"22:36:02 INFO modelhub.inference_engines.base: [rank: 0] Instantiating trainer...\n",
"/home/jbutch/Projects/HT25/af3/rfd3-release/.venv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jbutch/Projects/HT25/af3/rfd3-release/.venv/li ...\n",
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric general_metrics to the validation metrics...\n",
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric backbone_metrics to the validation metrics...\n",
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric sidechain_metrics to the validation metrics...\n",
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric metadata_metrics to the validation metrics...\n",
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric hbond_metrics to the validation metrics...\n",
"22:36:05 INFO modelhub.inference_engines.base: [rank: 0] Setting up model...\n",
"22:36:05 INFO modelhub.trainers.fabric: [rank: 0] Instantiating model...\n",
"22:36:05 WARNING rfd3.model.layers.layer_utils: [rank: 0] Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm.Ensure you're using the correct apptainer\n",
"Error while loading libcue_ops.so: libcuda.so.1: cannot open shared object file: No such file or directory\n",
"22:36:05 INFO rfd3.model.RFD3: [rank: 0] RFD3 initialized with chunked_pll=False\n",
"22:36:06 INFO rfd3.model.inference_sampler: [rank: 0] Initializing ConditionalDiffusionSampler with kind: default\n",
"22:36:06 INFO modelhub.trainers.fabric: [rank: 0] Wrapping model with EMA...\n",
"22:36:06 INFO modelhub.inference_engines.base: [rank: 0] Loading model weights from checkpoint...\n",
"22:36:06 INFO modelhub.trainers.fabric: [rank: 0] Using pre-loaded checkpoint...\n",
"22:36:06 WARNING modelhub.trainers.fabric: [rank: 0] Skipping optimizer loading...\n",
"22:36:06 WARNING modelhub.trainers.fabric: [rank: 0] Skipping scheduler loading...\n",
"22:36:07 INFO modelhub.trainers.fabric: [rank: 0] Loaded checkpoint. Current epoch: 570, global step: 42825\n",
"22:36:07 INFO modelhub.inference_engines.base: [rank: 0] Model loaded and ready for inference.\n",
"22:36:07 INFO rfd3.inference.datasets: [rank: 0] \n",
"+----------------------------------------------+\n",
"Dataset inference-dataset:\n",
" - Found 1 examples:\n",
"_0\n",
"\n",
"+----------------------------------------------+\n",
"\n",
"22:36:07 WARNING rfd3.utils.inference: [rank: 0] No ori_token, infer_ori_strategy, or motif provided. Setting [0,0,0] as origin.\n",
"22:36:07 WARNING atomworks.io: The `extra_fields` argument will be ignored if there is no CIF file input.\n",
"22:36:07 WARNING atomworks.ml: Cached data not found for ALA at /net/tukwila/ncorley/datahub/MACE-OFF23_medium/A/ALA/ALA.pt\n",
"22:36:07 INFO rfd3.transforms.conditioning_base: Indexing all unindexed components\n",
"/home/jbutch/Projects/HT25/af3/rfd3-release/.venv/lib/python3.12/site-packages/torch/nn/functional.py:2920: UserWarning: Mismatch dtype between input and weight: input dtype = c10::BFloat16, weight dtype = float, Cannot dispatch to fused implementation. (Triggered internally at /pytorch/aten/src/ATen/native/layer_norm.cpp:344.)\n",
" return torch.rms_norm(input, normalized_shape, weight, eps)\n",
"/mnt/home/jbutch/Projects/HT25/af3/rfd3-release/models/rfd3/src/rfd3/model/layers/blocks.py:218: UserWarning: index_reduce() is in beta and the API may change at any time. (Triggered internally at /pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1517.)\n",
" .index_reduce(\n"
]
}
],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "7bec1ffd",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "d07ae413",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "rfd3-release",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View File

@@ -28,10 +28,13 @@ def design(ctx: typer.Context):
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
cfg = compose(config_name="inference", overrides=args)
# Lazy import to avoid loading heavy dependencies at CLI startup
from modelhub.utils.logging import suppress_warnings
from rfd3.run_inference import run_inference
run_inference(cfg)
with suppress_warnings(is_inference=True):
run_inference(cfg)
if __name__ == "__main__":

View File

@@ -2,9 +2,10 @@ import json
import logging
import os
import time
from dataclasses import dataclass, field
from os import PathLike
from pathlib import Path
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
import torch
import yaml
@@ -19,6 +20,7 @@ from rfd3.inference.datasets import (
assemble_distributed_inference_loader_from_json,
)
from rfd3.inference.input_parsing import DesignInputSpecification
from rfd3.model.inference_sampler import SampleDiffusionConfig
from rfd3.utils.inference import ensure_input_is_abspath
from rfd3.utils.io import (
CIF_LIKE_EXTENSIONS,
@@ -33,6 +35,52 @@ logging.basicConfig(level=logging.INFO)
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
@dataclass(kw_only=True)
class RFD3InferenceConfig:
ckpt_path: str = "/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt"
diffusion_batch_size: int = 16
# RFD3 specific
skip_existing: bool = False
json_keys_subset: Optional[List[str]] = None
skip_existing: bool = True
specification: Optional[dict] = field(default_factory=dict)
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
# Saving args
cleanup_guideposts: bool = True
cleanup_virtual_atoms: bool = True
read_sequence_from_sequence_head: bool = True
output_full_json: bool = True
# Prefix to add to all output samples
# Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
# Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
global_prefix: Optional[str] = None
dump_prediction_metadata_json: bool = True
dump_trajectories: bool = False
align_trajectory_structures: bool = False
prevalidate_inputs: bool = True
low_memory_mode: bool = (
False # False for standard mode, True for memory efficient tokenization mode
)
# Other:
num_nodes: int = 1
devices_per_node: int = 1
print_config: bool = False
seed: Optional[int] = None
# For use as mapping:
def keys(self):
return self.__dataclass_fields__.keys()
def __getitem__(self, key):
return getattr(self, key)
class RFD3InferenceEngine(BaseInferenceEngine):
"""Inference engine for RFdiffusion3"""
@@ -234,7 +282,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
)
return design_specifications
return inputs
def _multiply_specifications(
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None

View File

@@ -162,7 +162,7 @@ class ContigJsonDataset(MolecularDataset):
if not isinstance(spec, DesignInputSpecification):
spec = ensure_input_is_abspath(spec, self.json_path)
spec["cif_parser_args"] = self.cif_parser_args
spec = DesignInputSpecification(**spec)
spec = DesignInputSpecification.safe_init(**spec)
# Create pipeline input
data = spec.to_pipeline_input(example_id=example_id)

View File

@@ -15,78 +15,23 @@ from modelhub.utils.rotation_augmentation import (
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
def centre_random_augment_around_motif(
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
s_trans: float = 1.0,
center_option: str = "all",
centering_affects_motif: bool = True,
reinsert_motif=True,
):
D, L, _ = X_L.shape
if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
# ... Align original coordinates to the prediction
coords_with_gt_aligned = weighted_rigid_align(
X_L[..., is_motif_atom_with_fixed_coord, :],
coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
)
# ... Insert original coordinates into X_L
X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
# ... Centering
if torch.any(is_motif_atom_with_fixed_coord):
if center_option == "motif":
center = torch.mean(
X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
) # (D, 1, 3) - COM of motif atoms
elif center_option == "diffuse":
center = torch.mean(
X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
) # (D, 1, 3) - COM of diffused atoms
else:
center = torch.mean(X_L, dim=-2, keepdim=True)
else:
center = torch.mean(X_L, dim=-2, keepdim=True)
# ... Center
if centering_affects_motif:
X_L = X_L - center
else:
X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
)
# ... Random augmentation
R = uniform_random_rotation((D,)).to(X_L.device)
noise = (
torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
) # (D, 1, 3)
X_L = rot_vec_mul(R[:, None], X_L) + noise
return X_L, R
@dataclass(kw_only=True)
class SampleDiffusionWithMotif:
"""Diffusion sampler that supports optional motif alignment."""
class SampleDiffusionConfig:
kind: Literal["default", "symmetry"] = "default"
# Standard EDM args
num_timesteps: int # AF-3: 200
min_t: int # AF-3: 0
max_t: int # AF-3: 1
sigma_data: int # AF-3: 16
s_min: float # AF-3: 4e-4
s_max: int # AF-3: 160
p: int # AF-3: 7
gamma_0: float # AF-3: 0.8
gamma_min: float # AF-3: 1.0
noise_scale: float # AF-3: 1.003
step_scale: float # AF-3: 1.5
solver: Literal["af3"]
num_timesteps: int = 200
min_t: int = 0
max_t: int = 1
sigma_data: int = 16
s_min: float = 4e-4
s_max: int = 160
p: int = 7
gamma_0: float = 0.6
gamma_min: float = 1.0
noise_scale: float = 1.003
step_scale: float = 1.5
solver: Literal["af3"] = "af3"
# RFD3 / design args
center_option: str = "all"
@@ -99,8 +44,10 @@ class SampleDiffusionWithMotif:
use_classifier_free_guidance: bool = False
cfg_scale: float = 2.0
cfg_t_max: float | None = None
use_frame_guidance: bool = False
fg_scale: float = 1.0
class SampleDiffusionWithMotif(SampleDiffusionConfig):
"""Diffusion sampler that supports optional motif alignment."""
def _construct_inference_noise_schedule(
self, device: torch.device, partial_t: float = None
@@ -340,11 +287,6 @@ class SampleDiffusionWithMotif:
# apply CFG
delta_L = delta_L + (self.cfg_scale - 1) * (delta_L - delta_L_ref)
if self.use_frame_guidance:
X_L_ref_frame = outs.get("X_L_ref_frame")
delta_L_ref = (X_noisy_L - X_L_ref_frame) / t_hat
delta_L = delta_L + (self.fg_scale - 1) * (delta_L - delta_L_ref)
if exists(outs.get("sequence_logits_I")):
# Compute confidence
p = torch.softmax(
@@ -636,3 +578,58 @@ class ConditionalDiffusionSampler:
[param.name for param in signature.parameters.values()]
)
return arg_names
def centre_random_augment_around_motif(
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
s_trans: float = 1.0,
center_option: str = "all",
centering_affects_motif: bool = True,
reinsert_motif=True,
):
D, L, _ = X_L.shape
if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
# ... Align original coordinates to the prediction
coords_with_gt_aligned = weighted_rigid_align(
X_L[..., is_motif_atom_with_fixed_coord, :],
coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
)
# ... Insert original coordinates into X_L
X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
# ... Centering
if torch.any(is_motif_atom_with_fixed_coord):
if center_option == "motif":
center = torch.mean(
X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
) # (D, 1, 3) - COM of motif atoms
elif center_option == "diffuse":
center = torch.mean(
X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
) # (D, 1, 3) - COM of diffused atoms
else:
center = torch.mean(X_L, dim=-2, keepdim=True)
else:
center = torch.mean(X_L, dim=-2, keepdim=True)
# ... Center
if centering_affects_motif:
X_L = X_L - center
else:
X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
)
# ... Random augmentation
R = uniform_random_rotation((D,)).to(X_L.device)
noise = (
torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
) # (D, 1, 3)
X_L = rot_vec_mul(R[:, None], X_L) + noise
return X_L, R

View File

@@ -5,10 +5,9 @@ import os
import hydra
import rootutils
from dotenv import load_dotenv
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from modelhub.utils.logging import suppress_warnings
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
@@ -20,6 +19,24 @@ load_dotenv(override=True)
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
# def run_inference_without_hydra(
# inputs,
# out_dir,
# n_batches,
# **kwargs
# ) -> None:
# # Create config
# from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
# conf = RFD3InferenceConfig(**kwargs)
# with RFD3InferenceEngine(**conf) as engine:
# return engine.run(
# inputs=inputs,
# out_dir=out_dir,
# n_batches=n_batches
# )
@hydra.main(
config_path=_config_path,
config_name="inference",
@@ -35,11 +52,22 @@ def run_inference(cfg: DictConfig) -> None:
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_params_set}
init_cfg = OmegaConf.create(init_cfg_dict)
inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
# # Run inference
with suppress_warnings(is_inference=True):
inference_engine.run(**run_params)
# Run
init_cfg_dict = {k: v for k, v in init_cfg_dict.items() if k not in ["_target_"]}
init_cfg = RFD3InferenceConfig(**init_cfg_dict)
engine = RFD3InferenceEngine(**init_cfg)
import ipdb
ipdb.set_trace()
engine.run(**run_params)
# inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
# # # Run inference
# with suppress_warnings(is_inference=True):
# inference_engine.run(**run_params)
if __name__ == "__main__":

View File

@@ -44,7 +44,8 @@ dependencies = [
"zstandard",
"pandas",
# "biotite",
"atomworks"
"atomworks",
"ipykernel>=6.31.0",
]