mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix: working rf3
This commit is contained in:
63
README.md
63
README.md
@@ -24,7 +24,7 @@ For more information, please see our preprint, [Accelerating Biomolecular Modeli
|
||||
</div>
|
||||
|
||||
> [!TIP]
|
||||
> Complete inference instructions for RF3 are provided [here](src/modelhub/inference_engines/README.md).
|
||||
> Complete inference instructions for RF3 are provided [here](models/rf3/README.md).
|
||||
|
||||
### RF3 Quick Start - Installation & Usage
|
||||
|
||||
@@ -32,7 +32,7 @@ Follow these steps to set up **ModelForge** and run a test prediction with **RF3
|
||||
|
||||
---
|
||||
|
||||
#### 1. Install the repository using `uv`
|
||||
#### 1. Install the source repository and RF3 model using `uv`
|
||||
|
||||
```bash
|
||||
git clone https://github.com/RosettaCommons/modelforge.git \
|
||||
@@ -40,9 +40,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \
|
||||
&& uv python install 3.12 \
|
||||
&& uv venv --python 3.12 \
|
||||
&& source .venv/bin/activate \
|
||||
&& uv pip install -e .
|
||||
&& uv pip install -e ./models/rf3
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency.
|
||||
|
||||
#### 2. Download model weights for RF3
|
||||
```bash
|
||||
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
|
||||
@@ -53,4 +56,56 @@ wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
|
||||
rf3 fold tests/data/5vht_from_json.json
|
||||
```
|
||||
|
||||
Details on the exact formatting of the json files are available [here](src/modelhub/inference_engines/README.md).
|
||||
Details on the exact formatting of the json files are available [here](models/rf3/README.md).
|
||||
|
||||
## Development
|
||||
|
||||
### Package Structure
|
||||
|
||||
ModelForge uses a multi-package architecture:
|
||||
|
||||
- **`modelhub`**: Core package containing shared utilities, training infrastructure, and base classes
|
||||
- **`models/rf3/`**: RF3 model package with model-specific code and dependencies
|
||||
- **`models/<future>/`**: Additional models can be added as separate packages
|
||||
|
||||
### Installation Options
|
||||
|
||||
#### For Users (Single Model)
|
||||
|
||||
Install only the model you need. Dependencies (including `modelhub`) are automatically installed:
|
||||
|
||||
```bash
|
||||
# Install RF3 (includes modelhub automatically)
|
||||
uv pip install -e ./models/rf3
|
||||
|
||||
# Future: Install other models
|
||||
# uv pip install -e ./models/other_model
|
||||
```
|
||||
|
||||
#### For Core Developers (Multiple Packages)
|
||||
|
||||
Install both `modelhub` and models in editable mode for development:
|
||||
|
||||
```bash
|
||||
# Install modelhub and RF3 in editable mode
|
||||
uv pip install -e . -e ./models/rf3
|
||||
|
||||
# Or install only modelhub (no models)
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
This approach allows you to:
|
||||
- Modify `modelhub` shared utilities and see changes immediately
|
||||
- Work on specific models without installing all models
|
||||
- Add new models as independent packages in `models/`
|
||||
|
||||
### Adding New Models
|
||||
|
||||
To add a new model:
|
||||
|
||||
1. Create `models/<model_name>/` directory with its own `pyproject.toml`
|
||||
2. Add `modelhub` as a dependency
|
||||
3. Implement model-specific code in `models/<model_name>/src/`
|
||||
4. Users can install with: `uv pip install -e ./models/<model_name>`
|
||||
|
||||
## Development
|
||||
@@ -1,919 +0,0 @@
|
||||
# ModelForge Repository Restructuring Plan
|
||||
|
||||
**Date**: October 1, 2025
|
||||
**Author**: Analysis based on PyTorch Lightning patterns
|
||||
**Goal**: Reorganize `src/modelhub/` following best practices while keeping `releases/` structure intact
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This plan addresses three critical issues in the current structure:
|
||||
1. **Broken imports**: RF3 uses `from metrics.base import Metric` but base classes are in `src/modelhub/`
|
||||
2. **CLI entry point mismatch**: `pyproject.toml` points to non-existent `modelhub.cli:app`
|
||||
3. **Package installation confusion**: Only `src/modelhub` is packaged, but RF3 code is in `releases/`
|
||||
4. **Unclear organization**: Mixed base classes and implementations without clear semantic separation
|
||||
|
||||
**Core Principle**: Follow PyTorch Lightning's pattern - base classes live WITH their implementations, no separate `contrib/` directory.
|
||||
|
||||
---
|
||||
|
||||
## Current Structure Analysis
|
||||
|
||||
### What Exists Now
|
||||
```
|
||||
modelhub_latent/
|
||||
├── pyproject.toml # name = "rf3" (wrong!)
|
||||
├── src/
|
||||
│ └── modelhub/ # Only this is packaged
|
||||
│ ├── callbacks/
|
||||
│ │ ├── base.py # BaseCallback
|
||||
│ │ ├── health_logging.py # Implementation
|
||||
│ │ ├── timing_logging.py # Implementation
|
||||
│ │ └── train_logging.py # Implementation
|
||||
│ ├── metrics/
|
||||
│ │ ├── base.py # Metric + MetricManager
|
||||
│ │ ├── chiral.py # Implementation
|
||||
│ │ └── rasa.py # Implementation
|
||||
│ ├── utils/
|
||||
│ ├── trainers/
|
||||
│ └── hydra/
|
||||
├── releases/
|
||||
│ └── rf3/ # NOT packaged!
|
||||
│ ├── src/rf3/
|
||||
│ │ ├── callbacks/
|
||||
│ │ ├── metrics/
|
||||
│ │ ├── model/
|
||||
│ │ └── cli.py
|
||||
│ ├── configs/
|
||||
│ └── tests/
|
||||
└── tests/ # Shared tests
|
||||
```
|
||||
|
||||
### Critical Problems
|
||||
|
||||
#### Problem 1: Import Path Mismatch
|
||||
**RF3 code uses**:
|
||||
```python
|
||||
from metrics.base import Metric # Expects local resolution
|
||||
from callbacks.base import BaseCallback # Expects local resolution
|
||||
```
|
||||
|
||||
**Reality**:
|
||||
- Base classes are in `src/modelhub/metrics/base.py`
|
||||
- This only works via sys.path manipulation or broken imports
|
||||
|
||||
**Should be**:
|
||||
```python
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
```
|
||||
|
||||
#### Problem 2: CLI Entry Point
|
||||
**Current `pyproject.toml`**:
|
||||
```toml
|
||||
[project.scripts]
|
||||
rf3 = "modelhub.cli:app"
|
||||
```
|
||||
|
||||
**Reality**: There is NO `src/modelhub/cli.py` file!
|
||||
|
||||
**Actual CLI**: Located at `releases/rf3/src/rf3/cli.py`
|
||||
|
||||
#### Problem 3: Package Name Confusion
|
||||
**Current**: `name = "rf3"` in pyproject.toml
|
||||
|
||||
**Problem**:
|
||||
- Repository is called "modelforge"
|
||||
- Users would `pip install rf3` (gets only base framework)
|
||||
- Can't do `pip install modelforge[rf3]` for selective installation
|
||||
|
||||
#### Problem 4: Packaging Scope
|
||||
**Current**: `packages = ["src/modelhub"]`
|
||||
|
||||
**Problem**: RF3 code in `releases/rf3/src/rf3/` is NOT included in the package!
|
||||
|
||||
---
|
||||
|
||||
## Design Goals
|
||||
|
||||
### User Experience Goals
|
||||
1. `pip install modelforge` → Get base framework only
|
||||
2. `pip install modelforge[rf3]` → Get base + RF3
|
||||
3. `pip install modelforge[all]` → Get all models (future-proof)
|
||||
4. Development: `pip install -e .[rf3,dev]`
|
||||
|
||||
### Code Organization Goals
|
||||
1. Follow PyTorch Lightning patterns (familiar to ML researchers)
|
||||
2. Clear distinction between base classes and implementations
|
||||
3. Proper namespacing for imports
|
||||
4. Maintain `releases/` structure for development workflow
|
||||
|
||||
### Import Pattern Goals
|
||||
```python
|
||||
# Clear, unambiguous imports
|
||||
from modelforge.callbacks.callback import BaseCallback
|
||||
from modelforge.metrics.metric import Metric
|
||||
from modelforge.utils.ddp import RankedLogger
|
||||
|
||||
# RF3-specific imports
|
||||
from modelforge.models.rf3.model import RF3
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PyTorch Lightning Pattern Analysis
|
||||
|
||||
### Lightning's Structure (Reference)
|
||||
```
|
||||
lightning/pytorch/
|
||||
├── core/ # Primary model abstractions
|
||||
│ ├── module.py # LightningModule (what users inherit from)
|
||||
│ ├── datamodule.py # LightningDataModule
|
||||
│ └── hooks.py
|
||||
├── callbacks/ # Base + implementations together
|
||||
│ ├── callback.py # Callback base class
|
||||
│ ├── model_checkpoint.py # ModelCheckpoint implementation
|
||||
│ ├── early_stopping.py # EarlyStopping implementation
|
||||
│ └── __init__.py # Exports everything
|
||||
├── loggers/ # Base + implementations together
|
||||
│ ├── logger.py # Logger base class
|
||||
│ ├── wandb.py # WandbLogger implementation
|
||||
│ └── __init__.py
|
||||
└── utilities/ # Pure utilities
|
||||
```
|
||||
|
||||
### Key Insights from Lightning
|
||||
|
||||
1. **`core/` is for PRIMARY model abstractions** - Things users inherit from to define their model:
|
||||
- `LightningModule` - "Your model IS A LightningModule"
|
||||
- `LightningDataModule` - "Your data IS A LightningDataModule"
|
||||
|
||||
2. **Other directories contain base + implementations together**:
|
||||
- `callbacks/callback.py` (base) + `callbacks/model_checkpoint.py` (impl)
|
||||
- `loggers/logger.py` (base) + `loggers/wandb.py` (impl)
|
||||
|
||||
3. **NO `contrib/` directory** - That's a Django pattern, not Lightning!
|
||||
|
||||
4. **`__init__.py` exports everything** for clean imports
|
||||
|
||||
### When to Use `core/`
|
||||
|
||||
**Use `core/` IF**:
|
||||
- You have a primary model class users inherit from (like `LightningModule`)
|
||||
- These classes define "what you ARE building" (semantic role)
|
||||
|
||||
**Don't use `core/` IF**:
|
||||
- All your base classes are for plugins/extensions (callbacks, metrics)
|
||||
- You don't have a central model abstraction
|
||||
|
||||
**For ModelForge**: Since there's no primary "ModelBase" class, **we don't need `core/`**.
|
||||
|
||||
### Re-evaluating `inference_engines/`
|
||||
|
||||
**Current situation**:
|
||||
- `InferenceEngine` is a minimal ABC with only 2 abstract methods (`__init__`, `eval`)
|
||||
- Only RF3 uses it (no other models yet)
|
||||
- It's a very thin abstraction that provides minimal value
|
||||
|
||||
**Question**: Do we need this abstraction at all?
|
||||
|
||||
**Arguments for keeping it**:
|
||||
- Future-proofs for when other models are added
|
||||
- Provides a consistent interface pattern
|
||||
|
||||
**Arguments for removing it**:
|
||||
- YAGNI (You Aren't Gonna Need It) - only one implementation exists
|
||||
- Adds indirection without clear benefit
|
||||
- The abstraction is so minimal it doesn't enforce meaningful constraints
|
||||
- Each model's inference is likely to be sufficiently different that a shared interface adds little value
|
||||
|
||||
**Recommendation**: **Remove the `inference_engines/` directory entirely from `src/modelhub/`**
|
||||
- RF3's inference engine can live solely in `models/rf3/src/rf3/inference_engines/rf3.py`
|
||||
- Remove the ABC inheritance - just make it a standalone class
|
||||
- When/if a second model is added, we can evaluate whether a shared abstraction makes sense based on actual commonalities
|
||||
- This follows YAGNI and keeps the code simpler
|
||||
|
||||
---
|
||||
|
||||
## Recommended Solution
|
||||
|
||||
### Two-Part Restructuring
|
||||
|
||||
#### Part 1: Fix `src/modelhub/` Structure (Simple, Lightning-style)
|
||||
|
||||
```
|
||||
src/
|
||||
└── modelhub/ (rename to modelforge?)
|
||||
├── __init__.py # Export key base classes
|
||||
│
|
||||
├── callbacks/
|
||||
│ ├── __init__.py # Export Callback + all implementations
|
||||
│ ├── callback.py # BaseCallback (RENAMED from base.py)
|
||||
│ ├── health_logging.py # HealthLoggingCallback
|
||||
│ ├── timing_logging.py # TimingLoggingCallback
|
||||
│ └── train_logging.py # TrainLoggingCallback
|
||||
│
|
||||
├── metrics/
|
||||
│ ├── __init__.py # Export Metric + all implementations
|
||||
│ ├── metric.py # Metric, MetricManager (RENAMED from base.py)
|
||||
│ ├── chiral.py # ChiralMetric
|
||||
│ └── rasa.py # RASAMetric
|
||||
│
|
||||
├── trainers/
|
||||
│ ├── __init__.py
|
||||
│ └── fabric.py # Fabric trainer wrapper
|
||||
│
|
||||
├── utils/ # Pure utilities (no base classes)
|
||||
│ ├── ddp.py
|
||||
│ ├── logging.py
|
||||
│ ├── weights.py
|
||||
│ ├── instantiators.py
|
||||
│ └── torch.py
|
||||
│
|
||||
└── hydra/ # Hydra utilities
|
||||
├── __init__.py
|
||||
└── resolvers.py
|
||||
```
|
||||
|
||||
**Key changes**:
|
||||
1. Rename `base.py` files to match their content:
|
||||
- `callbacks/base.py` → `callbacks/callback.py`
|
||||
- `metrics/base.py` → `metrics/metric.py`
|
||||
2. **REMOVE** `inference_engines/` entirely (not needed with only one model)
|
||||
3. Keep implementations WITH base classes (Lightning pattern)
|
||||
4. No `core/` directory (not needed without primary model abstraction)
|
||||
5. Proper `__init__.py` exports
|
||||
|
||||
#### Part 2: Integrate RF3 into Main Package (For pip install)
|
||||
|
||||
**Decision Point**: Choose ONE of these approaches:
|
||||
|
||||
##### Option A: Move RF3 into src/modelhub/ (Monorepo style)
|
||||
```
|
||||
src/
|
||||
└── modelhub/
|
||||
├── callbacks/, metrics/, utils/ (as above)
|
||||
└── models/ # NEW
|
||||
└── rf3/ # Moved from releases/rf3/src/rf3/
|
||||
├── __init__.py
|
||||
├── model/
|
||||
├── data/
|
||||
├── metrics/ # RF3-specific metrics
|
||||
├── callbacks/ # RF3-specific callbacks
|
||||
├── inference_engines/
|
||||
└── cli.py
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Simple pip install: `pip install modelforge[rf3]`
|
||||
- Single package, single version
|
||||
- Easy code sharing between models
|
||||
|
||||
**Cons**:
|
||||
- Changes development workflow (no separate `releases/` for development)
|
||||
- RF3 code always in source tree even if not installed
|
||||
|
||||
##### Option B: Keep releases/ separate (Development friendly)
|
||||
```
|
||||
modelhub_latent/
|
||||
├── src/modelhub/ # Base framework only
|
||||
├── releases/
|
||||
│ └── rf3/ # Stays as is for development
|
||||
│ └── src/rf3/
|
||||
└── pyproject.toml # Use optional dependencies
|
||||
```
|
||||
|
||||
**Then in `pyproject.toml`**:
|
||||
```toml
|
||||
[project.optional-dependencies]
|
||||
rf3 = [
|
||||
"modelforge-rf3 @ file:///${PROJECT_ROOT}/releases/rf3",
|
||||
# RF3-specific deps
|
||||
]
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Keeps development workflow unchanged
|
||||
- Clear separation for releases
|
||||
|
||||
**Cons**:
|
||||
- More complex build setup
|
||||
- Requires editable installs for development
|
||||
- Each release needs own `pyproject.toml`
|
||||
|
||||
##### Option C: Hybrid - Move RF3 but keep releases/ for development (RECOMMENDED)
|
||||
```
|
||||
# For development: work in releases/rf3/
|
||||
releases/rf3/
|
||||
├── src/rf3/ # Development happens here
|
||||
├── configs/ # RF3 configs here
|
||||
└── tests/ # RF3 tests here
|
||||
|
||||
# For installation: symlink or copy during build
|
||||
src/modelhub/models/rf3/ → symlink to releases/rf3/src/rf3/
|
||||
|
||||
# Or use build hooks to include releases/ in package
|
||||
```
|
||||
|
||||
**Implementation**: Use hatch build hooks to include `releases/rf3/src/rf3/` as `src/modelhub/models/rf3/`
|
||||
|
||||
---
|
||||
|
||||
## Detailed Implementation Plan
|
||||
|
||||
### Phase 1: Rename Package (Breaking Change Decision)
|
||||
|
||||
**Decision needed**: Keep `modelhub` or rename to `modelforge`?
|
||||
|
||||
**Arguments for `modelforge`**:
|
||||
- Matches repository name
|
||||
- More descriptive (framework for building models)
|
||||
- Enables `pip install modelforge[rf3]`
|
||||
|
||||
**Arguments for `modelhub`**:
|
||||
- Less renaming work
|
||||
- Existing code already uses it
|
||||
|
||||
**Recommendation**: Rename to `modelforge` for clarity and marketing.
|
||||
|
||||
### Phase 2: Restructure src/modelhub/ (src/modelforge/)
|
||||
|
||||
#### Step 2.1: Rename Base Class Files
|
||||
```bash
|
||||
# In src/modelhub/
|
||||
git mv callbacks/base.py callbacks/callback.py
|
||||
git mv metrics/base.py metrics/metric.py
|
||||
# inference_engines/base.py could stay or rename to inference_engine.py
|
||||
```
|
||||
|
||||
#### Step 2.2: Update Imports in Base Files
|
||||
**In `src/modelhub/callbacks/callback.py`**:
|
||||
```python
|
||||
# Update any internal imports if needed
|
||||
# Mainly just ensure docstrings reference correct module names
|
||||
```
|
||||
|
||||
**In `src/modelhub/metrics/metric.py`**:
|
||||
```python
|
||||
# Update imports and docstrings
|
||||
```
|
||||
|
||||
#### Step 2.3: Create Proper __init__.py Files
|
||||
|
||||
**`src/modelhub/callbacks/__init__.py`**:
|
||||
```python
|
||||
"""Callbacks for training customization.
|
||||
|
||||
This module provides both the base callback class and common callback implementations.
|
||||
"""
|
||||
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.callbacks.health_logging import HealthLoggingCallback
|
||||
from modelhub.callbacks.timing_logging import TimingLoggingCallback
|
||||
from modelhub.callbacks.train_logging import TrainLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"BaseCallback",
|
||||
"HealthLoggingCallback",
|
||||
"TimingLoggingCallback",
|
||||
"TrainLoggingCallback",
|
||||
]
|
||||
```
|
||||
|
||||
**`src/modelhub/metrics/__init__.py`**:
|
||||
```python
|
||||
"""Metrics for model evaluation.
|
||||
|
||||
This module provides the base metric framework and common metric implementations.
|
||||
"""
|
||||
|
||||
from modelhub.metrics.metric import Metric, MetricManager, instantiate_metric_manager
|
||||
from modelhub.metrics.chiral import ChiralMetric
|
||||
from modelhub.metrics.rasa import RASAMetric
|
||||
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"MetricManager",
|
||||
"instantiate_metric_manager",
|
||||
"ChiralMetric",
|
||||
"RASAMetric",
|
||||
]
|
||||
```
|
||||
|
||||
**`src/modelhub/__init__.py`**:
|
||||
```python
|
||||
"""ModelForge: Open-source framework for biomolecular modeling.
|
||||
|
||||
This package provides the base framework for building and training biomolecular models.
|
||||
"""
|
||||
|
||||
# Export key base classes at top level
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.metrics.metric import Metric, MetricManager
|
||||
|
||||
# Version
|
||||
from modelhub.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"BaseCallback",
|
||||
"Metric",
|
||||
"MetricManager",
|
||||
"__version__",
|
||||
]
|
||||
```
|
||||
|
||||
### Phase 3: Fix RF3 Imports
|
||||
|
||||
#### Step 3.1: Update All RF3 Import Statements
|
||||
|
||||
**Find all broken imports**:
|
||||
```bash
|
||||
cd releases/rf3/
|
||||
grep -r "from metrics.base import" src/
|
||||
grep -r "from callbacks.base import" src/
|
||||
grep -r "from inference_engines.base import" src/
|
||||
```
|
||||
|
||||
**Replace with**:
|
||||
```python
|
||||
# OLD (broken)
|
||||
from metrics.base import Metric
|
||||
from callbacks.base import BaseCallback
|
||||
|
||||
# NEW (correct)
|
||||
from modelhub.metrics.metric import Metric, MetricManager
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.inference_engines.base import InferenceEngine
|
||||
```
|
||||
|
||||
**Automated replacement**:
|
||||
```bash
|
||||
# In releases/rf3/src/
|
||||
find . -name "*.py" -exec sed -i 's/from metrics\.base import/from modelhub.metrics.metric import/g' {} +
|
||||
find . -name "*.py" -exec sed -i 's/from callbacks\.base import/from modelhub.callbacks.callback import/g' {} +
|
||||
find . -name "*.py" -exec sed -i 's/from inference_engines\.base import/from modelhub.inference_engines.base import/g' {} +
|
||||
```
|
||||
|
||||
#### Step 3.2: Update RF3 Utility Imports
|
||||
|
||||
```python
|
||||
# Also update imports of shared utilities
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from modelhub.utils.weights import load_checkpoint
|
||||
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks
|
||||
```
|
||||
|
||||
### Phase 4: Fix Build Configuration
|
||||
|
||||
#### Step 4.1: Update pyproject.toml
|
||||
|
||||
**Current**:
|
||||
```toml
|
||||
[project]
|
||||
name = "rf3"
|
||||
[project.scripts]
|
||||
rf3 = "modelhub.cli:app"
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/modelhub"]
|
||||
```
|
||||
|
||||
**New**:
|
||||
```toml
|
||||
[project]
|
||||
name = "modelforge"
|
||||
description = "Open-source framework for biomolecular structure prediction and design"
|
||||
|
||||
# Minimal dependencies for base framework
|
||||
dependencies = [
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.4.0,<2.5",
|
||||
"hydra-core>=1.3.0,<1.4",
|
||||
"rootutils>=1.0.7,<1.1",
|
||||
"environs>=11.0.0,<12",
|
||||
"wandb>=0.15.10,<1",
|
||||
"rich>=13.9.4,<14",
|
||||
"jaxtyping>=0.2.17,<1",
|
||||
"beartype>=0.18.0,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# RF3 model with its specific dependencies
|
||||
rf3 = [
|
||||
"atomworks==1.0.2",
|
||||
"einops>=0.8.0,<1",
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# All models
|
||||
all = [
|
||||
"modelforge[rf3]",
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = [
|
||||
"ruff==0.8.3",
|
||||
"pytest>=8.2.0,<9",
|
||||
# ... other dev deps
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
# Main CLI - needs to be created or use RF3 CLI directly
|
||||
rf3 = "rf3.cli:app" # Direct to RF3 for now
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
# Include both modelhub and rf3 (if using monorepo approach)
|
||||
packages = ["src/modelhub"]
|
||||
# OR include releases/rf3/src/rf3 via custom build hook
|
||||
|
||||
# For development: force-include releases
|
||||
[tool.hatch.build.targets.wheel.force-include]
|
||||
"releases/rf3/src/rf3" = "modelhub/models/rf3" # If using monorepo approach
|
||||
```
|
||||
|
||||
#### Step 4.2: Create Main CLI (If Needed)
|
||||
|
||||
**Option 1**: Keep `rf3` CLI pointing directly to RF3
|
||||
```toml
|
||||
[project.scripts]
|
||||
rf3 = "rf3.cli:app"
|
||||
```
|
||||
|
||||
**Option 2**: Create dispatcher CLI (future-proof)
|
||||
|
||||
**`src/modelhub/cli.py`**:
|
||||
```python
|
||||
"""Main ModelForge CLI."""
|
||||
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# Import RF3 CLI if available
|
||||
try:
|
||||
from rf3.cli import app as rf3_app
|
||||
# Expose RF3 commands at top level
|
||||
for command_name, command in rf3_app.registered_commands:
|
||||
app.command(name=command_name)(command.callback)
|
||||
except ImportError:
|
||||
pass # RF3 not installed
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
```
|
||||
|
||||
Then:
|
||||
```toml
|
||||
[project.scripts]
|
||||
modelforge = "modelhub.cli:app"
|
||||
rf3 = "rf3.cli:app" # Keep for backward compatibility
|
||||
```
|
||||
|
||||
### Phase 5: Handle RF3 Package Integration
|
||||
|
||||
**Decision needed**: Choose approach from Phase 2, Part 2.
|
||||
|
||||
**Recommended: Option C (Hybrid)**
|
||||
|
||||
Use hatch build hooks to include RF3 from `releases/`:
|
||||
|
||||
**`hatch_build.py`** (in project root):
|
||||
```python
|
||||
"""Custom hatch build hook to include RF3 from releases/."""
|
||||
|
||||
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
|
||||
|
||||
class CustomBuildHook(BuildHookInterface):
|
||||
def initialize(self, version, build_data):
|
||||
"""Include RF3 from releases/ directory."""
|
||||
if self.target_name == "wheel":
|
||||
# Add releases/rf3/src/rf3 to the wheel as modelhub/models/rf3
|
||||
build_data["force_include"] = {
|
||||
"releases/rf3/src/rf3": "modelhub/models/rf3"
|
||||
}
|
||||
```
|
||||
|
||||
**In `pyproject.toml`**:
|
||||
```toml
|
||||
[tool.hatch.build.hooks.custom]
|
||||
path = "hatch_build.py"
|
||||
```
|
||||
|
||||
This way:
|
||||
- Development happens in `releases/rf3/`
|
||||
- Installation includes RF3 in the package
|
||||
- Users can `pip install modelforge[rf3]`
|
||||
|
||||
### Phase 6: Update Tests
|
||||
|
||||
#### Step 6.1: Update Test Imports
|
||||
|
||||
**In `releases/rf3/tests/`**:
|
||||
```python
|
||||
# Update conftest.py and test files
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
```
|
||||
|
||||
**In root `tests/`**:
|
||||
```python
|
||||
# These already test shared utilities
|
||||
from modelhub.utils.weights import load_checkpoint
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
```
|
||||
|
||||
#### Step 6.2: Verify Test Execution
|
||||
|
||||
```bash
|
||||
# Test shared utilities
|
||||
pytest tests/
|
||||
|
||||
# Test RF3 (from releases/rf3/)
|
||||
cd releases/rf3/
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
### Phase 7: Update Documentation
|
||||
|
||||
#### Step 7.1: Update CLAUDE.md
|
||||
|
||||
Update import examples and package structure documentation.
|
||||
|
||||
#### Step 7.2: Update README.md
|
||||
|
||||
```markdown
|
||||
# ModelForge
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base framework only
|
||||
pip install modelforge
|
||||
|
||||
# With RF3
|
||||
pip install modelforge[rf3]
|
||||
|
||||
# Development
|
||||
pip install -e .[rf3,dev]
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Import base framework
|
||||
from modelforge.callbacks import BaseCallback
|
||||
from modelforge.metrics import Metric
|
||||
|
||||
# Import RF3
|
||||
from modelforge.models.rf3 import RF3
|
||||
```
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration Checklist
|
||||
|
||||
### Pre-Migration
|
||||
- [ ] Backup current code
|
||||
- [ ] Create feature branch: `git checkout -b refactor/restructure-package`
|
||||
- [ ] Document current import patterns for reference
|
||||
|
||||
### Phase 1: Package Rename
|
||||
- [ ] Decision: Keep `modelhub` or rename to `modelforge`?
|
||||
- [ ] Update `pyproject.toml` name field
|
||||
- [ ] Update all imports if renaming
|
||||
|
||||
### Phase 2: Restructure src/
|
||||
- [ ] Rename `callbacks/base.py` → `callbacks/callback.py`
|
||||
- [ ] Rename `metrics/base.py` → `metrics/metric.py`
|
||||
- [ ] Create/update all `__init__.py` files with proper exports
|
||||
- [ ] Update internal imports within src/modelhub/
|
||||
|
||||
### Phase 3: Fix RF3 Imports
|
||||
- [ ] Find all `from metrics.base` imports in releases/rf3/
|
||||
- [ ] Replace with `from modelhub.metrics.metric`
|
||||
- [ ] Find all `from callbacks.base` imports
|
||||
- [ ] Replace with `from modelhub.callbacks.callback`
|
||||
- [ ] Update utility imports to use `modelhub.utils.*`
|
||||
- [ ] Test that RF3 can import all needed modules
|
||||
|
||||
### Phase 4: Fix Build Config
|
||||
- [ ] Update `pyproject.toml` project name
|
||||
- [ ] Split dependencies: core vs rf3-specific
|
||||
- [ ] Add `[project.optional-dependencies]` for rf3
|
||||
- [ ] Fix `[project.scripts]` CLI entry point
|
||||
- [ ] Update `[tool.hatch.build.targets.wheel]` packages
|
||||
- [ ] Add build hooks if using hybrid approach
|
||||
|
||||
### Phase 5: RF3 Integration
|
||||
- [ ] Decide on integration approach (A, B, or C)
|
||||
- [ ] Implement chosen approach
|
||||
- [ ] Test `pip install -e .` (base only)
|
||||
- [ ] Test `pip install -e .[rf3]` (with RF3)
|
||||
|
||||
### Phase 6: Update Tests
|
||||
- [ ] Update test imports
|
||||
- [ ] Run shared tests: `pytest tests/`
|
||||
- [ ] Run RF3 tests: `cd releases/rf3 && pytest tests/`
|
||||
- [ ] Fix any import errors
|
||||
|
||||
### Phase 7: Documentation
|
||||
- [ ] Update CLAUDE.md
|
||||
- [ ] Update README.md
|
||||
- [ ] Update any other documentation
|
||||
- [ ] Add migration notes for contributors
|
||||
|
||||
### Phase 8: Validation
|
||||
- [ ] Clean install test: `pip install -e .`
|
||||
- [ ] RF3 install test: `pip install -e .[rf3]`
|
||||
- [ ] Test CLI: `rf3 fold inputs=...`
|
||||
- [ ] Run full test suite
|
||||
- [ ] Test on clean environment
|
||||
|
||||
---
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Test Scenarios
|
||||
|
||||
#### Scenario 1: Base Install Only
|
||||
```bash
|
||||
# Clean environment
|
||||
python -m venv test-env
|
||||
source test-env/bin/activate
|
||||
pip install -e .
|
||||
|
||||
# Should work
|
||||
python -c "from modelhub.callbacks import BaseCallback; print('OK')"
|
||||
python -c "from modelhub.metrics import Metric; print('OK')"
|
||||
python -c "from modelhub.utils.ddp import RankedLogger; print('OK')"
|
||||
|
||||
# Should fail (RF3 not installed)
|
||||
python -c "from modelhub.models.rf3 import RF3" # ImportError expected
|
||||
```
|
||||
|
||||
#### Scenario 2: With RF3
|
||||
```bash
|
||||
# Clean environment
|
||||
python -m venv test-env
|
||||
source test-env/bin/activate
|
||||
pip install -e .[rf3]
|
||||
|
||||
# Should work
|
||||
python -c "from modelhub.models.rf3 import RF3; print('OK')"
|
||||
python -c "from rf3.cli import app; print('OK')"
|
||||
|
||||
# Test CLI
|
||||
rf3 fold inputs='releases/rf3/tests/data/5vht_from_json.json'
|
||||
```
|
||||
|
||||
#### Scenario 3: Development Workflow
|
||||
```bash
|
||||
# Install in development mode with RF3
|
||||
pip install -e .[rf3,dev]
|
||||
|
||||
# Make changes to releases/rf3/src/rf3/model/RF3.py
|
||||
# Changes should be immediately available
|
||||
|
||||
python -c "from modelhub.models.rf3.model import RF3" # Should see changes
|
||||
|
||||
# Run tests
|
||||
pytest releases/rf3/tests/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Risk Assessment
|
||||
|
||||
### High Risk Items
|
||||
1. **Breaking all existing imports** - Requires updating many files
|
||||
- Mitigation: Use automated search/replace, test thoroughly
|
||||
|
||||
2. **Build configuration complexity** - Hatch build hooks can be tricky
|
||||
- Mitigation: Start with simple approach, test installation frequently
|
||||
|
||||
3. **Package name change** - Breaking change for any external users
|
||||
- Mitigation: Coordinate with team, announce breaking change
|
||||
|
||||
### Medium Risk Items
|
||||
1. **CLI entry point changes** - Users may have scripts using old CLI
|
||||
- Mitigation: Keep backward-compatible entry points
|
||||
|
||||
2. **Test imports** - Many test files to update
|
||||
- Mitigation: Automated replacement, run tests incrementally
|
||||
|
||||
### Low Risk Items
|
||||
1. **Documentation updates** - Time-consuming but low risk
|
||||
2. **`__init__.py` exports** - Easy to fix if wrong
|
||||
|
||||
---
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If migration fails:
|
||||
1. `git checkout main` - Revert to pre-migration state
|
||||
2. Keep feature branch for later attempt
|
||||
3. Document what failed for next iteration
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
### Adding New Models
|
||||
Once restructured, adding a new model is straightforward:
|
||||
|
||||
```
|
||||
releases/
|
||||
├── rf3/ # Existing
|
||||
└── proteinmpnn/ # New model
|
||||
├── src/proteinmpnn/
|
||||
├── configs/
|
||||
└── tests/
|
||||
```
|
||||
|
||||
Then in `pyproject.toml`:
|
||||
```toml
|
||||
[project.optional-dependencies]
|
||||
proteinmpnn = [
|
||||
"proteinmpnn-specific-deps",
|
||||
]
|
||||
all = [
|
||||
"modelforge[rf3,proteinmpnn]",
|
||||
]
|
||||
```
|
||||
|
||||
### Splitting Into Separate Packages (Future)
|
||||
If models become too large, can later split:
|
||||
- `modelforge-core` - Base framework
|
||||
- `modelforge-rf3` - RF3 model
|
||||
- `modelforge-proteinmpnn` - ProteinMPNN model
|
||||
|
||||
But this is much more complex and not recommended initially.
|
||||
|
||||
---
|
||||
|
||||
## Questions to Resolve
|
||||
|
||||
1. **Package name**: Keep `modelhub` or rename to `modelforge`?
|
||||
- Recommendation: Rename to `modelforge`
|
||||
|
||||
2. **RF3 integration**: Which approach (A, B, or C)?
|
||||
- Recommendation: Option C (hybrid with build hooks)
|
||||
|
||||
3. **CLI structure**: Single entry point or keep separate?
|
||||
- Recommendation: Keep `rf3` command separate for now
|
||||
|
||||
4. **Base file naming**: Rename `base.py` to match content?
|
||||
- Recommendation: Yes, rename to `callback.py`, `metric.py`
|
||||
|
||||
5. **Version strategy**: Single version or per-model versions?
|
||||
- Recommendation: Single version (simpler)
|
||||
|
||||
---
|
||||
|
||||
## Success Criteria
|
||||
|
||||
- [ ] `pip install modelforge` works
|
||||
- [ ] `pip install modelforge[rf3]` works
|
||||
- [ ] All imports are unambiguous and correct
|
||||
- [ ] All tests pass
|
||||
- [ ] CLI works: `rf3 fold inputs=...`
|
||||
- [ ] Structure matches PyTorch Lightning patterns
|
||||
- [ ] Documentation is updated
|
||||
- [ ] No more `sys.path` manipulation needed
|
||||
|
||||
---
|
||||
|
||||
## Timeline Estimate
|
||||
|
||||
- Phase 1 (Rename decision): 1 hour
|
||||
- Phase 2 (Restructure src/): 2-3 hours
|
||||
- Phase 3 (Fix RF3 imports): 1-2 hours
|
||||
- Phase 4 (Build config): 2-3 hours
|
||||
- Phase 5 (RF3 integration): 2-4 hours
|
||||
- Phase 6 (Update tests): 1-2 hours
|
||||
- Phase 7 (Documentation): 1-2 hours
|
||||
- Phase 8 (Validation): 2-3 hours
|
||||
|
||||
**Total**: 12-20 hours of focused work
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
This restructuring will:
|
||||
1. ✅ Fix all broken imports
|
||||
2. ✅ Follow industry best practices (Lightning pattern)
|
||||
3. ✅ Enable proper pip installation
|
||||
4. ✅ Support selective model installation
|
||||
5. ✅ Maintain development workflow in `releases/`
|
||||
6. ✅ Provide clear, unambiguous imports
|
||||
7. ✅ Scale to future models
|
||||
|
||||
The key is following PyTorch Lightning's pattern: base classes live WITH their implementations, no separate `contrib/` or overly complex directory nesting.
|
||||
89
models/rf3/CONTAINER.md
Normal file
89
models/rf3/CONTAINER.md
Normal file
@@ -0,0 +1,89 @@
|
||||
# RF3 Apptainer Containers
|
||||
|
||||
This directory contains two Apptainer definition files for different use cases.
|
||||
|
||||
## Container Options
|
||||
|
||||
### 1. `rf3-standalone.def` - Standalone Container
|
||||
Contains a complete snapshot of the modelhub repository at build time. Use this for:
|
||||
- Production deployments
|
||||
- Reproducible inference with a fixed codebase
|
||||
- Running on systems where you don't have the repository
|
||||
|
||||
### 2. `rf3-dev.def` - Development Container
|
||||
Contains only Python dependencies (from `requirements.txt`). Use this for:
|
||||
- Active development and testing
|
||||
- Working with your local modelhub code
|
||||
- Debugging and modifying RF3 code
|
||||
|
||||
**Note**: Generate `requirements.txt` first using `uv pip compile pyproject.toml -o requirements.txt` before building this container.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Apptainer/Singularity installed
|
||||
- NVIDIA GPU with CUDA 12.1+ support
|
||||
- Sufficient disk space (~10GB per container)
|
||||
|
||||
## Building Containers
|
||||
|
||||
Build from the `models/rf3/` directory:
|
||||
|
||||
```bash
|
||||
cd models/rf3/
|
||||
|
||||
# Build standalone container (includes modelhub snapshot)
|
||||
apptainer build rf3-standalone.sif rf3-standalone.def
|
||||
|
||||
# Build development container (dependencies only)
|
||||
apptainer build rf3-dev.sif rf3-dev.def
|
||||
|
||||
# Or build with sandbox for debugging
|
||||
apptainer build --sandbox rf3_sandbox/ rf3-standalone.def
|
||||
```
|
||||
|
||||
## Using the Standalone Container
|
||||
|
||||
The standalone container has the full repository baked in.
|
||||
|
||||
### Basic Inference
|
||||
|
||||
```bash
|
||||
# Run inference on a single input
|
||||
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='input.json'
|
||||
|
||||
# Process CIF/PDB files
|
||||
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='structure.cif'
|
||||
|
||||
# Batch processing
|
||||
apptainer exec --nv rf3-standalone.sif rf3 fold inputs='[file1.cif,file2.json,file3.pdb]'
|
||||
```
|
||||
|
||||
### With Custom Weights
|
||||
|
||||
```bash
|
||||
# Mount weights directory and specify checkpoint
|
||||
apptainer exec --nv \
|
||||
--bind /path/to/weights:/weights \
|
||||
rf3-standalone.sif \
|
||||
rf3 fold inputs='input.json' ckpt_path='/weights/rf3_latest.pt'
|
||||
```
|
||||
|
||||
## Using the Development Container
|
||||
|
||||
The development container requires mounting your local modelhub repository.
|
||||
|
||||
### Basic Usage
|
||||
|
||||
```bash
|
||||
# Run with local modelhub repository
|
||||
apptainer exec --nv \
|
||||
--bind /path/to/modelhub:/opt/modelhub \
|
||||
rf3-dev.sif \
|
||||
rf3 fold inputs='input.json'
|
||||
|
||||
# Example with actual path
|
||||
apptainer exec --nv \
|
||||
--bind $PWD/../..:/opt/modelhub \
|
||||
rf3-dev.sif \
|
||||
rf3 fold inputs='input.json'
|
||||
```
|
||||
@@ -1,7 +1,7 @@
|
||||
# Inference with RosettaFold3(RF3)
|
||||
|
||||
<div align="center">
|
||||
<img src="./docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
|
||||
<img src="../../docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
|
||||
</div>
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -21,9 +21,12 @@ git clone https://github.com/RosettaCommons/modelforge.git \
|
||||
&& uv python install 3.12 \
|
||||
&& uv venv --python 3.12 \
|
||||
&& source .venv/bin/activate \
|
||||
&& uv pip install -e .
|
||||
&& uv pip install -e ./models/rf3
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Installing `rf3` automatically installs `modelhub` (shared utilities) as a dependency. For development on both packages, use: `uv pip install -e . -e ./models/rf3`
|
||||
|
||||
### B. Download model weights for RF3
|
||||
```bash
|
||||
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
|
||||
@@ -67,7 +70,7 @@ For this example, the pTM in the `metrics.csv` should be `>0.8` (even without an
|
||||
|
||||
RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommended. We do not at the moment support pre-paired MSAs (we will pair on-the-fly) or on-the-fly MSA computation, but both are on the roadmap. Please raise an issue if these limitations are critical for your project and we can prioritize accordingly.
|
||||
|
||||
📝 **Example JSON configuration** (full example found at `docs/rf3/examples/3en2_from_json_with_msa.json`):
|
||||
📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/3en2_from_json_with_msa.json`):
|
||||
|
||||
```json
|
||||
{
|
||||
@@ -85,7 +88,7 @@ RF3 supports `.a3m` and `.fasta` files as input MSA formats; `.a3m` is recommend
|
||||
🚀 **Run the example:**
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json'
|
||||
```
|
||||
|
||||
---
|
||||
@@ -93,12 +96,12 @@ rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json'
|
||||
If performing inference from a prepared `.cif` file, MSAs can also be specified directly as a category within the raw CIF data.
|
||||
We will automatically extract the correct MSA paths during parsing.
|
||||
|
||||
📝 **Example CIF header** (full example found at `docs/rf3/examples/3en2_from_file.cif`):
|
||||
📝 **Example CIF header** (full example found at `../../docs/releases/rf3/examples/3en2_from_file.cif`):
|
||||
```cif
|
||||
data_3EN2
|
||||
#
|
||||
_msa_paths_by_chain_id.A docs/rf3/examples/msas/b3a35202064.a3m.gz
|
||||
_msa_paths_by_chain_id.B docs/rf3/examples/msas/b3a35202064.a3m.gz
|
||||
_msa_paths_by_chain_id.A ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz
|
||||
_msa_paths_by_chain_id.B ../../docs/releases/rf3/examples/msas/b3a35202064.a3m.gz
|
||||
#
|
||||
```
|
||||
|
||||
@@ -112,7 +115,7 @@ rf3 fold inputs='docs/rf3/3en2_from_file.cif
|
||||
> Without an MSA and using default settings, the above examples will trigger "early stopping." This means that if the model determines early on that a correct prediction is unlikely, it will stop computation and only output a `metrics.csv` and `.score` file to save compute resources. You can adjust this behavior using the `early_stopping_plddt_threshold` argument (see below). In our group, we find this argument can save wasted compute on erroneous inputs.
|
||||
|
||||
> [!TIP]
|
||||
> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='docs/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs.
|
||||
> To ensure that a provided MSA is loaded correctly, you may use the `raise_if_missing_msa_for_protein_of_length_n` command-line argument. For example, `rf3 fold inputs='../../docs/releases/rf3/examples/3en2_from_json_with_msa.json' raise_if_missing_msa_for_protein_of_length_n=10` would raise an error if there were any proteins >=10 residues without compatible MSAs.
|
||||
|
||||
> [!TIP]
|
||||
> For non-canonical amino acids, most MSA generation algorithms substitute `X` (unknown residue)! Ensure your MSAs adhere to this convention.
|
||||
@@ -138,7 +141,7 @@ We will automatically distribute predictions across GPU's if running in a multi-
|
||||
|
||||
### 1️⃣ **Single JSON with Multiple Examples**
|
||||
|
||||
📝 **Example JSON configuration** (full example found at `docs/rf3/examples/multiple_example_from_json.json`)
|
||||
📝 **Example JSON configuration** (full example found at `../../docs/releases/rf3/examples/multiple_example_from_json.json`)
|
||||
|
||||
```json
|
||||
[
|
||||
@@ -172,7 +175,7 @@ We will automatically distribute predictions across GPU's if running in a multi-
|
||||
🚀 **Run the example:**
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/multiple_examples_from_json.json'
|
||||
```
|
||||
|
||||
---
|
||||
@@ -182,7 +185,7 @@ rf3 fold inputs='docs/rf3/examples/multiple_examples_from_json.json'
|
||||
You can specify multiple files/directories using Hydra's list syntax:
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='[docs/rf3/examples/701r_from_file.cif, docs/rf3/examples/701r_from_json.json]'
|
||||
rf3 fold inputs='[../../docs/releases/rf3/examples/701r_from_file.cif, ../../docs/releases/rf3/examples/701r_from_json.json]'
|
||||
```
|
||||
|
||||
---
|
||||
@@ -211,7 +214,7 @@ Complex assemblies containing arbitrary biomolecules can be easily folded if pre
|
||||
|
||||
🚀 **Run an example from a prepared CIF file:**
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif'
|
||||
```
|
||||
|
||||
Such files (including all bonds, covalent modifications, non-canonical amino acids, etc.) can be created either (a) directly from ProteinMPNN/LigandMPNN or other software that generates structural files; or, (b) assembled with [AtomWorks](https://github.com/RosettaCommons/atomworks) or another CIF-processing library.
|
||||
@@ -221,7 +224,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
|
||||
> [!TIP]
|
||||
> **Performance Tip**: For small molecules, a general rule-of-thumb is that performance is best when using `CCD` codes directly, followed by `cif`/`sdf` files, and finally SMILES.
|
||||
|
||||
📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `docs/rf3/examples/7o1r_from_json.json`):
|
||||
📝 **Example JSON configuration with arbitrary biomolecules** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`):
|
||||
```json
|
||||
[
|
||||
{
|
||||
@@ -229,7 +232,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
|
||||
"components": [
|
||||
{
|
||||
"seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH",
|
||||
"msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz",
|
||||
"msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz",
|
||||
"chain_id": "A"
|
||||
},
|
||||
{
|
||||
@@ -242,7 +245,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
|
||||
{
|
||||
// We provide the heme (HEME) via SDF from the CCD; we could have also used a CIF file
|
||||
// We will automatically name the atoms (SDF files do not specify atom names)
|
||||
"path": "docs/rf3/examples/ligands/HEM.sdf"
|
||||
"path": "../../docs/releases/rf3/examples/ligands/HEM.sdf"
|
||||
},
|
||||
{
|
||||
// We provide the imidazole ring (IMD) via SMILES
|
||||
@@ -257,7 +260,7 @@ For convenience, we also support a `json` API analogous to that implemented by A
|
||||
🚀 **Run the example:**
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/7o1r_from_json.json'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_json.json'
|
||||
```
|
||||
|
||||
**Supported input options:**
|
||||
@@ -278,11 +281,11 @@ As described in the example [Folding with Arbitrary Biomolecules](#folding-with-
|
||||
|
||||
For example, folding `7o1r`, which contains two N-glycosylations:
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/7o1r_from_file.cif'
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="./docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
|
||||
<img src="../../docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>Figure: `7o1r` structure showing N-glycosylation covalent modification prediction with RF3 and ground truth crystal structure.</em>
|
||||
@@ -292,7 +295,7 @@ Such `.cif` files complete with appropriate bonds can be composed with AtomWorks
|
||||
|
||||
If you would prefer to use the JSON API, bonds can be explicitly given using PyMol-like strings of the form `chain_id/res_name/res_id/atom_name`. You will need to know the specific chain ID, residue name, residue ID, and atom name between the relevant pairs of atoms to unambiguously specify the bond.
|
||||
|
||||
📝 **Example JSON configuration with covalent modifcations** (full example found at `docs/rf3/examples/7o1r_from_json.json`):
|
||||
📝 **Example JSON configuration with covalent modifcations** (full example found at `../../docs/releases/rf3/examples/7o1r_from_json.json`):
|
||||
```json
|
||||
[
|
||||
{
|
||||
@@ -300,7 +303,7 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM
|
||||
"components": [
|
||||
{
|
||||
"seq": "MKSLSFSLALGFGSTLVYSAPSPSSGWQAPGPNDVRAPCPMLNTLANHGFLPHDGKGITVNKTIDALGSALNIDANLSTLLFGFAATTNPQPNATFFDLDHLSRHNILEHDASLSRQDSYFGPADVFNEAVFNQTKSFWTGDIIDVQMAANARIVRLLTSNLTNPEYSLSDLGSAFSIGESAAYIGILGDKKSATVPKSWVEYLFENERLPYELGFKRPNDPFTTDDLGDLSTQIINAQHFPQSPGKVEKRGDTRCPYGYH",
|
||||
"msa_path": "docs/rf3/examples/msas/7o1r_A.a3m.gz",
|
||||
"msa_path": "../../docs/releases/rf3/examples/msas/7o1r_A.a3m.gz",
|
||||
"chain_id": "A"
|
||||
},
|
||||
{
|
||||
@@ -308,10 +311,10 @@ If you would prefer to use the JSON API, bonds can be explicitly given using PyM
|
||||
},
|
||||
{
|
||||
// We provide one sugar via a CIF file, with complete control over bonds and atom names (as we use the atom names from the CIF file)
|
||||
"path": "docs/rf3/examples/ligands/NAG.cif"
|
||||
"path": "../../docs/releases/rf3/examples/ligands/NAG.cif"
|
||||
},
|
||||
{
|
||||
"path": "docs/rf3/examples/ligands/HEM.sdf"
|
||||
"path": "../../docs/releases/rf3/examples/ligands/HEM.sdf"
|
||||
},
|
||||
{
|
||||
"smiles": "[nH]1cc[nH+]c1"
|
||||
@@ -351,7 +354,7 @@ from atomworks.io.utils.visualize import view
|
||||
import numpy as np
|
||||
|
||||
# Load an SDF file into an AtomArray
|
||||
sdf_path = "docs/rf3/examples/ligands/NAG.sdf"
|
||||
sdf_path = "../../docs/releases/rf3/examples/ligands/NAG.sdf"
|
||||
|
||||
# Load into an AtomArray
|
||||
# Since SDF file files do not have atom names, we automatically generate them (e.g., C1, C2, C3, etc.)
|
||||
@@ -448,17 +451,17 @@ RF3 uses AtomWorks' flexible `AtomSelectionStack` query syntax for specifying st
|
||||
|
||||
It is often helpful to template one or multiple polymer chains while allowing the other chain(s) to fold unconstrained. We demonstrate with an nanobody-antigen use case below how to apply templates.
|
||||
|
||||
📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `docs/rf3/examples/7xli_template_antigen_and_framework.json`):
|
||||
📝 **Example JSON configuration templating the antigen and the nanobody framework** (full example found at `../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json`):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "7xli_template_antigen",
|
||||
"components": [
|
||||
{
|
||||
"path": "docs/rf3/examples/templates/7xli_chain_A.cif"
|
||||
"path": "../../docs/releases/rf3/examples/templates/7xli_chain_A.cif"
|
||||
},
|
||||
{
|
||||
"path": "docs/rf3/examples/templates/7xli_chain_B.cif"
|
||||
"path": "../../docs/releases/rf3/examples/templates/7xli_chain_B.cif"
|
||||
}
|
||||
],
|
||||
"template_selection": ["A", "B/*/1-42", "B/*/49-63", "B/*/71-102", "B/*/108-125"]
|
||||
@@ -469,7 +472,7 @@ It is often helpful to template one or multiple polymer chains while allowing th
|
||||
🚀 **Run the example:**
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/7xli_template_antigen_and_framework.json'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/7xli_template_antigen_and_framework.json'
|
||||
```
|
||||
|
||||
You may also specify templating directly via the CLI using `template_selection="[A, B/*/1-42, ...]"`.
|
||||
@@ -483,14 +486,14 @@ We find that enforcing a particular small molecule conformation has various appl
|
||||
|
||||
For the moment, the ground truth conformer track is only effective if we want to template the *entire* small molecule. Partial templating of small molecules is still possible via the `template_selection` approach. We encourage exploration of both templating techniques to find what combination(s) are most effective for a given problem. Below we provide both, which represents the strongest possible conditioning.
|
||||
|
||||
📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `docs/rf3/examples/1eiz_template_ligand_and_protein.json`):
|
||||
📝 **Example JSON configuration templating a small molecule and the corresponding protein** (full example found at `../../docs/releases/rf3/examples/1eiz_template_ligand_and_protein.json`):
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "9dfn_template_ligand_and_protein",
|
||||
"components": [
|
||||
{
|
||||
"path": "docs/rf3/examples/9dfn.cif"
|
||||
"path": "../../docs/releases/rf3/examples/9dfn.cif"
|
||||
}
|
||||
],
|
||||
"template_selection": ["A", "C", "D"],
|
||||
@@ -506,7 +509,7 @@ For the moment, the ground truth conformer track is only effective if we want to
|
||||
🚀 **Run the example:**
|
||||
|
||||
```bash
|
||||
rf3 fold inputs='docs/rf3/examples/8cdz_templating_ligand.json'
|
||||
rf3 fold inputs='../../docs/releases/rf3/examples/8cdz_templating_ligand.json'
|
||||
```
|
||||
|
||||
You may also specify the ground truth conformer selection directly via the CLI, e.g., using `ground_truth_conformer_selection="[E]"`
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
dump_validation_structures_callback:
|
||||
_target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback
|
||||
_target_: rf3.callbacks.dump_validation_structures.DumpValidationStructuresCallback
|
||||
save_dir: ${paths.output_dir}/val_structures
|
||||
dump_predictions: True
|
||||
one_model_per_file: False
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
store_validation_metrics_in_df_callback:
|
||||
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
|
||||
_target_: rf3.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
|
||||
save_dir: ${paths.output_dir}/val_metrics
|
||||
metrics_to_save: "all"
|
||||
|
||||
log_af3_validation_metrics_callback:
|
||||
_target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
|
||||
_target_: rf3.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
|
||||
# Only logs if present in the metric output dictionary
|
||||
# Must be subset of metrics_to_save
|
||||
metrics_to_log: "all"
|
||||
@@ -17,7 +17,7 @@ defaults:
|
||||
- _self_
|
||||
|
||||
# Dataloading pipeline to use
|
||||
pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline
|
||||
pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline
|
||||
|
||||
# Dataset weighting
|
||||
train:
|
||||
|
||||
@@ -9,7 +9,7 @@ defaults:
|
||||
- _self_
|
||||
|
||||
# Dataloading pipeline to use
|
||||
pipeline_target: modelhub.data.pipelines.build_af3_transform_pipeline
|
||||
pipeline_target: rf3.data.pipelines.build_af3_transform_pipeline
|
||||
|
||||
# Dataset weighting
|
||||
train:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
disorder_distillation:
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: null
|
||||
|
||||
# cif parser arguments
|
||||
@@ -13,19 +13,18 @@ disorder_distillation:
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.GenericDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.GenericDFParser
|
||||
pn_unit_iid_colnames: null
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: pdb_diso_distillation
|
||||
id_column: example_id
|
||||
data: ${paths.data.disorder_distill_parquet_dir}/disorderDistillation.csv
|
||||
columns_to_load:
|
||||
- example_id
|
||||
- path
|
||||
return_key: null
|
||||
transform:
|
||||
_target_: ${datasets.pipeline_target}
|
||||
is_inference: False
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
multidomain_distillation:
|
||||
dataset:
|
||||
_target_: modelhub.data.paired_msa.MultiInputDatasetWrapper
|
||||
_target_: rf3.data.paired_msa.MultiInputDatasetWrapper
|
||||
save_failed_examples_to_dir: null
|
||||
|
||||
# cif parser
|
||||
@@ -14,11 +14,11 @@ multidomain_distillation:
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: modelhub.data.paired_msa.MultidomainDFParser
|
||||
_target_: rf3.data.paired_msa.MultidomainDFParser
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: multidomain_distillation
|
||||
id_column: example_id
|
||||
data: /projects/ml/datahub/dfs/domain_domain/domain_domain_dataset.DIGS.parquet
|
||||
@@ -26,7 +26,6 @@ multidomain_distillation:
|
||||
- example_id
|
||||
- pdb_path
|
||||
- msa_path
|
||||
return_key: null
|
||||
transform:
|
||||
_target_: ${datasets.pipeline_target}
|
||||
is_inference: False
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
monomer_distillation:
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||
|
||||
# cif parser arguments
|
||||
@@ -13,19 +13,18 @@ monomer_distillation:
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.GenericDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.GenericDFParser
|
||||
pn_unit_iid_colnames: null
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: af2fb_distillation
|
||||
id_column: example_id
|
||||
data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet
|
||||
columns_to_load:
|
||||
- example_id
|
||||
- path
|
||||
return_key: null
|
||||
transform:
|
||||
_target_: ${datasets.pipeline_target}
|
||||
is_inference: False
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
na_complex_distillation:
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: null
|
||||
|
||||
# cif parser
|
||||
@@ -14,19 +14,18 @@ na_complex_distillation:
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.GenericDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.GenericDFParser
|
||||
pn_unit_iid_colnames: null #[]
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: tf_distillation
|
||||
id_column: example_id
|
||||
data: ${paths.data.na_complex_distillation_parquet_dir}/transcriptionFactor_distillation_rf3.newDL.csv
|
||||
columns_to_load:
|
||||
- example_id
|
||||
- path
|
||||
return_key: null
|
||||
transform:
|
||||
_target_: ${datasets.pipeline_target}
|
||||
is_inference: False
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
weights:
|
||||
_target_: datahub.samplers.calculate_weights_for_pdb_dataset_df
|
||||
_target_: atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df
|
||||
# We do not include beta here, since it is different for interfaces and chains
|
||||
alphas:
|
||||
a_prot: 3.0 # 3 for AF-3
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||
cif_parser_args:
|
||||
cache_dir: null
|
||||
load_from_cache: false
|
||||
save_to_cache: false
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
# we will use the example_id as the unique column
|
||||
id_column: example_id
|
||||
# return all keys (do not subset)
|
||||
return_key: null
|
||||
transform:
|
||||
# common Transform pipeline components for all PDB datasets
|
||||
_target_: ${datasets.pipeline_target}
|
||||
|
||||
@@ -5,7 +5,7 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.InterfacesDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
|
||||
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
dataset:
|
||||
name: plinder
|
||||
@@ -50,5 +50,5 @@ dataset:
|
||||
crop_spatial_probability: 1.0
|
||||
|
||||
weights:
|
||||
_target_: datahub.samplers.calculate_weights_by_inverse_cluster_size
|
||||
_target_: atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size
|
||||
cluster_column: pli_qcov__50__weak__component # Need to ablate
|
||||
@@ -4,7 +4,7 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.InterfacesDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
|
||||
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
dataset:
|
||||
name: interface
|
||||
|
||||
@@ -4,7 +4,7 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.PNUnitsDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.PNUnitsDFParser
|
||||
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
dataset:
|
||||
name: pn_unit
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
rna_monomer_distillation:
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||
|
||||
# cif parser arguments
|
||||
@@ -13,12 +13,12 @@ rna_monomer_distillation:
|
||||
|
||||
# metadata parser
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.GenericDFParser
|
||||
_target_: atomworks.ml.datasets.parsers.GenericDFParser
|
||||
pn_unit_iid_colnames: null
|
||||
|
||||
# metadata dataset
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: rna_monomer_distillation
|
||||
id_column: example_id
|
||||
data: /projects/ml/afavor/rna_distillation/rna_distillation_filtered_df.parquet
|
||||
@@ -31,7 +31,6 @@ rna_monomer_distillation:
|
||||
- overall_pde
|
||||
- overall_pae
|
||||
|
||||
return_key: null
|
||||
transform:
|
||||
_target_: ${datasets.pipeline_target}
|
||||
is_inference: False
|
||||
|
||||
@@ -3,8 +3,9 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
|
||||
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
|
||||
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: af3_validation
|
||||
data: /net/scratch/rib7/rf3_ab_splits/entry_level_val_df.parquet
|
||||
|
||||
@@ -3,8 +3,9 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
|
||||
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
|
||||
base_dir: /projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: af3_validation
|
||||
data: ${paths.data.pdb_data_dir}/entry_level_val_df.parquet
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
|
||||
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||
cif_parser_args:
|
||||
cache_dir: null
|
||||
load_from_cache: False
|
||||
save_to_cache: False
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
# we will use the example_id as the unique column
|
||||
id_column: example_id
|
||||
# return all keys (do not subset)
|
||||
return_key: null
|
||||
transform:
|
||||
# common Transform pipeline components for all PDB datasets
|
||||
_target_: ${datasets.pipeline_target}
|
||||
|
||||
@@ -3,9 +3,10 @@ defaults:
|
||||
|
||||
dataset:
|
||||
dataset_parser:
|
||||
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
|
||||
_target_: atomworks.ml.datasets.parsers.ValidationDFParserLikeAF3
|
||||
dataset:
|
||||
_target_: datahub.datasets.datasets.PandasDataset
|
||||
_target_: atomworks.ml.datasets.datasets.PandasDataset
|
||||
name: af3_validation
|
||||
data: /projects/ml/datahub/dfs/af3_splits/2024_12_16/runs_n_poses_entry_level_df.parquet
|
||||
filters:
|
||||
- "n_tokens_total < 1000" # Subset to reasonably-sized examples for efficiency
|
||||
|
||||
@@ -5,7 +5,7 @@ name: rf3
|
||||
defaults:
|
||||
- override /datasets: pdb_and_distillation
|
||||
- override /model: rf3
|
||||
- override /trainer: af3
|
||||
- override /trainer: rf3
|
||||
|
||||
ckpt_config:
|
||||
_target_: modelhub.utils.weights.CheckpointConfig
|
||||
|
||||
@@ -6,7 +6,7 @@ name: rf3-with-confidence
|
||||
defaults:
|
||||
- rf3
|
||||
- override /model: rf3_with_confidence
|
||||
- override /trainer: af3_with_confidence
|
||||
- override /trainer: rf3_with_confidence
|
||||
- _self_
|
||||
|
||||
datasets:
|
||||
|
||||
@@ -8,7 +8,7 @@ name: quick-rf3-with-confidence
|
||||
defaults:
|
||||
- quick-rf3
|
||||
- override /model: rf3_with_confidence
|
||||
- override /trainer: af3_with_confidence
|
||||
- override /trainer: rf3_with_confidence
|
||||
- _self_
|
||||
|
||||
datasets:
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# Experiment that loads a small dataset for quick testing
|
||||
|
||||
name: quick-af3
|
||||
name: quick-rf3
|
||||
|
||||
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||
defaults:
|
||||
|
||||
@@ -4,9 +4,9 @@ defaults:
|
||||
- base
|
||||
- _self_
|
||||
|
||||
_target_: modelhub.inference_engines.rf3.RF3InferenceEngine
|
||||
_target_: rf3.inference_engines.rf3.RF3InferenceEngine
|
||||
|
||||
ckpt_path: /software/containers/versions/modelhub_inference/ckpts/rf3_latest.ckpt
|
||||
ckpt_path: /home/ncorley/rf3-w-conf-newdatecut-ep688-refactored-fixed.ckpt
|
||||
|
||||
# Transform arguments
|
||||
n_recycles: 10
|
||||
@@ -36,8 +36,8 @@ annotate_b_factor_with_plddt: true
|
||||
metrics_cfg:
|
||||
# (For confidence outputs)
|
||||
ptm:
|
||||
_target_: modelhub.metrics.predicted_error.ComputePTM
|
||||
_target_: rf3.metrics.predicted_error.ComputePTM
|
||||
iptm:
|
||||
_target_: modelhub.metrics.predicted_error.ComputeIPTM
|
||||
_target_: rf3.metrics.predicted_error.ComputeIPTM
|
||||
count_clashing_chains:
|
||||
_target_: modelhub.metrics.clashing_chains.CountClashingChains
|
||||
_target_: rf3.metrics.clashing_chains.CountClashingChains
|
||||
@@ -1,5 +1,5 @@
|
||||
# Model architecture
|
||||
_target_: modelhub.model.RF3.RF3
|
||||
_target_: rf3.model.RF3.RF3
|
||||
|
||||
# +---------- Channel dimensions ----------+
|
||||
c_s: 384
|
||||
|
||||
@@ -2,7 +2,7 @@ defaults:
|
||||
- rf3_net
|
||||
|
||||
# Model architecture
|
||||
_target_: modelhub.model.RF3.RF3WithConfidence
|
||||
_target_: rf3.model.RF3.RF3WithConfidence
|
||||
|
||||
# +---------- Mini rollout sampler ----------+
|
||||
# From the AF-3 main text:
|
||||
|
||||
@@ -4,4 +4,4 @@ defaults:
|
||||
- _self_
|
||||
|
||||
net:
|
||||
_target_: modelhub.model.RF3.RF3WithConfidence
|
||||
_target_: rf3.model.RF3.RF3WithConfidence
|
||||
@@ -1,5 +1,5 @@
|
||||
# Learning rate scheduler
|
||||
_target_: modelhub.training.schedulers.AF3Scheduler
|
||||
_target_: rf3.training.schedulers.AF3Scheduler
|
||||
base_lr: 1.8e-3
|
||||
warmup_steps: 1000
|
||||
decay_factor: 0.95
|
||||
|
||||
@@ -3,30 +3,37 @@
|
||||
########################
|
||||
|
||||
# path to directory with training splits
|
||||
pdb_data_dir: ???
|
||||
pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2025_07_13
|
||||
|
||||
# fb monomer distillation dataset
|
||||
monomer_distillation_data_dir: ???
|
||||
monomer_distillation_parquet_dir: ???
|
||||
monomer_distillation_data_dir: /squash/af2_distillation_facebook
|
||||
monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
|
||||
|
||||
mgnify_distillation_data_dir: /squash/mgnify_distill_rf3/
|
||||
mgnify_distillation_parquet_dir: /home/dimaio/MGnify/
|
||||
|
||||
# na complex distill set
|
||||
na_complex_distillation_data_dir: ???
|
||||
na_complex_distillation_parquet_dir: ???
|
||||
na_complex_distillation_data_dir: /projects/ml/prot_dna/rf3_newDL
|
||||
na_complex_distillation_parquet_dir: /projects/ml/prot_dna
|
||||
|
||||
# disorder distill set
|
||||
disorder_distill_parquet_dir: ???
|
||||
disorder_distill_parquet_dir: /projects/ml/disorder_distill
|
||||
|
||||
########################
|
||||
# MSAs
|
||||
########################
|
||||
|
||||
# path(s) to search for protein MSAs (for PDB datasets)
|
||||
# e.g., {"dir": "/path/to/msas", "extension": ".a3m.gz", "directory_depth": 2}
|
||||
protein_msa_dirs:
|
||||
- {"dir": ???, "extension": ".a3m.gz", "directory_depth": 2}
|
||||
- {"dir": "/projects/msa/hhblits", "extension": ".a3m.gz", "directory_depth": 2}
|
||||
- {"dir": "/projects/msa/mmseqs_gpu", "extension": ".a3m", "directory_depth": 2}
|
||||
- {"dir": "/projects/msa/lab", "extension": ".a3m", "directory_depth": 2}
|
||||
|
||||
# path(s) to search for RNA MSAs
|
||||
# e.g., {"dir": "/path/to/msas", "extension": ".afa", "directory_depth": 0}
|
||||
rna_msa_dirs:
|
||||
- {"dir": ???, "extension": ".afa", "directory_depth": 0}
|
||||
- {"dir": "/projects/msa/rna", "extension": ".afa", "directory_depth": 0}
|
||||
|
||||
########################
|
||||
# Misc
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
_target_: modelhub.loss.af3_confidence_loss.ConfidenceLoss
|
||||
_target_: rf3.loss.af3_confidence_loss.ConfidenceLoss
|
||||
weight: 1.0
|
||||
|
||||
plddt:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
_target_: modelhub.loss.af3_losses.DiffusionLoss
|
||||
_target_: rf3.loss.af3_losses.DiffusionLoss
|
||||
weight: 4.0
|
||||
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||
alpha_dna: 5
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
_target_: modelhub.loss.af3_losses.DistogramLoss
|
||||
_target_: rf3.loss.af3_losses.DistogramLoss
|
||||
weight: 3e-2
|
||||
@@ -1,14 +1,14 @@
|
||||
by_type_lddt:
|
||||
_target_: modelhub.metrics.lddt.ByTypeLDDT
|
||||
_target_: rf3.metrics.lddt.ByTypeLDDT
|
||||
all_atom_lddt:
|
||||
_target_: modelhub.metrics.lddt.AllAtomLDDT
|
||||
_target_: rf3.metrics.lddt.AllAtomLDDT
|
||||
distogram:
|
||||
_target_: modelhub.metrics.distogram.DistogramLoss
|
||||
_target_: rf3.metrics.distogram.DistogramLoss
|
||||
distogram_comparisons:
|
||||
_target_: modelhub.metrics.distogram.DistogramComparisons
|
||||
_target_: rf3.metrics.distogram.DistogramComparisons
|
||||
distogram_entropy:
|
||||
_target_: modelhub.metrics.distogram.DistogramEntropy
|
||||
_target_: rf3.metrics.distogram.DistogramEntropy
|
||||
chiral_loss:
|
||||
_target_: modelhub.metrics.chiral.ChiralLoss
|
||||
_target_: rf3.metrics.chiral.ChiralLoss
|
||||
unresolved_rasa:
|
||||
_target_: modelhub.metrics.rasa.UnresolvedRegionRASA
|
||||
_target_: rf3.metrics.rasa.UnresolvedRegionRASA
|
||||
|
||||
@@ -3,7 +3,7 @@ defaults:
|
||||
- loss: structure_prediction
|
||||
- metrics: structure_prediction
|
||||
|
||||
_target_: modelhub.trainers.rf3.RF3Trainer
|
||||
_target_: rf3.trainers.rf3.RF3Trainer
|
||||
validate_every_n_epochs: 1
|
||||
max_epochs: 10_000
|
||||
n_examples_per_epoch: 24000
|
||||
|
||||
@@ -2,12 +2,12 @@ defaults:
|
||||
- af3
|
||||
- override loss: structure_prediction_with_confidence
|
||||
|
||||
_target_: modelhub.trainers.rf3.RF3TrainerWithConfidence
|
||||
_target_: rf3.trainers.rf3.RF3TrainerWithConfidence
|
||||
|
||||
metrics:
|
||||
ptm:
|
||||
_target_: modelhub.metrics.predicted_error.ComputePTM
|
||||
_target_: rf3.metrics.predicted_error.ComputePTM
|
||||
iptm:
|
||||
_target_: modelhub.metrics.predicted_error.ComputeIPTM
|
||||
_target_: rf3.metrics.predicted_error.ComputeIPTM
|
||||
count_clashing_chains:
|
||||
_target_: modelhub.metrics.clashing_chains.CountClashingChains
|
||||
_target_: rf3.metrics.clashing_chains.CountClashingChains
|
||||
|
||||
67
models/rf3/pyproject.toml
Normal file
67
models/rf3/pyproject.toml
Normal file
@@ -0,0 +1,67 @@
|
||||
[project]
|
||||
name = "rf3"
|
||||
dynamic = ["version"]
|
||||
description = "RosettaFold3: Open-source biomolecular structure prediction for all molecules of life."
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
authors = [
|
||||
{ name = "Institute for Protein Design", email = "contact@ipd.uw.edu" },
|
||||
]
|
||||
license = { file = "../../LICENSE.md" }
|
||||
|
||||
classifiers = [
|
||||
"Intended Audience :: Developers",
|
||||
"Intended Audience :: Science/Research",
|
||||
"Natural Language :: English",
|
||||
"Operating System :: POSIX :: Linux",
|
||||
"Operating System :: MacOS",
|
||||
"Operating System :: Microsoft :: Windows",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: Implementation :: CPython",
|
||||
"Topic :: Scientific/Engineering :: Bio-Informatics",
|
||||
"License :: OSI Approved :: BSD License",
|
||||
]
|
||||
|
||||
dependencies = [
|
||||
# Core modelhub dependency
|
||||
"modelhub",
|
||||
# CLI
|
||||
"typer>=0.9.0,<1",
|
||||
# RF3-specific ML dependencies
|
||||
"einops>=0.8.0,<1",
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
# ... kernels (Linux only)
|
||||
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
|
||||
# ... dataloading
|
||||
"atomworks==1.0.2",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
rf3 = "rf3.cli:app"
|
||||
|
||||
# Build settings ----------------------------------------------------------------------
|
||||
[build-system]
|
||||
requires = [
|
||||
"hatchling",
|
||||
"hatch-vcs == 0.4",
|
||||
]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "vcs"
|
||||
|
||||
[tool.hatch.version.raw-options]
|
||||
root = "../.."
|
||||
|
||||
[tool.hatch.build.hooks.vcs]
|
||||
version-file = "src/rf3/_version.py"
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/rf3"]
|
||||
61
models/rf3/rf3-dev.def
Normal file
61
models/rf3/rf3-dev.def
Normal file
@@ -0,0 +1,61 @@
|
||||
Bootstrap: docker
|
||||
From: nvcr.io/nvidia/pytorch:25.04-py3
|
||||
IncludeCmd: yes
|
||||
|
||||
# RF3 Development Container - Dependencies only
|
||||
|
||||
%labels
|
||||
Author Institute for Protein Design
|
||||
Version 1.0-dev
|
||||
Description RosettaFold3 - Dependencies only (for development and model training)
|
||||
|
||||
%setup
|
||||
# Create a directory in the container to bind the host's current working directory
|
||||
mkdir ${APPTAINER_ROOTFS}/modelhub_host
|
||||
# ... for mounting `/projects` with --bind
|
||||
mkdir ${APPTAINER_ROOTFS}/projects
|
||||
# ... for mounting `/databases` with --bind
|
||||
mkdir ${APPTAINER_ROOTFS}/net
|
||||
# ... for mounting `/squash` with --bind
|
||||
mkdir ${APPTAINER_ROOTFS}/squash
|
||||
|
||||
%files
|
||||
/etc/localtime
|
||||
/etc/hosts
|
||||
pyproject.toml /opt/pyproject.toml
|
||||
|
||||
%post
|
||||
## GENERAL SETUP
|
||||
|
||||
# Common symlinks (within container)
|
||||
ln -s /net/databases /databases
|
||||
ln -s /net/software /software
|
||||
ln -s /home /mnt/home
|
||||
ln -s /projects /mnt/projects
|
||||
ln -s /net /mnt/net
|
||||
|
||||
## PYTHON DEPENDENCY INSTALLATION
|
||||
|
||||
# Install uv for fast package installation (10-100x faster than pip)
|
||||
pip install uv
|
||||
|
||||
# Auto-generate a dependency list
|
||||
uv pip compile /opt/pyproject.toml -o /opt/requirements.txt
|
||||
|
||||
# Install all Python dependencies using requirements.txt with uv
|
||||
uv pip install --system --break-system-packages -r /opt/requirements.txt
|
||||
|
||||
|
||||
%environment
|
||||
# (Flag to increase accessible GPU memory)
|
||||
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||
|
||||
# (Turn off NVLink)
|
||||
export NCCL_P2P_DISABLE=1
|
||||
|
||||
%runscript
|
||||
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
|
||||
exec python "$@"
|
||||
|
||||
%help
|
||||
RosettaFold3 (RF3) development container with dependencies only.
|
||||
3
models/rf3/src/rf3/__init__.py
Normal file
3
models/rf3/src/rf3/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
"""RF3 - RosettaFold3 model implementation."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
34
models/rf3/src/rf3/_version.py
Normal file
34
models/rf3/src/rf3/_version.py
Normal file
@@ -0,0 +1,34 @@
|
||||
# file generated by setuptools-scm
|
||||
# don't change, don't track in version control
|
||||
|
||||
__all__ = [
|
||||
"__version__",
|
||||
"__version_tuple__",
|
||||
"version",
|
||||
"version_tuple",
|
||||
"__commit_id__",
|
||||
"commit_id",
|
||||
]
|
||||
|
||||
TYPE_CHECKING = False
|
||||
if TYPE_CHECKING:
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
||||
COMMIT_ID = Union[str, None]
|
||||
else:
|
||||
VERSION_TUPLE = object
|
||||
COMMIT_ID = object
|
||||
|
||||
version: str
|
||||
__version__: str
|
||||
__version_tuple__: VERSION_TUPLE
|
||||
version_tuple: VERSION_TUPLE
|
||||
commit_id: COMMIT_ID
|
||||
__commit_id__: COMMIT_ID
|
||||
|
||||
__version__ = version = '0.1.dev917+gcbbe4c6a6.d20251001'
|
||||
__version_tuple__ = version_tuple = (0, 1, 'dev917', 'gcbbe4c6a6.d20251001')
|
||||
|
||||
__commit_id__ = commit_id = None
|
||||
@@ -1,10 +1,10 @@
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
|
||||
from atomworks.common import parse_example_id
|
||||
from atomworks.ml.example_id import parse_example_id
|
||||
from beartype.typing import Any
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from rf3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
dump_structures,
|
||||
|
||||
@@ -7,9 +7,9 @@ from atomworks.ml.utils import nested_dict
|
||||
from beartype.typing import Any, Literal
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from utils.ddp import RankedLogger
|
||||
from utils.logging import (
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import (
|
||||
condense_count_columns_of_grouped_df,
|
||||
print_df_as_table,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import typer
|
||||
from hydra import compose, initialize_config_dir
|
||||
@@ -13,9 +14,11 @@ app = typer.Typer()
|
||||
)
|
||||
def fold(ctx: typer.Context):
|
||||
"""Run structure prediction using hydra config overrides or simple input file."""
|
||||
config_path = os.path.join(
|
||||
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
|
||||
)
|
||||
# Find the RF3 configs directory relative to this file
|
||||
# This file is at: models/rf3/src/rf3/cli.py
|
||||
# Configs are at: models/rf3/configs/
|
||||
rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
|
||||
config_path = str(rf3_package_dir / "configs")
|
||||
|
||||
# Get all arguments
|
||||
args = ctx.params.get("args", []) + ctx.args
|
||||
|
||||
@@ -20,7 +20,7 @@ from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float, Shaped
|
||||
from torch import Tensor
|
||||
|
||||
from utils.torch import assert_no_nans
|
||||
from modelhub.utils.torch import assert_no_nans
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -125,11 +125,7 @@ class MultiInputDatasetWrapper(StructuralDatasetWrapper):
|
||||
)
|
||||
raise e
|
||||
|
||||
# Return the specified key or the entire data dict (i.e., only "feats" key from the Transform dictionary)
|
||||
if exists(self.return_key):
|
||||
return data[self.return_key]
|
||||
else:
|
||||
return data
|
||||
return data
|
||||
|
||||
|
||||
class MultidomainDFParser(MetadataRowParser):
|
||||
|
||||
@@ -3,7 +3,7 @@ from beartype.typing import Any, Literal
|
||||
from jaxtyping import Float
|
||||
|
||||
from rf3.data.rotation_augmentation import centre_random_augmentation
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from dotenv import load_dotenv
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from utils.logging import suppress_warnings
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -18,10 +18,11 @@ load_dotenv(override=True)
|
||||
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
|
||||
_config_path = os.path.join(
|
||||
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
|
||||
)
|
||||
# Find the RF3 configs directory relative to this file
|
||||
# This file is at: models/rf3/src/rf3/inference.py
|
||||
# Configs are at: models/rf3/configs/
|
||||
rf3_package_dir = Path(__file__).parent.parent.parent # Go up to models/rf3/
|
||||
_config_path = str(rf3_package_dir / "configs")
|
||||
|
||||
|
||||
@hydra.main(
|
||||
|
||||
5
models/rf3/src/rf3/inference_engines/__init__.py
Normal file
5
models/rf3/src/rf3/inference_engines/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""RF3 inference engines."""
|
||||
|
||||
from rf3.inference_engines.rf3 import RF3InferenceEngine
|
||||
|
||||
__all__ = ["RF3InferenceEngine"]
|
||||
@@ -10,12 +10,11 @@ from atomworks.io.transforms.categories import category_to_dict
|
||||
from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from inference_engines.base import InferenceEngine
|
||||
from rf3.model.RF3 import ShouldEarlyStopFn
|
||||
from rf3.utils.datasets import (
|
||||
assemble_distributed_inference_loader_from_list_of_paths,
|
||||
)
|
||||
from utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from rf3.utils.inference import (
|
||||
apply_conformer_and_template_selections,
|
||||
build_file_paths_for_prediction,
|
||||
@@ -25,7 +24,7 @@ from rf3.utils.io import (
|
||||
dump_structures,
|
||||
dump_trajectories,
|
||||
)
|
||||
from utils.logging import print_config_tree
|
||||
from modelhub.utils.logging import print_config_tree
|
||||
from rf3.utils.predicted_error import (
|
||||
annotate_atom_array_b_factor_with_plddt,
|
||||
compile_af3_confidence_outputs,
|
||||
@@ -55,7 +54,7 @@ def should_early_stop_by_mean_plddt(
|
||||
return fn
|
||||
|
||||
|
||||
class RF3InferenceEngine(InferenceEngine):
|
||||
class RF3InferenceEngine:
|
||||
"""Class for inference with RF3. Evaluates a trained RF3 model on a set of spoofed CIFs."""
|
||||
|
||||
def __init__(
|
||||
@@ -152,8 +151,6 @@ class RF3InferenceEngine(InferenceEngine):
|
||||
self.cfg.trainer.num_nodes = num_nodes
|
||||
self.cfg.trainer.devices_per_node = devices_per_node
|
||||
self.cfg.trainer.precision = "bf16-mixed" # HACK: Temporary hack until our checkpoint configs are updated
|
||||
self.cfg.trainer._target_ = "modelhub.trainers.rf3.RF3TrainerWithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking
|
||||
self.cfg.model.net._target_ = "modelhub.model.RF3.RF3WithConfidence" # HACK: Enables inference with 9/21 checkpoint for benchmarking
|
||||
|
||||
# We don't want to compute all of the training metrics, since they may error during inference
|
||||
self.cfg.trainer["metrics"] = {}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
# TODO: Many of these functions are unused; we will deprecate and delete
|
||||
# (They are holdovers from previous frameworks)
|
||||
|
||||
from itertools import permutations
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from openbabel import openbabel
|
||||
|
||||
PARAMS = {
|
||||
"DMIN": 1,
|
||||
@@ -355,36 +352,3 @@ def c6d_to_bins(c6d, same_chain, negative=False, params=PARAMS):
|
||||
def standardize_dihedral_retain_first(a, b, c, d):
|
||||
isomorphisms = [(a, b, c, d), (a, c, b, d)]
|
||||
return sorted(isomorphisms)[0]
|
||||
|
||||
|
||||
def get_chirals(obmol, xyz):
|
||||
"""
|
||||
get all quadruples of atoms forming chiral centers and the expected ideal pseudodihedral between them
|
||||
"""
|
||||
stereo = openbabel.OBStereoFacade(obmol)
|
||||
angle = np.arcsin(1 / 3**0.5)
|
||||
chiral_idx_set = set()
|
||||
for i in range(obmol.NumAtoms()):
|
||||
if not stereo.HasTetrahedralStereo(i):
|
||||
continue
|
||||
si = stereo.GetTetrahedralStereo(i)
|
||||
config = si.GetConfig()
|
||||
|
||||
o = config.center
|
||||
c = config.from_or_towards
|
||||
i, j, k = list(config.refs)
|
||||
for a, b, c in permutations((c, i, j, k), 3):
|
||||
chiral_idx_set.add(standardize_dihedral_retain_first(o, a, b, c))
|
||||
|
||||
chiral_idx = list(chiral_idx_set)
|
||||
chiral_idx.sort()
|
||||
chiral_idx = torch.tensor(chiral_idx, dtype=torch.float32)
|
||||
chiral_idx = chiral_idx[(chiral_idx < obmol.NumAtoms()).all(dim=-1)]
|
||||
|
||||
if chiral_idx.numel() == 0:
|
||||
return torch.zeros((0, 5))
|
||||
|
||||
dih = get_dih(*xyz[chiral_idx.long()].split(split_size=1, dim=1))[:, 0]
|
||||
chirals = torch.nn.functional.pad(chiral_idx, (0, 1), mode="constant", value=angle)
|
||||
chirals[dih < 0.0, -1] *= -1
|
||||
return chirals
|
||||
|
||||
@@ -10,7 +10,7 @@ from biotite.structure import AtomArray, AtomArrayStack
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from rf3.kinematics import get_dih
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
def calc_chiral_metrics_masked(
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
import torch
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
class CountClashingChains(Metric):
|
||||
|
||||
@@ -11,8 +11,8 @@ from einops import rearrange, repeat
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from rf3.loss.af3_losses import distogram_loss
|
||||
from metrics.base import Metric
|
||||
from utils.torch import assert_no_nans
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.utils.torch import assert_no_nans
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from atomworks.ml.transforms.atom_array import (
|
||||
AddGlobalTokenIdAnnotation,
|
||||
ensure_atom_array_stack,
|
||||
)
|
||||
from atomworks.io.transforms.atom_array import ensure_atom_array_stack
|
||||
from atomworks.ml.transforms.atom_array import AddGlobalTokenIdAnnotation
|
||||
from atomworks.ml.transforms.atomize import AtomizeByCCDName
|
||||
from atomworks.ml.transforms.base import Compose
|
||||
from atomworks.ml.utils.token import get_token_starts
|
||||
@@ -11,8 +9,8 @@ from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from metrics.base import Metric
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
|
||||
from beartype.typing import Any, Literal
|
||||
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
class ExtraInfo(Metric):
|
||||
|
||||
@@ -2,7 +2,7 @@ from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
from rf3.metrics.metric_utils import find_bin_midpoints
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
class UnresolvedRegionRASA(Metric):
|
||||
@@ -7,7 +7,7 @@ from atomworks.ml.utils.selection import (
|
||||
from beartype.typing import Any
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from metrics.base import Metric
|
||||
from modelhub.metrics.metric import Metric
|
||||
|
||||
|
||||
class SelectedAtomByAtomDistances(Metric):
|
||||
|
||||
@@ -6,7 +6,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
from opt_einsum import contract as einsum
|
||||
|
||||
from src import SHOULD_USE_CUEQUIVARIANCE
|
||||
from modelhub import SHOULD_USE_CUEQUIVARIANCE
|
||||
from rf3.training.checkpoint import activation_checkpointing
|
||||
from rf3.util_module import init_lecun_normal
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from jaxtyping import Float
|
||||
|
||||
from src import SHOULD_USE_CUEQUIVARIANCE
|
||||
from modelhub import SHOULD_USE_CUEQUIVARIANCE
|
||||
from rf3.util_module import init_lecun_normal
|
||||
|
||||
if SHOULD_USE_CUEQUIVARIANCE:
|
||||
|
||||
@@ -12,7 +12,7 @@ from rf3.model.layers.layer_utils import (
|
||||
)
|
||||
from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage
|
||||
from rf3.training.checkpoint import activation_checkpointing
|
||||
from utils.torch import device_of
|
||||
from modelhub.utils.torch import device_of
|
||||
|
||||
|
||||
class AtomAttentionEncoderDiffusion(nn.Module):
|
||||
|
||||
@@ -8,8 +8,8 @@ import rootutils
|
||||
from dotenv import load_dotenv
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from utils.logging import suppress_warnings
|
||||
from utils.weights import CheckpointConfig
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from modelhub.utils.weights import CheckpointConfig
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -18,10 +18,7 @@ load_dotenv(override=True)
|
||||
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
|
||||
_config_path = os.path.join(
|
||||
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
|
||||
)
|
||||
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
|
||||
|
||||
_spawning_process_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -43,14 +40,14 @@ def train(cfg: DictConfig) -> None:
|
||||
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
from callbacks.base import BaseCallback # noqa
|
||||
from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
||||
from utils.logging import (
|
||||
from modelhub.callbacks.callback import BaseCallback # noqa
|
||||
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
||||
from modelhub.utils.logging import (
|
||||
print_config_tree,
|
||||
log_hyperparameters_with_all_loggers,
|
||||
) # noqa
|
||||
from utils.ddp import RankedLogger # noqa
|
||||
from utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
|
||||
from modelhub.utils.ddp import RankedLogger # noqa
|
||||
from modelhub.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa
|
||||
from rf3.utils.datasets import (
|
||||
recursively_instantiate_datasets_and_samplers,
|
||||
assemble_distributed_loader,
|
||||
|
||||
@@ -6,20 +6,20 @@ from jaxtyping import Float, Int
|
||||
from lightning_utilities import apply_to_collection
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from common import exists
|
||||
from modelhub.common import exists
|
||||
from rf3.loss.af3_losses import Loss as AF3Loss
|
||||
from rf3.loss.af3_losses import (
|
||||
ResidueSymmetryResolution,
|
||||
SubunitSymmetryResolution,
|
||||
)
|
||||
from metrics.base import MetricManager
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
from rf3.model.RF3 import ShouldEarlyStopFn
|
||||
from trainers.fabric import FabricTrainer
|
||||
from modelhub.trainers.fabric import FabricTrainer
|
||||
from rf3.training.EMA import EMA
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from rf3.utils.io import build_stack_from_atom_array_and_batched_coords
|
||||
from rf3.utils.recycling import get_recycle_schedule
|
||||
from utils.torch import assert_no_nans, assert_same_shape
|
||||
from modelhub.utils.torch import assert_no_nans, assert_same_shape
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -1,8 +1,4 @@
|
||||
"""Utilities for gradient checkpointing to reduce memory usage during training.
|
||||
|
||||
Gradient checkpointing (also called activation checkpointing) trades compute for memory
|
||||
by recomputing intermediate activations during the backward pass instead of storing them.
|
||||
This enables training larger models or using larger batch sizes within GPU memory constraints.
|
||||
"""Utilities for gradient checkpointing.
|
||||
|
||||
References:
|
||||
* `PyTorch Checkpoint Documentation`_
|
||||
@@ -17,9 +13,8 @@ from torch.utils.checkpoint import checkpoint
|
||||
def create_custom_forward(module, **kwargs):
|
||||
"""Create a custom forward function for gradient checkpointing with fixed kwargs.
|
||||
|
||||
This helper enables passing keyword arguments to a module when using PyTorch's
|
||||
checkpoint function, which only accepts positional arguments for the function to
|
||||
be checkpointed.
|
||||
Enables passing keyword arguments to a module when using PyTorch's checkpoint function,
|
||||
which only accepts positional arguments for the function to be checkpointed.
|
||||
|
||||
Args:
|
||||
module: The callable (typically a nn.Module) to wrap.
|
||||
@@ -28,15 +23,6 @@ def create_custom_forward(module, **kwargs):
|
||||
Returns:
|
||||
A callable that accepts only positional arguments and forwards them along
|
||||
with the fixed kwargs to the original module.
|
||||
|
||||
Examples:
|
||||
Use with PyTorch checkpoint::
|
||||
|
||||
custom_fn = create_custom_forward(my_module, frame_atom_idxs=frame_idxs)
|
||||
output = checkpoint(custom_fn, input_tensor, use_reentrant=False)
|
||||
|
||||
See Also:
|
||||
:py:func:`activation_checkpointing`
|
||||
"""
|
||||
|
||||
def custom_forward(*inputs):
|
||||
@@ -48,11 +34,6 @@ def create_custom_forward(module, **kwargs):
|
||||
def activation_checkpointing(function):
|
||||
"""Decorator to enable gradient checkpointing for a function during training.
|
||||
|
||||
When gradients are enabled (training mode), this decorator wraps the function
|
||||
with PyTorch's checkpoint to save memory by recomputing activations during
|
||||
the backward pass. During inference (gradients disabled), the function runs
|
||||
normally without checkpointing overhead.
|
||||
|
||||
Args:
|
||||
function: The function to apply gradient checkpointing to.
|
||||
|
||||
@@ -67,11 +48,7 @@ def activation_checkpointing(function):
|
||||
return self.layer(x, mask)
|
||||
|
||||
Notes:
|
||||
Uses ``use_reentrant=False`` for better compatibility with modern PyTorch
|
||||
features like autograd hooks and higher-order gradients.
|
||||
|
||||
See Also:
|
||||
:py:func:`create_custom_forward`
|
||||
Uses ``use_reentrant=False`` for compatibility with recent PyTorch versions.
|
||||
"""
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
|
||||
@@ -20,8 +20,8 @@ from torch.utils.data import (
|
||||
)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from hydra.resolvers import register_resolvers
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.hydra.resolvers import register_resolvers
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
try:
|
||||
|
||||
@@ -9,7 +9,7 @@ from beartype.typing import Literal
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
|
||||
from rf3.alignment import weighted_rigid_align
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -2,13 +2,14 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import hydra
|
||||
import rootutils
|
||||
from dotenv import load_dotenv
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from utils.logging import suppress_warnings
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -17,10 +18,7 @@ load_dotenv(override=True)
|
||||
# NOTE: Sets the `PROJECT_ROOT` environment variable to the root directory of the project (where `.project-root` is located)
|
||||
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
||||
|
||||
# If the user has set `PROJECT_PATH`, use it to build the config path; otherwise, fall back to `PROJECT_ROOT`
|
||||
_config_path = os.path.join(
|
||||
os.environ.get("PROJECT_PATH", os.environ["PROJECT_ROOT"]), "configs"
|
||||
)
|
||||
_config_path = os.path.join(os.environ["PROJECT_ROOT"], "models/rf3/configs")
|
||||
|
||||
_spawning_process_logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,11 +40,11 @@ def validate(cfg: DictConfig) -> None:
|
||||
# Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
|
||||
torch.set_float32_matmul_precision("medium")
|
||||
|
||||
from callbacks.base import BaseCallback # noqa
|
||||
from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
||||
from utils.logging import print_config_tree # noqa
|
||||
from utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
|
||||
from utils.ddp import is_rank_zero # noqa
|
||||
from modelhub.callbacks.callback import BaseCallback # noqa
|
||||
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa
|
||||
from modelhub.utils.logging import print_config_tree # noqa
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa
|
||||
from modelhub.utils.ddp import is_rank_zero # noqa
|
||||
from rf3.utils.datasets import assemble_val_loader_dict # noqa
|
||||
|
||||
set_accelerator_based_on_availability(cfg)
|
||||
|
||||
@@ -9,11 +9,18 @@ TEST_DATA_DIR = Path(__file__).resolve().parent / "data"
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
import sys
|
||||
|
||||
# Set PROJECT_ROOT
|
||||
project_root = rootutils.setup_root(
|
||||
__file__, indicator=".project-root", pythonpath=True
|
||||
)
|
||||
|
||||
# Add models/rf3/src to path so RF3 modules can be imported
|
||||
rf3_src = project_root / "models" / "rf3" / "src"
|
||||
if rf3_src.exists() and str(rf3_src) not in sys.path:
|
||||
sys.path.insert(0, str(rf3_src))
|
||||
|
||||
# Construct path to .env file at project root
|
||||
dotenv_path = project_root / ".env"
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
data_5VHT
|
||||
#
|
||||
_msa_paths_by_chain_id.A tests/data/msas/5vht_A.a3m
|
||||
_msa_paths_by_chain_id.B tests/data/msas/5vht_A.a3m
|
||||
_msa_paths_by_chain_id.A models/rf3/tests/data/msas/5vht_A.a3m
|
||||
_msa_paths_by_chain_id.B models/rf3/tests/data/msas/5vht_A.a3m
|
||||
#
|
||||
_entry.id 5VHT
|
||||
#
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
{
|
||||
"seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH",
|
||||
"chain_id": "A",
|
||||
"msa_path": "tests/data/msas/5vht_A.a3m"
|
||||
"msa_path": "models/rf3/tests/data/msas/5vht_A.a3m"
|
||||
},
|
||||
{
|
||||
"seq": "MTSENPLLALREKISALDEKLLALFAERRELAVEVGKAKLLSHRPVRDIDRERDLLERLITLGKAHHLDAH(PBF)ITRTFQLGIEYSVLTQQALLEHHHHHH",
|
||||
"chain_id": "B",
|
||||
"msa_path": "tests/data/msas/5vht_A.a3m"
|
||||
"msa_path": "models/rf3/tests/data/msas/5vht_A.a3m"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ from copy import deepcopy
|
||||
import pytest
|
||||
from atomworks.ml.utils.testing import cached_parse
|
||||
|
||||
from metrics.chiral import ChiralLoss
|
||||
from rf3.metrics.chiral import ChiralLoss
|
||||
|
||||
|
||||
@pytest.mark.parametrize("pdb_id", ["5ocm", "1ivo"])
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "rf3"
|
||||
name = "modelhub"
|
||||
dynamic = ["version"]
|
||||
description = "Open-source biomolecular structure prediction for all molecules of life."
|
||||
description = "Shared utilities and training infrastructure for biomolecular structure prediction models."
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.12"
|
||||
authors = [
|
||||
@@ -22,17 +22,6 @@ classifiers = [
|
||||
"License :: OSI Approved :: BSD License",
|
||||
]
|
||||
dependencies = [
|
||||
# ...ml tools
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.4.0,<2.5",
|
||||
"einops>=0.8.0,<1",
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
# ... kernels (Linux only)
|
||||
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
|
||||
# ... configuration & CLI
|
||||
"rootutils>=1.0.7,<1.1",
|
||||
"hydra-core>=1.3.0,<1.4",
|
||||
@@ -43,9 +32,9 @@ dependencies = [
|
||||
# ... typing & documentation
|
||||
"jaxtyping>=0.2.17,<1",
|
||||
"beartype>=0.18.0,<1",
|
||||
|
||||
# ... dataloading
|
||||
"atomworks==1.0.2"
|
||||
# ... ml tools (core)
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.4.0,<2.5",
|
||||
]
|
||||
|
||||
|
||||
@@ -64,8 +53,7 @@ dev = [
|
||||
"pytest-cov>=4.1.0,<5", # generate coverage report
|
||||
"pytest-benchmark>=5.0.0,<6", # benchmark tests for speed
|
||||
]
|
||||
[project.scripts]
|
||||
rf3 = "modelhub.cli:app"
|
||||
# No CLI scripts at root level - models provide their own entry points
|
||||
# Build settings ----------------------------------------------------------------------
|
||||
[build-system]
|
||||
requires = [
|
||||
|
||||
@@ -44,4 +44,4 @@ except ImportError:
|
||||
logger.debug("cuEquivariance unavailable: import failed")
|
||||
|
||||
# Export for easy access
|
||||
__all__ = ["SHOULD_USE_CUEQUIVARIANCE", "silence_warnings"]
|
||||
__all__ = ["SHOULD_USE_CUEQUIVARIANCE"]
|
||||
|
||||
5
src/modelhub/callbacks/__init__.py
Normal file
5
src/modelhub/callbacks/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Callbacks for training and validation."""
|
||||
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
|
||||
__all__ = ["BaseCallback"]
|
||||
@@ -10,7 +10,7 @@ from jaxtyping import Float, Int
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
||||
from torch import Tensor
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
|
||||
_DEFAULT_STATISTICS = types.MappingProxyType(
|
||||
{
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import pandas as pd
|
||||
from lightning.fabric.utilities.rank_zero import rank_zero_only
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from rf3.utils.logging import print_df_as_table
|
||||
from rf3.utils.torch_utils import Timers
|
||||
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
|
||||
|
||||
class TimingCallback(BaseCallback):
|
||||
"""Fabric callback to print timing metrics."""
|
||||
|
||||
@@ -2,26 +2,26 @@ import time
|
||||
from collections import defaultdict
|
||||
|
||||
import pandas as pd
|
||||
from atomworks.common import parse_example_id
|
||||
from atomworks.ml.example_id import parse_example_id
|
||||
from beartype.typing import Any
|
||||
from lightning.fabric.wrappers import (
|
||||
_FabricOptimizer,
|
||||
)
|
||||
from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
|
||||
from rich.console import Group
|
||||
from rich.panel import Panel
|
||||
from rich.table import Table
|
||||
from torch import nn
|
||||
from torchmetrics.aggregation import MeanMetric
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from rf3.utils.ddp import RankedLogger
|
||||
from rf3.utils.logging import (
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import (
|
||||
print_df_as_table,
|
||||
print_model_parameters,
|
||||
safe_print,
|
||||
table_from_df,
|
||||
)
|
||||
from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses
|
||||
|
||||
|
||||
class LogModelParametersCallback(BaseCallback):
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class InferenceEngine(ABC):
|
||||
"""Abstract base class for inference pipelines."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, inputs: list[Path]) -> None:
|
||||
"""Run inference on input files."""
|
||||
pass
|
||||
12
src/modelhub/metrics/__init__.py
Normal file
12
src/modelhub/metrics/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""Metrics for model evaluation.
|
||||
|
||||
This module provides the base metric framework.
|
||||
"""
|
||||
|
||||
from modelhub.metrics.metric import Metric, MetricInputError, MetricManager
|
||||
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"MetricManager",
|
||||
"MetricInputError",
|
||||
]
|
||||
@@ -8,7 +8,7 @@ from beartype.typing import Any
|
||||
from omegaconf import DictConfig
|
||||
from toolz import keymap
|
||||
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -25,12 +25,12 @@ from lightning.fabric.wrappers import (
|
||||
_FabricModule,
|
||||
_FabricOptimizer,
|
||||
)
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from rf3.training.EMA import EMA
|
||||
from rf3.training.schedulers import SchedulerConfig
|
||||
from utils.ddp import RankedLogger
|
||||
from utils.weights import (
|
||||
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.weights import (
|
||||
CheckpointConfig,
|
||||
WeightLoadingConfig,
|
||||
freeze_parameters_with_config,
|
||||
|
||||
@@ -2,7 +2,7 @@ import hydra
|
||||
from lightning.fabric.loggers import Logger
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from callbacks.base import BaseCallback
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
|
||||
|
||||
def _can_be_instantiated(cfg: DictConfig) -> bool:
|
||||
|
||||
@@ -12,7 +12,7 @@ from rich.table import Table
|
||||
from rich.tree import Tree
|
||||
from torch import nn
|
||||
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -14,8 +14,8 @@ from torch import Tensor
|
||||
from torch._prims_common import DeviceLikeType
|
||||
from torch.types import _dtype
|
||||
|
||||
from src import should_check_nans
|
||||
from common import at_least_one_exists, do_nothing
|
||||
from modelhub import should_check_nans
|
||||
from modelhub.common import at_least_one_exists, do_nothing
|
||||
|
||||
|
||||
def map_to(
|
||||
|
||||
@@ -9,7 +9,7 @@ import torch
|
||||
from beartype.typing import Pattern
|
||||
from torch import nn
|
||||
|
||||
from utils.ddp import RankedLogger
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import pytest
|
||||
import torch
|
||||
|
||||
os.environ["NAN_CHECKING"] = "True"
|
||||
from utils.torch import assert_no_nans, map_to
|
||||
from modelhub.utils.torch import assert_no_nans, map_to
|
||||
|
||||
|
||||
def test_map_to():
|
||||
|
||||
@@ -3,7 +3,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# Import your code here
|
||||
from utils.weights import (
|
||||
from modelhub.utils.weights import (
|
||||
ParameterFreezingConfig,
|
||||
WeightLoadingConfig,
|
||||
WeightLoadingPolicy,
|
||||
|
||||
Reference in New Issue
Block a user