diff --git a/.env b/.env index 18fb77e..4576273 100644 --- a/.env +++ b/.env @@ -6,11 +6,12 @@ # --- Mirrors to RCSB data --- + # The `PDB_MIRROR_PATH` is a path to a local mirror of the PDB database. It's # expected that you use the same saving conventions as the RCSB PDB, which means: # `1a2b` --> /path/to/pdb_mirror/a2/1a2b.cif.gz # To set up a mirror, you can use tha atomworks commandline: `atomworks pdb sync /path/to/mirror` -PDB_MIRROR_PATH= +PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_pdb # The `CCD_MIRROR_PATH` is a path to a local mirror of the CCD database. # It's expected that you use the same saving conventions as the RCSB CCD, which means: @@ -19,10 +20,36 @@ PDB_MIRROR_PATH= # If no mirror is provided, the internal biotite CCD will be used as a fallback. To provide a # custom CCD for a ligand, you can place it in the in the CCD mirror path following the CCDs pattern. # Example: /path/to/ccd_mirror/M/MYLIGAND1/MYLIGAND1.cif -CCD_MIRROR_PATH= +CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2025_07_13_ccd + +# --- Local MSA directories --- +LOCAL_MSA_DIRS=/projects/msa/hhblits,/projects/msa/mmseqs_gpu,/projects/msa/lab,/squash/mgnify_distill_rf3/msas # --- External tools --- + # The `X3DNA_PATH` is a path to the x3dna tool, which is used for DNA structure analysis. # Example: /path/to/x3dna-v2.4 -X3DNA_PATH= \ No newline at end of file +X3DNA_PATH= + +# The `HHFILTER_PATH` is a path to the hhfilter tool from the HH-suite, which is used for +# filtering MSAs to reduce redundancy. +# Example: /path/to/hhsuite/build/bin/hhfilter +HHFILTER_PATH=/net/software/hhsuite/build/bin/hhfilter + +# The `MMSEQS2_PATH` is a path to the mmseqs2 tool, which is used for fast sequence searching. +# Example: /path/to/mmseqs-gpu/bin/mmseqs +MMSEQS2_PATH=/net/software/mmseqs-gpu/bin/mmseqs + +# CollabFold MMseqs2 database paths for GPU and CPU usage. + +# Local access (preferred) +# NOTE: MMseqs2 databases are best stored on local drives of compute nodes for performance +COLABFOLD_LOCAL_DB_PATH_GPU=/local/colabfold/gpu +COLABFOLD_LOCAL_DB_PATH_CPU=/local/databases/colabfold/ + +# Network access (fallback; may cause IO-related issues) +COLABFOLD_NET_DB_PATH_GPU=/net/databases/colabfold/gpu +COLABFOLD_NET_DB_PATH_CPU=/net/databases/colabfold/ + + diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..ce55fbf --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,680 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +### Setup +It is **CRITICAL** that before you run **ANY** commands, you activate the python environment like below. +Otherwise, you will always run into import and package errors. +```bash +# IMPORTANT! ALWAYS ACTIVATE THE PYTHON ENVIRONMENT! +source .venv/bin/activate +``` + +## Coding Practices + +Read CAREFULLY the following coding practices. These are central tenants of the way we code. Whenver you write code, retroactively examine your code and ensure it conforms to the principles oulined below. +(1) Adhere to the do-not-repeat-yourself (DRY) practice. Break out common operations into shared functions or classes. +(2) Prefer functions over classes where possible; the more general the function, the better. Functional programming makes the code more extendable. +(3) Follow the YAGNI principle - "You Ain't Gunna Need It." Don't speculatively build functionality that I do not explicitly ask for. Prioritize code simplicity and brevity above all else. + +## Project Overview + +**ModelForge** is a repository of open-source neural networks for biomolecular structure prediction and design. The flagship model is **RosettaFold3 (RF3)**, a structure prediction network competitive with AlphaFold3. The repository uses a shared training harness and integrates with [AtomWorks](https://github.com/RosettaCommons/atomworks) for biomolecular data processing. + +### Current Status +The repository is undergoing active refactoring. Code is organized into: +- `releases/rf3/`: Stable RF3 model release with complete inference and training + - `src/rf3/`: RF3-specific implementation code + - `configs/`: Hydra configuration files for RF3 + - `tests/`: RF3 test suite +- `src/modelhub/`: Shared utilities and base classes used across all models +- `tests/`: Shared tests at the repository level +- `lib/`: Git submodules including AtomWorks + + +## Key Commands + +### Installation and Setup +```bash +# Clone and install +git clone https://github.com/RosettaCommons/modelforge.git +cd modelforge +uv python install 3.12 +uv venv --python 3.12 +source .venv/bin/activate +uv pip install -e . + +# Download RF3 weights +wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt +``` + +### Development Commands +```bash +# IMPORTANT: Always activate virtual environment first +source .venv/bin/activate + +# Code formatting and linting +make format # Format code using ruff (preferred) +ruff format src tests # Format code directly +ruff check --fix src tests # Lint and fix issues + +# Cleanup +make clean # Delete compiled and cached files +``` + +### Testing +```bash +# IMPORTANT: Always activate virtual environment first +source .venv/bin/activate + +# Run RF3-specific tests (from releases/rf3/) +cd releases/rf3/ +pytest tests/ # All RF3 tests +pytest tests/test_inference_regression.py # Inference regression tests +pytest tests/test_write_confidence.py # Confidence output tests + +# Run shared/root-level tests (from repository root) +cd /path/to/modelhub_latent +pytest tests/ # All shared tests +pytest tests/test_metrics.py # Metric tests +pytest tests/test_weight_loading.py # Weight loading tests +pytest tests/test_torch_utils.py # Torch utility tests + +# Run GPU-dependent tests (requires GPU) +pytest tests/test_inference_regression.py -m gpu + +# Run with verbose output +pytest tests/ -v +``` + +### Inference (RF3) +```bash +# IMPORTANT: Navigate to RF3 release directory +cd releases/rf3/ + +# Basic structure prediction +rf3 fold inputs='tests/data/5vht_from_json.json' + +# With MSA +rf3 fold inputs='../../docs/rf3/examples/3en2_from_json_with_msa.json' + +# Batch processing multiple files +rf3 fold inputs='[file1.cif, file2.json, file3.pdb]' +rf3 fold inputs='path/to/directory' # Process all CIF/PDB/JSON in directory + +# Advanced inference options +rf3 fold inputs='input.json' \ + ckpt_path='/path/to/rf3_latest.pt' \ + out_dir='./predictions' \ + n_recycles=10 \ + diffusion_batch_size=5 \ + num_steps=50 \ + annotate_b_factor_with_plddt=true \ + early_stopping_plddt_threshold=0.5 + +# Templating (fix specific regions during prediction) +rf3 fold inputs='input.cif' \ + template_selection='[A, B/*/1-42, B/*/49-63]' \ + ground_truth_conformer_selection='[C, D]' + +# Alternative: Direct Python invocation (more informative error messages) +python src/rf3/inference.py inputs='tests/data/5vht_from_json.json' +``` + +**Note**: RF3 uses Hydra for configuration, so arguments use `key=value` syntax (not `--key value`). + +**Selection Syntax** (for templating): `CHAIN/RES_NAME/RES_ID/ATOM_NAME` +- Exact: `A/ALA/15/CA` +- Wildcard: `A/*/*/CA` (all CA atoms in chain A) +- Range: `A/*/5-10` (residues 5-10 in chain A) +- Union: `A, B` (chains A and B) + +### Training (RF3) +```bash +# IMPORTANT: Navigate to RF3 release directory +cd releases/rf3/ + +# Train with specific experiment config +python src/rf3/train.py experiment=pretrained/rf3 + +# Override specific parameters +python src/rf3/train.py \ + experiment=pretrained/rf3 \ + trainer.max_steps=100000 \ + seed=42 + +# Resume from checkpoint +python src/rf3/train.py \ + experiment=pretrained/rf3 \ + ckpt_path=/path/to/checkpoint.ckpt + +# Debug mode (quick testing) +python src/rf3/train.py \ + experiment=pretrained/rf3 \ + debug=default +``` + +## Architecture + +### Package Structure +- `releases/rf3/`: Complete RF3 model release (self-contained) + - `src/rf3/`: RF3 implementation package + - `model/`: Neural network architecture (Pairformer, Diffusion Module, auxiliary heads) + - `layers/`: Network components (attention, triangle multiplication, outer product) + - `RF3.py`: Main model class + - `RF3_structure.py`: Diffusion module and structural components + - `data/`: Data pipelines and transformations + - `inference_engines/`: RF3 inference implementations + - `loss/`: Loss functions (diffusion loss, confidence loss, distogram loss) + - `metrics/`: RF3-specific evaluation metrics + - `trainers/`: Training logic using Lightning Fabric + - `training/`: Training utilities (EMA, schedulers, checkpoints) + - `utils/`: RF3-specific utilities + - `callbacks/`: RF3-specific callbacks + - `cli.py`: RF3 CLI entry point + - `inference.py`: Inference script + - `train.py`: Training script + - `validate.py`: Validation script + - `configs/`: Hydra configuration files for RF3 + - `model/`: Model architecture configs + - `trainer/`: Training configs (loss, metrics) + - `datasets/`: Dataset configs + - `inference_engine/`: Inference engine configs + - `callbacks/`, `logger/`, `paths/`, etc. + - `tests/`: RF3 test suite + - `test_inference_regression.py`: Regression tests + - `test_write_confidence.py`: Confidence output tests + - `data/`: Test data and baselines +- `src/modelhub/`: Shared utilities across all models + - `callbacks/`: Shared training callbacks (health logging, timing, base classes) + - `inference_engines/`: Base inference engine interface + - `metrics/`: Shared evaluation metrics (base classes, common metrics) + - `trainers/`: Shared training infrastructure (Fabric wrappers) + - `utils/`: Common utilities (weights, logging, instantiators, DDP, torch utils) + - `hydra/`: Hydra resolvers and utilities +- `tests/`: Repository-level shared tests + - `test_metrics.py`: Shared metric tests + - `test_weight_loading.py`: Weight loading tests + - `test_torch_utils.py`: Torch utility tests +- `lib/`: Git submodules (AtomWorks) +- `docs/`: Documentation and examples + +### Key Concepts +- **RF3 Architecture**: Pairformer (token-level) → Diffusion Module (atom-level) + - Token-level representations: Single (`S`, I×C_s), Pair (`Z`, I×I×C_z) + - Atom-level representations: Single (`Q`, L×C_atom), Pair (`P`, L×L×C_atompair) + - Distogram head for token-level distance predictions +- **Hydra Configuration**: Composable configs with defaults and overrides +- **Lightning Fabric**: Distributed training with DDP support +- **AtomWorks Integration**: Unified data processing for structures, MSAs, templates +- **Input Flexibility**: Supports CIF, PDB, JSON, SMILES, and CCD codes + +### Development Environment +- Python 3.12 required (3.11+ supported) +- Uses `ruff` for linting and formatting (configured in pyproject.toml) +- Testing with `pytest` including GPU-specific tests +- Environment variables: Configure in `.env` file (see `.env` template) + - `PDB_MIRROR_PATH`: Local PDB mirror for training data + - `CCD_MIRROR_PATH`: Chemical Component Dictionary mirror + - `LOCAL_MSA_DIRS`: MSA search directories + - Tool paths: `HHFILTER_PATH`, `MMSEQS2_PATH`, etc. + +### Data Dependencies +The training and inference pipelines expect: +- PDB mirror with mmCIF files in standard RCSB sharding pattern +- CCD mirror for small molecule definitions +- Optional MSA data for protein chains +- Pre-computed metadata as parquet files (for training) + +## RF3 Input Formats + +### JSON Format +```json +{ + "name": "example_name", + "components": [ + { + "seq": "MKTAYIA...", // Protein/NA sequence (supports non-canonical) + "msa_path": "path/to/msa.a3m", // Optional MSA (a3m or fasta) + "chain_id": "A" // Optional chain ID + }, + { + "smiles": "CC(=O)O" // Small molecule via SMILES + }, + { + "ccd_code": "HEM" // Chemical Component Dictionary code + }, + { + "path": "ligand.sdf" // Structure file (SDF/CIF) + } + ], + "bonds": [ // Optional covalent modifications + ["A/ASN/133/ND2", "B/NAG/1/C1"] + ], + "template_selection": ["A/*/1-50"], // Optional token-level templating + "ground_truth_conformer_selection": ["C"] // Optional atom-level templating +} +``` + +### CIF/PDB Files +Standard RCSB format with optional MSA specification in CIF header: +```cif +data_3EN2 +_msa_paths_by_chain_id.A path/to/msa_A.a3m.gz +_msa_paths_by_chain_id.B path/to/msa_B.a3m.gz +``` + +## RF3 Outputs + +Inference produces: +- `{name}_model_{i}.cif.gz`: Predicted structures (gzipped mmCIF, one per diffusion seed) +- `{name}_metrics.csv`: Overall confidence metrics (pTM, ipTM, pLDDT, etc.) +- `{name}.score`: Detailed per-residue confidence scores + +All CIF outputs can be directly opened in PyMol or parsed with AtomWorks. + +## Testing + +RF3 tests are located in `releases/rf3/tests/`, while shared tests are in the root `tests/` directory. + +- **RF3 Regression tests**: Compare inference outputs against frozen baselines + - Location: `releases/rf3/tests/test_inference_regression.py` + - Test data: `releases/rf3/tests/data/` (mini examples and baseline predictions) + - Run from: `releases/rf3/` directory +- **RF3 Confidence tests**: Test confidence output writing + - Location: `releases/rf3/tests/test_write_confidence.py` +- **Shared Unit tests**: Test shared utilities across models + - Location: Root `tests/` directory + - `test_metrics.py`: Shared metrics + - `test_weight_loading.py`: Weight loading utilities + - `test_torch_utils.py`: Torch utility functions +- **GPU tests**: Marked with `@pytest.mark.gpu` decorator +- **Test execution**: Navigate to appropriate directory before running pytest + +## Git Workflow + +- **Main branch**: `trunk` (use for PRs, **not** `main`) +- Current refactoring branch: `refactor/rf3-lab` +- The repository is actively being refactored; expect API changes + +## Common Development Patterns + +### Adding New Metrics +1. **RF3-specific**: Implement in `releases/rf3/src/rf3/metrics/` +2. **Shared across models**: Implement in `src/modelhub/metrics/` +3. Inherit from `BaseMetric` in `src/modelhub/metrics/base.py` +4. Register in appropriate config file (`releases/rf3/configs/trainer/metrics/`) +5. Add tests: + - RF3-specific tests in `releases/rf3/tests/` + - Shared tests in root `tests/test_metrics.py` + +### Adding New Callbacks +1. **RF3-specific**: Implement in `releases/rf3/src/rf3/callbacks/` +2. **Shared across models**: Implement in `src/modelhub/callbacks/` +3. Inherit from `BaseCallback` in `src/modelhub/callbacks/base.py` +4. Register in `releases/rf3/configs/callbacks/` +5. Hook into training loop via callback methods (`on_train_batch_end`, etc.) + +### Modifying Model Architecture +1. Edit modules in `releases/rf3/src/rf3/model/` +2. Update corresponding configs in `releases/rf3/configs/model/` +3. Verify with regression tests from `releases/rf3/`: + ```bash + cd releases/rf3/ + pytest tests/test_inference_regression.py + ``` +4. Check weight loading from root: + ```bash + pytest tests/test_weight_loading.py + ``` + +### Working with Hydra Configs +```bash +# Override nested config values +python script.py model.c_s=512 model.c_z=256 + +# Override list items (must quote) +python script.py inputs='[file1.cif, file2.json]' + +# Change config group defaults +python script.py inference_engine=rf3 trainer=ddp + +# Compose multiple configs +defaults: + - trainer: rf3 + - model: rf3 + - datasets: pdb_and_distillation + - _self_ +``` + +## Performance Tips + +- **Early stopping**: Set `early_stopping_plddt_threshold=0.5` to skip low-confidence predictions (10-20x faster) +- **Batch processing**: Use multiple inputs in single command to amortize startup cost +- **Diffusion steps**: Reduce `num_steps=50` (from default 200) for 2x speedup with minimal quality loss +- **Recycling**: Default `n_recycles=10`; reduce for faster inference + +## Troubleshooting + +### Import Errors +- Ensure virtual environment is activated: `source .venv/bin/activate` +- Check you're in correct directory: + - For RF3 tests/inference/training: `cd releases/rf3/` + - For shared tests: Stay in repository root +- The `rf3` CLI command (via `pyproject.toml`) should work from any directory once installed + +### Missing Data +- Set `PDB_MIRROR_PATH` and `CCD_MIRROR_PATH` in `.env` +- For MSA: Set `LOCAL_MSA_DIRS` or use `raise_if_missing_msa_for_protein_of_length_n=0` to require MSAs + +### CUDA/GPU Issues +- Check GPU availability: `python -c "import torch; print(torch.cuda.is_available())"` +- Set precision: `torch.set_float32_matmul_precision("medium")` + +### Hydra Configuration Errors +- Use `=` not `--` for arguments: `inputs='file.cif'` not `--inputs file.cif` +- Quote lists and strings: `inputs='[file1, file2]'` +- For detailed errors, use Python directly: `python src/modelhub/inference.py ...` + +## Docstring Guidelines + +Follow Google-style docstrings with Sphinx optimization. These comprehensive guidelines ensure consistent, high-quality documentation across the codebase. + +### Primary Goals +- **Concise**: Keep docstrings as short as possible while being clear +- **No redundancy**: Don't repeat function/class names, types when annotated, or obvious behavior +- **Sphinx-first**: Prefer reStructuredText roles/directives that render beautifully +- **Google section headers**: Use standard section names with colons (e.g., `Args:`), not underlined headings + +### When to Include Sections +- **Args**: Include only if non-trivial or non-obvious. Omit types when PEP 484 annotations are present +- **Returns**: + - Omit if the function returns `None` + - Omit if the summary sentence fully describes the return + - Otherwise include, using rST for clarity (including literal blocks when helpful) +- **Yields**: Use instead of Returns for generators +- **Raises**: Include only unusual, explicitly raised exceptions that matter to users +- **Examples**: Strongly encouraged when usage isn't obvious. See "Examples formatting" below +- **References**: Include when citing standards, papers, or external docs; also when adding rST link targets +- **Todo**: Allowed; requires the Sphinx `sphinx.ext.todo` extension +- For classes, put argument documentation in `__init__`, not in the class docstring; the class docstring is a high-level overview + +### General Formatting Rules +- **One-line summary**: Imperative, single sentence, ends with a period +- **Second paragraph**: Optional short elaboration only if genuinely useful +- **Inline code**: Use double backticks for identifiers and literals, e.g., ``"same"`` or ``pathlib.Path`` +- **Cross-references**: Prefer semantic roles (see "Cross-referencing rules" below) +- **Line length**: Wrap naturally; don't force hard 79-char wraps if it harms readability +- **Defaults**: Use "Defaults to X." at end of the parameter description +- **None vs Optional**: Prefer "Defaults to None."; avoid repeating "Optional" if type hints already indicate optionality + +### Section Ordering (Typical) +1. Summary +2. Optional elaboration +3. Args +4. Returns or Yields +5. Raises +6. Notes / Warnings (only if needed) +7. Examples +8. References +9. See Also +10. Todo +11. Attributes (for modules/classes only when helpful; not for `__init__` args) + +### Args Formatting +Use the Google format; types omitted if PEP 484 present. Keep descriptions concise. + +Example with type annotations: +```python +def fn(a: int, path: str | None = None): + """Do X. + + Args: + a: Number of items to process. + path: Optional file path. Defaults to ``None``. + """ +``` + +Without type annotations: +```python +def fn(a, path=None): + """Do X. + + Args: + a (int): Number of items to process. + path (str, optional): Optional file path. Defaults to ``None``. + """ +``` + +### Returns / Yields Formatting +Omit if returning `None` or already clear from the summary. Otherwise, short text; use rST formatting when useful. + +Examples: +```python +def compute() -> dict[str, int]: + """Compute counts. + + Returns: + Mapping of names to counts. + + The ``Returns`` section supports any reStructuredText formatting, + including literal blocks:: + + { + 'param1': 1, + 'param2': 2 + } + """ +``` + +```python +def stream(n: int): + """Yield integers up to ``n``. + + Yields: + int: Next value in ``range(n)``. + """ +``` + +### Raises Formatting +Include only unusual, explicit exceptions that users should anticipate. +```python +Raises: + ValueError: If ``threshold`` is negative. +``` + +### Examples Formatting +- Use the `Examples:` section +- Prefer doctest-style for simple usage with expected output +- Use `.. code-block:: python` for multi-line, setup-heavy, or non-doctest examples +- Precede each code block with a short sentence describing the case +- Keep examples minimal; 1–3 cases is typical. Avoid redundant examples +- Don't mix doctest and code-block in the same example unless clearly justified + +Doctest example: +```python +def square(x: int) -> int: + """Return the square of ``x``. + + Examples: + >>> square(3) + 9 + """ +``` + +Multi-case using code blocks: +```python +def from_selection_str(s: str): + """Create an ``AtomSelection`` from a selection string. + + Examples: + Select CA at chain A, residue 1:: + + AtomSelection.from_selection_str("A/ALA/1/CA") + + Select CB in any chain at any residue:: + + AtomSelection.from_selection_str("*/ALA/*/CB") + """ +``` + +Or with an explicit directive: +```python +Examples: + Basic usage: + + .. code-block:: python + + result = fn("input") + print(result) +``` + +### References Formatting +Use the `References:` section to: +- Cite external resources (standards, papers, docs) +- Define hyperlink targets (label definitions) for clean inline linking in the docstring +- Place external link definitions directly in the `References:` section or right after it + +Examples: +```python +def parse(): + """Parse input. + + References: + * `Google Python Style Guide`_ + * `PEP 484`_ + + .. _Google Python Style Guide: http://google.github.io/styleguide/pyguide.html + .. _PEP 484: https://www.python.org/dev/peps/pep-0484/ + """ +``` + +### Cross-referencing Rules (Use Liberally) +Prefer semantic roles over plain backticks for important code entities: +- **Functions**: `:py:func:`package.module.fn`` (use `~` to shorten) +- **Classes**: `:py:class:`package.module.Class``; methods: `:py:meth:`~package.module.Class.method`` +- **Modules**: `:py:mod:`package.module``; **Attributes**: `:py:attr:`~package.module.Class.attr`` +- Use `:ref:` for intra-doc labels (sections/figures/tables) +- Use `:doc:` to link other documents by path +- Use `:any:` when unsure of the role; Sphinx will try to resolve it +- Prefer cross-references over repeating type info or behavior + +### Notes, Warnings, and See Also +Keep these sparse; only when clarity is improved. Prefer Google sections (`Notes:`, `Warnings:`, `See Also:`). If a visual callout is needed, you may use admonitions (`.. note::`, `.. warning::`) inside the docstring; keep content concise. + +Example: +```python +Notes: + Uses :py:mod:`asyncio` for scheduling. + +See Also: + :py:class:`~package.Scheduler`, :py:func:`~package.schedule_task` +``` + +### Class and __init__ Rules +- **Class docstring**: High-level overview and purpose; avoid argument docs +- **__init__ docstring**: Document constructor parameters in `Args:`; follow the same formatting rules as functions + +```python +class Cache: + """In-memory cache with time-based eviction. + + See Also: + :py:class:`~package.LRUCache` + """ + + def __init__(self, ttl: float, capacity: int = 1024): + """Initialize the cache. + + Args: + ttl: Time-to-live in seconds. + capacity: Max entries. Defaults to ``1024``. + """ +``` + +### Module Docstrings +Short overview of module purpose and key public objects. Optionally an `Attributes:` section for significant module-level constants. + +```python +"""Tools for model inference and orchestration. + +Attributes: + DEFAULT_TIMEOUT (float): Default timeout for network calls. + +Todo: + * Add support for streaming outputs. + * Enable tracing hooks. +""" +``` + +### Style/Consistency Conventions +- Use double backticks for inline code/literals: ``None``, ``"text"``, ``/path`` +- Use present tense and active voice +- Avoid restating obvious types when PEP 484 annotations exist +- Prefer short paragraphs and bulleted lists only when they clarify meaning +- Keep "Examples" and "References" well-formed and visually separated + +### Quick Templates + +Function (concise): +```python +def fn(x: int) -> int: + """Return the square of ``x``. + + Examples: + >>> fn(3) + 9 + """ +``` + +Function (fuller): +```python +def load(path: str, *, strict: bool = False) -> dict: + """Load a configuration from ``path``. + + Args: + path: File path to load. + strict: Validate schema strictly. Defaults to ``False``. + + Returns: + Mapping representing the configuration. + + Raises: + FileNotFoundError: If the file does not exist. + + Examples: + Strict load with error handling: + + .. code-block:: python + + cfg = load("config.yaml", strict=True) + + References: + * `YAML Spec`_ + + .. _YAML Spec: https://yaml.org/spec/ + """ +``` + +Class + __init__: +```python +class Runner: + """Execute tasks with configurable concurrency.""" + + def __init__(self, max_workers: int = 4): + """Initialize the runner. + + Args: + max_workers: Number of worker threads. Defaults to ``4``. + """ +``` + +Generator: +```python +def items(): + """Yield items from the queue. + + Yields: + Item type: Next available item. + """ +``` diff --git a/Makefile b/Makefile index de87d23..7aa13e1 100644 --- a/Makefile +++ b/Makefile @@ -4,6 +4,14 @@ # COMMANDS # ################################################################################# +## Delete all compiled and cached files +clean: + find . -type f -name "*.py[co]" -delete + find . -type d -name "__pycache__" -delete + rm -rf .pytest_cache + rm -rf .ruff_cache + rm -rf .benchmarks + ## Format src directory using black format: ruff format src tests diff --git a/RESTRUCTURING_PLAN.md b/RESTRUCTURING_PLAN.md new file mode 100644 index 0000000..c24668c --- /dev/null +++ b/RESTRUCTURING_PLAN.md @@ -0,0 +1,919 @@ +# ModelForge Repository Restructuring Plan + +**Date**: October 1, 2025 +**Author**: Analysis based on PyTorch Lightning patterns +**Goal**: Reorganize `src/modelhub/` following best practices while keeping `releases/` structure intact + +--- + +## Executive Summary + +This plan addresses three critical issues in the current structure: +1. **Broken imports**: RF3 uses `from metrics.base import Metric` but base classes are in `src/modelhub/` +2. **CLI entry point mismatch**: `pyproject.toml` points to non-existent `modelhub.cli:app` +3. **Package installation confusion**: Only `src/modelhub` is packaged, but RF3 code is in `releases/` +4. **Unclear organization**: Mixed base classes and implementations without clear semantic separation + +**Core Principle**: Follow PyTorch Lightning's pattern - base classes live WITH their implementations, no separate `contrib/` directory. + +--- + +## Current Structure Analysis + +### What Exists Now +``` +modelhub_latent/ +├── pyproject.toml # name = "rf3" (wrong!) +├── src/ +│ └── modelhub/ # Only this is packaged +│ ├── callbacks/ +│ │ ├── base.py # BaseCallback +│ │ ├── health_logging.py # Implementation +│ │ ├── timing_logging.py # Implementation +│ │ └── train_logging.py # Implementation +│ ├── metrics/ +│ │ ├── base.py # Metric + MetricManager +│ │ ├── chiral.py # Implementation +│ │ └── rasa.py # Implementation +│ ├── utils/ +│ ├── trainers/ +│ └── hydra/ +├── releases/ +│ └── rf3/ # NOT packaged! +│ ├── src/rf3/ +│ │ ├── callbacks/ +│ │ ├── metrics/ +│ │ ├── model/ +│ │ └── cli.py +│ ├── configs/ +│ └── tests/ +└── tests/ # Shared tests +``` + +### Critical Problems + +#### Problem 1: Import Path Mismatch +**RF3 code uses**: +```python +from metrics.base import Metric # Expects local resolution +from callbacks.base import BaseCallback # Expects local resolution +``` + +**Reality**: +- Base classes are in `src/modelhub/metrics/base.py` +- This only works via sys.path manipulation or broken imports + +**Should be**: +```python +from modelhub.metrics.metric import Metric +from modelhub.callbacks.callback import BaseCallback +``` + +#### Problem 2: CLI Entry Point +**Current `pyproject.toml`**: +```toml +[project.scripts] +rf3 = "modelhub.cli:app" +``` + +**Reality**: There is NO `src/modelhub/cli.py` file! + +**Actual CLI**: Located at `releases/rf3/src/rf3/cli.py` + +#### Problem 3: Package Name Confusion +**Current**: `name = "rf3"` in pyproject.toml + +**Problem**: +- Repository is called "modelforge" +- Users would `pip install rf3` (gets only base framework) +- Can't do `pip install modelforge[rf3]` for selective installation + +#### Problem 4: Packaging Scope +**Current**: `packages = ["src/modelhub"]` + +**Problem**: RF3 code in `releases/rf3/src/rf3/` is NOT included in the package! + +--- + +## Design Goals + +### User Experience Goals +1. `pip install modelforge` → Get base framework only +2. `pip install modelforge[rf3]` → Get base + RF3 +3. `pip install modelforge[all]` → Get all models (future-proof) +4. Development: `pip install -e .[rf3,dev]` + +### Code Organization Goals +1. Follow PyTorch Lightning patterns (familiar to ML researchers) +2. Clear distinction between base classes and implementations +3. Proper namespacing for imports +4. Maintain `releases/` structure for development workflow + +### Import Pattern Goals +```python +# Clear, unambiguous imports +from modelforge.callbacks.callback import BaseCallback +from modelforge.metrics.metric import Metric +from modelforge.utils.ddp import RankedLogger + +# RF3-specific imports +from modelforge.models.rf3.model import RF3 +``` + +--- + +## PyTorch Lightning Pattern Analysis + +### Lightning's Structure (Reference) +``` +lightning/pytorch/ +├── core/ # Primary model abstractions +│ ├── module.py # LightningModule (what users inherit from) +│ ├── datamodule.py # LightningDataModule +│ └── hooks.py +├── callbacks/ # Base + implementations together +│ ├── callback.py # Callback base class +│ ├── model_checkpoint.py # ModelCheckpoint implementation +│ ├── early_stopping.py # EarlyStopping implementation +│ └── __init__.py # Exports everything +├── loggers/ # Base + implementations together +│ ├── logger.py # Logger base class +│ ├── wandb.py # WandbLogger implementation +│ └── __init__.py +└── utilities/ # Pure utilities +``` + +### Key Insights from Lightning + +1. **`core/` is for PRIMARY model abstractions** - Things users inherit from to define their model: + - `LightningModule` - "Your model IS A LightningModule" + - `LightningDataModule` - "Your data IS A LightningDataModule" + +2. **Other directories contain base + implementations together**: + - `callbacks/callback.py` (base) + `callbacks/model_checkpoint.py` (impl) + - `loggers/logger.py` (base) + `loggers/wandb.py` (impl) + +3. **NO `contrib/` directory** - That's a Django pattern, not Lightning! + +4. **`__init__.py` exports everything** for clean imports + +### When to Use `core/` + +**Use `core/` IF**: +- You have a primary model class users inherit from (like `LightningModule`) +- These classes define "what you ARE building" (semantic role) + +**Don't use `core/` IF**: +- All your base classes are for plugins/extensions (callbacks, metrics) +- You don't have a central model abstraction + +**For ModelForge**: Since there's no primary "ModelBase" class, **we don't need `core/`**. + +### Re-evaluating `inference_engines/` + +**Current situation**: +- `InferenceEngine` is a minimal ABC with only 2 abstract methods (`__init__`, `eval`) +- Only RF3 uses it (no other models yet) +- It's a very thin abstraction that provides minimal value + +**Question**: Do we need this abstraction at all? + +**Arguments for keeping it**: +- Future-proofs for when other models are added +- Provides a consistent interface pattern + +**Arguments for removing it**: +- YAGNI (You Aren't Gonna Need It) - only one implementation exists +- Adds indirection without clear benefit +- The abstraction is so minimal it doesn't enforce meaningful constraints +- Each model's inference is likely to be sufficiently different that a shared interface adds little value + +**Recommendation**: **Remove the `inference_engines/` directory entirely from `src/modelhub/`** +- RF3's inference engine can live solely in `models/rf3/src/rf3/inference_engines/rf3.py` +- Remove the ABC inheritance - just make it a standalone class +- When/if a second model is added, we can evaluate whether a shared abstraction makes sense based on actual commonalities +- This follows YAGNI and keeps the code simpler + +--- + +## Recommended Solution + +### Two-Part Restructuring + +#### Part 1: Fix `src/modelhub/` Structure (Simple, Lightning-style) + +``` +src/ +└── modelhub/ (rename to modelforge?) + ├── __init__.py # Export key base classes + │ + ├── callbacks/ + │ ├── __init__.py # Export Callback + all implementations + │ ├── callback.py # BaseCallback (RENAMED from base.py) + │ ├── health_logging.py # HealthLoggingCallback + │ ├── timing_logging.py # TimingLoggingCallback + │ └── train_logging.py # TrainLoggingCallback + │ + ├── metrics/ + │ ├── __init__.py # Export Metric + all implementations + │ ├── metric.py # Metric, MetricManager (RENAMED from base.py) + │ ├── chiral.py # ChiralMetric + │ └── rasa.py # RASAMetric + │ + ├── trainers/ + │ ├── __init__.py + │ └── fabric.py # Fabric trainer wrapper + │ + ├── utils/ # Pure utilities (no base classes) + │ ├── ddp.py + │ ├── logging.py + │ ├── weights.py + │ ├── instantiators.py + │ └── torch.py + │ + └── hydra/ # Hydra utilities + ├── __init__.py + └── resolvers.py +``` + +**Key changes**: +1. Rename `base.py` files to match their content: + - `callbacks/base.py` → `callbacks/callback.py` + - `metrics/base.py` → `metrics/metric.py` +2. **REMOVE** `inference_engines/` entirely (not needed with only one model) +3. Keep implementations WITH base classes (Lightning pattern) +4. No `core/` directory (not needed without primary model abstraction) +5. Proper `__init__.py` exports + +#### Part 2: Integrate RF3 into Main Package (For pip install) + +**Decision Point**: Choose ONE of these approaches: + +##### Option A: Move RF3 into src/modelhub/ (Monorepo style) +``` +src/ +└── modelhub/ + ├── callbacks/, metrics/, utils/ (as above) + └── models/ # NEW + └── rf3/ # Moved from releases/rf3/src/rf3/ + ├── __init__.py + ├── model/ + ├── data/ + ├── metrics/ # RF3-specific metrics + ├── callbacks/ # RF3-specific callbacks + ├── inference_engines/ + └── cli.py +``` + +**Pros**: +- Simple pip install: `pip install modelforge[rf3]` +- Single package, single version +- Easy code sharing between models + +**Cons**: +- Changes development workflow (no separate `releases/` for development) +- RF3 code always in source tree even if not installed + +##### Option B: Keep releases/ separate (Development friendly) +``` +modelhub_latent/ +├── src/modelhub/ # Base framework only +├── releases/ +│ └── rf3/ # Stays as is for development +│ └── src/rf3/ +└── pyproject.toml # Use optional dependencies +``` + +**Then in `pyproject.toml`**: +```toml +[project.optional-dependencies] +rf3 = [ + "modelforge-rf3 @ file:///${PROJECT_ROOT}/releases/rf3", + # RF3-specific deps +] +``` + +**Pros**: +- Keeps development workflow unchanged +- Clear separation for releases + +**Cons**: +- More complex build setup +- Requires editable installs for development +- Each release needs own `pyproject.toml` + +##### Option C: Hybrid - Move RF3 but keep releases/ for development (RECOMMENDED) +``` +# For development: work in releases/rf3/ +releases/rf3/ +├── src/rf3/ # Development happens here +├── configs/ # RF3 configs here +└── tests/ # RF3 tests here + +# For installation: symlink or copy during build +src/modelhub/models/rf3/ → symlink to releases/rf3/src/rf3/ + +# Or use build hooks to include releases/ in package +``` + +**Implementation**: Use hatch build hooks to include `releases/rf3/src/rf3/` as `src/modelhub/models/rf3/` + +--- + +## Detailed Implementation Plan + +### Phase 1: Rename Package (Breaking Change Decision) + +**Decision needed**: Keep `modelhub` or rename to `modelforge`? + +**Arguments for `modelforge`**: +- Matches repository name +- More descriptive (framework for building models) +- Enables `pip install modelforge[rf3]` + +**Arguments for `modelhub`**: +- Less renaming work +- Existing code already uses it + +**Recommendation**: Rename to `modelforge` for clarity and marketing. + +### Phase 2: Restructure src/modelhub/ (src/modelforge/) + +#### Step 2.1: Rename Base Class Files +```bash +# In src/modelhub/ +git mv callbacks/base.py callbacks/callback.py +git mv metrics/base.py metrics/metric.py +# inference_engines/base.py could stay or rename to inference_engine.py +``` + +#### Step 2.2: Update Imports in Base Files +**In `src/modelhub/callbacks/callback.py`**: +```python +# Update any internal imports if needed +# Mainly just ensure docstrings reference correct module names +``` + +**In `src/modelhub/metrics/metric.py`**: +```python +# Update imports and docstrings +``` + +#### Step 2.3: Create Proper __init__.py Files + +**`src/modelhub/callbacks/__init__.py`**: +```python +"""Callbacks for training customization. + +This module provides both the base callback class and common callback implementations. +""" + +from modelhub.callbacks.callback import BaseCallback +from modelhub.callbacks.health_logging import HealthLoggingCallback +from modelhub.callbacks.timing_logging import TimingLoggingCallback +from modelhub.callbacks.train_logging import TrainLoggingCallback + +__all__ = [ + "BaseCallback", + "HealthLoggingCallback", + "TimingLoggingCallback", + "TrainLoggingCallback", +] +``` + +**`src/modelhub/metrics/__init__.py`**: +```python +"""Metrics for model evaluation. + +This module provides the base metric framework and common metric implementations. +""" + +from modelhub.metrics.metric import Metric, MetricManager, instantiate_metric_manager +from modelhub.metrics.chiral import ChiralMetric +from modelhub.metrics.rasa import RASAMetric + +__all__ = [ + "Metric", + "MetricManager", + "instantiate_metric_manager", + "ChiralMetric", + "RASAMetric", +] +``` + +**`src/modelhub/__init__.py`**: +```python +"""ModelForge: Open-source framework for biomolecular modeling. + +This package provides the base framework for building and training biomolecular models. +""" + +# Export key base classes at top level +from modelhub.callbacks.callback import BaseCallback +from modelhub.metrics.metric import Metric, MetricManager + +# Version +from modelhub.version import __version__ + +__all__ = [ + "BaseCallback", + "Metric", + "MetricManager", + "__version__", +] +``` + +### Phase 3: Fix RF3 Imports + +#### Step 3.1: Update All RF3 Import Statements + +**Find all broken imports**: +```bash +cd releases/rf3/ +grep -r "from metrics.base import" src/ +grep -r "from callbacks.base import" src/ +grep -r "from inference_engines.base import" src/ +``` + +**Replace with**: +```python +# OLD (broken) +from metrics.base import Metric +from callbacks.base import BaseCallback + +# NEW (correct) +from modelhub.metrics.metric import Metric, MetricManager +from modelhub.callbacks.callback import BaseCallback +from modelhub.inference_engines.base import InferenceEngine +``` + +**Automated replacement**: +```bash +# In releases/rf3/src/ +find . -name "*.py" -exec sed -i 's/from metrics\.base import/from modelhub.metrics.metric import/g' {} + +find . -name "*.py" -exec sed -i 's/from callbacks\.base import/from modelhub.callbacks.callback import/g' {} + +find . -name "*.py" -exec sed -i 's/from inference_engines\.base import/from modelhub.inference_engines.base import/g' {} + +``` + +#### Step 3.2: Update RF3 Utility Imports + +```python +# Also update imports of shared utilities +from modelhub.utils.ddp import RankedLogger +from modelhub.utils.logging import suppress_warnings +from modelhub.utils.weights import load_checkpoint +from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks +``` + +### Phase 4: Fix Build Configuration + +#### Step 4.1: Update pyproject.toml + +**Current**: +```toml +[project] +name = "rf3" +[project.scripts] +rf3 = "modelhub.cli:app" +[tool.hatch.build.targets.wheel] +packages = ["src/modelhub"] +``` + +**New**: +```toml +[project] +name = "modelforge" +description = "Open-source framework for biomolecular structure prediction and design" + +# Minimal dependencies for base framework +dependencies = [ + "torch>=2.2.0,<3", + "lightning>=2.4.0,<2.5", + "hydra-core>=1.3.0,<1.4", + "rootutils>=1.0.7,<1.1", + "environs>=11.0.0,<12", + "wandb>=0.15.10,<1", + "rich>=13.9.4,<14", + "jaxtyping>=0.2.17,<1", + "beartype>=0.18.0,<1", +] + +[project.optional-dependencies] +# RF3 model with its specific dependencies +rf3 = [ + "atomworks==1.0.2", + "einops>=0.8.0,<1", + "einx>=0.1.0,<1", + "opt_einsum>=3.4.0,<4", + "dm-tree>=0.1.6,<1", + "cuequivariance_ops_cu12>=0.5.0; sys_platform == 'linux'", + "cuequivariance_ops_torch_cu12>=0.5.0; sys_platform == 'linux'", + "cuequivariance_torch>=0.5.0; sys_platform == 'linux'", +] + +# All models +all = [ + "modelforge[rf3]", +] + +# Development +dev = [ + "ruff==0.8.3", + "pytest>=8.2.0,<9", + # ... other dev deps +] + +[project.scripts] +# Main CLI - needs to be created or use RF3 CLI directly +rf3 = "rf3.cli:app" # Direct to RF3 for now + +[tool.hatch.build.targets.wheel] +# Include both modelhub and rf3 (if using monorepo approach) +packages = ["src/modelhub"] +# OR include releases/rf3/src/rf3 via custom build hook + +# For development: force-include releases +[tool.hatch.build.targets.wheel.force-include] +"releases/rf3/src/rf3" = "modelhub/models/rf3" # If using monorepo approach +``` + +#### Step 4.2: Create Main CLI (If Needed) + +**Option 1**: Keep `rf3` CLI pointing directly to RF3 +```toml +[project.scripts] +rf3 = "rf3.cli:app" +``` + +**Option 2**: Create dispatcher CLI (future-proof) + +**`src/modelhub/cli.py`**: +```python +"""Main ModelForge CLI.""" + +import typer + +app = typer.Typer() + +# Import RF3 CLI if available +try: + from rf3.cli import app as rf3_app + # Expose RF3 commands at top level + for command_name, command in rf3_app.registered_commands: + app.command(name=command_name)(command.callback) +except ImportError: + pass # RF3 not installed + +if __name__ == "__main__": + app() +``` + +Then: +```toml +[project.scripts] +modelforge = "modelhub.cli:app" +rf3 = "rf3.cli:app" # Keep for backward compatibility +``` + +### Phase 5: Handle RF3 Package Integration + +**Decision needed**: Choose approach from Phase 2, Part 2. + +**Recommended: Option C (Hybrid)** + +Use hatch build hooks to include RF3 from `releases/`: + +**`hatch_build.py`** (in project root): +```python +"""Custom hatch build hook to include RF3 from releases/.""" + +from hatchling.builders.hooks.plugin.interface import BuildHookInterface + +class CustomBuildHook(BuildHookInterface): + def initialize(self, version, build_data): + """Include RF3 from releases/ directory.""" + if self.target_name == "wheel": + # Add releases/rf3/src/rf3 to the wheel as modelhub/models/rf3 + build_data["force_include"] = { + "releases/rf3/src/rf3": "modelhub/models/rf3" + } +``` + +**In `pyproject.toml`**: +```toml +[tool.hatch.build.hooks.custom] +path = "hatch_build.py" +``` + +This way: +- Development happens in `releases/rf3/` +- Installation includes RF3 in the package +- Users can `pip install modelforge[rf3]` + +### Phase 6: Update Tests + +#### Step 6.1: Update Test Imports + +**In `releases/rf3/tests/`**: +```python +# Update conftest.py and test files +from modelhub.metrics.metric import Metric +from modelhub.callbacks.callback import BaseCallback +``` + +**In root `tests/`**: +```python +# These already test shared utilities +from modelhub.utils.weights import load_checkpoint +from modelhub.metrics.metric import MetricManager +``` + +#### Step 6.2: Verify Test Execution + +```bash +# Test shared utilities +pytest tests/ + +# Test RF3 (from releases/rf3/) +cd releases/rf3/ +pytest tests/ +``` + +### Phase 7: Update Documentation + +#### Step 7.1: Update CLAUDE.md + +Update import examples and package structure documentation. + +#### Step 7.2: Update README.md + +```markdown +# ModelForge + +## Installation + +```bash +# Base framework only +pip install modelforge + +# With RF3 +pip install modelforge[rf3] + +# Development +pip install -e .[rf3,dev] +``` + +## Usage + +```python +# Import base framework +from modelforge.callbacks import BaseCallback +from modelforge.metrics import Metric + +# Import RF3 +from modelforge.models.rf3 import RF3 +``` +``` + +--- + +## Migration Checklist + +### Pre-Migration +- [ ] Backup current code +- [ ] Create feature branch: `git checkout -b refactor/restructure-package` +- [ ] Document current import patterns for reference + +### Phase 1: Package Rename +- [ ] Decision: Keep `modelhub` or rename to `modelforge`? +- [ ] Update `pyproject.toml` name field +- [ ] Update all imports if renaming + +### Phase 2: Restructure src/ +- [ ] Rename `callbacks/base.py` → `callbacks/callback.py` +- [ ] Rename `metrics/base.py` → `metrics/metric.py` +- [ ] Create/update all `__init__.py` files with proper exports +- [ ] Update internal imports within src/modelhub/ + +### Phase 3: Fix RF3 Imports +- [ ] Find all `from metrics.base` imports in releases/rf3/ +- [ ] Replace with `from modelhub.metrics.metric` +- [ ] Find all `from callbacks.base` imports +- [ ] Replace with `from modelhub.callbacks.callback` +- [ ] Update utility imports to use `modelhub.utils.*` +- [ ] Test that RF3 can import all needed modules + +### Phase 4: Fix Build Config +- [ ] Update `pyproject.toml` project name +- [ ] Split dependencies: core vs rf3-specific +- [ ] Add `[project.optional-dependencies]` for rf3 +- [ ] Fix `[project.scripts]` CLI entry point +- [ ] Update `[tool.hatch.build.targets.wheel]` packages +- [ ] Add build hooks if using hybrid approach + +### Phase 5: RF3 Integration +- [ ] Decide on integration approach (A, B, or C) +- [ ] Implement chosen approach +- [ ] Test `pip install -e .` (base only) +- [ ] Test `pip install -e .[rf3]` (with RF3) + +### Phase 6: Update Tests +- [ ] Update test imports +- [ ] Run shared tests: `pytest tests/` +- [ ] Run RF3 tests: `cd releases/rf3 && pytest tests/` +- [ ] Fix any import errors + +### Phase 7: Documentation +- [ ] Update CLAUDE.md +- [ ] Update README.md +- [ ] Update any other documentation +- [ ] Add migration notes for contributors + +### Phase 8: Validation +- [ ] Clean install test: `pip install -e .` +- [ ] RF3 install test: `pip install -e .[rf3]` +- [ ] Test CLI: `rf3 fold inputs=...` +- [ ] Run full test suite +- [ ] Test on clean environment + +--- + +## Testing Strategy + +### Test Scenarios + +#### Scenario 1: Base Install Only +```bash +# Clean environment +python -m venv test-env +source test-env/bin/activate +pip install -e . + +# Should work +python -c "from modelhub.callbacks import BaseCallback; print('OK')" +python -c "from modelhub.metrics import Metric; print('OK')" +python -c "from modelhub.utils.ddp import RankedLogger; print('OK')" + +# Should fail (RF3 not installed) +python -c "from modelhub.models.rf3 import RF3" # ImportError expected +``` + +#### Scenario 2: With RF3 +```bash +# Clean environment +python -m venv test-env +source test-env/bin/activate +pip install -e .[rf3] + +# Should work +python -c "from modelhub.models.rf3 import RF3; print('OK')" +python -c "from rf3.cli import app; print('OK')" + +# Test CLI +rf3 fold inputs='releases/rf3/tests/data/5vht_from_json.json' +``` + +#### Scenario 3: Development Workflow +```bash +# Install in development mode with RF3 +pip install -e .[rf3,dev] + +# Make changes to releases/rf3/src/rf3/model/RF3.py +# Changes should be immediately available + +python -c "from modelhub.models.rf3.model import RF3" # Should see changes + +# Run tests +pytest releases/rf3/tests/ +``` + +--- + +## Risk Assessment + +### High Risk Items +1. **Breaking all existing imports** - Requires updating many files + - Mitigation: Use automated search/replace, test thoroughly + +2. **Build configuration complexity** - Hatch build hooks can be tricky + - Mitigation: Start with simple approach, test installation frequently + +3. **Package name change** - Breaking change for any external users + - Mitigation: Coordinate with team, announce breaking change + +### Medium Risk Items +1. **CLI entry point changes** - Users may have scripts using old CLI + - Mitigation: Keep backward-compatible entry points + +2. **Test imports** - Many test files to update + - Mitigation: Automated replacement, run tests incrementally + +### Low Risk Items +1. **Documentation updates** - Time-consuming but low risk +2. **`__init__.py` exports** - Easy to fix if wrong + +--- + +## Rollback Plan + +If migration fails: +1. `git checkout main` - Revert to pre-migration state +2. Keep feature branch for later attempt +3. Document what failed for next iteration + +--- + +## Future Considerations + +### Adding New Models +Once restructured, adding a new model is straightforward: + +``` +releases/ +├── rf3/ # Existing +└── proteinmpnn/ # New model + ├── src/proteinmpnn/ + ├── configs/ + └── tests/ +``` + +Then in `pyproject.toml`: +```toml +[project.optional-dependencies] +proteinmpnn = [ + "proteinmpnn-specific-deps", +] +all = [ + "modelforge[rf3,proteinmpnn]", +] +``` + +### Splitting Into Separate Packages (Future) +If models become too large, can later split: +- `modelforge-core` - Base framework +- `modelforge-rf3` - RF3 model +- `modelforge-proteinmpnn` - ProteinMPNN model + +But this is much more complex and not recommended initially. + +--- + +## Questions to Resolve + +1. **Package name**: Keep `modelhub` or rename to `modelforge`? + - Recommendation: Rename to `modelforge` + +2. **RF3 integration**: Which approach (A, B, or C)? + - Recommendation: Option C (hybrid with build hooks) + +3. **CLI structure**: Single entry point or keep separate? + - Recommendation: Keep `rf3` command separate for now + +4. **Base file naming**: Rename `base.py` to match content? + - Recommendation: Yes, rename to `callback.py`, `metric.py` + +5. **Version strategy**: Single version or per-model versions? + - Recommendation: Single version (simpler) + +--- + +## Success Criteria + +- [ ] `pip install modelforge` works +- [ ] `pip install modelforge[rf3]` works +- [ ] All imports are unambiguous and correct +- [ ] All tests pass +- [ ] CLI works: `rf3 fold inputs=...` +- [ ] Structure matches PyTorch Lightning patterns +- [ ] Documentation is updated +- [ ] No more `sys.path` manipulation needed + +--- + +## Timeline Estimate + +- Phase 1 (Rename decision): 1 hour +- Phase 2 (Restructure src/): 2-3 hours +- Phase 3 (Fix RF3 imports): 1-2 hours +- Phase 4 (Build config): 2-3 hours +- Phase 5 (RF3 integration): 2-4 hours +- Phase 6 (Update tests): 1-2 hours +- Phase 7 (Documentation): 1-2 hours +- Phase 8 (Validation): 2-3 hours + +**Total**: 12-20 hours of focused work + +--- + +## Conclusion + +This restructuring will: +1. ✅ Fix all broken imports +2. ✅ Follow industry best practices (Lightning pattern) +3. ✅ Enable proper pip installation +4. ✅ Support selective model installation +5. ✅ Maintain development workflow in `releases/` +6. ✅ Provide clear, unambiguous imports +7. ✅ Scale to future models + +The key is following PyTorch Lightning's pattern: base classes live WITH their implementations, no separate `contrib/` or overly complex directory nesting. diff --git a/docs/rf3/examples/3en2_from_file.cif b/docs/releases/rf3/examples/3en2_from_file.cif similarity index 100% rename from docs/rf3/examples/3en2_from_file.cif rename to docs/releases/rf3/examples/3en2_from_file.cif diff --git a/docs/rf3/examples/3en2_from_json_with_msa.json b/docs/releases/rf3/examples/3en2_from_json_with_msa.json similarity index 100% rename from docs/rf3/examples/3en2_from_json_with_msa.json rename to docs/releases/rf3/examples/3en2_from_json_with_msa.json diff --git a/docs/rf3/examples/5hkn_from_file.cif b/docs/releases/rf3/examples/5hkn_from_file.cif similarity index 100% rename from docs/rf3/examples/5hkn_from_file.cif rename to docs/releases/rf3/examples/5hkn_from_file.cif diff --git a/docs/rf3/examples/7o1r_from_json.json b/docs/releases/rf3/examples/7o1r_from_json.json similarity index 100% rename from docs/rf3/examples/7o1r_from_json.json rename to docs/releases/rf3/examples/7o1r_from_json.json diff --git a/docs/rf3/examples/7xli_template_antigen_and_framework.json b/docs/releases/rf3/examples/7xli_template_antigen_and_framework.json similarity index 100% rename from docs/rf3/examples/7xli_template_antigen_and_framework.json rename to docs/releases/rf3/examples/7xli_template_antigen_and_framework.json diff --git a/docs/rf3/examples/9dfn.cif b/docs/releases/rf3/examples/9dfn.cif similarity index 100% rename from docs/rf3/examples/9dfn.cif rename to docs/releases/rf3/examples/9dfn.cif diff --git a/docs/rf3/examples/9dfn_template_ligand_and_protein.json b/docs/releases/rf3/examples/9dfn_template_ligand_and_protein.json similarity index 100% rename from docs/rf3/examples/9dfn_template_ligand_and_protein.json rename to docs/releases/rf3/examples/9dfn_template_ligand_and_protein.json diff --git a/docs/rf3/examples/ligands/HEM.sdf b/docs/releases/rf3/examples/ligands/HEM.sdf similarity index 100% rename from docs/rf3/examples/ligands/HEM.sdf rename to docs/releases/rf3/examples/ligands/HEM.sdf diff --git a/docs/rf3/examples/ligands/NAG.cif b/docs/releases/rf3/examples/ligands/NAG.cif similarity index 100% rename from docs/rf3/examples/ligands/NAG.cif rename to docs/releases/rf3/examples/ligands/NAG.cif diff --git a/docs/rf3/examples/msas/3en2_A.a3m.gz b/docs/releases/rf3/examples/msas/3en2_A.a3m.gz similarity index 100% rename from docs/rf3/examples/msas/3en2_A.a3m.gz rename to docs/releases/rf3/examples/msas/3en2_A.a3m.gz diff --git a/docs/rf3/examples/msas/7o1r_A.a3m.gz b/docs/releases/rf3/examples/msas/7o1r_A.a3m.gz similarity index 100% rename from docs/rf3/examples/msas/7o1r_A.a3m.gz rename to docs/releases/rf3/examples/msas/7o1r_A.a3m.gz diff --git a/docs/rf3/examples/msas/8cdz_A.a3m.gz b/docs/releases/rf3/examples/msas/8cdz_A.a3m.gz similarity index 100% rename from docs/rf3/examples/msas/8cdz_A.a3m.gz rename to docs/releases/rf3/examples/msas/8cdz_A.a3m.gz diff --git a/docs/rf3/examples/multiple_examples_from_json.json b/docs/releases/rf3/examples/multiple_examples_from_json.json similarity index 100% rename from docs/rf3/examples/multiple_examples_from_json.json rename to docs/releases/rf3/examples/multiple_examples_from_json.json diff --git a/docs/rf3/examples/templates/7xli_chain_A.cif b/docs/releases/rf3/examples/templates/7xli_chain_A.cif similarity index 100% rename from docs/rf3/examples/templates/7xli_chain_A.cif rename to docs/releases/rf3/examples/templates/7xli_chain_A.cif diff --git a/docs/rf3/examples/templates/7xli_chain_B.cif b/docs/releases/rf3/examples/templates/7xli_chain_B.cif similarity index 100% rename from docs/rf3/examples/templates/7xli_chain_B.cif rename to docs/releases/rf3/examples/templates/7xli_chain_B.cif diff --git a/src/modelhub/inference_engines/README.md b/models/rf3/README.md similarity index 99% rename from src/modelhub/inference_engines/README.md rename to models/rf3/README.md index b0f4f37..1a8058f 100644 --- a/src/modelhub/inference_engines/README.md +++ b/models/rf3/README.md @@ -1,7 +1,7 @@ # Inference with RosettaFold3(RF3)
+
-
+
Figure: `7o1r` structure showing N-glycosylation covalent modification prediction with RF3 and ground truth crystal structure. diff --git a/configs/callbacks/default.yaml b/models/rf3/configs/callbacks/default.yaml similarity index 100% rename from configs/callbacks/default.yaml rename to models/rf3/configs/callbacks/default.yaml diff --git a/configs/callbacks/dump_validation_structures.yaml b/models/rf3/configs/callbacks/dump_validation_structures.yaml similarity index 100% rename from configs/callbacks/dump_validation_structures.yaml rename to models/rf3/configs/callbacks/dump_validation_structures.yaml diff --git a/configs/callbacks/metrics_logging.yaml b/models/rf3/configs/callbacks/metrics_logging.yaml similarity index 100% rename from configs/callbacks/metrics_logging.yaml rename to models/rf3/configs/callbacks/metrics_logging.yaml diff --git a/configs/callbacks/train_logging.yaml b/models/rf3/configs/callbacks/train_logging.yaml similarity index 100% rename from configs/callbacks/train_logging.yaml rename to models/rf3/configs/callbacks/train_logging.yaml diff --git a/configs/dataloader/default.yaml b/models/rf3/configs/dataloader/default.yaml similarity index 100% rename from configs/dataloader/default.yaml rename to models/rf3/configs/dataloader/default.yaml diff --git a/configs/datasets/base.yaml b/models/rf3/configs/datasets/base.yaml similarity index 100% rename from configs/datasets/base.yaml rename to models/rf3/configs/datasets/base.yaml diff --git a/configs/datasets/pdb_and_distillation.yaml b/models/rf3/configs/datasets/pdb_and_distillation.yaml similarity index 100% rename from configs/datasets/pdb_and_distillation.yaml rename to models/rf3/configs/datasets/pdb_and_distillation.yaml diff --git a/configs/datasets/pdb_only.yaml b/models/rf3/configs/datasets/pdb_only.yaml similarity index 100% rename from configs/datasets/pdb_only.yaml rename to models/rf3/configs/datasets/pdb_only.yaml diff --git a/configs/datasets/train/disorder_distillation.yaml b/models/rf3/configs/datasets/train/disorder_distillation.yaml similarity index 100% rename from configs/datasets/train/disorder_distillation.yaml rename to models/rf3/configs/datasets/train/disorder_distillation.yaml diff --git a/configs/datasets/train/domain_distillation.yaml b/models/rf3/configs/datasets/train/domain_distillation.yaml similarity index 100% rename from configs/datasets/train/domain_distillation.yaml rename to models/rf3/configs/datasets/train/domain_distillation.yaml diff --git a/configs/datasets/train/monomer_distillation.yaml b/models/rf3/configs/datasets/train/monomer_distillation.yaml similarity index 100% rename from configs/datasets/train/monomer_distillation.yaml rename to models/rf3/configs/datasets/train/monomer_distillation.yaml diff --git a/configs/datasets/train/na_complex_distillation.yaml b/models/rf3/configs/datasets/train/na_complex_distillation.yaml similarity index 100% rename from configs/datasets/train/na_complex_distillation.yaml rename to models/rf3/configs/datasets/train/na_complex_distillation.yaml diff --git a/configs/datasets/train/pdb/af3_weighted_sampling.yaml b/models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml similarity index 100% rename from configs/datasets/train/pdb/af3_weighted_sampling.yaml rename to models/rf3/configs/datasets/train/pdb/af3_weighted_sampling.yaml diff --git a/configs/datasets/train/pdb/base.yaml b/models/rf3/configs/datasets/train/pdb/base.yaml similarity index 100% rename from configs/datasets/train/pdb/base.yaml rename to models/rf3/configs/datasets/train/pdb/base.yaml diff --git a/configs/datasets/train/pdb/plinder.yaml b/models/rf3/configs/datasets/train/pdb/plinder.yaml similarity index 100% rename from configs/datasets/train/pdb/plinder.yaml rename to models/rf3/configs/datasets/train/pdb/plinder.yaml diff --git a/configs/datasets/train/pdb/train_interface.yaml b/models/rf3/configs/datasets/train/pdb/train_interface.yaml similarity index 100% rename from configs/datasets/train/pdb/train_interface.yaml rename to models/rf3/configs/datasets/train/pdb/train_interface.yaml diff --git a/configs/datasets/train/pdb/train_pn_unit.yaml b/models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml similarity index 100% rename from configs/datasets/train/pdb/train_pn_unit.yaml rename to models/rf3/configs/datasets/train/pdb/train_pn_unit.yaml diff --git a/configs/datasets/train/rna_monomer_distillation.yaml b/models/rf3/configs/datasets/train/rna_monomer_distillation.yaml similarity index 100% rename from configs/datasets/train/rna_monomer_distillation.yaml rename to models/rf3/configs/datasets/train/rna_monomer_distillation.yaml diff --git a/configs/datasets/val/af3_ab_set.yaml b/models/rf3/configs/datasets/val/af3_ab_set.yaml similarity index 100% rename from configs/datasets/val/af3_ab_set.yaml rename to models/rf3/configs/datasets/val/af3_ab_set.yaml diff --git a/configs/datasets/val/af3_validation.yaml b/models/rf3/configs/datasets/val/af3_validation.yaml similarity index 100% rename from configs/datasets/val/af3_validation.yaml rename to models/rf3/configs/datasets/val/af3_validation.yaml diff --git a/configs/datasets/val/base.yaml b/models/rf3/configs/datasets/val/base.yaml similarity index 100% rename from configs/datasets/val/base.yaml rename to models/rf3/configs/datasets/val/base.yaml diff --git a/configs/datasets/val/runs_and_poses.yaml b/models/rf3/configs/datasets/val/runs_and_poses.yaml similarity index 100% rename from configs/datasets/val/runs_and_poses.yaml rename to models/rf3/configs/datasets/val/runs_and_poses.yaml diff --git a/configs/debug/default.yaml b/models/rf3/configs/debug/default.yaml similarity index 100% rename from configs/debug/default.yaml rename to models/rf3/configs/debug/default.yaml diff --git a/configs/debug/train_specific_examples.yaml b/models/rf3/configs/debug/train_specific_examples.yaml similarity index 100% rename from configs/debug/train_specific_examples.yaml rename to models/rf3/configs/debug/train_specific_examples.yaml diff --git a/configs/experiment/pretrained/rf3.yaml b/models/rf3/configs/experiment/pretrained/rf3.yaml similarity index 100% rename from configs/experiment/pretrained/rf3.yaml rename to models/rf3/configs/experiment/pretrained/rf3.yaml diff --git a/configs/experiment/pretrained/rf3_with_confidence.yaml b/models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml similarity index 100% rename from configs/experiment/pretrained/rf3_with_confidence.yaml rename to models/rf3/configs/experiment/pretrained/rf3_with_confidence.yaml diff --git a/configs/experiment/quick-rf3-with-confidence.yaml b/models/rf3/configs/experiment/quick-rf3-with-confidence.yaml similarity index 100% rename from configs/experiment/quick-rf3-with-confidence.yaml rename to models/rf3/configs/experiment/quick-rf3-with-confidence.yaml diff --git a/configs/experiment/quick-rf3.yaml b/models/rf3/configs/experiment/quick-rf3.yaml similarity index 100% rename from configs/experiment/quick-rf3.yaml rename to models/rf3/configs/experiment/quick-rf3.yaml diff --git a/configs/hydra/default.yaml b/models/rf3/configs/hydra/default.yaml similarity index 100% rename from configs/hydra/default.yaml rename to models/rf3/configs/hydra/default.yaml diff --git a/configs/hydra/no_logging.yaml b/models/rf3/configs/hydra/no_logging.yaml similarity index 100% rename from configs/hydra/no_logging.yaml rename to models/rf3/configs/hydra/no_logging.yaml diff --git a/configs/inference.yaml b/models/rf3/configs/inference.yaml similarity index 100% rename from configs/inference.yaml rename to models/rf3/configs/inference.yaml diff --git a/configs/inference_engine/base.yaml b/models/rf3/configs/inference_engine/base.yaml similarity index 100% rename from configs/inference_engine/base.yaml rename to models/rf3/configs/inference_engine/base.yaml diff --git a/configs/inference_engine/rf3.yaml b/models/rf3/configs/inference_engine/rf3.yaml similarity index 100% rename from configs/inference_engine/rf3.yaml rename to models/rf3/configs/inference_engine/rf3.yaml diff --git a/configs/logger/csv.yaml b/models/rf3/configs/logger/csv.yaml similarity index 100% rename from configs/logger/csv.yaml rename to models/rf3/configs/logger/csv.yaml diff --git a/configs/logger/default.yaml b/models/rf3/configs/logger/default.yaml similarity index 100% rename from configs/logger/default.yaml rename to models/rf3/configs/logger/default.yaml diff --git a/configs/logger/wandb.yaml b/models/rf3/configs/logger/wandb.yaml similarity index 100% rename from configs/logger/wandb.yaml rename to models/rf3/configs/logger/wandb.yaml diff --git a/configs/model/components/ema.yaml b/models/rf3/configs/model/components/ema.yaml similarity index 100% rename from configs/model/components/ema.yaml rename to models/rf3/configs/model/components/ema.yaml diff --git a/configs/model/components/rf3_net.yaml b/models/rf3/configs/model/components/rf3_net.yaml similarity index 100% rename from configs/model/components/rf3_net.yaml rename to models/rf3/configs/model/components/rf3_net.yaml diff --git a/configs/model/components/rf3_net_with_confidence_head.yaml b/models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml similarity index 100% rename from configs/model/components/rf3_net_with_confidence_head.yaml rename to models/rf3/configs/model/components/rf3_net_with_confidence_head.yaml diff --git a/configs/model/optimizers/adam.yaml b/models/rf3/configs/model/optimizers/adam.yaml similarity index 100% rename from configs/model/optimizers/adam.yaml rename to models/rf3/configs/model/optimizers/adam.yaml diff --git a/configs/model/rf3.yaml b/models/rf3/configs/model/rf3.yaml similarity index 100% rename from configs/model/rf3.yaml rename to models/rf3/configs/model/rf3.yaml diff --git a/configs/model/rf3_with_confidence.yaml b/models/rf3/configs/model/rf3_with_confidence.yaml similarity index 100% rename from configs/model/rf3_with_confidence.yaml rename to models/rf3/configs/model/rf3_with_confidence.yaml diff --git a/configs/model/schedulers/af3.yaml b/models/rf3/configs/model/schedulers/af3.yaml similarity index 100% rename from configs/model/schedulers/af3.yaml rename to models/rf3/configs/model/schedulers/af3.yaml diff --git a/configs/paths/data/default.yaml b/models/rf3/configs/paths/data/default.yaml similarity index 100% rename from configs/paths/data/default.yaml rename to models/rf3/configs/paths/data/default.yaml diff --git a/configs/paths/default.yaml b/models/rf3/configs/paths/default.yaml similarity index 100% rename from configs/paths/default.yaml rename to models/rf3/configs/paths/default.yaml diff --git a/configs/train.yaml b/models/rf3/configs/train.yaml similarity index 100% rename from configs/train.yaml rename to models/rf3/configs/train.yaml diff --git a/configs/trainer/cpu.yaml b/models/rf3/configs/trainer/cpu.yaml similarity index 100% rename from configs/trainer/cpu.yaml rename to models/rf3/configs/trainer/cpu.yaml diff --git a/configs/trainer/ddp.yaml b/models/rf3/configs/trainer/ddp.yaml similarity index 100% rename from configs/trainer/ddp.yaml rename to models/rf3/configs/trainer/ddp.yaml diff --git a/configs/trainer/loss/losses/confidence_loss.yaml b/models/rf3/configs/trainer/loss/losses/confidence_loss.yaml similarity index 100% rename from configs/trainer/loss/losses/confidence_loss.yaml rename to models/rf3/configs/trainer/loss/losses/confidence_loss.yaml diff --git a/configs/trainer/loss/losses/diffusion_loss.yaml b/models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml similarity index 100% rename from configs/trainer/loss/losses/diffusion_loss.yaml rename to models/rf3/configs/trainer/loss/losses/diffusion_loss.yaml diff --git a/configs/trainer/loss/losses/distogram_loss.yaml b/models/rf3/configs/trainer/loss/losses/distogram_loss.yaml similarity index 100% rename from configs/trainer/loss/losses/distogram_loss.yaml rename to models/rf3/configs/trainer/loss/losses/distogram_loss.yaml diff --git a/configs/trainer/loss/structure_prediction.yaml b/models/rf3/configs/trainer/loss/structure_prediction.yaml similarity index 100% rename from configs/trainer/loss/structure_prediction.yaml rename to models/rf3/configs/trainer/loss/structure_prediction.yaml diff --git a/configs/trainer/loss/structure_prediction_with_confidence.yaml b/models/rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml similarity index 100% rename from configs/trainer/loss/structure_prediction_with_confidence.yaml rename to models/rf3/configs/trainer/loss/structure_prediction_with_confidence.yaml diff --git a/configs/trainer/metrics/structure_prediction.yaml b/models/rf3/configs/trainer/metrics/structure_prediction.yaml similarity index 100% rename from configs/trainer/metrics/structure_prediction.yaml rename to models/rf3/configs/trainer/metrics/structure_prediction.yaml diff --git a/configs/trainer/rf3.yaml b/models/rf3/configs/trainer/rf3.yaml similarity index 100% rename from configs/trainer/rf3.yaml rename to models/rf3/configs/trainer/rf3.yaml diff --git a/configs/trainer/rf3_with_confidence.yaml b/models/rf3/configs/trainer/rf3_with_confidence.yaml similarity index 100% rename from configs/trainer/rf3_with_confidence.yaml rename to models/rf3/configs/trainer/rf3_with_confidence.yaml diff --git a/configs/validate.yaml b/models/rf3/configs/validate.yaml similarity index 100% rename from configs/validate.yaml rename to models/rf3/configs/validate.yaml diff --git a/src/modelhub/alignment.py b/models/rf3/src/rf3/alignment.py similarity index 100% rename from src/modelhub/alignment.py rename to models/rf3/src/rf3/alignment.py diff --git a/src/modelhub/callbacks/dump_validation_structures.py b/models/rf3/src/rf3/callbacks/dump_validation_structures.py similarity index 97% rename from src/modelhub/callbacks/dump_validation_structures.py rename to models/rf3/src/rf3/callbacks/dump_validation_structures.py index b00299e..056258f 100644 --- a/src/modelhub/callbacks/dump_validation_structures.py +++ b/models/rf3/src/rf3/callbacks/dump_validation_structures.py @@ -4,8 +4,8 @@ from pathlib import Path from atomworks.common import parse_example_id from beartype.typing import Any -from modelhub.callbacks.base import BaseCallback -from modelhub.utils.io import ( +from callbacks.base import BaseCallback +from rf3.utils.io import ( build_stack_from_atom_array_and_batched_coords, dump_structures, dump_trajectories, diff --git a/src/modelhub/callbacks/metrics_logging.py b/models/rf3/src/rf3/callbacks/metrics_logging.py similarity index 99% rename from src/modelhub/callbacks/metrics_logging.py rename to models/rf3/src/rf3/callbacks/metrics_logging.py index ac9a380..add49a3 100755 --- a/src/modelhub/callbacks/metrics_logging.py +++ b/models/rf3/src/rf3/callbacks/metrics_logging.py @@ -7,9 +7,9 @@ from atomworks.ml.utils import nested_dict from beartype.typing import Any, Literal from omegaconf import ListConfig -from modelhub.callbacks.base import BaseCallback -from modelhub.utils.ddp import RankedLogger -from modelhub.utils.logging import ( +from callbacks.base import BaseCallback +from utils.ddp import RankedLogger +from utils.logging import ( condense_count_columns_of_grouped_df, print_df_as_table, ) diff --git a/src/modelhub/chemical.py b/models/rf3/src/rf3/chemical.py similarity index 100% rename from src/modelhub/chemical.py rename to models/rf3/src/rf3/chemical.py diff --git a/src/modelhub/cli.py b/models/rf3/src/rf3/cli.py similarity index 97% rename from src/modelhub/cli.py rename to models/rf3/src/rf3/cli.py index b22d3be..0736076 100644 --- a/src/modelhub/cli.py +++ b/models/rf3/src/rf3/cli.py @@ -3,7 +3,7 @@ import os import typer from hydra import compose, initialize_config_dir -from modelhub.inference import run_inference +from rf3.inference import run_inference app = typer.Typer() diff --git a/src/modelhub/data/extra_xforms.py b/models/rf3/src/rf3/data/extra_xforms.py similarity index 100% rename from src/modelhub/data/extra_xforms.py rename to models/rf3/src/rf3/data/extra_xforms.py diff --git a/src/modelhub/data/ground_truth_conformer.py b/models/rf3/src/rf3/data/ground_truth_conformer.py similarity index 100% rename from src/modelhub/data/ground_truth_conformer.py rename to models/rf3/src/rf3/data/ground_truth_conformer.py diff --git a/src/modelhub/data/ground_truth_template.py b/models/rf3/src/rf3/data/ground_truth_template.py similarity index 99% rename from src/modelhub/data/ground_truth_template.py rename to models/rf3/src/rf3/data/ground_truth_template.py index 6f775d1..19ffb6a 100644 --- a/src/modelhub/data/ground_truth_template.py +++ b/models/rf3/src/rf3/data/ground_truth_template.py @@ -20,7 +20,7 @@ from biotite.structure import AtomArray from jaxtyping import Bool, Float, Shaped from torch import Tensor -from modelhub.utils.torch_utils import assert_no_nans +from utils.torch import assert_no_nans logger = logging.getLogger(__name__) diff --git a/src/modelhub/data/paired_msa.py b/models/rf3/src/rf3/data/paired_msa.py similarity index 100% rename from src/modelhub/data/paired_msa.py rename to models/rf3/src/rf3/data/paired_msa.py diff --git a/src/modelhub/data/pipeline_utils.py b/models/rf3/src/rf3/data/pipeline_utils.py similarity index 99% rename from src/modelhub/data/pipeline_utils.py rename to models/rf3/src/rf3/data/pipeline_utils.py index 06268cd..18f39b3 100644 --- a/src/modelhub/data/pipeline_utils.py +++ b/models/rf3/src/rf3/data/pipeline_utils.py @@ -6,7 +6,7 @@ from atomworks.ml.transforms._checks import check_atom_array_annotation from atomworks.ml.transforms.crop import compute_local_hash from omegaconf import DictConfig -from modelhub.data.ground_truth_template import ( +from rf3.data.ground_truth_template import ( FeaturizeNoisedGroundTruthAsTemplateDistogram, TokenGroupNoiseScaleSampler, af3_noise_scale_distribution_wrapped, diff --git a/src/modelhub/data/pipelines.py b/models/rf3/src/rf3/data/pipelines.py similarity index 99% rename from src/modelhub/data/pipelines.py rename to models/rf3/src/rf3/data/pipelines.py index 3c7de1f..8157ee2 100644 --- a/src/modelhub/data/pipelines.py +++ b/models/rf3/src/rf3/data/pipelines.py @@ -99,8 +99,8 @@ from atomworks.ml.transforms.rdkit_utils import GetRDKitChiralCenters from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX from omegaconf import DictConfig -from modelhub.data.extra_xforms import CheckForNaNsInInputs -from modelhub.data.pipeline_utils import ( +from rf3.data.extra_xforms import CheckForNaNsInInputs +from rf3.data.pipeline_utils import ( annotate_post_crop_hash, annotate_pre_crop_hash, build_ground_truth_distogram_transform, diff --git a/src/modelhub/data/rotation_augmentation.py b/models/rf3/src/rf3/data/rotation_augmentation.py similarity index 96% rename from src/modelhub/data/rotation_augmentation.py rename to models/rf3/src/rf3/data/rotation_augmentation.py index d206488..765b010 100644 --- a/src/modelhub/data/rotation_augmentation.py +++ b/models/rf3/src/rf3/data/rotation_augmentation.py @@ -2,7 +2,7 @@ import math import torch -from modelhub.flow_matching.rigid_utils import rot_vec_mul +from rf3.flow_matching.rigid_utils import rot_vec_mul def centre(X_L, X_exists_L): diff --git a/src/modelhub/diffusion_samplers/inference_sampler.py b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py similarity index 98% rename from src/modelhub/diffusion_samplers/inference_sampler.py rename to models/rf3/src/rf3/diffusion_samplers/inference_sampler.py index f139781..01dfa14 100755 --- a/src/modelhub/diffusion_samplers/inference_sampler.py +++ b/models/rf3/src/rf3/diffusion_samplers/inference_sampler.py @@ -2,8 +2,8 @@ import torch from beartype.typing import Any, Literal from jaxtyping import Float -from modelhub.data.rotation_augmentation import centre_random_augmentation -from modelhub.utils.ddp import RankedLogger +from rf3.data.rotation_augmentation import centre_random_augmentation +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/flow_matching/rigid_utils.py b/models/rf3/src/rf3/flow_matching/rigid_utils.py similarity index 100% rename from src/modelhub/flow_matching/rigid_utils.py rename to models/rf3/src/rf3/flow_matching/rigid_utils.py diff --git a/src/modelhub/inference.py b/models/rf3/src/rf3/inference.py similarity index 96% rename from src/modelhub/inference.py rename to models/rf3/src/rf3/inference.py index 4c757ed..dbbf65a 100755 --- a/src/modelhub/inference.py +++ b/models/rf3/src/rf3/inference.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv from hydra.utils import instantiate from omegaconf import DictConfig -from modelhub.utils.logging import suppress_warnings +from utils.logging import suppress_warnings load_dotenv(override=True) diff --git a/src/modelhub/inference_engines/rf3.py b/models/rf3/src/rf3/inference_engines/rf3.py similarity index 97% rename from src/modelhub/inference_engines/rf3.py rename to models/rf3/src/rf3/inference_engines/rf3.py index fca5f31..af4bb7c 100644 --- a/src/modelhub/inference_engines/rf3.py +++ b/models/rf3/src/rf3/inference_engines/rf3.py @@ -10,23 +10,23 @@ from atomworks.io.transforms.categories import category_to_dict from lightning.fabric import seed_everything from omegaconf import OmegaConf -from modelhub.inference_engines.base import InferenceEngine -from modelhub.model.RF3 import ShouldEarlyStopFn -from modelhub.utils.datasets import ( +from inference_engines.base import InferenceEngine +from rf3.model.RF3 import ShouldEarlyStopFn +from rf3.utils.datasets import ( assemble_distributed_inference_loader_from_list_of_paths, ) -from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability -from modelhub.utils.inference import ( +from utils.ddp import RankedLogger, set_accelerator_based_on_availability +from rf3.utils.inference import ( apply_conformer_and_template_selections, build_file_paths_for_prediction, ) -from modelhub.utils.io import ( +from rf3.utils.io import ( build_stack_from_atom_array_and_batched_coords, dump_structures, dump_trajectories, ) -from modelhub.utils.logging import print_config_tree -from modelhub.utils.predicted_error import ( +from utils.logging import print_config_tree +from rf3.utils.predicted_error import ( annotate_atom_array_b_factor_with_plddt, compile_af3_confidence_outputs, get_mean_atomwise_plddt, @@ -173,17 +173,17 @@ class RF3InferenceEngine(InferenceEngine): "allowed_chain_types_for_conditioning": None, # Avoid random conditioning "protein_msa_dirs": [ { - "dir": "/projects/msa/hhblits", + "dir": "/projects/msa/hhblits", "extension": ".a3m.gz", "directory_depth": 2, }, { - "dir": "/projects/msa/mmseqs_gpu", + "dir": "/projects/msa/mmseqs_gpu", "extension": ".a3m.gz", "directory_depth": 2, }, { - "dir": "/projects/msa/lab", + "dir": "/projects/msa/lab", "extension": ".a3m.gz", "directory_depth": 2, }, diff --git a/src/modelhub/kinematics.py b/models/rf3/src/rf3/kinematics.py similarity index 100% rename from src/modelhub/kinematics.py rename to models/rf3/src/rf3/kinematics.py diff --git a/src/modelhub/loss/af3_confidence_loss.py b/models/rf3/src/rf3/loss/af3_confidence_loss.py similarity index 99% rename from src/modelhub/loss/af3_confidence_loss.py rename to models/rf3/src/rf3/loss/af3_confidence_loss.py index 187a54e..3c2a906 100644 --- a/src/modelhub/loss/af3_confidence_loss.py +++ b/models/rf3/src/rf3/loss/af3_confidence_loss.py @@ -2,14 +2,14 @@ import torch import torch.nn as nn from scipy.stats import spearmanr -from modelhub.chemical import NFRAMES, NHEAVY, frame_indices +from rf3.chemical import NFRAMES, NHEAVY, frame_indices # TODO: REFACTOR; COPIED FROM RF2AA. WE NEED TO ADD DOCSTRINGS, EXAMPLES, HOPEFULLY TESTS, AND CLEAN UP -from modelhub.metrics.metric_utils import ( +from rf3.metrics.metric_utils import ( compute_mean_over_subsampled_pairs, unbin_logits, ) -from modelhub.utils.frames import ( +from rf3.utils.frames import ( get_frames, mask_unresolved_frames_batched, rigid_from_3_points, diff --git a/src/modelhub/loss/af3_losses.py b/models/rf3/src/rf3/loss/af3_losses.py similarity index 99% rename from src/modelhub/loss/af3_losses.py rename to models/rf3/src/rf3/loss/af3_losses.py index 2dccf36..f567da3 100644 --- a/src/modelhub/loss/af3_losses.py +++ b/models/rf3/src/rf3/loss/af3_losses.py @@ -3,8 +3,8 @@ import numpy as np import torch import torch.nn as nn -from modelhub.alignment import weighted_rigid_align -from modelhub.training.checkpoint import activation_checkpointing +from rf3.alignment import weighted_rigid_align +from rf3.training.checkpoint import activation_checkpointing # resolve residue-level symmetries in native vs pred diff --git a/src/modelhub/loss/loss.py b/models/rf3/src/rf3/loss/loss.py similarity index 100% rename from src/modelhub/loss/loss.py rename to models/rf3/src/rf3/loss/loss.py diff --git a/src/modelhub/metrics/clashing_chains.py b/models/rf3/src/rf3/metrics/clashing_chains.py similarity index 98% rename from src/modelhub/metrics/clashing_chains.py rename to models/rf3/src/rf3/metrics/clashing_chains.py index 83d1dab..4206d25 100644 --- a/src/modelhub/metrics/clashing_chains.py +++ b/models/rf3/src/rf3/metrics/clashing_chains.py @@ -4,7 +4,7 @@ from typing import Any import torch from biotite.structure import AtomArrayStack -from modelhub.metrics.base import Metric +from metrics.base import Metric class CountClashingChains(Metric): diff --git a/src/modelhub/metrics/distogram.py b/models/rf3/src/rf3/metrics/distogram.py similarity index 99% rename from src/modelhub/metrics/distogram.py rename to models/rf3/src/rf3/metrics/distogram.py index 515609d..38e57a4 100644 --- a/src/modelhub/metrics/distogram.py +++ b/models/rf3/src/rf3/metrics/distogram.py @@ -10,9 +10,9 @@ from biotite.structure import AtomArrayStack from einops import rearrange, repeat from jaxtyping import Bool, Float -from modelhub.loss.af3_losses import distogram_loss -from modelhub.metrics.base import Metric -from modelhub.utils.torch_utils import assert_no_nans +from rf3.loss.af3_losses import distogram_loss +from metrics.base import Metric +from utils.torch import assert_no_nans @dataclass diff --git a/src/modelhub/metrics/lddt.py b/models/rf3/src/rf3/metrics/lddt.py similarity index 99% rename from src/modelhub/metrics/lddt.py rename to models/rf3/src/rf3/metrics/lddt.py index 85cb1ef..f88ea67 100644 --- a/src/modelhub/metrics/lddt.py +++ b/models/rf3/src/rf3/metrics/lddt.py @@ -11,8 +11,8 @@ from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack, stack from jaxtyping import Bool, Float, Int -from modelhub.metrics.base import Metric -from modelhub.utils.ddp import RankedLogger +from metrics.base import Metric +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/metrics/metadata.py b/models/rf3/src/rf3/metrics/metadata.py similarity index 97% rename from src/modelhub/metrics/metadata.py rename to models/rf3/src/rf3/metrics/metadata.py index ebbd3b9..90b86eb 100644 --- a/src/modelhub/metrics/metadata.py +++ b/models/rf3/src/rf3/metrics/metadata.py @@ -2,7 +2,7 @@ import json from beartype.typing import Any, Literal -from modelhub.metrics.base import Metric +from metrics.base import Metric class ExtraInfo(Metric): diff --git a/src/modelhub/metrics/metric_utils.py b/models/rf3/src/rf3/metrics/metric_utils.py similarity index 100% rename from src/modelhub/metrics/metric_utils.py rename to models/rf3/src/rf3/metrics/metric_utils.py diff --git a/src/modelhub/metrics/predicted_error.py b/models/rf3/src/rf3/metrics/predicted_error.py similarity index 97% rename from src/modelhub/metrics/predicted_error.py rename to models/rf3/src/rf3/metrics/predicted_error.py index 840b431..824d81d 100644 --- a/src/modelhub/metrics/predicted_error.py +++ b/models/rf3/src/rf3/metrics/predicted_error.py @@ -2,8 +2,8 @@ from typing import Any import torch -from modelhub.metrics.base import Metric -from modelhub.metrics.metric_utils import find_bin_midpoints +from metrics.base import Metric +from rf3.metrics.metric_utils import find_bin_midpoints def compute_ptm( diff --git a/src/modelhub/metrics/selected_distances.py b/models/rf3/src/rf3/metrics/selected_distances.py similarity index 98% rename from src/modelhub/metrics/selected_distances.py rename to models/rf3/src/rf3/metrics/selected_distances.py index 6e54906..01ef478 100644 --- a/src/modelhub/metrics/selected_distances.py +++ b/models/rf3/src/rf3/metrics/selected_distances.py @@ -7,7 +7,7 @@ from atomworks.ml.utils.selection import ( from beartype.typing import Any from biotite.structure import AtomArrayStack -from modelhub.metrics.base import Metric +from metrics.base import Metric class SelectedAtomByAtomDistances(Metric): diff --git a/src/modelhub/model/RF3.py b/models/rf3/src/rf3/model/RF3.py similarity index 98% rename from src/modelhub/model/RF3.py rename to models/rf3/src/rf3/model/RF3.py index d511b12..5aaa6ed 100644 --- a/src/modelhub/model/RF3.py +++ b/models/rf3/src/rf3/model/RF3.py @@ -7,15 +7,15 @@ from beartype.typing import Any, Generator, Protocol from omegaconf import DictConfig from torch import nn -from modelhub.diffusion_samplers.inference_sampler import ( +from rf3.diffusion_samplers.inference_sampler import ( SampleDiffusion, SamplePartialDiffusion, ) -from modelhub.model.layers.pairformer_layers import ( +from rf3.model.layers.pairformer_layers import ( FeatureInitializer, ) -from modelhub.model.RF3_structure import DiffusionModule, DistogramHead, Recycler -from modelhub.training.checkpoint import create_custom_forward +from rf3.model.RF3_structure import DiffusionModule, DistogramHead, Recycler +from rf3.training.checkpoint import create_custom_forward """ Shape Annotation Glossary: @@ -327,7 +327,7 @@ class RF3WithConfidence(RF3): rollouts during inference) """ # (Lazy import) - from modelhub.model.layers.af3_auxiliary_heads import ConfidenceHead # noqa + from rf3.model.layers.af3_auxiliary_heads import ConfidenceHead # noqa super().__init__(**kwargs) diff --git a/src/modelhub/model/RF3_blocks.py b/models/rf3/src/rf3/model/RF3_blocks.py similarity index 97% rename from src/modelhub/model/RF3_blocks.py rename to models/rf3/src/rf3/model/RF3_blocks.py index 797ae53..e274600 100644 --- a/src/modelhub/model/RF3_blocks.py +++ b/models/rf3/src/rf3/model/RF3_blocks.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn import torch.nn.functional as F -from modelhub.training.checkpoint import activation_checkpointing +from rf3.training.checkpoint import activation_checkpointing class MSASubsampleEmbedder(nn.Module): diff --git a/src/modelhub/model/RF3_structure.py b/models/rf3/src/rf3/model/RF3_structure.py similarity index 97% rename from src/modelhub/model/RF3_structure.py rename to models/rf3/src/rf3/model/RF3_structure.py index 9337186..b1b7402 100644 --- a/src/modelhub/model/RF3_structure.py +++ b/models/rf3/src/rf3/model/RF3_structure.py @@ -3,19 +3,19 @@ import logging import torch import torch.nn as nn -from modelhub.model.layers.af3_diffusion_transformer import ( +from rf3.model.layers.af3_diffusion_transformer import ( AtomAttentionEncoderDiffusion, AtomTransformer, DiffusionTransformer, ) -from modelhub.model.layers.layer_utils import Transition, linearNoBias -from modelhub.model.layers.pairformer_layers import ( +from rf3.model.layers.layer_utils import Transition, linearNoBias +from rf3.model.layers.pairformer_layers import ( MSAModule, PairformerBlock, RelativePositionEncoding, RF3TemplateEmbedder, ) -from modelhub.training.checkpoint import activation_checkpointing +from rf3.training.checkpoint import activation_checkpointing logger = logging.getLogger(__name__) diff --git a/src/modelhub/model/layers/Attention_module.py b/models/rf3/src/rf3/model/layers/Attention_module.py similarity index 99% rename from src/modelhub/model/layers/Attention_module.py rename to models/rf3/src/rf3/model/layers/Attention_module.py index 36fae27..b983f11 100644 --- a/src/modelhub/model/layers/Attention_module.py +++ b/models/rf3/src/rf3/model/layers/Attention_module.py @@ -6,9 +6,9 @@ import torch.nn.functional as F from einops import rearrange from opt_einsum import contract as einsum -from modelhub import SHOULD_USE_CUEQUIVARIANCE -from modelhub.training.checkpoint import activation_checkpointing -from modelhub.util_module import init_lecun_normal +from src import SHOULD_USE_CUEQUIVARIANCE +from rf3.training.checkpoint import activation_checkpointing +from rf3.util_module import init_lecun_normal if SHOULD_USE_CUEQUIVARIANCE: import cuequivariance_torch as cuet diff --git a/src/modelhub/model/layers/FusedTriangleMultiplication.py b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py similarity index 98% rename from src/modelhub/model/layers/FusedTriangleMultiplication.py rename to models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py index f2bb9dd..92c5bb1 100644 --- a/src/modelhub/model/layers/FusedTriangleMultiplication.py +++ b/models/rf3/src/rf3/model/layers/FusedTriangleMultiplication.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from jaxtyping import Float -from modelhub import SHOULD_USE_CUEQUIVARIANCE -from modelhub.util_module import init_lecun_normal +from src import SHOULD_USE_CUEQUIVARIANCE +from rf3.util_module import init_lecun_normal if SHOULD_USE_CUEQUIVARIANCE: import cuequivariance_torch as cuet diff --git a/src/modelhub/model/layers/af3_auxiliary_heads.py b/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py similarity index 98% rename from src/modelhub/model/layers/af3_auxiliary_heads.py rename to models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py index b08fd42..9eb7db3 100644 --- a/src/modelhub/model/layers/af3_auxiliary_heads.py +++ b/models/rf3/src/rf3/model/layers/af3_auxiliary_heads.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn import torch.nn.functional as F -import modelhub -from modelhub.model.RF3_structure import PairformerBlock, linearNoBias +import src +from rf3.model.RF3_structure import PairformerBlock, linearNoBias # TODO: Get from RF2AA encoding instead CHEM_DATA_LEGACY = {"NHEAVY": 23, "aa2num": {"UNK": 20, "GLY": 7, "MAS": 21}} @@ -239,7 +239,7 @@ def calc_Cb_distances(X_pred_L, seq, rep_atoms, frame_atom_idxs): & (seq != CHEM_DATA_LEGACY.aa2num["GLY"]) & (seq != CHEM_DATA_LEGACY.aa2num["MAS"]) ) - is_valid_Cb = is_valid_Cb & modelhub.util.is_protein(seq) + is_valid_Cb = is_valid_Cb & src.util.is_protein(seq) b = Ca - N c = C - Ca diff --git a/src/modelhub/model/layers/af3_diffusion_transformer.py b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py similarity index 98% rename from src/modelhub/model/layers/af3_diffusion_transformer.py rename to models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py index 05669f2..ba3d5a2 100644 --- a/src/modelhub/model/layers/af3_diffusion_transformer.py +++ b/models/rf3/src/rf3/model/layers/af3_diffusion_transformer.py @@ -2,17 +2,17 @@ import numpy as np import torch import torch.nn as nn -from modelhub.loss.loss import calc_chiral_grads_flat_impl -from modelhub.model.layers.layer_utils import ( +from rf3.loss.loss import calc_chiral_grads_flat_impl +from rf3.model.layers.layer_utils import ( AdaLN, LinearBiasInit, MultiDimLinear, collapse, linearNoBias, ) -from modelhub.model.layers.mlff import ConformerEmbeddingWeightedAverage -from modelhub.training.checkpoint import activation_checkpointing -from modelhub.utils.torch_utils import device_of +from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage +from rf3.training.checkpoint import activation_checkpointing +from utils.torch import device_of class AtomAttentionEncoderDiffusion(nn.Module): diff --git a/src/modelhub/model/layers/layer_utils.py b/models/rf3/src/rf3/model/layers/layer_utils.py similarity index 98% rename from src/modelhub/model/layers/layer_utils.py rename to models/rf3/src/rf3/model/layers/layer_utils.py index 216ba1b..8d6149f 100644 --- a/src/modelhub/model/layers/layer_utils.py +++ b/models/rf3/src/rf3/model/layers/layer_utils.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from torch.nn.functional import silu -from modelhub.training.checkpoint import activation_checkpointing +from rf3.training.checkpoint import activation_checkpointing linearNoBias = partial(torch.nn.Linear, bias=False) diff --git a/src/modelhub/model/layers/mlff.py b/models/rf3/src/rf3/model/layers/mlff.py similarity index 100% rename from src/modelhub/model/layers/mlff.py rename to models/rf3/src/rf3/model/layers/mlff.py diff --git a/src/modelhub/model/layers/outer_product.py b/models/rf3/src/rf3/model/layers/outer_product.py similarity index 94% rename from src/modelhub/model/layers/outer_product.py rename to models/rf3/src/rf3/model/layers/outer_product.py index b22df2a..4657e07 100644 --- a/src/modelhub/model/layers/outer_product.py +++ b/models/rf3/src/rf3/model/layers/outer_product.py @@ -1,8 +1,8 @@ import torch import torch.nn as nn -from modelhub.training.checkpoint import activation_checkpointing -from modelhub.util_module import init_lecun_normal +from rf3.training.checkpoint import activation_checkpointing +from rf3.util_module import init_lecun_normal class OuterProductMean(nn.Module): diff --git a/src/modelhub/model/layers/pairformer_layers.py b/models/rf3/src/rf3/model/layers/pairformer_layers.py similarity index 98% rename from src/modelhub/model/layers/pairformer_layers.py rename to models/rf3/src/rf3/model/layers/pairformer_layers.py index 91e9f1f..61874ab 100644 --- a/src/modelhub/model/layers/pairformer_layers.py +++ b/models/rf3/src/rf3/model/layers/pairformer_layers.py @@ -2,30 +2,30 @@ import torch from torch import nn from torch.nn.functional import one_hot, relu -from modelhub.data.ground_truth_template import ( +from rf3.data.ground_truth_template import ( af3_noise_scale_to_noise_level, ) -from modelhub.model.layers.af3_diffusion_transformer import AtomTransformer -from modelhub.model.layers.Attention_module import ( +from rf3.model.layers.af3_diffusion_transformer import AtomTransformer +from rf3.model.layers.Attention_module import ( TriangleAttention, ) -from modelhub.model.layers.FusedTriangleMultiplication import ( +from rf3.model.layers.FusedTriangleMultiplication import ( FusedTriangleMultiplication, ) -from modelhub.model.layers.layer_utils import ( +from rf3.model.layers.layer_utils import ( MultiDimLinear, Transition, collapse, create_batch_dimension_if_not_present, linearNoBias, ) -from modelhub.model.layers.mlff import ConformerEmbeddingWeightedAverage -from modelhub.model.layers.outer_product import ( +from rf3.model.layers.mlff import ConformerEmbeddingWeightedAverage +from rf3.model.layers.outer_product import ( OuterProductMean_AF3, ) -from modelhub.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder -from modelhub.training.checkpoint import activation_checkpointing -from modelhub.util_module import Dropout +from rf3.model.RF3_blocks import MSAPairWeightedAverage, MSASubsampleEmbedder +from rf3.training.checkpoint import activation_checkpointing +from rf3.util_module import Dropout class AtomAttentionEncoderPairformer(nn.Module): diff --git a/src/modelhub/model/layers/structure_bias.py b/models/rf3/src/rf3/model/layers/structure_bias.py similarity index 96% rename from src/modelhub/model/layers/structure_bias.py rename to models/rf3/src/rf3/model/layers/structure_bias.py index 76b1749..5043ef6 100644 --- a/src/modelhub/model/layers/structure_bias.py +++ b/models/rf3/src/rf3/model/layers/structure_bias.py @@ -2,7 +2,7 @@ import torch import torch.nn as nn from opt_einsum import contract as einsum -from modelhub.util_module import init_lecun_normal, rbf +from rf3.util_module import init_lecun_normal, rbf class StructureBias(torch.nn.Module): diff --git a/src/modelhub/scoring.py b/models/rf3/src/rf3/scoring.py similarity index 100% rename from src/modelhub/scoring.py rename to models/rf3/src/rf3/scoring.py diff --git a/src/modelhub/symmetry/resolve.py b/models/rf3/src/rf3/symmetry/resolve.py similarity index 99% rename from src/modelhub/symmetry/resolve.py rename to models/rf3/src/rf3/symmetry/resolve.py index 07089d0..3b18a62 100644 --- a/src/modelhub/symmetry/resolve.py +++ b/models/rf3/src/rf3/symmetry/resolve.py @@ -15,7 +15,7 @@ from atomworks.ml.transforms.symmetry import FindAutomorphismsWithNetworkX from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import Bool, Float, Int -from modelhub.loss.af3_losses import ( +from rf3.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, ) diff --git a/src/modelhub/train.py b/models/rf3/src/rf3/train.py similarity index 93% rename from src/modelhub/train.py rename to models/rf3/src/rf3/train.py index dc96d66..a8b31d9 100755 --- a/src/modelhub/train.py +++ b/models/rf3/src/rf3/train.py @@ -8,8 +8,8 @@ import rootutils from dotenv import load_dotenv from omegaconf import DictConfig -from modelhub.utils.logging import suppress_warnings -from modelhub.utils.weights import CheckpointConfig +from utils.logging import suppress_warnings +from utils.weights import CheckpointConfig load_dotenv(override=True) @@ -43,15 +43,15 @@ def train(cfg: DictConfig) -> None: # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision torch.set_float32_matmul_precision("medium") - from modelhub.callbacks.base import BaseCallback # noqa - from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa - from modelhub.utils.logging import ( + from callbacks.base import BaseCallback # noqa + from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa + from utils.logging import ( print_config_tree, log_hyperparameters_with_all_loggers, ) # noqa - from modelhub.utils.ddp import RankedLogger # noqa - from modelhub.utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa - from modelhub.utils.datasets import ( + from utils.ddp import RankedLogger # noqa + from utils.ddp import is_rank_zero, set_accelerator_based_on_availability # noqa + from rf3.utils.datasets import ( recursively_instantiate_datasets_and_samplers, assemble_distributed_loader, subset_dataset_to_example_ids, diff --git a/src/modelhub/trainers/rf3.py b/models/rf3/src/rf3/trainers/rf3.py similarity index 97% rename from src/modelhub/trainers/rf3.py rename to models/rf3/src/rf3/trainers/rf3.py index d5bac48..f375e23 100644 --- a/src/modelhub/trainers/rf3.py +++ b/models/rf3/src/rf3/trainers/rf3.py @@ -6,20 +6,20 @@ from jaxtyping import Float, Int from lightning_utilities import apply_to_collection from omegaconf import DictConfig -from modelhub.common import exists -from modelhub.loss.af3_losses import Loss as AF3Loss -from modelhub.loss.af3_losses import ( +from common import exists +from rf3.loss.af3_losses import Loss as AF3Loss +from rf3.loss.af3_losses import ( ResidueSymmetryResolution, SubunitSymmetryResolution, ) -from modelhub.metrics.base import MetricManager -from modelhub.model.RF3 import ShouldEarlyStopFn -from modelhub.trainers.fabric import FabricTrainer -from modelhub.training.EMA import EMA -from modelhub.utils.ddp import RankedLogger -from modelhub.utils.io import build_stack_from_atom_array_and_batched_coords -from modelhub.utils.recycling import get_recycle_schedule -from modelhub.utils.torch_utils import assert_no_nans, assert_same_shape +from metrics.base import MetricManager +from rf3.model.RF3 import ShouldEarlyStopFn +from trainers.fabric import FabricTrainer +from rf3.training.EMA import EMA +from utils.ddp import RankedLogger +from rf3.utils.io import build_stack_from_atom_array_and_batched_coords +from rf3.utils.recycling import get_recycle_schedule +from utils.torch import assert_no_nans, assert_same_shape ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/training/EMA.py b/models/rf3/src/rf3/training/EMA.py similarity index 100% rename from src/modelhub/training/EMA.py rename to models/rf3/src/rf3/training/EMA.py diff --git a/models/rf3/src/rf3/training/checkpoint.py b/models/rf3/src/rf3/training/checkpoint.py new file mode 100644 index 0000000..da42cd7 --- /dev/null +++ b/models/rf3/src/rf3/training/checkpoint.py @@ -0,0 +1,84 @@ +"""Utilities for gradient checkpointing to reduce memory usage during training. + +Gradient checkpointing (also called activation checkpointing) trades compute for memory +by recomputing intermediate activations during the backward pass instead of storing them. +This enables training larger models or using larger batch sizes within GPU memory constraints. + +References: + * `PyTorch Checkpoint Documentation`_ + + .. _PyTorch Checkpoint Documentation: https://pytorch.org/docs/stable/checkpoint.html +""" + +import torch +from torch.utils.checkpoint import checkpoint + + +def create_custom_forward(module, **kwargs): + """Create a custom forward function for gradient checkpointing with fixed kwargs. + + This helper enables passing keyword arguments to a module when using PyTorch's + checkpoint function, which only accepts positional arguments for the function to + be checkpointed. + + Args: + module: The callable (typically a nn.Module) to wrap. + **kwargs: Keyword arguments to pass to the module during forward. + + Returns: + A callable that accepts only positional arguments and forwards them along + with the fixed kwargs to the original module. + + Examples: + Use with PyTorch checkpoint:: + + custom_fn = create_custom_forward(my_module, frame_atom_idxs=frame_idxs) + output = checkpoint(custom_fn, input_tensor, use_reentrant=False) + + See Also: + :py:func:`activation_checkpointing` + """ + + def custom_forward(*inputs): + return module(*inputs, **kwargs) + + return custom_forward + + +def activation_checkpointing(function): + """Decorator to enable gradient checkpointing for a function during training. + + When gradients are enabled (training mode), this decorator wraps the function + with PyTorch's checkpoint to save memory by recomputing activations during + the backward pass. During inference (gradients disabled), the function runs + normally without checkpointing overhead. + + Args: + function: The function to apply gradient checkpointing to. + + Returns: + Wrapped function that conditionally applies checkpointing based on gradient state. + + Examples: + Apply to a forward pass method:: + + @activation_checkpointing + def forward(self, x, mask=None): + return self.layer(x, mask) + + Notes: + Uses ``use_reentrant=False`` for better compatibility with modern PyTorch + features like autograd hooks and higher-order gradients. + + See Also: + :py:func:`create_custom_forward` + """ + + def wrapper(*args, **kwargs): + if torch.is_grad_enabled(): + return checkpoint( + create_custom_forward(function, **kwargs), *args, use_reentrant=False + ) + return function(*args, **kwargs) + + return wrapper diff --git a/src/modelhub/training/schedulers.py b/models/rf3/src/rf3/training/schedulers.py similarity index 100% rename from src/modelhub/training/schedulers.py rename to models/rf3/src/rf3/training/schedulers.py diff --git a/src/modelhub/util_module.py b/models/rf3/src/rf3/util_module.py similarity index 100% rename from src/modelhub/util_module.py rename to models/rf3/src/rf3/util_module.py diff --git a/src/modelhub/utils/datasets.py b/models/rf3/src/rf3/utils/datasets.py similarity index 99% rename from src/modelhub/utils/datasets.py rename to models/rf3/src/rf3/utils/datasets.py index 3cf5e46..9e3cf4d 100755 --- a/src/modelhub/utils/datasets.py +++ b/models/rf3/src/rf3/utils/datasets.py @@ -20,8 +20,8 @@ from torch.utils.data import ( ) from torch.utils.data.distributed import DistributedSampler -from modelhub.resolvers import register_resolvers -from modelhub.utils.ddp import RankedLogger +from hydra.resolvers import register_resolvers +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) try: diff --git a/src/modelhub/utils/frames.py b/models/rf3/src/rf3/utils/frames.py similarity index 98% rename from src/modelhub/utils/frames.py rename to models/rf3/src/rf3/utils/frames.py index 36509df..f707c89 100644 --- a/src/modelhub/utils/frames.py +++ b/models/rf3/src/rf3/utils/frames.py @@ -2,7 +2,7 @@ import torch -from modelhub.chemical import NFRAMES, NNAPROTAAS, costgtNA +from rf3.chemical import NFRAMES, NNAPROTAAS, costgtNA def is_atom(seq): diff --git a/src/modelhub/utils/inference.py b/models/rf3/src/rf3/utils/inference.py similarity index 99% rename from src/modelhub/utils/inference.py rename to models/rf3/src/rf3/utils/inference.py index dc7b9fa..0776294 100644 --- a/src/modelhub/utils/inference.py +++ b/models/rf3/src/rf3/utils/inference.py @@ -16,7 +16,7 @@ from atomworks.io.utils.io_utils import to_cif_file from atomworks.io.utils.selection import AtomSelectionStack from biotite.structure import AtomArray -from modelhub.utils.io import ( +from rf3.utils.io import ( CIF_LIKE_EXTENSIONS, DICTIONARY_LIKE_EXTENSIONS, create_example_id_extractor, diff --git a/src/modelhub/utils/io.py b/models/rf3/src/rf3/utils/io.py similarity index 98% rename from src/modelhub/utils/io.py rename to models/rf3/src/rf3/utils/io.py index d0e219b..c663e49 100644 --- a/src/modelhub/utils/io.py +++ b/models/rf3/src/rf3/utils/io.py @@ -8,8 +8,8 @@ from atomworks.io.utils.io_utils import to_cif_file from beartype.typing import Literal from biotite.structure import AtomArray, AtomArrayStack, stack -from modelhub.alignment import weighted_rigid_align -from modelhub.utils.ddp import RankedLogger +from rf3.alignment import weighted_rigid_align +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/utils/loss.py b/models/rf3/src/rf3/utils/loss.py similarity index 100% rename from src/modelhub/utils/loss.py rename to models/rf3/src/rf3/utils/loss.py diff --git a/src/modelhub/utils/predicted_error.py b/models/rf3/src/rf3/utils/predicted_error.py similarity index 99% rename from src/modelhub/utils/predicted_error.py rename to models/rf3/src/rf3/utils/predicted_error.py index f244d21..4ef5010 100644 --- a/src/modelhub/utils/predicted_error.py +++ b/models/rf3/src/rf3/utils/predicted_error.py @@ -10,8 +10,8 @@ from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack from omegaconf import DictConfig -from modelhub.chemical import NHEAVY -from modelhub.metrics.metric_utils import ( +from rf3.chemical import NHEAVY +from rf3.metrics.metric_utils import ( compute_mean_over_subsampled_pairs, compute_min_over_subsampled_pairs, create_chainwise_masks_1d, diff --git a/src/modelhub/utils/recycling.py b/models/rf3/src/rf3/utils/recycling.py similarity index 100% rename from src/modelhub/utils/recycling.py rename to models/rf3/src/rf3/utils/recycling.py diff --git a/src/modelhub/validate.py b/models/rf3/src/rf3/validate.py similarity index 91% rename from src/modelhub/validate.py rename to models/rf3/src/rf3/validate.py index b3ec994..b06a2a7 100755 --- a/src/modelhub/validate.py +++ b/models/rf3/src/rf3/validate.py @@ -8,7 +8,7 @@ import rootutils from dotenv import load_dotenv from omegaconf import DictConfig -from modelhub.utils.logging import suppress_warnings +from utils.logging import suppress_warnings load_dotenv(override=True) @@ -42,12 +42,12 @@ def validate(cfg: DictConfig) -> None: # Reference: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision torch.set_float32_matmul_precision("medium") - from modelhub.callbacks.base import BaseCallback # noqa - from modelhub.utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa - from modelhub.utils.logging import print_config_tree # noqa - from modelhub.utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa - from modelhub.utils.ddp import is_rank_zero # noqa - from modelhub.utils.datasets import assemble_val_loader_dict # noqa + from callbacks.base import BaseCallback # noqa + from utils.instantiators import instantiate_loggers, instantiate_callbacks # noqa + from utils.logging import print_config_tree # noqa + from utils.ddp import RankedLogger, set_accelerator_based_on_availability # noqa + from utils.ddp import is_rank_zero # noqa + from rf3.utils.datasets import assemble_val_loader_dict # noqa set_accelerator_based_on_availability(cfg) diff --git a/tests/.gitkeep b/models/rf3/tests/.gitkeep similarity index 100% rename from tests/.gitkeep rename to models/rf3/tests/.gitkeep diff --git a/tests/conftest.py b/models/rf3/tests/conftest.py similarity index 66% rename from tests/conftest.py rename to models/rf3/tests/conftest.py index fea1fe4..dfdc002 100644 --- a/tests/conftest.py +++ b/models/rf3/tests/conftest.py @@ -1,4 +1,3 @@ -import os from pathlib import Path import pytest @@ -10,14 +9,16 @@ TEST_DATA_DIR = Path(__file__).resolve().parent / "data" def pytest_configure(config): - # Get the directory where conftest.py is located - current_dir = os.path.dirname(os.path.abspath(__file__)) + # Set PROJECT_ROOT + project_root = rootutils.setup_root( + __file__, indicator=".project-root", pythonpath=True + ) - # Construct path to .env file in the parent directory - dotenv_path = os.path.join(current_dir, "..", ".env") + # Construct path to .env file at project root + dotenv_path = project_root / ".env" # Check if the .env file exists - if not os.path.exists(dotenv_path): + if not dotenv_path.exists(): raise pytest.UsageError( f"ERROR: Required .env file not found at {dotenv_path}. " f"Please create this file with the necessary environment variables." @@ -26,9 +27,6 @@ def pytest_configure(config): # Load the environment variables load_dotenv(dotenv_path) - # Set PROJECT_ROOT - rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) - @pytest.fixture(scope="session") def gpu(): diff --git a/tests/data/5vht_from_file.cif b/models/rf3/tests/data/5vht_from_file.cif similarity index 100% rename from tests/data/5vht_from_file.cif rename to models/rf3/tests/data/5vht_from_file.cif diff --git a/tests/data/5vht_from_json.json b/models/rf3/tests/data/5vht_from_json.json similarity index 100% rename from tests/data/5vht_from_json.json rename to models/rf3/tests/data/5vht_from_json.json diff --git a/tests/data/example_from_pdb_with_inter_chain_bond.pdb b/models/rf3/tests/data/example_from_pdb_with_inter_chain_bond.pdb similarity index 100% rename from tests/data/example_from_pdb_with_inter_chain_bond.pdb rename to models/rf3/tests/data/example_from_pdb_with_inter_chain_bond.pdb diff --git a/tests/data/example_pdb_with_clashing_ligand_name.pdb b/models/rf3/tests/data/example_pdb_with_clashing_ligand_name.pdb similarity index 100% rename from tests/data/example_pdb_with_clashing_ligand_name.pdb rename to models/rf3/tests/data/example_pdb_with_clashing_ligand_name.pdb diff --git a/tests/data/example_with_ncaa.json b/models/rf3/tests/data/example_with_ncaa.json similarity index 100% rename from tests/data/example_with_ncaa.json rename to models/rf3/tests/data/example_with_ncaa.json diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score new file mode 100644 index 0000000..d28892f --- /dev/null +++ b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score @@ -0,0 +1,16 @@ +example_id,chain_chainwise,chainwise_plddt,chainwise_pde,chainwise_pae,overall_plddt,overall_pde,overall_pae,batch_idx,chain_i_interface,chain_j_interface,pae_interface,pde_interface,min_pae_interface,min_pde_interface +5vht_from_file,A_1,0.8494526147842407,0.7609825134277344,4.044447898864746,0.8531264662742615,0.8598654270172119,4.3676862716674805,0,,,,,, +5vht_from_file,B_1,0.8568002581596375,0.7851151823997498,4.099449634552002,0.8531264662742615,0.8598654270172119,4.3676862716674805,0,,,,,, +5vht_from_file,A_1,0.8471965789794922,0.7723569273948669,4.057911396026611,0.8510856628417969,0.8723170161247253,4.338787078857422,1,,,,,, +5vht_from_file,B_1,0.8549749851226807,0.7815006375312805,4.067974090576172,0.8510856628417969,0.8723170161247253,4.338787078857422,1,,,,,, +5vht_from_file,A_1,0.8491722345352173,0.7198135256767273,3.956014633178711,0.8529133796691895,0.8150432705879211,4.268957614898682,2,,,,,, +5vht_from_file,B_1,0.8566544651985168,0.7406660318374634,4.006967067718506,0.8529133796691895,0.8150432705879211,4.268957614898682,2,,,,,, +5vht_from_file,A_1,0.8499325513839722,0.7343265414237976,4.005363941192627,0.8535234928131104,0.8327312469482422,4.315551280975342,3,,,,,, +5vht_from_file,B_1,0.8571144938468933,0.754400908946991,4.048102378845215,0.8535234928131104,0.8327312469482422,4.315551280975342,3,,,,,, +5vht_from_file,A_1,0.8489580750465393,0.7557268142700195,4.078780651092529,0.8529950976371765,0.8573365211486816,4.347229957580566,4,,,,,, +5vht_from_file,B_1,0.8570321202278137,0.7667781114578247,4.044546127319336,0.8529950976371765,0.8573365211486816,4.347229957580566,4,,,,,, +5vht_from_file,,,,,0.8531264662742615,0.8598654270172119,4.3676862716674805,0,A_1,B_1,4.663423538208008,0.946681797504425,0.5509309768676758,0.30699265003204346 +5vht_from_file,,,,,0.8510856628417969,0.8723170161247253,4.338787078857422,1,A_1,B_1,4.614631175994873,0.967705249786377,0.5410354137420654,0.30483248829841614 +5vht_from_file,,,,,0.8529133796691895,0.8150432705879211,4.268957614898682,2,A_1,B_1,4.556424617767334,0.8998467326164246,0.5328854918479919,0.3051001727581024 +5vht_from_file,,,,,0.8535234928131104,0.8327312469482422,4.315551280975342,3,A_1,B_1,4.604368209838867,0.9210987687110901,0.5463621616363525,0.3068302571773529 +5vht_from_file,,,,,0.8529950976371765,0.8573365211486816,4.347229957580566,4,A_1,B_1,4.632795810699463,0.9534204006195068,0.5457666516304016,0.3067321479320526 diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv similarity index 75% rename from tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv rename to models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv index 7750ae9..23fb0d4 100644 --- a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv +++ b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_metrics.csv @@ -1,2 +1,2 @@ example_id,ptm.ptm_0,ptm.ptm_1,ptm.ptm_2,ptm.ptm_3,ptm.ptm_4,iptm.iptm_0,iptm.iptm_1,iptm.iptm_2,iptm.iptm_3,iptm.iptm_4,iptm.iptm_protein_protein_0,iptm.iptm_protein_protein_1,iptm.iptm_protein_protein_2,iptm.iptm_protein_protein_3,iptm.iptm_protein_protein_4,iptm.iptm_protein_ligand_0,iptm.iptm_protein_ligand_1,iptm.iptm_protein_ligand_2,iptm.iptm_protein_ligand_3,iptm.iptm_protein_ligand_4,iptm.iptm_ligand_ligand_0,iptm.iptm_ligand_ligand_1,iptm.iptm_ligand_ligand_2,iptm.iptm_ligand_ligand_3,iptm.iptm_ligand_ligand_4,count_clashing_chains.has_clash_0,count_clashing_chains.has_clash_1,count_clashing_chains.has_clash_2,count_clashing_chains.has_clash_3,count_clashing_chains.has_clash_4 -5vht_from_file,0.9285362,0.92923117,0.930522,0.92750067,0.927776,0.9287108,0.92952603,0.93126476,0.92829627,0.9283343,0.9287108,0.92952603,0.93126476,0.92829627,0.9283343,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0 +5vht_from_file,0.92653567,0.92595327,0.9294402,0.92781174,0.9264588,0.9272439,0.926157,0.9301085,0.9286177,0.92686117,0.9272439,0.926157,0.9301085,0.9286177,0.92686117,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0,0,0,0,0 diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz new file mode 100644 index 0000000..71e84a4 Binary files /dev/null and b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz differ diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz new file mode 100644 index 0000000..e73bbb8 Binary files /dev/null and b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz differ diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz new file mode 100644 index 0000000..0343d9d Binary files /dev/null and b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz differ diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz new file mode 100644 index 0000000..f08e493 Binary files /dev/null and b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz differ diff --git a/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz new file mode 100644 index 0000000..8a61b7f Binary files /dev/null and b/models/rf3/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz differ diff --git a/tests/data/msas/5vht_A.a3m b/models/rf3/tests/data/msas/5vht_A.a3m similarity index 100% rename from tests/data/msas/5vht_A.a3m rename to models/rf3/tests/data/msas/5vht_A.a3m diff --git a/tests/data/multiple_examples_from_json.json b/models/rf3/tests/data/multiple_examples_from_json.json similarity index 100% rename from tests/data/multiple_examples_from_json.json rename to models/rf3/tests/data/multiple_examples_from_json.json diff --git a/tests/data/ncaa/create_cif_with_ligand_as_ncaa.ipynb b/models/rf3/tests/data/ncaa/create_cif_with_ligand_as_ncaa.ipynb similarity index 100% rename from tests/data/ncaa/create_cif_with_ligand_as_ncaa.ipynb rename to models/rf3/tests/data/ncaa/create_cif_with_ligand_as_ncaa.ipynb diff --git a/tests/data/ncaa/ligand_as_ncaa.cif b/models/rf3/tests/data/ncaa/ligand_as_ncaa.cif similarity index 100% rename from tests/data/ncaa/ligand_as_ncaa.cif rename to models/rf3/tests/data/ncaa/ligand_as_ncaa.cif diff --git a/tests/data/ncaa/penicillin_ts2_as_ncaa.cif b/models/rf3/tests/data/ncaa/penicillin_ts2_as_ncaa.cif similarity index 100% rename from tests/data/ncaa/penicillin_ts2_as_ncaa.cif rename to models/rf3/tests/data/ncaa/penicillin_ts2_as_ncaa.cif diff --git a/tests/data/nested_examples/example_from_json.json b/models/rf3/tests/data/nested_examples/example_from_json.json similarity index 100% rename from tests/data/nested_examples/example_from_json.json rename to models/rf3/tests/data/nested_examples/example_from_json.json diff --git a/tests/data/nested_examples/example_from_pdb_with_inter_chain_bonds.and.dots.pdb b/models/rf3/tests/data/nested_examples/example_from_pdb_with_inter_chain_bonds.and.dots.pdb similarity index 100% rename from tests/data/nested_examples/example_from_pdb_with_inter_chain_bonds.and.dots.pdb rename to models/rf3/tests/data/nested_examples/example_from_pdb_with_inter_chain_bonds.and.dots.pdb diff --git a/tests/test_inference_regression.py b/models/rf3/tests/test_inference_regression.py similarity index 88% rename from tests/test_inference_regression.py rename to models/rf3/tests/test_inference_regression.py index b70ed0a..bdab0d0 100755 --- a/tests/test_inference_regression.py +++ b/models/rf3/tests/test_inference_regression.py @@ -6,7 +6,6 @@ from pathlib import Path import numpy as np import pandas as pd import pytest -import torch from atomworks.ml.utils.rng import ( create_rng_state_from_seeds, rng_state, @@ -45,7 +44,7 @@ def compare_csv_files( max_diff = diff.max() assert ( max_diff <= tolerance - ), f"Numerical difference {max_diff} exceeds tolerance {tolerance} in column {col} of {predicted_file.name}. Predicted: {predicted_df[col]}, Baseline: {baseline_df[col]}" + ), f"Numerical difference {max_diff} exceeds tolerance {tolerance} in column {col} of {predicted_file.name}" else: # Exact comparison for non-numeric assert predicted_df[col].equals( @@ -55,8 +54,6 @@ def compare_csv_files( @pytest.mark.gpu def test_inference_regression(): - print("GPU available: ", torch.cuda.is_available()) - # inputs = "/home/ncorley/projects/modelhub_dev/tests/data/5vht_from_file.cif" inputs = TEST_DATA_DIR / "5vht_from_file.cif" data_dir = TEST_DATA_DIR / "inference_regression_tests" / "5vht_from_file" @@ -69,7 +66,7 @@ def test_inference_regression(): cfg = compose( config_name="inference", overrides=[ - "inference_engine=af3", + "inference_engine=rf3", f"inputs={inputs}", "annotate_b_factor_with_plddt=true", "one_model_per_file=false", @@ -87,7 +84,7 @@ def test_inference_regression(): # (CSV files with confidence outputs) for file in data_dir.glob("*.csv"): predicted_file = Path(temp_dir) / file.name - compare_csv_files(predicted_file, file, tolerance=1e-3) + compare_csv_files(predicted_file, file, tolerance=2e-3) if __name__ == "__main__": diff --git a/tests/test_write_confidence.py b/models/rf3/tests/test_write_confidence.py similarity index 95% rename from tests/test_write_confidence.py rename to models/rf3/tests/test_write_confidence.py index 3a1cad5..a82aa3d 100644 --- a/tests/test_write_confidence.py +++ b/models/rf3/tests/test_write_confidence.py @@ -4,12 +4,12 @@ import torch from lightning.fabric import seed_everything from omegaconf import DictConfig -from modelhub.chemical import NHEAVY, heavyatom_mask -from modelhub.metrics.metric_utils import ( +from rf3.chemical import NHEAVY, heavyatom_mask +from rf3.metrics.metric_utils import ( find_bin_midpoints, unbin_logits, ) -from modelhub.utils.predicted_error import compile_af3_confidence_outputs +from rf3.utils.predicted_error import compile_af3_confidence_outputs def test_compile_af3_confidence_outputs(): @@ -80,7 +80,12 @@ def test_compile_af3_confidence_outputs(): "chain_j_interface", "pae_interface", "pde_interface", + "min_pae_interface", + "min_pde_interface", ] + + print(df.columns.tolist()) + print(target_columns) assert df.columns.tolist() == target_columns, "Dataframe columns not set correctly" assert df.shape == ( num_batches * (num_interfaces + num_chains), diff --git a/src/modelhub/callbacks/base.py b/src/modelhub/callbacks/callback.py similarity index 100% rename from src/modelhub/callbacks/base.py rename to src/modelhub/callbacks/callback.py diff --git a/src/modelhub/callbacks/health_logging.py b/src/modelhub/callbacks/health_logging.py index 2732b50..44e617c 100644 --- a/src/modelhub/callbacks/health_logging.py +++ b/src/modelhub/callbacks/health_logging.py @@ -10,7 +10,7 @@ from jaxtyping import Float, Int from lightning.fabric.utilities.rank_zero import rank_zero_only from torch import Tensor -from modelhub.callbacks.base import BaseCallback +from callbacks.base import BaseCallback _DEFAULT_STATISTICS = types.MappingProxyType( { diff --git a/src/modelhub/callbacks/timing_logging.py b/src/modelhub/callbacks/timing_logging.py index 39bf360..4b4ecca 100644 --- a/src/modelhub/callbacks/timing_logging.py +++ b/src/modelhub/callbacks/timing_logging.py @@ -1,9 +1,9 @@ import pandas as pd from lightning.fabric.utilities.rank_zero import rank_zero_only -from modelhub.callbacks.base import BaseCallback -from modelhub.utils.logging import print_df_as_table -from modelhub.utils.torch_utils import Timers +from callbacks.base import BaseCallback +from rf3.utils.logging import print_df_as_table +from rf3.utils.torch_utils import Timers class TimingCallback(BaseCallback): diff --git a/src/modelhub/callbacks/train_logging.py b/src/modelhub/callbacks/train_logging.py index b480a0a..65842e4 100755 --- a/src/modelhub/callbacks/train_logging.py +++ b/src/modelhub/callbacks/train_logging.py @@ -13,15 +13,15 @@ from rich.table import Table from torch import nn from torchmetrics.aggregation import MeanMetric -from modelhub.callbacks.base import BaseCallback -from modelhub.utils.ddp import RankedLogger -from modelhub.utils.logging import ( +from callbacks.base import BaseCallback +from rf3.utils.ddp import RankedLogger +from rf3.utils.logging import ( print_df_as_table, print_model_parameters, safe_print, table_from_df, ) -from modelhub.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses +from rf3.utils.loss import convert_batched_losses_to_list_of_dicts, mean_losses class LogModelParametersCallback(BaseCallback): diff --git a/src/modelhub/resolvers.py b/src/modelhub/hydra/resolvers.py similarity index 98% rename from src/modelhub/resolvers.py rename to src/modelhub/hydra/resolvers.py index 0b65ae2..99a8237 100644 --- a/src/modelhub/resolvers.py +++ b/src/modelhub/hydra/resolvers.py @@ -10,7 +10,7 @@ from atomworks.enums import ChainType, ChainTypeInfo from beartype.typing import Any from omegaconf import OmegaConf -from .common import run_once +from ..common import run_once # (Custom resolvers) diff --git a/src/modelhub/metrics/base.py b/src/modelhub/metrics/base.py index 2fb8444..391ee12 100644 --- a/src/modelhub/metrics/base.py +++ b/src/modelhub/metrics/base.py @@ -8,7 +8,7 @@ from beartype.typing import Any from omegaconf import DictConfig from toolz import keymap -from modelhub.utils.ddp import RankedLogger +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/metrics/chiral.py b/src/modelhub/metrics/chiral.py index 7d500dc..260b872 100644 --- a/src/modelhub/metrics/chiral.py +++ b/src/modelhub/metrics/chiral.py @@ -1,16 +1,16 @@ import torch +from atomworks.io.transforms.atom_array import ensure_atom_array_stack from atomworks.ml.transforms.af3_reference_molecule import ( get_af3_reference_molecule_features, ) -from atomworks.ml.transforms.atom_array import ensure_atom_array_stack from atomworks.ml.transforms.chirals import add_af3_chiral_features from atomworks.ml.transforms.rdkit_utils import get_rdkit_chiral_centers from beartype.typing import Any from biotite.structure import AtomArray, AtomArrayStack from jaxtyping import Bool, Float -from modelhub.kinematics import get_dih -from modelhub.metrics.base import Metric +from rf3.kinematics import get_dih +from metrics.base import Metric def calc_chiral_metrics_masked( diff --git a/src/modelhub/metrics/rasa.py b/src/modelhub/metrics/rasa.py index deb11f7..d64b468 100644 --- a/src/modelhub/metrics/rasa.py +++ b/src/modelhub/metrics/rasa.py @@ -3,7 +3,7 @@ from atomworks.ml.transforms.sasa import calculate_atomwise_rasa from beartype.typing import Any from biotite.structure import AtomArrayStack -from modelhub.metrics.base import Metric +from metrics.base import Metric class UnresolvedRegionRASA(Metric): diff --git a/src/modelhub/trainers/fabric.py b/src/modelhub/trainers/fabric.py index 891a532..bc90df2 100755 --- a/src/modelhub/trainers/fabric.py +++ b/src/modelhub/trainers/fabric.py @@ -26,11 +26,11 @@ from lightning.fabric.wrappers import ( _FabricOptimizer, ) -from modelhub.callbacks.base import BaseCallback -from modelhub.training.EMA import EMA -from modelhub.training.schedulers import SchedulerConfig -from modelhub.utils.ddp import RankedLogger -from modelhub.utils.weights import ( +from callbacks.base import BaseCallback +from rf3.training.EMA import EMA +from rf3.training.schedulers import SchedulerConfig +from utils.ddp import RankedLogger +from utils.weights import ( CheckpointConfig, WeightLoadingConfig, freeze_parameters_with_config, diff --git a/src/modelhub/training/checkpoint.py b/src/modelhub/training/checkpoint.py deleted file mode 100644 index ad71ff8..0000000 --- a/src/modelhub/training/checkpoint.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch -from torch.utils.checkpoint import checkpoint - - -# for gradient checkpointing -def create_custom_forward(module, **kwargs): - def custom_forward(*inputs): - return module(*inputs, **kwargs) - - return custom_forward - - -def activation_checkpointing(function): - def wrapper(*args, **kwargs): - if torch.is_grad_enabled(): - return checkpoint( - create_custom_forward(function, **kwargs), *args, use_reentrant=False - ) - return function(*args, **kwargs) - - return wrapper diff --git a/src/modelhub/utils/instantiators.py b/src/modelhub/utils/instantiators.py index 836bc85..417bc2a 100755 --- a/src/modelhub/utils/instantiators.py +++ b/src/modelhub/utils/instantiators.py @@ -2,7 +2,7 @@ import hydra from lightning.fabric.loggers import Logger from omegaconf import DictConfig -from modelhub.callbacks.base import BaseCallback +from callbacks.base import BaseCallback def _can_be_instantiated(cfg: DictConfig) -> bool: diff --git a/src/modelhub/utils/logging.py b/src/modelhub/utils/logging.py index 4e37b57..0c9adcd 100755 --- a/src/modelhub/utils/logging.py +++ b/src/modelhub/utils/logging.py @@ -12,7 +12,7 @@ from rich.table import Table from rich.tree import Tree from torch import nn -from modelhub.utils.ddp import RankedLogger +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/src/modelhub/utils/torch_utils.py b/src/modelhub/utils/torch.py similarity index 99% rename from src/modelhub/utils/torch_utils.py rename to src/modelhub/utils/torch.py index 09e5954..d37c4e6 100755 --- a/src/modelhub/utils/torch_utils.py +++ b/src/modelhub/utils/torch.py @@ -14,8 +14,8 @@ from torch import Tensor from torch._prims_common import DeviceLikeType from torch.types import _dtype -from modelhub import should_check_nans -from modelhub.common import at_least_one_exists, do_nothing +from src import should_check_nans +from common import at_least_one_exists, do_nothing def map_to( diff --git a/src/modelhub/utils/weights.py b/src/modelhub/utils/weights.py index cf20bf9..e66ea87 100644 --- a/src/modelhub/utils/weights.py +++ b/src/modelhub/utils/weights.py @@ -9,7 +9,7 @@ import torch from beartype.typing import Pattern from torch import nn -from modelhub.utils.ddp import RankedLogger +from utils.ddp import RankedLogger ranked_logger = RankedLogger(__name__, rank_zero_only=True) diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score deleted file mode 100644 index cf19329..0000000 --- a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file.score +++ /dev/null @@ -1,16 +0,0 @@ -example_id,chain_chainwise,chainwise_plddt,chainwise_pde,chainwise_pae,overall_plddt,overall_pde,overall_pae,batch_idx,chain_i_interface,chain_j_interface,pae_interface,pde_interface -5vht_from_file,A_1,0.8487251996994019,0.7481884360313416,4.096249103546143,0.8523056507110596,0.8404603004455566,4.370947360992432,0,,,, -5vht_from_file,B_1,0.8558861613273621,0.7651486396789551,4.091543197631836,0.8523056507110596,0.8404603004455566,4.370947360992432,0,,,, -5vht_from_file,A_1,0.8485860824584961,0.7361811399459839,4.076277732849121,0.8519705533981323,0.8272480964660645,4.359431266784668,1,,,, -5vht_from_file,B_1,0.8553550243377686,0.7556437849998474,4.100306510925293,0.8519705533981323,0.8272480964660645,4.359431266784668,1,,,, -5vht_from_file,A_1,0.8494381904602051,0.7172074913978577,3.9760985374450684,0.8525967597961426,0.8058767318725586,4.274764060974121,2,,,, -5vht_from_file,B_1,0.8557553291320801,0.742344081401825,4.033787250518799,0.8525967597961426,0.8058767318725586,4.274764060974121,2,,,, -5vht_from_file,A_1,0.8491919636726379,0.7664837837219238,4.102567672729492,0.8524054288864136,0.8643093109130859,4.416274070739746,3,,,, -5vht_from_file,B_1,0.855618953704834,0.7962877154350281,4.166482925415039,0.8524054288864136,0.8643093109130859,4.416274070739746,3,,,, -5vht_from_file,A_1,0.8481411337852478,0.7494730949401855,4.083292007446289,0.8515472412109375,0.8493087887763977,4.38142728805542,4,,,, -5vht_from_file,B_1,0.8549532294273376,0.7740440964698792,4.134703159332275,0.8515472412109375,0.8493087887763977,4.38142728805542,4,,,, -5vht_from_file,,,,,0.8523056507110596,0.8404603004455566,4.370947360992432,0,A_1,B_1,4.647997856140137,0.9242519736289978 -5vht_from_file,,,,,0.8519705533981323,0.8272480964660645,4.359431266784668,1,A_1,B_1,4.630569934844971,0.9085838198661804 -5vht_from_file,,,,,0.8525967597961426,0.8058767318725586,4.274764060974121,2,A_1,B_1,4.544585227966309,0.8819775581359863 -5vht_from_file,,,,,0.8524054288864136,0.8643093109130859,4.416274070739746,3,A_1,B_1,4.698022842407227,0.9472327828407288 -5vht_from_file,,,,,0.8515472412109375,0.8493087887763977,4.38142728805542,4,A_1,B_1,4.653857231140137,0.9368589520454407 diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz deleted file mode 100644 index 2f84aa1..0000000 Binary files a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_0.cif.gz and /dev/null differ diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz deleted file mode 100644 index dc1bc1c..0000000 Binary files a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_1.cif.gz and /dev/null differ diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz deleted file mode 100644 index f274c71..0000000 Binary files a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_2.cif.gz and /dev/null differ diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz deleted file mode 100644 index db30962..0000000 Binary files a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_3.cif.gz and /dev/null differ diff --git a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz b/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz deleted file mode 100644 index 933a442..0000000 Binary files a/tests/data/inference_regression_tests/5vht_from_file/5vht_from_file_model_4.cif.gz and /dev/null differ diff --git a/tests/test_ground_truth_template.py b/tests/test_ground_truth_template.py deleted file mode 100644 index 1caa8c5..0000000 --- a/tests/test_ground_truth_template.py +++ /dev/null @@ -1,235 +0,0 @@ -from functools import partial - -import numpy as np -import pytest -import torch -from cifutils.constants import ( - STANDARD_AA, - STANDARD_DNA, - STANDARD_RNA, -) -from cifutils.enums import ChainType -from datahub.transforms.atomize import AtomizeByCCDName -from datahub.transforms.base import Compose -from datahub.utils.rng import create_rng_state_from_seeds, rng_state -from datahub.utils.testing import cached_parse -from datahub.utils.token import get_af3_token_center_masks -from jaxtyping import Float -from torch import Tensor - -from modelhub.data.ground_truth_template import ( - DEFAULT_DISTOGRAM_BINS, - FeaturizeNoisedGroundTruthAsTemplateDistogram, - TokenGroupNoiseScaleSampler, - af3_noise_scale_distribution_wrapped, - af3_noise_scale_to_noise_level, -) -from modelhub.utils.torch_utils import assert_no_nans, assert_same_shape - -TEST_CASES = ["6wtf", "5ocm"] - - -def calc_distogram(coords: Float[Tensor, "n 3"]) -> Float[Tensor, "n n"]: - return torch.cdist(coords, coords, p=2, compute_mode="donot_use_mm_for_euclid_dist") - - -def test_default_distogram_bins(): - assert len(DEFAULT_DISTOGRAM_BINS) == 63 - assert ( - torch.bucketize(torch.linspace(0, 22, 230), DEFAULT_DISTOGRAM_BINS).max() == 63 - ) - assert ( - torch.bucketize(torch.linspace(0, 22, 200), DEFAULT_DISTOGRAM_BINS).min() == 0 - ) - - -@pytest.fixture -def setup_data_and_pipeline(): - def _setup(pdb_id): - data = cached_parse(pdb_id) - pipe = Compose( - [ - AtomizeByCCDName( - atomize_by_default=True, - res_names_to_ignore=STANDARD_AA + STANDARD_RNA + STANDARD_DNA, - move_atomized_part_to_end=False, - validate_atomize=False, - ), - ], - track_rng_state=False, - ) - out = pipe(data) - return out - - return _setup - - -@pytest.mark.parametrize("pdb_id", TEST_CASES) -@pytest.mark.parametrize( - "transform_args", - [ - { - # Protein-only, high noise - "noise_scale_distribution": lambda size: torch.ones(size) * 10.0, - "allowed_chain_types": [ChainType.POLYPEPTIDE_L], - }, - { - # Protein and small molecules, no noise, no inter-molecule masking - # (should be the same as the ground truth distogram) - "noise_scale_distribution": lambda size: torch.zeros(size), - "allowed_chain_types": [ - *ChainType.get_polymers(), - *ChainType.get_non_polymers(), - ], - "p_provide_inter_molecule_distances": 1.0, - }, - { - # All chain types supported, but non-polymers have low noise, and polymers have high noise - "noise_scale_distribution": TokenGroupNoiseScaleSampler( - mask_and_sampling_fns=( - ( - lambda arr: np.isin(arr.chain_type, ChainType.get_polymers()), - partial( - af3_noise_scale_distribution_wrapped, - upper_noise_level=af3_noise_scale_to_noise_level( - 16.0 - ).item(), - ), - ), - ( - lambda arr: np.isin( - arr.chain_type, ChainType.get_non_polymers() - ), - partial( - af3_noise_scale_distribution_wrapped, - upper_noise_level=af3_noise_scale_to_noise_level( - 2.0 - ).item(), - ), - ), - ) - ), - "allowed_chain_types": [ - *ChainType.get_polymers(), - *ChainType.get_non_polymers(), - ], - }, - ], -) -def test_distogram_featurization( - pdb_id: str, transform_args: dict, setup_data_and_pipeline -): - out = setup_data_and_pipeline(pdb_id) - transform = FeaturizeNoisedGroundTruthAsTemplateDistogram(**transform_args) - - with rng_state(create_rng_state_from_seeds(12345)): - atom_array = out["atom_array"] - out["is_unconditional"] = False - - # Build ground-truth distogram - token_center_mask = get_af3_token_center_masks(atom_array) - token_center_atom_array = atom_array[token_center_mask] - token_coord = torch.as_tensor(token_center_atom_array.coord) - - distogram = calc_distogram(token_coord) - ground_truth_distogram_bins = torch.bucketize(distogram, DEFAULT_DISTOGRAM_BINS) - - # Featurize with the given arguments - pipeline_output = transform(out)["feats"] - has_distogram_condition = pipeline_output["has_distogram_condition"] - - output = torch.argmax(pipeline_output["distogram_condition"], dim=-1) - - assert has_distogram_condition.any(), "No distogram conditions found!" - - # Uncomment the code below to visualize the distogram and has_distogram_condition - # _, axes = plt.subplots(1, 2, figsize=(12, 6)) - # cmap_output = plt.cm.get_cmap('RdYlGn_r') - # axes[0].imshow(output, cmap=cmap_output, interpolation='none') - # axes[1].imshow(has_distogram_condition, cmap='gray', interpolation='none') - # plt.savefig('distogram_visualization.png', bbox_inches='tight') - - assert_same_shape(output, ground_truth_distogram_bins) - assert_no_nans(output) - - # Mask of inter-molecule distances - is_inter_molecule = ( - token_center_atom_array.molecule_id[:, None] - != token_center_atom_array.molecule_id - ) - - noise_sum = ( - transform_args["noise_scale_distribution"](atom_array).sum() - if isinstance( - transform_args["noise_scale_distribution"], TokenGroupNoiseScaleSampler - ) - else transform_args["noise_scale_distribution"](1).sum() - ) - - if noise_sum == 0: - if transform_args.get("p_provide_inter_molecule_distances", 0.0) == 0.0: - # ... except for inter-molecule distances - assert (output[is_inter_molecule] == len(DEFAULT_DISTOGRAM_BINS)).all() - assert not has_distogram_condition[is_inter_molecule].any() - else: - # (all distances) - assert ( - output[has_distogram_condition] - == ground_truth_distogram_bins[has_distogram_condition] - ).all(), "Unnoised output should match distogram bins" - assert ( - output[~has_distogram_condition] == len(DEFAULT_DISTOGRAM_BINS) - ).all(), "Values without condition should be max distance bin" - - # All values without distogram condition should be maximum distance bin - assert ( - output[~has_distogram_condition] == len(DEFAULT_DISTOGRAM_BINS) - ).all(), "All values without distogram condition should be the same" - else: - # Noised output should be different from the distogram bins - assert not ( - output == ground_truth_distogram_bins - ).all(), "Noised output should be different from the distogram bins" - - # Check that all values with distogram condition have been noised... - assert ( - output[has_distogram_condition] - != ground_truth_distogram_bins[has_distogram_condition] - ).any(), "Values with distogram condition should be noised" - # ... and that no values without conditions have been noised - assert ( - output[~has_distogram_condition] == len(DEFAULT_DISTOGRAM_BINS) - ).all(), "Values without distogram condition should be max distance bin" - - # Check chain type conditions - if "supported_chain_types" in transform_args: - tokens_with_supported_chain_types = np.isin( - token_center_atom_array.chain_type, - transform_args["supported_chain_types"], - ) - tokens_with_supported_chain_types_l_l = ( - tokens_with_supported_chain_types[:, None] - & tokens_with_supported_chain_types - ) - assert ( - has_distogram_condition & tokens_with_supported_chain_types_l_l - ).any(), "Supported chain types should be noised" - assert not ( - has_distogram_condition & ~tokens_with_supported_chain_types_l_l - ).any(), "Unsupported chain types should not be noised" - - # Inter-molecule distances should be the maximum distance bin and have no distogram condition - if transform_args.get("p_provide_inter_molecule_distances", 0.0) == 0.0: - assert (output[is_inter_molecule] == len(DEFAULT_DISTOGRAM_BINS)).all() - assert not has_distogram_condition[is_inter_molecule].any() - - # Sanity check: we should have at least 5 unique values in the output - unique_values = torch.unique(output) - # Assert we have more than 30 unique values in the noised output - assert ( - len(unique_values) > 30 - ), "We should have more than 30 unique values in the noised output" - - -if __name__ == "__main__": - pytest.main(["-v", __file__]) diff --git a/tests/test_inference_pipelines.py b/tests/test_inference_pipelines.py deleted file mode 100644 index e79130b..0000000 --- a/tests/test_inference_pipelines.py +++ /dev/null @@ -1,97 +0,0 @@ -import tempfile -from os import PathLike -from pathlib import Path - -import hydra -import numpy as np -import pytest -from atomworks.io import parse -from hydra import compose, initialize - -from modelhub.utils.inference import ( - apply_conformer_and_template_selections, - build_file_paths_for_prediction, -) - -current_file_directory = Path(__file__).parent - - -@pytest.mark.parametrize( - "file_path", - [ - "data/nested_examples", - "data/multiple_examples_from_json.json", - ], -) -def test_build_file_paths_for_prediction(file_path: PathLike, tmp_path: Path): - """Use the inference pipeline to build and parse inputs for prediction.""" - file_path = current_file_directory / Path(file_path) - - # Call the function with the file path and temporary directory - paths = build_file_paths_for_prediction(file_path, tmp_path) - - # Iterate over the returned paths and parse them, ensuring the the outputs are reasonable - for path in paths: - output = parse(path) - assert output is not None - assert len(output["assemblies"]["1"][0]) > 0 - - -@pytest.mark.parametrize( - "inference_engine", - ["af3"], -) -@pytest.mark.parametrize( - "inputs", - ["tests/data/5vht_from_file.cif"], -) -@pytest.mark.parametrize("template_selection", ["A"]) -@pytest.mark.parametrize("ground_truth_conformer_selection", ["*/PBF"]) -@pytest.mark.slow -@pytest.mark.skip(reason="TEST STILL BROKEN") -def test_inference_engine( - inference_engine: Path, - inputs: PathLike, - template_selection: str, - ground_truth_conformer_selection: str, -): - # TODO: TEST STILL BROKEN - with initialize(config_path="../configs"): - cfg = compose( - config_name="inference", - overrides=[ - f"inference_engine={inference_engine}", - f"inputs={inputs}", - ], - ) - - with tempfile.TemporaryDirectory() as temp_dir: - temp_dir = Path(temp_dir) - temp_dir.mkdir(parents=True, exist_ok=True) - - inference_engine = hydra.utils.instantiate( - cfg, temp_dir=temp_dir, _convert_="partial" - ) - out = inference_engine.parse_from_path(inputs) - atom_array = ( - out["assemblies"]["1"][0] if "assemblies" in out else out["asym_unit"][0] - ) - assert atom_array is not None - - atom_array_untemplated = apply_conformer_and_template_selections(atom_array) - assert ( - "is_input_file_templated" in atom_array_untemplated.get_annotation_categories() - ) - assert np.sum(atom_array_untemplated.get_annotation("is_input_file_templated")) == 0 - - atom_array_templated = apply_conformer_and_template_selections( - atom_array, - template_selection=template_selection, - ground_truth_conformer_selection=ground_truth_conformer_selection, - ) - assert "is_input_file_templated" in atom_array_templated.get_annotation_categories() - assert np.sum(atom_array_templated.get_annotation("is_input_file_templated")) > 0 - - # TODO: Make this actually test the ground truth conformer policy; make template selection actuall work; - # also dont rely on the is_input_file_templated annotation instead just handle the case where it doesnt exist - # correctly diff --git a/tests/test_metrics.py b/tests/test_metrics.py index fd66adb..3c58da5 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,12 +1,12 @@ from copy import deepcopy import pytest -from datahub.utils.testing import cached_parse +from atomworks.ml.utils.testing import cached_parse -from modelhub.metrics.chiral import ChiralLoss +from metrics.chiral import ChiralLoss -@pytest.mark.parametrize("pdb_id", ["5ocm", "6wtf"]) +@pytest.mark.parametrize("pdb_id", ["5ocm", "1ivo"]) def test_chiral_metrics(pdb_id: str): # ... get the AtomArray ground_truth_atom_array = cached_parse(pdb_id, hydrogen_policy="remove")[ diff --git a/tests/test_torch_utils.py b/tests/test_torch_utils.py index 5d86732..3cf2663 100644 --- a/tests/test_torch_utils.py +++ b/tests/test_torch_utils.py @@ -4,7 +4,7 @@ import pytest import torch os.environ["NAN_CHECKING"] = "True" -from modelhub.utils.torch_utils import assert_no_nans, map_to +from utils.torch import assert_no_nans, map_to def test_map_to(): diff --git a/tests/test_weight_loading.py b/tests/test_weight_loading.py index 72f2ac7..e0f00ef 100644 --- a/tests/test_weight_loading.py +++ b/tests/test_weight_loading.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn # Import your code here -from modelhub.utils.weights import ( +from utils.weights import ( ParameterFreezingConfig, WeightLoadingConfig, WeightLoadingPolicy,