From 592de5b48859dc84fd4d83986c7a1fc796cff89b Mon Sep 17 00:00:00 2001 From: ncorley Date: Wed, 1 Oct 2025 23:52:01 -0700 Subject: [PATCH] fix: working rf3 --- README.md | 63 +- RESTRUCTURING_PLAN.md | 919 ------------------ models/rf3/CONTAINER.md | 89 ++ models/rf3/README.md | 63 +- .../callbacks/dump_validation_structures.yaml | 2 +- .../configs/callbacks/metrics_logging.yaml | 4 +- .../datasets/pdb_and_distillation.yaml | 2 +- models/rf3/configs/datasets/pdb_only.yaml | 2 +- .../datasets/train/disorder_distillation.yaml | 7 +- .../datasets/train/domain_distillation.yaml | 7 +- .../datasets/train/monomer_distillation.yaml | 7 +- .../train/na_complex_distillation.yaml | 7 +- .../train/pdb/af3_weighted_sampling.yaml | 2 +- .../rf3/configs/datasets/train/pdb/base.yaml | 6 +- .../configs/datasets/train/pdb/plinder.yaml | 4 +- .../datasets/train/pdb/train_interface.yaml | 2 +- .../datasets/train/pdb/train_pn_unit.yaml | 2 +- .../train/rna_monomer_distillation.yaml | 7 +- .../rf3/configs/datasets/val/af3_ab_set.yaml | 5 +- .../configs/datasets/val/af3_validation.yaml | 5 +- models/rf3/configs/datasets/val/base.yaml | 5 +- .../configs/datasets/val/runs_and_poses.yaml | 5 +- .../configs/experiment/pretrained/rf3.yaml | 2 +- .../pretrained/rf3_with_confidence.yaml | 2 +- .../experiment/quick-rf3-with-confidence.yaml | 2 +- models/rf3/configs/experiment/quick-rf3.yaml | 2 +- models/rf3/configs/inference_engine/rf3.yaml | 10 +- .../rf3/configs/model/components/rf3_net.yaml | 2 +- .../rf3_net_with_confidence_head.yaml | 2 +- .../configs/model/rf3_with_confidence.yaml | 2 +- models/rf3/configs/model/schedulers/af3.yaml | 2 +- models/rf3/configs/paths/data/default.yaml | 23 +- .../trainer/loss/losses/confidence_loss.yaml | 2 +- .../trainer/loss/losses/diffusion_loss.yaml | 2 +- .../trainer/loss/losses/distogram_loss.yaml | 2 +- .../trainer/metrics/structure_prediction.yaml | 14 +- models/rf3/configs/trainer/rf3.yaml | 2 +- .../configs/trainer/rf3_with_confidence.yaml | 8 +- models/rf3/pyproject.toml | 67 ++ models/rf3/rf3-dev.def | 61 ++ models/rf3/src/rf3/__init__.py | 3 + models/rf3/src/rf3/_version.py | 34 + .../callbacks/dump_validation_structures.py | 4 +- .../rf3/src/rf3/callbacks/metrics_logging.py | 6 +- models/rf3/src/rf3/cli.py | 9 +- .../rf3/src/rf3/data/ground_truth_template.py | 2 +- models/rf3/src/rf3/data/paired_msa.py | 6 +- .../diffusion_samplers/inference_sampler.py | 2 +- models/rf3/src/rf3/inference.py | 11 +- .../rf3/src/rf3/inference_engines/__init__.py | 5 + models/rf3/src/rf3/inference_engines/rf3.py | 9 +- models/rf3/src/rf3/kinematics.py | 36 - .../rf3/src/rf3}/metrics/chiral.py | 2 +- models/rf3/src/rf3/metrics/clashing_chains.py | 2 +- models/rf3/src/rf3/metrics/distogram.py | 4 +- models/rf3/src/rf3/metrics/lddt.py | 10 +- models/rf3/src/rf3/metrics/metadata.py | 2 +- models/rf3/src/rf3/metrics/predicted_error.py | 2 +- .../rf3/src/rf3}/metrics/rasa.py | 2 +- .../rf3/src/rf3/metrics/selected_distances.py | 2 +- .../src/rf3/model/layers/Attention_module.py | 2 +- .../layers/FusedTriangleMultiplication.py | 2 +- .../model/layers/af3_diffusion_transformer.py | 2 +- models/rf3/src/rf3/train.py | 19 +- models/rf3/src/rf3/trainers/rf3.py | 10 +- models/rf3/src/rf3/training/checkpoint.py | 31 +- models/rf3/src/rf3/utils/datasets.py | 4 +- models/rf3/src/rf3/utils/io.py | 2 +- models/rf3/src/rf3/validate.py | 18 +- models/rf3/tests/conftest.py | 7 + models/rf3/tests/data/5vht_from_file.cif | 4 +- models/rf3/tests/data/5vht_from_json.json | 4 +- .../rf3/tests/test_chiral_metrics.py | 2 +- pyproject.toml | 24 +- src/modelhub/__init__.py | 2 +- src/modelhub/callbacks/__init__.py | 5 + src/modelhub/callbacks/health_logging.py | 2 +- src/modelhub/callbacks/timing_logging.py | 4 +- src/modelhub/callbacks/train_logging.py | 10 +- src/modelhub/inference_engines/base.py | 15 - src/modelhub/metrics/__init__.py | 12 + src/modelhub/metrics/{base.py => metric.py} | 2 +- src/modelhub/trainers/fabric.py | 8 +- src/modelhub/utils/instantiators.py | 2 +- src/modelhub/utils/logging.py | 2 +- src/modelhub/utils/torch.py | 4 +- src/modelhub/utils/weights.py | 2 +- tests/test_torch_utils.py | 2 +- tests/test_weight_loading.py | 2 +- 89 files changed, 554 insertions(+), 1226 deletions(-) delete mode 100644 RESTRUCTURING_PLAN.md create mode 100644 models/rf3/CONTAINER.md create mode 100644 models/rf3/pyproject.toml create mode 100644 models/rf3/rf3-dev.def create mode 100644 models/rf3/src/rf3/__init__.py create mode 100644 models/rf3/src/rf3/_version.py create mode 100644 models/rf3/src/rf3/inference_engines/__init__.py rename {src/modelhub => models/rf3/src/rf3}/metrics/chiral.py (99%) rename {src/modelhub => models/rf3/src/rf3}/metrics/rasa.py (99%) rename tests/test_metrics.py => models/rf3/tests/test_chiral_metrics.py (97%) create mode 100644 src/modelhub/callbacks/__init__.py delete mode 100644 src/modelhub/inference_engines/base.py create mode 100644 src/modelhub/metrics/__init__.py rename src/modelhub/metrics/{base.py => metric.py} (99%) diff --git a/README.md b/README.md index 747f5c7..d8e9bc4 100644 --- a/README.md +++ b/README.md @@ -24,7 +24,7 @@ For more information, please see our preprint, [Accelerating Biomolecular Modeli > [!TIP] -> Complete inference instructions for RF3 are provided [here](src/modelhub/inference_engines/README.md). +> Complete inference instructions for RF3 are provided [here](models/rf3/README.md). ### RF3 Quick Start - Installation & Usage @@ -32,7 +32,7 @@ Follow these steps to set up **ModelForge** and run a test prediction with **RF3 --- -#### 1. Install the repository using `uv` +#### 1. Install the source repository and RF3 model using `uv` ```bash git clone https://github.com/RosettaCommons/modelforge.git \ @@ -40,9 +40,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \ && uv python install 3.12 \ && uv venv --python 3.12 \ && source .venv/bin/activate \ - && uv pip install -e . + && uv pip install -e ./models/rf3 ``` +> [!NOTE] +> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency. + #### 2. Download model weights for RF3 ```bash wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt @@ -53,4 +56,56 @@ wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt rf3 fold tests/data/5vht_from_json.json ``` -Details on the exact formatting of the json files are available [here](src/modelhub/inference_engines/README.md). \ No newline at end of file +Details on the exact formatting of the json files are available [here](models/rf3/README.md). + +## Development + +### Package Structure + +ModelForge uses a multi-package architecture: + +- **`modelhub`**: Core package containing shared utilities, training infrastructure, and base classes +- **`models/rf3/`**: RF3 model package with model-specific code and dependencies +- **`models//`**: Additional models can be added as separate packages + +### Installation Options + +#### For Users (Single Model) + +Install only the model you need. Dependencies (including `modelhub`) are automatically installed: + +```bash +# Install RF3 (includes modelhub automatically) +uv pip install -e ./models/rf3 + +# Future: Install other models +# uv pip install -e ./models/other_model +``` + +#### For Core Developers (Multiple Packages) + +Install both `modelhub` and models in editable mode for development: + +```bash +# Install modelhub and RF3 in editable mode +uv pip install -e . -e ./models/rf3 + +# Or install only modelhub (no models) +uv pip install -e . +``` + +This approach allows you to: +- Modify `modelhub` shared utilities and see changes immediately +- Work on specific models without installing all models +- Add new models as independent packages in `models/` + +### Adding New Models + +To add a new model: + +1. Create `models//` directory with its own `pyproject.toml` +2. Add `modelhub` as a dependency +3. Implement model-specific code in `models//src/` +4. Users can install with: `uv pip install -e ./models/` + +## Development \ No newline at end of file diff --git a/RESTRUCTURING_PLAN.md b/RESTRUCTURING_PLAN.md deleted file mode 100644 index c24668c..0000000 --- a/RESTRUCTURING_PLAN.md +++ /dev/null @@ -1,919 +0,0 @@ -# ModelForge Repository Restructuring Plan - -**Date**: October 1, 2025 -**Author**: Analysis based on PyTorch Lightning patterns -**Goal**: Reorganize `src/modelhub/` following best practices while keeping `releases/` structure intact - ---- - -## Executive Summary - -This plan addresses three critical issues in the current structure: -1. **Broken imports**: RF3 uses `from metrics.base import Metric` but base classes are in `src/modelhub/` -2. **CLI entry point mismatch**: `pyproject.toml` points to non-existent `modelhub.cli:app` -3. **Package installation confusion**: Only `src/modelhub` is packaged, but RF3 code is in `releases/` -4. **Unclear organization**: Mixed base classes and implementations without clear semantic separation - -**Core Principle**: Follow PyTorch Lightning's pattern - base classes live WITH their implementations, no separate `contrib/` directory. - ---- - -## Current Structure Analysis - -### What Exists Now -``` -modelhub_latent/ -├── pyproject.toml # name = "rf3" (wrong!) -├── src/ -│ └── modelhub/ # Only this is packaged -│ ├── callbacks/ -│ │ ├── base.py # BaseCallback -│ │ ├── health_logging.py # Implementation -│ │ ├── timing_logging.py # Implementation -│ │ └── train_logging.py # Implementation -│ ├── metrics/ -│ │ ├── base.py # Metric + MetricManager -│ │ ├── chiral.py # Implementation -│ │ └── rasa.py # Implementation -│ ├── utils/ -│ ├── trainers/ -│ └── hydra/ -├── releases/ -│ └── rf3/ # NOT packaged! -│ ├── src/rf3/ -│ │ ├── callbacks/ -│ │ ├── metrics/ -│ │ ├── model/ -│ │ └── cli.py -│ ├── configs/ -│ └── tests/ -└── tests/ # Shared tests -``` - -### Critical Problems - -#### Problem 1: Import Path Mismatch -**RF3 code uses**: -```python -from metrics.base import Metric # Expects local resolution -from callbacks.base import BaseCallback # Expects local resolution -``` - -**Reality**: -- Base classes are in `src/modelhub/metrics/base.py` -- This only works via sys.path manipulation or broken imports - -**Should be**: -```python -from modelhub.metrics.metric import Metric -from modelhub.callbacks.callback import BaseCallback -``` - -#### Problem 2: CLI Entry Point -**Current `pyproject.toml`**: -```toml -[project.scripts] -rf3 = "modelhub.cli:app" -``` - -**Reality**: There is NO `src/modelhub/cli.py` file! - -**Actual CLI**: Located at `releases/rf3/src/rf3/cli.py` - -#### Problem 3: Package Name Confusion -**Current**: `name = "rf3"` in pyproject.toml - -**Problem**: -- Repository is called "modelforge" -- Users would `pip install rf3` (gets only base framework) -- Can't do `pip install modelforge[rf3]` for selective installation - -#### Problem 4: Packaging Scope -**Current**: `packages = ["src/modelhub"]` - -**Problem**: RF3 code in `releases/rf3/src/rf3/` is NOT included in the package! - ---- - -## Design Goals - -### User Experience Goals -1. `pip install modelforge` → Get base framework only -2. `pip install modelforge[rf3]` → Get base + RF3 -3. `pip install modelforge[all]` → Get all models (future-proof) -4. Development: `pip install -e .[rf3,dev]` - -### Code Organization Goals -1. Follow PyTorch Lightning patterns (familiar to ML researchers) -2. Clear distinction between base classes and implementations -3. Proper namespacing for imports -4. Maintain `releases/` structure for development workflow - -### Import Pattern Goals -```python -# Clear, unambiguous imports -from modelforge.callbacks.callback import BaseCallback -from modelforge.metrics.metric import Metric -from modelforge.utils.ddp import RankedLogger - -# RF3-specific imports -from modelforge.models.rf3.model import RF3 -``` - ---- - -## PyTorch Lightning Pattern Analysis - -### Lightning's Structure (Reference) -``` -lightning/pytorch/ -├── core/ # Primary model abstractions -│ ├── module.py # LightningModule (what users inherit from) -│ ├── datamodule.py # LightningDataModule -│ └── hooks.py -├── callbacks/ # Base + implementations together -│ ├── callback.py # Callback base class -│ ├── model_checkpoint.py # ModelCheckpoint implementation -│ ├── early_stopping.py # EarlyStopping implementation -│ └── __init__.py # Exports everything -├── loggers/ # Base + implementations together -│ ├── logger.py # Logger base class -│ ├── wandb.py # WandbLogger implementation -│ └── __init__.py -└── utilities/ # Pure utilities -``` - -### Key Insights from Lightning - -1. **`core/` is for PRIMARY model abstractions** - Things users inherit from to define their model: - - `LightningModule` - "Your model IS A LightningModule" - - `LightningDataModule` - "Your data IS A LightningDataModule" - -2. **Other directories contain base + implementations together**: - - `callbacks/callback.py` (base) + `callbacks/model_checkpoint.py` (impl) - - `loggers/logger.py` (base) + `loggers/wandb.py` (impl) - -3. **NO `contrib/` directory** - That's a Django pattern, not Lightning! - -4. **`__init__.py` exports everything** for clean imports - -### When to Use `core/` - -**Use `core/` IF**: -- You have a primary model class users inherit from (like `LightningModule`) -- These classes define "what you ARE building" (semantic role) - -**Don't use `core/` IF**: -- All your base classes are for plugins/extensions (callbacks, metrics) -- You don't have a central model abstraction - -**For ModelForge**: Since there's no primary "ModelBase" class, **we don't need `core/`**. - -### Re-evaluating `inference_engines/` - -**Current situation**: -- `InferenceEngine` is a minimal ABC with only 2 abstract methods (`__init__`, `eval`) -- Only RF3 uses it (no other models yet) -- It's a very thin abstraction that provides minimal value - -**Question**: Do we need this abstraction at all? - -**Arguments for keeping it**: -- Future-proofs for when other models are added -- Provides a consistent interface pattern - -**Arguments for removing it**: -- YAGNI (You Aren't Gonna Need It) - only one implementation exists -- Adds indirection without clear benefit -- The abstraction is so minimal it doesn't enforce meaningful constraints -- Each model's inference is likely to be sufficiently different that a shared interface adds little value - -**Recommendation**: **Remove the `inference_engines/` directory entirely from `src/modelhub/`** -- RF3's inference engine can live solely in `models/rf3/src/rf3/inference_engines/rf3.py` -- Remove the ABC inheritance - just make it a standalone class -- When/if a second model is added, we can evaluate whether a shared abstraction makes sense based on actual commonalities -- This follows YAGNI and keeps the code simpler - ---- - -## Recommended Solution - -### Two-Part Restructuring - -#### Part 1: Fix `src/modelhub/` Structure (Simple, Lightning-style) - -``` -src/ -└── modelhub/ (rename to modelforge?) - ├── __init__.py # Export key base classes - │ - ├── callbacks/ - │ ├── __init__.py # Export Callback + all implementations - │ ├── callback.py # BaseCallback (RENAMED from base.py) - │ ├── health_logging.py # HealthLoggingCallback - │ ├── timing_logging.py # TimingLoggingCallback - │ └── train_logging.py # TrainLoggingCallback - │ - ├── metrics/ - │ ├── __init__.py # Export Metric + all implementations - │ ├── metric.py # Metric, MetricManager (RENAMED from base.py) - │ ├── chiral.py # ChiralMetric - │ └── rasa.py # RASAMetric - │ - ├── trainers/ - │ ├── __init__.py - │ └── fabric.py # Fabric trainer wrapper - │ - ├── utils/ # Pure utilities (no base classes) - │ ├── ddp.py - │ ├── logging.py - │ ├── weights.py - │ ├── instantiators.py - │ └── torch.py - │ - └── hydra/ # Hydra utilities - ├── __init__.py - └── resolvers.py -``` - -**Key changes**: -1. Rename `base.py` files to match their content: - - `callbacks/base.py` → `callbacks/callback.py` - - `metrics/base.py` → `metrics/metric.py` -2. **REMOVE** `inference_engines/` entirely (not needed with only one model) -3. Keep implementations WITH base classes (Lightning pattern) -4. No `core/` directory (not needed without primary model abstraction) -5. Proper `__init__.py` exports - -#### Part 2: Integrate RF3 into Main Package (For pip install) - -**Decision Point**: Choose ONE of these approaches: - -##### Option A: Move RF3 into src/modelhub/ (Monorepo style) -``` -src/ -└── modelhub/ - ├── callbacks/, metrics/, utils/ (as above) - └── models/ # NEW - └── rf3/ # Moved from releases/rf3/src/rf3/ - ├── __init__.py - ├── model/ - ├── data/ - ├── metrics/ # RF3-specific metrics - ├── callbacks/ # RF3-specific callbacks - ├── inference_engines/ - └── cli.py -``` - -**Pros**: -- Simple pip install: `pip install modelforge[rf3]` -- Single package, single version -- Easy code sharing between models - -**Cons**: -- Changes development workflow (no separate `releases/` for development) -- RF3 code always in source tree even if not installed - -##### Option B: Keep releases/ separate (Development friendly) -``` -modelhub_latent/ -├── src/modelhub/ # Base framework only -├── releases/ -│ └── rf3/ # Stays as is for development -│ └── src/rf3/ -└── pyproject.toml # Use optional dependencies -``` - -**Then in `pyproject.toml`**: -```toml -[project.optional-dependencies] -rf3 = [ - "modelforge-rf3 @ file:///${PROJECT_ROOT}/releases/rf3", - # RF3-specific deps -] -``` - -**Pros**: -- Keeps development workflow unchanged -- Clear separation for releases - -**Cons**: -- More complex build setup -- Requires editable installs for development -- Each release needs own `pyproject.toml` - -##### Option C: Hybrid - Move RF3 but keep releases/ for development (RECOMMENDED) -``` -# For development: work in releases/rf3/ -releases/rf3/ -├── src/rf3/ # Development happens here -├── configs/ # RF3 configs here -└── tests/ # RF3 tests here - -# For installation: symlink or copy during build -src/modelhub/models/rf3/ → symlink to releases/rf3/src/rf3/ - -# Or use build hooks to include releases/ in package -``` - -**Implementation**: Use hatch build hooks to include `releases/rf3/src/rf3/` as `src/modelhub/models/rf3/` - ---- - -## Detailed Implementation Plan - -### Phase 1: Rename Package (Breaking Change Decision) - -**Decision needed**: Keep `modelhub` or rename to `modelforge`? - -**Arguments for `modelforge`**: -- Matches repository name -- More descriptive (framework for building models) -- Enables `pip install modelforge[rf3]` - -**Arguments for `modelhub`**: -- Less renaming work -- Existing code already uses it - -**Recommendation**: Rename to `modelforge` for clarity and marketing. - -### Phase 2: Restructure src/modelhub/ (src/modelforge/) - -#### Step 2.1: Rename Base Class Files -```bash -# In src/modelhub/ -git mv callbacks/base.py callbacks/callback.py -git mv metrics/base.py metrics/metric.py -# inference_engines/base.py could stay or rename to inference_engine.py -``` - -#### Step 2.2: Update Imports in Base Files -**In `src/modelhub/callbacks/callback.py`**: -```python -# Update any internal imports if needed -# Mainly just ensure docstrings reference correct module names -``` - -**In `src/modelhub/metrics/metric.py`**: -```python -# Update imports and docstrings -``` - -#### Step 2.3: Create Proper __init__.py Files - -**`src/modelhub/callbacks/__init__.py`**: -```python -"""Callbacks for training customization. - -This module provides both the base callback class and common callback implementations. -""" - -from modelhub.callbacks.callback import BaseCallback -from modelhub.callbacks.health_logging import HealthLoggingCallback -from modelhub.callbacks.timing_logging import TimingLoggingCallback -from modelhub.callbacks.train_logging import TrainLoggingCallback - -__all__ = [ - "BaseCallback", - "HealthLoggingCallback", - "TimingLoggingCallback", - "TrainLoggingCallback", -] -``` - -**`src/modelhub/metrics/__init__.py`**: -```python -"""Metrics for model evaluation. - -This module provides the base metric framework and common metric implementations. -""" - -from modelhub.metrics.metric import Metric, MetricManager, instantiate_metric_manager -from modelhub.metrics.chiral import ChiralMetric -from modelhub.metrics.rasa import RASAMetric - -__all__ = [ - "Metric", - "MetricManager", - "instantiate_metric_manager", - "ChiralMetric", - "RASAMetric", -] -``` - -**`src/modelhub/__init__.py`**: -```python -"""ModelForge: Open-source framework for biomolecular modeling. - -This package provides the base framework for building and training biomolecular models. -""" - -# Export key base classes at top level -from modelhub.callbacks.callback import BaseCallback -from modelhub.metrics.metric import Metric, MetricManager - -# Version -from modelhub.version import __version__ - -__all__ = [ - "BaseCallback", - "Metric", - "MetricManager", - "__version__", -] -``` - -### Phase 3: Fix RF3 Imports - -#### Step 3.1: Update All RF3 Import Statements - -**Find all broken imports**: -```bash -cd releases/rf3/ -grep -r "from metrics.base import" src/ -grep -r "from callbacks.base import" src/ -grep -r "from inference_engines.base import" src/ -``` - -**Replace with**: -```python -# OLD (broken) -from metrics.base import Metric -from callbacks.base import BaseCallback - -# NEW (correct) -from modelhub.metrics.metric import Metric, MetricManager -from modelhub.callbacks.callback import BaseCallback -from modelhub.inference_engines.base import InferenceEngine -``` - -**Automated replacement**: -```bash -# In releases/rf3/src/ -find . -name "*.py" -exec sed -i 's/from metrics\.base import/from modelhub.metrics.metric import/g' {} + -find . -name "*.py" -exec sed -i 's/from callbacks\.base import/from modelhub.callbacks.callback import/g' {} + -find . -name "*.py" -exec sed -i 's/from inference_engines\.base import/from modelhub.inference_engines.base import/g' {} + -``` - -#### Step 3.2: Update RF3 Utility Imports - -```python -# Also update imports of shared utilities -from modelhub.utils.ddp import RankedLogger -from modelhub.utils.logging import suppress_warnings -from modelhub.utils.weights import load_checkpoint -from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks -``` - -### Phase 4: Fix Build Configuration - -#### Step 4.1: Update pyproject.toml - -**Current**: -```toml -[project] -name = "rf3" -[project.scripts] -rf3 = "modelhub.cli:app" -[tool.hatch.build.targets.wheel] -packages = ["src/modelhub"] -``` - -**New**: -```toml -[project] -name = "modelforge" -description = "Open-source framework for biomolecular structure prediction and design" - -# Minimal dependencies for base framework -dependencies = [ - "torch>=2.2.0,<3", - "lightning>=2.4.0,<2.5", - "hydra-core>=1.3.0,<1.4", - "rootutils>=1.0.7,<1.1", - "environs>=11.0.0,<12", - "wandb>=0.15.10,<1", - "rich>=13.9.4,<14", - "jaxtyping>=0.2.17,<1", - "beartype>=0.18.0,<1", -] - -[project.optional-dependencies] -# RF3 model with its specific dependencies -rf3 = [ - "atomworks==1.0.2", - "einops>=0.8.0,<1", - "einx>=0.1.0,<1", - "opt_einsum>=3.4.0,<4", - "dm-tree>=0.1.6,<1", - "cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'", - "cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'", - "cuequivariance_torch>=0.5.0; sys_platform == 'linux'", -] - -# All models -all = [ - "modelforge[rf3]", -] - -# Development -dev = [ - "ruff==0.8.3", - "pytest>=8.2.0,<9", - # ... other dev deps -] - -[project.scripts] -# Main CLI - needs to be created or use RF3 CLI directly -rf3 = "rf3.cli:app" # Direct to RF3 for now - -[tool.hatch.build.targets.wheel] -# Include both modelhub and rf3 (if using monorepo approach) -packages = ["src/modelhub"] -# OR include releases/rf3/src/rf3 via custom build hook - -# For development: force-include releases -[tool.hatch.build.targets.wheel.force-include] -"releases/rf3/src/rf3" = "modelhub/models/rf3" # If using monorepo approach -``` - -#### Step 4.2: Create Main CLI (If Needed) - -**Option 1**: Keep `rf3` CLI pointing directly to RF3 -```toml -[project.scripts] -rf3 = "rf3.cli:app" -``` - -**Option 2**: Create dispatcher CLI (future-proof) - -**`src/modelhub/cli.py`**: -```python -"""Main ModelForge CLI.""" - -import typer - -app = typer.Typer() - -# Import RF3 CLI if available -try: - from rf3.cli import app as rf3_app - # Expose RF3 commands at top level - for command_name, command in rf3_app.registered_commands: - app.command(name=command_name)(command.callback) -except ImportError: - pass # RF3 not installed - -if __name__ == "__main__": - app() -``` - -Then: -```toml -[project.scripts] -modelforge = "modelhub.cli:app" -rf3 = "rf3.cli:app" # Keep for backward compatibility -``` - -### Phase 5: Handle RF3 Package Integration - -**Decision needed**: Choose approach from Phase 2, Part 2. - -**Recommended: Option C (Hybrid)** - -Use hatch build hooks to include RF3 from `releases/`: - -**`hatch_build.py`** (in project root): -```python -"""Custom hatch build hook to include RF3 from releases/.""" - -from hatchling.builders.hooks.plugin.interface import BuildHookInterface - -class CustomBuildHook(BuildHookInterface): - def initialize(self, version, build_data): - """Include RF3 from releases/ directory.""" - if self.target_name == "wheel": - # Add releases/rf3/src/rf3 to the wheel as modelhub/models/rf3 - build_data["force_include"] = { - "releases/rf3/src/rf3": "modelhub/models/rf3" - } -``` - -**In `pyproject.toml`**: -```toml -[tool.hatch.build.hooks.custom] -path = "hatch_build.py" -``` - -This way: -- Development happens in `releases/rf3/` -- Installation includes RF3 in the package -- Users can `pip install modelforge[rf3]` - -### Phase 6: Update Tests - -#### Step 6.1: Update Test Imports - -**In `releases/rf3/tests/`**: -```python -# Update conftest.py and test files -from modelhub.metrics.metric import Metric -from modelhub.callbacks.callback import BaseCallback -``` - -**In root `tests/`**: -```python -# These already test shared utilities -from modelhub.utils.weights import load_checkpoint -from modelhub.metrics.metric import MetricManager -``` - -#### Step 6.2: Verify Test Execution - -```bash -# Test shared utilities -pytest tests/ - -# Test RF3 (from releases/rf3/) -cd releases/rf3/ -pytest tests/ -``` - -### Phase 7: Update Documentation - -#### Step 7.1: Update CLAUDE.md - -Update import examples and package structure documentation. - -#### Step 7.2: Update README.md - -```markdown -# ModelForge - -## Installation - -```bash -# Base framework only -pip install modelforge - -# With RF3 -pip install modelforge[rf3] - -# Development -pip install -e .[rf3,dev] -``` - -## Usage - -```python -# Import base framework -from modelforge.callbacks import BaseCallback -from modelforge.metrics import Metric - -# Import RF3 -from modelforge.models.rf3 import RF3 -``` -``` - ---- - -## Migration Checklist - -### Pre-Migration -- [ ] Backup current code -- [ ] Create feature branch: `git checkout -b refactor/restructure-package` -- [ ] Document current import patterns for reference - -### Phase 1: Package Rename -- [ ] Decision: Keep `modelhub` or rename to `modelforge`? -- [ ] Update `pyproject.toml` name field -- [ ] Update all imports if renaming - -### Phase 2: Restructure src/ -- [ ] Rename `callbacks/base.py` → `callbacks/callback.py` -- [ ] Rename `metrics/base.py` → `metrics/metric.py` -- [ ] Create/update all `__init__.py` files with proper exports -- [ ] Update internal imports within src/modelhub/ - -### Phase 3: Fix RF3 Imports -- [ ] Find all `from metrics.base` imports in releases/rf3/ -- [ ] Replace with `from modelhub.metrics.metric` -- [ ] Find all `from callbacks.base` imports -- [ ] Replace with `from modelhub.callbacks.callback` -- [ ] Update utility imports to use `modelhub.utils.*` -- [ ] Test that RF3 can import all needed modules - -### Phase 4: Fix Build Config -- [ ] Update `pyproject.toml` project name -- [ ] Split dependencies: core vs rf3-specific -- [ ] Add `[project.optional-dependencies]` for rf3 -- [ ] Fix `[project.scripts]` CLI entry point -- [ ] Update `[tool.hatch.build.targets.wheel]` packages -- [ ] Add build hooks if using hybrid approach - -### Phase 5: RF3 Integration -- [ ] Decide on integration approach (A, B, or C) -- [ ] Implement chosen approach -- [ ] Test `pip install -e .` (base only) -- [ ] Test `pip install -e .[rf3]` (with RF3) - -### Phase 6: Update Tests -- [ ] Update test imports -- [ ] Run shared tests: `pytest tests/` -- [ ] Run RF3 tests: `cd releases/rf3 && pytest tests/` -- [ ] Fix any import errors - -### Phase 7: Documentation -- [ ] Update CLAUDE.md -- [ ] Update README.md -- [ ] Update any other documentation -- [ ] Add migration notes for contributors - -### Phase 8: Validation -- [ ] Clean install test: `pip install -e .` -- [ ] RF3 install test: `pip install -e .[rf3]` -- [ ] Test CLI: `rf3 fold inputs=...` -- [ ] Run full test suite -- [ ] Test on clean environment - ---- - -## Testing Strategy - -### Test Scenarios - -#### Scenario 1: Base Install Only -```bash -# Clean environment -python -m venv test-env -source test-env/bin/activate -pip install -e . - -# Should work -python -c "from modelhub.callbacks import BaseCallback; print('OK')" -python -c "from modelhub.metrics import Metric; print('OK')" -python -c "from modelhub.utils.ddp import RankedLogger; print('OK')" - -# Should fail (RF3 not installed) -python -c "from modelhub.models.rf3 import RF3" # ImportError expected -``` - -#### Scenario 2: With RF3 -```bash -# Clean environment -python -m venv test-env -source test-env/bin/activate -pip install -e .[rf3] - -# Should work -python -c "from modelhub.models.rf3 import RF3; print('OK')" -python -c "from rf3.cli import app; print('OK')" - -# Test CLI -rf3 fold inputs='releases/rf3/tests/data/5vht_from_json.json' -``` - -#### Scenario 3: Development Workflow -```bash -# Install in development mode with RF3 -pip install -e .[rf3,dev] - -# Make changes to releases/rf3/src/rf3/model/RF3.py -# Changes should be immediately available - -python -c "from modelhub.models.rf3.model import RF3" # Should see changes - -# Run tests -pytest releases/rf3/tests/ -``` - ---- - -## Risk Assessment - -### High Risk Items -1. **Breaking all existing imports** - Requires updating many files - - Mitigation: Use automated search/replace, test thoroughly - -2. **Build configuration complexity** - Hatch build hooks can be tricky - - Mitigation: Start with simple approach, test installation frequently - -3. **Package name change** - Breaking change for any external users - - Mitigation: Coordinate with team, announce breaking change - -### Medium Risk Items -1. **CLI entry point changes** - Users may have scripts using old CLI - - Mitigation: Keep backward-compatible entry points - -2. **Test imports** - Many test files to update - - Mitigation: Automated replacement, run tests incrementally - -### Low Risk Items -1. **Documentation updates** - Time-consuming but low risk -2. **`__init__.py` exports** - Easy to fix if wrong - ---- - -## Rollback Plan - -If migration fails: -1. `git checkout main` - Revert to pre-migration state -2. Keep feature branch for later attempt -3. Document what failed for next iteration - ---- - -## Future Considerations - -### Adding New Models -Once restructured, adding a new model is straightforward: - -``` -releases/ -├── rf3/ # Existing -└── proteinmpnn/ # New model - ├── src/proteinmpnn/ - ├── configs/ - └── tests/ -``` - -Then in `pyproject.toml`: -```toml -[project.optional-dependencies] -proteinmpnn = [ - "proteinmpnn-specific-deps", -] -all = [ - "modelforge[rf3,proteinmpnn]", -] -``` - -### Splitting Into Separate Packages (Future) -If models become too large, can later split: -- `modelforge-core` - Base framework -- `modelforge-rf3` - RF3 model -- `modelforge-proteinmpnn` - ProteinMPNN model - -But this is much more complex and not recommended initially. - ---- - -## Questions to Resolve - -1. **Package name**: Keep `modelhub` or rename to `modelforge`? - - Recommendation: Rename to `modelforge` - -2. **RF3 integration**: Which approach (A, B, or C)? - - Recommendation: Option C (hybrid with build hooks) - -3. **CLI structure**: Single entry point or keep separate? - - Recommendation: Keep `rf3` command separate for now - -4. **Base file naming**: Rename `base.py` to match content? - - Recommendation: Yes, rename to `callback.py`, `metric.py` - -5. **Version strategy**: Single version or per-model versions? - - Recommendation: Single version (simpler) - ---- - -## Success Criteria - -- [ ] `pip install modelforge` works -- [ ] `pip install modelforge[rf3]` works -- [ ] All imports are unambiguous and correct -- [ ] All tests pass -- [ ] CLI works: `rf3 fold inputs=...` -- [ ] Structure matches PyTorch Lightning patterns -- [ ] Documentation is updated -- [ ] No more `sys.path` manipulation needed - ---- - -## Timeline Estimate - -- Phase 1 (Rename decision): 1 hour -- Phase 2 (Restructure src/): 2-3 hours -- Phase 3 (Fix RF3 imports): 1-2 hours -- Phase 4 (Build config): 2-3 hours -- Phase 5 (RF3 integration): 2-4 hours -- Phase 6 (Update tests): 1-2 hours -- Phase 7 (Documentation): 1-2 hours -- Phase 8 (Validation): 2-3 hours - -**Total**: 12-20 hours of focused work - ---- - -## Conclusion - -This restructuring will: -1. ✅ Fix all broken imports -2. ✅ Follow industry best practices (Lightning pattern) -3. ✅ Enable proper pip installation -4. ✅ Support selective model installation -5. ✅ Maintain development workflow in `releases/` -6. ✅ Provide clear, unambiguous imports -7. ✅ Scale to future models - -The key is following PyTorch Lightning's pattern: base classes live WITH their implementations, no separate `contrib/` or overly complex directory nesting. diff --git a/models/rf3/CONTAINER.md b/models/rf3/CONTAINER.md new file mode 100644 index 0000000..840d828 --- /dev/null +++ b/models/rf3/CONTAINER.md @@ -0,0 +1,89 @@ +# RF3 Apptainer Containers + +This directory contains two Apptainer definition files for different use cases. + +## Container Options + +### 1. `rf3-standalone.def` - Standalone Container +Contains a complete snapshot of the modelhub repository at build time. Use this for: +- Production deployments +- Reproducible inference with a fixed codebase +- Running on systems where you don't have the repository + +### 2. `rf3-dev.def` - Development Container +Contains only Python dependencies (from `requirements.txt`). Use this for: +- Active development and testing +- Working with your local modelhub code +- Debugging and modifying RF3 code + +**Note**: Generate `requirements.txt` first using `uv pip compile pyproject.toml -o requirements.txt` before building this container. + +## Prerequisites + +- Apptainer/Singularity installed +- NVIDIA GPU with CUDA 12.1+ support +- Sufficient disk space (~10GB per container) + +## Building Containers + +Build from the `models/rf3/` directory: + +```bash +cd models/rf3/ + +# Build standalone container (includes modelhub snapshot) +apptainer build rf3-standalone.sif rf3-standalone.def + +# Build development container (dependencies only) +apptainer build rf3-dev.sif rf3-dev.def + +# Or build with sandbox for debugging +apptainer build --sandbox rf3_sandbox/ rf3-standalone.def +``` + +## Using the Standalone Container + +The standalone container has the full repository baked in. + +### Basic Inference + +```bash +# Run inference on a single input +apptainer exec --nv rf3-standalone.sif rf3 fold inputs='input.json' + +# Process CIF/PDB files +apptainer exec --nv rf3-standalone.sif rf3 fold inputs='structure.cif' + +# Batch processing +apptainer exec --nv rf3-standalone.sif rf3 fold inputs='[file1.cif,file2.json,file3.pdb]' +``` + +### With Custom Weights + +```bash +# Mount weights directory and specify checkpoint +apptainer exec --nv \ + --bind /path/to/weights:/weights \ + rf3-standalone.sif \ + rf3 fold inputs='input.json' ckpt_path='/weights/rf3_latest.pt' +``` + +## Using the Development Container + +The development container requires mounting your local modelhub repository. + +### Basic Usage + +```bash +# Run with local modelhub repository +apptainer exec --nv \ + --bind /path/to/modelhub:/opt/modelhub \ + rf3-dev.sif \ + rf3 fold inputs='input.json' + +# Example with actual path +apptainer exec --nv \ + --bind $PWD/../..:/opt/modelhub \ + rf3-dev.sif \ + rf3 fold inputs='input.json' +``` diff --git a/models/rf3/README.md b/models/rf3/README.md index 1a8058f..5d7f3d6 100644 --- a/models/rf3/README.md +++ b/models/rf3/README.md @@ -1,7 +1,7 @@ # Inference with RosettaFold3(RF3)
- Protein-DNA complex prediction + Protein-DNA complex prediction
> [!IMPORTANT] @@ -21,9 +21,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \ && uv python install 3.12 \ && uv venv --python 3.12 \ && source .venv/bin/activate \ - && uv pip install -e . + && uv pip install -e ./models/rf3 ``` +> [!NOTE] +> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency. For development on both packages, use: `uv pip install -e . -e ./models/rf3` + ### B. Download model weights for RF3 ```bash wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt @@ -67,7 +70,7 @@ For this example, the pTM in the `metrics.csv` should be `>0.8` (even without an RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommended. We do not at the moment support pre-paired MSAs (we will pair on-the-fly) or on-the-fly MSA computation, but both are on the roadmap. Please raise an issue if these limitations are critical for your project and we can prioritize accordingly. -📝 **Example JSON configuration** (full example found at `docs/rf3/examples/3en2_from_json_with_msa.json`): +📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/3en2_from_json_with_msa.json`): ```json { @@ -85,7 +88,7 @@ RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommend 🚀 **Run the example:** ```bash -rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json' +rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json' ``` --- @@ -93,12 +96,12 @@ rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json' If performing inference from a prepared `.cif` file, MSAs can also be specified directly as a category within the raw CIF data. We will automatically extract the correct MSA paths during parsing. -📝 **Example CIF header** (full example found at `docs/rf3/examples/3en2_from_file.cif`): +📝 **Example CIF header** (full example found at `../../docs/releases/rf3/examples/3en2_from_file.cif`): ```cif data_3EN2 # -_msa_paths_by_chain_id.A docs/rf3/examples/msas/b3a35202064.a3m.gz -_msa_paths_by_chain_id.B docs/rf3/examples/msas/b3a35202064.a3m.gz +_msa_paths_by_chain_id.A ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz +_msa_paths_by_chain_id.B ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz # ``` @@ -112,7 +115,7 @@ rf3 fold inputs='docs/rf3/3en2_from_file.cif > Without an MSA and using default settings, the above examples will trigger "early stopping." This means that if the model determines early on that a correct prediction is unlikely, it will stop computation and only output a `metrics.csv` and `.score` file to save compute resources. You can adjust this behavior using the `early_stopping_plddt_threshold` argument (see below). In our group, we find this argument can save wasted compute on erroneous inputs. > [!TIP] -> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs. +> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs. > [!TIP] > For non-canonical amino acids, most MSA generation algorithms substitute `X` (unknown residue)! Ensure your MSAs adhere to this convention. @@ -138,7 +141,7 @@ We will automatically distribute predictions across GPU's if running in a multi- ### 1️⃣ **Single JSON with Multiple Examples** -📝 **Example JSON configuration** (full example found at `docs/rf3/examples/multiple_example_from_json.json`) +📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/multiple_example_from_json.json`) ```json [ @@ -172,7 +175,7 @@ We will automatically distribute predictions across GPU's if running in a multi- 🚀 **Run the example:** ```bash -rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json' +rf3 fold inputs='../../docs/releases/rf3/examples/multiple_examples_from_json.json' ``` --- @@ -182,7 +185,7 @@ rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json' You can specify multiple files/directories using Hydra's list syntax: ```bash -rf3 fold inputs='[docs/rf3/examples/701r_from_file.cif, docs/rf3/examples/701r_from_json.json]' +rf3 fold inputs='[../../docs/releases/rf3/examples/701r_from_file.cif, ../../docs/releases/rf3/examples/701r_from_json.json]' ``` --- @@ -211,7 +214,7 @@ Complex assemblies containing arbitrary biomolecules can be easily folded if pre 🚀 **Run an example from a prepared CIF file:** ```bash -rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif' +rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif' ``` Such files (including all bonds, covalent modifications, non-canonical amino acids, etc.) can be created either (a) directly from ProteinMPNN/LigandMPNN or other software that generates structural files; or, (b) assembled with [AtomWorks](https://github.com/RosettaCommons/atomworks) or another CIF-processing library. @@ -221,7 +224,7 @@ For convenience, we also support a `json` API analogous to that implemented by A > [!TIP] > **Performance Tip**: For small molecules, a general rule-of-thumb is that performance is best when using `CCD` codes directly, followed by `cif`/`sdf` files, and finally SMILES. -📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `docs/rf3/examples/7o1r_from_json.json`): +📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`): ```json [ { @@ -229,7 +232,7 @@ For convenience, we also support a `json` API analogous to that implemented by A "components": [ { "seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH", - "msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz", + "msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz", "chain_id": "A" }, { @@ -242,7 +245,7 @@ For convenience, we also support a `json` API analogous to that implemented by A { // We provide the heme (HEME) via SDF from the CCD; we could have also used a CIF file // We will automatically name the atoms (SDF files do not specify atom names) - "path": "docs/rf3/examples/ligands/HEM.sdf" + "path": "../../docs/releases/rf3/examples/ligands/HEM.sdf" }, { // We provide the imidazole ring (IMD) via SMILES @@ -257,7 +260,7 @@ For convenience, we also support a `json` API analogous to that implemented by A 🚀 **Run the example:** ```bash -rf3 fold inputs='docs/rf3/examples/7o1r_from_json.json' +rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_json.json' ``` **Supported input options:** @@ -278,11 +281,11 @@ As described in the example [Folding with Arbitrary Biomolecules](#folding-with- For example, folding `7o1r`, which contains two N-glycosylations: ```bash -rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif' +rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif' ```

- 7o1r Covalent Modification + 7o1r Covalent Modification

Figure: `7o1r` structure showing N-glycosylation covalent modification prediction with RF3 and ground truth crystal structure. @@ -292,7 +295,7 @@ Such `.cif` files complete with appropriate bonds can be composed with AtomWorks If you would prefer to use the JSON API, bonds can be explicitly given using PyMol-like strings of the form `chain_id/res_name/res_id/atom_name`. You will need to know the specific chain ID, residue name, residue ID, and atom name between the relevant pairs of atoms to unambiguously specify the bond. -📝 **Example JSON configuration with covalent modifcations** (full example found at `docs/rf3/examples/7o1r_from_json.json`): +📝 **Example JSON configuration with covalent modifcations** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`): ```json [ { @@ -300,7 +303,7 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM "components": [ { "seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH", - "msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz", + "msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz", "chain_id": "A" }, { @@ -308,10 +311,10 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM }, { // We provide one sugar via a CIF file, with complete control over bonds and atom names (as we use the atom names from the CIF file) - "path": "docs/rf3/examples/ligands/NAG.cif" + "path": "../../docs/releases/rf3/examples/ligands/NAG.cif" }, { - "path": "docs/rf3/examples/ligands/HEM.sdf" + "path": "../../docs/releases/rf3/examples/ligands/HEM.sdf" }, { "smiles": "[nH]1cc[nH+]c1" @@ -351,7 +354,7 @@ from atomworks.io.utils.visualize import view import numpy as np # Load an SDF file into an AtomArray -sdf_path = "docs/rf3/examples/ligands/NAG.sdf" +sdf_path = "../../docs/releases/rf3/examples/ligands/NAG.sdf" # Load into an AtomArray # Since SDF file files do not have atom names, we automatically generate them (e.g., C1, C2, C3, etc.) @@ -448,17 +451,17 @@ RF3 uses AtomWorks' flexible `AtomSelectionStack` query syntax for specifying st It is often helpful to template one or multiple polymer chains while allowing the other chain(s) to fold unconstrained. We demonstrate with an nanobody-antigen use case below how to apply templates. -📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `docs/rf3/examples/7xli_template_antigen_and_framework.json`): +📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json`): ```json [ { "name": "7xli_template_antigen", "components": [ { - "path": "docs/rf3/examples/templates/7xli_chain_A.cif" + "path": "../../docs/releases/rf3/examples/templates/7xli_chain_A.cif" }, { - "path": "docs/rf3/examples/templates/7xli_chain_B.cif" + "path": "../../docs/releases/rf3/examples/templates/7xli_chain_B.cif" } ], "template_selection": ["A", "B/*/1-42", "B/*/49-63", "B/*/71-102", "B/*/108-125"] @@ -469,7 +472,7 @@ It is often helpful to template one or multiple polymer chains while allowing th 🚀 **Run the example:** ```bash -rf3 fold inputs='docs/rf3/examples/7xli_template_antigen_and_framework.json' +rf3 fold inputs='../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json' ``` You may also specify templating directly via the CLI using `template_selection="[A, B/*/1-42, ...]"`. @@ -483,14 +486,14 @@ We find that enforcing a particular small molecule conformation has various appl For the moment, the ground truth conformer track is only effective if we want to template the *entire* small molecule. Partial templating of small molecules is still possible via the `template_selection` approach. We encourage exploration of both templating techniques to find what combination(s) are most effective for a given problem. Below we provide both, which represents the strongest possible conditioning. -📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `docs/rf3/examples/1eiz_template_ligand_and_protein.json`): +📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `../../docs/releases/rf3/examples/1eiz_template_ligand_and_protein.json`): ```json [ { "name": "9dfn_template_ligand_and_protein", "components": [ { - "path": "docs/rf3/examples/9dfn.cif" + "path": "../../docs/releases/rf3/examples/9dfn.cif" } ], "template_selection": ["A", "C", "D"], @@ -506,7 +509,7 @@ For the moment, the ground truth conformer track is only effective if we want to 🚀 **Run the example:** ```bash -rf3 fold inputs='docs/rf3/examples/8cdz_templating_ligand.json' +rf3 fold inputs='../../docs/releases/rf3/examples/8cdz_templating_ligand.json' ``` You may also specify the ground truth conformer selection directly via the CLI, e.g., using `ground_truth_conformer_selection="[E]"` diff --git a/models/rf3/configs/callbacks/dump_validation_structures.yaml b/models/rf3/configs/callbacks/dump_validation_structures.yaml index b72949c..66ac845 100644 --- a/models/rf3/configs/callbacks/dump_validation_structures.yaml +++ b/models/rf3/configs/callbacks/dump_validation_structures.yaml @@ -1,5 +1,5 @@ dump_validation_structures_callback: - _target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback + _target_: rf3.callbacks.dump_validation_structures.DumpValidationStructuresCallback save_dir: ${paths.output_dir}/val_structures dump_predictions: True one_model_per_file: False diff --git a/models/rf3/configs/callbacks/metrics_logging.yaml b/models/rf3/configs/callbacks/metrics_logging.yaml index 34e7e15..7f30cf8 100644 --- a/models/rf3/configs/callbacks/metrics_logging.yaml +++ b/models/rf3/configs/callbacks/metrics_logging.yaml @@ -1,10 +1,10 @@ store_validation_metrics_in_df_callback: - _target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback + _target_: rf3.callbacks.metrics_logging.StoreValidationMetricsInDFCallback save_dir: ${paths.output_dir}/val_metrics metrics_to_save: "all" log_af3_validation_metrics_callback: - _target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback + _target_: rf3.callbacks.metrics_logging.LogAF3ValidationMetricsCallback # Only logs if present in the metric output dictionary # Must be subset of metrics_to_save metrics_to_log: "all" \ No newline at end of file diff --git a/models/rf3/configs/datasets/pdb_and_distillation.yaml b/models/rf3/configs/datasets/pdb_and_distillation.yaml index 5665766..24e3e5b 100644 --- a/models/rf3/configs/datasets/pdb_and_distillation.yaml +++ b/models/rf3/configs/datasets/pdb_and_distillation.yaml @@ -17,7 +17,7 @@ defaults: - _self_ # Dataloading pipeline to use -pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline +pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline # Dataset weighting train: diff --git a/models/rf3/configs/datasets/pdb_only.yaml b/models/rf3/configs/datasets/pdb_only.yaml index c38cc9d..310f387 100644 --- a/models/rf3/configs/datasets/pdb_only.yaml +++ b/models/rf3/configs/datasets/pdb_only.yaml @@ -9,7 +9,7 @@ defaults: - _self_ # Dataloading pipeline to use -pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline +pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline # Dataset weighting train: diff --git a/models/rf3/configs/datasets/train/disorder_distillation.yaml b/models/rf3/configs/datasets/train/disorder_distillation.yaml index 8174b90..70ec14c 100644 --- a/models/rf3/configs/datasets/train/disorder_distillation.yaml +++ b/models/rf3/configs/datasets/train/disorder_distillation.yaml @@ -2,7 +2,7 @@ disorder_distillation: dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: null # cif parser arguments @@ -13,19 +13,18 @@ disorder_distillation: # metadata parser dataset_parser: - _target_: datahub.datasets.parsers.GenericDFParser + _target_: atomworks.ml.datasets.parsers.GenericDFParser pn_unit_iid_colnames: null # metadata dataset dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset name: pdb_diso_distillation id_column: example_id data: ${paths.data.disorder_distill_parquet_dir}/disorderDistillation.csv columns_to_load: - example_id - path - return_key: null transform: _target_: ${datasets.pipeline_target} is_inference: False diff --git a/models/rf3/configs/datasets/train/domain_distillation.yaml b/models/rf3/configs/datasets/train/domain_distillation.yaml index 2a31dc5..9929d9f 100644 --- a/models/rf3/configs/datasets/train/domain_distillation.yaml +++ b/models/rf3/configs/datasets/train/domain_distillation.yaml @@ -2,7 +2,7 @@ multidomain_distillation: dataset: - _target_: modelhub.data.paired_msa.MultiInputDatasetWrapper + _target_: rf3.data.paired_msa.MultiInputDatasetWrapper save_failed_examples_to_dir: null # cif parser @@ -14,11 +14,11 @@ multidomain_distillation: # metadata parser dataset_parser: - _target_: modelhub.data.paired_msa.MultidomainDFParser + _target_: rf3.data.paired_msa.MultidomainDFParser # metadata dataset dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset name: multidomain_distillation id_column: example_id data: /projects/ml/datahub/dfs/domain_domain/domain_domain_dataset.DIGS.parquet @@ -26,7 +26,6 @@ multidomain_distillation: - example_id - pdb_path - msa_path - return_key: null transform: _target_: ${datasets.pipeline_target} is_inference: False diff --git a/models/rf3/configs/datasets/train/monomer_distillation.yaml b/models/rf3/configs/datasets/train/monomer_distillation.yaml index afbd176..4aabf74 100644 --- a/models/rf3/configs/datasets/train/monomer_distillation.yaml +++ b/models/rf3/configs/datasets/train/monomer_distillation.yaml @@ -2,7 +2,7 @@ monomer_distillation: dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: ${paths.data.failed_examples_dir} # cif parser arguments @@ -13,19 +13,18 @@ monomer_distillation: # metadata parser dataset_parser: - _target_: datahub.datasets.parsers.GenericDFParser + _target_: atomworks.ml.datasets.parsers.GenericDFParser pn_unit_iid_colnames: null # metadata dataset dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset name: af2fb_distillation id_column: example_id data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet columns_to_load: - example_id - path - return_key: null transform: _target_: ${datasets.pipeline_target} is_inference: False diff --git a/models/rf3/configs/datasets/train/na_complex_distillation.yaml b/models/rf3/configs/datasets/train/na_complex_distillation.yaml index 59089d3..c5f2811 100644 --- a/models/rf3/configs/datasets/train/na_complex_distillation.yaml +++ b/models/rf3/configs/datasets/train/na_complex_distillation.yaml @@ -2,7 +2,7 @@ na_complex_distillation: dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: null # cif parser @@ -14,19 +14,18 @@ na_complex_distillation: # metadata parser dataset_parser: - _target_: datahub.datasets.parsers.GenericDFParser + _target_: atomworks.ml.datasets.parsers.GenericDFParser pn_unit_iid_colnames: null #[] # metadata dataset dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset name: tf_distillation id_column: example_id data: ${paths.data.na_complex_distillation_parquet_dir}/transcriptionFactor_distillation_rf3.newDL.csv columns_to_load: - example_id - path - return_key: null transform: _target_: ${datasets.pipeline_target} is_inference: False diff --git a/models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml b/models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml index 9ab18b6..d90e5e4 100644 --- a/models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml +++ b/models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml @@ -1,5 +1,5 @@ weights: - _target_: datahub.samplers.calculate_weights_for_pdb_dataset_df + _target_: atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df # We do not include beta here, since it is different for interfaces and chains alphas: a_prot: 3.0 # 3 for AF-3 diff --git a/models/rf3/configs/datasets/train/pdb/base.yaml b/models/rf3/configs/datasets/train/pdb/base.yaml index 0f2a6ca..2b362e4 100644 --- a/models/rf3/configs/datasets/train/pdb/base.yaml +++ b/models/rf3/configs/datasets/train/pdb/base.yaml @@ -1,16 +1,14 @@ dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: ${paths.data.failed_examples_dir} cif_parser_args: cache_dir: null load_from_cache: false save_to_cache: false dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset # we will use the example_id as the unique column id_column: example_id - # return all keys (do not subset) - return_key: null transform: # common Transform pipeline components for all PDB datasets _target_: ${datasets.pipeline_target} diff --git a/models/rf3/configs/datasets/train/pdb/plinder.yaml b/models/rf3/configs/datasets/train/pdb/plinder.yaml index 830e093..823e3db 100644 --- a/models/rf3/configs/datasets/train/pdb/plinder.yaml +++ b/models/rf3/configs/datasets/train/pdb/plinder.yaml @@ -5,7 +5,7 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.InterfacesDFParser + _target_: atomworks.ml.datasets.parsers.InterfacesDFParser base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb dataset: name: plinder @@ -50,5 +50,5 @@ dataset: crop_spatial_probability: 1.0 weights: - _target_: datahub.samplers.calculate_weights_by_inverse_cluster_size + _target_: atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size cluster_column: pli_qcov__50__weak__component # Need to ablate \ No newline at end of file diff --git a/models/rf3/configs/datasets/train/pdb/train_interface.yaml b/models/rf3/configs/datasets/train/pdb/train_interface.yaml index bc7c00d..abe960a 100644 --- a/models/rf3/configs/datasets/train/pdb/train_interface.yaml +++ b/models/rf3/configs/datasets/train/pdb/train_interface.yaml @@ -4,7 +4,7 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.InterfacesDFParser + _target_: atomworks.ml.datasets.parsers.InterfacesDFParser base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb dataset: name: interface diff --git a/models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml b/models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml index 9c2c578..f30ebb2 100644 --- a/models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml +++ b/models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml @@ -4,7 +4,7 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.PNUnitsDFParser + _target_: atomworks.ml.datasets.parsers.PNUnitsDFParser base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb dataset: name: pn_unit diff --git a/models/rf3/configs/datasets/train/rna_monomer_distillation.yaml b/models/rf3/configs/datasets/train/rna_monomer_distillation.yaml index c11ac06..468b4b9 100644 --- a/models/rf3/configs/datasets/train/rna_monomer_distillation.yaml +++ b/models/rf3/configs/datasets/train/rna_monomer_distillation.yaml @@ -2,7 +2,7 @@ rna_monomer_distillation: dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: ${paths.data.failed_examples_dir} # cif parser arguments @@ -13,12 +13,12 @@ rna_monomer_distillation: # metadata parser dataset_parser: - _target_: datahub.datasets.parsers.GenericDFParser + _target_: atomworks.ml.datasets.parsers.GenericDFParser pn_unit_iid_colnames: null # metadata dataset dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset name: rna_monomer_distillation id_column: example_id data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet @@ -31,7 +31,6 @@ rna_monomer_distillation: - overall_pde - overall_pae - return_key: null transform: _target_: ${datasets.pipeline_target} is_inference: False diff --git a/models/rf3/configs/datasets/val/af3_ab_set.yaml b/models/rf3/configs/datasets/val/af3_ab_set.yaml index 38a31cc..fd9045b 100644 --- a/models/rf3/configs/datasets/val/af3_ab_set.yaml +++ b/models/rf3/configs/datasets/val/af3_ab_set.yaml @@ -3,8 +3,9 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.ValidationDFParserLikeAF3 + _target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3 base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset + name: af3_validation data: /net/scratch/rib7/rf3_ab_splits/entry_level_val_df.parquet diff --git a/models/rf3/configs/datasets/val/af3_validation.yaml b/models/rf3/configs/datasets/val/af3_validation.yaml index 176f12f..b748bca 100644 --- a/models/rf3/configs/datasets/val/af3_validation.yaml +++ b/models/rf3/configs/datasets/val/af3_validation.yaml @@ -3,8 +3,9 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.ValidationDFParserLikeAF3 + _target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3 base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset + name: af3_validation data: ${paths.data.pdb_data_dir}/entry_level_val_df.parquet diff --git a/models/rf3/configs/datasets/val/base.yaml b/models/rf3/configs/datasets/val/base.yaml index bda42d3..8c15adf 100644 --- a/models/rf3/configs/datasets/val/base.yaml +++ b/models/rf3/configs/datasets/val/base.yaml @@ -1,16 +1,15 @@ dataset: - _target_: datahub.datasets.datasets.StructuralDatasetWrapper + _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper save_failed_examples_to_dir: ${paths.data.failed_examples_dir} cif_parser_args: cache_dir: null load_from_cache: False save_to_cache: False dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset # we will use the example_id as the unique column id_column: example_id # return all keys (do not subset) - return_key: null transform: # common Transform pipeline components for all PDB datasets _target_: ${datasets.pipeline_target} diff --git a/models/rf3/configs/datasets/val/runs_and_poses.yaml b/models/rf3/configs/datasets/val/runs_and_poses.yaml index 01923b6..f3a5d2b 100644 --- a/models/rf3/configs/datasets/val/runs_and_poses.yaml +++ b/models/rf3/configs/datasets/val/runs_and_poses.yaml @@ -3,9 +3,10 @@ defaults: dataset: dataset_parser: - _target_: datahub.datasets.parsers.ValidationDFParserLikeAF3 + _target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3 dataset: - _target_: datahub.datasets.datasets.PandasDataset + _target_: atomworks.ml.datasets.datasets.PandasDataset + name: af3_validation data: /projects/ml/datahub/dfs/af3_splits/2024_12_16/runs_n_poses_entry_level_df.parquet filters: - "n_tokens_total < 1000" # Subset to reasonably-sized examples for efficiency diff --git a/models/rf3/configs/experiment/pretrained/rf3.yaml b/models/rf3/configs/experiment/pretrained/rf3.yaml index 7cf8fc7..3855414 100644 --- a/models/rf3/configs/experiment/pretrained/rf3.yaml +++ b/models/rf3/configs/experiment/pretrained/rf3.yaml @@ -5,7 +5,7 @@ name: rf3 defaults: - override /datasets: pdb_and_distillation - override /model: rf3 - - override /trainer: af3 + - override /trainer: rf3 ckpt_config: _target_: modelhub.utils.weights.CheckpointConfig diff --git a/models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml b/models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml index 7957e4d..ee47ba4 100644 --- a/models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml +++ b/models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml @@ -6,7 +6,7 @@ name: rf3-with-confidence defaults: - rf3 - override /model: rf3_with_confidence - - override /trainer: af3_with_confidence + - override /trainer: rf3_with_confidence - _self_ datasets: diff --git a/models/rf3/configs/experiment/quick-rf3-with-confidence.yaml b/models/rf3/configs/experiment/quick-rf3-with-confidence.yaml index bf20aa5..6722ec2 100644 --- a/models/rf3/configs/experiment/quick-rf3-with-confidence.yaml +++ b/models/rf3/configs/experiment/quick-rf3-with-confidence.yaml @@ -8,7 +8,7 @@ name: quick-rf3-with-confidence defaults: - quick-rf3 - override /model: rf3_with_confidence - - override /trainer: af3_with_confidence + - override /trainer: rf3_with_confidence - _self_ datasets: diff --git a/models/rf3/configs/experiment/quick-rf3.yaml b/models/rf3/configs/experiment/quick-rf3.yaml index bad035b..530943c 100644 --- a/models/rf3/configs/experiment/quick-rf3.yaml +++ b/models/rf3/configs/experiment/quick-rf3.yaml @@ -2,7 +2,7 @@ # Experiment that loads a small dataset for quick testing -name: quick-af3 +name: quick-rf3 # For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/ defaults: diff --git a/models/rf3/configs/inference_engine/rf3.yaml b/models/rf3/configs/inference_engine/rf3.yaml index fc3d2da..4fd8edb 100644 --- a/models/rf3/configs/inference_engine/rf3.yaml +++ b/models/rf3/configs/inference_engine/rf3.yaml @@ -4,9 +4,9 @@ defaults: - base - _self_ -_target_: modelhub.inference_engines.rf3.RF3InferenceEngine +_target_: rf3.inference_engines.rf3.RF3InferenceEngine -ckpt_path: /software/containers/versions/modelhub_inference/ckpts/rf3_latest.ckpt +ckpt_path: /home/ncorley/rf3-w-conf-newdatecut-ep688-refactored-fixed.ckpt # Transform arguments n_recycles: 10 @@ -36,8 +36,8 @@ annotate_b_factor_with_plddt: true metrics_cfg: # (For confidence outputs) ptm: - _target_: modelhub.metrics.predicted_error.ComputePTM + _target_: rf3.metrics.predicted_error.ComputePTM iptm: - _target_: modelhub.metrics.predicted_error.ComputeIPTM + _target_: rf3.metrics.predicted_error.ComputeIPTM count_clashing_chains: - _target_: modelhub.metrics.clashing_chains.CountClashingChains \ No newline at end of file + _target_: rf3.metrics.clashing_chains.CountClashingChains \ No newline at end of file diff --git a/models/rf3/configs/model/components/rf3_net.yaml b/models/rf3/configs/model/components/rf3_net.yaml index f90bd72..f629ef2 100644 --- a/models/rf3/configs/model/components/rf3_net.yaml +++ b/models/rf3/configs/model/components/rf3_net.yaml @@ -1,5 +1,5 @@ # Model architecture -_target_: modelhub.model.RF3.RF3 +_target_: rf3.model.RF3.RF3 # +---------- Channel dimensions ----------+ c_s: 384 diff --git a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml index 00de2d5..063e6f5 100644 --- a/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml +++ b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml @@ -2,7 +2,7 @@ defaults: - rf3_net # Model architecture -_target_: modelhub.model.RF3.RF3WithConfidence +_target_: rf3.model.RF3.RF3WithConfidence # +---------- Mini rollout sampler ----------+ # From the AF-3 main text: diff --git a/models/rf3/configs/model/rf3_with_confidence.yaml b/models/rf3/configs/model/rf3_with_confidence.yaml index 2562a31..0aefbb5 100644 --- a/models/rf3/configs/model/rf3_with_confidence.yaml +++ b/models/rf3/configs/model/rf3_with_confidence.yaml @@ -4,4 +4,4 @@ defaults: - _self_ net: - _target_: modelhub.model.RF3.RF3WithConfidence \ No newline at end of file + _target_: rf3.model.RF3.RF3WithConfidence \ No newline at end of file diff --git a/models/rf3/configs/model/schedulers/af3.yaml b/models/rf3/configs/model/schedulers/af3.yaml index 7c2a65d..868a339 100644 --- a/models/rf3/configs/model/schedulers/af3.yaml +++ b/models/rf3/configs/model/schedulers/af3.yaml @@ -1,5 +1,5 @@ # Learning rate scheduler -_target_: modelhub.training.schedulers.AF3Scheduler +_target_: rf3.training.schedulers.AF3Scheduler base_lr: 1.8e-3 warmup_steps: 1000 decay_factor: 0.95 diff --git a/models/rf3/configs/paths/data/default.yaml b/models/rf3/configs/paths/data/default.yaml index 383c391..401abd8 100644 --- a/models/rf3/configs/paths/data/default.yaml +++ b/models/rf3/configs/paths/data/default.yaml @@ -3,30 +3,37 @@ ######################## # path to directory with training splits -pdb_data_dir: ??? +pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2025_07_13 # fb monomer distillation dataset -monomer_distillation_data_dir: ??? -monomer_distillation_parquet_dir: ??? +monomer_distillation_data_dir: /squash/af2_distillation_facebook +monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook + +mgnify_distillation_data_dir: /squash/mgnify_distill_rf3/ +mgnify_distillation_parquet_dir: /home/dimaio/MGnify/ # na complex distill set -na_complex_distillation_data_dir: ??? -na_complex_distillation_parquet_dir: ??? +na_complex_distillation_data_dir: /projects/ml/prot_dna/rf3_newDL +na_complex_distillation_parquet_dir: /projects/ml/prot_dna # disorder distill set -disorder_distill_parquet_dir: ??? +disorder_distill_parquet_dir: /projects/ml/disorder_distill ######################## # MSAs ######################## # path(s) to search for protein MSAs (for PDB datasets) +# e.g., {"dir": "/path/to/msas", "extension": ".a3m.gz", "directory_depth": 2} protein_msa_dirs: - - {"dir": ???, "extension": ".a3m.gz", "directory_depth": 2} + - {"dir": "/projects/msa/hhblits", "extension": ".a3m.gz", "directory_depth": 2} + - {"dir": "/projects/msa/mmseqs_gpu", "extension": ".a3m", "directory_depth": 2} + - {"dir": "/projects/msa/lab", "extension": ".a3m", "directory_depth": 2} # path(s) to search for RNA MSAs +# e.g., {"dir": "/path/to/msas", "extension": ".afa", "directory_depth": 0} rna_msa_dirs: - - {"dir": ???, "extension": ".afa", "directory_depth": 0} + - {"dir": "/projects/msa/rna", "extension": ".afa", "directory_depth": 0} ######################## # Misc diff --git a/models/rf3/configs/trainer/loss/losses/confidence_loss.yaml b/models/rf3/configs/trainer/loss/losses/confidence_loss.yaml index afea09d..3f4baf7 100644 --- a/models/rf3/configs/trainer/loss/losses/confidence_loss.yaml +++ b/models/rf3/configs/trainer/loss/losses/confidence_loss.yaml @@ -1,4 +1,4 @@ -_target_: modelhub.loss.af3_confidence_loss.ConfidenceLoss +_target_: rf3.loss.af3_confidence_loss.ConfidenceLoss weight: 1.0 plddt: diff --git a/models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml b/models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml index 8d19e02..1bb6a11 100644 --- a/models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml +++ b/models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml @@ -1,4 +1,4 @@ -_target_: modelhub.loss.af3_losses.DiffusionLoss +_target_: rf3.loss.af3_losses.DiffusionLoss weight: 4.0 sigma_data: ${model.net.diffusion_module.sigma_data} alpha_dna: 5 diff --git a/models/rf3/configs/trainer/loss/losses/distogram_loss.yaml b/models/rf3/configs/trainer/loss/losses/distogram_loss.yaml index 27b7782..ba45eb5 100644 --- a/models/rf3/configs/trainer/loss/losses/distogram_loss.yaml +++ b/models/rf3/configs/trainer/loss/losses/distogram_loss.yaml @@ -1,2 +1,2 @@ -_target_: modelhub.loss.af3_losses.DistogramLoss +_target_: rf3.loss.af3_losses.DistogramLoss weight: 3e-2 \ No newline at end of file diff --git a/models/rf3/configs/trainer/metrics/structure_prediction.yaml b/models/rf3/configs/trainer/metrics/structure_prediction.yaml index e4a62dc..48630af 100644 --- a/models/rf3/configs/trainer/metrics/structure_prediction.yaml +++ b/models/rf3/configs/trainer/metrics/structure_prediction.yaml @@ -1,14 +1,14 @@ by_type_lddt: - _target_: modelhub.metrics.lddt.ByTypeLDDT + _target_: rf3.metrics.lddt.ByTypeLDDT all_atom_lddt: - _target_: modelhub.metrics.lddt.AllAtomLDDT + _target_: rf3.metrics.lddt.AllAtomLDDT distogram: - _target_: modelhub.metrics.distogram.DistogramLoss + _target_: rf3.metrics.distogram.DistogramLoss distogram_comparisons: - _target_: modelhub.metrics.distogram.DistogramComparisons + _target_: rf3.metrics.distogram.DistogramComparisons distogram_entropy: - _target_: modelhub.metrics.distogram.DistogramEntropy + _target_: rf3.metrics.distogram.DistogramEntropy chiral_loss: - _target_: modelhub.metrics.chiral.ChiralLoss + _target_: rf3.metrics.chiral.ChiralLoss unresolved_rasa: - _target_: modelhub.metrics.rasa.UnresolvedRegionRASA + _target_: rf3.metrics.rasa.UnresolvedRegionRASA diff --git a/models/rf3/configs/trainer/rf3.yaml b/models/rf3/configs/trainer/rf3.yaml index 1ce40a2..270a6b3 100644 --- a/models/rf3/configs/trainer/rf3.yaml +++ b/models/rf3/configs/trainer/rf3.yaml @@ -3,7 +3,7 @@ defaults: - loss: structure_prediction - metrics: structure_prediction -_target_: modelhub.trainers.rf3.RF3Trainer +_target_: rf3.trainers.rf3.RF3Trainer validate_every_n_epochs: 1 max_epochs: 10_000 n_examples_per_epoch: 24000 diff --git a/models/rf3/configs/trainer/rf3_with_confidence.yaml b/models/rf3/configs/trainer/rf3_with_confidence.yaml index 1e80043..4782719 100644 --- a/models/rf3/configs/trainer/rf3_with_confidence.yaml +++ b/models/rf3/configs/trainer/rf3_with_confidence.yaml @@ -2,12 +2,12 @@ defaults: - af3 - override loss: structure_prediction_with_confidence -_target_: modelhub.trainers.rf3.RF3TrainerWithConfidence +_target_: rf3.trainers.rf3.RF3TrainerWithConfidence metrics: ptm: - _target_: modelhub.metrics.predicted_error.ComputePTM + _target_: rf3.metrics.predicted_error.ComputePTM iptm: - _target_: modelhub.metrics.predicted_error.ComputeIPTM + _target_: rf3.metrics.predicted_error.ComputeIPTM count_clashing_chains: - _target_: modelhub.metrics.clashing_chains.CountClashingChains + _target_: rf3.metrics.clashing_chains.CountClashingChains diff --git a/models/rf3/pyproject.toml b/models/rf3/pyproject.toml new file mode 100644 index 0000000..ec3dd2b --- /dev/null +++ b/models/rf3/pyproject.toml @@ -0,0 +1,67 @@ +[project] +name = "rf3" +dynamic = ["version"] +description = "RosettaFold3: Open-source biomolecular structure prediction for all molecules of life." +readme = "README.md" +requires-python = ">= 3.12" +authors = [ + { name = "Institute for Protein Design", email = "contact@ipd.uw.edu" }, +] +license = { file = "../../LICENSE.md" } + +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: POSIX :: Linux", + "Operating System :: MacOS", + "Operating System :: Microsoft :: Windows", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering :: Bio-Informatics", + "License :: OSI Approved :: BSD License", +] + +dependencies = [ + # Core modelhub dependency + "modelhub", + # CLI + "typer>=0.9.0,<1", + # RF3-specific ML dependencies + "einops>=0.8.0,<1", + "einx>=0.1.0,<1", + "opt_einsum>=3.4.0,<4", + "dm-tree>=0.1.6,<1", + # ... kernels (Linux only) + "cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'", + "cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'", + "cuequivariance_torch>=0.5.0; sys_platform == 'linux'", + # ... dataloading + "atomworks==1.0.2", +] + +[project.scripts] +rf3 = "rf3.cli:app" + +# Build settings ---------------------------------------------------------------------- +[build-system] +requires = [ + "hatchling", + "hatch-vcs == 0.4", +] +build-backend = "hatchling.build" + +[tool.hatch.version] +source = "vcs" + +[tool.hatch.version.raw-options] +root = "../.." + +[tool.hatch.build.hooks.vcs] +version-file = "src/rf3/_version.py" + +[tool.hatch.metadata] +allow-direct-references = true + +[tool.hatch.build.targets.wheel] +packages = ["src/rf3"] diff --git a/models/rf3/rf3-dev.def b/models/rf3/rf3-dev.def new file mode 100644 index 0000000..0d8e871 --- /dev/null +++ b/models/rf3/rf3-dev.def @@ -0,0 +1,61 @@ +Bootstrap: docker +From: nvcr.io/nvidia/pytorch:25.04-py3 +IncludeCmd: yes + +# RF3 Development Container - Dependencies only + +%labels + Author Institute for Protein Design + Version 1.0-dev + Description RosettaFold3 - Dependencies only (for development and model training) + +%setup + # Create a directory in the container to bind the host's current working directory + mkdir ${APPTAINER_ROOTFS}/modelhub_host + # ... for mounting `/projects` with --bind + mkdir ${APPTAINER_ROOTFS}/projects + # ... for mounting `/databases` with --bind + mkdir ${APPTAINER_ROOTFS}/net + # ... for mounting `/squash` with --bind + mkdir ${APPTAINER_ROOTFS}/squash + +%files + /etc/localtime + /etc/hosts + pyproject.toml /opt/pyproject.toml + +%post + ## GENERAL SETUP + + # Common symlinks (within container) + ln -s /net/databases /databases + ln -s /net/software /software + ln -s /home /mnt/home + ln -s /projects /mnt/projects + ln -s /net /mnt/net + + ## PYTHON DEPENDENCY INSTALLATION + + # Install uv for fast package installation (10-100x faster than pip) + pip install uv + + # Auto-generate a dependency list + uv pip compile /opt/pyproject.toml -o /opt/requirements.txt + + # Install all Python dependencies using requirements.txt with uv + uv pip install --system --break-system-packages -r /opt/requirements.txt + + +%environment + # (Flag to increase accessible GPU memory) + export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True + + # (Turn off NVLink) + export NCCL_P2P_DISABLE=1 + +%runscript + # NOTE: The %runscript is invoked when the container is run without specifying a different command. + exec python "$@" + +%help + RosettaFold3 (RF3) development container with dependencies only. diff --git a/models/rf3/src/rf3/__init__.py b/models/rf3/src/rf3/__init__.py new file mode 100644 index 0000000..466dc32 --- /dev/null +++ b/models/rf3/src/rf3/__init__.py @@ -0,0 +1,3 @@ +"""RF3 - RosettaFold3 model implementation.""" + +__version__ = "0.1.0" diff --git a/models/rf3/src/rf3/_version.py b/models/rf3/src/rf3/_version.py new file mode 100644 index 0000000..f1bd7c5 --- /dev/null +++ b/models/rf3/src/rf3/_version.py @@ -0,0 +1,34 @@ +# file generated by setuptools-scm +# don't change, don't track in version control + +__all__ = [ + "__version__", + "__version_tuple__", + "version", + "version_tuple", + "__commit_id__", + "commit_id", +] + +TYPE_CHECKING = False +if TYPE_CHECKING: + from typing import Tuple + from typing import Union + + VERSION_TUPLE = Tuple[Union[int, str], ...] + COMMIT_ID = Union[str, None] +else: + VERSION_TUPLE = object + COMMIT_ID = object + +version: str +__version__: str +__version_tuple__: VERSION_TUPLE +version_tuple: VERSION_TUPLE +commit_id: COMMIT_ID +__commit_id__: COMMIT_ID + +__version__ = version = '0.1.dev917+gcbbe4c6a6.d20251001' +__version_tuple__ = version_tuple = (0, 1, 'dev917', 'gcbbe4c6a6.d20251001') + +__commit_id__ = commit_id = None diff --git a/models/rf3/src/rf3/callbacks/dump_validation_structures.py b/models/rf3/src/rf3/callbacks/dump_validation_structures.py index 056258f..0358e06 100644 --- a/models/rf3/src/rf3/callbacks/dump_validation_structures.py +++ b/models/rf3/src/rf3/callbacks/dump_validation_structures.py @@ -1,10 +1,10 @@ from os import PathLike from pathlib import Path -from atomworks.common import parse_example_id +from atomworks.ml.example_id import parse_example_id from beartype.typing import Any -from callbacks.base import BaseCallback +from modelhub.callbacks.callback import BaseCallback from rf3.utils.io import ( build_stack_from_atom_array_and_batched_coords, dump_structures, diff --git a/models/rf3/src/rf3/callbacks/metrics_logging.py b/models/rf3/src/rf3/callbacks/metrics_logging.py index add49a3..65e01be 100755 --- a/models/rf3/src/rf3/callbacks/metrics_logging.py +++ b/models/rf3/src/rf3/callbacks/metrics_logging.py @@ -7,9 +7,9 @@ from atomworks.ml.utils import nested_dict from beartype.typing import Any, Literal from omegaconf import ListConfig -from callbacks.base import BaseCallback -from utils.ddp import RankedLogger -from utils.logging import ( +from modelhub.callbacks.callback import BaseCallback +from modelhub.utils.ddp import RankedLogger +from modelhub.utils.logging import ( condense_count_columns_of_grouped_df, print_df_as_table, ) diff --git a/models/rf3/src/rf3/cli.py b/models/rf3/src/rf3/cli.py index 0736076..e78bcd2 100644 --- a/models/rf3/src/rf3/cli.py +++ b/models/rf3/src/rf3/cli.py @@ -1,4 +1,5 @@ import os +from pathlib import Path import typer from hydra import compose, initialize_config_dir @@ -13,9 +14,11 @@ app = typer.Typer() ) def fold(ctx: typer.Context): """Run structure prediction using hydra config overrides or simple input file.""" - config_path = os.path.join( - os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs" - ) + # Find the RF3 configs directory relative to this file + # This file is at: models/rf3/src/rf3/cli.py + # Configs are at: models/rf3/configs/ + rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/ + config_path = str(rf3_package_dir / "configs") # Get all arguments args = ctx.params.get("args", []) + ctx.args diff --git a/models/rf3/src/rf3/data/ground_truth_template.py b/models/rf3/src/rf3/data/ground_truth_template.py index 19ffb6a..65d80e2 100644 --- a/models/rf3/src/rf3/data/ground_truth_template.py +++ b/models/rf3/src/rf3/data/ground_truth_template.py @@ -20,7 +20,7 @@ from biotite.structure import AtomArray from jaxtyping import Bool, Float, Shaped from torch import Tensor -from utils.torch import assert_no_nans +from modelhub.utils.torch import assert_no_nans logger = logging.getLogger(__name__) diff --git a/models/rf3/src/rf3/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py index c5716cc..3606def 100644 --- a/models/rf3/src/rf3/data/paired_msa.py +++ b/models/rf3/src/rf3/data/paired_msa.py @@ -125,11 +125,7 @@ class MultiInputDatasetWrapper(StructuralDatasetWrapper): ) raise e - # Return the specified key or the entire data dict (i.e., only "feats" key from the Transform dictionary) - if exists(self.return_key): - return data[self.return_key] - else: - return data + return data class MultidomainDFParser(MetadataRowParser): diff --git a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py index 01dfa14..9190aba 100755 --- a/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py +++ b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py @@ -3,7 +3,7 @@ from beartype.typing import Any, Literal from jaxtyping import Float from rf3.data.rotation_augmentation import centre_random_augmentation -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/inference.py b/models/rf3/src/rf3/inference.py index dbbf65a..ba8e8e1 100755 --- a/models/rf3/src/rf3/inference.py +++ b/models/rf3/src/rf3/inference.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from hydra.utils import instantiate from omegaconf import DictConfig -from utils.logging import suppress_warnings +from modelhub.utils.logging import suppress_warnings load_dotenv(override=True) @@ -18,10 +18,11 @@ load_dotenv(override=True) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT` -_config_path = os.path.join( - os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs" -) +# Find the RF3 configs directory relative to this file +# This file is at: models/rf3/src/rf3/inference.py +# Configs are at: models/rf3/configs/ +rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/ +_config_path = str(rf3_package_dir / "configs") @hydra.main( diff --git a/models/rf3/src/rf3/inference_engines/__init__.py b/models/rf3/src/rf3/inference_engines/__init__.py new file mode 100644 index 0000000..0773669 --- /dev/null +++ b/models/rf3/src/rf3/inference_engines/__init__.py @@ -0,0 +1,5 @@ +"""RF3 inference engines.""" + +from rf3.inference_engines.rf3 import RF3InferenceEngine + +__all__ = ["RF3InferenceEngine"] diff --git a/models/rf3/src/rf3/inference_engines/rf3.py b/models/rf3/src/rf3/inference_engines/rf3.py index af4bb7c..eda6076 100644 --- a/models/rf3/src/rf3/inference_engines/rf3.py +++ b/models/rf3/src/rf3/inference_engines/rf3.py @@ -10,12 +10,11 @@ from atomworks.io.transforms.categories import category_to_dict from lightning.fabric import seed_everything from omegaconf import OmegaConf -from inference_engines.base import InferenceEngine from rf3.model.RF3 import ShouldEarlyStopFn from rf3.utils.datasets import ( assemble_distributed_inference_loader_from_list_of_paths, ) -from utils.ddp import RankedLogger, set_accelerator_based_on_availability +from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability from rf3.utils.inference import ( apply_conformer_and_template_selections, build_file_paths_for_prediction, @@ -25,7 +24,7 @@ from rf3.utils.io import ( dump_structures, dump_trajectories, ) -from utils.logging import print_config_tree +from modelhub.utils.logging import print_config_tree from rf3.utils.predicted_error import ( annotate_atom_array_b_factor_with_plddt, compile_af3_confidence_outputs, @@ -55,7 +54,7 @@ def should_early_stop_by_mean_plddt( return fn -class RF3InferenceEngine(InferenceEngine): +class RF3InferenceEngine: """Class for inference with RF3. Evaluates a trained RF3 model on a set of spoofed CIFs.""" def __init__( @@ -152,8 +151,6 @@ class RF3InferenceEngine(InferenceEngine): self.cfg.trainer.num_nodes = num_nodes self.cfg.trainer.devices_per_node = devices_per_node self.cfg.trainer.precision = "bf16-mixed" # HACK: Temporary hack until our checkpoint configs are updated - self.cfg.trainer._target_ = "modelhub.trainers.rf3.RF3TrainerWithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking - self.cfg.model.net._target_ = "modelhub.model.RF3.RF3WithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking # We don't want to compute all of the training metrics, since they may error during inference self.cfg.trainer["metrics"] = {} diff --git a/models/rf3/src/rf3/kinematics.py b/models/rf3/src/rf3/kinematics.py index 111faab..eafd310 100644 --- a/models/rf3/src/rf3/kinematics.py +++ b/models/rf3/src/rf3/kinematics.py @@ -1,11 +1,8 @@ # TODO: Many of these functions are unused; we will deprecate and delete # (They are holdovers from previous frameworks) -from itertools import permutations - import numpy as np import torch -from openbabel import openbabel PARAMS = { "DMIN": 1, @@ -355,36 +352,3 @@ def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS): def standardize_dihedral_retain_first(a, b, c, d): isomorphisms = [(a, b, c, d), (a, c, b, d)] return sorted(isomorphisms)[0] - - -def get_chirals(obmol, xyz): - """ - get all quadruples of atoms forming chiral centers and the expected ideal pseudodihedral between them - """ - stereo = openbabel.OBStereoFacade(obmol) - angle = np.arcsin(1 / 3**0.5) - chiral_idx_set = set() - for i in range(obmol.NumAtoms()): - if not stereo.HasTetrahedralStereo(i): - continue - si = stereo.GetTetrahedralStereo(i) - config = si.GetConfig() - - o = config.center - c = config.from_or_towards - i, j, k = list(config.refs) - for a, b, c in permutations((c, i, j, k), 3): - chiral_idx_set.add(standardize_dihedral_retain_first(o, a, b, c)) - - chiral_idx = list(chiral_idx_set) - chiral_idx.sort() - chiral_idx = torch.tensor(chiral_idx, dtype=torch.float32) - chiral_idx = chiral_idx[(chiral_idx < obmol.NumAtoms()).all(dim=-1)] - - if chiral_idx.numel() == 0: - return torch.zeros((0, 5)) - - dih = get_dih(*xyz[chiral_idx.long()].split(split_size=1, dim=1))[:, 0] - chirals = torch.nn.functional.pad(chiral_idx, (0, 1), mode="constant", value=angle) - chirals[dih < 0.0, -1] *= -1 - return chirals diff --git a/src/modelhub/metrics/chiral.py b/models/rf3/src/rf3/metrics/chiral.py similarity index 99% rename from src/modelhub/metrics/chiral.py rename to models/rf3/src/rf3/metrics/chiral.py index 260b872..d8d8ce3 100644 --- a/src/modelhub/metrics/chiral.py +++ b/models/rf3/src/rf3/metrics/chiral.py @@ -10,7 +10,7 @@ from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import Bool, Float from rf3.kinematics import get_dih -from metrics.base import Metric +from modelhub.metrics.metric import Metric def calc_chiral_metrics_masked( diff --git a/models/rf3/src/rf3/metrics/clashing_chains.py b/models/rf3/src/rf3/metrics/clashing_chains.py index 4206d25..f68df72 100644 --- a/models/rf3/src/rf3/metrics/clashing_chains.py +++ b/models/rf3/src/rf3/metrics/clashing_chains.py @@ -4,7 +4,7 @@ from typing import Any import torch from biotite.structure import AtomArrayStack -from metrics.base import Metric +from modelhub.metrics.metric import Metric class CountClashingChains(Metric): diff --git a/models/rf3/src/rf3/metrics/distogram.py b/models/rf3/src/rf3/metrics/distogram.py index 38e57a4..81cfd95 100644 --- a/models/rf3/src/rf3/metrics/distogram.py +++ b/models/rf3/src/rf3/metrics/distogram.py @@ -11,8 +11,8 @@ from einops import rearrange, repeat from jaxtyping import Bool, Float from rf3.loss.af3_losses import distogram_loss -from metrics.base import Metric -from utils.torch import assert_no_nans +from modelhub.metrics.metric import Metric +from modelhub.utils.torch import assert_no_nans @dataclass diff --git a/models/rf3/src/rf3/metrics/lddt.py b/models/rf3/src/rf3/metrics/lddt.py index f88ea67..89ad0c2 100644 --- a/models/rf3/src/rf3/metrics/lddt.py +++ b/models/rf3/src/rf3/metrics/lddt.py @@ -1,9 +1,7 @@ import numpy as np import torch -from atomworks.ml.transforms.atom_array import ( - AddGlobalTokenIdAnnotation, - ensure_atom_array_stack, -) +from atomworks.io.transforms.atom_array import ensure_atom_array_stack +from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation from atomworks.ml.transforms.atomize import AtomizeByCCDName from atomworks.ml.transforms.base import Compose from atomworks.ml.utils.token import get_token_starts @@ -11,8 +9,8 @@ from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack, stack from jaxtyping import Bool, Float, Int -from metrics.base import Metric -from utils.ddp import RankedLogger +from modelhub.metrics.metric import Metric +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/metrics/metadata.py b/models/rf3/src/rf3/metrics/metadata.py index 90b86eb..91b88d4 100644 --- a/models/rf3/src/rf3/metrics/metadata.py +++ b/models/rf3/src/rf3/metrics/metadata.py @@ -2,7 +2,7 @@ import json from beartype.typing import Any, Literal -from metrics.base import Metric +from modelhub.metrics.metric import Metric class ExtraInfo(Metric): diff --git a/models/rf3/src/rf3/metrics/predicted_error.py b/models/rf3/src/rf3/metrics/predicted_error.py index 824d81d..468e9d7 100644 --- a/models/rf3/src/rf3/metrics/predicted_error.py +++ b/models/rf3/src/rf3/metrics/predicted_error.py @@ -2,7 +2,7 @@ from typing import Any import torch -from metrics.base import Metric +from modelhub.metrics.metric import Metric from rf3.metrics.metric_utils import find_bin_midpoints diff --git a/src/modelhub/metrics/rasa.py b/models/rf3/src/rf3/metrics/rasa.py similarity index 99% rename from src/modelhub/metrics/rasa.py rename to models/rf3/src/rf3/metrics/rasa.py index d64b468..ba5e2f3 100644 --- a/src/modelhub/metrics/rasa.py +++ b/models/rf3/src/rf3/metrics/rasa.py @@ -3,7 +3,7 @@ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa from beartype.typing import Any from biotite.structure import AtomArrayStack -from metrics.base import Metric +from modelhub.metrics.metric import Metric class UnresolvedRegionRASA(Metric): diff --git a/models/rf3/src/rf3/metrics/selected_distances.py b/models/rf3/src/rf3/metrics/selected_distances.py index 01ef478..216f075 100644 --- a/models/rf3/src/rf3/metrics/selected_distances.py +++ b/models/rf3/src/rf3/metrics/selected_distances.py @@ -7,7 +7,7 @@ from atomworks.ml.utils.selection import ( from beartype.typing import Any from biotite.structure import AtomArrayStack -from metrics.base import Metric +from modelhub.metrics.metric import Metric class SelectedAtomByAtomDistances(Metric): diff --git a/models/rf3/src/rf3/model/layers/Attention_module.py b/models/rf3/src/rf3/model/layers/Attention_module.py index b983f11..2f79ffd 100644 --- a/models/rf3/src/rf3/model/layers/Attention_module.py +++ b/models/rf3/src/rf3/model/layers/Attention_module.py @@ -6,7 +6,7 @@ import torch.nn.functional as F from einops import rearrange from opt_einsum import contract as einsum -from src import SHOULD_USE_CUEQUIVARIANCE +from modelhub import SHOULD_USE_CUEQUIVARIANCE from rf3.training.checkpoint import activation_checkpointing from rf3.util_module import init_lecun_normal diff --git a/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py index 92c5bb1..e9aaa69 100644 --- a/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py +++ b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from jaxtyping import Float -from src import SHOULD_USE_CUEQUIVARIANCE +from modelhub import SHOULD_USE_CUEQUIVARIANCE from rf3.util_module import init_lecun_normal if SHOULD_USE_CUEQUIVARIANCE: diff --git a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py index ba3d5a2..9e5d303 100644 --- a/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py +++ b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py @@ -12,7 +12,7 @@ from rf3.model.layers.layer_utils import ( ) from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage from rf3.training.checkpoint import activation_checkpointing -from utils.torch import device_of +from modelhub.utils.torch import device_of class AtomAttentionEncoderDiffusion(nn.Module): diff --git a/models/rf3/src/rf3/train.py b/models/rf3/src/rf3/train.py index a8b31d9..299ef2d 100755 --- a/models/rf3/src/rf3/train.py +++ b/models/rf3/src/rf3/train.py @@ -8,8 +8,8 @@ import rootutils from dotenv import load_dotenv from omegaconf import DictConfig -from utils.logging import suppress_warnings -from utils.weights import CheckpointConfig +from modelhub.utils.logging import suppress_warnings +from modelhub.utils.weights import CheckpointConfig load_dotenv(override=True) @@ -18,10 +18,7 @@ load_dotenv(override=True) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT` -_config_path = os.path.join( - os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs" -) +_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs") _spawning_process_logger = logging.getLogger(__name__) @@ -43,14 +40,14 @@ def train(cfg: DictConfig) -> None: # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision torch.set_float32_matmul_precision("medium") - from callbacks.base import BaseCallback # noqa - from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa - from utils.logging import ( + from modelhub.callbacks.callback import BaseCallback # noqa + from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa + from modelhub.utils.logging import ( print_config_tree, log_hyperparameters_with_all_loggers, ) # noqa - from utils.ddp import RankedLogger # noqa - from utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa + from modelhub.utils.ddp import RankedLogger # noqa + from modelhub.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa from rf3.utils.datasets import ( recursively_instantiate_datasets_and_samplers, assemble_distributed_loader, diff --git a/models/rf3/src/rf3/trainers/rf3.py b/models/rf3/src/rf3/trainers/rf3.py index f375e23..a603e00 100644 --- a/models/rf3/src/rf3/trainers/rf3.py +++ b/models/rf3/src/rf3/trainers/rf3.py @@ -6,20 +6,20 @@ from jaxtyping import Float, Int from lightning_utilities import apply_to_collection from omegaconf import DictConfig -from common import exists +from modelhub.common import exists from rf3.loss.af3_losses import Loss as AF3Loss from rf3.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, ) -from metrics.base import MetricManager +from modelhub.metrics.metric import MetricManager from rf3.model.RF3 import ShouldEarlyStopFn -from trainers.fabric import FabricTrainer +from modelhub.trainers.fabric import FabricTrainer from rf3.training.EMA import EMA -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger from rf3.utils.io import build_stack_from_atom_array_and_batched_coords from rf3.utils.recycling import get_recycle_schedule -from utils.torch import assert_no_nans, assert_same_shape +from modelhub.utils.torch import assert_no_nans, assert_same_shape ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/training/checkpoint.py b/models/rf3/src/rf3/training/checkpoint.py index da42cd7..94b185c 100644 --- a/models/rf3/src/rf3/training/checkpoint.py +++ b/models/rf3/src/rf3/training/checkpoint.py @@ -1,8 +1,4 @@ -"""Utilities for gradient checkpointing to reduce memory usage during training. - -Gradient checkpointing (also called activation checkpointing) trades compute for memory -by recomputing intermediate activations during the backward pass instead of storing them. -This enables training larger models or using larger batch sizes within GPU memory constraints. +"""Utilities for gradient checkpointing. References: * `PyTorch Checkpoint Documentation`_ @@ -17,9 +13,8 @@ from torch.utils.checkpoint import checkpoint def create_custom_forward(module, **kwargs): """Create a custom forward function for gradient checkpointing with fixed kwargs. - This helper enables passing keyword arguments to a module when using PyTorch's - checkpoint function, which only accepts positional arguments for the function to - be checkpointed. + Enables passing keyword arguments to a module when using PyTorch's checkpoint function, + which only accepts positional arguments for the function to be checkpointed. Args: module: The callable (typically a nn.Module) to wrap. @@ -28,15 +23,6 @@ def create_custom_forward(module, **kwargs): Returns: A callable that accepts only positional arguments and forwards them along with the fixed kwargs to the original module. - - Examples: - Use with PyTorch checkpoint:: - - custom_fn = create_custom_forward(my_module, frame_atom_idxs=frame_idxs) - output = checkpoint(custom_fn, input_tensor, use_reentrant=False) - - See Also: - :py:func:`activation_checkpointing` """ def custom_forward(*inputs): @@ -48,11 +34,6 @@ def create_custom_forward(module, **kwargs): def activation_checkpointing(function): """Decorator to enable gradient checkpointing for a function during training. - When gradients are enabled (training mode), this decorator wraps the function - with PyTorch's checkpoint to save memory by recomputing activations during - the backward pass. During inference (gradients disabled), the function runs - normally without checkpointing overhead. - Args: function: The function to apply gradient checkpointing to. @@ -67,11 +48,7 @@ def activation_checkpointing(function): return self.layer(x, mask) Notes: - Uses ``use_reentrant=False`` for better compatibility with modern PyTorch - features like autograd hooks and higher-order gradients. - - See Also: - :py:func:`create_custom_forward` + Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions. """ def wrapper(*args, **kwargs): diff --git a/models/rf3/src/rf3/utils/datasets.py b/models/rf3/src/rf3/utils/datasets.py index 9e3cf4d..80e8242 100755 --- a/models/rf3/src/rf3/utils/datasets.py +++ b/models/rf3/src/rf3/utils/datasets.py @@ -20,8 +20,8 @@ from torch.utils.data import ( ) from torch.utils.data.distributed import DistributedSampler -from hydra.resolvers import register_resolvers -from utils.ddp import RankedLogger +from modelhub.hydra.resolvers import register_resolvers +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) try: diff --git a/models/rf3/src/rf3/utils/io.py b/models/rf3/src/rf3/utils/io.py index c663e49..6d6f275 100644 --- a/models/rf3/src/rf3/utils/io.py +++ b/models/rf3/src/rf3/utils/io.py @@ -9,7 +9,7 @@ from beartype.typing import Literal from biotite.structure import AtomArray, AtomArrayStack, stack from rf3.alignment import weighted_rigid_align -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/models/rf3/src/rf3/validate.py b/models/rf3/src/rf3/validate.py index b06a2a7..d623686 100755 --- a/models/rf3/src/rf3/validate.py +++ b/models/rf3/src/rf3/validate.py @@ -2,13 +2,14 @@ import logging import os +from pathlib import Path import hydra import rootutils from dotenv import load_dotenv from omegaconf import DictConfig -from utils.logging import suppress_warnings +from modelhub.utils.logging import suppress_warnings load_dotenv(override=True) @@ -17,10 +18,7 @@ load_dotenv(override=True) # NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located) rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) -# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT` -_config_path = os.path.join( - os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs" -) +_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs") _spawning_process_logger = logging.getLogger(__name__) @@ -42,11 +40,11 @@ def validate(cfg: DictConfig) -> None: # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision torch.set_float32_matmul_precision("medium") - from callbacks.base import BaseCallback # noqa - from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa - from utils.logging import print_config_tree # noqa - from utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa - from utils.ddp import is_rank_zero # noqa + from modelhub.callbacks.callback import BaseCallback # noqa + from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa + from modelhub.utils.logging import print_config_tree # noqa + from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa + from modelhub.utils.ddp import is_rank_zero # noqa from rf3.utils.datasets import assemble_val_loader_dict # noqa set_accelerator_based_on_availability(cfg) diff --git a/models/rf3/tests/conftest.py b/models/rf3/tests/conftest.py index dfdc002..ca2d361 100644 --- a/models/rf3/tests/conftest.py +++ b/models/rf3/tests/conftest.py @@ -9,11 +9,18 @@ TEST_DATA_DIR = Path(__file__).resolve().parent / "data" def pytest_configure(config): + import sys + # Set PROJECT_ROOT project_root = rootutils.setup_root( __file__, indicator=".project-root", pythonpath=True ) + # Add models/rf3/src to path so RF3 modules can be imported + rf3_src = project_root / "models" / "rf3" / "src" + if rf3_src.exists() and str(rf3_src) not in sys.path: + sys.path.insert(0, str(rf3_src)) + # Construct path to .env file at project root dotenv_path = project_root / ".env" diff --git a/models/rf3/tests/data/5vht_from_file.cif b/models/rf3/tests/data/5vht_from_file.cif index d967154..d13b416 100644 --- a/models/rf3/tests/data/5vht_from_file.cif +++ b/models/rf3/tests/data/5vht_from_file.cif @@ -1,7 +1,7 @@ data_5VHT # -_msa_paths_by_chain_id.A tests/data/msas/5vht_A.a3m -_msa_paths_by_chain_id.B tests/data/msas/5vht_A.a3m +_msa_paths_by_chain_id.A models/rf3/tests/data/msas/5vht_A.a3m +_msa_paths_by_chain_id.B models/rf3/tests/data/msas/5vht_A.a3m # _entry.id 5VHT # diff --git a/models/rf3/tests/data/5vht_from_json.json b/models/rf3/tests/data/5vht_from_json.json index 4471145..24e977e 100644 --- a/models/rf3/tests/data/5vht_from_json.json +++ b/models/rf3/tests/data/5vht_from_json.json @@ -5,12 +5,12 @@ { "seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH", "chain_id": "A", - "msa_path": "tests/data/msas/5vht_A.a3m" + "msa_path": "models/rf3/tests/data/msas/5vht_A.a3m" }, { "seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH", "chain_id": "B", - "msa_path": "tests/data/msas/5vht_A.a3m" + "msa_path": "models/rf3/tests/data/msas/5vht_A.a3m" } ] } diff --git a/tests/test_metrics.py b/models/rf3/tests/test_chiral_metrics.py similarity index 97% rename from tests/test_metrics.py rename to models/rf3/tests/test_chiral_metrics.py index 3c58da5..27e6b75 100644 --- a/tests/test_metrics.py +++ b/models/rf3/tests/test_chiral_metrics.py @@ -3,7 +3,7 @@ from copy import deepcopy import pytest from atomworks.ml.utils.testing import cached_parse -from metrics.chiral import ChiralLoss +from rf3.metrics.chiral import ChiralLoss @pytest.mark.parametrize("pdb_id", ["5ocm", "1ivo"]) diff --git a/pyproject.toml b/pyproject.toml index 501706b..1966da4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] -name = "rf3" +name = "modelhub" dynamic = ["version"] -description = "Open-source biomolecular structure prediction for all molecules of life." +description = "Shared utilities and training infrastructure for biomolecular structure prediction models." readme = "README.md" requires-python = ">= 3.12" authors = [ @@ -22,17 +22,6 @@ classifiers = [ "License :: OSI Approved :: BSD License", ] dependencies = [ - # ...ml tools - "torch>=2.2.0,<3", - "lightning>=2.4.0,<2.5", - "einops>=0.8.0,<1", - "einx>=0.1.0,<1", - "opt_einsum>=3.4.0,<4", - "dm-tree>=0.1.6,<1", - # ... kernels (Linux only) - "cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'", - "cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'", - "cuequivariance_torch>=0.5.0; sys_platform == 'linux'", # ... configuration & CLI "rootutils>=1.0.7,<1.1", "hydra-core>=1.3.0,<1.4", @@ -43,9 +32,9 @@ dependencies = [ # ... typing & documentation "jaxtyping>=0.2.17,<1", "beartype>=0.18.0,<1", - - # ... dataloading - "atomworks==1.0.2" + # ... ml tools (core) + "torch>=2.2.0,<3", + "lightning>=2.4.0,<2.5", ] @@ -64,8 +53,7 @@ dev = [ "pytest-cov>=4.1.0,<5", # generate coverage report "pytest-benchmark>=5.0.0,<6", # benchmark tests for speed ] -[project.scripts] -rf3 = "modelhub.cli:app" +# No CLI scripts at root level - models provide their own entry points # Build settings ---------------------------------------------------------------------- [build-system] requires = [ diff --git a/src/modelhub/__init__.py b/src/modelhub/__init__.py index 5b2e701..fe8ddf1 100644 --- a/src/modelhub/__init__.py +++ b/src/modelhub/__init__.py @@ -44,4 +44,4 @@ except ImportError: logger.debug("cuEquivariance unavailable: import failed") # Export for easy access -__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "silence_warnings"] +__all__ = ["SHOULD_USE_CUEQUIVARIANCE"] diff --git a/src/modelhub/callbacks/__init__.py b/src/modelhub/callbacks/__init__.py new file mode 100644 index 0000000..ae64d24 --- /dev/null +++ b/src/modelhub/callbacks/__init__.py @@ -0,0 +1,5 @@ +"""Callbacks for training and validation.""" + +from modelhub.callbacks.callback import BaseCallback + +__all__ = ["BaseCallback"] diff --git a/src/modelhub/callbacks/health_logging.py b/src/modelhub/callbacks/health_logging.py index 44e617c..f027b0a 100644 --- a/src/modelhub/callbacks/health_logging.py +++ b/src/modelhub/callbacks/health_logging.py @@ -10,7 +10,7 @@ from jaxtyping import Float, Int from lightning.fabric.utilities.rank_zero import rank_zero_only from torch import Tensor -from callbacks.base import BaseCallback +from modelhub.callbacks.callback import BaseCallback _DEFAULT_STATISTICS = types.MappingProxyType( { diff --git a/src/modelhub/callbacks/timing_logging.py b/src/modelhub/callbacks/timing_logging.py index 4b4ecca..8f2f7e1 100644 --- a/src/modelhub/callbacks/timing_logging.py +++ b/src/modelhub/callbacks/timing_logging.py @@ -1,10 +1,10 @@ import pandas as pd from lightning.fabric.utilities.rank_zero import rank_zero_only - -from callbacks.base import BaseCallback from rf3.utils.logging import print_df_as_table from rf3.utils.torch_utils import Timers +from modelhub.callbacks.callback import BaseCallback + class TimingCallback(BaseCallback): """Fabric callback to print timing metrics.""" diff --git a/src/modelhub/callbacks/train_logging.py b/src/modelhub/callbacks/train_logging.py index 65842e4..a6328cf 100755 --- a/src/modelhub/callbacks/train_logging.py +++ b/src/modelhub/callbacks/train_logging.py @@ -2,26 +2,26 @@ import time from collections import defaultdict import pandas as pd -from atomworks.common import parse_example_id +from atomworks.ml.example_id import parse_example_id from beartype.typing import Any from lightning.fabric.wrappers import ( _FabricOptimizer, ) +from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses from rich.console import Group from rich.panel import Panel from rich.table import Table from torch import nn from torchmetrics.aggregation import MeanMetric -from callbacks.base import BaseCallback -from rf3.utils.ddp import RankedLogger -from rf3.utils.logging import ( +from modelhub.callbacks.callback import BaseCallback +from modelhub.utils.ddp import RankedLogger +from modelhub.utils.logging import ( print_df_as_table, print_model_parameters, safe_print, table_from_df, ) -from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses class LogModelParametersCallback(BaseCallback): diff --git a/src/modelhub/inference_engines/base.py b/src/modelhub/inference_engines/base.py deleted file mode 100644 index ea7e69e..0000000 --- a/src/modelhub/inference_engines/base.py +++ /dev/null @@ -1,15 +0,0 @@ -from abc import ABC, abstractmethod -from pathlib import Path - - -class InferenceEngine(ABC): - """Abstract base class for inference pipelines.""" - - @abstractmethod - def __init__(self, **kwargs): - pass - - @abstractmethod - def eval(self, inputs: list[Path]) -> None: - """Run inference on input files.""" - pass diff --git a/src/modelhub/metrics/__init__.py b/src/modelhub/metrics/__init__.py new file mode 100644 index 0000000..68e802d --- /dev/null +++ b/src/modelhub/metrics/__init__.py @@ -0,0 +1,12 @@ +"""Metrics for model evaluation. + +This module provides the base metric framework. +""" + +from modelhub.metrics.metric import Metric, MetricInputError, MetricManager + +__all__ = [ + "Metric", + "MetricManager", + "MetricInputError", +] diff --git a/src/modelhub/metrics/base.py b/src/modelhub/metrics/metric.py similarity index 99% rename from src/modelhub/metrics/base.py rename to src/modelhub/metrics/metric.py index 391ee12..2fb8444 100644 --- a/src/modelhub/metrics/base.py +++ b/src/modelhub/metrics/metric.py @@ -8,7 +8,7 @@ from beartype.typing import Any from omegaconf import DictConfig from toolz import keymap -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/trainers/fabric.py b/src/modelhub/trainers/fabric.py index bc90df2..c9f6a5f 100755 --- a/src/modelhub/trainers/fabric.py +++ b/src/modelhub/trainers/fabric.py @@ -25,12 +25,12 @@ from lightning.fabric.wrappers import ( _FabricModule, _FabricOptimizer, ) - -from callbacks.base import BaseCallback from rf3.training.EMA import EMA from rf3.training.schedulers import SchedulerConfig -from utils.ddp import RankedLogger -from utils.weights import ( + +from modelhub.callbacks.callback import BaseCallback +from modelhub.utils.ddp import RankedLogger +from modelhub.utils.weights import ( CheckpointConfig, WeightLoadingConfig, freeze_parameters_with_config, diff --git a/src/modelhub/utils/instantiators.py b/src/modelhub/utils/instantiators.py index 417bc2a..0a13844 100755 --- a/src/modelhub/utils/instantiators.py +++ b/src/modelhub/utils/instantiators.py @@ -2,7 +2,7 @@ import hydra from lightning.fabric.loggers import Logger from omegaconf import DictConfig -from callbacks.base import BaseCallback +from modelhub.callbacks.callback import BaseCallback def _can_be_instantiated(cfg: DictConfig) -> bool: diff --git a/src/modelhub/utils/logging.py b/src/modelhub/utils/logging.py index 0c9adcd..4e37b57 100755 --- a/src/modelhub/utils/logging.py +++ b/src/modelhub/utils/logging.py @@ -12,7 +12,7 @@ from rich.table import Table from rich.tree import Tree from torch import nn -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/utils/torch.py b/src/modelhub/utils/torch.py index d37c4e6..09e5954 100755 --- a/src/modelhub/utils/torch.py +++ b/src/modelhub/utils/torch.py @@ -14,8 +14,8 @@ from torch import Tensor from torch._prims_common import DeviceLikeType from torch.types import _dtype -from src import should_check_nans -from common import at_least_one_exists, do_nothing +from modelhub import should_check_nans +from modelhub.common import at_least_one_exists, do_nothing def map_to( diff --git a/src/modelhub/utils/weights.py b/src/modelhub/utils/weights.py index e66ea87..cf20bf9 100644 --- a/src/modelhub/utils/weights.py +++ b/src/modelhub/utils/weights.py @@ -9,7 +9,7 @@ import torch from beartype.typing import Pattern from torch import nn -from utils.ddp import RankedLogger +from modelhub.utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/tests/test_torch_utils.py b/tests/test_torch_utils.py index 3cf2663..faa2467 100644 --- a/tests/test_torch_utils.py +++ b/tests/test_torch_utils.py @@ -4,7 +4,7 @@ import pytest import torch os.environ["NAN_CHECKING"] = "True" -from utils.torch import assert_no_nans, map_to +from modelhub.utils.torch import assert_no_nans, map_to def test_map_to(): diff --git a/tests/test_weight_loading.py b/tests/test_weight_loading.py index e0f00ef..72f2ac7 100644 --- a/tests/test_weight_loading.py +++ b/tests/test_weight_loading.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn # Import your code here -from utils.weights import ( +from modelhub.utils.weights import ( ParameterFreezingConfig, WeightLoadingConfig, WeightLoadingPolicy,