fix: working rf3

This commit is contained in:
ncorley
2025-10-01 23:52:01 -07:00
parent cbbe4c6a6d
commit 592de5b488
89 changed files with 554 additions and 1226 deletions

View File

@@ -24,7 +24,7 @@ For more information, please see our preprint, [Accelerating Biomolecular Modeli
</div>
> [!TIP]
> Complete inference instructions for RF3 are provided [here](src/modelhub/inference_engines/README.md).
> Complete inference instructions for RF3 are provided [here](models/rf3/README.md).
### RF3 Quick Start - Installation & Usage
@@ -32,7 +32,7 @@ Follow these steps to set up **ModelForge** and run a test prediction with **RF3
---
#### 1. Install the repository using `uv`
#### 1. Install the source repository and RF3 model using `uv`
```bash
git clone https://github.com/RosettaCommons/modelforge.git \
@@ -40,9 +40,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \
&& uv python install 3.12 \
&& uv venv --python 3.12 \
&& source .venv/bin/activate \
&& uv pip install -e .
&& uv pip install -e ./models/rf3
```
> [!NOTE]
> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency.
#### 2. Download model weights for RF3
```bash
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
@@ -53,4 +56,56 @@ wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
rf3 fold tests/data/5vht_from_json.json
```
Details on the exact formatting of the json files are available [here](src/modelhub/inference_engines/README.md).
Details on the exact formatting of the json files are available [here](models/rf3/README.md).
## Development
### Package Structure
ModelForge uses a multi-package architecture:
- **`modelhub`**: Core package containing shared utilities, training infrastructure, and base classes
- **`models/rf3/`**: RF3 model package with model-specific code and dependencies
- **`models/<future>/`**: Additional models can be added as separate packages
### Installation Options
#### For Users (Single Model)
Install only the model you need. Dependencies (including `modelhub`) are automatically installed:
```bash
# Install RF3 (includes modelhub automatically)
uv pip install -e ./models/rf3
# Future: Install other models
# uv pip install -e ./models/other_model
```
#### For Core Developers (Multiple Packages)
Install both `modelhub` and models in editable mode for development:
```bash
# Install modelhub and RF3 in editable mode
uv pip install -e . -e ./models/rf3
# Or install only modelhub (no models)
uv pip install -e .
```
This approach allows you to:
- Modify `modelhub` shared utilities and see changes immediately
- Work on specific models without installing all models
- Add new models as independent packages in `models/`
### Adding New Models
To add a new model:
1. Create `models/<model_name>/` directory with its own `pyproject.toml`
2. Add `modelhub` as a dependency
3. Implement model-specific code in `models/<model_name>/src/`
4. Users can install with: `uv pip install -e ./models/<model_name>`
## Development

View File

@@ -1,919 +0,0 @@
# ModelForge Repository Restructuring Plan
**Date**: October 1, 2025
**Author**: Analysis based on PyTorch Lightning patterns
**Goal**: Reorganize `src/modelhub/` following best practices while keeping `releases/` structure intact
---
## Executive Summary
This plan addresses three critical issues in the current structure:
1. **Broken imports**: RF3 uses `from metrics.base import Metric` but base classes are in `src/modelhub/`
2. **CLI entry point mismatch**: `pyproject.toml` points to non-existent `modelhub.cli:app`
3. **Package installation confusion**: Only `src/modelhub` is packaged, but RF3 code is in `releases/`
4. **Unclear organization**: Mixed base classes and implementations without clear semantic separation
**Core Principle**: Follow PyTorch Lightning's pattern - base classes live WITH their implementations, no separate `contrib/` directory.
---
## Current Structure Analysis
### What Exists Now
```
modelhub_latent/
├── pyproject.toml # name = "rf3" (wrong!)
├── src/
│ └── modelhub/ # Only this is packaged
│ ├── callbacks/
│ │ ├── base.py # BaseCallback
│ │ ├── health_logging.py # Implementation
│ │ ├── timing_logging.py # Implementation
│ │ └── train_logging.py # Implementation
│ ├── metrics/
│ │ ├── base.py # Metric + MetricManager
│ │ ├── chiral.py # Implementation
│ │ └── rasa.py # Implementation
│ ├── utils/
│ ├── trainers/
│ └── hydra/
├── releases/
│ └── rf3/ # NOT packaged!
│ ├── src/rf3/
│ │ ├── callbacks/
│ │ ├── metrics/
│ │ ├── model/
│ │ └── cli.py
│ ├── configs/
│ └── tests/
└── tests/ # Shared tests
```
### Critical Problems
#### Problem 1: Import Path Mismatch
**RF3 code uses**:
```python
from metrics.base import Metric # Expects local resolution
from callbacks.base import BaseCallback # Expects local resolution
```
**Reality**:
- Base classes are in `src/modelhub/metrics/base.py`
- This only works via sys.path manipulation or broken imports
**Should be**:
```python
from modelhub.metrics.metric import Metric
from modelhub.callbacks.callback import BaseCallback
```
#### Problem 2: CLI Entry Point
**Current `pyproject.toml`**:
```toml
[project.scripts]
rf3 = "modelhub.cli:app"
```
**Reality**: There is NO `src/modelhub/cli.py` file!
**Actual CLI**: Located at `releases/rf3/src/rf3/cli.py`
#### Problem 3: Package Name Confusion
**Current**: `name = "rf3"` in pyproject.toml
**Problem**:
- Repository is called "modelforge"
- Users would `pip install rf3` (gets only base framework)
- Can't do `pip install modelforge[rf3]` for selective installation
#### Problem 4: Packaging Scope
**Current**: `packages = ["src/modelhub"]`
**Problem**: RF3 code in `releases/rf3/src/rf3/` is NOT included in the package!
---
## Design Goals
### User Experience Goals
1. `pip install modelforge` → Get base framework only
2. `pip install modelforge[rf3]` → Get base + RF3
3. `pip install modelforge[all]` → Get all models (future-proof)
4. Development: `pip install -e .[rf3,dev]`
### Code Organization Goals
1. Follow PyTorch Lightning patterns (familiar to ML researchers)
2. Clear distinction between base classes and implementations
3. Proper namespacing for imports
4. Maintain `releases/` structure for development workflow
### Import Pattern Goals
```python
# Clear, unambiguous imports
from modelforge.callbacks.callback import BaseCallback
from modelforge.metrics.metric import Metric
from modelforge.utils.ddp import RankedLogger
# RF3-specific imports
from modelforge.models.rf3.model import RF3
```
---
## PyTorch Lightning Pattern Analysis
### Lightning's Structure (Reference)
```
lightning/pytorch/
├── core/ # Primary model abstractions
│ ├── module.py # LightningModule (what users inherit from)
│ ├── datamodule.py # LightningDataModule
│ └── hooks.py
├── callbacks/ # Base + implementations together
│ ├── callback.py # Callback base class
│ ├── model_checkpoint.py # ModelCheckpoint implementation
│ ├── early_stopping.py # EarlyStopping implementation
│ └── __init__.py # Exports everything
├── loggers/ # Base + implementations together
│ ├── logger.py # Logger base class
│ ├── wandb.py # WandbLogger implementation
│ └── __init__.py
└── utilities/ # Pure utilities
```
### Key Insights from Lightning
1. **`core/` is for PRIMARY model abstractions** - Things users inherit from to define their model:
- `LightningModule` - "Your model IS A LightningModule"
- `LightningDataModule` - "Your data IS A LightningDataModule"
2. **Other directories contain base + implementations together**:
- `callbacks/callback.py` (base) + `callbacks/model_checkpoint.py` (impl)
- `loggers/logger.py` (base) + `loggers/wandb.py` (impl)
3. **NO `contrib/` directory** - That's a Django pattern, not Lightning!
4. **`__init__.py` exports everything** for clean imports
### When to Use `core/`
**Use `core/` IF**:
- You have a primary model class users inherit from (like `LightningModule`)
- These classes define "what you ARE building" (semantic role)
**Don't use `core/` IF**:
- All your base classes are for plugins/extensions (callbacks, metrics)
- You don't have a central model abstraction
**For ModelForge**: Since there's no primary "ModelBase" class, **we don't need `core/`**.
### Re-evaluating `inference_engines/`
**Current situation**:
- `InferenceEngine` is a minimal ABC with only 2 abstract methods (`__init__`, `eval`)
- Only RF3 uses it (no other models yet)
- It's a very thin abstraction that provides minimal value
**Question**: Do we need this abstraction at all?
**Arguments for keeping it**:
- Future-proofs for when other models are added
- Provides a consistent interface pattern
**Arguments for removing it**:
- YAGNI (You Aren't Gonna Need It) - only one implementation exists
- Adds indirection without clear benefit
- The abstraction is so minimal it doesn't enforce meaningful constraints
- Each model's inference is likely to be sufficiently different that a shared interface adds little value
**Recommendation**: **Remove the `inference_engines/` directory entirely from `src/modelhub/`**
- RF3's inference engine can live solely in `models/rf3/src/rf3/inference_engines/rf3.py`
- Remove the ABC inheritance - just make it a standalone class
- When/if a second model is added, we can evaluate whether a shared abstraction makes sense based on actual commonalities
- This follows YAGNI and keeps the code simpler
---
## Recommended Solution
### Two-Part Restructuring
#### Part 1: Fix `src/modelhub/` Structure (Simple, Lightning-style)
```
src/
└── modelhub/ (rename to modelforge?)
├── __init__.py # Export key base classes
├── callbacks/
│ ├── __init__.py # Export Callback + all implementations
│ ├── callback.py # BaseCallback (RENAMED from base.py)
│ ├── health_logging.py # HealthLoggingCallback
│ ├── timing_logging.py # TimingLoggingCallback
│ └── train_logging.py # TrainLoggingCallback
├── metrics/
│ ├── __init__.py # Export Metric + all implementations
│ ├── metric.py # Metric, MetricManager (RENAMED from base.py)
│ ├── chiral.py # ChiralMetric
│ └── rasa.py # RASAMetric
├── trainers/
│ ├── __init__.py
│ └── fabric.py # Fabric trainer wrapper
├── utils/ # Pure utilities (no base classes)
│ ├── ddp.py
│ ├── logging.py
│ ├── weights.py
│ ├── instantiators.py
│ └── torch.py
└── hydra/ # Hydra utilities
├── __init__.py
└── resolvers.py
```
**Key changes**:
1. Rename `base.py` files to match their content:
- `callbacks/base.py``callbacks/callback.py`
- `metrics/base.py``metrics/metric.py`
2. **REMOVE** `inference_engines/` entirely (not needed with only one model)
3. Keep implementations WITH base classes (Lightning pattern)
4. No `core/` directory (not needed without primary model abstraction)
5. Proper `__init__.py` exports
#### Part 2: Integrate RF3 into Main Package (For pip install)
**Decision Point**: Choose ONE of these approaches:
##### Option A: Move RF3 into src/modelhub/ (Monorepo style)
```
src/
└── modelhub/
├── callbacks/, metrics/, utils/ (as above)
└── models/ # NEW
└── rf3/ # Moved from releases/rf3/src/rf3/
├── __init__.py
├── model/
├── data/
├── metrics/ # RF3-specific metrics
├── callbacks/ # RF3-specific callbacks
├── inference_engines/
└── cli.py
```
**Pros**:
- Simple pip install: `pip install modelforge[rf3]`
- Single package, single version
- Easy code sharing between models
**Cons**:
- Changes development workflow (no separate `releases/` for development)
- RF3 code always in source tree even if not installed
##### Option B: Keep releases/ separate (Development friendly)
```
modelhub_latent/
├── src/modelhub/ # Base framework only
├── releases/
│ └── rf3/ # Stays as is for development
│ └── src/rf3/
└── pyproject.toml # Use optional dependencies
```
**Then in `pyproject.toml`**:
```toml
[project.optional-dependencies]
rf3 = [
"modelforge-rf3 @ file:///${PROJECT_ROOT}/releases/rf3",
# RF3-specific deps
]
```
**Pros**:
- Keeps development workflow unchanged
- Clear separation for releases
**Cons**:
- More complex build setup
- Requires editable installs for development
- Each release needs own `pyproject.toml`
##### Option C: Hybrid - Move RF3 but keep releases/ for development (RECOMMENDED)
```
# For development: work in releases/rf3/
releases/rf3/
├── src/rf3/ # Development happens here
├── configs/ # RF3 configs here
└── tests/ # RF3 tests here
# For installation: symlink or copy during build
src/modelhub/models/rf3/ → symlink to releases/rf3/src/rf3/
# Or use build hooks to include releases/ in package
```
**Implementation**: Use hatch build hooks to include `releases/rf3/src/rf3/` as `src/modelhub/models/rf3/`
---
## Detailed Implementation Plan
### Phase 1: Rename Package (Breaking Change Decision)
**Decision needed**: Keep `modelhub` or rename to `modelforge`?
**Arguments for `modelforge`**:
- Matches repository name
- More descriptive (framework for building models)
- Enables `pip install modelforge[rf3]`
**Arguments for `modelhub`**:
- Less renaming work
- Existing code already uses it
**Recommendation**: Rename to `modelforge` for clarity and marketing.
### Phase 2: Restructure src/modelhub/ (src/modelforge/)
#### Step 2.1: Rename Base Class Files
```bash
# In src/modelhub/
git mv callbacks/base.py callbacks/callback.py
git mv metrics/base.py metrics/metric.py
# inference_engines/base.py could stay or rename to inference_engine.py
```
#### Step 2.2: Update Imports in Base Files
**In `src/modelhub/callbacks/callback.py`**:
```python
# Update any internal imports if needed
# Mainly just ensure docstrings reference correct module names
```
**In `src/modelhub/metrics/metric.py`**:
```python
# Update imports and docstrings
```
#### Step 2.3: Create Proper __init__.py Files
**`src/modelhub/callbacks/__init__.py`**:
```python
"""Callbacks for training customization.
This module provides both the base callback class and common callback implementations.
"""
from modelhub.callbacks.callback import BaseCallback
from modelhub.callbacks.health_logging import HealthLoggingCallback
from modelhub.callbacks.timing_logging import TimingLoggingCallback
from modelhub.callbacks.train_logging import TrainLoggingCallback
__all__ = [
"BaseCallback",
"HealthLoggingCallback",
"TimingLoggingCallback",
"TrainLoggingCallback",
]
```
**`src/modelhub/metrics/__init__.py`**:
```python
"""Metrics for model evaluation.
This module provides the base metric framework and common metric implementations.
"""
from modelhub.metrics.metric import Metric, MetricManager, instantiate_metric_manager
from modelhub.metrics.chiral import ChiralMetric
from modelhub.metrics.rasa import RASAMetric
__all__ = [
"Metric",
"MetricManager",
"instantiate_metric_manager",
"ChiralMetric",
"RASAMetric",
]
```
**`src/modelhub/__init__.py`**:
```python
"""ModelForge: Open-source framework for biomolecular modeling.
This package provides the base framework for building and training biomolecular models.
"""
# Export key base classes at top level
from modelhub.callbacks.callback import BaseCallback
from modelhub.metrics.metric import Metric, MetricManager
# Version
from modelhub.version import __version__
__all__ = [
"BaseCallback",
"Metric",
"MetricManager",
"__version__",
]
```
### Phase 3: Fix RF3 Imports
#### Step 3.1: Update All RF3 Import Statements
**Find all broken imports**:
```bash
cd releases/rf3/
grep -r "from metrics.base import" src/
grep -r "from callbacks.base import" src/
grep -r "from inference_engines.base import" src/
```
**Replace with**:
```python
# OLD (broken)
from metrics.base import Metric
from callbacks.base import BaseCallback
# NEW (correct)
from modelhub.metrics.metric import Metric, MetricManager
from modelhub.callbacks.callback import BaseCallback
from modelhub.inference_engines.base import InferenceEngine
```
**Automated replacement**:
```bash
# In releases/rf3/src/
find . -name "*.py" -exec sed -i 's/from metrics\.base import/from modelhub.metrics.metric import/g' {} +
find . -name "*.py" -exec sed -i 's/from callbacks\.base import/from modelhub.callbacks.callback import/g' {} +
find . -name "*.py" -exec sed -i 's/from inference_engines\.base import/from modelhub.inference_engines.base import/g' {} +
```
#### Step 3.2: Update RF3 Utility Imports
```python
# Also update imports of shared utilities
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.logging import suppress_warnings
from modelhub.utils.weights import load_checkpoint
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks
```
### Phase 4: Fix Build Configuration
#### Step 4.1: Update pyproject.toml
**Current**:
```toml
[project]
name = "rf3"
[project.scripts]
rf3 = "modelhub.cli:app"
[tool.hatch.build.targets.wheel]
packages = ["src/modelhub"]
```
**New**:
```toml
[project]
name = "modelforge"
description = "Open-source framework for biomolecular structure prediction and design"
# Minimal dependencies for base framework
dependencies = [
"torch>=2.2.0,<3",
"lightning>=2.4.0,<2.5",
"hydra-core>=1.3.0,<1.4",
"rootutils>=1.0.7,<1.1",
"environs>=11.0.0,<12",
"wandb>=0.15.10,<1",
"rich>=13.9.4,<14",
"jaxtyping>=0.2.17,<1",
"beartype>=0.18.0,<1",
]
[project.optional-dependencies]
# RF3 model with its specific dependencies
rf3 = [
"atomworks==1.0.2",
"einops>=0.8.0,<1",
"einx>=0.1.0,<1",
"opt_einsum>=3.4.0,<4",
"dm-tree>=0.1.6,<1",
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
]
# All models
all = [
"modelforge[rf3]",
]
# Development
dev = [
"ruff==0.8.3",
"pytest>=8.2.0,<9",
# ... other dev deps
]
[project.scripts]
# Main CLI - needs to be created or use RF3 CLI directly
rf3 = "rf3.cli:app" # Direct to RF3 for now
[tool.hatch.build.targets.wheel]
# Include both modelhub and rf3 (if using monorepo approach)
packages = ["src/modelhub"]
# OR include releases/rf3/src/rf3 via custom build hook
# For development: force-include releases
[tool.hatch.build.targets.wheel.force-include]
"releases/rf3/src/rf3" = "modelhub/models/rf3" # If using monorepo approach
```
#### Step 4.2: Create Main CLI (If Needed)
**Option 1**: Keep `rf3` CLI pointing directly to RF3
```toml
[project.scripts]
rf3 = "rf3.cli:app"
```
**Option 2**: Create dispatcher CLI (future-proof)
**`src/modelhub/cli.py`**:
```python
"""Main ModelForge CLI."""
import typer
app = typer.Typer()
# Import RF3 CLI if available
try:
from rf3.cli import app as rf3_app
# Expose RF3 commands at top level
for command_name, command in rf3_app.registered_commands:
app.command(name=command_name)(command.callback)
except ImportError:
pass # RF3 not installed
if __name__ == "__main__":
app()
```
Then:
```toml
[project.scripts]
modelforge = "modelhub.cli:app"
rf3 = "rf3.cli:app" # Keep for backward compatibility
```
### Phase 5: Handle RF3 Package Integration
**Decision needed**: Choose approach from Phase 2, Part 2.
**Recommended: Option C (Hybrid)**
Use hatch build hooks to include RF3 from `releases/`:
**`hatch_build.py`** (in project root):
```python
"""Custom hatch build hook to include RF3 from releases/."""
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
class CustomBuildHook(BuildHookInterface):
def initialize(self, version, build_data):
"""Include RF3 from releases/ directory."""
if self.target_name == "wheel":
# Add releases/rf3/src/rf3 to the wheel as modelhub/models/rf3
build_data["force_include"] = {
"releases/rf3/src/rf3": "modelhub/models/rf3"
}
```
**In `pyproject.toml`**:
```toml
[tool.hatch.build.hooks.custom]
path = "hatch_build.py"
```
This way:
- Development happens in `releases/rf3/`
- Installation includes RF3 in the package
- Users can `pip install modelforge[rf3]`
### Phase 6: Update Tests
#### Step 6.1: Update Test Imports
**In `releases/rf3/tests/`**:
```python
# Update conftest.py and test files
from modelhub.metrics.metric import Metric
from modelhub.callbacks.callback import BaseCallback
```
**In root `tests/`**:
```python
# These already test shared utilities
from modelhub.utils.weights import load_checkpoint
from modelhub.metrics.metric import MetricManager
```
#### Step 6.2: Verify Test Execution
```bash
# Test shared utilities
pytest tests/
# Test RF3 (from releases/rf3/)
cd releases/rf3/
pytest tests/
```
### Phase 7: Update Documentation
#### Step 7.1: Update CLAUDE.md
Update import examples and package structure documentation.
#### Step 7.2: Update README.md
```markdown
# ModelForge
## Installation
```bash
# Base framework only
pip install modelforge
# With RF3
pip install modelforge[rf3]
# Development
pip install -e .[rf3,dev]
```
## Usage
```python
# Import base framework
from modelforge.callbacks import BaseCallback
from modelforge.metrics import Metric
# Import RF3
from modelforge.models.rf3 import RF3
```
```
---
## Migration Checklist
### Pre-Migration
- [ ] Backup current code
- [ ] Create feature branch: `git checkout -b refactor/restructure-package`
- [ ] Document current import patterns for reference
### Phase 1: Package Rename
- [ ] Decision: Keep `modelhub` or rename to `modelforge`?
- [ ] Update `pyproject.toml` name field
- [ ] Update all imports if renaming
### Phase 2: Restructure src/
- [ ] Rename `callbacks/base.py` → `callbacks/callback.py`
- [ ] Rename `metrics/base.py` → `metrics/metric.py`
- [ ] Create/update all `__init__.py` files with proper exports
- [ ] Update internal imports within src/modelhub/
### Phase 3: Fix RF3 Imports
- [ ] Find all `from metrics.base` imports in releases/rf3/
- [ ] Replace with `from modelhub.metrics.metric`
- [ ] Find all `from callbacks.base` imports
- [ ] Replace with `from modelhub.callbacks.callback`
- [ ] Update utility imports to use `modelhub.utils.*`
- [ ] Test that RF3 can import all needed modules
### Phase 4: Fix Build Config
- [ ] Update `pyproject.toml` project name
- [ ] Split dependencies: core vs rf3-specific
- [ ] Add `[project.optional-dependencies]` for rf3
- [ ] Fix `[project.scripts]` CLI entry point
- [ ] Update `[tool.hatch.build.targets.wheel]` packages
- [ ] Add build hooks if using hybrid approach
### Phase 5: RF3 Integration
- [ ] Decide on integration approach (A, B, or C)
- [ ] Implement chosen approach
- [ ] Test `pip install -e .` (base only)
- [ ] Test `pip install -e .[rf3]` (with RF3)
### Phase 6: Update Tests
- [ ] Update test imports
- [ ] Run shared tests: `pytest tests/`
- [ ] Run RF3 tests: `cd releases/rf3 && pytest tests/`
- [ ] Fix any import errors
### Phase 7: Documentation
- [ ] Update CLAUDE.md
- [ ] Update README.md
- [ ] Update any other documentation
- [ ] Add migration notes for contributors
### Phase 8: Validation
- [ ] Clean install test: `pip install -e .`
- [ ] RF3 install test: `pip install -e .[rf3]`
- [ ] Test CLI: `rf3 fold inputs=...`
- [ ] Run full test suite
- [ ] Test on clean environment
---
## Testing Strategy
### Test Scenarios
#### Scenario 1: Base Install Only
```bash
# Clean environment
python -m venv test-env
source test-env/bin/activate
pip install -e .
# Should work
python -c "from modelhub.callbacks import BaseCallback; print('OK')"
python -c "from modelhub.metrics import Metric; print('OK')"
python -c "from modelhub.utils.ddp import RankedLogger; print('OK')"
# Should fail (RF3 not installed)
python -c "from modelhub.models.rf3 import RF3" # ImportError expected
```
#### Scenario 2: With RF3
```bash
# Clean environment
python -m venv test-env
source test-env/bin/activate
pip install -e .[rf3]
# Should work
python -c "from modelhub.models.rf3 import RF3; print('OK')"
python -c "from rf3.cli import app; print('OK')"
# Test CLI
rf3 fold inputs='releases/rf3/tests/data/5vht_from_json.json'
```
#### Scenario 3: Development Workflow
```bash
# Install in development mode with RF3
pip install -e .[rf3,dev]
# Make changes to releases/rf3/src/rf3/model/RF3.py
# Changes should be immediately available
python -c "from modelhub.models.rf3.model import RF3" # Should see changes
# Run tests
pytest releases/rf3/tests/
```
---
## Risk Assessment
### High Risk Items
1. **Breaking all existing imports** - Requires updating many files
- Mitigation: Use automated search/replace, test thoroughly
2. **Build configuration complexity** - Hatch build hooks can be tricky
- Mitigation: Start with simple approach, test installation frequently
3. **Package name change** - Breaking change for any external users
- Mitigation: Coordinate with team, announce breaking change
### Medium Risk Items
1. **CLI entry point changes** - Users may have scripts using old CLI
- Mitigation: Keep backward-compatible entry points
2. **Test imports** - Many test files to update
- Mitigation: Automated replacement, run tests incrementally
### Low Risk Items
1. **Documentation updates** - Time-consuming but low risk
2. **`__init__.py` exports** - Easy to fix if wrong
---
## Rollback Plan
If migration fails:
1. `git checkout main` - Revert to pre-migration state
2. Keep feature branch for later attempt
3. Document what failed for next iteration
---
## Future Considerations
### Adding New Models
Once restructured, adding a new model is straightforward:
```
releases/
├── rf3/ # Existing
└── proteinmpnn/ # New model
├── src/proteinmpnn/
├── configs/
└── tests/
```
Then in `pyproject.toml`:
```toml
[project.optional-dependencies]
proteinmpnn = [
"proteinmpnn-specific-deps",
]
all = [
"modelforge[rf3,proteinmpnn]",
]
```
### Splitting Into Separate Packages (Future)
If models become too large, can later split:
- `modelforge-core` - Base framework
- `modelforge-rf3` - RF3 model
- `modelforge-proteinmpnn` - ProteinMPNN model
But this is much more complex and not recommended initially.
---
## Questions to Resolve
1. **Package name**: Keep `modelhub` or rename to `modelforge`?
- Recommendation: Rename to `modelforge`
2. **RF3 integration**: Which approach (A, B, or C)?
- Recommendation: Option C (hybrid with build hooks)
3. **CLI structure**: Single entry point or keep separate?
- Recommendation: Keep `rf3` command separate for now
4. **Base file naming**: Rename `base.py` to match content?
- Recommendation: Yes, rename to `callback.py`, `metric.py`
5. **Version strategy**: Single version or per-model versions?
- Recommendation: Single version (simpler)
---
## Success Criteria
- [ ] `pip install modelforge` works
- [ ] `pip install modelforge[rf3]` works
- [ ] All imports are unambiguous and correct
- [ ] All tests pass
- [ ] CLI works: `rf3 fold inputs=...`
- [ ] Structure matches PyTorch Lightning patterns
- [ ] Documentation is updated
- [ ] No more `sys.path` manipulation needed
---
## Timeline Estimate
- Phase 1 (Rename decision): 1 hour
- Phase 2 (Restructure src/): 2-3 hours
- Phase 3 (Fix RF3 imports): 1-2 hours
- Phase 4 (Build config): 2-3 hours
- Phase 5 (RF3 integration): 2-4 hours
- Phase 6 (Update tests): 1-2 hours
- Phase 7 (Documentation): 1-2 hours
- Phase 8 (Validation): 2-3 hours
**Total**: 12-20 hours of focused work
---
## Conclusion
This restructuring will:
1. ✅ Fix all broken imports
2. ✅ Follow industry best practices (Lightning pattern)
3. ✅ Enable proper pip installation
4. ✅ Support selective model installation
5. ✅ Maintain development workflow in `releases/`
6. ✅ Provide clear, unambiguous imports
7. ✅ Scale to future models
The key is following PyTorch Lightning's pattern: base classes live WITH their implementations, no separate `contrib/` or overly complex directory nesting.

89
models/rf3/CONTAINER.md Normal file
View File

@@ -0,0 +1,89 @@
# RF3 Apptainer Containers
This directory contains two Apptainer definition files for different use cases.
## Container Options
### 1. `rf3-standalone.def` - Standalone Container
Contains a complete snapshot of the modelhub repository at build time. Use this for:
- Production deployments
- Reproducible inference with a fixed codebase
- Running on systems where you don't have the repository
### 2. `rf3-dev.def` - Development Container
Contains only Python dependencies (from `requirements.txt`). Use this for:
- Active development and testing
- Working with your local modelhub code
- Debugging and modifying RF3 code
**Note**: Generate `requirements.txt` first using `uv pip compile pyproject.toml -o requirements.txt` before building this container.
## Prerequisites
- Apptainer/Singularity installed
- NVIDIA GPU with CUDA 12.1+ support
- Sufficient disk space (~10GB per container)
## Building Containers
Build from the `models/rf3/` directory:
```bash
cd models/rf3/
# Build standalone container (includes modelhub snapshot)
apptainer build rf3-standalone.sif rf3-standalone.def
# Build development container (dependencies only)
apptainer build rf3-dev.sif rf3-dev.def
# Or build with sandbox for debugging
apptainer build --sandbox rf3_sandbox/ rf3-standalone.def
```
## Using the Standalone Container
The standalone container has the full repository baked in.
### Basic Inference
```bash
# Run inference on a single input
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='input.json'
# Process CIF/PDB files
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='structure.cif'
# Batch processing
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='[file1.cif,file2.json,file3.pdb]'
```
### With Custom Weights
```bash
# Mount weights directory and specify checkpoint
apptainer exec --nv \
--bind /path/to/weights:/weights \
rf3-standalone.sif \
rf3 fold inputs='input.json' ckpt_path='/weights/rf3_latest.pt'
```
## Using the Development Container
The development container requires mounting your local modelhub repository.
### Basic Usage
```bash
# Run with local modelhub repository
apptainer exec --nv \
--bind /path/to/modelhub:/opt/modelhub \
rf3-dev.sif \
rf3 fold inputs='input.json'
# Example with actual path
apptainer exec --nv \
--bind $PWD/../..:/opt/modelhub \
rf3-dev.sif \
rf3 fold inputs='input.json'
```

View File

@@ -1,7 +1,7 @@
# Inference with RosettaFold3(RF3)
<div align="center">
<img src="./docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
<img src="../../docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
</div>
> [!IMPORTANT]
@@ -21,9 +21,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \
&& uv python install 3.12 \
&& uv venv --python 3.12 \
&& source .venv/bin/activate \
&& uv pip install -e .
&& uv pip install -e ./models/rf3
```
> [!NOTE]
> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency. For development on both packages, use: `uv pip install -e . -e ./models/rf3`
### B. Download model weights for RF3
```bash
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
@@ -67,7 +70,7 @@ For this example, the pTM in the `metrics.csv` should be `>0.8` (even without an
RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommended. We do not at the moment support pre-paired MSAs (we will pair on-the-fly) or on-the-fly MSA computation, but both are on the roadmap. Please raise an issue if these limitations are critical for your project and we can prioritize accordingly.
📝 **Example JSON configuration** (full example found at `docs/rf3/examples/3en2_from_json_with_msa.json`):
📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/3en2_from_json_with_msa.json`):
```json
{
@@ -85,7 +88,7 @@ RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommend
🚀 **Run the example:**
```bash
rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json'
rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json'
```
---
@@ -93,12 +96,12 @@ rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json'
If performing inference from a prepared `.cif` file, MSAs can also be specified directly as a category within the raw CIF data.
We will automatically extract the correct MSA paths during parsing.
📝 **Example CIF header** (full example found at `docs/rf3/examples/3en2_from_file.cif`):
📝 **Example CIF header** (full example found at `../../docs/releases/rf3/examples/3en2_from_file.cif`):
```cif
data_3EN2
#
_msa_paths_by_chain_id.A docs/rf3/examples/msas/b3a35202064.a3m.gz
_msa_paths_by_chain_id.B docs/rf3/examples/msas/b3a35202064.a3m.gz
_msa_paths_by_chain_id.A ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz
_msa_paths_by_chain_id.B ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz
#
```
@@ -112,7 +115,7 @@ rf3 fold inputs='docs/rf3/3en2_from_file.cif
> Without an MSA and using default settings, the above examples will trigger "early stopping." This means that if the model determines early on that a correct prediction is unlikely, it will stop computation and only output a `metrics.csv` and `.score` file to save compute resources. You can adjust this behavior using the `early_stopping_plddt_threshold` argument (see below). In our group, we find this argument can save wasted compute on erroneous inputs.
> [!TIP]
> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs.
> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs.
> [!TIP]
> For non-canonical amino acids, most MSA generation algorithms substitute `X` (unknown residue)! Ensure your MSAs adhere to this convention.
@@ -138,7 +141,7 @@ We will automatically distribute predictions across GPU's if running in a multi-
### 1⃣ **Single JSON with Multiple Examples**
📝 **Example JSON configuration** (full example found at `docs/rf3/examples/multiple_example_from_json.json`)
📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/multiple_example_from_json.json`)
```json
[
@@ -172,7 +175,7 @@ We will automatically distribute predictions across GPU's if running in a multi-
🚀 **Run the example:**
```bash
rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json'
rf3 fold inputs='../../docs/releases/rf3/examples/multiple_examples_from_json.json'
```
---
@@ -182,7 +185,7 @@ rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json'
You can specify multiple files/directories using Hydra's list syntax:
```bash
rf3 fold inputs='[docs/rf3/examples/701r_from_file.cif, docs/rf3/examples/701r_from_json.json]'
rf3 fold inputs='[../../docs/releases/rf3/examples/701r_from_file.cif, ../../docs/releases/rf3/examples/701r_from_json.json]'
```
---
@@ -211,7 +214,7 @@ Complex assemblies containing arbitrary biomolecules can be easily folded if pre
🚀 **Run an example from a prepared CIF file:**
```bash
rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif'
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif'
```
Such files (including all bonds, covalent modifications, non-canonical amino acids, etc.) can be created either (a) directly from ProteinMPNN/LigandMPNN or other software that generates structural files; or, (b) assembled with [AtomWorks](https://github.com/RosettaCommons/atomworks) or another CIF-processing library.
@@ -221,7 +224,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
> [!TIP]
> **Performance Tip**: For small molecules, a general rule-of-thumb is that performance is best when using `CCD` codes directly, followed by `cif`/`sdf` files, and finally SMILES.
📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `docs/rf3/examples/7o1r_from_json.json`):
📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`):
```json
[
{
@@ -229,7 +232,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
"components": [
{
"seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH",
"msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz",
"msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz",
"chain_id": "A"
},
{
@@ -242,7 +245,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
{
// We provide the heme (HEME) via SDF from the CCD; we could have also used a CIF file
// We will automatically name the atoms (SDF files do not specify atom names)
"path": "docs/rf3/examples/ligands/HEM.sdf"
"path": "../../docs/releases/rf3/examples/ligands/HEM.sdf"
},
{
// We provide the imidazole ring (IMD) via SMILES
@@ -257,7 +260,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
🚀 **Run the example:**
```bash
rf3 fold inputs='docs/rf3/examples/7o1r_from_json.json'
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_json.json'
```
**Supported input options:**
@@ -278,11 +281,11 @@ As described in the example [Folding with Arbitrary Biomolecules](#folding-with-
For example, folding `7o1r`, which contains two N-glycosylations:
```bash
rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif'
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif'
```
<p align="center">
<img src="./docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
<img src="../../docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
</p>
<p align="center">
<em>Figure: `7o1r` structure showing N-glycosylation covalent modification prediction with RF3 and ground truth crystal structure.</em>
@@ -292,7 +295,7 @@ Such `.cif` files complete with appropriate bonds can be composed with AtomWorks
If you would prefer to use the JSON API, bonds can be explicitly given using PyMol-like strings of the form `chain_id/res_name/res_id/atom_name`. You will need to know the specific chain ID, residue name, residue ID, and atom name between the relevant pairs of atoms to unambiguously specify the bond.
📝 **Example JSON configuration with covalent modifcations** (full example found at `docs/rf3/examples/7o1r_from_json.json`):
📝 **Example JSON configuration with covalent modifcations** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`):
```json
[
{
@@ -300,7 +303,7 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM
"components": [
{
"seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH",
"msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz",
"msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz",
"chain_id": "A"
},
{
@@ -308,10 +311,10 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM
},
{
// We provide one sugar via a CIF file, with complete control over bonds and atom names (as we use the atom names from the CIF file)
"path": "docs/rf3/examples/ligands/NAG.cif"
"path": "../../docs/releases/rf3/examples/ligands/NAG.cif"
},
{
"path": "docs/rf3/examples/ligands/HEM.sdf"
"path": "../../docs/releases/rf3/examples/ligands/HEM.sdf"
},
{
"smiles": "[nH]1cc[nH+]c1"
@@ -351,7 +354,7 @@ from atomworks.io.utils.visualize import view
import numpy as np
# Load an SDF file into an AtomArray
sdf_path = "docs/rf3/examples/ligands/NAG.sdf"
sdf_path = "../../docs/releases/rf3/examples/ligands/NAG.sdf"
# Load into an AtomArray
# Since SDF file files do not have atom names, we automatically generate them (e.g., C1, C2, C3, etc.)
@@ -448,17 +451,17 @@ RF3 uses AtomWorks' flexible `AtomSelectionStack` query syntax for specifying st
It is often helpful to template one or multiple polymer chains while allowing the other chain(s) to fold unconstrained. We demonstrate with an nanobody-antigen use case below how to apply templates.
📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `docs/rf3/examples/7xli_template_antigen_and_framework.json`):
📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json`):
```json
[
{
"name": "7xli_template_antigen",
"components": [
{
"path": "docs/rf3/examples/templates/7xli_chain_A.cif"
"path": "../../docs/releases/rf3/examples/templates/7xli_chain_A.cif"
},
{
"path": "docs/rf3/examples/templates/7xli_chain_B.cif"
"path": "../../docs/releases/rf3/examples/templates/7xli_chain_B.cif"
}
],
"template_selection": ["A", "B/*/1-42", "B/*/49-63", "B/*/71-102", "B/*/108-125"]
@@ -469,7 +472,7 @@ It is often helpful to template one or multiple polymer chains while allowing th
🚀 **Run the example:**
```bash
rf3 fold inputs='docs/rf3/examples/7xli_template_antigen_and_framework.json'
rf3 fold inputs='../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json'
```
You may also specify templating directly via the CLI using `template_selection="[A, B/*/1-42, ...]"`.
@@ -483,14 +486,14 @@ We find that enforcing a particular small molecule conformation has various appl
For the moment, the ground truth conformer track is only effective if we want to template the *entire* small molecule. Partial templating of small molecules is still possible via the `template_selection` approach. We encourage exploration of both templating techniques to find what combination(s) are most effective for a given problem. Below we provide both, which represents the strongest possible conditioning.
📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `docs/rf3/examples/1eiz_template_ligand_and_protein.json`):
📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `../../docs/releases/rf3/examples/1eiz_template_ligand_and_protein.json`):
```json
[
{
"name": "9dfn_template_ligand_and_protein",
"components": [
{
"path": "docs/rf3/examples/9dfn.cif"
"path": "../../docs/releases/rf3/examples/9dfn.cif"
}
],
"template_selection": ["A", "C", "D"],
@@ -506,7 +509,7 @@ For the moment, the ground truth conformer track is only effective if we want to
🚀 **Run the example:**
```bash
rf3 fold inputs='docs/rf3/examples/8cdz_templating_ligand.json'
rf3 fold inputs='../../docs/releases/rf3/examples/8cdz_templating_ligand.json'
```
You may also specify the ground truth conformer selection directly via the CLI, e.g., using `ground_truth_conformer_selection="[E]"`

View File

@@ -1,5 +1,5 @@
dump_validation_structures_callback:
_target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback
_target_: rf3.callbacks.dump_validation_structures.DumpValidationStructuresCallback
save_dir: ${paths.output_dir}/val_structures
dump_predictions: True
one_model_per_file: False

View File

@@ -1,10 +1,10 @@
store_validation_metrics_in_df_callback:
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
_target_: rf3.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
save_dir: ${paths.output_dir}/val_metrics
metrics_to_save: "all"
log_af3_validation_metrics_callback:
_target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
_target_: rf3.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
# Only logs if present in the metric output dictionary
# Must be subset of metrics_to_save
metrics_to_log: "all"

View File

@@ -17,7 +17,7 @@ defaults:
- _self_
# Dataloading pipeline to use
pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline
pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline
# Dataset weighting
train:

View File

@@ -9,7 +9,7 @@ defaults:
- _self_
# Dataloading pipeline to use
pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline
pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline
# Dataset weighting
train:

View File

@@ -2,7 +2,7 @@
disorder_distillation:
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: null
# cif parser arguments
@@ -13,19 +13,18 @@ disorder_distillation:
# metadata parser
dataset_parser:
_target_: datahub.datasets.parsers.GenericDFParser
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: pdb_diso_distillation
id_column: example_id
data: ${paths.data.disorder_distill_parquet_dir}/disorderDistillation.csv
columns_to_load:
- example_id
- path
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -2,7 +2,7 @@
multidomain_distillation:
dataset:
_target_: modelhub.data.paired_msa.MultiInputDatasetWrapper
_target_: rf3.data.paired_msa.MultiInputDatasetWrapper
save_failed_examples_to_dir: null
# cif parser
@@ -14,11 +14,11 @@ multidomain_distillation:
# metadata parser
dataset_parser:
_target_: modelhub.data.paired_msa.MultidomainDFParser
_target_: rf3.data.paired_msa.MultidomainDFParser
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: multidomain_distillation
id_column: example_id
data: /projects/ml/datahub/dfs/domain_domain/domain_domain_dataset.DIGS.parquet
@@ -26,7 +26,6 @@ multidomain_distillation:
- example_id
- pdb_path
- msa_path
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -2,7 +2,7 @@
monomer_distillation:
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
# cif parser arguments
@@ -13,19 +13,18 @@ monomer_distillation:
# metadata parser
dataset_parser:
_target_: datahub.datasets.parsers.GenericDFParser
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af2fb_distillation
id_column: example_id
data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet
columns_to_load:
- example_id
- path
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -2,7 +2,7 @@
na_complex_distillation:
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: null
# cif parser
@@ -14,19 +14,18 @@ na_complex_distillation:
# metadata parser
dataset_parser:
_target_: datahub.datasets.parsers.GenericDFParser
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null #[]
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: tf_distillation
id_column: example_id
data: ${paths.data.na_complex_distillation_parquet_dir}/transcriptionFactor_distillation_rf3.newDL.csv
columns_to_load:
- example_id
- path
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -1,5 +1,5 @@
weights:
_target_: datahub.samplers.calculate_weights_for_pdb_dataset_df
_target_: atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df
# We do not include beta here, since it is different for interfaces and chains
alphas:
a_prot: 3.0 # 3 for AF-3

View File

@@ -1,16 +1,14 @@
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
cif_parser_args:
cache_dir: null
load_from_cache: false
save_to_cache: false
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
# we will use the example_id as the unique column
id_column: example_id
# return all keys (do not subset)
return_key: null
transform:
# common Transform pipeline components for all PDB datasets
_target_: ${datasets.pipeline_target}

View File

@@ -5,7 +5,7 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.InterfacesDFParser
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
dataset:
name: plinder
@@ -50,5 +50,5 @@ dataset:
crop_spatial_probability: 1.0
weights:
_target_: datahub.samplers.calculate_weights_by_inverse_cluster_size
_target_: atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size
cluster_column: pli_qcov__50__weak__component # Need to ablate

View File

@@ -4,7 +4,7 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.InterfacesDFParser
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
dataset:
name: interface

View File

@@ -4,7 +4,7 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.PNUnitsDFParser
_target_: atomworks.ml.datasets.parsers.PNUnitsDFParser
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
dataset:
name: pn_unit

View File

@@ -2,7 +2,7 @@
rna_monomer_distillation:
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
# cif parser arguments
@@ -13,12 +13,12 @@ rna_monomer_distillation:
# metadata parser
dataset_parser:
_target_: datahub.datasets.parsers.GenericDFParser
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: rna_monomer_distillation
id_column: example_id
data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet
@@ -31,7 +31,6 @@ rna_monomer_distillation:
- overall_pde
- overall_pae
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -3,8 +3,9 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af3_validation
data: /net/scratch/rib7/rf3_ab_splits/entry_level_val_df.parquet

View File

@@ -3,8 +3,9 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af3_validation
data: ${paths.data.pdb_data_dir}/entry_level_val_df.parquet

View File

@@ -1,16 +1,15 @@
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
# we will use the example_id as the unique column
id_column: example_id
# return all keys (do not subset)
return_key: null
transform:
# common Transform pipeline components for all PDB datasets
_target_: ${datasets.pipeline_target}

View File

@@ -3,9 +3,10 @@ defaults:
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
dataset:
_target_: datahub.datasets.datasets.PandasDataset
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af3_validation
data: /projects/ml/datahub/dfs/af3_splits/2024_12_16/runs_n_poses_entry_level_df.parquet
filters:
- "n_tokens_total < 1000" # Subset to reasonably-sized examples for efficiency

View File

@@ -5,7 +5,7 @@ name: rf3
defaults:
- override /datasets: pdb_and_distillation
- override /model: rf3
- override /trainer: af3
- override /trainer: rf3
ckpt_config:
_target_: modelhub.utils.weights.CheckpointConfig

View File

@@ -6,7 +6,7 @@ name: rf3-with-confidence
defaults:
- rf3
- override /model: rf3_with_confidence
- override /trainer: af3_with_confidence
- override /trainer: rf3_with_confidence
- _self_
datasets:

View File

@@ -8,7 +8,7 @@ name: quick-rf3-with-confidence
defaults:
- quick-rf3
- override /model: rf3_with_confidence
- override /trainer: af3_with_confidence
- override /trainer: rf3_with_confidence
- _self_
datasets:

View File

@@ -2,7 +2,7 @@
# Experiment that loads a small dataset for quick testing
name: quick-af3
name: quick-rf3
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:

View File

@@ -4,9 +4,9 @@ defaults:
- base
- _self_
_target_: modelhub.inference_engines.rf3.RF3InferenceEngine
_target_: rf3.inference_engines.rf3.RF3InferenceEngine
ckpt_path: /software/containers/versions/modelhub_inference/ckpts/rf3_latest.ckpt
ckpt_path: /home/ncorley/rf3-w-conf-newdatecut-ep688-refactored-fixed.ckpt
# Transform arguments
n_recycles: 10
@@ -36,8 +36,8 @@ annotate_b_factor_with_plddt: true
metrics_cfg:
# (For confidence outputs)
ptm:
_target_: modelhub.metrics.predicted_error.ComputePTM
_target_: rf3.metrics.predicted_error.ComputePTM
iptm:
_target_: modelhub.metrics.predicted_error.ComputeIPTM
_target_: rf3.metrics.predicted_error.ComputeIPTM
count_clashing_chains:
_target_: modelhub.metrics.clashing_chains.CountClashingChains
_target_: rf3.metrics.clashing_chains.CountClashingChains

View File

@@ -1,5 +1,5 @@
# Model architecture
_target_: modelhub.model.RF3.RF3
_target_: rf3.model.RF3.RF3
# +---------- Channel dimensions ----------+
c_s: 384

View File

@@ -2,7 +2,7 @@ defaults:
- rf3_net
# Model architecture
_target_: modelhub.model.RF3.RF3WithConfidence
_target_: rf3.model.RF3.RF3WithConfidence
# +---------- Mini rollout sampler ----------+
# From the AF-3 main text:

View File

@@ -4,4 +4,4 @@ defaults:
- _self_
net:
_target_: modelhub.model.RF3.RF3WithConfidence
_target_: rf3.model.RF3.RF3WithConfidence

View File

@@ -1,5 +1,5 @@
# Learning rate scheduler
_target_: modelhub.training.schedulers.AF3Scheduler
_target_: rf3.training.schedulers.AF3Scheduler
base_lr: 1.8e-3
warmup_steps: 1000
decay_factor: 0.95

View File

@@ -3,30 +3,37 @@
########################
# path to directory with training splits
pdb_data_dir: ???
pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2025_07_13
# fb monomer distillation dataset
monomer_distillation_data_dir: ???
monomer_distillation_parquet_dir: ???
monomer_distillation_data_dir: /squash/af2_distillation_facebook
monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
mgnify_distillation_data_dir: /squash/mgnify_distill_rf3/
mgnify_distillation_parquet_dir: /home/dimaio/MGnify/
# na complex distill set
na_complex_distillation_data_dir: ???
na_complex_distillation_parquet_dir: ???
na_complex_distillation_data_dir: /projects/ml/prot_dna/rf3_newDL
na_complex_distillation_parquet_dir: /projects/ml/prot_dna
# disorder distill set
disorder_distill_parquet_dir: ???
disorder_distill_parquet_dir: /projects/ml/disorder_distill
########################
# MSAs
########################
# path(s) to search for protein MSAs (for PDB datasets)
# e.g., {"dir": "/path/to/msas", "extension": ".a3m.gz", "directory_depth": 2}
protein_msa_dirs:
- {"dir": ???, "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/hhblits", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/mmseqs_gpu", "extension": ".a3m", "directory_depth": 2}
- {"dir": "/projects/msa/lab", "extension": ".a3m", "directory_depth": 2}
# path(s) to search for RNA MSAs
# e.g., {"dir": "/path/to/msas", "extension": ".afa", "directory_depth": 0}
rna_msa_dirs:
- {"dir": ???, "extension": ".afa", "directory_depth": 0}
- {"dir": "/projects/msa/rna", "extension": ".afa", "directory_depth": 0}
########################
# Misc

View File

@@ -1,4 +1,4 @@
_target_: modelhub.loss.af3_confidence_loss.ConfidenceLoss
_target_: rf3.loss.af3_confidence_loss.ConfidenceLoss
weight: 1.0
plddt:

View File

@@ -1,4 +1,4 @@
_target_: modelhub.loss.af3_losses.DiffusionLoss
_target_: rf3.loss.af3_losses.DiffusionLoss
weight: 4.0
sigma_data: ${model.net.diffusion_module.sigma_data}
alpha_dna: 5

View File

@@ -1,2 +1,2 @@
_target_: modelhub.loss.af3_losses.DistogramLoss
_target_: rf3.loss.af3_losses.DistogramLoss
weight: 3e-2

View File

@@ -1,14 +1,14 @@
by_type_lddt:
_target_: modelhub.metrics.lddt.ByTypeLDDT
_target_: rf3.metrics.lddt.ByTypeLDDT
all_atom_lddt:
_target_: modelhub.metrics.lddt.AllAtomLDDT
_target_: rf3.metrics.lddt.AllAtomLDDT
distogram:
_target_: modelhub.metrics.distogram.DistogramLoss
_target_: rf3.metrics.distogram.DistogramLoss
distogram_comparisons:
_target_: modelhub.metrics.distogram.DistogramComparisons
_target_: rf3.metrics.distogram.DistogramComparisons
distogram_entropy:
_target_: modelhub.metrics.distogram.DistogramEntropy
_target_: rf3.metrics.distogram.DistogramEntropy
chiral_loss:
_target_: modelhub.metrics.chiral.ChiralLoss
_target_: rf3.metrics.chiral.ChiralLoss
unresolved_rasa:
_target_: modelhub.metrics.rasa.UnresolvedRegionRASA
_target_: rf3.metrics.rasa.UnresolvedRegionRASA

View File

@@ -3,7 +3,7 @@ defaults:
- loss: structure_prediction
- metrics: structure_prediction
_target_: modelhub.trainers.rf3.RF3Trainer
_target_: rf3.trainers.rf3.RF3Trainer
validate_every_n_epochs: 1
max_epochs: 10_000
n_examples_per_epoch: 24000

View File

@@ -2,12 +2,12 @@ defaults:
- af3
- override loss: structure_prediction_with_confidence
_target_: modelhub.trainers.rf3.RF3TrainerWithConfidence
_target_: rf3.trainers.rf3.RF3TrainerWithConfidence
metrics:
ptm:
_target_: modelhub.metrics.predicted_error.ComputePTM
_target_: rf3.metrics.predicted_error.ComputePTM
iptm:
_target_: modelhub.metrics.predicted_error.ComputeIPTM
_target_: rf3.metrics.predicted_error.ComputeIPTM
count_clashing_chains:
_target_: modelhub.metrics.clashing_chains.CountClashingChains
_target_: rf3.metrics.clashing_chains.CountClashingChains

67
models/rf3/pyproject.toml Normal file
View File

@@ -0,0 +1,67 @@
[project]
name = "rf3"
dynamic = ["version"]
description = "RosettaFold3: Open-source biomolecular structure prediction for all molecules of life."
readme = "README.md"
requires-python = ">= 3.12"
authors = [
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
]
license = { file = "../../LICENSE.md" }
classifiers = [
"Intended Audience :: Developers",
"Intended Audience :: Science/Research",
"Natural Language :: English",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: Implementation :: CPython",
"Topic :: Scientific/Engineering :: Bio-Informatics",
"License :: OSI Approved :: BSD License",
]
dependencies = [
# Core modelhub dependency
"modelhub",
# CLI
"typer>=0.9.0,<1",
# RF3-specific ML dependencies
"einops>=0.8.0,<1",
"einx>=0.1.0,<1",
"opt_einsum>=3.4.0,<4",
"dm-tree>=0.1.6,<1",
# ... kernels (Linux only)
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
# ... dataloading
"atomworks==1.0.2",
]
[project.scripts]
rf3 = "rf3.cli:app"
# Build settings ----------------------------------------------------------------------
[build-system]
requires = [
"hatchling",
"hatch-vcs == 0.4",
]
build-backend = "hatchling.build"
[tool.hatch.version]
source = "vcs"
[tool.hatch.version.raw-options]
root = "../.."
[tool.hatch.build.hooks.vcs]
version-file = "src/rf3/_version.py"
[tool.hatch.metadata]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["src/rf3"]

61
models/rf3/rf3-dev.def Normal file
View File

@@ -0,0 +1,61 @@
Bootstrap: docker
From: nvcr.io/nvidia/pytorch:25.04-py3
IncludeCmd: yes
# RF3 Development Container - Dependencies only
%labels
Author Institute for Protein Design
Version 1.0-dev
Description RosettaFold3 - Dependencies only (for development and model training)
%setup
# Create a directory in the container to bind the host's current working directory
mkdir ${APPTAINER_ROOTFS}/modelhub_host
# ... for mounting `/projects` with --bind
mkdir ${APPTAINER_ROOTFS}/projects
# ... for mounting `/databases` with --bind
mkdir ${APPTAINER_ROOTFS}/net
# ... for mounting `/squash` with --bind
mkdir ${APPTAINER_ROOTFS}/squash
%files
/etc/localtime
/etc/hosts
pyproject.toml /opt/pyproject.toml
%post
## GENERAL SETUP
# Common symlinks (within container)
ln -s /net/databases /databases
ln -s /net/software /software
ln -s /home /mnt/home
ln -s /projects /mnt/projects
ln -s /net /mnt/net
## PYTHON DEPENDENCY INSTALLATION
# Install uv for fast package installation (10-100x faster than pip)
pip install uv
# Auto-generate a dependency list
uv pip compile /opt/pyproject.toml -o /opt/requirements.txt
# Install all Python dependencies using requirements.txt with uv
uv pip install --system --break-system-packages -r /opt/requirements.txt
%environment
# (Flag to increase accessible GPU memory)
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# (Turn off NVLink)
export NCCL_P2P_DISABLE=1
%runscript
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
exec python "$@"
%help
RosettaFold3 (RF3) development container with dependencies only.

View File

@@ -0,0 +1,3 @@
"""RF3 - RosettaFold3 model implementation."""
__version__ = "0.1.0"

View File

@@ -0,0 +1,34 @@
# file generated by setuptools-scm
# don't change, don't track in version control
__all__ = [
"__version__",
"__version_tuple__",
"version",
"version_tuple",
"__commit_id__",
"commit_id",
]
TYPE_CHECKING = False
if TYPE_CHECKING:
from typing import Tuple
from typing import Union
VERSION_TUPLE = Tuple[Union[int, str], ...]
COMMIT_ID = Union[str, None]
else:
VERSION_TUPLE = object
COMMIT_ID = object
version: str
__version__: str
__version_tuple__: VERSION_TUPLE
version_tuple: VERSION_TUPLE
commit_id: COMMIT_ID
__commit_id__: COMMIT_ID
__version__ = version = '0.1.dev917+gcbbe4c6a6.d20251001'
__version_tuple__ = version_tuple = (0, 1, 'dev917', 'gcbbe4c6a6.d20251001')
__commit_id__ = commit_id = None

View File

@@ -1,10 +1,10 @@
from os import PathLike
from pathlib import Path
from atomworks.common import parse_example_id
from atomworks.ml.example_id import parse_example_id
from beartype.typing import Any
from callbacks.base import BaseCallback
from modelhub.callbacks.callback import BaseCallback
from rf3.utils.io import (
build_stack_from_atom_array_and_batched_coords,
dump_structures,

View File

@@ -7,9 +7,9 @@ from atomworks.ml.utils import nested_dict
from beartype.typing import Any, Literal
from omegaconf import ListConfig
from callbacks.base import BaseCallback
from utils.ddp import RankedLogger
from utils.logging import (
from modelhub.callbacks.callback import BaseCallback
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.logging import (
condense_count_columns_of_grouped_df,
print_df_as_table,
)

View File

@@ -1,4 +1,5 @@
import os
from pathlib import Path
import typer
from hydra import compose, initialize_config_dir
@@ -13,9 +14,11 @@ app = typer.Typer()
)
def fold(ctx: typer.Context):
"""Run structure prediction using hydra config overrides or simple input file."""
config_path = os.path.join(
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
)
# Find the RF3 configs directory relative to this file
# This file is at: models/rf3/src/rf3/cli.py
# Configs are at: models/rf3/configs/
rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
config_path = str(rf3_package_dir / "configs")
# Get all arguments
args = ctx.params.get("args", []) + ctx.args

View File

@@ -20,7 +20,7 @@ from biotite.structure import AtomArray
from jaxtyping import Bool, Float, Shaped
from torch import Tensor
from utils.torch import assert_no_nans
from modelhub.utils.torch import assert_no_nans
logger = logging.getLogger(__name__)

View File

@@ -125,11 +125,7 @@ class MultiInputDatasetWrapper(StructuralDatasetWrapper):
)
raise e
# Return the specified key or the entire data dict (i.e., only "feats" key from the Transform dictionary)
if exists(self.return_key):
return data[self.return_key]
else:
return data
return data
class MultidomainDFParser(MetadataRowParser):

View File

@@ -3,7 +3,7 @@ from beartype.typing import Any, Literal
from jaxtyping import Float
from rf3.data.rotation_augmentation import centre_random_augmentation
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -10,7 +10,7 @@ from dotenv import load_dotenv
from hydra.utils import instantiate
from omegaconf import DictConfig
from utils.logging import suppress_warnings
from modelhub.utils.logging import suppress_warnings
load_dotenv(override=True)
@@ -18,10 +18,11 @@ load_dotenv(override=True)
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
_config_path = os.path.join(
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
)
# Find the RF3 configs directory relative to this file
# This file is at: models/rf3/src/rf3/inference.py
# Configs are at: models/rf3/configs/
rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
_config_path = str(rf3_package_dir / "configs")
@hydra.main(

View File

@@ -0,0 +1,5 @@
"""RF3 inference engines."""
from rf3.inference_engines.rf3 import RF3InferenceEngine
__all__ = ["RF3InferenceEngine"]

View File

@@ -10,12 +10,11 @@ from atomworks.io.transforms.categories import category_to_dict
from lightning.fabric import seed_everything
from omegaconf import OmegaConf
from inference_engines.base import InferenceEngine
from rf3.model.RF3 import ShouldEarlyStopFn
from rf3.utils.datasets import (
assemble_distributed_inference_loader_from_list_of_paths,
)
from utils.ddp import RankedLogger, set_accelerator_based_on_availability
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
from rf3.utils.inference import (
apply_conformer_and_template_selections,
build_file_paths_for_prediction,
@@ -25,7 +24,7 @@ from rf3.utils.io import (
dump_structures,
dump_trajectories,
)
from utils.logging import print_config_tree
from modelhub.utils.logging import print_config_tree
from rf3.utils.predicted_error import (
annotate_atom_array_b_factor_with_plddt,
compile_af3_confidence_outputs,
@@ -55,7 +54,7 @@ def should_early_stop_by_mean_plddt(
return fn
class RF3InferenceEngine(InferenceEngine):
class RF3InferenceEngine:
"""Class for inference with RF3. Evaluates a trained RF3 model on a set of spoofed CIFs."""
def __init__(
@@ -152,8 +151,6 @@ class RF3InferenceEngine(InferenceEngine):
self.cfg.trainer.num_nodes = num_nodes
self.cfg.trainer.devices_per_node = devices_per_node
self.cfg.trainer.precision = "bf16-mixed" # HACK: Temporary hack until our checkpoint configs are updated
self.cfg.trainer._target_ = "modelhub.trainers.rf3.RF3TrainerWithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking
self.cfg.model.net._target_ = "modelhub.model.RF3.RF3WithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking
# We don't want to compute all of the training metrics, since they may error during inference
self.cfg.trainer["metrics"] = {}

View File

@@ -1,11 +1,8 @@
# TODO: Many of these functions are unused; we will deprecate and delete
# (They are holdovers from previous frameworks)
from itertools import permutations
import numpy as np
import torch
from openbabel import openbabel
PARAMS = {
"DMIN": 1,
@@ -355,36 +352,3 @@ def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
def standardize_dihedral_retain_first(a, b, c, d):
isomorphisms = [(a, b, c, d), (a, c, b, d)]
return sorted(isomorphisms)[0]
def get_chirals(obmol, xyz):
"""
get all quadruples of atoms forming chiral centers and the expected ideal pseudodihedral between them
"""
stereo = openbabel.OBStereoFacade(obmol)
angle = np.arcsin(1 / 3**0.5)
chiral_idx_set = set()
for i in range(obmol.NumAtoms()):
if not stereo.HasTetrahedralStereo(i):
continue
si = stereo.GetTetrahedralStereo(i)
config = si.GetConfig()
o = config.center
c = config.from_or_towards
i, j, k = list(config.refs)
for a, b, c in permutations((c, i, j, k), 3):
chiral_idx_set.add(standardize_dihedral_retain_first(o, a, b, c))
chiral_idx = list(chiral_idx_set)
chiral_idx.sort()
chiral_idx = torch.tensor(chiral_idx, dtype=torch.float32)
chiral_idx = chiral_idx[(chiral_idx < obmol.NumAtoms()).all(dim=-1)]
if chiral_idx.numel() == 0:
return torch.zeros((0, 5))
dih = get_dih(*xyz[chiral_idx.long()].split(split_size=1, dim=1))[:, 0]
chirals = torch.nn.functional.pad(chiral_idx, (0, 1), mode="constant", value=angle)
chirals[dih < 0.0, -1] *= -1
return chirals

View File

@@ -10,7 +10,7 @@ from biotite.structure import AtomArray, AtomArrayStack
from jaxtyping import Bool, Float
from rf3.kinematics import get_dih
from metrics.base import Metric
from modelhub.metrics.metric import Metric
def calc_chiral_metrics_masked(

View File

@@ -4,7 +4,7 @@ from typing import Any
import torch
from biotite.structure import AtomArrayStack
from metrics.base import Metric
from modelhub.metrics.metric import Metric
class CountClashingChains(Metric):

View File

@@ -11,8 +11,8 @@ from einops import rearrange, repeat
from jaxtyping import Bool, Float
from rf3.loss.af3_losses import distogram_loss
from metrics.base import Metric
from utils.torch import assert_no_nans
from modelhub.metrics.metric import Metric
from modelhub.utils.torch import assert_no_nans
@dataclass

View File

@@ -1,9 +1,7 @@
import numpy as np
import torch
from atomworks.ml.transforms.atom_array import (
AddGlobalTokenIdAnnotation,
ensure_atom_array_stack,
)
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
from atomworks.ml.transforms.atomize import AtomizeByCCDName
from atomworks.ml.transforms.base import Compose
from atomworks.ml.utils.token import get_token_starts
@@ -11,8 +9,8 @@ from beartype.typing import Any
from biotite.structure import AtomArray, AtomArrayStack, stack
from jaxtyping import Bool, Float, Int
from metrics.base import Metric
from utils.ddp import RankedLogger
from modelhub.metrics.metric import Metric
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -2,7 +2,7 @@ import json
from beartype.typing import Any, Literal
from metrics.base import Metric
from modelhub.metrics.metric import Metric
class ExtraInfo(Metric):

View File

@@ -2,7 +2,7 @@ from typing import Any
import torch
from metrics.base import Metric
from modelhub.metrics.metric import Metric
from rf3.metrics.metric_utils import find_bin_midpoints

View File

@@ -3,7 +3,7 @@ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
from beartype.typing import Any
from biotite.structure import AtomArrayStack
from metrics.base import Metric
from modelhub.metrics.metric import Metric
class UnresolvedRegionRASA(Metric):

View File

@@ -7,7 +7,7 @@ from atomworks.ml.utils.selection import (
from beartype.typing import Any
from biotite.structure import AtomArrayStack
from metrics.base import Metric
from modelhub.metrics.metric import Metric
class SelectedAtomByAtomDistances(Metric):

View File

@@ -6,7 +6,7 @@ import torch.nn.functional as F
from einops import rearrange
from opt_einsum import contract as einsum
from src import SHOULD_USE_CUEQUIVARIANCE
from modelhub import SHOULD_USE_CUEQUIVARIANCE
from rf3.training.checkpoint import activation_checkpointing
from rf3.util_module import init_lecun_normal

View File

@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
from jaxtyping import Float
from src import SHOULD_USE_CUEQUIVARIANCE
from modelhub import SHOULD_USE_CUEQUIVARIANCE
from rf3.util_module import init_lecun_normal
if SHOULD_USE_CUEQUIVARIANCE:

View File

@@ -12,7 +12,7 @@ from rf3.model.layers.layer_utils import (
)
from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage
from rf3.training.checkpoint import activation_checkpointing
from utils.torch import device_of
from modelhub.utils.torch import device_of
class AtomAttentionEncoderDiffusion(nn.Module):

View File

@@ -8,8 +8,8 @@ import rootutils
from dotenv import load_dotenv
from omegaconf import DictConfig
from utils.logging import suppress_warnings
from utils.weights import CheckpointConfig
from modelhub.utils.logging import suppress_warnings
from modelhub.utils.weights import CheckpointConfig
load_dotenv(override=True)
@@ -18,10 +18,7 @@ load_dotenv(override=True)
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
_config_path = os.path.join(
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
)
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
_spawning_process_logger = logging.getLogger(__name__)
@@ -43,14 +40,14 @@ def train(cfg: DictConfig) -> None:
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision("medium")
from callbacks.base import BaseCallback # noqa
from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
from utils.logging import (
from modelhub.callbacks.callback import BaseCallback # noqa
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
from modelhub.utils.logging import (
print_config_tree,
log_hyperparameters_with_all_loggers,
) # noqa
from utils.ddp import RankedLogger # noqa
from utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
from modelhub.utils.ddp import RankedLogger # noqa
from modelhub.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
from rf3.utils.datasets import (
recursively_instantiate_datasets_and_samplers,
assemble_distributed_loader,

View File

@@ -6,20 +6,20 @@ from jaxtyping import Float, Int
from lightning_utilities import apply_to_collection
from omegaconf import DictConfig
from common import exists
from modelhub.common import exists
from rf3.loss.af3_losses import Loss as AF3Loss
from rf3.loss.af3_losses import (
ResidueSymmetryResolution,
SubunitSymmetryResolution,
)
from metrics.base import MetricManager
from modelhub.metrics.metric import MetricManager
from rf3.model.RF3 import ShouldEarlyStopFn
from trainers.fabric import FabricTrainer
from modelhub.trainers.fabric import FabricTrainer
from rf3.training.EMA import EMA
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
from rf3.utils.recycling import get_recycle_schedule
from utils.torch import assert_no_nans, assert_same_shape
from modelhub.utils.torch import assert_no_nans, assert_same_shape
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -1,8 +1,4 @@
"""Utilities for gradient checkpointing to reduce memory usage during training.
Gradient checkpointing (also called activation checkpointing) trades compute for memory
by recomputing intermediate activations during the backward pass instead of storing them.
This enables training larger models or using larger batch sizes within GPU memory constraints.
"""Utilities for gradient checkpointing.
References:
* `PyTorch Checkpoint Documentation`_
@@ -17,9 +13,8 @@ from torch.utils.checkpoint import checkpoint
def create_custom_forward(module, **kwargs):
"""Create a custom forward function for gradient checkpointing with fixed kwargs.
This helper enables passing keyword arguments to a module when using PyTorch's
checkpoint function, which only accepts positional arguments for the function to
be checkpointed.
Enables passing keyword arguments to a module when using PyTorch's checkpoint function,
which only accepts positional arguments for the function to be checkpointed.
Args:
module: The callable (typically a nn.Module) to wrap.
@@ -28,15 +23,6 @@ def create_custom_forward(module, **kwargs):
Returns:
A callable that accepts only positional arguments and forwards them along
with the fixed kwargs to the original module.
Examples:
Use with PyTorch checkpoint::
custom_fn = create_custom_forward(my_module, frame_atom_idxs=frame_idxs)
output = checkpoint(custom_fn, input_tensor, use_reentrant=False)
See Also:
:py:func:`activation_checkpointing`
"""
def custom_forward(*inputs):
@@ -48,11 +34,6 @@ def create_custom_forward(module, **kwargs):
def activation_checkpointing(function):
"""Decorator to enable gradient checkpointing for a function during training.
When gradients are enabled (training mode), this decorator wraps the function
with PyTorch's checkpoint to save memory by recomputing activations during
the backward pass. During inference (gradients disabled), the function runs
normally without checkpointing overhead.
Args:
function: The function to apply gradient checkpointing to.
@@ -67,11 +48,7 @@ def activation_checkpointing(function):
return self.layer(x, mask)
Notes:
Uses ``use_reentrant=False`` for better compatibility with modern PyTorch
features like autograd hooks and higher-order gradients.
See Also:
:py:func:`create_custom_forward`
Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions.
"""
def wrapper(*args, **kwargs):

View File

@@ -20,8 +20,8 @@ from torch.utils.data import (
)
from torch.utils.data.distributed import DistributedSampler
from hydra.resolvers import register_resolvers
from utils.ddp import RankedLogger
from modelhub.hydra.resolvers import register_resolvers
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
try:

View File

@@ -9,7 +9,7 @@ from beartype.typing import Literal
from biotite.structure import AtomArray, AtomArrayStack, stack
from rf3.alignment import weighted_rigid_align
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -2,13 +2,14 @@
import logging
import os
from pathlib import Path
import hydra
import rootutils
from dotenv import load_dotenv
from omegaconf import DictConfig
from utils.logging import suppress_warnings
from modelhub.utils.logging import suppress_warnings
load_dotenv(override=True)
@@ -17,10 +18,7 @@ load_dotenv(override=True)
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
_config_path = os.path.join(
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
)
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
_spawning_process_logger = logging.getLogger(__name__)
@@ -42,11 +40,11 @@ def validate(cfg: DictConfig) -> None:
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision("medium")
from callbacks.base import BaseCallback # noqa
from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
from utils.logging import print_config_tree # noqa
from utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
from utils.ddp import is_rank_zero # noqa
from modelhub.callbacks.callback import BaseCallback # noqa
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
from modelhub.utils.logging import print_config_tree # noqa
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
from modelhub.utils.ddp import is_rank_zero # noqa
from rf3.utils.datasets import assemble_val_loader_dict # noqa
set_accelerator_based_on_availability(cfg)

View File

@@ -9,11 +9,18 @@ TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
def pytest_configure(config):
import sys
# Set PROJECT_ROOT
project_root = rootutils.setup_root(
__file__, indicator=".project-root", pythonpath=True
)
# Add models/rf3/src to path so RF3 modules can be imported
rf3_src = project_root / "models" / "rf3" / "src"
if rf3_src.exists() and str(rf3_src) not in sys.path:
sys.path.insert(0, str(rf3_src))
# Construct path to .env file at project root
dotenv_path = project_root / ".env"

View File

@@ -1,7 +1,7 @@
data_5VHT
#
_msa_paths_by_chain_id.A tests/data/msas/5vht_A.a3m
_msa_paths_by_chain_id.B tests/data/msas/5vht_A.a3m
_msa_paths_by_chain_id.A models/rf3/tests/data/msas/5vht_A.a3m
_msa_paths_by_chain_id.B models/rf3/tests/data/msas/5vht_A.a3m
#
_entry.id 5VHT
#

View File

@@ -5,12 +5,12 @@
{
"seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH",
"chain_id": "A",
"msa_path": "tests/data/msas/5vht_A.a3m"
"msa_path": "models/rf3/tests/data/msas/5vht_A.a3m"
},
{
"seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH",
"chain_id": "B",
"msa_path": "tests/data/msas/5vht_A.a3m"
"msa_path": "models/rf3/tests/data/msas/5vht_A.a3m"
}
]
}

View File

@@ -3,7 +3,7 @@ from copy import deepcopy
import pytest
from atomworks.ml.utils.testing import cached_parse
from metrics.chiral import ChiralLoss
from rf3.metrics.chiral import ChiralLoss
@pytest.mark.parametrize("pdb_id", ["5ocm", "1ivo"])

View File

@@ -1,7 +1,7 @@
[project]
name = "rf3"
name = "modelhub"
dynamic = ["version"]
description = "Open-source biomolecular structure prediction for all molecules of life."
description = "Shared utilities and training infrastructure for biomolecular structure prediction models."
readme = "README.md"
requires-python = ">= 3.12"
authors = [
@@ -22,17 +22,6 @@ classifiers = [
"License :: OSI Approved :: BSD License",
]
dependencies = [
# ...ml tools
"torch>=2.2.0,<3",
"lightning>=2.4.0,<2.5",
"einops>=0.8.0,<1",
"einx>=0.1.0,<1",
"opt_einsum>=3.4.0,<4",
"dm-tree>=0.1.6,<1",
# ... kernels (Linux only)
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
# ... configuration & CLI
"rootutils>=1.0.7,<1.1",
"hydra-core>=1.3.0,<1.4",
@@ -43,9 +32,9 @@ dependencies = [
# ... typing & documentation
"jaxtyping>=0.2.17,<1",
"beartype>=0.18.0,<1",
# ... dataloading
"atomworks==1.0.2"
# ... ml tools (core)
"torch>=2.2.0,<3",
"lightning>=2.4.0,<2.5",
]
@@ -64,8 +53,7 @@ dev = [
"pytest-cov>=4.1.0,<5", # generate coverage report
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
]
[project.scripts]
rf3 = "modelhub.cli:app"
# No CLI scripts at root level - models provide their own entry points
# Build settings ----------------------------------------------------------------------
[build-system]
requires = [

View File

@@ -44,4 +44,4 @@ except ImportError:
logger.debug("cuEquivariance unavailable: import failed")
# Export for easy access
__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "silence_warnings"]
__all__ = ["SHOULD_USE_CUEQUIVARIANCE"]

View File

@@ -0,0 +1,5 @@
"""Callbacks for training and validation."""
from modelhub.callbacks.callback import BaseCallback
__all__ = ["BaseCallback"]

View File

@@ -10,7 +10,7 @@ from jaxtyping import Float, Int
from lightning.fabric.utilities.rank_zero import rank_zero_only
from torch import Tensor
from callbacks.base import BaseCallback
from modelhub.callbacks.callback import BaseCallback
_DEFAULT_STATISTICS = types.MappingProxyType(
{

View File

@@ -1,10 +1,10 @@
import pandas as pd
from lightning.fabric.utilities.rank_zero import rank_zero_only
from callbacks.base import BaseCallback
from rf3.utils.logging import print_df_as_table
from rf3.utils.torch_utils import Timers
from modelhub.callbacks.callback import BaseCallback
class TimingCallback(BaseCallback):
"""Fabric callback to print timing metrics."""

View File

@@ -2,26 +2,26 @@ import time
from collections import defaultdict
import pandas as pd
from atomworks.common import parse_example_id
from atomworks.ml.example_id import parse_example_id
from beartype.typing import Any
from lightning.fabric.wrappers import (
_FabricOptimizer,
)
from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
from rich.console import Group
from rich.panel import Panel
from rich.table import Table
from torch import nn
from torchmetrics.aggregation import MeanMetric
from callbacks.base import BaseCallback
from rf3.utils.ddp import RankedLogger
from rf3.utils.logging import (
from modelhub.callbacks.callback import BaseCallback
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.logging import (
print_df_as_table,
print_model_parameters,
safe_print,
table_from_df,
)
from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
class LogModelParametersCallback(BaseCallback):

View File

@@ -1,15 +0,0 @@
from abc import ABC, abstractmethod
from pathlib import Path
class InferenceEngine(ABC):
"""Abstract base class for inference pipelines."""
@abstractmethod
def __init__(self, **kwargs):
pass
@abstractmethod
def eval(self, inputs: list[Path]) -> None:
"""Run inference on input files."""
pass

View File

@@ -0,0 +1,12 @@
"""Metrics for model evaluation.
This module provides the base metric framework.
"""
from modelhub.metrics.metric import Metric, MetricInputError, MetricManager
__all__ = [
"Metric",
"MetricManager",
"MetricInputError",
]

View File

@@ -8,7 +8,7 @@ from beartype.typing import Any
from omegaconf import DictConfig
from toolz import keymap
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -25,12 +25,12 @@ from lightning.fabric.wrappers import (
_FabricModule,
_FabricOptimizer,
)
from callbacks.base import BaseCallback
from rf3.training.EMA import EMA
from rf3.training.schedulers import SchedulerConfig
from utils.ddp import RankedLogger
from utils.weights import (
from modelhub.callbacks.callback import BaseCallback
from modelhub.utils.ddp import RankedLogger
from modelhub.utils.weights import (
CheckpointConfig,
WeightLoadingConfig,
freeze_parameters_with_config,

View File

@@ -2,7 +2,7 @@ import hydra
from lightning.fabric.loggers import Logger
from omegaconf import DictConfig
from callbacks.base import BaseCallback
from modelhub.callbacks.callback import BaseCallback
def _can_be_instantiated(cfg: DictConfig) -> bool:

View File

@@ -12,7 +12,7 @@ from rich.table import Table
from rich.tree import Tree
from torch import nn
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -14,8 +14,8 @@ from torch import Tensor
from torch._prims_common import DeviceLikeType
from torch.types import _dtype
from src import should_check_nans
from common import at_least_one_exists, do_nothing
from modelhub import should_check_nans
from modelhub.common import at_least_one_exists, do_nothing
def map_to(

View File

@@ -9,7 +9,7 @@ import torch
from beartype.typing import Pattern
from torch import nn
from utils.ddp import RankedLogger
from modelhub.utils.ddp import RankedLogger
ranked_logger = RankedLogger(__name__, rank_zero_only=True)

View File

@@ -4,7 +4,7 @@ import pytest
import torch
os.environ["NAN_CHECKING"] = "True"
from utils.torch import assert_no_nans, map_to
from modelhub.utils.torch import assert_no_nans, map_to
def test_map_to():

View File

@@ -3,7 +3,7 @@ import torch
import torch.nn as nn
# Import your code here
from utils.weights import (
from modelhub.utils.weights import (
ParameterFreezingConfig,
WeightLoadingConfig,
WeightLoadingPolicy,