mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
wip: save before letting Claude rip
This commit is contained in:
33
.env
33
.env
@@ -6,11 +6,12 @@
|
||||
|
||||
|
||||
# --- Mirrors to RCSB data ---
|
||||
|
||||
# The `PDB_MIRROR_PATH` is a path to a local mirror of the PDB database. It's
|
||||
# expected that you use the same saving conventions as the RCSB PDB, which means:
|
||||
# `1a2b` --> /path/to/pdb_mirror/a2/1a2b.cif.gz
|
||||
# To set up a mirror, you can use tha atomworks commandline: `atomworks pdb sync /path/to/mirror`
|
||||
PDB_MIRROR_PATH=
|
||||
PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_pdb
|
||||
|
||||
# The `CCD_MIRROR_PATH` is a path to a local mirror of the CCD database.
|
||||
# It's expected that you use the same saving conventions as the RCSB CCD, which means:
|
||||
@@ -19,10 +20,36 @@ PDB_MIRROR_PATH=
|
||||
# If no mirror is provided, the internal biotite CCD will be used as a fallback. To provide a
|
||||
# custom CCD for a ligand, you can place it in the in the CCD mirror path following the CCDs pattern.
|
||||
# Example: /path/to/ccd_mirror/M/MYLIGAND1/MYLIGAND1.cif
|
||||
CCD_MIRROR_PATH=
|
||||
CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_ccd
|
||||
|
||||
# --- Local MSA directories ---
|
||||
LOCAL_MSA_DIRS=/projects/msa/hhblits,/projects/msa/mmseqs_gpu,/projects/msa/lab,/squash/mgnify_distill_rf3/msas
|
||||
|
||||
|
||||
# --- External tools ---
|
||||
|
||||
# The `X3DNA_PATH` is a path to the x3dna tool, which is used for DNA structure analysis.
|
||||
# Example: /path/to/x3dna-v2.4
|
||||
X3DNA_PATH=
|
||||
X3DNA_PATH=
|
||||
|
||||
# The `HHFILTER_PATH` is a path to the hhfilter tool from the HH-suite, which is used for
|
||||
# filtering MSAs to reduce redundancy.
|
||||
# Example: /path/to/hhsuite/build/bin/hhfilter
|
||||
HHFILTER_PATH=/net/software/hhsuite/build/bin/hhfilter
|
||||
|
||||
# The `MMSEQS2_PATH` is a path to the mmseqs2 tool, which is used for fast sequence searching.
|
||||
# Example: /path/to/mmseqs-gpu/bin/mmseqs
|
||||
MMSEQS2_PATH=/net/software/mmseqs-gpu/bin/mmseqs
|
||||
|
||||
# CollabFold MMseqs2 database paths for GPU and CPU usage.
|
||||
|
||||
# Local access (preferred)
|
||||
# NOTE: MMseqs2 databases are best stored on local drives of compute nodes for performance
|
||||
COLABFOLD_LOCAL_DB_PATH_GPU=/local/colabfold/gpu
|
||||
COLABFOLD_LOCAL_DB_PATH_CPU=/local/databases/colabfold/
|
||||
|
||||
# Network access (fallback; may cause IO-related issues)
|
||||
COLABFOLD_NET_DB_PATH_GPU=/net/databases/colabfold/gpu
|
||||
COLABFOLD_NET_DB_PATH_CPU=/net/databases/colabfold/
|
||||
|
||||
|
||||
|
||||
680
CLAUDE.md
Normal file
680
CLAUDE.md
Normal file
@@ -0,0 +1,680 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
### Setup
|
||||
It is **CRITICAL** that before you run **ANY** commands, you activate the python environment like below.
|
||||
Otherwise, you will always run into import and package errors.
|
||||
```bash
|
||||
# IMPORTANT! ALWAYS ACTIVATE THE PYTHON ENVIRONMENT!
|
||||
source .venv/bin/activate
|
||||
```
|
||||
|
||||
## Coding Practices
|
||||
|
||||
Read CAREFULLY the following coding practices. These are central tenants of the way we code. Whenver you write code, retroactively examine your code and ensure it conforms to the principles oulined below.
|
||||
(1) Adhere to the do-not-repeat-yourself (DRY) practice. Break out common operations into shared functions or classes.
|
||||
(2) Prefer functions over classes where possible; the more general the function, the better. Functional programming makes the code more extendable.
|
||||
(3) Follow the YAGNI principle - "You Ain't Gunna Need It." Don't speculatively build functionality that I do not explicitly ask for. Prioritize code simplicity and brevity above all else.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**ModelForge** is a repository of open-source neural networks for biomolecular structure prediction and design. The flagship model is **RosettaFold3 (RF3)**, a structure prediction network competitive with AlphaFold3. The repository uses a shared training harness and integrates with [AtomWorks](https://github.com/RosettaCommons/atomworks) for biomolecular data processing.
|
||||
|
||||
### Current Status
|
||||
The repository is undergoing active refactoring. Code is organized into:
|
||||
- `releases/rf3/`: Stable RF3 model release with complete inference and training
|
||||
- `src/rf3/`: RF3-specific implementation code
|
||||
- `configs/`: Hydra configuration files for RF3
|
||||
- `tests/`: RF3 test suite
|
||||
- `src/modelhub/`: Shared utilities and base classes used across all models
|
||||
- `tests/`: Shared tests at the repository level
|
||||
- `lib/`: Git submodules including AtomWorks
|
||||
|
||||
|
||||
## Key Commands
|
||||
|
||||
### Installation and Setup
|
||||
```bash
|
||||
# Clone and install
|
||||
git clone https://github.com/RosettaCommons/modelforge.git
|
||||
cd modelforge
|
||||
uv python install 3.12
|
||||
uv venv --python 3.12
|
||||
source .venv/bin/activate
|
||||
uv pip install -e .
|
||||
|
||||
# Download RF3 weights
|
||||
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
|
||||
```
|
||||
|
||||
### Development Commands
|
||||
```bash
|
||||
# IMPORTANT: Always activate virtual environment first
|
||||
source .venv/bin/activate
|
||||
|
||||
# Code formatting and linting
|
||||
make format # Format code using ruff (preferred)
|
||||
ruff format src tests # Format code directly
|
||||
ruff check --fix src tests # Lint and fix issues
|
||||
|
||||
# Cleanup
|
||||
make clean # Delete compiled and cached files
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# IMPORTANT: Always activate virtual environment first
|
||||
source .venv/bin/activate
|
||||
|
||||
# Run RF3-specific tests (from releases/rf3/)
|
||||
cd releases/rf3/
|
||||
pytest tests/ # All RF3 tests
|
||||
pytest tests/test_inference_regression.py # Inference regression tests
|
||||
pytest tests/test_write_confidence.py # Confidence output tests
|
||||
|
||||
# Run shared/root-level tests (from repository root)
|
||||
cd /path/to/modelhub_latent
|
||||
pytest tests/ # All shared tests
|
||||
pytest tests/test_metrics.py # Metric tests
|
||||
pytest tests/test_weight_loading.py # Weight loading tests
|
||||
pytest tests/test_torch_utils.py # Torch utility tests
|
||||
|
||||
# Run GPU-dependent tests (requires GPU)
|
||||
pytest tests/test_inference_regression.py -m gpu
|
||||
|
||||
# Run with verbose output
|
||||
pytest tests/ -v
|
||||
```
|
||||
|
||||
### Inference (RF3)
|
||||
```bash
|
||||
# IMPORTANT: Navigate to RF3 release directory
|
||||
cd releases/rf3/
|
||||
|
||||
# Basic structure prediction
|
||||
rf3 fold inputs='tests/data/5vht_from_json.json'
|
||||
|
||||
# With MSA
|
||||
rf3 fold inputs='../../docs/rf3/examples/3en2_from_json_with_msa.json'
|
||||
|
||||
# Batch processing multiple files
|
||||
rf3 fold inputs='[file1.cif, file2.json, file3.pdb]'
|
||||
rf3 fold inputs='path/to/directory' # Process all CIF/PDB/JSON in directory
|
||||
|
||||
# Advanced inference options
|
||||
rf3 fold inputs='input.json' \
|
||||
ckpt_path='/path/to/rf3_latest.pt' \
|
||||
out_dir='./predictions' \
|
||||
n_recycles=10 \
|
||||
diffusion_batch_size=5 \
|
||||
num_steps=50 \
|
||||
annotate_b_factor_with_plddt=true \
|
||||
early_stopping_plddt_threshold=0.5
|
||||
|
||||
# Templating (fix specific regions during prediction)
|
||||
rf3 fold inputs='input.cif' \
|
||||
template_selection='[A, B/*/1-42, B/*/49-63]' \
|
||||
ground_truth_conformer_selection='[C, D]'
|
||||
|
||||
# Alternative: Direct Python invocation (more informative error messages)
|
||||
python src/rf3/inference.py inputs='tests/data/5vht_from_json.json'
|
||||
```
|
||||
|
||||
**Note**: RF3 uses Hydra for configuration, so arguments use `key=value` syntax (not `--key value`).
|
||||
|
||||
**Selection Syntax** (for templating): `CHAIN/RES_NAME/RES_ID/ATOM_NAME`
|
||||
- Exact: `A/ALA/15/CA`
|
||||
- Wildcard: `A/*/*/CA` (all CA atoms in chain A)
|
||||
- Range: `A/*/5-10` (residues 5-10 in chain A)
|
||||
- Union: `A, B` (chains A and B)
|
||||
|
||||
### Training (RF3)
|
||||
```bash
|
||||
# IMPORTANT: Navigate to RF3 release directory
|
||||
cd releases/rf3/
|
||||
|
||||
# Train with specific experiment config
|
||||
python src/rf3/train.py experiment=pretrained/rf3
|
||||
|
||||
# Override specific parameters
|
||||
python src/rf3/train.py \
|
||||
experiment=pretrained/rf3 \
|
||||
trainer.max_steps=100000 \
|
||||
seed=42
|
||||
|
||||
# Resume from checkpoint
|
||||
python src/rf3/train.py \
|
||||
experiment=pretrained/rf3 \
|
||||
ckpt_path=/path/to/checkpoint.ckpt
|
||||
|
||||
# Debug mode (quick testing)
|
||||
python src/rf3/train.py \
|
||||
experiment=pretrained/rf3 \
|
||||
debug=default
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
### Package Structure
|
||||
- `releases/rf3/`: Complete RF3 model release (self-contained)
|
||||
- `src/rf3/`: RF3 implementation package
|
||||
- `model/`: Neural network architecture (Pairformer, Diffusion Module, auxiliary heads)
|
||||
- `layers/`: Network components (attention, triangle multiplication, outer product)
|
||||
- `RF3.py`: Main model class
|
||||
- `RF3_structure.py`: Diffusion module and structural components
|
||||
- `data/`: Data pipelines and transformations
|
||||
- `inference_engines/`: RF3 inference implementations
|
||||
- `loss/`: Loss functions (diffusion loss, confidence loss, distogram loss)
|
||||
- `metrics/`: RF3-specific evaluation metrics
|
||||
- `trainers/`: Training logic using Lightning Fabric
|
||||
- `training/`: Training utilities (EMA, schedulers, checkpoints)
|
||||
- `utils/`: RF3-specific utilities
|
||||
- `callbacks/`: RF3-specific callbacks
|
||||
- `cli.py`: RF3 CLI entry point
|
||||
- `inference.py`: Inference script
|
||||
- `train.py`: Training script
|
||||
- `validate.py`: Validation script
|
||||
- `configs/`: Hydra configuration files for RF3
|
||||
- `model/`: Model architecture configs
|
||||
- `trainer/`: Training configs (loss, metrics)
|
||||
- `datasets/`: Dataset configs
|
||||
- `inference_engine/`: Inference engine configs
|
||||
- `callbacks/`, `logger/`, `paths/`, etc.
|
||||
- `tests/`: RF3 test suite
|
||||
- `test_inference_regression.py`: Regression tests
|
||||
- `test_write_confidence.py`: Confidence output tests
|
||||
- `data/`: Test data and baselines
|
||||
- `src/modelhub/`: Shared utilities across all models
|
||||
- `callbacks/`: Shared training callbacks (health logging, timing, base classes)
|
||||
- `inference_engines/`: Base inference engine interface
|
||||
- `metrics/`: Shared evaluation metrics (base classes, common metrics)
|
||||
- `trainers/`: Shared training infrastructure (Fabric wrappers)
|
||||
- `utils/`: Common utilities (weights, logging, instantiators, DDP, torch utils)
|
||||
- `hydra/`: Hydra resolvers and utilities
|
||||
- `tests/`: Repository-level shared tests
|
||||
- `test_metrics.py`: Shared metric tests
|
||||
- `test_weight_loading.py`: Weight loading tests
|
||||
- `test_torch_utils.py`: Torch utility tests
|
||||
- `lib/`: Git submodules (AtomWorks)
|
||||
- `docs/`: Documentation and examples
|
||||
|
||||
### Key Concepts
|
||||
- **RF3 Architecture**: Pairformer (token-level) → Diffusion Module (atom-level)
|
||||
- Token-level representations: Single (`S`, I×C_s), Pair (`Z`, I×I×C_z)
|
||||
- Atom-level representations: Single (`Q`, L×C_atom), Pair (`P`, L×L×C_atompair)
|
||||
- Distogram head for token-level distance predictions
|
||||
- **Hydra Configuration**: Composable configs with defaults and overrides
|
||||
- **Lightning Fabric**: Distributed training with DDP support
|
||||
- **AtomWorks Integration**: Unified data processing for structures, MSAs, templates
|
||||
- **Input Flexibility**: Supports CIF, PDB, JSON, SMILES, and CCD codes
|
||||
|
||||
### Development Environment
|
||||
- Python 3.12 required (3.11+ supported)
|
||||
- Uses `ruff` for linting and formatting (configured in pyproject.toml)
|
||||
- Testing with `pytest` including GPU-specific tests
|
||||
- Environment variables: Configure in `.env` file (see `.env` template)
|
||||
- `PDB_MIRROR_PATH`: Local PDB mirror for training data
|
||||
- `CCD_MIRROR_PATH`: Chemical Component Dictionary mirror
|
||||
- `LOCAL_MSA_DIRS`: MSA search directories
|
||||
- Tool paths: `HHFILTER_PATH`, `MMSEQS2_PATH`, etc.
|
||||
|
||||
### Data Dependencies
|
||||
The training and inference pipelines expect:
|
||||
- PDB mirror with mmCIF files in standard RCSB sharding pattern
|
||||
- CCD mirror for small molecule definitions
|
||||
- Optional MSA data for protein chains
|
||||
- Pre-computed metadata as parquet files (for training)
|
||||
|
||||
## RF3 Input Formats
|
||||
|
||||
### JSON Format
|
||||
```json
|
||||
{
|
||||
"name": "example_name",
|
||||
"components": [
|
||||
{
|
||||
"seq": "MKTAYIA...", // Protein/NA sequence (supports non-canonical)
|
||||
"msa_path": "path/to/msa.a3m", // Optional MSA (a3m or fasta)
|
||||
"chain_id": "A" // Optional chain ID
|
||||
},
|
||||
{
|
||||
"smiles": "CC(=O)O" // Small molecule via SMILES
|
||||
},
|
||||
{
|
||||
"ccd_code": "HEM" // Chemical Component Dictionary code
|
||||
},
|
||||
{
|
||||
"path": "ligand.sdf" // Structure file (SDF/CIF)
|
||||
}
|
||||
],
|
||||
"bonds": [ // Optional covalent modifications
|
||||
["A/ASN/133/ND2", "B/NAG/1/C1"]
|
||||
],
|
||||
"template_selection": ["A/*/1-50"], // Optional token-level templating
|
||||
"ground_truth_conformer_selection": ["C"] // Optional atom-level templating
|
||||
}
|
||||
```
|
||||
|
||||
### CIF/PDB Files
|
||||
Standard RCSB format with optional MSA specification in CIF header:
|
||||
```cif
|
||||
data_3EN2
|
||||
_msa_paths_by_chain_id.A path/to/msa_A.a3m.gz
|
||||
_msa_paths_by_chain_id.B path/to/msa_B.a3m.gz
|
||||
```
|
||||
|
||||
## RF3 Outputs
|
||||
|
||||
Inference produces:
|
||||
- `{name}_model_{i}.cif.gz`: Predicted structures (gzipped mmCIF, one per diffusion seed)
|
||||
- `{name}_metrics.csv`: Overall confidence metrics (pTM, ipTM, pLDDT, etc.)
|
||||
- `{name}.score`: Detailed per-residue confidence scores
|
||||
|
||||
All CIF outputs can be directly opened in PyMol or parsed with AtomWorks.
|
||||
|
||||
## Testing
|
||||
|
||||
RF3 tests are located in `releases/rf3/tests/`, while shared tests are in the root `tests/` directory.
|
||||
|
||||
- **RF3 Regression tests**: Compare inference outputs against frozen baselines
|
||||
- Location: `releases/rf3/tests/test_inference_regression.py`
|
||||
- Test data: `releases/rf3/tests/data/` (mini examples and baseline predictions)
|
||||
- Run from: `releases/rf3/` directory
|
||||
- **RF3 Confidence tests**: Test confidence output writing
|
||||
- Location: `releases/rf3/tests/test_write_confidence.py`
|
||||
- **Shared Unit tests**: Test shared utilities across models
|
||||
- Location: Root `tests/` directory
|
||||
- `test_metrics.py`: Shared metrics
|
||||
- `test_weight_loading.py`: Weight loading utilities
|
||||
- `test_torch_utils.py`: Torch utility functions
|
||||
- **GPU tests**: Marked with `@pytest.mark.gpu` decorator
|
||||
- **Test execution**: Navigate to appropriate directory before running pytest
|
||||
|
||||
## Git Workflow
|
||||
|
||||
- **Main branch**: `trunk` (use for PRs, **not** `main`)
|
||||
- Current refactoring branch: `refactor/rf3-lab`
|
||||
- The repository is actively being refactored; expect API changes
|
||||
|
||||
## Common Development Patterns
|
||||
|
||||
### Adding New Metrics
|
||||
1. **RF3-specific**: Implement in `releases/rf3/src/rf3/metrics/`
|
||||
2. **Shared across models**: Implement in `src/modelhub/metrics/`
|
||||
3. Inherit from `BaseMetric` in `src/modelhub/metrics/base.py`
|
||||
4. Register in appropriate config file (`releases/rf3/configs/trainer/metrics/`)
|
||||
5. Add tests:
|
||||
- RF3-specific tests in `releases/rf3/tests/`
|
||||
- Shared tests in root `tests/test_metrics.py`
|
||||
|
||||
### Adding New Callbacks
|
||||
1. **RF3-specific**: Implement in `releases/rf3/src/rf3/callbacks/`
|
||||
2. **Shared across models**: Implement in `src/modelhub/callbacks/`
|
||||
3. Inherit from `BaseCallback` in `src/modelhub/callbacks/base.py`
|
||||
4. Register in `releases/rf3/configs/callbacks/`
|
||||
5. Hook into training loop via callback methods (`on_train_batch_end`, etc.)
|
||||
|
||||
### Modifying Model Architecture
|
||||
1. Edit modules in `releases/rf3/src/rf3/model/`
|
||||
2. Update corresponding configs in `releases/rf3/configs/model/`
|
||||
3. Verify with regression tests from `releases/rf3/`:
|
||||
```bash
|
||||
cd releases/rf3/
|
||||
pytest tests/test_inference_regression.py
|
||||
```
|
||||
4. Check weight loading from root:
|
||||
```bash
|
||||
pytest tests/test_weight_loading.py
|
||||
```
|
||||
|
||||
### Working with Hydra Configs
|
||||
```bash
|
||||
# Override nested config values
|
||||
python script.py model.c_s=512 model.c_z=256
|
||||
|
||||
# Override list items (must quote)
|
||||
python script.py inputs='[file1.cif, file2.json]'
|
||||
|
||||
# Change config group defaults
|
||||
python script.py inference_engine=rf3 trainer=ddp
|
||||
|
||||
# Compose multiple configs
|
||||
defaults:
|
||||
- trainer: rf3
|
||||
- model: rf3
|
||||
- datasets: pdb_and_distillation
|
||||
- _self_
|
||||
```
|
||||
|
||||
## Performance Tips
|
||||
|
||||
- **Early stopping**: Set `early_stopping_plddt_threshold=0.5` to skip low-confidence predictions (10-20x faster)
|
||||
- **Batch processing**: Use multiple inputs in single command to amortize startup cost
|
||||
- **Diffusion steps**: Reduce `num_steps=50` (from default 200) for 2x speedup with minimal quality loss
|
||||
- **Recycling**: Default `n_recycles=10`; reduce for faster inference
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Import Errors
|
||||
- Ensure virtual environment is activated: `source .venv/bin/activate`
|
||||
- Check you're in correct directory:
|
||||
- For RF3 tests/inference/training: `cd releases/rf3/`
|
||||
- For shared tests: Stay in repository root
|
||||
- The `rf3` CLI command (via `pyproject.toml`) should work from any directory once installed
|
||||
|
||||
### Missing Data
|
||||
- Set `PDB_MIRROR_PATH` and `CCD_MIRROR_PATH` in `.env`
|
||||
- For MSA: Set `LOCAL_MSA_DIRS` or use `raise_if_missing_msa_for_protein_of_length_n=0` to require MSAs
|
||||
|
||||
### CUDA/GPU Issues
|
||||
- Check GPU availability: `python -c "import torch; print(torch.cuda.is_available())"`
|
||||
- Set precision: `torch.set_float32_matmul_precision("medium")`
|
||||
|
||||
### Hydra Configuration Errors
|
||||
- Use `=` not `--` for arguments: `inputs='file.cif'` not `--inputs file.cif`
|
||||
- Quote lists and strings: `inputs='[file1, file2]'`
|
||||
- For detailed errors, use Python directly: `python src/modelhub/inference.py ...`
|
||||
|
||||
## Docstring Guidelines
|
||||
|
||||
Follow Google-style docstrings with Sphinx optimization. These comprehensive guidelines ensure consistent, high-quality documentation across the codebase.
|
||||
|
||||
### Primary Goals
|
||||
- **Concise**: Keep docstrings as short as possible while being clear
|
||||
- **No redundancy**: Don't repeat function/class names, types when annotated, or obvious behavior
|
||||
- **Sphinx-first**: Prefer reStructuredText roles/directives that render beautifully
|
||||
- **Google section headers**: Use standard section names with colons (e.g., `Args:`), not underlined headings
|
||||
|
||||
### When to Include Sections
|
||||
- **Args**: Include only if non-trivial or non-obvious. Omit types when PEP 484 annotations are present
|
||||
- **Returns**:
|
||||
- Omit if the function returns `None`
|
||||
- Omit if the summary sentence fully describes the return
|
||||
- Otherwise include, using rST for clarity (including literal blocks when helpful)
|
||||
- **Yields**: Use instead of Returns for generators
|
||||
- **Raises**: Include only unusual, explicitly raised exceptions that matter to users
|
||||
- **Examples**: Strongly encouraged when usage isn't obvious. See "Examples formatting" below
|
||||
- **References**: Include when citing standards, papers, or external docs; also when adding rST link targets
|
||||
- **Todo**: Allowed; requires the Sphinx `sphinx.ext.todo` extension
|
||||
- For classes, put argument documentation in `__init__`, not in the class docstring; the class docstring is a high-level overview
|
||||
|
||||
### General Formatting Rules
|
||||
- **One-line summary**: Imperative, single sentence, ends with a period
|
||||
- **Second paragraph**: Optional short elaboration only if genuinely useful
|
||||
- **Inline code**: Use double backticks for identifiers and literals, e.g., ``"same"`` or ``pathlib.Path``
|
||||
- **Cross-references**: Prefer semantic roles (see "Cross-referencing rules" below)
|
||||
- **Line length**: Wrap naturally; don't force hard 79-char wraps if it harms readability
|
||||
- **Defaults**: Use "Defaults to X." at end of the parameter description
|
||||
- **None vs Optional**: Prefer "Defaults to None."; avoid repeating "Optional" if type hints already indicate optionality
|
||||
|
||||
### Section Ordering (Typical)
|
||||
1. Summary
|
||||
2. Optional elaboration
|
||||
3. Args
|
||||
4. Returns or Yields
|
||||
5. Raises
|
||||
6. Notes / Warnings (only if needed)
|
||||
7. Examples
|
||||
8. References
|
||||
9. See Also
|
||||
10. Todo
|
||||
11. Attributes (for modules/classes only when helpful; not for `__init__` args)
|
||||
|
||||
### Args Formatting
|
||||
Use the Google format; types omitted if PEP 484 present. Keep descriptions concise.
|
||||
|
||||
Example with type annotations:
|
||||
```python
|
||||
def fn(a: int, path: str | None = None):
|
||||
"""Do X.
|
||||
|
||||
Args:
|
||||
a: Number of items to process.
|
||||
path: Optional file path. Defaults to ``None``.
|
||||
"""
|
||||
```
|
||||
|
||||
Without type annotations:
|
||||
```python
|
||||
def fn(a, path=None):
|
||||
"""Do X.
|
||||
|
||||
Args:
|
||||
a (int): Number of items to process.
|
||||
path (str, optional): Optional file path. Defaults to ``None``.
|
||||
"""
|
||||
```
|
||||
|
||||
### Returns / Yields Formatting
|
||||
Omit if returning `None` or already clear from the summary. Otherwise, short text; use rST formatting when useful.
|
||||
|
||||
Examples:
|
||||
```python
|
||||
def compute() -> dict[str, int]:
|
||||
"""Compute counts.
|
||||
|
||||
Returns:
|
||||
Mapping of names to counts.
|
||||
|
||||
The ``Returns`` section supports any reStructuredText formatting,
|
||||
including literal blocks::
|
||||
|
||||
{
|
||||
'param1': 1,
|
||||
'param2': 2
|
||||
}
|
||||
"""
|
||||
```
|
||||
|
||||
```python
|
||||
def stream(n: int):
|
||||
"""Yield integers up to ``n``.
|
||||
|
||||
Yields:
|
||||
int: Next value in ``range(n)``.
|
||||
"""
|
||||
```
|
||||
|
||||
### Raises Formatting
|
||||
Include only unusual, explicit exceptions that users should anticipate.
|
||||
```python
|
||||
Raises:
|
||||
ValueError: If ``threshold`` is negative.
|
||||
```
|
||||
|
||||
### Examples Formatting
|
||||
- Use the `Examples:` section
|
||||
- Prefer doctest-style for simple usage with expected output
|
||||
- Use `.. code-block:: python` for multi-line, setup-heavy, or non-doctest examples
|
||||
- Precede each code block with a short sentence describing the case
|
||||
- Keep examples minimal; 1–3 cases is typical. Avoid redundant examples
|
||||
- Don't mix doctest and code-block in the same example unless clearly justified
|
||||
|
||||
Doctest example:
|
||||
```python
|
||||
def square(x: int) -> int:
|
||||
"""Return the square of ``x``.
|
||||
|
||||
Examples:
|
||||
>>> square(3)
|
||||
9
|
||||
"""
|
||||
```
|
||||
|
||||
Multi-case using code blocks:
|
||||
```python
|
||||
def from_selection_str(s: str):
|
||||
"""Create an ``AtomSelection`` from a selection string.
|
||||
|
||||
Examples:
|
||||
Select CA at chain A, residue 1::
|
||||
|
||||
AtomSelection.from_selection_str("A/ALA/1/CA")
|
||||
|
||||
Select CB in any chain at any residue::
|
||||
|
||||
AtomSelection.from_selection_str("*/ALA/*/CB")
|
||||
"""
|
||||
```
|
||||
|
||||
Or with an explicit directive:
|
||||
```python
|
||||
Examples:
|
||||
Basic usage:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
result = fn("input")
|
||||
print(result)
|
||||
```
|
||||
|
||||
### References Formatting
|
||||
Use the `References:` section to:
|
||||
- Cite external resources (standards, papers, docs)
|
||||
- Define hyperlink targets (label definitions) for clean inline linking in the docstring
|
||||
- Place external link definitions directly in the `References:` section or right after it
|
||||
|
||||
Examples:
|
||||
```python
|
||||
def parse():
|
||||
"""Parse input.
|
||||
|
||||
References:
|
||||
* `Google Python Style Guide`_
|
||||
* `PEP 484`_
|
||||
|
||||
.. _Google Python Style Guide: http://google.github.io/styleguide/pyguide.html
|
||||
.. _PEP 484: https://www.python.org/dev/peps/pep-0484/
|
||||
"""
|
||||
```
|
||||
|
||||
### Cross-referencing Rules (Use Liberally)
|
||||
Prefer semantic roles over plain backticks for important code entities:
|
||||
- **Functions**: `:py:func:`package.module.fn`` (use `~` to shorten)
|
||||
- **Classes**: `:py:class:`package.module.Class``; methods: `:py:meth:`~package.module.Class.method``
|
||||
- **Modules**: `:py:mod:`package.module``; **Attributes**: `:py:attr:`~package.module.Class.attr``
|
||||
- Use `:ref:` for intra-doc labels (sections/figures/tables)
|
||||
- Use `:doc:` to link other documents by path
|
||||
- Use `:any:` when unsure of the role; Sphinx will try to resolve it
|
||||
- Prefer cross-references over repeating type info or behavior
|
||||
|
||||
### Notes, Warnings, and See Also
|
||||
Keep these sparse; only when clarity is improved. Prefer Google sections (`Notes:`, `Warnings:`, `See Also:`). If a visual callout is needed, you may use admonitions (`.. note::`, `.. warning::`) inside the docstring; keep content concise.
|
||||
|
||||
Example:
|
||||
```python
|
||||
Notes:
|
||||
Uses :py:mod:`asyncio` for scheduling.
|
||||
|
||||
See Also:
|
||||
:py:class:`~package.Scheduler`, :py:func:`~package.schedule_task`
|
||||
```
|
||||
|
||||
### Class and __init__ Rules
|
||||
- **Class docstring**: High-level overview and purpose; avoid argument docs
|
||||
- **__init__ docstring**: Document constructor parameters in `Args:`; follow the same formatting rules as functions
|
||||
|
||||
```python
|
||||
class Cache:
|
||||
"""In-memory cache with time-based eviction.
|
||||
|
||||
See Also:
|
||||
:py:class:`~package.LRUCache`
|
||||
"""
|
||||
|
||||
def __init__(self, ttl: float, capacity: int = 1024):
|
||||
"""Initialize the cache.
|
||||
|
||||
Args:
|
||||
ttl: Time-to-live in seconds.
|
||||
capacity: Max entries. Defaults to ``1024``.
|
||||
"""
|
||||
```
|
||||
|
||||
### Module Docstrings
|
||||
Short overview of module purpose and key public objects. Optionally an `Attributes:` section for significant module-level constants.
|
||||
|
||||
```python
|
||||
"""Tools for model inference and orchestration.
|
||||
|
||||
Attributes:
|
||||
DEFAULT_TIMEOUT (float): Default timeout for network calls.
|
||||
|
||||
Todo:
|
||||
* Add support for streaming outputs.
|
||||
* Enable tracing hooks.
|
||||
"""
|
||||
```
|
||||
|
||||
### Style/Consistency Conventions
|
||||
- Use double backticks for inline code/literals: ``None``, ``"text"``, ``/path``
|
||||
- Use present tense and active voice
|
||||
- Avoid restating obvious types when PEP 484 annotations exist
|
||||
- Prefer short paragraphs and bulleted lists only when they clarify meaning
|
||||
- Keep "Examples" and "References" well-formed and visually separated
|
||||
|
||||
### Quick Templates
|
||||
|
||||
Function (concise):
|
||||
```python
|
||||
def fn(x: int) -> int:
|
||||
"""Return the square of ``x``.
|
||||
|
||||
Examples:
|
||||
>>> fn(3)
|
||||
9
|
||||
"""
|
||||
```
|
||||
|
||||
Function (fuller):
|
||||
```python
|
||||
def load(path: str, *, strict: bool = False) -> dict:
|
||||
"""Load a configuration from ``path``.
|
||||
|
||||
Args:
|
||||
path: File path to load.
|
||||
strict: Validate schema strictly. Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
Mapping representing the configuration.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist.
|
||||
|
||||
Examples:
|
||||
Strict load with error handling:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
cfg = load("config.yaml", strict=True)
|
||||
|
||||
References:
|
||||
* `YAML Spec`_
|
||||
|
||||
.. _YAML Spec: https://yaml.org/spec/
|
||||
"""
|
||||
```
|
||||
|
||||
Class + __init__:
|
||||
```python
|
||||
class Runner:
|
||||
"""Execute tasks with configurable concurrency."""
|
||||
|
||||
def __init__(self, max_workers: int = 4):
|
||||
"""Initialize the runner.
|
||||
|
||||
Args:
|
||||
max_workers: Number of worker threads. Defaults to ``4``.
|
||||
"""
|
||||
```
|
||||
|
||||
Generator:
|
||||
```python
|
||||
def items():
|
||||
"""Yield items from the queue.
|
||||
|
||||
Yields:
|
||||
Item type: Next available item.
|
||||
"""
|
||||
```
|
||||
8
Makefile
8
Makefile
@@ -4,6 +4,14 @@
|
||||
# COMMANDS #
|
||||
#################################################################################
|
||||
|
||||
## Delete all compiled and cached files
|
||||
clean:
|
||||
find . -type f -name "*.py[co]" -delete
|
||||
find . -type d -name "__pycache__" -delete
|
||||
rm -rf .pytest_cache
|
||||
rm -rf .ruff_cache
|
||||
rm -rf .benchmarks
|
||||
|
||||
## Format src directory using black
|
||||
format:
|
||||
ruff format src tests
|
||||
|
||||
919
RESTRUCTURING_PLAN.md
Normal file
919
RESTRUCTURING_PLAN.md
Normal file
@@ -0,0 +1,919 @@
|
||||
# ModelForge Repository Restructuring Plan
|
||||
|
||||
**Date**: October 1, 2025
|
||||
**Author**: Analysis based on PyTorch Lightning patterns
|
||||
**Goal**: Reorganize `src/modelhub/` following best practices while keeping `releases/` structure intact
|
||||
|
||||
---
|
||||
|
||||
## Executive Summary
|
||||
|
||||
This plan addresses three critical issues in the current structure:
|
||||
1. **Broken imports**: RF3 uses `from metrics.base import Metric` but base classes are in `src/modelhub/`
|
||||
2. **CLI entry point mismatch**: `pyproject.toml` points to non-existent `modelhub.cli:app`
|
||||
3. **Package installation confusion**: Only `src/modelhub` is packaged, but RF3 code is in `releases/`
|
||||
4. **Unclear organization**: Mixed base classes and implementations without clear semantic separation
|
||||
|
||||
**Core Principle**: Follow PyTorch Lightning's pattern - base classes live WITH their implementations, no separate `contrib/` directory.
|
||||
|
||||
---
|
||||
|
||||
## Current Structure Analysis
|
||||
|
||||
### What Exists Now
|
||||
```
|
||||
modelhub_latent/
|
||||
├── pyproject.toml # name = "rf3" (wrong!)
|
||||
├── src/
|
||||
│ └── modelhub/ # Only this is packaged
|
||||
│ ├── callbacks/
|
||||
│ │ ├── base.py # BaseCallback
|
||||
│ │ ├── health_logging.py # Implementation
|
||||
│ │ ├── timing_logging.py # Implementation
|
||||
│ │ └── train_logging.py # Implementation
|
||||
│ ├── metrics/
|
||||
│ │ ├── base.py # Metric + MetricManager
|
||||
│ │ ├── chiral.py # Implementation
|
||||
│ │ └── rasa.py # Implementation
|
||||
│ ├── utils/
|
||||
│ ├── trainers/
|
||||
│ └── hydra/
|
||||
├── releases/
|
||||
│ └── rf3/ # NOT packaged!
|
||||
│ ├── src/rf3/
|
||||
│ │ ├── callbacks/
|
||||
│ │ ├── metrics/
|
||||
│ │ ├── model/
|
||||
│ │ └── cli.py
|
||||
│ ├── configs/
|
||||
│ └── tests/
|
||||
└── tests/ # Shared tests
|
||||
```
|
||||
|
||||
### Critical Problems
|
||||
|
||||
#### Problem 1: Import Path Mismatch
|
||||
**RF3 code uses**:
|
||||
```python
|
||||
from metrics.base import Metric # Expects local resolution
|
||||
from callbacks.base import BaseCallback # Expects local resolution
|
||||
```
|
||||
|
||||
**Reality**:
|
||||
- Base classes are in `src/modelhub/metrics/base.py`
|
||||
- This only works via sys.path manipulation or broken imports
|
||||
|
||||
**Should be**:
|
||||
```python
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
```
|
||||
|
||||
#### Problem 2: CLI Entry Point
|
||||
**Current `pyproject.toml`**:
|
||||
```toml
|
||||
[project.scripts]
|
||||
rf3 = "modelhub.cli:app"
|
||||
```
|
||||
|
||||
**Reality**: There is NO `src/modelhub/cli.py` file!
|
||||
|
||||
**Actual CLI**: Located at `releases/rf3/src/rf3/cli.py`
|
||||
|
||||
#### Problem 3: Package Name Confusion
|
||||
**Current**: `name = "rf3"` in pyproject.toml
|
||||
|
||||
**Problem**:
|
||||
- Repository is called "modelforge"
|
||||
- Users would `pip install rf3` (gets only base framework)
|
||||
- Can't do `pip install modelforge[rf3]` for selective installation
|
||||
|
||||
#### Problem 4: Packaging Scope
|
||||
**Current**: `packages = ["src/modelhub"]`
|
||||
|
||||
**Problem**: RF3 code in `releases/rf3/src/rf3/` is NOT included in the package!
|
||||
|
||||
---
|
||||
|
||||
## Design Goals
|
||||
|
||||
### User Experience Goals
|
||||
1. `pip install modelforge` → Get base framework only
|
||||
2. `pip install modelforge[rf3]` → Get base + RF3
|
||||
3. `pip install modelforge[all]` → Get all models (future-proof)
|
||||
4. Development: `pip install -e .[rf3,dev]`
|
||||
|
||||
### Code Organization Goals
|
||||
1. Follow PyTorch Lightning patterns (familiar to ML researchers)
|
||||
2. Clear distinction between base classes and implementations
|
||||
3. Proper namespacing for imports
|
||||
4. Maintain `releases/` structure for development workflow
|
||||
|
||||
### Import Pattern Goals
|
||||
```python
|
||||
# Clear, unambiguous imports
|
||||
from modelforge.callbacks.callback import BaseCallback
|
||||
from modelforge.metrics.metric import Metric
|
||||
from modelforge.utils.ddp import RankedLogger
|
||||
|
||||
# RF3-specific imports
|
||||
from modelforge.models.rf3.model import RF3
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## PyTorch Lightning Pattern Analysis
|
||||
|
||||
### Lightning's Structure (Reference)
|
||||
```
|
||||
lightning/pytorch/
|
||||
├── core/ # Primary model abstractions
|
||||
│ ├── module.py # LightningModule (what users inherit from)
|
||||
│ ├── datamodule.py # LightningDataModule
|
||||
│ └── hooks.py
|
||||
├── callbacks/ # Base + implementations together
|
||||
│ ├── callback.py # Callback base class
|
||||
│ ├── model_checkpoint.py # ModelCheckpoint implementation
|
||||
│ ├── early_stopping.py # EarlyStopping implementation
|
||||
│ └── __init__.py # Exports everything
|
||||
├── loggers/ # Base + implementations together
|
||||
│ ├── logger.py # Logger base class
|
||||
│ ├── wandb.py # WandbLogger implementation
|
||||
│ └── __init__.py
|
||||
└── utilities/ # Pure utilities
|
||||
```
|
||||
|
||||
### Key Insights from Lightning
|
||||
|
||||
1. **`core/` is for PRIMARY model abstractions** - Things users inherit from to define their model:
|
||||
- `LightningModule` - "Your model IS A LightningModule"
|
||||
- `LightningDataModule` - "Your data IS A LightningDataModule"
|
||||
|
||||
2. **Other directories contain base + implementations together**:
|
||||
- `callbacks/callback.py` (base) + `callbacks/model_checkpoint.py` (impl)
|
||||
- `loggers/logger.py` (base) + `loggers/wandb.py` (impl)
|
||||
|
||||
3. **NO `contrib/` directory** - That's a Django pattern, not Lightning!
|
||||
|
||||
4. **`__init__.py` exports everything** for clean imports
|
||||
|
||||
### When to Use `core/`
|
||||
|
||||
**Use `core/` IF**:
|
||||
- You have a primary model class users inherit from (like `LightningModule`)
|
||||
- These classes define "what you ARE building" (semantic role)
|
||||
|
||||
**Don't use `core/` IF**:
|
||||
- All your base classes are for plugins/extensions (callbacks, metrics)
|
||||
- You don't have a central model abstraction
|
||||
|
||||
**For ModelForge**: Since there's no primary "ModelBase" class, **we don't need `core/`**.
|
||||
|
||||
### Re-evaluating `inference_engines/`
|
||||
|
||||
**Current situation**:
|
||||
- `InferenceEngine` is a minimal ABC with only 2 abstract methods (`__init__`, `eval`)
|
||||
- Only RF3 uses it (no other models yet)
|
||||
- It's a very thin abstraction that provides minimal value
|
||||
|
||||
**Question**: Do we need this abstraction at all?
|
||||
|
||||
**Arguments for keeping it**:
|
||||
- Future-proofs for when other models are added
|
||||
- Provides a consistent interface pattern
|
||||
|
||||
**Arguments for removing it**:
|
||||
- YAGNI (You Aren't Gonna Need It) - only one implementation exists
|
||||
- Adds indirection without clear benefit
|
||||
- The abstraction is so minimal it doesn't enforce meaningful constraints
|
||||
- Each model's inference is likely to be sufficiently different that a shared interface adds little value
|
||||
|
||||
**Recommendation**: **Remove the `inference_engines/` directory entirely from `src/modelhub/`**
|
||||
- RF3's inference engine can live solely in `models/rf3/src/rf3/inference_engines/rf3.py`
|
||||
- Remove the ABC inheritance - just make it a standalone class
|
||||
- When/if a second model is added, we can evaluate whether a shared abstraction makes sense based on actual commonalities
|
||||
- This follows YAGNI and keeps the code simpler
|
||||
|
||||
---
|
||||
|
||||
## Recommended Solution
|
||||
|
||||
### Two-Part Restructuring
|
||||
|
||||
#### Part 1: Fix `src/modelhub/` Structure (Simple, Lightning-style)
|
||||
|
||||
```
|
||||
src/
|
||||
└── modelhub/ (rename to modelforge?)
|
||||
├── __init__.py # Export key base classes
|
||||
│
|
||||
├── callbacks/
|
||||
│ ├── __init__.py # Export Callback + all implementations
|
||||
│ ├── callback.py # BaseCallback (RENAMED from base.py)
|
||||
│ ├── health_logging.py # HealthLoggingCallback
|
||||
│ ├── timing_logging.py # TimingLoggingCallback
|
||||
│ └── train_logging.py # TrainLoggingCallback
|
||||
│
|
||||
├── metrics/
|
||||
│ ├── __init__.py # Export Metric + all implementations
|
||||
│ ├── metric.py # Metric, MetricManager (RENAMED from base.py)
|
||||
│ ├── chiral.py # ChiralMetric
|
||||
│ └── rasa.py # RASAMetric
|
||||
│
|
||||
├── trainers/
|
||||
│ ├── __init__.py
|
||||
│ └── fabric.py # Fabric trainer wrapper
|
||||
│
|
||||
├── utils/ # Pure utilities (no base classes)
|
||||
│ ├── ddp.py
|
||||
│ ├── logging.py
|
||||
│ ├── weights.py
|
||||
│ ├── instantiators.py
|
||||
│ └── torch.py
|
||||
│
|
||||
└── hydra/ # Hydra utilities
|
||||
├── __init__.py
|
||||
└── resolvers.py
|
||||
```
|
||||
|
||||
**Key changes**:
|
||||
1. Rename `base.py` files to match their content:
|
||||
- `callbacks/base.py` → `callbacks/callback.py`
|
||||
- `metrics/base.py` → `metrics/metric.py`
|
||||
2. **REMOVE** `inference_engines/` entirely (not needed with only one model)
|
||||
3. Keep implementations WITH base classes (Lightning pattern)
|
||||
4. No `core/` directory (not needed without primary model abstraction)
|
||||
5. Proper `__init__.py` exports
|
||||
|
||||
#### Part 2: Integrate RF3 into Main Package (For pip install)
|
||||
|
||||
**Decision Point**: Choose ONE of these approaches:
|
||||
|
||||
##### Option A: Move RF3 into src/modelhub/ (Monorepo style)
|
||||
```
|
||||
src/
|
||||
└── modelhub/
|
||||
├── callbacks/, metrics/, utils/ (as above)
|
||||
└── models/ # NEW
|
||||
└── rf3/ # Moved from releases/rf3/src/rf3/
|
||||
├── __init__.py
|
||||
├── model/
|
||||
├── data/
|
||||
├── metrics/ # RF3-specific metrics
|
||||
├── callbacks/ # RF3-specific callbacks
|
||||
├── inference_engines/
|
||||
└── cli.py
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Simple pip install: `pip install modelforge[rf3]`
|
||||
- Single package, single version
|
||||
- Easy code sharing between models
|
||||
|
||||
**Cons**:
|
||||
- Changes development workflow (no separate `releases/` for development)
|
||||
- RF3 code always in source tree even if not installed
|
||||
|
||||
##### Option B: Keep releases/ separate (Development friendly)
|
||||
```
|
||||
modelhub_latent/
|
||||
├── src/modelhub/ # Base framework only
|
||||
├── releases/
|
||||
│ └── rf3/ # Stays as is for development
|
||||
│ └── src/rf3/
|
||||
└── pyproject.toml # Use optional dependencies
|
||||
```
|
||||
|
||||
**Then in `pyproject.toml`**:
|
||||
```toml
|
||||
[project.optional-dependencies]
|
||||
rf3 = [
|
||||
"modelforge-rf3 @ file:///${PROJECT_ROOT}/releases/rf3",
|
||||
# RF3-specific deps
|
||||
]
|
||||
```
|
||||
|
||||
**Pros**:
|
||||
- Keeps development workflow unchanged
|
||||
- Clear separation for releases
|
||||
|
||||
**Cons**:
|
||||
- More complex build setup
|
||||
- Requires editable installs for development
|
||||
- Each release needs own `pyproject.toml`
|
||||
|
||||
##### Option C: Hybrid - Move RF3 but keep releases/ for development (RECOMMENDED)
|
||||
```
|
||||
# For development: work in releases/rf3/
|
||||
releases/rf3/
|
||||
├── src/rf3/ # Development happens here
|
||||
├── configs/ # RF3 configs here
|
||||
└── tests/ # RF3 tests here
|
||||
|
||||
# For installation: symlink or copy during build
|
||||
src/modelhub/models/rf3/ → symlink to releases/rf3/src/rf3/
|
||||
|
||||
# Or use build hooks to include releases/ in package
|
||||
```
|
||||
|
||||
**Implementation**: Use hatch build hooks to include `releases/rf3/src/rf3/` as `src/modelhub/models/rf3/`
|
||||
|
||||
---
|
||||
|
||||
## Detailed Implementation Plan
|
||||
|
||||
### Phase 1: Rename Package (Breaking Change Decision)
|
||||
|
||||
**Decision needed**: Keep `modelhub` or rename to `modelforge`?
|
||||
|
||||
**Arguments for `modelforge`**:
|
||||
- Matches repository name
|
||||
- More descriptive (framework for building models)
|
||||
- Enables `pip install modelforge[rf3]`
|
||||
|
||||
**Arguments for `modelhub`**:
|
||||
- Less renaming work
|
||||
- Existing code already uses it
|
||||
|
||||
**Recommendation**: Rename to `modelforge` for clarity and marketing.
|
||||
|
||||
### Phase 2: Restructure src/modelhub/ (src/modelforge/)
|
||||
|
||||
#### Step 2.1: Rename Base Class Files
|
||||
```bash
|
||||
# In src/modelhub/
|
||||
git mv callbacks/base.py callbacks/callback.py
|
||||
git mv metrics/base.py metrics/metric.py
|
||||
# inference_engines/base.py could stay or rename to inference_engine.py
|
||||
```
|
||||
|
||||
#### Step 2.2: Update Imports in Base Files
|
||||
**In `src/modelhub/callbacks/callback.py`**:
|
||||
```python
|
||||
# Update any internal imports if needed
|
||||
# Mainly just ensure docstrings reference correct module names
|
||||
```
|
||||
|
||||
**In `src/modelhub/metrics/metric.py`**:
|
||||
```python
|
||||
# Update imports and docstrings
|
||||
```
|
||||
|
||||
#### Step 2.3: Create Proper __init__.py Files
|
||||
|
||||
**`src/modelhub/callbacks/__init__.py`**:
|
||||
```python
|
||||
"""Callbacks for training customization.
|
||||
|
||||
This module provides both the base callback class and common callback implementations.
|
||||
"""
|
||||
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.callbacks.health_logging import HealthLoggingCallback
|
||||
from modelhub.callbacks.timing_logging import TimingLoggingCallback
|
||||
from modelhub.callbacks.train_logging import TrainLoggingCallback
|
||||
|
||||
__all__ = [
|
||||
"BaseCallback",
|
||||
"HealthLoggingCallback",
|
||||
"TimingLoggingCallback",
|
||||
"TrainLoggingCallback",
|
||||
]
|
||||
```
|
||||
|
||||
**`src/modelhub/metrics/__init__.py`**:
|
||||
```python
|
||||
"""Metrics for model evaluation.
|
||||
|
||||
This module provides the base metric framework and common metric implementations.
|
||||
"""
|
||||
|
||||
from modelhub.metrics.metric import Metric, MetricManager, instantiate_metric_manager
|
||||
from modelhub.metrics.chiral import ChiralMetric
|
||||
from modelhub.metrics.rasa import RASAMetric
|
||||
|
||||
__all__ = [
|
||||
"Metric",
|
||||
"MetricManager",
|
||||
"instantiate_metric_manager",
|
||||
"ChiralMetric",
|
||||
"RASAMetric",
|
||||
]
|
||||
```
|
||||
|
||||
**`src/modelhub/__init__.py`**:
|
||||
```python
|
||||
"""ModelForge: Open-source framework for biomolecular modeling.
|
||||
|
||||
This package provides the base framework for building and training biomolecular models.
|
||||
"""
|
||||
|
||||
# Export key base classes at top level
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.metrics.metric import Metric, MetricManager
|
||||
|
||||
# Version
|
||||
from modelhub.version import __version__
|
||||
|
||||
__all__ = [
|
||||
"BaseCallback",
|
||||
"Metric",
|
||||
"MetricManager",
|
||||
"__version__",
|
||||
]
|
||||
```
|
||||
|
||||
### Phase 3: Fix RF3 Imports
|
||||
|
||||
#### Step 3.1: Update All RF3 Import Statements
|
||||
|
||||
**Find all broken imports**:
|
||||
```bash
|
||||
cd releases/rf3/
|
||||
grep -r "from metrics.base import" src/
|
||||
grep -r "from callbacks.base import" src/
|
||||
grep -r "from inference_engines.base import" src/
|
||||
```
|
||||
|
||||
**Replace with**:
|
||||
```python
|
||||
# OLD (broken)
|
||||
from metrics.base import Metric
|
||||
from callbacks.base import BaseCallback
|
||||
|
||||
# NEW (correct)
|
||||
from modelhub.metrics.metric import Metric, MetricManager
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
from modelhub.inference_engines.base import InferenceEngine
|
||||
```
|
||||
|
||||
**Automated replacement**:
|
||||
```bash
|
||||
# In releases/rf3/src/
|
||||
find . -name "*.py" -exec sed -i 's/from metrics\.base import/from modelhub.metrics.metric import/g' {} +
|
||||
find . -name "*.py" -exec sed -i 's/from callbacks\.base import/from modelhub.callbacks.callback import/g' {} +
|
||||
find . -name "*.py" -exec sed -i 's/from inference_engines\.base import/from modelhub.inference_engines.base import/g' {} +
|
||||
```
|
||||
|
||||
#### Step 3.2: Update RF3 Utility Imports
|
||||
|
||||
```python
|
||||
# Also update imports of shared utilities
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from modelhub.utils.weights import load_checkpoint
|
||||
from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks
|
||||
```
|
||||
|
||||
### Phase 4: Fix Build Configuration
|
||||
|
||||
#### Step 4.1: Update pyproject.toml
|
||||
|
||||
**Current**:
|
||||
```toml
|
||||
[project]
|
||||
name = "rf3"
|
||||
[project.scripts]
|
||||
rf3 = "modelhub.cli:app"
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/modelhub"]
|
||||
```
|
||||
|
||||
**New**:
|
||||
```toml
|
||||
[project]
|
||||
name = "modelforge"
|
||||
description = "Open-source framework for biomolecular structure prediction and design"
|
||||
|
||||
# Minimal dependencies for base framework
|
||||
dependencies = [
|
||||
"torch>=2.2.0,<3",
|
||||
"lightning>=2.4.0,<2.5",
|
||||
"hydra-core>=1.3.0,<1.4",
|
||||
"rootutils>=1.0.7,<1.1",
|
||||
"environs>=11.0.0,<12",
|
||||
"wandb>=0.15.10,<1",
|
||||
"rich>=13.9.4,<14",
|
||||
"jaxtyping>=0.2.17,<1",
|
||||
"beartype>=0.18.0,<1",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
# RF3 model with its specific dependencies
|
||||
rf3 = [
|
||||
"atomworks==1.0.2",
|
||||
"einops>=0.8.0,<1",
|
||||
"einx>=0.1.0,<1",
|
||||
"opt_einsum>=3.4.0,<4",
|
||||
"dm-tree>=0.1.6,<1",
|
||||
"cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'",
|
||||
"cuequivariance_torch>=0.5.0; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
# All models
|
||||
all = [
|
||||
"modelforge[rf3]",
|
||||
]
|
||||
|
||||
# Development
|
||||
dev = [
|
||||
"ruff==0.8.3",
|
||||
"pytest>=8.2.0,<9",
|
||||
# ... other dev deps
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
# Main CLI - needs to be created or use RF3 CLI directly
|
||||
rf3 = "rf3.cli:app" # Direct to RF3 for now
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
# Include both modelhub and rf3 (if using monorepo approach)
|
||||
packages = ["src/modelhub"]
|
||||
# OR include releases/rf3/src/rf3 via custom build hook
|
||||
|
||||
# For development: force-include releases
|
||||
[tool.hatch.build.targets.wheel.force-include]
|
||||
"releases/rf3/src/rf3" = "modelhub/models/rf3" # If using monorepo approach
|
||||
```
|
||||
|
||||
#### Step 4.2: Create Main CLI (If Needed)
|
||||
|
||||
**Option 1**: Keep `rf3` CLI pointing directly to RF3
|
||||
```toml
|
||||
[project.scripts]
|
||||
rf3 = "rf3.cli:app"
|
||||
```
|
||||
|
||||
**Option 2**: Create dispatcher CLI (future-proof)
|
||||
|
||||
**`src/modelhub/cli.py`**:
|
||||
```python
|
||||
"""Main ModelForge CLI."""
|
||||
|
||||
import typer
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
# Import RF3 CLI if available
|
||||
try:
|
||||
from rf3.cli import app as rf3_app
|
||||
# Expose RF3 commands at top level
|
||||
for command_name, command in rf3_app.registered_commands:
|
||||
app.command(name=command_name)(command.callback)
|
||||
except ImportError:
|
||||
pass # RF3 not installed
|
||||
|
||||
if __name__ == "__main__":
|
||||
app()
|
||||
```
|
||||
|
||||
Then:
|
||||
```toml
|
||||
[project.scripts]
|
||||
modelforge = "modelhub.cli:app"
|
||||
rf3 = "rf3.cli:app" # Keep for backward compatibility
|
||||
```
|
||||
|
||||
### Phase 5: Handle RF3 Package Integration
|
||||
|
||||
**Decision needed**: Choose approach from Phase 2, Part 2.
|
||||
|
||||
**Recommended: Option C (Hybrid)**
|
||||
|
||||
Use hatch build hooks to include RF3 from `releases/`:
|
||||
|
||||
**`hatch_build.py`** (in project root):
|
||||
```python
|
||||
"""Custom hatch build hook to include RF3 from releases/."""
|
||||
|
||||
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
|
||||
|
||||
class CustomBuildHook(BuildHookInterface):
|
||||
def initialize(self, version, build_data):
|
||||
"""Include RF3 from releases/ directory."""
|
||||
if self.target_name == "wheel":
|
||||
# Add releases/rf3/src/rf3 to the wheel as modelhub/models/rf3
|
||||
build_data["force_include"] = {
|
||||
"releases/rf3/src/rf3": "modelhub/models/rf3"
|
||||
}
|
||||
```
|
||||
|
||||
**In `pyproject.toml`**:
|
||||
```toml
|
||||
[tool.hatch.build.hooks.custom]
|
||||
path = "hatch_build.py"
|
||||
```
|
||||
|
||||
This way:
|
||||
- Development happens in `releases/rf3/`
|
||||
- Installation includes RF3 in the package
|
||||
- Users can `pip install modelforge[rf3]`
|
||||
|
||||
### Phase 6: Update Tests
|
||||
|
||||
#### Step 6.1: Update Test Imports
|
||||
|
||||
**In `releases/rf3/tests/`**:
|
||||
```python
|
||||
# Update conftest.py and test files
|
||||
from modelhub.metrics.metric import Metric
|
||||
from modelhub.callbacks.callback import BaseCallback
|
||||
```
|
||||
|
||||
**In root `tests/`**:
|
||||
```python
|
||||
# These already test shared utilities
|
||||
from modelhub.utils.weights import load_checkpoint
|
||||
from modelhub.metrics.metric import MetricManager
|
||||
```
|
||||
|
||||
#### Step 6.2: Verify Test Execution
|
||||
|
||||
```bash
|
||||
# Test shared utilities
|
||||
pytest tests/
|
||||
|
||||
# Test RF3 (from releases/rf3/)
|
||||
cd releases/rf3/
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
### Phase 7: Update Documentation
|
||||
|
||||
#### Step 7.1: Update CLAUDE.md
|
||||
|
||||
Update import examples and package structure documentation.
|
||||
|
||||
#### Step 7.2: Update README.md
|
||||
|
||||
```markdown
|
||||
# ModelForge
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
# Base framework only
|
||||
pip install modelforge
|
||||
|
||||
# With RF3
|
||||
pip install modelforge[rf3]
|
||||
|
||||
# Development
|
||||
pip install -e .[rf3,dev]
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```python
|
||||
# Import base framework
|
||||
from modelforge.callbacks import BaseCallback
|
||||
from modelforge.metrics import Metric
|
||||
|
||||
# Import RF3
|
||||
from modelforge.models.rf3 import RF3
|
||||
```
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Migration Checklist
|
||||
|
||||
### Pre-Migration
|
||||
- [ ] Backup current code
|
||||
- [ ] Create feature branch: `git checkout -b refactor/restructure-package`
|
||||
- [ ] Document current import patterns for reference
|
||||
|
||||
### Phase 1: Package Rename
|
||||
- [ ] Decision: Keep `modelhub` or rename to `modelforge`?
|
||||
- [ ] Update `pyproject.toml` name field
|
||||
- [ ] Update all imports if renaming
|
||||
|
||||
### Phase 2: Restructure src/
|
||||
- [ ] Rename `callbacks/base.py` → `callbacks/callback.py`
|
||||
- [ ] Rename `metrics/base.py` → `metrics/metric.py`
|
||||
- [ ] Create/update all `__init__.py` files with proper exports
|
||||
- [ ] Update internal imports within src/modelhub/
|
||||
|
||||
### Phase 3: Fix RF3 Imports
|
||||
- [ ] Find all `from metrics.base` imports in releases/rf3/
|
||||
- [ ] Replace with `from modelhub.metrics.metric`
|
||||
- [ ] Find all `from callbacks.base` imports
|
||||
- [ ] Replace with `from modelhub.callbacks.callback`
|
||||
- [ ] Update utility imports to use `modelhub.utils.*`
|
||||
- [ ] Test that RF3 can import all needed modules
|
||||
|
||||
### Phase 4: Fix Build Config
|
||||
- [ ] Update `pyproject.toml` project name
|
||||
- [ ] Split dependencies: core vs rf3-specific
|
||||
- [ ] Add `[project.optional-dependencies]` for rf3
|
||||
- [ ] Fix `[project.scripts]` CLI entry point
|
||||
- [ ] Update `[tool.hatch.build.targets.wheel]` packages
|
||||
- [ ] Add build hooks if using hybrid approach
|
||||
|
||||
### Phase 5: RF3 Integration
|
||||
- [ ] Decide on integration approach (A, B, or C)
|
||||
- [ ] Implement chosen approach
|
||||
- [ ] Test `pip install -e .` (base only)
|
||||
- [ ] Test `pip install -e .[rf3]` (with RF3)
|
||||
|
||||
### Phase 6: Update Tests
|
||||
- [ ] Update test imports
|
||||
- [ ] Run shared tests: `pytest tests/`
|
||||
- [ ] Run RF3 tests: `cd releases/rf3 && pytest tests/`
|
||||
- [ ] Fix any import errors
|
||||
|
||||
### Phase 7: Documentation
|
||||
- [ ] Update CLAUDE.md
|
||||
- [ ] Update README.md
|
||||
- [ ] Update any other documentation
|
||||
- [ ] Add migration notes for contributors
|
||||
|
||||
### Phase 8: Validation
|
||||
- [ ] Clean install test: `pip install -e .`
|
||||
- [ ] RF3 install test: `pip install -e .[rf3]`
|
||||
- [ ] Test CLI: `rf3 fold inputs=...`
|
||||
- [ ] Run full test suite
|
||||
- [ ] Test on clean environment
|
||||
|
||||
---
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Test Scenarios
|
||||
|
||||
#### Scenario 1: Base Install Only
|
||||
```bash
|
||||
# Clean environment
|
||||
python -m venv test-env
|
||||
source test-env/bin/activate
|
||||
pip install -e .
|
||||
|
||||
# Should work
|
||||
python -c "from modelhub.callbacks import BaseCallback; print('OK')"
|
||||
python -c "from modelhub.metrics import Metric; print('OK')"
|
||||
python -c "from modelhub.utils.ddp import RankedLogger; print('OK')"
|
||||
|
||||
# Should fail (RF3 not installed)
|
||||
python -c "from modelhub.models.rf3 import RF3" # ImportError expected
|
||||
```
|
||||
|
||||
#### Scenario 2: With RF3
|
||||
```bash
|
||||
# Clean environment
|
||||
python -m venv test-env
|
||||
source test-env/bin/activate
|
||||
pip install -e .[rf3]
|
||||
|
||||
# Should work
|
||||
python -c "from modelhub.models.rf3 import RF3; print('OK')"
|
||||
python -c "from rf3.cli import app; print('OK')"
|
||||
|
||||
# Test CLI
|
||||
rf3 fold inputs='releases/rf3/tests/data/5vht_from_json.json'
|
||||
```
|
||||
|
||||
#### Scenario 3: Development Workflow
|
||||
```bash
|
||||
# Install in development mode with RF3
|
||||
pip install -e .[rf3,dev]
|
||||
|
||||
# Make changes to releases/rf3/src/rf3/model/RF3.py
|
||||
# Changes should be immediately available
|
||||
|
||||
python -c "from modelhub.models.rf3.model import RF3" # Should see changes
|
||||
|
||||
# Run tests
|
||||
pytest releases/rf3/tests/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Risk Assessment
|
||||
|
||||
### High Risk Items
|
||||
1. **Breaking all existing imports** - Requires updating many files
|
||||
- Mitigation: Use automated search/replace, test thoroughly
|
||||
|
||||
2. **Build configuration complexity** - Hatch build hooks can be tricky
|
||||
- Mitigation: Start with simple approach, test installation frequently
|
||||
|
||||
3. **Package name change** - Breaking change for any external users
|
||||
- Mitigation: Coordinate with team, announce breaking change
|
||||
|
||||
### Medium Risk Items
|
||||
1. **CLI entry point changes** - Users may have scripts using old CLI
|
||||
- Mitigation: Keep backward-compatible entry points
|
||||
|
||||
2. **Test imports** - Many test files to update
|
||||
- Mitigation: Automated replacement, run tests incrementally
|
||||
|
||||
### Low Risk Items
|
||||
1. **Documentation updates** - Time-consuming but low risk
|
||||
2. **`__init__.py` exports** - Easy to fix if wrong
|
||||
|
||||
---
|
||||
|
||||
## Rollback Plan
|
||||
|
||||
If migration fails:
|
||||
1. `git checkout main` - Revert to pre-migration state
|
||||
2. Keep feature branch for later attempt
|
||||
3. Document what failed for next iteration
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
### Adding New Models
|
||||
Once restructured, adding a new model is straightforward:
|
||||
|
||||
```
|
||||
releases/
|
||||
├── rf3/ # Existing
|
||||
└── proteinmpnn/ # New model
|
||||
├── src/proteinmpnn/
|
||||
├── configs/
|
||||
└── tests/
|
||||
```
|
||||
|
||||
Then in `pyproject.toml`:
|
||||
```toml
|
||||
[project.optional-dependencies]
|
||||
proteinmpnn = [
|
||||
"proteinmpnn-specific-deps",
|
||||
]
|
||||
all = [
|
||||
"modelforge[rf3,proteinmpnn]",
|
||||
]
|
||||
```
|
||||
|
||||
### Splitting Into Separate Packages (Future)
|
||||
If models become too large, can later split:
|
||||
- `modelforge-core` - Base framework
|
||||
- `modelforge-rf3` - RF3 model
|
||||
- `modelforge-proteinmpnn` - ProteinMPNN model
|
||||
|
||||
But this is much more complex and not recommended initially.
|
||||
|
||||
---
|
||||
|
||||
## Questions to Resolve
|
||||
|
||||
1. **Package name**: Keep `modelhub` or rename to `modelforge`?
|
||||
- Recommendation: Rename to `modelforge`
|
||||
|
||||
2. **RF3 integration**: Which approach (A, B, or C)?
|
||||
- Recommendation: Option C (hybrid with build hooks)
|
||||
|
||||
3. **CLI structure**: Single entry point or keep separate?
|
||||
- Recommendation: Keep `rf3` command separate for now
|
||||
|
||||
4. **Base file naming**: Rename `base.py` to match content?
|
||||
- Recommendation: Yes, rename to `callback.py`, `metric.py`
|
||||
|
||||
5. **Version strategy**: Single version or per-model versions?
|
||||
- Recommendation: Single version (simpler)
|
||||
|
||||
---
|
||||
|
||||
## Success Criteria
|
||||
|
||||
- [ ] `pip install modelforge` works
|
||||
- [ ] `pip install modelforge[rf3]` works
|
||||
- [ ] All imports are unambiguous and correct
|
||||
- [ ] All tests pass
|
||||
- [ ] CLI works: `rf3 fold inputs=...`
|
||||
- [ ] Structure matches PyTorch Lightning patterns
|
||||
- [ ] Documentation is updated
|
||||
- [ ] No more `sys.path` manipulation needed
|
||||
|
||||
---
|
||||
|
||||
## Timeline Estimate
|
||||
|
||||
- Phase 1 (Rename decision): 1 hour
|
||||
- Phase 2 (Restructure src/): 2-3 hours
|
||||
- Phase 3 (Fix RF3 imports): 1-2 hours
|
||||
- Phase 4 (Build config): 2-3 hours
|
||||
- Phase 5 (RF3 integration): 2-4 hours
|
||||
- Phase 6 (Update tests): 1-2 hours
|
||||
- Phase 7 (Documentation): 1-2 hours
|
||||
- Phase 8 (Validation): 2-3 hours
|
||||
|
||||
**Total**: 12-20 hours of focused work
|
||||
|
||||
---
|
||||
|
||||
## Conclusion
|
||||
|
||||
This restructuring will:
|
||||
1. ✅ Fix all broken imports
|
||||
2. ✅ Follow industry best practices (Lightning pattern)
|
||||
3. ✅ Enable proper pip installation
|
||||
4. ✅ Support selective model installation
|
||||
5. ✅ Maintain development workflow in `releases/`
|
||||
6. ✅ Provide clear, unambiguous imports
|
||||
7. ✅ Scale to future models
|
||||
|
||||
The key is following PyTorch Lightning's pattern: base classes live WITH their implementations, no separate `contrib/` or overly complex directory nesting.
|
||||
@@ -1,7 +1,7 @@
|
||||
# Inference with RosettaFold3(RF3)
|
||||
|
||||
<div align="center">
|
||||
<img src="../../../docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
|
||||
<img src="./docs/_static/prot_dna.png" alt="Protein-DNA complex prediction" width="400">
|
||||
</div>
|
||||
|
||||
> [!IMPORTANT]
|
||||
@@ -282,7 +282,7 @@ rf3 fold inputs='docs/rf3/examples/7o1r_from_file.cif'
|
||||
```
|
||||
|
||||
<p align="center">
|
||||
<img src="../../../docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
|
||||
<img src="./docs/_static/7o1r_covalent_modification.png" alt="7o1r Covalent Modification" width="25%"/>
|
||||
</p>
|
||||
<p align="center">
|
||||
<em>Figure: `7o1r` structure showing N-glycosylation covalent modification prediction with RF3 and ground truth crystal structure.</em>
|
||||
@@ -4,8 +4,8 @@ from pathlib import Path
|
||||
from atomworks.common import parse_example_id
|
||||
from beartype.typing import Any
|
||||
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.io import (
|
||||
from callbacks.base import BaseCallback
|
||||
from rf3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
dump_structures,
|
||||
dump_trajectories,
|
||||
@@ -7,9 +7,9 @@ from atomworks.ml.utils import nested_dict
|
||||
from beartype.typing import Any, Literal
|
||||
from omegaconf import ListConfig
|
||||
|
||||
from modelhub.callbacks.base import BaseCallback
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from modelhub.utils.logging import (
|
||||
from callbacks.base import BaseCallback
|
||||
from utils.ddp import RankedLogger
|
||||
from utils.logging import (
|
||||
condense_count_columns_of_grouped_df,
|
||||
print_df_as_table,
|
||||
)
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import typer
|
||||
from hydra import compose, initialize_config_dir
|
||||
|
||||
from modelhub.inference import run_inference
|
||||
from rf3.inference import run_inference
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
@@ -20,7 +20,7 @@ from biotite.structure import AtomArray
|
||||
from jaxtyping import Bool, Float, Shaped
|
||||
from torch import Tensor
|
||||
|
||||
from modelhub.utils.torch_utils import assert_no_nans
|
||||
from utils.torch import assert_no_nans
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -6,7 +6,7 @@ from atomworks.ml.transforms._checks import check_atom_array_annotation
|
||||
from atomworks.ml.transforms.crop import compute_local_hash
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from modelhub.data.ground_truth_template import (
|
||||
from rf3.data.ground_truth_template import (
|
||||
FeaturizeNoisedGroundTruthAsTemplateDistogram,
|
||||
TokenGroupNoiseScaleSampler,
|
||||
af3_noise_scale_distribution_wrapped,
|
||||
@@ -99,8 +99,8 @@ from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters
|
||||
from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from modelhub.data.extra_xforms import CheckForNaNsInInputs
|
||||
from modelhub.data.pipeline_utils import (
|
||||
from rf3.data.extra_xforms import CheckForNaNsInInputs
|
||||
from rf3.data.pipeline_utils import (
|
||||
annotate_post_crop_hash,
|
||||
annotate_pre_crop_hash,
|
||||
build_ground_truth_distogram_transform,
|
||||
@@ -2,7 +2,7 @@ import math
|
||||
|
||||
import torch
|
||||
|
||||
from modelhub.flow_matching.rigid_utils import rot_vec_mul
|
||||
from rf3.flow_matching.rigid_utils import rot_vec_mul
|
||||
|
||||
|
||||
def centre(X_L, X_exists_L):
|
||||
@@ -2,8 +2,8 @@ import torch
|
||||
from beartype.typing import Any, Literal
|
||||
from jaxtyping import Float
|
||||
|
||||
from modelhub.data.rotation_augmentation import centre_random_augmentation
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from rf3.data.rotation_augmentation import centre_random_augmentation
|
||||
from utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
@@ -10,7 +10,7 @@ from dotenv import load_dotenv
|
||||
from hydra.utils import instantiate
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from modelhub.utils.logging import suppress_warnings
|
||||
from utils.logging import suppress_warnings
|
||||
|
||||
load_dotenv(override=True)
|
||||
|
||||
@@ -10,23 +10,23 @@ from atomworks.io.transforms.categories import category_to_dict
|
||||
from lightning.fabric import seed_everything
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from modelhub.inference_engines.base import InferenceEngine
|
||||
from modelhub.model.RF3 import ShouldEarlyStopFn
|
||||
from modelhub.utils.datasets import (
|
||||
from inference_engines.base import InferenceEngine
|
||||
from rf3.model.RF3 import ShouldEarlyStopFn
|
||||
from rf3.utils.datasets import (
|
||||
assemble_distributed_inference_loader_from_list_of_paths,
|
||||
)
|
||||
from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from modelhub.utils.inference import (
|
||||
from utils.ddp import RankedLogger, set_accelerator_based_on_availability
|
||||
from rf3.utils.inference import (
|
||||
apply_conformer_and_template_selections,
|
||||
build_file_paths_for_prediction,
|
||||
)
|
||||
from modelhub.utils.io import (
|
||||
from rf3.utils.io import (
|
||||
build_stack_from_atom_array_and_batched_coords,
|
||||
dump_structures,
|
||||
dump_trajectories,
|
||||
)
|
||||
from modelhub.utils.logging import print_config_tree
|
||||
from modelhub.utils.predicted_error import (
|
||||
from utils.logging import print_config_tree
|
||||
from rf3.utils.predicted_error import (
|
||||
annotate_atom_array_b_factor_with_plddt,
|
||||
compile_af3_confidence_outputs,
|
||||
get_mean_atomwise_plddt,
|
||||
@@ -173,17 +173,17 @@ class RF3InferenceEngine(InferenceEngine):
|
||||
"allowed_chain_types_for_conditioning": None, # Avoid random conditioning
|
||||
"protein_msa_dirs": [
|
||||
{
|
||||
"dir": "/projects/msa/hhblits",
|
||||
"dir": "/projects/msa/hhblits",
|
||||
"extension": ".a3m.gz",
|
||||
"directory_depth": 2,
|
||||
},
|
||||
{
|
||||
"dir": "/projects/msa/mmseqs_gpu",
|
||||
"dir": "/projects/msa/mmseqs_gpu",
|
||||
"extension": ".a3m.gz",
|
||||
"directory_depth": 2,
|
||||
},
|
||||
{
|
||||
"dir": "/projects/msa/lab",
|
||||
"dir": "/projects/msa/lab",
|
||||
"extension": ".a3m.gz",
|
||||
"directory_depth": 2,
|
||||
},
|
||||
@@ -2,14 +2,14 @@ import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import spearmanr
|
||||
|
||||
from modelhub.chemical import NFRAMES, NHEAVY, frame_indices
|
||||
from rf3.chemical import NFRAMES, NHEAVY, frame_indices
|
||||
|
||||
# TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP
|
||||
from modelhub.metrics.metric_utils import (
|
||||
from rf3.metrics.metric_utils import (
|
||||
compute_mean_over_subsampled_pairs,
|
||||
unbin_logits,
|
||||
)
|
||||
from modelhub.utils.frames import (
|
||||
from rf3.utils.frames import (
|
||||
get_frames,
|
||||
mask_unresolved_frames_batched,
|
||||
rigid_from_3_points,
|
||||
@@ -3,8 +3,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelhub.alignment import weighted_rigid_align
|
||||
from modelhub.training.checkpoint import activation_checkpointing
|
||||
from rf3.alignment import weighted_rigid_align
|
||||
from rf3.training.checkpoint import activation_checkpointing
|
||||
|
||||
|
||||
# resolve residue-level symmetries in native vs pred
|
||||
@@ -4,7 +4,7 @@ from typing import Any
|
||||
import torch
|
||||
from biotite.structure import AtomArrayStack
|
||||
|
||||
from modelhub.metrics.base import Metric
|
||||
from metrics.base import Metric
|
||||
|
||||
|
||||
class CountClashingChains(Metric):
|
||||
@@ -10,9 +10,9 @@ from biotite.structure import AtomArrayStack
|
||||
from einops import rearrange, repeat
|
||||
from jaxtyping import Bool, Float
|
||||
|
||||
from modelhub.loss.af3_losses import distogram_loss
|
||||
from modelhub.metrics.base import Metric
|
||||
from modelhub.utils.torch_utils import assert_no_nans
|
||||
from rf3.loss.af3_losses import distogram_loss
|
||||
from metrics.base import Metric
|
||||
from utils.torch import assert_no_nans
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -11,8 +11,8 @@ from beartype.typing import Any
|
||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
||||
from jaxtyping import Bool, Float, Int
|
||||
|
||||
from modelhub.metrics.base import Metric
|
||||
from modelhub.utils.ddp import RankedLogger
|
||||
from metrics.base import Metric
|
||||
from utils.ddp import RankedLogger
|
||||
|
||||
ranked_logger = RankedLogger(__name__, rank_zero_only=True)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user