mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
Add non-hydra interface for instantiating model
This commit is contained in:
152
examples/all.ipynb
Normal file
152
examples/all.ipynb
Normal file
@@ -0,0 +1,152 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "819e8193",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# For now, we still need to manually add atomworks dependency\n",
|
||||
"import sys;\n",
|
||||
"sys.path.insert(0, '/home/jbutch/Projects/HT25/af3/rfd3-release/lib/atomworks/src')\n",
|
||||
"sys.path.append('/home/jbutch/Projects/HT25/af3/rfd3-release/models/rfd3/src')\n",
|
||||
"sys.path.append('/home/jbutch/Projects/HT25/af3/rfd3-release/src')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "2d422b11",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"22:35:01 WARNING atomworks.constants: Environment variable CCD_MIRROR_PATH not set. Will not be able to use function requiring this variable. To set it you may:\n",
|
||||
" (1) add the line 'export VAR_NAME=path/to/variable' to your .bashrc or .zshrc file\n",
|
||||
" (2) set it in your current shell with 'export VAR_NAME=path/to/variable'\n",
|
||||
" (3) write it to a .env file in the root of the atomworks.io repository\n",
|
||||
"22:35:01 WARNING atomworks.constants: Environment variable PDB_MIRROR_PATH not set. Will not be able to use function requiring this variable. To set it you may:\n",
|
||||
" (1) add the line 'export VAR_NAME=path/to/variable' to your .bashrc or .zshrc file\n",
|
||||
" (2) set it in your current shell with 'export VAR_NAME=path/to/variable'\n",
|
||||
" (3) write it to a .env file in the root of the atomworks.io repository\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine\n",
|
||||
"\n",
|
||||
"conf = RFD3InferenceConfig(\n",
|
||||
" ckpt_path='/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt',\n",
|
||||
" specification={\n",
|
||||
" 'length': 2\n",
|
||||
" },\n",
|
||||
" diffusion_batch_size=1,\n",
|
||||
")\n",
|
||||
"model = RFD3InferenceEngine(**conf)\n",
|
||||
"outputs = model.run(\n",
|
||||
" inputs=None,\n",
|
||||
" out_dir=None,\n",
|
||||
" n_batches=1,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "de18ec69",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"22:35:31 INFO modelhub.inference_engines.base: [rank: 0] Seed is None - using external RNG state\n",
|
||||
"22:35:31 INFO modelhub.inference_engines.base: [rank: 0] Loading checkpoint from /projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt...\n",
|
||||
"22:36:01 ERROR modelhub.utils.ddp: No GPUs available - Setting accelerator to 'cpu'. Are you sure you are using the correct configs?\n",
|
||||
"22:36:01 INFO modelhub.inference_engines.base: [rank: 0] Building Transform pipeline...\n",
|
||||
"22:36:01 INFO modelhub.inference_engines.base: [rank: 0] Using settings from validation dataset: unconditional.\n",
|
||||
"22:36:01 INFO rdkit: Enabling RDKit 2025.03.6 jupyter extensions\n",
|
||||
"22:36:02 INFO modelhub.inference_engines.base: [rank: 0] Instantiating trainer...\n",
|
||||
"/home/jbutch/Projects/HT25/af3/rfd3-release/.venv/lib/python3.12/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/jbutch/Projects/HT25/af3/rfd3-release/.venv/li ...\n",
|
||||
"Using bfloat16 Automatic Mixed Precision (AMP)\n",
|
||||
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric general_metrics to the validation metrics...\n",
|
||||
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric backbone_metrics to the validation metrics...\n",
|
||||
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric sidechain_metrics to the validation metrics...\n",
|
||||
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric metadata_metrics to the validation metrics...\n",
|
||||
"22:36:05 INFO modelhub.metrics.metric: [rank: 0] Adding metric hbond_metrics to the validation metrics...\n",
|
||||
"22:36:05 INFO modelhub.inference_engines.base: [rank: 0] Setting up model...\n",
|
||||
"22:36:05 INFO modelhub.trainers.fabric: [rank: 0] Instantiating model...\n",
|
||||
"22:36:05 WARNING rfd3.model.layers.layer_utils: [rank: 0] Using nn.RMSNorm instead of apex.normalization.fused_layer_norm.FusedRMSNorm.Ensure you're using the correct apptainer\n",
|
||||
"Error while loading libcue_ops.so: libcuda.so.1: cannot open shared object file: No such file or directory\n",
|
||||
"22:36:05 INFO rfd3.model.RFD3: [rank: 0] RFD3 initialized with chunked_pll=False\n",
|
||||
"22:36:06 INFO rfd3.model.inference_sampler: [rank: 0] Initializing ConditionalDiffusionSampler with kind: default\n",
|
||||
"22:36:06 INFO modelhub.trainers.fabric: [rank: 0] Wrapping model with EMA...\n",
|
||||
"22:36:06 INFO modelhub.inference_engines.base: [rank: 0] Loading model weights from checkpoint...\n",
|
||||
"22:36:06 INFO modelhub.trainers.fabric: [rank: 0] Using pre-loaded checkpoint...\n",
|
||||
"22:36:06 WARNING modelhub.trainers.fabric: [rank: 0] Skipping optimizer loading...\n",
|
||||
"22:36:06 WARNING modelhub.trainers.fabric: [rank: 0] Skipping scheduler loading...\n",
|
||||
"22:36:07 INFO modelhub.trainers.fabric: [rank: 0] Loaded checkpoint. Current epoch: 570, global step: 42825\n",
|
||||
"22:36:07 INFO modelhub.inference_engines.base: [rank: 0] Model loaded and ready for inference.\n",
|
||||
"22:36:07 INFO rfd3.inference.datasets: [rank: 0] \n",
|
||||
"+----------------------------------------------+\n",
|
||||
"Dataset inference-dataset:\n",
|
||||
" - Found 1 examples:\n",
|
||||
"_0\n",
|
||||
"\n",
|
||||
"+----------------------------------------------+\n",
|
||||
"\n",
|
||||
"22:36:07 WARNING rfd3.utils.inference: [rank: 0] No ori_token, infer_ori_strategy, or motif provided. Setting [0,0,0] as origin.\n",
|
||||
"22:36:07 WARNING atomworks.io: The `extra_fields` argument will be ignored if there is no CIF file input.\n",
|
||||
"22:36:07 WARNING atomworks.ml: Cached data not found for ALA at /net/tukwila/ncorley/datahub/MACE-OFF23_medium/A/ALA/ALA.pt\n",
|
||||
"22:36:07 INFO rfd3.transforms.conditioning_base: Indexing all unindexed components\n",
|
||||
"/home/jbutch/Projects/HT25/af3/rfd3-release/.venv/lib/python3.12/site-packages/torch/nn/functional.py:2920: UserWarning: Mismatch dtype between input and weight: input dtype = c10::BFloat16, weight dtype = float, Cannot dispatch to fused implementation. (Triggered internally at /pytorch/aten/src/ATen/native/layer_norm.cpp:344.)\n",
|
||||
" return torch.rms_norm(input, normalized_shape, weight, eps)\n",
|
||||
"/mnt/home/jbutch/Projects/HT25/af3/rfd3-release/models/rfd3/src/rfd3/model/layers/blocks.py:218: UserWarning: index_reduce() is in beta and the API may change at any time. (Triggered internally at /pytorch/aten/src/ATen/native/TensorAdvancedIndexing.cpp:1517.)\n",
|
||||
" .index_reduce(\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7bec1ffd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "d07ae413",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "rfd3-release",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -28,10 +28,13 @@ def design(ctx: typer.Context):
|
||||
|
||||
with initialize_config_dir(config_dir=config_path, version_base="1.3"):
|
||||
cfg = compose(config_name="inference", overrides=args)
|
||||
|
||||
# Lazy import to avoid loading heavy dependencies at CLI startup
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from rfd3.run_inference import run_inference
|
||||
|
||||
run_inference(cfg)
|
||||
with suppress_warnings(is_inference=True):
|
||||
run_inference(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,9 +2,10 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@@ -19,6 +20,7 @@ from rfd3.inference.datasets import (
|
||||
assemble_distributed_inference_loader_from_json,
|
||||
)
|
||||
from rfd3.inference.input_parsing import DesignInputSpecification
|
||||
from rfd3.model.inference_sampler import SampleDiffusionConfig
|
||||
from rfd3.utils.inference import ensure_input_is_abspath
|
||||
from rfd3.utils.io import (
|
||||
CIF_LIKE_EXTENSIONS,
|
||||
@@ -33,6 +35,52 @@ logging.basicConfig(level=logging.INFO)
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class RFD3InferenceConfig:
|
||||
ckpt_path: str = "/projects/ml/aa_design/models/rfd3_latest_cleaned.ckpt"
|
||||
diffusion_batch_size: int = 16
|
||||
|
||||
# RFD3 specific
|
||||
skip_existing: bool = False
|
||||
json_keys_subset: Optional[List[str]] = None
|
||||
skip_existing: bool = True
|
||||
specification: Optional[dict] = field(default_factory=dict)
|
||||
inference_sampler: SampleDiffusionConfig | dict = field(default_factory=dict)
|
||||
|
||||
# Saving args
|
||||
cleanup_guideposts: bool = True
|
||||
cleanup_virtual_atoms: bool = True
|
||||
read_sequence_from_sequence_head: bool = True
|
||||
output_full_json: bool = True
|
||||
|
||||
# Prefix to add to all output samples
|
||||
# Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
|
||||
# Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
|
||||
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
|
||||
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
|
||||
global_prefix: Optional[str] = None
|
||||
dump_prediction_metadata_json: bool = True
|
||||
dump_trajectories: bool = False
|
||||
align_trajectory_structures: bool = False
|
||||
prevalidate_inputs: bool = True
|
||||
low_memory_mode: bool = (
|
||||
False # False for standard mode, True for memory efficient tokenization mode
|
||||
)
|
||||
|
||||
# Other:
|
||||
num_nodes: int = 1
|
||||
devices_per_node: int = 1
|
||||
print_config: bool = False
|
||||
seed: Optional[int] = None
|
||||
|
||||
# For use as mapping:
|
||||
def keys(self):
|
||||
return self.__dataclass_fields__.keys()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class RFD3InferenceEngine(BaseInferenceEngine):
|
||||
"""Inference engine for RFdiffusion3"""
|
||||
|
||||
@@ -234,7 +282,7 @@ class RFD3InferenceEngine(BaseInferenceEngine):
|
||||
f"Invalid input type: {type(inputs)}. Expected JSON/YAML file paths, AtomArray, or DesignInputSpecification.\nInput: {inputs}"
|
||||
)
|
||||
|
||||
return design_specifications
|
||||
return inputs
|
||||
|
||||
def _multiply_specifications(
|
||||
self, inputs: Dict[str, dict | DesignInputSpecification], n_batches=None
|
||||
|
||||
@@ -162,7 +162,7 @@ class ContigJsonDataset(MolecularDataset):
|
||||
if not isinstance(spec, DesignInputSpecification):
|
||||
spec = ensure_input_is_abspath(spec, self.json_path)
|
||||
spec["cif_parser_args"] = self.cif_parser_args
|
||||
spec = DesignInputSpecification(**spec)
|
||||
spec = DesignInputSpecification.safe_init(**spec)
|
||||
|
||||
# Create pipeline input
|
||||
data = spec.to_pipeline_input(example_id=example_id)
|
||||
|
||||
@@ -15,78 +15,23 @@ from modelhub.utils.rotation_augmentation import (
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
def centre_random_augment_around_motif(
|
||||
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
|
||||
is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
|
||||
s_trans: float = 1.0,
|
||||
center_option: str = "all",
|
||||
centering_affects_motif: bool = True,
|
||||
reinsert_motif=True,
|
||||
):
|
||||
D, L, _ = X_L.shape
|
||||
|
||||
if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
|
||||
# ... Align original coordinates to the prediction
|
||||
coords_with_gt_aligned = weighted_rigid_align(
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :],
|
||||
coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
|
||||
)
|
||||
|
||||
# ... Insert original coordinates into X_L
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
|
||||
|
||||
# ... Centering
|
||||
if torch.any(is_motif_atom_with_fixed_coord):
|
||||
if center_option == "motif":
|
||||
center = torch.mean(
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
||||
) # (D, 1, 3) - COM of motif atoms
|
||||
elif center_option == "diffuse":
|
||||
center = torch.mean(
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
||||
) # (D, 1, 3) - COM of diffused atoms
|
||||
|
||||
else:
|
||||
center = torch.mean(X_L, dim=-2, keepdim=True)
|
||||
else:
|
||||
center = torch.mean(X_L, dim=-2, keepdim=True)
|
||||
|
||||
# ... Center
|
||||
if centering_affects_motif:
|
||||
X_L = X_L - center
|
||||
else:
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
|
||||
)
|
||||
|
||||
# ... Random augmentation
|
||||
R = uniform_random_rotation((D,)).to(X_L.device)
|
||||
noise = (
|
||||
torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
|
||||
) # (D, 1, 3)
|
||||
X_L = rot_vec_mul(R[:, None], X_L) + noise
|
||||
|
||||
return X_L, R
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
class SampleDiffusionWithMotif:
|
||||
"""Diffusion sampler that supports optional motif alignment."""
|
||||
class SampleDiffusionConfig:
|
||||
kind: Literal["default", "symmetry"] = "default"
|
||||
|
||||
# Standard EDM args
|
||||
num_timesteps: int # AF-3: 200
|
||||
min_t: int # AF-3: 0
|
||||
max_t: int # AF-3: 1
|
||||
sigma_data: int # AF-3: 16
|
||||
s_min: float # AF-3: 4e-4
|
||||
s_max: int # AF-3: 160
|
||||
p: int # AF-3: 7
|
||||
gamma_0: float # AF-3: 0.8
|
||||
gamma_min: float # AF-3: 1.0
|
||||
noise_scale: float # AF-3: 1.003
|
||||
step_scale: float # AF-3: 1.5
|
||||
solver: Literal["af3"]
|
||||
num_timesteps: int = 200
|
||||
min_t: int = 0
|
||||
max_t: int = 1
|
||||
sigma_data: int = 16
|
||||
s_min: float = 4e-4
|
||||
s_max: int = 160
|
||||
p: int = 7
|
||||
gamma_0: float = 0.6
|
||||
gamma_min: float = 1.0
|
||||
noise_scale: float = 1.003
|
||||
step_scale: float = 1.5
|
||||
solver: Literal["af3"] = "af3"
|
||||
|
||||
# RFD3 / design args
|
||||
center_option: str = "all"
|
||||
@@ -99,8 +44,10 @@ class SampleDiffusionWithMotif:
|
||||
use_classifier_free_guidance: bool = False
|
||||
cfg_scale: float = 2.0
|
||||
cfg_t_max: float | None = None
|
||||
use_frame_guidance: bool = False
|
||||
fg_scale: float = 1.0
|
||||
|
||||
|
||||
class SampleDiffusionWithMotif(SampleDiffusionConfig):
|
||||
"""Diffusion sampler that supports optional motif alignment."""
|
||||
|
||||
def _construct_inference_noise_schedule(
|
||||
self, device: torch.device, partial_t: float = None
|
||||
@@ -340,11 +287,6 @@ class SampleDiffusionWithMotif:
|
||||
# apply CFG
|
||||
delta_L = delta_L + (self.cfg_scale - 1) * (delta_L - delta_L_ref)
|
||||
|
||||
if self.use_frame_guidance:
|
||||
X_L_ref_frame = outs.get("X_L_ref_frame")
|
||||
delta_L_ref = (X_noisy_L - X_L_ref_frame) / t_hat
|
||||
delta_L = delta_L + (self.fg_scale - 1) * (delta_L - delta_L_ref)
|
||||
|
||||
if exists(outs.get("sequence_logits_I")):
|
||||
# Compute confidence
|
||||
p = torch.softmax(
|
||||
@@ -636,3 +578,58 @@ class ConditionalDiffusionSampler:
|
||||
[param.name for param in signature.parameters.values()]
|
||||
)
|
||||
return arg_names
|
||||
|
||||
|
||||
def centre_random_augment_around_motif(
|
||||
X_L: torch.Tensor, # (D, L, 3) noisy diffused coordinates
|
||||
coord_atom_lvl_to_be_noised: torch.Tensor, # (D, L, 3) original coordinates
|
||||
is_motif_atom_with_fixed_coord: torch.Tensor, # (D, L) indices in original coordinates to be kept constant
|
||||
s_trans: float = 1.0,
|
||||
center_option: str = "all",
|
||||
centering_affects_motif: bool = True,
|
||||
reinsert_motif=True,
|
||||
):
|
||||
D, L, _ = X_L.shape
|
||||
|
||||
if reinsert_motif and torch.any(is_motif_atom_with_fixed_coord):
|
||||
# ... Align original coordinates to the prediction
|
||||
coords_with_gt_aligned = weighted_rigid_align(
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :],
|
||||
coord_atom_lvl_to_be_noised[..., is_motif_atom_with_fixed_coord, :],
|
||||
)
|
||||
|
||||
# ... Insert original coordinates into X_L
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :] = coords_with_gt_aligned
|
||||
|
||||
# ... Centering
|
||||
if torch.any(is_motif_atom_with_fixed_coord):
|
||||
if center_option == "motif":
|
||||
center = torch.mean(
|
||||
X_L[..., is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
||||
) # (D, 1, 3) - COM of motif atoms
|
||||
elif center_option == "diffuse":
|
||||
center = torch.mean(
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :], dim=-2, keepdim=True
|
||||
) # (D, 1, 3) - COM of diffused atoms
|
||||
|
||||
else:
|
||||
center = torch.mean(X_L, dim=-2, keepdim=True)
|
||||
else:
|
||||
center = torch.mean(X_L, dim=-2, keepdim=True)
|
||||
|
||||
# ... Center
|
||||
if centering_affects_motif:
|
||||
X_L = X_L - center
|
||||
else:
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :] = (
|
||||
X_L[..., ~is_motif_atom_with_fixed_coord, :] - center
|
||||
)
|
||||
|
||||
# ... Random augmentation
|
||||
R = uniform_random_rotation((D,)).to(X_L.device)
|
||||
noise = (
|
||||
torch.normal(mean=0, std=1, size=(D, 1, 3), device=X_L.device) * s_trans
|
||||
) # (D, 1, 3)
|
||||
X_L = rot_vec_mul(R[:, None], X_L) + noise
|
||||
|
||||
return X_L, R
|
||||
|
||||
@@ -5,10 +5,9 @@ import os
|
||||
import hydra
|
||||
import rootutils
|
||||
from dotenv import load_dotenv
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig, OmegaConf
|
||||
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
|
||||
|
||||
# Setup root dir and environment variables (more info: https://github.com/ashleve/rootutils)
|
||||
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
||||
@@ -20,6 +19,24 @@ load_dotenv(override=True)
|
||||
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rfd3/configs")
|
||||
|
||||
|
||||
# def run_inference_without_hydra(
|
||||
# inputs,
|
||||
# out_dir,
|
||||
# n_batches,
|
||||
# **kwargs
|
||||
# ) -> None:
|
||||
|
||||
# # Create config
|
||||
# from rfd3.engine import RFD3InferenceConfig, RFD3InferenceEngine
|
||||
# conf = RFD3InferenceConfig(**kwargs)
|
||||
# with RFD3InferenceEngine(**conf) as engine:
|
||||
# return engine.run(
|
||||
# inputs=inputs,
|
||||
# out_dir=out_dir,
|
||||
# n_batches=n_batches
|
||||
# )
|
||||
|
||||
|
||||
@hydra.main(
|
||||
config_path=_config_path,
|
||||
config_name="inference",
|
||||
@@ -35,11 +52,22 @@ def run_inference(cfg: DictConfig) -> None:
|
||||
cfg_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
init_cfg_dict = {k: v for k, v in cfg_dict.items() if k not in run_params_set}
|
||||
init_cfg = OmegaConf.create(init_cfg_dict)
|
||||
inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
|
||||
|
||||
# # Run inference
|
||||
with suppress_warnings(is_inference=True):
|
||||
inference_engine.run(**run_params)
|
||||
# Run
|
||||
init_cfg_dict = {k: v for k, v in init_cfg_dict.items() if k not in ["_target_"]}
|
||||
init_cfg = RFD3InferenceConfig(**init_cfg_dict)
|
||||
engine = RFD3InferenceEngine(**init_cfg)
|
||||
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
engine.run(**run_params)
|
||||
|
||||
# inference_engine = instantiate(init_cfg, _convert_="partial", _recursive_=False)
|
||||
|
||||
# # # Run inference
|
||||
# with suppress_warnings(is_inference=True):
|
||||
# inference_engine.run(**run_params)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -44,7 +44,8 @@ dependencies = [
|
||||
"zstandard",
|
||||
"pandas",
|
||||
# "biotite",
|
||||
"atomworks"
|
||||
"atomworks",
|
||||
"ipykernel>=6.31.0",
|
||||
]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user