refactor: new modelhub (#109)

* Initial commit of chiral changes

Initial checkin of chiral feature code

Add chiral metric

* Update the way chiral features are incorporated into the model

Move initialization to new func

use default pytorch reset parameters

fix initialization for chirals

config

rename argument of confidence head

fix initialization for chirals

* refactor: src nest, rename rf2aa to modelhub

* refactor: initial commit without projects

* Initial commit of chiral changes

* Initial checkin of chiral feature code

* Add chiral metric

* Remove option for double residual connection.  Add kq_norm oiptions to base (20250125) config.

* Restoring flag

* config

* rename argument of confidence head

* Update the way chiral features are incorporated into the model

* config

* rename argument of confidence head

* Update the way chiral features are incorporated into the model

* Initial commit of chiral changes

Initial checkin of chiral feature code

Add chiral metric

* Update the way chiral features are incorporated into the model

Move initialization to new func

use default pytorch reset parameters

fix initialization for chirals

config

rename argument of confidence head

fix initialization for chirals

* refactor: new modelhub

---------

Co-authored-by: fdimaio <dimaio@uw.edu>
Co-authored-by: HaotianZhangAI4Science <haotianzhang@zju.edu.cn>
This commit is contained in:
Nathaniel Corley
2025-04-08 13:33:17 -07:00
committed by GitHub
parent 1fd848a861
commit 5a492032d5
356 changed files with 12882 additions and 14815 deletions

6
.env
View File

@@ -1,5 +1,9 @@
# +--------+ Cifutils +--------+
CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd
PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb
# +--------+ Datahub +--------+
# (Distillation)
AF2FB_PATH=/squash/af2_distillation_facebook AF2FB_PATH=/squash/af2_distillation_facebook
# (PadDNA TRansform)
X3DNA=/projects/ml/prot_dna/x3dna-v2.4

237
.gitignore vendored
View File

@@ -1,23 +1,220 @@
valid_remapped # Base .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore
lig_test
dataset.pkl # Byte-compiled / optimized / DLL files
run_digs.sh
*.pdb
.vscode
slurm_logs/
**/output/
**/outputs/
*/notebooks/
*/models/
__pycache__/ __pycache__/
*/run_scripts/ */__pycache__/
unit_tests/ *.py[cod]
ruff.toml *$py.class
*/scratch/
*/wandb/ # C extensions
rf2aa/dataset_20240318.pkl *.so
*.csv
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebooks (unless explicitly not ignored)
.ipynb_checkpoints
**/.ipynb
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# VS Code
.vscode
.history/
# Slurm
*/slurm_logs/
*.err *.err
*.log *.log
*.json
data/run.sh # Ruff
ruff.toml
.ruff_cache
# Development
dev.py
# Pytest
*.benchmarks/
# Images
*.png
*.pdf
*.svg
*.jpg
*.jpeg
*.gif
*.bmp
*.tiff
# Versioning
**/version.py
# W&B
wandb/
# Hydra
.hydra/
# Outputs
**/outputs/
# Logs
logs/
# Other
*.sif
*.out
# Misc
**/notebooks/
**/models/
**/run_scripts/
**/scratch/

View File

@@ -1,31 +0,0 @@
# This file is a template, and might need editing before it works on your project.
# This is a sample GitLab CI/CD configuration file that should run without any modifications.
# It demonstrates a basic 3 stage CI/CD pipeline. Instead of real tests or scripts,
# it uses echo commands to simulate the pipeline execution.
#
# A pipeline is composed of independent jobs that run scripts, grouped into stages.
# Stages run in sequential order, but jobs within stages run in parallel.
#
# For more information, see: https://docs.gitlab.com/ee/ci/yaml/index.html#stages
#
# You can copy and paste this template into a new `.gitlab-ci.yml` file.
# You should not add this template to an existing `.gitlab-ci.yml` file by using the `include:` keyword.
#
# To contribute improvements to CI/CD templates, please follow the Development guide at:
# https://docs.gitlab.com/ee/development/cicd/templates.html
# This specific template is located at:
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Getting-Started.gitlab-ci.yml
stages: # List of stages for jobs, and their order of execution
- test
unit-test-job: # This job runs in the test stage.
stage: test # It only starts when the job in the build stage completes successfully.
rules:
- if: '$CI_PIPELINE_SOURCE == "merge_request_event"'
script:
- echo "Running unit tests"
- git submodule update --init
- cd rf2aa
- srun -p gpu --gres=gpu:a4000:1 --cpus-per-task=4 --mem=32G bash ../ci/run_tests.sh

View File

@@ -24,25 +24,72 @@ format:
ruff format . ruff format .
ruff check --fix . ruff check --fix .
## Create a new conda environment _github_token_error:
@echo "==============================================================================="; \
echo "Error: Environment variables GITHUB_USER and GITHUB_TOKEN must be set."; \
echo ""; \
echo "You need to set the environment variables GITHUB_USER and GITHUB_TOKEN."; \
echo "You can create a personal access token on GitHub at:"; \
echo " https://github.com/settings/tokens"; \
echo ""; \
echo "For more info see: https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic"; \
echo ""; \
echo "To expose these variables, you can use:"; \
echo "export GITHUB_USER=<github-username>"; \
echo "export GITHUB_TOKEN=<github-token>"; \
echo ""; \
echo "It is recommended that you set these tokens in your .bashrc or .zshrc file for future use."; \
echo "===============================================================================";
exit 1;
_check_conda:
@echo "... checking if conda/mamba is installed"
@command -v $(CONDA_BINARY) >/dev/null 2>&1 || { \
echo "Error: Conda/mamba is not installed or not found in PATH" >&2; \
exit 1; \
}
@echo "... found conda executable: $(CONDA_BINARY)"
_check_tokens:
@echo "... checking if GITHUB_USER and GITHUB_TOKEN are set"
@if [ -z "$(GITHUB_USER)" ] || [ -z "$(GITHUB_TOKEN)" ]; then \
$(MAKE) _github_token_error; \
fi
@echo "... found GITHUB_USER ($(GITHUB_USER)) and GITHUB_TOKEN."
## Create a new conda environment and install modelhub
env: env:
$(CONDA_BINARY) env create -n modelhub --file environment.yaml @echo "Creating modelhub conda environment: modelhub"
conda init
conda activate modelhub @$(MAKE) --no-print-directory _check_tokens
pip install -e ".[dev]" @$(MAKE) --no-print-directory _check_conda
@$(CONDA_BINARY) env create -n modelhub --file environment.yaml
@conda init
@conda activate modelhub
@pip install -e ".[dev]"
@python -m biotite.setup_ccd
## Install modelhub locally into the current environment ## Install modelhub locally into the current environment
install: install:
# Install the conda requirements in the current activated environment # Install the conda requirements in the current activated environment
$(CONDA_BINARY) env update --file environment.yaml $(CONDA_BINARY) env update --file environment.yaml
# Install the pip requirements in the current activated environment # Install the pip requirements in the current activated environment
pip install -e ".[dev]" @pip install -e ".[dev]"
@python -m biotite.setup_ccd
## Build the apptainer image ## Build the apptainer image
apptainer: base_apptainer:
$(eval DATE := $(shell date +%Y-%m-%d)) $(eval DATE := $(shell date +%Y-%m-%d))
$(eval IMAGE := modelhub_$(DATE).sif) bash ./scripts/build_base_apptainer.sh
bash ./scripts/build_apptainer.sh
# Set INSTALL_PROJECT to true to install modelhub within the apptainer (much slower)
# e.g., `make INSTALL_PROJECT=true freeze_apptainer` or `make freeze_apptainer INSTALL_PROJECT=true`
INSTALL_PROJECT ?= false
freeze_apptainer:
$(eval DATE := $(shell date +%Y-%m-%d))
bash ./scripts/freeze_apptainer.sh $(INSTALL_PROJECT)
## Run pytest and generate coverage report ## Run pytest and generate coverage report
test: test:

219
README.md
View File

@@ -1,51 +1,210 @@
RoseTTAFold All-Atom # Modelhub
--------------------
This repository contains the code to training and running inference on - [Modelhub](#modelhub)
RoseTTAFold All-Atom (RFAA), a neural network that can predict the structures - [Background](#background)
of proteins in complex with DNA, RNA, and/or small molecule ligands. - [Division of code between Modelhub, Datahub, and Cifutils](#division-of-code-between-modelhub-datahub-and-cifutils)
- [Cifutils](#cifutils)
- [Datahub](#datahub)
- [Training, Validation, and Inference](#training-validation-and-inference)
- [Training and Validation](#training-and-validation)
- [Inference](#inference)
- [Setup](#setup)
- [Apptainers](#apptainers)
- [Base Apptainer](#base-apptainer)
- [Frozen Apptainer](#frozen-apptainer)
- [Shebang](#shebang)
- [General Use](#general-use)
- [Debugging](#debugging)
`rf2aa/` contains the model and training code. ## Background
`data/` contains code used to curate the training data from the PDB.
This repository constitutes the base for deep-learning method development at the Institute for Protein Design.
## Contributing to RFAA It is symbiotic with two other Institute for Protein Design repositories:
- [cifutils](https://github.com/baker-laboratory/cifutils), which manages input parsing and data cleaning
- [datahub](https://github.com/baker-laboratory/datahub), which manages input featurization and holds our composable `Transform` components
### Set Up Within this ontology, `modelhub` contains the *architectures*, *training* code, and *inference* endpoints.
## Division of code between Modelhub, Datahub, and Cifutils
Across our codebases, we balance the need to develop quickly with the need to write code that we can continue to maintain and that is easy to understand. We below lay out some thoughts on what code should live where.
We enforce a strict dependency flow of `modelhub` -> (depends on) `datahub` -> (depends on) `cifutils`; it would be a circular anti-pattern to thus import any `datahub` or `modelhub` functions from within `cifutils`.
### Cifutils
[cifutils](https://github.com/baker-laboratory/cifutils) is the most static of our three codebases. Basic parsing functionality, RDKit and other molecular toolkit utilities, and `AtomArray` quality-of-life tools live in this repository.
Examples of `cifutils` functions are:
- All functions related to **parsing structural files from source**; e.g., keeping/removing hydrogens, resolving occupancy, etc.
- Utility functions to manipulate `AtomArrays`, the core API of the `biotite` library, upon which we heavily rely
- Utility functions for common bioinformatics software, such as `RDKit`, that interface with `AtomArrays`
As a foundational library for the Institute for Protein Design, `cifutils` functions most like an open-source codebase. We must keep the code easy-to-understand and easy-to-maintain, both now and into the future. As such, `cifutils`:
- Maintains the **highest code quality standard**, requiring well-documented, easy-to-maintain code with adequate test coverage (we aim for **>85%** coverage)
- **Strictly versions** to minimize breaking changes with downstream repositories
You should write code in `cifutils` if:
- You are are writing **core** `AtomArray`-level level functionality that will be broadly useful, not only to those at the Institute for Protein Design but possibly the wider bioinformatics community (i.e., without dependencies, or even knowledge of, `datahub` or `modelhub`)
- You are willing to spend some additional time to ensure the code is **scalable, well-tested, and maintainable**
Quick-and-dirty experiments that require modifying `cifutils` can be performed by submoduling or cloning the repository and exporting a local path.
### Datahub
[datahub](https://github.com/baker-laboratory/datahub) manages data loading, preprocessing, and featurization pipelines for structure-dependent deep-learning models. We offer three core components: a `Transforms` library, a set of `Preprocessing` scripts, and `Datasets`.
- **Transforms**: A series of composable classes that take as input a dictionary containing sequence- and structure-based data (in the form of an `AtomArray`) and perform arbitrary operations, analogous to TorchVision's [approach](https://pytorch.org/vision/main/transforms.html) for computer vision
- **Preprocessing**: Scripts and functions for common data cleaning and preparation tasks, including specialized pipelines for frequent use cases (e.g., antibodies, clash detection, cleaning PDB data, etc.). Many of these *scripts* output `parquet` files stored to disk that are sampled from at train-time, while the *functions* are called by the scripts to clean, label, or filter the data (e.g., `has_clash()`, etc.)
- **Datasets**: The base `Datasets` and `Sampler` classes used for training, imported by `modelhub`
`datahub` is less static than `cifutils`; however, it still must operate as a stand-alone library that others can continue to build around and upon, even without `modelhub`. We strive to maintain `datahub` like an open-source software project such that others in the lab can easily understand, and build upon, our base components. We focus on **maintainable** and **flexible** code - if a particular `Transform` is bespoke or non-generalizable (at least initially), then the `/projects` folder within `Modelhub` may be a more appropriate place for initial development.
You should write code in `datahub` if:
- You are writing flexible, generic *pre-processing scripts* or *functions* that others in the lab have expressed interest in using (vs. a single-purpose pipeline or feature to test a hypothesis)
- **Example that should live in `datahub`**: You are writing a pre-processing pipeline to label all beta barrels in the PDB. Your scripts, written in a functional manner, may be a good candidate for `datahub/scripts/preprocessing`, so long as you are willing to write them generally and include tests. Similarly, if a single function may be generalizable but the pipeline is bespoke, that single function (with a test) could still be included as a stand-alone element in `datahub`, e.g.,
```python
atom_array_has_beta_barrel(atom_array: AtomArray) -> bool
```
- **Example that should live in `modelhub/projects`**: You have pulled together a script that loads PDB files, includes manual annotations, and saves out to CIF. Such a script may be appropriate for the specific use case but is unlikely to generalize across other use cases.
- You are writing `Transforms` that generalize to additional use cases beyond the current project
- **Example that should live in `datahub`**: Any `Transform` that adds a useful annotation to an `AtomArray` (e.g., annotationg pocket residues, hydrogen bonds, SASA, etc.)
- **Example that should live in `datahub`**: A `Transform` that pads DNA with generated B-form structure, as is done in AF-3; such a `Transform` may be applicable to both structure prediction and design, when proven effective
- **Example that should live in `modelhub/projects`**: A `Transform` that aggregates and/or concatenates features for a bespoke model pipeline
- You are willing to spend some additional time to ensure the code is scalable, well-tested, and maintainable. Otherwise the `projects` folder of `modelhub` may be a more appropriate place in the interim
## Training, Validation, and Inference
> If you are developing at the IPD, our `shebang` executables will take care of identifying and executing with the most up-do-date apptainer. If you are not at the IPD, you will need to ensure you have the appropriate apptainer. See below for details.
NOTE: For Training, Validation, and Inference, we make heavy use of [Hydra](https://hydra.cc/) for configuration management.
Before running any of the below commands, you will need to ensure `datahub` and `cifutils` are in your `PYTHONPATH`. E.g.,
``` ```
git clone https://git.ipd.uw.edu/jue/RF2-allatom.git export PYTHONPATH="/home/<USER>/projects/datahub/src:/home/<USER>/projects/cifutils/src"
cd RF2-allatom
``` ```
If you are on digs, the S3nv.sif apptainer has all the relevant packages. To get started coding: ### Training and Validation
For Training and Validation, when you execute `train.py` or `validate.py`, you will need to provide an *experiment* Hydra config. Experiments are a Hydra best-practice pattern to enable us to maintain multiple configurations; see more in the [Hydra documentaion](https://hydra.cc/docs/patterns/configuring_experiments/)
and in the `configs/experiment` sub-directory.
For example, to test AF-3 training without confidence, run:
``` ```
export PYTHONPATH="../RF2-allatom" ./src/modelhub/train.py experiment=quick-af3 debug=default
``` ```
First, run the test suite: **Explanation:**
- `./src/modelhub/train.py` — we execute our `train.py` like a bash executable, which triggers the `shebang` code to find the correct apptainer. It's equivalent to `apptainer exec --nv /path/to/apptainer python ./src/modelhub/train.py`
- `experiment=quick-af3` — we identify the experiment we want to use for training; in this case, `quick-af3`, which can be viewed at `configs/experiment/quick-af3.yaml`. This experiment is a simple test config for AF-3 that loads and runs more rapidly that the full training config
- `debug=default` - a setting letter Hydra know we are debugging; when we debug, we perform some automatic time-savings like setting a small diffusion batch size and crop size. You could remove this line if you don't want those options. You can explore more about various `debug` options in `config/debug`
For validation only, run the following:
``` ```
apptainer exec --nv /software/containers/versions/SE3nv/SE3nv-20240415.sif pytest tests/ ./src/modelhub/validate.py experiment=quick-af3 debug=default
``` ```
If all the tests pass, you have a stable version of the code.
### Running model training Note that since we use `hydra`, you could specify additional setup arguments using the command line. For example, by default, we `prevalidate` - running validation at the beginning of training so we develop a baseline and catch any errors (especially out-of-memory errors) before training for a full epoch. If you don't want that behavior, you could override in-line:
We use a package called hydra to configure different training runs of the model. Config files for different training runs can be found in `rf2aa/config/train`. The base trainable version is `rf2aa/config/train/rf2aa.yaml`, to run training with this version, run:
``` ```
/software/containers/versions/SE3nv/SE3nv-20240415.sif trainer_new.py --config-name rf2aa ./src/modelhub/train.py experiment=quick-af3 debug=default trainer.prevalidate=false
``` ```
These tests are most often run on a4000s on digs. If you have a separate installation of cifutils in your home directory, this can potentially break the tests.
If you make changes in the code, they should NOT break backwards compatibility, e.g. there should be a flag in the yaml files that would make it as if your changes were never committed. You can view the flattened Hydra configuration to determine how to best override or add additional arguments by:
- Running training or validation and viewing the pretty-printed file, which looks like:
![alt text](assets/example_config.png)
- Adding `--cfg job` to your launch command, which prints the config for the application and then exits
### Inference
To support multiple models and multiple projects, we build an `InferenceEngine` for each use case. For end-users the details of the `InferenceEngine` are not necessary; the appropriate engine can be specified with with `inference_engine` argument.
For example, to run the latest AF-3 model with confidence, we can execute (if `cifutils` and `datahub` are in the `PYTHONPATH`):
```
./src/modelhub/inference.py inference_engine=af3 inputs='./tests/data/example_with_ncaa.json'
```
We can then modify the command by adding/removing arguments with Hydra to our liking; for example, to dump diffusion trajectories and only include one model per CIF file:
```
./src/modelhub/inference.py inference_engine=af3 inputs='./tests/data/example_with_ncaa.json' dump_trajectories=true one_model_per_file=true
```
More details can be found in the [inference README](src/modelhub/inference_engines/README.md)
## Setup
> If you are developing at the IPD, then our `shebang` executables will handle the Apptainer dependencies; no need to run the commands below. See the `shebang` section below.
### Apptainers
To accelerate development and better contain dependencies, we offer two apptainers:
- `base_apptainer`: Contains all of the development dependencies, pre-compiled DeepSpeed, but *NOT* `cifutils` or `datahub`. The rationale for the base apptainer is that you expose these libraries via your PYTHONPATH/PATH to allow you to develop & pull updates for these libraries without having to re-build any apptainer.
- `freeze_apptainer`: Takes the `base_apptainer` as its image, and adds versioned `cifutils`, `datahub`, and (optionally) pip-installs `modelhub` as well (useful for releasing self-contained inference code). The rationale for these apptainers is to provide designers with a stable environment to tackle design problems in.
#### Base Apptainer
To make the base apptainer, run:
```
make base_apptainer
```
from the project root. This container will **not** contain `cifutils` or `datahub`; those paths must be exported explicitly during development (e.g., the paths to their respective submodules or clones elsewhere).
Building this apptainer pre-compiles DeepSpeed, among other actions, and is slow. You **should not** need to re-build this apptainer often; changes to `datahub` and `cifutils` can be addressed much more efficiently through the `freeze_apptainer` command specified below.
> NOTE: Since we pre-compile CUDA-specific DeepSpeed, you must run `make base_apptainer` on a GPU node
> NOTE: You will need to adjust the IPD-speciifc paths to frozen copies of the PDB and the CCD
#### Frozen Apptainer
To make a container that contains `cifutils` and `datahub`, but not `modelhub`, run:
```
make freeze_apptainer
```
This will use the `base_apptainer` pointed to by the `shebang` symlink as a base. Note that by default the versions of `cifutils` and `datahub` are fixed; update the `freeze_apptainer.spec` file to adjust the version numbers and/or add dependencies.
To make a container that contains `modelhub`, `datahub`, and `cifutils` (e.g., for production usage across the lab), run
```
make INSTALL_PROJECT=true freeze_apptainer
```
> NOTE: Since we build from the `base_apptainer` image, which contains pre-compiled DeepSpeed, `make freeze_apptainer` does NOT need to be run from a GPU
### Shebang
#### General Use
We use `shebang` to help manage and version apptainers. Namely:
- The shebang lines (`#!/bin/bash` ...) at the top of entry point scripts like `train.py` redirect the system to to `scripts/shebang/modelhub_exec.sh`
- The script `modelhub_exec.sh` in turn identifies the correct Apptainer and executes your command
- Apptainers are symlinks in `scripts/shebang` to elsewhere on the DIGS (where they are versioned); thus, when we update apptainers, we must also update the symlink. This allows us to track which apptainers to use for a given branch of the code at any given time (provided you update the symlinks for your branch when you switch out which apptainer you run with!)
For example, to launch a dummy training run, one could type (after adding `cifutils` and `datahub` to your `PYTHONPATH`):
```
cd src/modelhub
./train.py experiment=none-00-dummy
```
> You may need to adjust the permissions on `train.py` (e.g., `chmod +x train.py`) in order to execute the file like a script.
#### Debugging
We also support VSCode-native debugging with Apptainers. To debug:
1. Update your `launch.json` to include `Python: Attach`; for example, add the configuration:
```
{
"name": "Python: Attach",
"type": "debugpy",
"request": "attach",
"connect": {
"host": "localhost",
"port": 2345
}
}
```
2. Add any interactive debug breakpoints in VSCode
3. Set the `DEBUG_PORT` to `2345`, and then execute your script with `shebang` like normal. That is:
```
export DEBUG_PORT=2345
./train.py experiment=none-00-dummy
```
4. When prompted in the termal, launch the VSCode debug session (shortcut: `F5`)
Happy debugging!
### Contributing to model code
Generally, we follow software engineering practices of:
1. Not duplicating functionality that is already in the code
2. Keeping functions as short as possible, and splitting complicated functions into multiple functions
3. Using object oriented programming, which means subclassing already existing classes when possible.
4. Writing tests for our code and sending small functional PRs for review.
5. Maintaining code stability and not breaking backwards compatibility for users using the package.
To write new blocks in RF, you can go to the rf2aa/model directory and add the new block into the simulator_blocks.py file (and be sure to add a relevant name in the blocks_factory dictionary). These names can be referenced in hydra configs: see rf2aa.yaml for an example with any keyword arguments necessary to initialize the block.

View File

@@ -3,7 +3,7 @@ From: ubuntu:24.04
IncludeCmd: yes IncludeCmd: yes
# NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd` # NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd`
# To build this apptainer, use: # To build this apptainer, use:
# apptainer build --bind $PWD:/modelhub_host modelhub_apptainer.sif apptainer.spec # make apptainer
%setup %setup
# Create a directory in the container to bind the host's current working directory # Create a directory in the container to bind the host's current working directory
@@ -15,8 +15,9 @@ IncludeCmd: yes
# ... for mounting `/squash` with --bind # ... for mounting `/squash` with --bind
mkdir ${APPTAINER_ROOTFS}/squash mkdir ${APPTAINER_ROOTFS}/squash
%files %files
/etc/localtime
/etc/hosts
environment.yaml /opt/environment.yaml environment.yaml /opt/environment.yaml
%post %post
@@ -47,7 +48,10 @@ IncludeCmd: yes
apt-get clean apt-get clean
# Clone CUTLASS (for DeepSpeed) # Clone CUTLASS (for DeepSpeed)
git clone https://github.com/NVIDIA/cutlass /opt/cutlass git clone https://github.com/NVIDIA/cutlass.git /opt/cutlass
# Clone DeepSpeed (so we can pre-install the wheel)
git clone --branch v0.16.2 https://github.com/deepspeedai/DeepSpeed.git /opt/deepspeed
## ENVIRONMENT CREATION & DEPENDENCY INSTALLATION ## ENVIRONMENT CREATION & DEPENDENCY INSTALLATION
# Download miniconda # Download miniconda
@@ -66,15 +70,27 @@ IncludeCmd: yes
# Add conda environment to PATH # Add conda environment to PATH
export PATH=/usr/envs/modelhub-apptainer/bin:$PATH export PATH=/usr/envs/modelhub-apptainer/bin:$PATH
echo "Proceeding with DeepSpeed reinstallation."
## PRE-COMPILE DEEPSPEED FROM WHEEL
# (Overwrite deepspeed installation from the `environment.yaml`)
pip uninstall deepspeed -y # Avoid interactive prompts
# (Flags for building the Evoformer attention)
export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9"
export DS_BUILD_EVOFORMER_ATTN=1
export CUTLASS_PATH=/opt/cutlass/
# Reinstall DeepSpeed, pre-compiling the evoformer attentino kernel
pip wheel /opt/deepspeed -w /opt/deepspeed
pip install /opt/deepspeed/deepspeed-0.16.2+b344c04d-cp311-cp311-linux_x86_64.whl
# Run the biotite setup command # Run the biotite setup command
# (Temporary measure until we switch to released Biotite version) # (Temporary measure until we switch to released Biotite version)
. /usr/etc/profile.d/conda.sh . /usr/etc/profile.d/conda.sh
conda activate modelhub-apptainer conda activate modelhub-apptainer
python -m biotite.setup_ccd python -m biotite.setup_ccd
# deepspeed
pip install deepspeed==0.15.1
# clean up files to reduce size # clean up files to reduce size
# ... remove conda # ... remove conda
mamba clean -a -y mamba clean -a -y
@@ -93,7 +109,7 @@ IncludeCmd: yes
%runscript %runscript
# NOTE: The %runscript is invoked when the container is run without specifying a different command. # NOTE: The %runscript is invoked when the container is run without specifying a different command.
exec "$@" exec python "$@"
%help %help
modelhub environment for running modelhub independently and for development modelhub environment for running modelhub independently and for development

View File

@@ -1,4 +0,0 @@
APP=/software/containers/versions/rf_diffusion_aa/24-05-21/rf_diffusion_aa.sif
PYTHONPATH=.. $APP -mpytest --benchmark-skip --ignore tests/test_semantics.py --durations=10 tests

View File

@@ -0,0 +1,5 @@
defaults:
- train_logging
- metrics_logging
- dump_validation_structures
- _self_

View File

@@ -0,0 +1,6 @@
dump_validation_structures_callback:
_target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback
save_dir: ${paths.output_dir}/val_structures
dump_predictions: False
one_model_per_file: False
dump_trajectories: False

View File

@@ -0,0 +1,14 @@
store_validation_metrics_in_df_callback:
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
save_dir: ${paths.output_dir}/val_metrics
metrics_to_save: "all"
log_af3_validation_metrics_callback:
_target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
metrics_to_log:
# Only logs if present in the metric output dictionary
# Must be subset of metrics_to_save
- by_type_lddt
- all_atom_lddt
- distogram_loss
- distogram_comparisons

View File

@@ -0,0 +1,16 @@
log_af3_training_losses_callback:
_target_: modelhub.callbacks.train_logging.LogAF3TrainingLossesCallback
log_every_n: 10
log_full_batch_losses: true
log_learning_rate_callback:
_target_: modelhub.callbacks.train_logging.LogLearningRateCallback
log_every_n: 10
log_model_parameters_callback:
_target_: modelhub.callbacks.train_logging.LogModelParametersCallback
log_dataset_sampling_ratios_callback:
_target_: modelhub.callbacks.train_logging.LogDatasetSamplingRatiosCallback

View File

@@ -0,0 +1,15 @@
train:
dataloader_params:
# These parameters will be unpacked as kwargs for the DataLoader
batch_size: 1
num_workers: 2
prefetch_factor: 3
n_fallback_retries: 4
val:
dataloader_params:
# These parameters will be unpacked as kwargs for the DataLoader
batch_size: 1
num_workers: 2
prefetch_factor: 3
n_fallback_retries: 0 # Disable fallback retries for validation

21
configs/datasets/af3.yaml Normal file
View File

@@ -0,0 +1,21 @@
# AF3 dataset configuration with monomer distillation
defaults:
- base
# The @ symbol specifies the tree under which the item will be attached to the config
- train/pdb/af3_train_interface@train.pdb.sub_datasets.interface
- train/pdb/af3_train_pn_unit@train.pdb.sub_datasets.pn_unit
- train:
- monomer_distillation
- val/af3_validation@val.af3_validation
- _self_
# Dataloading pipeline to use
pipeline_target: datahub.pipelines.af3.build_af3_transform_pipeline
# Dataset weighting
train:
pdb:
probability: 0.5
monomer_distillation:
probability: 0.5

View File

@@ -0,0 +1,12 @@
# Base Transform defaults
diffusion_batch_size_train: 48
diffusion_batch_size_inference: 5
n_recycles_train: 4
n_recycles_validation: 10
n_msa: 1024
crop_size: 384
max_atoms_in_crop: 5000
key_to_balance: n_tokens_total

View File

@@ -0,0 +1,38 @@
monomer_distillation:
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
# cif parser arguments
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
# metadata parser
dataset_parser:
_target_: datahub.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
# metadata dataset
dataset:
_target_: datahub.datasets.datasets.PandasDataset
name: af2fb_distillation
id_column: example_id
data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet
columns_to_load:
- example_id
- path
return_key: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False
protein_msa_dirs: [{"dir": "${paths.data.monomer_distillation_data_dir}/msa", "extension": ".a3m", "directory_depth": 2}]
rna_msa_dirs: []
n_recycles: ${datasets.n_recycles_train}
crop_size: ${datasets.crop_size}
n_msa: ${datasets.n_msa}
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
crop_contiguous_probability: 0.25
crop_spatial_probability: 0.75

View File

@@ -0,0 +1,45 @@
defaults:
- base
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.InterfacesDFParser
dataset:
name: interface
data: ${paths.data.pdb_data_dir}/interfaces_df_train.parquet
filters:
# filters common across all PDB datasets
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
# interface specific columns
- pn_unit_1_iid
- pn_unit_2_iid
- pn_unit_1_non_polymer_res_names
- pn_unit_2_non_polymer_res_names
- is_inter_molecule
- all_pn_unit_iids_after_processing
- involves_loi
transform:
# interface-specific Transform pipeline parameters
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0

View File

@@ -0,0 +1,41 @@
defaults:
- base
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.PNUnitsDFParser
dataset:
name: pn_unit
data: ${paths.data.pdb_data_dir}/pn_units_df_train.parquet
filters:
# filters common across all PDB datasets
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# pn_unit specific filters
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
- total_num_atoms_in_unprocessed_assembly
# pn_unit specific columns
- q_pn_unit_iid
- q_pn_unit_non_polymer_res_names
- all_pn_unit_iids_after_processing
- q_pn_unit_is_loi
transform:
# pn_unit-specific Transform pipeline parameters
crop_contiguous_probability: 0.3333333333333333
crop_spatial_probability: 0.6666666666666667

View File

@@ -0,0 +1,33 @@
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
cif_parser_args:
cache_dir: null
load_from_cache: false
save_to_cache: false
dataset:
_target_: datahub.datasets.datasets.PandasDataset
# we will use the example_id as the unique column
id_column: example_id
# return all keys (do not subset)
return_key: null
transform:
# common Transform pipeline components for all PDB datasets
_target_: ${datasets.pipeline_target}
is_inference: False
protein_msa_dirs: ${paths.data.protein_msa_dirs}
rna_msa_dirs: ${paths.data.rna_msa_dirs}
n_recycles: ${datasets.n_recycles_train}
crop_size: ${datasets.crop_size}
n_msa: ${datasets.n_msa}
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
weights:
_target_: datahub.samplers.calculate_weights_for_pdb_dataset_df
beta: 0.5
alphas:
a_prot: 3.0 # 3 for AF-3
a_nuc: 0.0 # 3 for AF-3
a_ligand: 1.0 # 1 for AF-3
a_loi: 5.0 # 5 for AF-3

View File

@@ -0,0 +1,12 @@
defaults:
- base
dataset:
dataset_parser:
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
dataset:
_target_: datahub.datasets.datasets.PandasDataset
data: ${paths.data.pdb_data_dir}/entry_level_val_df.parquet
filters:
# NOTE: We exclude these examples from validation because they produce an error upon RDKit small molecule processing that causes a data loading fallback
- example_id not in ["{['validation']}{7erc}{1}{[]}", "{['validation']}{7qbs}{1}{[]}", "{['validation']}{7z0n}{1}{[]}"]

View File

@@ -0,0 +1,26 @@
dataset:
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
dataset:
_target_: datahub.datasets.datasets.PandasDataset
# we will use the example_id as the unique column
id_column: example_id
# return all keys (do not subset)
return_key: null
transform:
# common Transform pipeline components for all PDB datasets
_target_: ${datasets.pipeline_target}
is_inference: True
protein_msa_dirs: ${paths.data.protein_msa_dirs}
rna_msa_dirs: ${paths.data.rna_msa_dirs}
n_recycles: ${datasets.n_recycles_validation}
crop_size: null # do not crop for inference
n_msa: ${datasets.n_msa}
diffusion_batch_size: ${datasets.diffusion_batch_size_inference}
max_atoms_in_crop: null # do not crop for inference
return_atom_array: True # return atom array for inference
key_to_balance: ${datasets.key_to_balance}

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- override /logger: null
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use the below to also set hydra loggers to 'DEBUG'
verbose: True
# Print example ID before forward pass
callbacks:
print_example_id_before_forward_pass:
_target_: modelhub.callbacks.train_logging.PrintExampleIDBeforeForwardPassCallback
dataloader:
train:
dataloader_params:
batch_size: 1
num_workers: 0 # debuggers don't like multiprocessing -- work on main thread
pin_memory: False # disable gpu memory pin
prefetch_factor: null # must be null for num_workers=0
n_fallback_retries: 0 # disable fallback retries for debugging
val:
dataloader_params:
batch_size: 1
num_workers: 0
pin_memory: False
prefetch_factor: null # must be null for num_workers=0
datasets:
crop_size: 100 # set small crop size for quick debugging
diffusion_batch_size_train: 1
diffusion_batch_size_inference: 1
n_recycles_train: 1
n_recycles_validation: 1
n_msa: 128
key_to_balance: null # otherwise big examples will be processed first
trainer:
devices_per_node: 1
limit_train_batches: 1
limit_val_batches: 1
validate_every_n_epochs: 1
# Set tags to help identify debugging runs
tags:
- debug

View File

@@ -0,0 +1,21 @@
# @package _global_
# See: https://hydra.cc/docs/patterns/configuring_experiments/
# to execute this experiment run:
# python train.py +debug=train_single_example [any other arguments]
defaults:
- default
- gpu
datasets:
# you can add specific example IDs here to load a subset of the dataset (training)
subset_to_example_ids:
- "{['pdb', 'pn_units']}{3px1}{1}{['A_3']}"
val: null
tags:
- debug
- train
- specific-examples

View File

@@ -0,0 +1,83 @@
# @package _global_
name: af3-elements-as-ligand-atom-names
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:
- override /trainer: af3
- override /datasets: af3
- override /model: af3
tags:
# list of tags to add to the run ( & on wandb to easily find & filter runs)
- atom-names
- experiment
project: af3
ckpt_path: /projects/ml/modelhub/inference/weights_with_no_confidence_2025_1_21_new_modelhub.ckpt
model:
lr_scheduler:
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
datasets:
train:
pdb:
probability: 1.0
sub_datasets:
interface:
dataset:
transform:
use_element_for_atom_names_of_atomized_tokens: True
dataset:
# Same as AF-3 training, but limit to protein-ligand interfaces
filters:
# (from before)
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
- >
~(pn_unit_1_non_polymer_res_names.notnull() and
pn_unit_1_non_polymer_res_names.str.contains(
'${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}',
regex=True))
- >
~(pn_unit_2_non_polymer_res_names.notnull() and
pn_unit_2_non_polymer_res_names.str.contains(
'${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}',
regex=True))
- "is_inter_molecule"
# only protein-ligand interfaces
- "(n_prot == 1 and n_nuc == 0 and n_ligand == 1)"
pn_unit:
dataset:
transform:
use_element_for_atom_names_of_atomized_tokens: True
dataset:
# Same as AF-3 training, but limit to protein-ligand interfaces
filters:
# (from before)
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
# only proteins or ligands
- "(n_prot == 1 or n_ligand == 1)"
# Datasets set to null are ignored
monomer_distillation: null
val:
af3_validation:
dataset:
transform:
use_element_for_atom_names_of_atomized_tokens: True
dataset:
filters:
- "n_tokens_total < 400"
- "interfaces_to_score.str.contains('protein-ligand')"
# Exclude example where RDKit errors
- example_id not in ["{['validation']}{7qbs}{1}{[]}"]

View File

@@ -0,0 +1,30 @@
# @package _global_
name: af3
defaults:
- override /datasets: af3
- override /model: af3
- override /trainer: af3
tags:
- af3
- fine-tune
project: af3
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt
model:
lr_scheduler:
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
# Protein-ligand only for speed
val:
af3_validation:
dataset:
dataset:
filters:
# Only score examples with protein-ligand interfaces
- "interfaces_to_score.str.contains('protein-ligand')"
- example_id not in ["{['validation']}{7qbs}{1}{[]}"]

View File

@@ -0,0 +1,37 @@
# @package _global_
name: af3-new-msas-pdb-only
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:
- override /trainer: af3
- override /datasets: af3
- override /model: af3
tags:
# list of tags to add to the run ( & on wandb to easily find & filter runs)
- msas
- experiment
project: af3
paths:
data:
protein_msa_dirs:
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
datasets:
train:
pdb:
# We must adjust the probability, since we set the monomer distillation dataset to null
probability: 1.0
# Datasets set to null are ignored
monomer_distillation: null
val:
af3_validation:
dataset:
dataset:
filters:
- "n_tokens_total < 400"

View File

@@ -0,0 +1,13 @@
# @package _global_
name: af3-old-msas-pdb-only
defaults:
- af3-new-msas-pdb-only
paths:
data:
protein_msa_dirs:
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}

View File

@@ -0,0 +1,15 @@
# @package _global_
# NOTE: Dummy experiment that you can use to just run the code
# . For actual experiments, please create a new experiment config from copying the template 'user-XX-template.yaml'
name: none-00-dummy
tags:
# list of tags to add to the run ( & on wandb to easily find & filter runs)
- experiment
- dummy
project: test

View File

@@ -0,0 +1,15 @@
# @package _global_
name: af3
defaults:
- override /datasets: af3
- override /model: af3
- override /trainer: af3
tags:
- af3
project: af3
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt

View File

@@ -0,0 +1,15 @@
# @package _global_
name: af3-with-confidence
defaults:
- override /datasets: af3
- override /model: af3_with_confidence
- override /trainer: af3_with_confidence
tags:
- af3
project: af3
ckpt_path: /projects/ml/modelhub/inference/weights_with_confidence_2025_2_27_new_modelhub.ckpt

View File

@@ -0,0 +1,14 @@
# @package _global_
# Experiment that loads a small dataset for quick testing
name: quick-af3-with-confidence
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:
- quick-af3
- override /model: af3_with_confidence
- override /trainer: af3_with_confidence
- _self_
ckpt_path: /projects/ml/modelhub/inference/weights_with_confidence_2025_2_27_new_modelhub.ckpt

View File

@@ -0,0 +1,50 @@
# @package _global_
# Experiment that loads a small dataset for quick testing
name: quick-af3
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
defaults:
- override /trainer: af3
- override /datasets: af3
- override /model: af3
tags:
# list of tags to add to the run ( & on wandb to easily find & filter runs)
- quick
project: test
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt
datasets:
train:
pdb:
# We must adjust the probability, since we set the monomer distillation dataset to null
probability: 1.0
sub_datasets:
interface:
dataset:
dataset:
# A small dataframe that loads quickly
data: /projects/ml/datahub/dfs/pdb/test_dfs/interfaces_df.parquet
filters:
- "num_polymer_pn_units <= 2"
- "cluster.notnull()"
pn_unit:
dataset:
dataset:
# A small dataframe that loads quickly
data: /projects/ml/datahub/dfs/pdb/test_dfs/pn_units_df.parquet
filters:
- "num_polymer_pn_units <= 2"
- "cluster.notnull()"
# Datasets set to null are ignored
monomer_distillation: null
val:
af3_validation:
dataset:
dataset:
filters:
- "n_tokens_total < 200"

View File

@@ -0,0 +1,18 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging (requires `colorlog` to be installed)
# defaults:
# - override hydra_logging: colorlog
# - override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/${name}/${now:%Y-%m-%d}_${now:%H-%M}
# ... this is where the log file is written (i.e. the programs output)
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/experiment.log

View File

@@ -0,0 +1,7 @@
defaults:
- override job_logging: disabled
- override hydra_logging: disabled
output_subdir: null
run:
dir: .

7
configs/inference.yaml Normal file
View File

@@ -0,0 +1,7 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
defaults:
- inference_engine: ???
- _self_

View File

@@ -0,0 +1,22 @@
# @package _global_
defaults:
- base
- _self_
_target_: modelhub.inference_engines.af3.AF3InferenceEngine
ckpt_path: /net/tukwila/ncorley/modelhub/inference/modelhub_latest.ckpt
n_recycles: 10
diffusion_batch_size: 5
residue_renaming_dict: null
num_steps: 50
solver: "af3"
print_config: true
seed: null
skip_existing: true
dump_predictions: true
dump_trajectories: false
one_model_per_file: false

View File

@@ -0,0 +1,10 @@
# @package _global_
defaults:
- /hydra: no_logging
ckpt_path: ???
inputs: ???
out_dir: ./
num_nodes: 1
devices_per_node: 1

6
configs/logger/csv.yaml Normal file
View File

@@ -0,0 +1,6 @@
# https://lightning.ai/docs/fabric/latest/api/generated/lightning.fabric.loggers.CSVLogger.html#lightning.fabric.loggers.CSVLogger
csv:
_target_: lightning.fabric.loggers.CSVLogger
root_dir: ${paths.output_dir}
flush_logs_every_n_steps: 1

View File

@@ -0,0 +1,3 @@
defaults:
- wandb
- csv

14
configs/logger/wandb.yaml Normal file
View File

@@ -0,0 +1,14 @@
# https://wandb.ai
wandb:
_target_: wandb.integration.lightning.fabric.WandbLogger
save_dir: ${paths.output_dir}
offline: False
id: null # pass correct id (along with checkpoint path, and resume='allow' or 'must') to resume a run
anonymous: null # enable anonymous logging
project: ${project}
prefix: "" # a string to put at the beginning of metric keys
log_model: False # do not upload model checkpoints
tags: ${tags}
# (Default resume to "never" to avoid accidentally resuming runs; we want to be explicit about resuming)
resume: never # never, allow, or must (see: https://docs.wandb.ai/guides/runs/resuming/)

7
configs/model/af3.yaml Normal file
View File

@@ -0,0 +1,7 @@
defaults:
- optimizers/adam@optimizer
- schedulers/af3@lr_scheduler
- components/ema@ema
- components/af3_net@net

View File

@@ -0,0 +1,5 @@
defaults:
- af3
- components/af3_net_with_confidence_head@net

View File

@@ -0,0 +1,177 @@
# Model architecture
_target_: modelhub.model.AF3.AF3
# +---------- Channel dimensions ----------+
c_s: 384
c_z: 128
c_atom: 128
c_atompair: 16
c_s_inputs: 449 # TODO: What is this?
# +---------- Feature embedding ----------+
feature_initializer:
# InputFeatureEmbedder
input_feature_embedder:
features:
- restype
- profile
- deletion_mean
atom_attention_encoder:
c_token: 384
c_atom_1d_features: 389
c_tokenpair: ${model.net.c_z}
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: 40_000 # does not matter
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
# RelativePositionEncoding
relative_position_encoding:
r_max: 32
s_max: 2
# +---------- Recycler ----------+
recycler:
# Pairformer
n_pairformer_blocks: 48
pairformer_block:
p_drop: 0.25
triangle_multiplication:
d_hidden: 128
triangle_attention:
n_head: 4
d_hidden: 32
attention_pair_bias:
n_head: 16
# TemplateEmbedder
template_embedder:
n_block: 2
raw_template_dim: 108
c: 64
p_drop: 0.25
# MSA module
msa_module:
n_block: 4
c_m: 64
p_drop_msa: 0.15
p_drop_pair: 0.25
msa_subsample_embedder:
num_sequences: 1024
dim_raw_msa: 34
c_s_inputs: ${model.net.c_s_inputs}
c_msa_embed: ${model.net.recycler.msa_module.c_m}
outer_product:
c_msa_embed: ${model.net.recycler.msa_module.c_m}
c_outer_product: 32
c_out: ${model.net.c_z}
msa_pair_weighted_averaging:
n_heads: 8
c_weighted_average: 32
c_msa_embed: ${model.net.recycler.msa_module.c_m}
c_z: ${model.net.c_z}
separate_gate_for_every_channel: true
msa_transition:
n: 4
c: ${model.net.recycler.msa_module.c_m}
triangle_multiplication_outgoing:
d_pair: ${model.net.c_z}
d_hidden: 128
bias: True
triangle_multiplication_incoming:
d_pair: ${model.net.c_z}
d_hidden: 128
bias: True
triangle_attention_starting:
d_pair: ${model.net.c_z}
n_head: 4
d_hidden: 32
p_drop: 0.0 # This does not do anything: TODO: Remove
triangle_attention_ending:
d_pair: ${model.net.c_z}
n_head: 4
d_hidden: 32
p_drop: 0.0 # This does not do anything; TODO: Remove
pair_transition:
n: 4
c: ${model.net.c_z}
# +---------- Diffusion module ----------+
diffusion_module:
sigma_data: 16
c_token: 768
f_pred: edm
diffusion_conditioning:
c_s_inputs: ${model.net.c_s_inputs}
c_t_embed: 256
relative_position_encoding:
r_max: 32
s_max: 2
atom_attention_encoder:
c_tokenpair: ${model.net.c_z}
c_atom_1d_features: 389
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
broadcast_trunk_feats_on_1dim_old: false
use_chiral_features: true
diffusion_transformer:
n_block: 24
diffusion_transformer_block:
n_head: 16
no_residual_connection_between_attention_and_transition: true
kq_norm: true
atom_attention_decoder:
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
distogram_head:
bins: 65
# +---------- Inference sampler ----------+
inference_sampler:
solver: "af3"
num_timesteps: 200
min_t: 0
max_t: 1
sigma_data: ${model.net.diffusion_module.sigma_data}
s_min: 4e-4
s_max: 160
p: 7
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.003
step_scale: 1.5

View File

@@ -0,0 +1,45 @@
defaults:
- af3_net
# Model architecture
_target_: modelhub.model.AF3.AF3WithConfidence
# +---------- Mini rollout sampler ----------+
# From the AF-3 main text:
# > ...To remedy this, we developed a diffusion rollout procedure for the full-structure prediction generation during training (using a larger step size than normal)
# They do not further elaborate on how they adjusted the step size during diffusion rollout, but this may be a fruitful area of exploration moving forwards
mini_rollout_sampler:
solver: "af3"
num_timesteps: 20 # 20 timesteps for the mini-rollout (vs. 200 for the full rollout during inference)
min_t: 0
max_t: 1
sigma_data: ${model.net.diffusion_module.sigma_data}
s_min: 4e-4
s_max: 160
p: 7
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.003
step_scale: 1.5
# +---------- Confidence head architecture ----------+
confidence_head:
c_s: ${model.net.c_s}
c_z: ${model.net.c_z}
n_pairformer_layers: 4
pairformer:
p_drop: 0.25
triangle_multiplication:
d_hidden: 128
triangle_attention:
n_head: 4
d_hidden: 32
attention_pair_bias:
n_head: 16
n_bins_pae: 64
n_bins_pde: 64
n_bins_plddt: 50
n_bins_exp_resolved: 2
use_Cb_distances: False
use_af3_style_binning_and_final_layer_norms: True
symmetrize_Cb_logits: True

View File

@@ -0,0 +1 @@
decay: 0.999 # From AF-3

View File

@@ -0,0 +1,5 @@
# Optimizer
_target_: torch.optim.Adam
lr: 0 # Will be set by the scheduler (starts at 0, increasing to `base_lr`)
betas: [0.9, 0.95]
eps: 1.0e-8

View File

@@ -0,0 +1,6 @@
# Learning rate scheduler
_target_: modelhub.training.schedulers.AF3Scheduler
base_lr: 1.8e-3
warmup_steps: 1000
decay_factor: 0.95
decay_steps: 50000

View File

@@ -0,0 +1,23 @@
# path to directory with training splits
pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2024_12_16/
# fb monomer distillation dataset
monomer_distillation_data_dir: /squash/af2_distillation_facebook/
monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
# path(s) to search for protein MSAs (for PDB datasets)
protein_msa_dirs:
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
- {"dir": "/net/scratch/mkazman/msa/validate_no_leak_taxid", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/net/scratch/mkazman/msa/missing_antibody_msas", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/net/scratch/mkazman/msa/post_training_cutoff_msas/processed_nested", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/net/scratch/mkazman/msa/post_training_cutoff_msas/extra_seqs_processed_nested", "extension": ".a3m.gz", "directory_depth": 2}
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}
# path(s) to search for RNA MSAs
rna_msa_dirs:
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_rna_msas", "extension": ".afa", "directory_depth": 0}
# path to save examples that fail during the Transform pipeline (null = do not save)
failed_examples_dir: null

View File

@@ -0,0 +1,21 @@
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
defaults:
- _self_
- data: default
# path to root directory (requires the `PROJECT_ROOT` environment variable to be set)
#  NOTE: This variable is auto-set upon loading via `rootutils`
root_dir: ${oc.env:PROJECT_ROOT}
# where to store data (checkpoints, logs, etc.) of all experiments in general
# (this influences the output_dir in the hydra/default.yaml config)
# change this to e.g. /scratch if you are running larger experiments with lots lof logs, checkpoints, etc.
log_dir: ${.root_dir}/logs/
# path to output directory for this specific run, created dynamically by hydra
# path generation pattern is specified in `configs/hydra/default.yaml`
# use it to store all files generated during the run, like ckpts and metrics
output_dir: ${hydra:runtime.output_dir}
# path to working directory (auto-generated by hydra)
work_dir: ${hydra:runtime.cwd}

42
configs/train.yaml Normal file
View File

@@ -0,0 +1,42 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
defaults:
- callbacks: default
- logger: csv
- trainer: ???
- paths: default
- datasets: ???
- dataloader: default
- hydra: default
- model: ???
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
- _self_
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: ???
# debug configs to add onto any experiment for quickly testing or debugging code
- debug: null
# DO NOT set these here. Set them in the relevant experiment config file.
# ... these are just here to ensure users always specify these fields in their experiment configs.
name: ???
tags: ???
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
# here.
# ... task name (determines the output directory path)
task_name: "train"
project: ??? # required for W&B logging
seed: 1
# Provide checkpoint path to resume training from a checkpoint
# NOTE: If using W&B, must also set the `id` and `resume` fields in the `logger/wandb` config
ckpt_path: null

20
configs/trainer/af3.yaml Normal file
View File

@@ -0,0 +1,20 @@
defaults:
- ddp
- loss: structure_prediction
- metrics: structure_prediction
_target_: modelhub.trainers.af3.AF3Trainer
validate_every_n_epochs: 1
max_epochs: 10_000
n_examples_per_epoch: 24000
prevalidate: True
# We must pre-specify the number of recycles during training so we can pre-sample recycles per batch consistently for each GPU
n_recycles_train: ${datasets.n_recycles_train}
clip_grad_max_norm: 10.0
output_dir: ${paths.output_dir}
checkpoint_every_n_epochs: 1
# precision: bf16-mixed # Mixed precision training with bfloat16 (currently does not work)

View File

@@ -0,0 +1,5 @@
defaults:
- af3
- override loss: structure_prediction_with_confidence
_target_: modelhub.trainers.af3.AF3TrainerWithConfidence

6
configs/trainer/cpu.yaml Normal file
View File

@@ -0,0 +1,6 @@
defaults:
- af3
accelerator: cpu
devices_per_node: 1
num_nodes: 1

5
configs/trainer/ddp.yaml Normal file
View File

@@ -0,0 +1,5 @@
strategy: ddp
accelerator: gpu
devices_per_node: 1
num_nodes: 1

View File

@@ -0,0 +1,29 @@
_target_: modelhub.loss.af3_confidence_loss.ConfidenceLoss
weight: 1.0
plddt:
weight: 1.0
n_bins: 50
max_value: 1.0
pae:
weight: 1.0
n_bins: 64
max_value: 32
pde:
weight: 1.0
n_bins: 64
max_value: 32
exp_resolved:
weight: 1.0
n_bins: 2
max_value: 1
# Adds to loss_dict true and predicted average plddt, pae, and pde per batch, also info about the spread and correlation of those values within a batch
log_statistics: True
rank_loss:
use_listnet_loss: False
weight: 0.0

View File

@@ -0,0 +1,9 @@
_target_: modelhub.loss.af3_losses.DiffusionLoss
weight: 4.0
sigma_data: ${model.net.diffusion_module.sigma_data}
alpha_dna: 5
alpha_rna: 5
alpha_ligand: 10
edm_lambda: True
se3_invariant_loss: True
clamp_diffusion_loss: False

View File

@@ -0,0 +1,2 @@
_target_: modelhub.loss.af3_losses.DistogramLoss
weight: 3e-2

View File

@@ -0,0 +1,4 @@
defaults:
# Note that the SmoothedLDDTLoss is included within the DiffusionLoss
- losses/diffusion_loss@diffusion_loss
- losses/distogram_loss@distogram_loss

View File

@@ -0,0 +1,2 @@
defaults:
- losses/confidence_loss@confidence_loss

View File

@@ -0,0 +1,8 @@
by_type_lddt:
_target_: modelhub.metrics.lddt.ByTypeLDDT
all_atom_lddt:
_target_: modelhub.metrics.lddt.AllAtomLDDT
distogram:
_target_: modelhub.metrics.distogram.DistogramLoss
distogram_comparisons:
_target_: modelhub.metrics.distogram.DistogramComparisons

49
configs/validate.yaml Normal file
View File

@@ -0,0 +1,49 @@
# @package _global_
# ^ The "package" determines where the content of the config is placed in the output config
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
defaults:
- callbacks: default
- logger: csv
- trainer: ???
- paths: default
- datasets: ???
- dataloader: default
- hydra: default
- model: ???
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
- _self_
# experiment configs allow for version control of specific hyperparameters
# e.g. best hyperparameters for given model and datamodule
- experiment: ???
# debug configs to add onto any experiment for quickly testing or debugging code
- debug: null
# DO NOT set these here. Set them in the relevant experiment config file.
# ... these are just here to ensure users always specify these fields in their experiment configs.
name: ???
tags: ???
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
# here.
# ... task name (determines the output directory path)
task_name: "validate"
project: ??? # required for W&B logging
seed: 1
# Dump CIF files for validation structures
callbacks:
dump_validation_structures_callback:
dump_predictions: True
one_model_per_file: False
dump_trajectories: False
# passing checkpoint path required for validation
# DO NOT set here; set in the experiment config file
ckpt_path: ???

View File

@@ -6,53 +6,80 @@ channels:
- conda-forge - conda-forge
- defaults - defaults
dependencies: dependencies:
# Core dependencies
- pip
- python=3.11 - python=3.11
- cuda - cuda
- pytorch=2.4 - pytorch=2.4
- pytorch-cuda=12.4 - pytorch-cuda=12.4
- pytorch-scatter>=2.1.0,<3 - pytorch-scatter>=2.1.0,<3
- lightning>=2.4.0,<2.5 - lightning>=2.4.0,<2.5
- pandas>=1.4.2,<2.3 # Small molecule libraries
- numpy>=1.25.0,<2.1
- scipy>=1.13.1,<2
- cytoolz>=0.12.3,<1
- biopython>=1.83,<2
- fire>=0.6.0,<1
- ruff>=0.6.2
- pytest-dotenv>=0.5.2,<1
- pytest-cov>=4.1.0,<5
- rdkit>=2024.3.5 - rdkit>=2024.3.5
- openbabel=3.1.1 - openbabel=3.1.1
- pip
- pip: - pip:
- biotite>=1.1.0,<1.2 # Project-related dependencies
- seaborn>=0.13.0,<1 # ... generic tools
- loguru>=0.7.0,<1 - GitPython>=3.0.0,<4 # GitPython is a Python library used to interact with Git repositories
- beartype>=0.18.0,<1 - cython>=3.0.0,<4 # Cython compiler for C extensions
- cytoolz>=0.12.3,<1 # Cython-optimized tools for itertools and functional programming
- assertpy>=1.1.0,<2 # Assertions library
- tqdm>=4.65.0,<5 # Fast, extensible progress bar for loops and more
- rootutils>=1.0.7,<1.1 # Setting up the project root paths
- dm-tree>=0.1.6,<1 # Tree data structure from DeepMind
- deepdiff>=8.0.0,<9 # Deep difference and search of any Python object
# ... configuration & CLI
- fire>=0.6.0,<1 # Better argument parsing than argparse
- hydra-core>=1.3.0,<1.4 # Config management framework
- environs>=11.0.0,<12
# ... linear algebra, maths & ml
- numpy>=1.25.0,<2
- scipy>=1.13.1,<2
- einops>=0.8.0,<1 - einops>=0.8.0,<1
- einx>=0.1.0,<1 - einx>=0.1.0,<1
- debugpy>=1.8.5,<2
- cython>=3.0.0,<4
- pytest>=8.2.0,<9
- assertpy>=1.1.0,<2
- pre-commit>=3.7.1
- tqdm>=4.65.0,<5
- py3Dmol>=2.2.1,<3
- pyarrow>=17.0.0
- fastparquet>=2024.5.0
- ipykernel>=6.29.4,<7
- jaxtyping>=0.2.17,<1
- hydra-core>=1.3.0,<1.4
- wandb>=0.15.10,<1
- environs>=11.0.0,<12
- rootutils>=1.0.7,<1.1
- opt_einsum>=3.4.0,<4 - opt_einsum>=3.4.0,<4
- rich>=13.9.4,<14 - deepspeed>=0.15.1 # will be uninstalled by the apptainer's `spec` file, if pre-compiling
- msgpack>=1.1.0,<2 # ... data tools
- pymol-remote>=0.1.0 - pandas>=2.2,<2.3 # Data manipulation and analysis
- deepspeed>=0.15.1 - pyarrow==17.0.0 # Columnar data format for efficient data storage and processing
- fastparquet==2024.5.0 # Fast Parquet file format implementation
- seaborn>=0.13.0,<1
# ... bioinformatics
- biopython>=1.83,<2 # Collection of Python modules for bioinformatics
- py3Dmol>=2.2.1,<3 # Python wrapper for 3Dmol.js
- pymol-remote>=0.0.5 # Remote access to PyMOL from Python (has no dependencies)
- git+https://github.com/biotite-dev/biotite.git@fab175e7ba4608d9613f092ad4e080661c6cc816 - git+https://github.com/biotite-dev/biotite.git@fab175e7ba4608d9613f092ad4e080661c6cc816
- GitPython>=3.0.0,<4 # Git library for Python - hydride==1.2.3 #biotite supported hydrogen addition
# ... logging
- wandb>=0.15.10,<1
- rich>=13.9.4,<14
# NOTE: After navigating to the datahub / cifutils directories, you can install the local package in editable mode with: # Formatting & linting (only needed for development)
- ruff==0.8.3 # python linter & formatter
- pre-commit==3.7.1 # pre-commit hooks for formatting & linting
# Debugger & interactive tools (only needed for development)
- debugpy>=1.8.5,<2 # debugger for python
- ipykernel>=6.29.4,<7 # ipython kernel for jupyter
- icecream>=2.0.0,<3 # print debugging
- pymol-remote>=0.1.0 # Remote access to PyMOL from Python (has no dependencies)
- ipdb>=0.13.9 # IPython debugger
# Pytest plugins (only needed for development)
- pytest>=8.2.0,<9 # testing framework
- pytest-testmon>=2.1.1,<3 # run only tests related to changed code
- pytest-xdist>=3.6.1,<4 # run tests in parallel
- pytest-dotenv>=0.5.2,<1 # load environment variables from .env file
- pytest-cov>=4.1.0,<5 # generate coverage report
- pytest-benchmark>=5.0.0,<6 # benchmark tests for speed
# Typing & documentation (only needed for development)
- jaxtyping>=0.2.17,<1
- beartype>=0.18.0,<1
# NOTE: After navigating to the datahub / cifutils / modelhub directories, you can install the local package in editable mode with:
# pip install -e . # pip install -e .
# NOTE: By default, DeepSpeed just-in-time compiles, which may take 3-4 minutes when first running the code on a new machine.
# It may be possible to pre-compile DeepSpeed within a `conda` environment; see: https://www.deepspeed.ai/tutorials/advanced-install/
# By default, the apptainers will have DeepSpeed pre-compiled, so when performance is a concern, it is recommended to use the apptainers.

115
freeze_apptainer.spec Normal file
View File

@@ -0,0 +1,115 @@
Bootstrap: localimage
From: ./scripts/shebang/modelhub.sif
IncludeCmd: yes
# NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd`
%setup
# NOTE: This is executed on the host, not the container
# Ensure the token environment variables are set
set +x # ... supress bash output to avoid printing the tokens in the output
for var in GITHUB_USER GITHUB_TOKEN; do
if [ -z "$(eval echo \$$var)" ]; then
set -x
echo "ERROR: $var is not set. Please create a personal access token at"
echo " - GitHub: https://github.com/settings/tokens"
echo "Then set the following environment variables:"
echo " - GITHUB_USER"
echo " - GITHUB_TOKEN"
exit 1
fi
done
set -x
# Create temporary `secrets.txt` file from host's environment variables in the container
# (which are otherwise not available in the %post section)
echo "Creating temporary secrets.txt file with access tokens in the container"
set +x
touch ${APPTAINER_ROOTFS}/secrets.txt
echo "GITHUB_USER=${GITHUB_USER}" >> ${APPTAINER_ROOTFS}/secrets.txt
echo "GITHUB_TOKEN=${GITHUB_TOKEN}" >> ${APPTAINER_ROOTFS}/secrets.txt
set -x
# Conditionally copy the project files based on the INSTALL_PROJECT environment variable
if [ ${INSTALL_PROJECT} = "true" ]; then
echo "Copying project files into the container..."
mkdir -p ${APPTAINER_ROOTFS}/opt/modelhub
rsync -av ./ ${APPTAINER_ROOTFS}/opt/modelhub/
else
echo "Skipping copying of project files."
fi
%post
# get os name
echo "Running on OS name $(lsb_release -i | awk '{ print $3 }')"
# get os version
echo "... in OS version $(lsb_release -r | awk '{ print $2 }')"
## SECRETS FILE
# Deal with secrets file
# ... verify that the secrets file is present on the container
if [ ! -e /secrets.txt ]; then
echo "ERROR: secrets.txt is not present on the container"
exit 1
fi
# ... temporarily set the access token environment variables
# from the secrets file
echo "Exporting access tokens from secrets.txt"
set +x
export GITHUB_USER=$(grep GITHUB_USER /secrets.txt | cut -d '=' -f2)
export GITHUB_TOKEN=$(grep GITHUB_TOKEN /secrets.txt | cut -d '=' -f2)
set -x
# ... remove secrets file
rm secrets.txt
# ... verify that the secrets file is not present on the container
if [ -e /secrets.txt ]; then
echo "ERROR: secrets.txt is still present on the container"
exit 1
else
echo "Verified that secrets.txt is not present on the container"
fi
# ... verify that the access token environment variables are set
set +x
for var in GITHUB_USER GITHUB_TOKEN; do
if [ -z "$(eval echo \$$var)" ]; then
echo "ERROR: $var is not set"
exit 1
fi
done
set -x
echo "Verified that access tokens are set"
# Install additional libraries
# Cifutils
pip install git+https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/baker-laboratory/cifutils.git@v2.15.0
# Datahub
pip install git+https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/baker-laboratory/datahub.git@v3.14.1
# Modelhub (maybe)
if [ -d "/opt/modelhub" ]; then
echo "Installing the project from /opt/modelhub..."
pip install /opt/modelhub
else
echo "Skipping project installation. /opt/modelhub does not exist."
fi
## CLEANUP
# Unset the access token environment variables to avoid possibly
# leaking them in the container
unset GITHUB_USER
unset GITHUB_TOKEN
# ... verify that the access token environment variables are unset
set +x
for var in GITHUB_USER GITHUB_TOKEN; do
if [ -n "$(eval echo \$$var)" ]; then
set -x
echo "ERROR: $var is still set"
exit 1
fi
done
set -x
echo "Verified that access tokens are unset."
%runscript
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
exec python "$@"

547
notebooks/plot.ipynb Normal file
View File

@@ -0,0 +1,547 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Plotting AF3 Results with CSV Logger"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This notebook provides examples of how to parse the results of the `CSVLogger` for both training and validation."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Imports for this notebook\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"from pathlib import Path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Path to the log folder\n",
"LOG_PATH = Path(\"/path/to/logs\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Validation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Workflows to plot and visualize validation metrics"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot Results for Most Recent Epoch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"val_df = pd.read_csv(LOG_PATH / \"val_metrics/validation_output_all_epochs.csv\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_validation_results_by_type(\n",
" df: pd.DataFrame,\n",
" ignore_zeros: bool = False,\n",
") -> None:\n",
" \"\"\"Visualize metrics across all datasets.\n",
"\n",
" NOTE: Ensure that you first subset the DataFrame to only include the desired epoch.\n",
" \n",
" Args:\n",
" df: Combined DataFrame containing metrics data\n",
" ignore_zeros: Whether to treat zero values as missing data\n",
" \"\"\"\n",
" # (Copy the DataFrame to avoid modifying the original)\n",
" _df = df.copy()\n",
"\n",
" # ... subset to only include the desired columns\n",
" _df = _df[[\"dataset\", \"by_type_lddt.type\", \"by_type_lddt.best_of_1_lddt\", \"by_type_lddt.best_of_5_lddt\"]]\n",
" _df = _df.dropna()\n",
"\n",
" if ignore_zeros:\n",
" _df = _df.replace(0, pd.NA)\n",
"\n",
" # Prepare data\n",
" melted = pd.melt(_df,\n",
" id_vars=[\"dataset\", \"by_type_lddt.type\"],\n",
" value_vars=[\"by_type_lddt.best_of_1_lddt\", \"by_type_lddt.best_of_5_lddt\"],\n",
" var_name='metric',\n",
" value_name='lddt')\n",
" \n",
" # Create visualization\n",
" sns.set(style=\"whitegrid\", font_scale=1.1)\n",
" plt.figure(figsize=(15, 8))\n",
" \n",
" g = sns.catplot(\n",
" data=melted,\n",
" x='by_type_lddt.type',\n",
" y='lddt',\n",
" hue='metric',\n",
" col='dataset',\n",
" kind='bar',\n",
" estimator='mean', # Explicitly set to mean aggregation\n",
" ci=None, # Disable confidence intervals\n",
" height=6,\n",
" aspect=2,\n",
" sharey=False,\n",
" legend_out=False\n",
" )\n",
"\n",
" # Annotate bars with values\n",
" for ax in g.axes.flat:\n",
" for p in ax.patches:\n",
" ax.annotate(f\"{p.get_height():.2f}\",\n",
" (p.get_x() + p.get_width() / 2., p.get_height()),\n",
" ha='center', va='center',\n",
" fontsize=10,\n",
" color='black',\n",
" xytext=(0, 7),\n",
" textcoords='offset points')\n",
" \n",
" ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')\n",
" ax.set_xlabel('')\n",
" ax.set_ylabel('LDDT')\n",
"\n",
" plt.suptitle(f'Model Performance Comparison', y=1.02)\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"current_epoch_df = val_df[val_df[\"epoch\"] == val_df[\"epoch\"].max()]\n",
"plot_validation_results_by_type(current_epoch_df)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot Validation Curves"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_metric_trend(\n",
" df: pd.DataFrame,\n",
" metric_type: str,\n",
" dataset: str | None = None,\n",
" ignore_zeros: bool = False,\n",
" last_n_epochs: int | None = None\n",
") -> None:\n",
" \"\"\"Plot best-of-1 vs best-of-5 trends across epochs for a specific metric type.\n",
" \n",
" Args:\n",
" df: Combined metrics DataFrame\n",
" metric_type: The 'type' to plot (e.g., 'protein-ligand')\n",
" dataset_filter: Optional specific dataset to filter\n",
" ignore_zeros: Whether to exclude zero values\n",
" last_n_epochs: Optional number of most recent epochs to plot\n",
" \"\"\"\n",
" # Filter data\n",
" filtered = df[df['by_type_lddt.type'] == metric_type]\n",
" \n",
" if dataset:\n",
" filtered = filtered[filtered['dataset'] == dataset]\n",
" \n",
" if ignore_zeros:\n",
" filtered = filtered.replace(0, pd.NA).dropna(\n",
" subset=['by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']\n",
" )\n",
"\n",
" if filtered.empty:\n",
" raise ValueError(f\"No data found for {metric_type} in dataset {dataset or 'any'}\")\n",
"\n",
" if last_n_epochs:\n",
" max_epoch = filtered['epoch'].max()\n",
" filtered = filtered[filtered['epoch'] > (max_epoch - last_n_epochs)]\n",
"\n",
" # Aggregate by epoch\n",
" trend_data = filtered.groupby('epoch').agg({\n",
" 'by_type_lddt.best_of_1_lddt': 'mean',\n",
" 'by_type_lddt.best_of_5_lddt': 'mean'\n",
" }).reset_index()\n",
"\n",
" # Create plot\n",
" plt.figure(figsize=(12, 6))\n",
" sns.set_style(\"whitegrid\")\n",
"\n",
" # Plot lines with markers\n",
" sns.lineplot(\n",
" data=trend_data,\n",
" x='epoch',\n",
" y='by_type_lddt.best_of_1_lddt',\n",
" color='#1f77b4',\n",
" label='Best of 1',\n",
" marker='o',\n",
" markersize=8,\n",
" linewidth=2\n",
" )\n",
" \n",
" sns.lineplot(\n",
" data=trend_data,\n",
" x='epoch',\n",
" y='by_type_lddt.best_of_5_lddt',\n",
" color='#ff7f0e',\n",
" label='Best of 5',\n",
" marker='s',\n",
" markersize=8,\n",
" linewidth=2\n",
" )\n",
"\n",
" # Style plot\n",
" plt.title(f\"{metric_type} LDDT Trends\\nDataset: {dataset or 'All'}\")\n",
" plt.xlabel(\"Epoch\")\n",
" plt.ylabel(\"Average LDDT\")\n",
" plt.legend(title=\"Strategy\")\n",
" plt.grid(alpha=0.3)\n",
"\n",
" # Add padding to autoscaled y-axis\n",
" plt.ylim(top=min(1.0, plt.ylim()[1] * 1.05)) # Cap at 1.0 if near upper bound\n",
" plt.ylim(bottom=max(0.0, plt.ylim()[0] * 0.95)) # Floor at 0.0 if near lower bound\n",
"\n",
" # Set x-axis to show only whole numbers\n",
" plt.xticks(ticks=trend_data['epoch'], labels=trend_data['epoch'].astype(int))\n",
" \n",
" # Add final value annotations\n",
" last_values = trend_data.iloc[-1]\n",
" plt.text(\n",
" 0.95, 0.15,\n",
" f\"Final Bo1: {last_values['by_type_lddt.best_of_1_lddt']:.2f}\\nFinal Bo5: {last_values['by_type_lddt.best_of_5_lddt']:.2f}\",\n",
" ha='right',\n",
" va='bottom',\n",
" transform=plt.gca().transAxes,\n",
" bbox=dict(facecolor='white', alpha=0.8)\n",
" )\n",
"\n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_metric_trend(val_df, 'protein-ligand', dataset=\"af3_validation\", ignore_zeros=False, last_n_epochs=4)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize Outliers"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Identify and visualize outliers"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"structures_path = f\"{LOG_PATH}/val_structures\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def get_worst_examples(\n",
" df: pd.DataFrame, \n",
" metric_type: str, \n",
" dataset: str = None, \n",
" epoch: int = None\n",
") -> list:\n",
" \"\"\"Return example IDs sorted by worst performance for a metric type at specific epoch.\"\"\"\n",
" filtered = df[df['by_type_lddt.type'] == metric_type]\n",
" \n",
" if dataset:\n",
" filtered = filtered[filtered['dataset'] == dataset]\n",
" \n",
" # Use latest epoch if none specified\n",
" target_epoch = epoch if epoch is not None else filtered['epoch'].max()\n",
" filtered = filtered[filtered['epoch'] == target_epoch]\n",
" \n",
" return (\n",
" filtered[['example_id', 'by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']]\n",
" .assign(worst_score=lambda x: x[['by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']].min(axis=1))\n",
" .sort_values('worst_score')\n",
" ['example_id']\n",
" .tolist()\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datahub.common import parse_example_id\n",
"from cifutils.utils.visualize import view\n",
"from cifutils import parse\n",
"\n",
"dataset = \"af3-validation\"\n",
"metric = \"protein-ligand\"\n",
"latest_epoch = val_df['epoch'].max()\n",
"\n",
"worst_protein_ligand_examples = get_worst_examples(val_df, metric, dataset=\"af3_validation\", epoch=latest_epoch)\n",
"\n",
"# Visualize the worst example\n",
"parsed_id = parse_example_id(worst_protein_ligand_examples[0])\n",
"\n",
"# Find the worst example in the structures directory\n",
"structure_path_for_epoch = Path(structures_path) / f\"epoch_{latest_epoch}\" / dataset\n",
"\n",
"if structure_path_for_epoch.exists():\n",
" example_path = next(structure_path_for_epoch.glob(f\"*{parsed_id['pdb_id']}_{parsed_id['assembly_id']}*\"))\n",
"\n",
" # ... and visualize\n",
" atom_array = parse(example_path)\n",
" view(atom_array[\"assemblies\"][\"1\"][0])\n",
"else:\n",
" print(f\"No structure found for {parsed_id}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"csv_path = f\"{LOG_PATH}/lightning_logs/version_0/metrics.csv\"\n",
"_df = pd.read_csv(csv_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Curves"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Subset to the relevant columns\n",
"mean_cols = [\"train/batch_mean/diffusion_loss\", \"train/batch_mean/smoothed_lddt_loss\", \"train/batch_mean/total_loss\", \"train/batch_mean/distogram_loss\"]\n",
"per_structure_cols = [\"train/per_structure/t\", \"train/per_structure/diffusion_loss\", \"train/per_structure/smoothed_lddt_loss\"]\n",
"train_df = _df[mean_cols + per_structure_cols + [\"step\", \"train/learning_rate\"]]\n",
"\n",
"# Remove rows with all NaN values except for the 'step' column\n",
"check_cols = [col for col in train_df.columns if col != 'step']\n",
"train_df = train_df.dropna(how='all', subset=check_cols)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_training_metrics(train_df: pd.DataFrame) -> None:\n",
" \"\"\"Plot all training metrics from a DataFrame.\"\"\"\n",
" \n",
" processed = (\n",
" train_df\n",
" .groupby('step', as_index=False)\n",
" .mean()\n",
" .melt(id_vars='step', var_name='metric')\n",
" )\n",
" \n",
" # Create visualization\n",
" plt.figure(figsize=(12, 8))\n",
" sns.set_style(\"whitegrid\")\n",
" \n",
" g = sns.FacetGrid(\n",
" processed,\n",
" col='metric',\n",
" col_wrap=3,\n",
" height=4,\n",
" aspect=1.5,\n",
" sharey=False\n",
" )\n",
" \n",
" g.map(sns.lineplot, 'step', 'value', color='#2ca02c')\n",
" g.set_titles(\"{col_name}\")\n",
" g.set_axis_labels(\"Training Step\", \"Value\")\n",
" \n",
" # Special handling for learning rate\n",
" if 'learning_rate' in processed['metric'].unique():\n",
" g.axes[-1].set_yscale('log')\n",
" \n",
" plt.tight_layout()\n",
" plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"mean_df = train_df[mean_cols + [\"step\"]].copy()\n",
"mean_df = mean_df.dropna(subset=mean_cols)\n",
"plot_training_metrics(train_df[mean_cols + [\"step\", \"train/learning_rate\"]])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training Loss by T"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def plot_loss_scatter_by_t(\n",
" train_df: pd.DataFrame,\n",
" loss_column: str,\n",
" t_column: str = 'train/per_structure/t',\n",
" n_steps: int = 1000\n",
") -> None:\n",
" \"\"\"Plot loss values as a scatter plot against train/t values for the most recent N steps with a log-scaled x-axis.\n",
" \n",
" Args:\n",
" train_df: DataFrame containing training metrics\n",
" loss_column: Name of loss column to plot (e.g., 'train/total_loss')\n",
" t_column: Name of the column representing 't' values\n",
" n_steps: Number of recent training steps to analyze (default: 1000)\n",
" \"\"\"\n",
" train_df = train_df.copy()\n",
"\n",
" assert loss_column in train_df.columns, f\"Loss column '{loss_column}' not found in DataFrame\"\n",
"\n",
" # Get most recent steps\n",
" unique_steps = train_df['step'].dropna().unique()\n",
" \n",
" # Get actual number of available steps\n",
" n_steps = min(n_steps, len(unique_steps))\n",
" latest_steps = np.sort(unique_steps)[-n_steps:]\n",
" \n",
" # Filter recent data\n",
" recent_data = train_df[train_df['step'].isin(latest_steps)]\n",
"\n",
" # Subset to relevant columns and remove rows with NaN values\n",
" recent_data = recent_data[[t_column, loss_column]]\n",
" \n",
" # Create scatter plot\n",
" plt.figure(figsize=(10, 6)) # Fixed figure size\n",
" sns.set_style(\"whitegrid\")\n",
" \n",
" sns.scatterplot(\n",
" data=recent_data,\n",
" x=t_column,\n",
" y=loss_column,\n",
" color='#2ca02c',\n",
" alpha=0.6\n",
" )\n",
" \n",
" plt.xscale('log') # Set x-axis to logarithmic scale\n",
" plt.title(f\"{loss_column} vs {t_column} (Log Scale)\\n(Last {n_steps} Steps)\")\n",
" plt.xlabel(t_column)\n",
" plt.ylabel(loss_column)\n",
" \n",
" plt.tight_layout()\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_loss_scatter_by_t(train_df, 'train/per_structure/smoothed_lddt_loss', n_steps=1000)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "modelhub",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -20,19 +20,19 @@ build-backend = "hatchling.build"
source = "vcs" source = "vcs"
[tool.hatch.build.hooks.vcs] [tool.hatch.build.hooks.vcs]
version-file = "rf2aa/version.py" version-file = "src/modelhub/version.py"
[tool.hatch.metadata] [tool.hatch.metadata]
allow-direct-references = true allow-direct-references = true
[tool.hatch.build.targets.wheel] [tool.hatch.build.targets.wheel]
packages = ["rf2aa"] packages = ["src/modelhub"]
# Formatting & linting settings ------------------------------------------------------- # Formatting & linting settings -------------------------------------------------------
[tool.ruff] [tool.ruff]
line-length = 88 line-length = 88
indent-width = 4 indent-width = 4
target-version = "py311" target-version = "py310"
exclude = [ exclude = [
".bzr", ".bzr",
".direnv", ".direnv",

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

View File

@@ -1,479 +0,0 @@
import json
import logging
from collections import defaultdict
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tree
from icecream import ic
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from scipy.stats import norm
from rf2aa import pymol, pymol_tools
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.debug import pretty_describe_dict
from rf2aa.loss.af3_losses import Loss
from rf2aa.pymol import cmd
from rf2aa.util import writepdb
logger = logging.getLogger(__name__)
def flatten_dictionary(dictionary, parent_key="", separator="."):
flattened_dict = {}
for key, value in dictionary.items():
new_key = f"{parent_key}{separator}{key}" if parent_key else key
if isinstance(value, dict):
flattened_dict.update(flatten_dictionary(value, new_key, separator))
else:
flattened_dict[new_key] = value
return flattened_dict
class LogMetrics(Callback):
def __init__(self, config):
super().__init__()
self.config = config
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx: int,
) -> None:
logger.debug("on_train_batch_end outputs:\n" + pretty_describe_dict(outputs))
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
o = {}
stratifications = defaultdict(list)
for metric in [diffusion_losses, lddt_metrics]:
metric_d, stratification_keys = metric(self.config, outputs)
stratifications[stratification_keys].extend(metric_d.keys())
o.update(metric_d)
o["t"] = outputs["t"]
o["t_quantile_4"] = get_t_quantiles(
outputs["t"], self.config.loss.sigma_data, 4
)
df = pd.DataFrame.from_dict(o)
df = df.reindex(sorted(df.columns), axis=1)
ic(o)
(D,) = outputs["t"].shape
df["batch_idx"] = batch_idx
df["data_idx"] = np.arange(D)
df["global_step"] = trainer.global_step
trainer.logger.log_df(df, stratifications=stratifications)
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
o = {}
for metric in [lddt_metrics, lddt_metrics_null, diffusion_losses]:
metric_d, stratification_keys = metric(self.config, outputs)
o.update(metric_d)
df = pd.DataFrame.from_dict(o)
df = df.reindex(sorted(df.columns), axis=1)
ic(o)
df["batch_idx"] = batch_idx
df["global_step"] = trainer.global_step
trainer.logger.log_df(df, stratifications={})
return super().on_validation_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)
def lddt_metrics(config, outputs):
# compute distances between ground truth atoms
ground_truth_distances = torch.cdist(outputs["X_gt_L"], outputs["X_gt_L"])
# compute distances between predicted atoms
predicted_distances = torch.cdist(outputs["X_L"], outputs["X_L"])
# compute LDDT score for each pair of distances
difference_distances = torch.abs(ground_truth_distances - predicted_distances)
lddt_matrix = torch.zeros_like(difference_distances)
lddt_matrix = (
0.25 * (difference_distances < 4.0)
+ 0.25 * (difference_distances < 2.0)
+ 0.25 * (difference_distances < 1.0)
+ 0.25 * (difference_distances < 0.5)
)
# remove unresolved atoms, atoms within same residue
is_real_atom = ChemData().heavyatom_mask.to(outputs["seq"].device)[outputs["seq"]]
is_resolved_atom_L = outputs["crd_mask_I"][is_real_atom]
is_unresolved_distance_LL = (
is_resolved_atom_L[..., None] & is_resolved_atom_L[None, ...]
)
in_same_residue_LL = (
outputs["f"]["tok_idx"][:, None] == outputs["f"]["tok_idx"][None, :]
)
lddt_values = {}
for mask, mask_type in get_lddt_masks(outputs):
mask = mask & is_unresolved_distance_LL & ~in_same_residue_LL
lddt = torch.div(lddt_matrix[:, mask].sum(dim=(-1)), mask.sum(dim=(-1, -2)))
lddt_values[f"lddt_{mask_type}"] = lddt
return lddt_values, ("t_quantile_4",)
def lddt_metrics_null(config, outputs):
# compute distances between ground truth atoms
ground_truth_distances = torch.cdist(outputs["X_gt_L"], outputs["X_gt_L"])
# compute distances between predicted atoms
t = outputs["t"]
X_noisy_L = outputs["X_noisy_L"]
sigma_data = 16
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[..., None, None] * X_noisy_L
predicted_distances = torch.cdist(null_pred, null_pred)
# compute LDDT score for each pair of distances
difference_distances = torch.abs(ground_truth_distances - predicted_distances)
lddt_matrix = torch.zeros_like(difference_distances)
lddt_matrix = (
0.25 * (difference_distances < 4.0)
+ 0.25 * (difference_distances < 2.0)
+ 0.25 * (difference_distances < 1.0)
+ 0.25 * (difference_distances < 0.5)
)
# remove unresolved atoms, atoms within same residue
is_real_atom = ChemData().heavyatom_mask[outputs["seq"]]
is_resolved_atom_L = outputs["crd_mask_I"][is_real_atom]
is_unresolved_distance_LL = (
is_resolved_atom_L[..., None] & is_resolved_atom_L[None, ...]
)
in_same_residue_LL = (
outputs["f"]["tok_idx"][:, None] == outputs["f"]["tok_idx"][None, :]
)
lddt_values = {}
for mask, mask_type in get_lddt_masks(outputs):
mask = mask & is_unresolved_distance_LL & ~in_same_residue_LL
lddt = torch.div(lddt_matrix[:, mask].sum(dim=(-1)), mask.sum(dim=(-1, -2)))
lddt_values[f"lddt_{mask_type}_null"] = lddt
return lddt_values, ("t_quantile_4",)
def get_lddt_masks(outputs):
D, L = outputs["X_L"].shape[:2]
tok_idx = outputs["f"]["tok_idx"]
is_protein_L = outputs["f"]["is_protein"][tok_idx]
is_dna_L = outputs["f"]["is_dna"][tok_idx]
is_rna_L = outputs["f"]["is_rna"][tok_idx]
is_ligand_L = outputs["f"]["is_ligand"][tok_idx]
asym_id_L = outputs["f"]["asym_id"][tok_idx]
same_chain_LL = asym_id_L[:, None] == asym_id_L[None, :]
for mask_type in [
"all",
"protein_intra",
"protein_inter",
"ligand_intra",
"ligand_inter",
]:
if mask_type == "all":
mask = torch.ones((L, L), dtype=torch.bool, device=outputs["X_L"].device)
elif mask_type == "protein_intra":
mask = is_protein_L[:, None] & is_protein_L[None, :]
mask *= same_chain_LL
elif mask_type == "protein_inter":
mask = is_protein_L[:, None] & is_protein_L[None, :]
mask *= ~same_chain_LL
elif mask_type == "ligand_intra":
mask = is_ligand_L[:, None] & is_ligand_L[None, :]
mask *= same_chain_LL
elif mask_type == "ligand_inter":
mask = is_ligand_L[:, None] & is_ligand_L[None, :]
mask *= ~same_chain_LL
elif mask_type == "protein_ligand_inter":
mask = is_protein_L[:, None] & is_ligand_L[None, :]
yield (mask, mask_type)
def diffusion_losses(config, outputs):
loss = Loss(**config.loss)
loss_dict_by_type = {}
t = outputs["t"]
X_noisy_L = outputs["X_noisy_L"]
sigma_data = 16
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[..., None, None] * X_noisy_L
sigma_gt = torch.var(outputs["X_gt_L"], dim=(1, 2)) ** 0.5
for input_type, X_L in (
("pred", outputs["X_L"]),
# ('input', outputs['X_noisy_L']),
("true", outputs["X_gt_L"]),
("null_pred", null_pred),
):
l_total, _, loss_dict_batched = loss(
outputs["f"],
X_L,
outputs["X_gt_L"],
outputs["t"],
outputs["seq"],
outputs["crd_mask_I"],
)
# loss_dict_by_type[input_type] = loss_dict_batched
loss_dict_batched_prefixed = {
f"{k}.{input_type}": v for k, v in loss_dict_batched.items()
}
loss_dict_by_type.update(loss_dict_batched_prefixed)
# Correcting for EDM : AF3 lambda conversion
edm_corr = (t + loss.sigma_data) ** 2 / (t * loss.sigma_data) ** 2
loss_dict_batched_edm = {k: v * edm_corr for k, v in loss_dict_batched.items()}
loss_dict_batched_prefixed_edm = {
f"{k}_edm.{input_type}": v for k, v in loss_dict_batched_edm.items()
}
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
# Correcting for Var(gt) != sigma_data
expected_loss_gt = (
1
/ (loss.sigma_data**2 + t**2)
* (loss.sigma_data**2 + t**2 * sigma_gt**2 / loss.sigma_data**2)
)
loss_dict_batched_edm_gt_corr = {
k: edm_corr * v / expected_loss_gt for k, v in loss_dict_batched.items()
}
loss_dict_batched_prefixed_edm = {
f"{k}_edm_gt_corr.{input_type}": v
for k, v in loss_dict_batched_edm_gt_corr.items()
}
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
o = flatten_dictionary(loss_dict_by_type)
o["pred_over_null_pred"] = o["diffusion_loss.pred"] / o["diffusion_loss.null_pred"]
o["pred_over_null_pred_norm"] = (
o["diffusion_loss_edm_gt_corr.pred"] / o["diffusion_loss_edm_gt_corr.null_pred"]
)
return o, ("t_quantile_4",)
def get_normal_quantiles(n):
# Generate n evenly spaced probabilities between 0 and 1
probabilities = np.linspace(0, 1, n)
# Use the percent point function (inverse CDF) of the standard normal distribution
return norm.ppf(probabilities)
def get_t_quantiles(t, sigma_data, n):
bins = sigma_data * np.exp(-1.2 + 1.5 * get_normal_quantiles(n + 1))
t_binned_list = []
for t in t:
t_bin = np.digitize(t, bins) - 1
bin_start = bins[t_bin]
bin_end = bins[t_bin + 1]
t_binned = f"t=[{bin_start:.2f},{bin_end:.2f})"
t_binned_list.append(t_binned)
return t_binned_list
class NetworkOutputGradSanityCheck(Callback):
def __init__(self, call_n_times=0, *args, **kwargs):
super().__init__(*args, **kwargs)
self.call_n_times = call_n_times
self.call_count = 0
def on_after_backward(self, trainer, pl_module):
if self.call_count < self.call_n_times:
self.call_count += 1
r_projection_weight = pl_module.model.model.diffusion_module.atom_attention_decoder.to_r_update[
1
].weight
ic(
torch.linalg.norm(r_projection_weight)
if r_projection_weight is not None
else None,
torch.linalg.norm(r_projection_weight.grad)
if r_projection_weight.grad is not None
else None,
)
class MonitorActivations(Callback):
def make_hook(self, label):
def hook(module, args, kwargs, output):
activation_metrics = {
f"{label}:inter_batch_cosine_similarity": F.cosine_similarity(
torch.flatten(output[0]),
torch.flatten(output[1]),
dim=0,
),
f"{label}:intra_batch_cosine_similarity_to_elem_0": F.cosine_similarity(
output[0][0:1],
output[0],
).mean(),
}
self.log_dict(activation_metrics)
return hook
def setup(self, trainer, pl_module, stage):
self.pl_module = pl_module
self.trainer = trainer
pl_module.model.model.diffusion_module.atom_attention_decoder.register_forward_hook(
self.make_hook(
"diffusion_module.atom_attention_decoder",
),
with_kwargs=True,
)
class FindUnusedParameters(Callback):
def __init__(self, only_once=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.only_once = only_once
self.called = False
def on_after_backward(self, trainer, pl_module):
if self.called and self.only_once:
return
self.called = True
# Calculate unused parameters after each batch
unused_params = [
name for name, param in pl_module.named_parameters() if param.grad is None
]
# Log unused parameters
logging.info(
f"global_step={pl_module.global_step}: parameters with no gradient: {json.dumps(unused_params, indent=4)}"
)
if unused_params:
raise Exception("storp")
class WriteToPymol(Callback):
def __init__(self, only_once=True, *args, **kwargs):
super().__init__(*args, **kwargs)
self.only_once = only_once
self.called = False
pymol.init("http://chesaw.dhcp.ipd:9123")
def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx: int,
) -> None:
if self.called and self.only_once:
return
self.called = True
pymol_tools.clear()
predicted = outputs
logger.info("predicted:\n" + pretty_describe_dict(predicted))
ic(predicted["loss"])
D = predicted["X_L"].shape[0]
max_to_show = 16
grid_slot = 1
cmd.set("grid_mode", 1)
for i in range(min(D, max_to_show)):
X_gt_L = predicted["X_gt_L"][i]
X_L = predicted["X_L"][i]
X_noisy_L = predicted["X_noisy_L"][i]
t = predicted["t"][i]
label = pymol_tools.show_pymol(
pymol_tools.to_atom37(X_noisy_L, predicted["crd_mask_I"]),
predicted["seq"],
predicted["bond_feats"],
label=f"input_{i}_t_{t.item():.2f}",
)
cmd.set("grid_slot", grid_slot, label)
cmd.color("yellow", label)
label = pymol_tools.show_pymol(
pymol_tools.to_atom37(X_L, predicted["crd_mask_I"]),
predicted["seq"],
predicted["bond_feats"],
label=f"pred_{i}_t_{t.item():.2f}",
)
cmd.set("grid_slot", grid_slot, label)
cmd.color("green", label)
label = pymol_tools.show_pymol(
pymol_tools.to_atom37(X_gt_L, predicted["crd_mask_I"]),
predicted["seq"],
predicted["bond_feats"],
label=f"gt_{i}",
)
cmd.set("grid_slot", grid_slot, label)
cmd.color("blue", label)
grid_slot += 1
cmd.show_as("licorice", "all")
cmd.alter("name CA", "vdw=2.0")
# cmd.set('sphere_transparency', 0.0)
cmd.show("spheres", "name CA")
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
class WritePDB(Callback):
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
seq = batch["seq"][0][0]
is_real_atom = ChemData().heavyatom_mask.to(seq.device)[seq]
X_L = outputs["X_L"]
X_gt_L = outputs["X_gt_L"]
atom_mask = outputs["crd_mask_I"]
bond_feats = batch["bond_feats"]
X_I = torch.full(
(X_L.shape[0], atom_mask.shape[0], ChemData().NTOTAL, 3), np.nan
).to(X_L.device)
X_I[..., is_real_atom, :] = X_L
X_gt_I = torch.full(
(X_gt_L.shape[0], atom_mask.shape[0], ChemData().NTOTAL, 3), np.nan
).to(X_gt_L.device)
X_gt_I[..., atom_mask, :] = X_gt_L
pdb_path = f"tmp/true_{batch_idx}.pdb"
writepdb(
pdb_path,
X_gt_I[0],
seq.long(),
bond_feats=bond_feats,
)
for i in range(X_L.shape[0]):
pdb_path = f"tmp/pred_{batch_idx}_{i}.pdb"
writepdb(
pdb_path,
X_I[i],
seq.long(),
bond_feats=bond_feats,
)
return super().on_validation_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)
class DebugGrads(Callback):
def on_after_backward(self, trainer, pl_module):
grad_dict = {}
for name, param in pl_module.named_parameters():
if param.grad is not None and "pairformer" in name:
grad_dict[name] = param.grad.clone().detach()
ic(
name,
torch.linalg.norm(param.grad),
torch.linalg.norm(param),
)
torch.save(grad_dict, "grad_dict_unbatched.pt")

File diff suppressed because it is too large Load Diff

View File

@@ -1,534 +0,0 @@
import os
import argparse
import json
import logging
import pickle
import tempfile
from collections.abc import Mapping
from os import PathLike
from pathlib import Path
import hydra
import numpy as np
import torch
import yaml
from biotite.structure import AtomArray, AtomArrayStack, stack
from cifutils import parse
from cifutils.tools.inference import (
build_msa_paths_by_chain_id_from_component_list,
components_to_atom_array,
)
from cifutils.utils.io_utils import to_cif_file
from datahub.encoding_definitions import AF3SequenceEncoding
import omegaconf
from omegaconf import OmegaConf
from rf2aa.metrics.predicted_error import WriteAF3Confidence
from rf2aa.trainer_base import trainer_factory
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Define the sequence encoding; needed to decode the restypes when saving to CIF
encoding = AF3SequenceEncoding()
def build_stack_from_atom_array_and_batched_coords(
coords: np.ndarray,
atom_array: AtomArray,
annotations_to_keep: list[str] = [
"chain_id",
"transformation_id",
"res_id",
"res_name",
"element",
"atom_name",
],
) -> AtomArrayStack:
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
Args:
coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
"""
# (Diffusion batch size will become the number of models)
n_batch = coords.shape[0]
# Remove unwanted annotations
for annotation in atom_array.get_annotation_categories():
if annotation not in annotations_to_keep:
atom_array.del_annotation(annotation)
# Build the stack and assign the coordinates
atom_array_stack = stack([atom_array for _ in range(n_batch)])
atom_array_stack.coord = coords
# Adjust chain_id if there are multiple transformations
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
if (
"transformation_id" in atom_array.get_annotation_categories()
and len(np.unique(atom_array_stack.transformation_id)) > 1
):
atom_array_stack.chain_id = (
atom_array_stack.chain_id + atom_array_stack.transformation_id
)
return atom_array_stack
def _spoof_cif_from_dictionary(item: dict, temp_dir: PathLike) -> Path:
"""Unpacks a dictionary to create a CIF file from its components.
Args:
item (dict): A dictionary containing 'name' and 'components', optionally 'bonds'.
temp_dir (Path): Path to the temporary directory for storing CIF files.
Returns:
Path: The path to the created CIF file, saved in the temporary directory.
Raises:
NotImplementedError: If 'bonds' is present in the dictionary.
ValueError: If 'name' or 'components' are missing from the dictionary.
"""
# Validate the dictionary structure ("name" and "components" are required, "bonds" is optional)
assert (
"name" in item and "components" in item
), "The input dictionary must contain 'name' and 'components' keys."
# Build components
atom_array, component_list = components_to_atom_array(
item["components"], return_components=True, bonds=item.get("bonds", None)
)
msa_paths_by_chain_id = build_msa_paths_by_chain_id_from_component_list(
component_list
)
# Create a temporary CIF file from the JSON data
cif_path = Path(temp_dir) / f"{item['name']}.cif"
save_path = to_cif_file(
atom_array,
cif_path,
extra_categories={"msa_paths_by_chain_id": msa_paths_by_chain_id}
if msa_paths_by_chain_id
else None,
file_type="cif", # Not zipped for efficiency (as it's a temporary directory anyways)
)
return Path(save_path)
def _build_file_paths_for_prediction(inputs: list, temp_dir: PathLike) -> list[Path]:
"""Prepare files for prediction based on the input paths.
Input paths may be dictionary-like format (e.g., JSON, YAML, Pickle), CIF/PDB files, or directories containing these files.
Processes directories to find supported file types and converts dictionary-like formats to CIF files.
Args:
inputs (list): List of input paths (JSON, YAML, Pickle, or CIF/PDB).
temp_dir (Path): Path to the temporary directory for storing CIF files.
Returns:
list[Path]: List of file paths for prediction.
"""
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
# Collect all files from inputs, handling directories and individual files
paths_to_raw_input_files = []
for input_path in inputs:
if Path(input_path).is_dir():
paths_to_raw_input_files.extend(
_find_files(
input_path, DICTIONARY_LIKE_EXTENSIONS | CIF_LIKE_EXTENSIONS
)
)
else:
paths_to_raw_input_files.append(Path(input_path))
paths_to_cif_like_files = []
for path in paths_to_raw_input_files:
#concatenated_suffix = "".join(path.suffixes)
concatenated_suffix = path.suffixes[-1]
if concatenated_suffix in DICTIONARY_LIKE_EXTENSIONS:
# Spoof CIF files from dictionary-like formats
with open(path, "rb" if path.suffix == ".pkl" else "r") as file:
# Load data based on file extension
if path.suffix == ".json":
data = json.load(file)
elif path.suffix in {".yaml", ".yml"}:
raise NotImplementedError("YAML files are not yet supported.")
elif path.suffix == ".pkl":
data = pickle.load(file)
if isinstance(data, dict):
data = [
data
] # Convert single dictionary to list for uniform processing
for item in data:
paths_to_cif_like_files.append(
_spoof_cif_from_dictionary(item, temp_dir)
)
elif concatenated_suffix in CIF_LIKE_EXTENSIONS:
# Directly use CIF-like files
paths_to_cif_like_files.append(path)
else:
raise ValueError(
f"Unsupported file extension: {path.suffix} (path: {path}; paths: {paths_to_raw_input_files})."
)
return paths_to_cif_like_files
def _find_files(path: PathLike, supported_file_types: list) -> list[Path]:
"""Recursively find all files with the given extensions in the specified path.
Args:
path (PathLike): Path to the directory containing the files.
supported_file_types (list): List of supported file extensions.
Returns:
list[Path]: List of files with the given extensions.
"""
files_with_supported_types = []
path = Path(path)
# Check if the path is a directory
if path.is_dir():
# Search for files with each supported extension
for file_type in supported_file_types:
files_with_supported_types.extend(path.glob(f"*{file_type}"))
elif path.is_file() and path.suffix in supported_file_types:
# If it's a file and has a supported extension, add to the list
files_with_supported_types.append(path)
return files_with_supported_types
def _update_nested_dictconfig(d: Mapping, u: Mapping, depth: int = 0) -> Mapping:
"""Recursive function to overwrite contents of one nested omegaconf dictconfig with another.
Args:
d: dictionary of dictconfigs whose contents will be overwritten
u: dictionary of items which will overwrite or add to values in d
depth: depth of recursion: a positive integer:
-used to keep the outermost layer of the config as a dict instead of DictConfig.
-set to 1 or higher to return only DictConfig.
Returns:
d updated to contain values in u
"""
d = dict(d)
u = dict(u)
for k, v in u.items():
if isinstance(v, Mapping):
d[k] = _update_nested_dictconfig(d.get(k, {}), v, depth=depth + 1)
else:
d[k] = v
if depth == 0:
return d
else:
return omegaconf.dictconfig.DictConfig(d)
class EvaluateAF3:
"""Class for inference with AF3. Evaluates a trained AF3 model on a set of spoofed CIFs."""
def __init__(
self,
checkpoint_path: PathLike,
cif_out_dir: PathLike,
n_recycles: int,
diffusion_batch_size: int,
config_override_path: PathLike | None = None,
residue_renaming_dict: dict | None = None,
temp_dir: PathLike | None = None,
num_steps: int = 200,
solver: str = "af3",
overwrite: bool = False
):
"""Initialize the evaluator.
Args:
checkpoint_path (PathLike): Path to the checkpoint file, e.g., /path/to/checkpoint.pt.
cif_out_dir (PathLike): Directory to save the output (predicted) CIF files.
config_override_path (PathLike): Path to a yaml file with config options to override those in the checkpoint file.
world_size (int): Number of GPUs to use for evaluation.
n_recycles (int): Number of recycles for AF3. The default is 10.
diffusion_batch_size (int): Diffusion batch size for AF3. Each predicted structure will be saved as a separate model within the same CIF file.
residue_renaming_dict (dict): Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}.
temp_dir (PathLike): Temporary directory to store intermediate files. The default is None.
num_steps (int): Number of steps for sampling of the diffusion model. The default is 200; we see reasonable results with 50 steps.
solver (str): Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'. The default is 'af3'.
"""
# Load the checkpoint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
# Load the config
self.config = OmegaConf.create(checkpoint["training_config"])
if config_override_path is not None:
with open(config_override_path, 'r') as fs:
config_override_dict = yaml.load(fs, yaml.FullLoader)
self.config = _update_nested_dictconfig(self.config, config_override_dict)
self.config = OmegaConf.create(self.config)
# Make sure we aren't using the version with a bug in plddt
if (
self.config.experiment.name
== "rf2aa-af3-repro-rollout_nmw_from_scratch_af3_style_no_cb_normal_crop_cont_3"
):
raise ValueError(
"These weights are outdated and the plddt metric may be inaccurate. Please update to the latest available weights."
)
# Sampler sets diffusion batch size based on the following, not strictly on batch size in vaildation transform
self.config.dataset_params["diffusion_batch_size_valid"] = diffusion_batch_size
self.config.af3_inference["num_steps"] = num_steps
self.config.af3_inference["solver"] = solver
# Load the AF-3 trainer
self.trainer = trainer_factory[self.config.experiment.trainer](
config=self.config
)
self.trainer.checkpoint = checkpoint
# Set the output directory for the CIF files (e.g., predicted structures)
self.cif_out_dir = Path(cif_out_dir) if cif_out_dir else Path("./")
# Model parameters
self.n_recycles = n_recycles
self.diffusion_batch_size = diffusion_batch_size
if "confidence_loss" in self.config.loss:
self.confidence_writer = WriteAF3Confidence(
**self.config.loss.confidence_loss
)
else:
self.confidence_writer = None
# Rename residues
self.residue_renaming_dict = residue_renaming_dict
self.temp_dir = Path(temp_dir)
self.overwrite = overwrite
def construct_pipeline(self):
"""Construct the AF3 inference pipeline."""
self.config.dataset_params.val.interface.transform.n_recycles = self.n_recycles
self.config.dataset_params.val.interface.transform.diffusion_batch_size = (
self.diffusion_batch_size
)
self.config.dataset_params.val.interface.transform.return_atom_array = (
True # Required for `to_cif`
)
assert (
self.config.dataset_params.val.interface.transform.n_recycles
== self.n_recycles
), "Number of recycles not set correctly."
assert (
self.config.dataset_params.val.interface.transform.diffusion_batch_size
== self.diffusion_batch_size
), "Diffusion batch size not set correctly."
pipeline = hydra.utils.instantiate(
self.config.dataset_params.val.interface.transform
)
return pipeline
def eval(self, files: list[PathLike]):
"""Evaluate the model on a set of spoofed CIF files.
Args:
files (list[PathLike]): List of paths to spoofed CIF files or directories containing spoofed CIF files.
Coordinates must be present but may contain NaN values. If a directory is provided,
all files with the extensions .cif, .pdb, .bcif, .cif.gz, .pdb.gz, .bcif.gz will be processed.
"""
# Construct the model and load the checkpoint
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
self.trainer.construct_model(device=gpu, inference=True)
self.trainer.load_model()
# Set the model to evaluation mode
self.trainer.model.eval()
logger.info("Building Transform pipeline...")
# Construct the AF3 inference pipeline
pipeline = self.construct_pipeline()
logger.info(f"Found {len(files)} structures to predict: {files}.")
for structure in files:
# ... parse into an AtomArray (`parse` handles all valid formats)
logger.info(f"Parsing from path: {structure}")
#example_id = structure.name.split(".")[0]
example_id = ".".join(structure.name.split(".")[:-1])
# optionally, skip if output already exists
cif_output_path = example_id + '.cif'
cif_output_path = self.cif_out_dir / cif_output_path
if os.path.exists(cif_output_path) and not self.overwrite:
logger.info(f"Existing output for {example_id} found at {cif_output_path}. Skipping this example. Set --overwrite to not skip examples with existing output")
continue
# If we're renaming residues, we do a brute-force replacement in the CIF file
if self.residue_renaming_dict:
logger.info(
f"Renaming residues in {structure} with brute-force find and replace: {self.residue_renaming_dict}"
)
with open(structure, "r") as f:
content = f.read()
for old_res, new_res in self.residue_renaming_dict.items():
content = content.replace(old_res, new_res)
structure = Path(self.temp_dir / structure.name)
with open(structure, "w") as f:
f.write(content)
out = parse(structure, remove_hydrogens=True)
# ... get the atom array and set NaN coordinates to random
atom_array = (
out["assemblies"]["1"][0]
if "assemblies" in out
else out["asym_unit"][0]
)
# HACK: Set NaN coordinates to random values to avoid unexpected behavior in the pipeline
atom_array.coord[np.isnan(atom_array.coord)] = np.random.rand(
*atom_array.coord[np.isnan(atom_array.coord)].shape
)
# ... assemble the pipeline input in a format compatible with the DataHub pipeline
pipeline_input = {
"example_id": example_id,
"atom_array": atom_array,
"chain_info": out["chain_info"],
}
# ... run dataloading and featurization
pipeline_output = pipeline(pipeline_input)
# Model inference
with torch.no_grad():
outputs = self.trainer.sampler.sample(
[pipeline_output],
n_cycle=self.n_recycles,
use_amp=self.config.training_params.use_amp,
)
# Override the AtomArray with the predited coordinates
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
outputs["X_L"].cpu().numpy(), pipeline_output["atom_array"]
)
# Write the atom array to a CIF file
# NOTE: To make the secondary structure appear, run `dss` in PyMol (see: https://biology.stackexchange.com/questions/70143/can-pymol-show-cartoon-secondary-structure-for-a-pdb-of-multiple-frames)
out_path = to_cif_file(
atom_array_stack, self.cif_out_dir / example_id, file_type="cif"
)
logger.info(f"Prediction for {example_id} written to {out_path}.")
if "confidence" in outputs:
loss_input = {
"example_id": example_id,
"is_real_atom": pipeline_output["confidence_feats"]["is_real_atom"],
}
logger.info(f"Writing {example_id}.score to {self.cif_out_dir}")
df = self.confidence_writer(None, outputs, loss_input)
df.to_csv(self.cif_out_dir / f"{example_id}.score", index=False)
logger.info(
f"Confidence metrics for {example_id}.cif written to {self.cif_out_dir / example_id}.score."
)
def main():
parser = argparse.ArgumentParser(description="Evaluate AF3 using specified paths.")
parser.add_argument(
"inputs",
nargs="+",
help="List of paths to supported file types or directories of of supported files.",
)
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the checkpoint file"
)
parser.add_argument(
"--cif_out_dir", type=str, required=True, help="Directory for output CIF files"
)
parser.add_argument(
"--config_override_path",
type=str,
required=False,
help="Path to a yaml file with configs to override those in the checkpoint file",
)
parser.add_argument(
"--n_recycles", type=int, default=10, help="Number of recycles for AF3"
)
parser.add_argument(
"--diffusion_batch_size",
type=int,
default=5,
help="Diffusion batch size for AF3",
)
parser.add_argument(
"--rename_residues",
type=str,
default="",
help="Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}",
)
parser.add_argument(
"--num_steps",
type=int,
default=200,
help="Number of steps for sampling of the diffusion model",
)
parser.add_argument(
"--solver",
type=str,
default="af3",
help="Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'.",
)
parser.add_argument(
"--overwrite",
default=False,
action="store_true",
help="Overwrite existing .cif outputs with new runs.",
)
args = parser.parse_args()
with tempfile.TemporaryDirectory() as temp_dir:
temp_dir = Path(temp_dir)
temp_dir.mkdir(parents=True, exist_ok=True)
# Prepare inputs based on the file types
file_paths_for_prediction = _build_file_paths_for_prediction(
args.inputs, temp_dir
)
# Rename residues if necessary (e.g., for MPNN outputs that have ligand names that clash with the CCD)
residue_renaming_dict = (
json.loads(args.rename_residues) if args.rename_residues else {}
)
# Construct the evaluator
evaluator = EvaluateAF3(
checkpoint_path=args.checkpoint_path,
cif_out_dir=args.cif_out_dir,
config_override_path=args.config_override_path,
n_recycles=args.n_recycles,
diffusion_batch_size=args.diffusion_batch_size,
residue_renaming_dict=residue_renaming_dict,
temp_dir=temp_dir,
num_steps=args.num_steps,
solver=args.solver,
overwrite=args.overwrite
)
# Launch the evaluation
evaluator.eval(files=file_paths_for_prediction)
if __name__ == "__main__":
main()

View File

@@ -1,28 +0,0 @@
import torch
import torch.nn as nn
from rf2aa.loss.af3_losses import distogram_loss
from rf2aa.metrics.metrics_base import Metric
class DistogramLoss(Metric):
def __init__(self):
super().__init__()
self.cce_loss = nn.CrossEntropyLoss(reduction="none")
def __call__(self, network_input, network_output, loss_input):
pred_distogram = network_output["distogram"]
X_rep_atoms_I = loss_input["X_rep_atoms_I"]
crd_mask_rep_atoms_I = loss_input["crd_mask_rep_atoms_I"]
loss = distogram_loss(
pred_distogram, X_rep_atoms_I, crd_mask_rep_atoms_I, self.cce_loss
)
return {"distogram_loss": loss.detach().item()}
class SaveDistograms(Metric):
def __call__(self, network_input, network_output, loss_input):
pred_distogram = network_output["distogram"]
example_id = loss_input["example_id"]
torch.save(pred_distogram, f"distograms/{example_id}.pt")
return {"distogram_saved": True}

View File

@@ -1,288 +0,0 @@
import torch
import torch.nn as nn
from rf2aa.metrics.metrics_base import Metric
def calc_lddt(
X_L,
X_gt_L,
crd_mask_L,
tok_idx,
pairs_to_score=None,
distance_cutoff=15.0,
use_amp=True,
eps=1e-6,
):
"""
X_L: predicted coordinates (D, L, 3)
X_gt_L: ground truth coordinates (D, L, 3)
crd_mask_L: mask of coordinates (D, L,)
tok_idx: token index of each atom (L,)
pairs_to_score: pairs to score (L, L) | None
"""
D, L = X_L.shape[:2]
if pairs_to_score is None:
pairs_to_score = torch.ones((L, L), dtype=torch.bool).triu(0).to(X_L.device)
else:
assert pairs_to_score.shape == (L, L)
pairs_to_score = pairs_to_score.triu(0).to(X_L.device)
first_index, second_index = torch.nonzero(pairs_to_score, as_tuple=True)
lddt = []
for d in range(D):
ground_truth_distances = torch.linalg.norm(
X_gt_L[d, first_index] - X_gt_L[d, second_index], dim=-1
)
pair_mask = torch.logical_and(
ground_truth_distances > 0, ground_truth_distances < distance_cutoff
)
# only score pairs that are resolved in the ground truth
pair_mask *= crd_mask_L[d, first_index] * crd_mask_L[d, second_index]
# don't score pairs that are in the same token
pair_mask *= tok_idx[first_index] != tok_idx[second_index]
valid_pairs = pair_mask.nonzero(as_tuple=True)
pair_mask = pair_mask[valid_pairs].to(X_L.dtype)
ground_truth_distances = ground_truth_distances[valid_pairs]
first_index, second_index = first_index[valid_pairs], second_index[valid_pairs]
predicted_distances = torch.linalg.norm(
X_L[d, first_index] - X_L[d, second_index], dim=-1
)
delta_distances = torch.abs(predicted_distances - ground_truth_distances + eps)
del predicted_distances, ground_truth_distances
lddt.append(
0.25
* (
torch.sum((delta_distances < 4.0) * pair_mask)
+ torch.sum((delta_distances < 2.0) * pair_mask)
+ torch.sum((delta_distances < 1.0) * pair_mask)
+ torch.sum((delta_distances < 0.5) * pair_mask)
)
/ (torch.sum(pair_mask) + eps)
)
return torch.tensor(lddt)
class InterfaceLDDT(Metric):
def __call__(self, network_input, network_output, loss_input):
interface_lddt = {"interface_lddt_first": [], "interface_lddt_best": []}
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
for chain_i, chain_j, interface_type in loss_input["interfaces_to_score"]:
# get tokens in chain_i and chain_j
chain_i_tokens = chain_iid_token_lvl == chain_i
chain_j_tokens = chain_iid_token_lvl == chain_j
# convert the token level to the atom level
chain_i_atoms = chain_i_tokens[tok_idx]
chain_j_atoms = chain_j_tokens[tok_idx]
# compute the intersection of chain_i and chain_j
chain_ij_atoms = torch.einsum(
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
).to(network_output["X_L"].device)
# symmetrize
chain_ij_atoms = chain_ij_atoms | chain_ij_atoms.T
# compute lddt using the pairs_to_score from the intersection
lddt = calc_lddt(
network_output["X_L"],
loss_input["X_gt_L"],
loss_input["crd_mask_L"],
torch.tensor(tok_idx).to(network_output["X_L"].device),
pairs_to_score=chain_ij_atoms,
distance_cutoff=30.0,
)
interface_lddt["interface_lddt_first"].append(lddt[0].item())
interface_lddt["interface_lddt_best"].append(lddt.max().item())
return interface_lddt
class ConfidenceInterfaceLDDT(Metric):
def __call__(self, network_input, network_output, loss_input):
interface_lddt = {
"interface_lddt_first": [],
"interface_lddt_best": [],
"interface_lddt_pae": [],
"interface_lddt_pde": [],
"interface_lddt_plddt": [],
"interface_lddt_af3_style_ipae": [],
"interface_lddt_af3_style_lig_ipae": [],
}
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
for chain_i, chain_j, interface_type in loss_input["interfaces_to_score"]:
# get tokens in chain_i and chain_j
chain_i_tokens = chain_iid_token_lvl == chain_i
chain_j_tokens = chain_iid_token_lvl == chain_j
# convert the token level to the atom level
chain_i_atoms = chain_i_tokens[tok_idx]
chain_j_atoms = chain_j_tokens[tok_idx]
# compute the intersection of chain_i and chain_j
chain_ij_atoms = torch.einsum(
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
).to(network_output["X_L"].device)
# compute lddt using the pairs_to_score from the intersection
lddt = calc_lddt(
network_output["X_L"],
loss_input["X_gt_L"],
loss_input["crd_mask_L"],
torch.tensor(tok_idx).to(network_output["X_L"].device),
pairs_to_score=chain_ij_atoms,
distance_cutoff=30.0,
)
pae_idx = loss_input["pae_idx"]
pde_idx = loss_input["pde_idx"]
plddt_idx = loss_input["plddt_idx"]
af3_style_ipae_idx = loss_input["best_interface_idx"][
f"{chain_i}-{chain_j}"
]
interface_lddt["interface_lddt_first"].append(lddt[0].item())
interface_lddt["interface_lddt_best"].append(lddt.max().item())
interface_lddt["interface_lddt_pae"].append(lddt[pae_idx].item())
interface_lddt["interface_lddt_pde"].append(lddt[pde_idx].item())
interface_lddt["interface_lddt_plddt"].append(lddt[plddt_idx].item())
interface_lddt["interface_lddt_af3_style_ipae"].append(
lddt[af3_style_ipae_idx].item()
)
interface_lddt["interface_lddt_af3_style_lig_ipae"].append(
lddt[loss_input["best_lig_ipae_idx"][f"{chain_i}-{chain_j}"]].item()
)
return interface_lddt
class ConfidenceChainLDDT(Metric):
def __call__(self, network_input, network_output, loss_input):
chain_lddt = {
"chain_lddt_first": [],
"chain_lddt_best": [],
"chain_lddt_pae": [],
"chain_lddt_pde": [],
"chain_lddt_plddt": [],
"chain_lddt_af3_style_chain": [],
"chain_lddt_af3_style_single_chain": [],
}
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
for chain_i, chain_type in loss_input["pn_units_to_score"]:
# print(chain_type)
# get tokens in chain_i and chain_j
chain_i_tokens = chain_iid_token_lvl == chain_i
chain_j_tokens = chain_iid_token_lvl == chain_i
# convert the token level to the atom level
chain_i_atoms = chain_i_tokens[tok_idx]
chain_j_atoms = chain_j_tokens[tok_idx]
# compute the intersection of chain_i and chain_j
chain_ij_atoms = torch.einsum(
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
).to(network_output["X_L"].device)
# compute lddt using the pairs_to_score from the intersection
lddt = calc_lddt(
network_output["X_L"],
loss_input["X_gt_L"],
loss_input["crd_mask_L"],
torch.tensor(tok_idx).to(network_output["X_L"].device),
pairs_to_score=chain_ij_atoms,
)
chain_lddt["chain_lddt_first"].append(lddt[0].item())
chain_lddt["chain_lddt_best"].append(lddt.max().item())
chain_lddt["chain_lddt_pae"].append(lddt[loss_input["pae_idx"]].item())
chain_lddt["chain_lddt_pde"].append(lddt[loss_input["pde_idx"]].item())
chain_lddt["chain_lddt_plddt"].append(lddt[loss_input["plddt_idx"]].item())
chain_lddt["chain_lddt_af3_style_chain"].append(
lddt[loss_input["best_chain_to_all_idx"][chain_i]].item()
)
chain_lddt["chain_lddt_af3_style_single_chain"].append(
lddt[loss_input["best_chain_to_self_idx"][chain_i]].item()
)
return chain_lddt
class LigRMSD(Metric):
# TODO: move these to a separate file, here for backwards compatibility with configs
def __call__(self, network_input, network_output, loss_input):
raise NotImplementedError()
class InterfacePocketLigandRMSD(Metric):
# TODO: move these to a separate file, here for backwards compatibility with configs
"""
Compute the Ligand RMSD for each interface in the interfaces_to_score list.
The ligand RMSD is computed only for interface protein-ligand chains.
Given a chain pair (chain_i, chain_j) and the interface type, the RMSD is computed as follows:
- if the interface_type is protein_ligand: continue
- Rigid align the GT coordinates of onto the predicted coordinates using only the CA atoms within 10A of the ligand in chain_i or chain_j
- Compute the RMSD between the aligned GT coordinates and the predicted coordinates of the ligand atoms
Note: if the interface is not between a protein-ligand pair, the RMSD is set to -1
"""
def __call__(self, network_input, network_output, loss_input):
raise NotImplementedError()
class ChainLDDT(Metric):
def __call__(self, network_input, network_output, loss_input):
chain_lddt = {"chain_lddt_first": [], "chain_lddt_best": []}
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
for chain_i, chain_type in loss_input["pn_units_to_score"]:
# get tokens in chain_i and chain_j
chain_i_tokens = chain_iid_token_lvl == chain_i
chain_j_tokens = chain_iid_token_lvl == chain_i
# convert the token level to the atom level
chain_i_atoms = chain_i_tokens[tok_idx]
chain_j_atoms = chain_j_tokens[tok_idx]
# compute the intersection of chain_i and chain_j
chain_ij_atoms = torch.einsum(
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
).to(network_output["X_L"].device)
# compute lddt using the pairs_to_score from the intersection
lddt = calc_lddt(
network_output["X_L"],
loss_input["X_gt_L"],
loss_input["crd_mask_L"],
torch.tensor(tok_idx).to(network_output["X_L"].device),
pairs_to_score=chain_ij_atoms,
)
chain_lddt["chain_lddt_first"].append(lddt[0].item())
chain_lddt["chain_lddt_best"].append(lddt.max().item())
return chain_lddt
class LDDTByDiffusionStep(Metric):
def __call__(self, network_input, network_output, loss_input):
lddt_by_step = {"lddt_by_step": []}
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
for i, X_L in enumerate(network_output["X_denoised_L_traj"]):
lddt = calc_lddt(
X_L,
loss_input["X_gt_L"],
loss_input["crd_mask_L"],
torch.tensor(tok_idx).to(network_output["X_L"].device),
)
lddt_by_step["lddt_by_step"].append(lddt)
return lddt_by_step
class SmoothedLDDT(nn.Module):
def __call__(self, network_input, network_output, loss_input):
raise NotImplementedError()

View File

@@ -1,42 +0,0 @@
import hydra
import torch.nn as nn
# class Metric:
# def __call__(self, rf_output, loss_calc_items) -> float:
# raise NotImplementedError("base class")
class MetricManager(nn.Module):
"""
Similar syntax to LossManager, but for metrics
"""
def __init__(self, **metrics):
super().__init__()
self.to_compute = []
for metric_name, metric in metrics.items():
metric_fn = hydra.utils.instantiate(metric)
print(f"Adding metric {metric_name} to the validation metrics")
self.to_compute.append(metric_fn)
def forward(
self,
network_input,
network_output,
loss_input,
):
loss_dict = {}
for loss_fn in self.to_compute:
loss_dict_ = loss_fn(network_input, network_output, loss_input)
loss_dict.update(loss_dict_)
return loss_dict
class Metric:
def __call__(self, network_input, network_output, loss_input) -> float:
raise NotImplementedError("base class")
class AddExampleID(Metric):
def __call__(self, network_input, network_output, loss_input):
return {"example_id": loss_input["example_id"]}

View File

@@ -1,19 +0,0 @@
from typing import Dict
from rf2aa.metrics.predicted_error import PAE, PLDDT
class MetricManager:
def __init__(self, config) -> None:
self.config = config
self.metrics = {metric: metrics_factory[metric] for metric in config.metrics}
def __call__(self, rf_outputs, loss_calc_items) -> Dict:
metrics_dict = {}
for metric_name, metric in self.metrics:
metric_value = metric(rf_outputs, loss_calc_items)
metrics_dict[metric_name] = metric_value
return metrics_dict
metrics_factory = {"mean_pae": PAE(), "mean_plddt": PLDDT()}

View File

@@ -1,359 +0,0 @@
from itertools import combinations
from typing import Any
import numpy as np
import pandas as pd
import torch
import tree
from rf2aa.chemical import ChemicalData as ChemData
from rf2aa.metrics.metric_utils import (
compute_mean_over_subsampled_pairs,
create_chainwise_masks_1d,
create_chainwise_masks_2d,
create_interface_masks_2d,
spread_batch_into_dictionary,
unbin_logits,
)
from rf2aa.metrics.metrics_base import Metric
class WriteAF3Confidence(Metric):
"""
Given some config setups of pae, plddt, and pde, computes aggregate metrics for the model's confidence predictions
TO be used at inference time for users to know how confident their predictions are.
"""
def __init__(self, pae, plddt, pde, **kwargs):
super().__init__()
self.pae = pae
self.plddt = plddt
self.pde = pde
def __call__(self, network_input, network_output, loss_input) -> Any:
plddt_logit_stack = network_output["confidence"]["plddt_logits"]
pae_logits = network_output["confidence"]["pae_logits"]
pde_logits = network_output["confidence"]["pde_logits"]
ch_label = network_output["confidence"]["chain_iid_token_lvl"]
is_real_atom = network_output["confidence"]["is_real_atom"]
# reorder the input tensors to be in (B, n_bins, ...) format for unbinning
plddt = unbin_logits(
plddt_logit_stack.reshape(
-1,
plddt_logit_stack.shape[1],
ChemData().NHEAVY,
self.plddt.n_bins,
).permute(0, 3, 1, 2).float(),
self.plddt.max_value,
self.plddt.n_bins,
)
pae = unbin_logits(
pae_logits.permute(0, 3, 1, 2).float(), self.pae.max_value, self.pae.n_bins
)
pde = unbin_logits(
pde_logits.permute(0, 3, 1, 2).float(), self.pde.max_value, self.pde.n_bins
)
pae_interface = {}
pde_interface = {}
for interface, pairs_to_score in create_interface_masks_2d(
ch_label, device=pae.device
).items():
pae_interface[interface] = spread_batch_into_dictionary(
compute_mean_over_subsampled_pairs(pae, pairs_to_score)
)
pde_interface[interface] = spread_batch_into_dictionary(
compute_mean_over_subsampled_pairs(pde, pairs_to_score)
)
pae_chainwise = {}
pde_chainwise = {}
for chain, pairs_to_score in create_chainwise_masks_2d(
ch_label, device=pae.device
).items():
pae_chainwise[chain] = spread_batch_into_dictionary(
compute_mean_over_subsampled_pairs(pae, pairs_to_score)
)
pde_chainwise[chain] = spread_batch_into_dictionary(
compute_mean_over_subsampled_pairs(pde, pairs_to_score)
)
plddt_chainwise = {}
for chain, residue_atom_indices_to_score in create_chainwise_masks_1d(
ch_label, device=is_real_atom.device
).items():
chain_is_real_atom = (
is_real_atom[..., : ChemData().NHEAVY]
* residue_atom_indices_to_score[:, None]
)
plddt_chainwise[chain] = spread_batch_into_dictionary(
compute_mean_over_subsampled_pairs(plddt, chain_is_real_atom)
)
confidence_data = {
"example_id": loss_input["example_id"],
"mean_plddt": spread_batch_into_dictionary(plddt.mean(dim=(-1, -2))),
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
"chain_wise_mean_plddt": plddt_chainwise,
"chain_wise_mean_pae": pae_chainwise,
"chain_wise_mean_pde": pde_chainwise,
"interface_wise_mean_pae": pae_interface,
"interface_wise_mean_pde": pde_interface,
}
num_batches = plddt.shape[0]
chains = np.unique(ch_label)
num_chains = len(chains)
chain_pairs = list(combinations(chains, 2))
# TODO: refactor to remove for loops
rows = []
for batch_idx in range(num_batches):
for chain_id in range(num_chains):
chain = chains[chain_id]
row = {
"example_id": confidence_data["example_id"],
"chain_chainwise": chain,
"chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][
batch_idx
],
"chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][
batch_idx
],
"chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][
batch_idx
],
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
"overall_pde": confidence_data["mean_pde"][batch_idx],
"overall_pae": confidence_data["mean_pae"][batch_idx],
"batch_idx": batch_idx,
}
rows.append(row)
for interface in chain_pairs:
chain_i, chain_j = interface
row = {
"example_id": confidence_data["example_id"],
"chain_i_interface": chain_i,
"chain_j_interface": chain_j,
"pae_interface": confidence_data["interface_wise_mean_pae"][
interface
][batch_idx],
"pde_interface": confidence_data["interface_wise_mean_pde"][
interface
][batch_idx],
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
"overall_pde": confidence_data["mean_pde"][batch_idx],
"overall_pae": confidence_data["mean_pae"][batch_idx],
"batch_idx": batch_idx,
}
rows.append(row)
return pd.DataFrame(rows)
class GetConfidenceIndices(Metric):
def __call__(self, network_input, network_output, loss_input):
# AF3's ranking metrics work like this, but using ptm instead of ipae:
confidence_loss = loss_input["confidence_loss"]
del loss_input["confidence_loss"]
ch_label = loss_input["chain_iid_token_lvl"]
scored_chains, interfaces, interface_chains = select_scored_units(loss_input)
chain_to_all_masks = create_chain_to_all_masks(ch_label, scored_chains)
chain_to_self_masks = create_chain_to_self_masks(ch_label, scored_chains)
interface_masks, lig_chains = create_interface_masks(
ch_label, interfaces, loss_input["is_ligand"]
)
# map everything to gpu
gpu = network_output["confidence"]["plddt_logits"].device
chain_to_all_masks = tree.map_structure(
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks
)
chain_to_self_masks = tree.map_structure(
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks
)
interface_masks = tree.map_structure(
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks
)
confidence = network_output["confidence"]
plddt_logits = confidence["plddt_logits"]
# Reshape logits to B, K, L, NHEAVY
is_real_atom = network_output["confidence"]["is_real_atom"]
plddt_logits = plddt_logits.reshape(
-1, plddt_logits.shape[1], ChemData().NHEAVY, confidence_loss.plddt.n_bins
).permute(0, 3, 1, 2).float()
# Reshape the pae and pde logits to B, K, L, L
pae_logits = confidence["pae_logits"].permute(0, 3, 1, 2).float()
pde_logits = confidence["pde_logits"].permute(0, 3, 1, 2).float()
pae_logits_unbinned = unbin_logits(
pae_logits, confidence_loss.pae.max_value, confidence_loss.pae.n_bins
)
plddt_logits_unbinned = unbin_logits(
plddt_logits, confidence_loss.plddt.max_value, confidence_loss.plddt.n_bins
)
pde_logits_unbinned = unbin_logits(
pde_logits, confidence_loss.pde.max_value, confidence_loss.pde.n_bins
)
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
complex_plddt = (
plddt_logits_unbinned * is_real_atom[..., : ChemData().NHEAVY]
).sum(dim=(1, 2)) / is_real_atom[..., : ChemData().NHEAVY].sum()
loss_input["pae_idx"] = torch.argmin(complex_pae)
loss_input["pde_idx"] = torch.argmin(complex_pde)
loss_input["plddt_idx"] = torch.argmax(complex_plddt)
chain_to_self_paes = get_masked_error_per_chain(
scored_chains, chain_to_self_masks, pae_logits_unbinned
)
chain_to_all_paes = get_masked_error_per_chain(
scored_chains, chain_to_all_masks, pae_logits_unbinned
)
interface_chain_paes = get_masked_error_per_chain(
interface_chains, interface_masks, pae_logits_unbinned
)
# average over both interfaces
average_interface_paes = get_average_error_per_interface(
interfaces, lig_chains, interface_chain_paes
)
loss_input["best_chain_to_all_idx"] = get_lowest_error_indices(
chain_to_all_paes
)
loss_input["best_chain_to_self_idx"] = get_lowest_error_indices(
chain_to_self_paes
)
loss_input["best_interface_idx"] = get_lowest_error_indices(
average_interface_paes
)
# for ligands, we don't average the error
loss_input["best_lig_ipae_idx"] = get_lowest_error_ligand_indices(
interface_chain_paes, interfaces, lig_chains
)
return loss_input
def select_scored_units(loss_input):
scored_chains = []
interfaces = []
interface_chains = []
for k in loss_input["interfaces_to_score"]:
interfaces.append(f"{k[0]}-{k[1]}")
interface_chains.append(k[0])
interface_chains.append(k[1])
for k in loss_input["pn_units_to_score"]:
scored_chains.append(k[0])
return scored_chains, interfaces, interface_chains
def create_chain_to_all_masks(ch_label, chains_to_score):
unique_chains = np.unique(ch_label)
I = len(ch_label)
chain_to_all_masks = {}
for chain in unique_chains:
if chain in chains_to_score:
indices = torch.from_numpy((ch_label == chain))
mask = indices.unsqueeze(0) | indices.unsqueeze(1)
# set the diagonal to false
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
chain_to_all_masks[chain] = mask
return chain_to_all_masks
def create_chain_to_self_masks(ch_label, chains_to_score):
unique_chains = np.unique(ch_label)
I = len(ch_label)
chain_to_self_masks = {}
for chain in unique_chains:
if chain in chains_to_score:
indices = torch.from_numpy((ch_label == chain))
mask = indices.unsqueeze(0) & indices.unsqueeze(1)
# set the diagonal to false
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
chain_to_self_masks[chain] = mask
return chain_to_self_masks
def create_interface_masks(ch_label, interfaces, is_ligand):
interface_masks = {}
interface_chains = []
ligand_chains = []
for interface in interfaces:
interface_chains.append(interface.split("-")[0])
interface_chains.append(interface.split("-")[1])
interface_chains = set(interface_chains)
for chain in interface_chains:
chain_indices = torch.from_numpy((ch_label == chain))
to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1)
to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1)
interface_mask = to_all & ~to_self
interface_masks[chain] = interface_mask
if torch.all(is_ligand[chain_indices]):
ligand_chains.append(chain)
return interface_masks, ligand_chains
def get_masked_error_per_chain(chains, masks, unbinned_logits):
error = {}
for chain in chains:
mask = masks[chain]
chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask)
error[chain] = chain_error
return error
def get_average_error_per_interface(interfaces, lig_chains, interface_errors):
average_error = {}
for interface in interfaces:
chain_a = interface.split("-")[0]
chain_b = interface.split("-")[1]
average_error[interface] = (
interface_errors[chain_a] + interface_errors[chain_b]
) / 2
return average_error
def get_lowest_error_indices(errors):
lowest_error_indices = {}
for k, v in errors.items():
lowest_error_indices[k] = torch.argmin(v)
return lowest_error_indices
def get_lowest_error_ligand_indices(errors, interfaces, lig_chains):
# ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface
lowest_error_indices = {}
for interface in interfaces:
chain_a = interface.split("-")[0]
chain_b = interface.split("-")[1]
if chain_a in lig_chains or chain_b in lig_chains:
if chain_a in lig_chains:
lig_chain = chain_a
elif chain_b in lig_chains:
lig_chain = chain_b
lowest_error_indices[interface] = torch.argmin(errors[lig_chain])
else:
# assign a random value to avoid key errors downstream; sorting ligand interfaces
# from other types is handles in analysis
lowest_error_indices[interface] = 0
return lowest_error_indices

View File

@@ -1,18 +0,0 @@
import torch.nn as nn
from rf2aa.model.AF3_structure import AtomAttentionDecoder, AtomAttentionEncoder
class NonEquivariantAtomEncoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionEncoder(**block_params)
class NonEquivariantAtomDecoder(nn.Module):
def __init__(self, block_params):
super().__init__()
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
self.model = AtomAttentionDecoder(**block_params)

View File

@@ -1,23 +0,0 @@
import importlib
import sys
from typing import Any
def resolve_import(path: str) -> Any:
"""
Import a module from a string path.
If the module is not already imported, we dynamically import
with `importlib.import_module` and return the module object.
Args:
path (str): The path to the module.
Example usage with Hydra, assuming the module `rf2aa.setup` exists within the PYTHONPATH:
```yaml
# config.yaml
setup: ${resolve_import:rf2aa.setup}
```
"""
namespace, name = path.rsplit(".", maxsplit=1)
importlib.import_module(namespace)
return sys.modules[namespace].__dict__[name]

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,183 +0,0 @@
asd
OpenBabel10022416543D
86 91 0 0 1 0 0 0 0 0999 V2000
0.0000 0.0000 0.0000 Ru 0 0 0 0 0 0 0 0 0 0 0 0
1.9811 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
2.8984 0.0000 0.9976 N 0 3 0 0 0 4 0 0 0 0 0 0
4.2283 -0.4311 0.5225 C 0 0 1 0 0 0 0 0 0 0 0 0
4.1294 -0.1510 -0.9789 C 0 0 2 0 0 0 0 0 0 0 0 0
2.6689 -0.1856 -1.1704 N 0 0 0 0 0 0 0 0 0 0 0 0
4.5104 0.8502 -1.2196 H 0 0 0 0 0 0 0 0 0 0 0 0
4.3148 -1.5116 0.6738 H 0 0 0 0 0 0 0 0 0 0 0 0
2.0926 0.0177 -2.4711 C 0 0 0 0 0 0 0 0 0 0 0 0
2.0551 1.3215 -3.0085 C 0 0 0 0 0 0 0 0 0 0 0 0
1.4395 1.5092 -4.2457 C 0 0 0 0 0 0 0 0 0 0 0 0
0.8688 0.4467 -4.9577 C 0 0 0 0 0 0 0 0 0 0 0 0
0.9728 -0.8382 -4.4242 C 0 0 0 0 0 0 0 0 0 0 0 0
1.5980 -1.0817 -3.1942 C 0 0 0 0 0 0 0 0 0 0 0 0
1.3946 2.5134 -4.6607 H 0 0 0 0 0 0 0 0 0 0 0 0
0.5711 -1.6818 -4.9808 H 0 0 0 0 0 0 0 0 0 0 0 0
1.7839 -2.5051 -2.7336 C 0 0 0 0 0 0 0 0 0 0 0 0
2.6142 -2.9688 -3.2824 H 0 0 0 0 0 0 0 0 0 0 0 0
1.9963 -2.5662 -1.6680 H 0 0 0 0 0 0 0 0 0 0 0 0
0.8879 -3.0998 -2.9345 H 0 0 0 0 0 0 0 0 0 0 0 0
2.6908 2.4918 -2.3015 C 0 0 0 0 0 0 0 0 0 0 0 0
2.2457 3.4308 -2.6402 H 0 0 0 0 0 0 0 0 0 0 0 0
3.7660 2.5403 -2.5179 H 0 0 0 0 0 0 0 0 0 0 0 0
2.5630 2.4326 -1.2198 H 0 0 0 0 0 0 0 0 0 0 0 0
0.1546 0.6923 -6.2641 C 0 0 0 0 0 0 0 0 0 0 0 0
0.0871 -0.2224 -6.8610 H 0 0 0 0 0 0 0 0 0 0 0 0
0.6644 1.4570 -6.8598 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.8688 1.0481 -6.0884 H 0 0 0 0 0 0 0 0 0 0 0 0
2.6482 0.2211 2.3858 C 0 0 0 0 0 0 0 0 0 0 0 0
2.4836 1.5443 2.8329 C 0 0 0 0 0 0 0 0 0 0 0 0
2.2054 1.7521 4.1866 C 0 0 0 0 0 0 0 0 0 0 0 0
2.0596 2.7690 4.5422 H 0 0 0 0 0 0 0 0 0 0 0 0
2.2790 -0.6127 4.6093 C 0 0 0 0 0 0 0 0 0 0 0 0
2.5615 -0.8708 3.2641 C 0 0 0 0 0 0 0 0 0 0 0 0
2.1903 -1.4511 5.2957 H 0 0 0 0 0 0 0 0 0 0 0 0
2.0965 0.6881 5.0892 C 0 0 0 0 0 0 0 0 0 0 0 0
2.7358 -2.2864 2.7757 C 0 0 0 0 0 0 0 0 0 0 0 0
2.3651 -2.9970 3.5199 H 0 0 0 0 0 0 0 0 0 0 0 0
3.7928 -2.5228 2.5978 H 0 0 0 0 0 0 0 0 0 0 0 0
2.1924 -2.4512 1.8403 H 0 0 0 0 0 0 0 0 0 0 0 0
3.4096 2.6945 1.2403 H 0 0 0 0 0 0 0 0 0 0 0 0
2.5155 2.7062 1.8725 C 0 0 0 0 0 0 0 0 0 0 0 0
2.4962 3.6544 2.4168 H 0 0 0 0 0 0 0 0 0 0 0 0
1.6520 2.6776 1.1976 H 0 0 0 0 0 0 0 0 0 0 0 0
1.8132 0.9407 6.5497 C 0 0 0 0 0 0 0 0 0 0 0 0
1.1857 1.8271 6.6864 H 0 0 0 0 0 0 0 0 0 0 0 0
2.7445 1.1113 7.1053 H 0 0 0 0 0 0 0 0 0 0 0 0
1.3095 0.0856 7.0112 H 0 0 0 0 0 0 0 0 0 0 0 0
-0.3034 2.2463 -0.7911 Cl 0 0 0 0 0 0 0 0 0 0 0 0
-0.0311 -2.4021 0.1684 Cl 0 0 0 0 0 0 0 0 0 0 0 0
-0.3847 0.3239 1.7667 C 0 0 0 0 0 3 0 0 0 0 0 0
0.3401 0.5495 2.5448 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.7569 0.2734 2.2121 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.7733 0.0216 1.2582 C 0 0 0 0 0 0 0 0 0 0 0 0
-4.1105 -0.0598 1.6390 C 0 0 0 0 0 0 0 0 0 0 0 0
-4.4360 0.1136 2.9889 C 0 0 0 0 0 0 0 0 0 0 0 0
-3.4526 0.3669 3.9517 C 0 0 0 0 0 0 0 0 0 0 0 0
-2.1198 0.4461 3.5620 C 0 0 0 0 0 0 0 0 0 0 0 0
-1.3353 0.6371 4.2895 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.7295 0.4995 4.9924 H 0 0 0 0 0 0 0 0 0 0 0 0
-5.4785 0.0501 3.2858 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.8949 -0.2514 0.9184 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.2963 -0.1404 -0.0112 O 0 0 0 0 0 0 0 0 0 0 0 0
5.5671 -0.2126 2.2083 H 0 0 0 0 0 0 0 0 0 0 0 0
5.1577 1.3111 1.4179 H 0 0 0 0 0 0 0 0 0 0 0 0
6.6571 0.2637 0.4773 N 0 0 0 0 0 0 0 0 0 0 0 0
5.3944 0.2611 1.2403 C 0 0 0 0 0 0 0 0 0 0 0 0
6.8286 1.1116 -0.0564 H 0 0 0 0 0 0 0 0 0 0 0 0
7.1294 -1.0762 -0.3820 S 0 0 0 0 0 0 0 0 0 0 0 0
6.3097 -0.9086 -1.8483 N 0 0 0 0 0 0 0 0 0 0 0 0
8.5424 -0.9362 -0.7279 O 0 0 0 0 0 0 0 0 0 0 0 0
6.6677 -2.2360 0.3905 O 0 0 0 0 0 0 0 0 0 0 0 0
4.5395 -1.0515 -2.8887 H 0 0 0 0 0 0 0 0 0 0 0 0
4.6166 -2.1893 -1.5280 H 0 0 0 0 0 0 0 0 0 0 0 0
4.8612 -1.1705 -1.8517 C 0 0 0 0 0 0 0 0 0 0 0 0
6.8079 -1.4416 -2.5601 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.2077 -0.2475 -1.1550 C 0 0 0 0 0 0 0 0 0 0 0 0
-3.9743 -0.9748 -0.8710 H 0 0 0 0 0 0 0 0 0 0 0 0
-2.3957 -0.8002 -2.3143 C 0 0 0 0 0 0 0 0 0 0 0 0
-1.6290 -0.0874 -2.6345 H 0 0 0 0 0 0 0 0 0 0 0 0
-1.9096 -1.7391 -2.0383 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.0669 -0.9839 -3.1595 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.8219 1.1126 -1.4561 C 0 0 0 0 0 0 0 0 0 0 0 0
-4.5239 1.0169 -2.2909 H 0 0 0 0 0 0 0 0 0 0 0 0
-4.3670 1.5131 -0.5970 H 0 0 0 0 0 0 0 0 0 0 0 0
-3.0331 1.8184 -1.7297 H 0 0 0 0 0 0 0 0 0 0 0 0
1 50 1 0 0 0 0
1 51 1 0 0 0 0
2 1 1 0 0 0 0
2 3 2 0 0 0 0
3 29 1 0 0 0 0
4 8 1 1 0 0 0
4 3 1 0 0 0 0
4 67 1 0 0 0 0
5 7 1 6 0 0 0
5 4 1 0 0 0 0
6 5 1 0 0 0 0
6 2 1 0 0 0 0
9 6 1 0 0 0 0
10 9 2 0 0 0 0
10 21 1 0 0 0 0
11 10 1 0 0 0 0
12 13 1 0 0 0 0
12 11 2 0 0 0 0
13 14 2 0 0 0 0
14 17 1 0 0 0 0
14 9 1 0 0 0 0
15 11 1 0 0 0 0
16 13 1 0 0 0 0
17 19 1 0 0 0 0
18 17 1 0 0 0 0
20 17 1 0 0 0 0
21 24 1 0 0 0 0
22 21 1 0 0 0 0
23 21 1 0 0 0 0
25 28 1 0 0 0 0
25 12 1 0 0 0 0
26 25 1 0 0 0 0
27 25 1 0 0 0 0
29 30 2 0 0 0 0
29 34 1 0 0 0 0
30 31 1 0 0 0 0
31 32 1 0 0 0 0
31 36 2 0 0 0 0
33 36 1 0 0 0 0
33 35 1 0 0 0 0
34 33 2 0 0 0 0
36 45 1 0 0 0 0
37 34 1 0 0 0 0
37 38 1 0 0 0 0
39 37 1 0 0 0 0
40 37 1 0 0 0 0
41 42 1 0 0 0 0
42 43 1 0 0 0 0
42 30 1 0 0 0 0
44 42 1 0 0 0 0
45 46 1 0 0 0 0
45 48 1 0 0 0 0
45 47 1 0 0 0 0
49 1 1 0 0 0 0
51 53 1 0 0 0 0
51 52 1 0 0 0 0
53 58 2 0 0 0 0
54 55 2 0 0 0 0
54 53 1 0 0 0 0
55 56 1 0 0 0 0
56 61 1 0 0 0 0
56 57 2 0 0 0 0
57 60 1 0 0 0 0
58 57 1 0 0 0 0
58 59 1 0 0 0 0
62 55 1 0 0 0 0
63 54 1 0 0 0 0
63 1 1 0 0 0 0
66 67 1 0 0 0 0
67 65 1 0 0 0 0
67 64 1 0 0 0 0
68 66 1 0 0 0 0
69 72 2 0 0 0 0
69 66 1 0 0 0 0
70 69 1 0 0 0 0
71 69 2 0 0 0 0
73 75 1 0 0 0 0
75 70 1 0 0 0 0
75 74 1 0 0 0 0
75 5 1 0 0 0 0
76 70 1 0 0 0 0
77 78 1 0 0 0 0
77 63 1 0 0 0 0
79 81 1 0 0 0 0
79 77 1 0 0 0 0
80 79 1 0 0 0 0
82 79 1 0 0 0 0
83 77 1 0 0 0 0
83 85 1 0 0 0 0
84 83 1 0 0 0 0
86 83 1 0 0 0 0
M END
$$$$

View File

@@ -1,144 +0,0 @@
import os
from functools import partial
import hydra
import pandas as pd
import torch
import torch.multiprocessing as mp
from omegaconf import OmegaConf
from torch.nn.parallel import DistributedDataParallel as DDP
from rf2aa.chemical import initialize_chemdata
from rf2aa.data.compose_dataset import compose_posebusters
from rf2aa.data.dataloader_adaptor import get_loss_calc_items
from rf2aa.trainer_new import trainer_factory
from rf2aa.util import writepdb
class PoseBustersBenchmark:
def __init__(self, config):
# config file logic for validation, low->high prio:
# 1) use default parameters in config/train/base.yml
# 2) use parameters saved in model
# 3) use specific params in config/inference
default_config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "config/train/base.yaml"
)
base_config = OmegaConf.load(default_config_path)
tmp_data = torch.load(config.eval_params.checkpoint_path, mmap=True)
if "training_config" in tmp_data:
train_config = tmp_data["training_config"]
self.config = OmegaConf.merge(base_config, train_config, config)
else:
self.config = OmegaConf.merge(base_config, config)
tmp_data = None
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
if self.config.experiment.output_dir is not None:
self.output_dir = self.config.experiment.output_dir
else:
self.output_dir = "output/"
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
self.trainer = trainer_factory[self.config.experiment.trainer](
config=self.config
)
def construct_dataset(self, rank, world_size):
# fd initialize chemical data based on input arguments
# this needs to be initialized first
init = partial(initialize_chemdata, self.config)
init()
return compose_posebusters(init, self.config.loader_params, rank, world_size)
def launch_distributed_eval(self):
world_size = torch.cuda.device_count()
if "MASTER_ADDR" not in os.environ:
os.environ["MASTER_ADDR"] = (
"127.0.0.1" # multinode requires this set in submit script
)
if "MASTER_PORT" not in os.environ:
os.environ["MASTER_PORT"] = "%d" % self.config.ddp_params.port
world_size = torch.cuda.device_count()
if world_size == 0:
print("Error! No GPUs found!")
elif world_size == 1:
# No need for multiple processes with 1 GPU
self.evaluate_model(0, world_size)
else:
mp.spawn(
self.evaluate_model, args=(world_size,), nprocs=world_size, join=True
)
def evaluate_model(self, rank, world_size):
gpu = self.trainer.init_process_group(rank, world_size)
benchmark_loader = self.construct_dataset(rank, world_size)
# move global information to device
self.trainer.move_constants_to_device(gpu)
self.trainer.construct_model(device=gpu)
self.trainer.model = DDP(
self.trainer.model,
device_ids=[gpu],
find_unused_parameters=False,
broadcast_buffers=False,
)
self.trainer.load_checkpoint(rank)
self.trainer.load_model()
self.trainer.model.eval()
records = []
for inputs in benchmark_loader:
item = inputs[-1]
with torch.no_grad():
loss, loss_dict, outputs = self.trainer.train_step(
inputs,
self.config.loader_params.maxcycle,
nograds=True,
return_outputs=True,
)
loss_dict["CHAINID"] = item["CHAINID"][0]
for k, v in loss_dict.items():
if torch.is_tensor(v):
loss_dict[k] = v.item()
records.append(loss_dict)
df = pd.DataFrame(records)
df.to_csv(
f"{self.output_dir}/{self.config.experiment.name}_{rank}_posebusters.csv"
)
torch.cuda.empty_cache()
true_crds = inputs[5]
seq, _, idx_pdb, bond_feats, _, _ = get_loss_calc_items(inputs, device=gpu)
pred_crds, alphas, pred_lddts = outputs[5], outputs[6], outputs[8]
_, pred_allatom = self.trainer.xyz_converter.compute_all_atom(
seq[:, 0], pred_crds[-1], alphas[-1]
)
writepdb(
f"{self.output_dir}/{item['CHAINID'][0]}_nat.pdb",
true_crds[:, 0],
seq[:, 0].long(),
bond_feats=bond_feats,
)
writepdb(
f"{self.output_dir}/{item['CHAINID'][0]}_pred.pdb",
pred_allatom[0],
seq[:, 0].long(),
bond_feats=bond_feats,
)
@hydra.main(version_base=None, config_path="config/inference")
def main(config):
benchmarker = PoseBustersBenchmark(config=config)
benchmarker.launch_distributed_eval()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,33 @@
#!/bin/bash
# This script builds a datahub apptainer container.
set -e # Exit on error
echo "Running from $PWD"
# Check if apptainer/singularity is available
APPTAINER_BINARY=$(command -v apptainer || command -v singularity)
if [ -z "$APPTAINER_BINARY" ]; then
echo "Error: Neither apptainer nor singularity found in PATH"
exit 1
fi
echo "Using apptainer at: $APPTAINER_BINARY"
# Generate the image name with today's date
DATE=$(date +%Y-%m-%d)
IMAGE_NAME="modelhub_${DATE}.sif"
echo "Building apptainer image: $IMAGE_NAME"
# Build Phase
echo
echo "=== Starting Build Phase ==="
echo "Running: $APPTAINER_BINARY build --notest '$IMAGE_NAME' base_apptainer.spec"
echo "----------------------------------------"
$APPTAINER_BINARY build \
--nv \
--notest \
"$IMAGE_NAME" base_apptainer.spec
echo "----------------------------------------"
echo
echo "=== Build Complete ==="
echo "Container is available at: $PWD/$IMAGE_NAME"

50
scripts/freeze_apptainer.sh Executable file
View File

@@ -0,0 +1,50 @@
#!/bin/bash
# This script freezes CIFUtils, Datahub, and Modelhub versions within an existing apptainer.
set -e # Exit on error
echo "Running from $PWD"
SCRIPT_PATH=$(realpath $0)
SCRIPT_DIR=$(dirname $SCRIPT_PATH)
# Check if apptainer/singularity is available
APPTAINER_BINARY=$(command -v apptainer || command -v singularity)
if [ -z "$APPTAINER_BINARY" ]; then
echo "Error: Neither apptainer nor singularity found in PATH"
exit 1
fi
echo "Using apptainer at: $APPTAINER_BINARY"
# This is the default apptainer that you can build from 'make apptainer'
echo "... looking for a local apptainer image at '$SCRIPT_DIR/modelhub.sif'"
SIF_PATH="$SCRIPT_DIR/modelhub.sif"
SIF_PATH=$(readlink -f "$SCRIPT_DIR/modelhub.sif" )
echo "Base SIF path to build from: $SIF_PATH"
# Generate the image name with today's date
DATE=$(date +%Y-%m-%d)
IMAGE_NAME="frozen_modelhub_${DATE}.sif"
echo "Building apptainer from image with frozen dependencies: $IMAGE_NAME"
# Check if INSTALL_PROJECT is set to true and set the image name accordingly
if ${INSTALL_PROJECT}; then
echo "Modelhub WILL be installed in the apptainer! Ensure that this is intentional."
IMAGE_NAME="frozen_modelhub_datahub_cifutils_${DATE}.sif"
else
IMAGE_NAME="frozen_datahub_cifutils_${DATE}.sif"
fi
# Build Phase
echo
echo "=== Starting Build Phase ==="
echo "Running: $APPTAINER_BINARY build --notest '$IMAGE_NAME' freeze_apptainer.spec"
echo "----------------------------------------"
INSTALL_PROJECT=$INSTALL_PROJECT $APPTAINER_BINARY build \
--nv \
--notest \
"$IMAGE_NAME" freeze_apptainer.spec
echo "----------------------------------------"
echo
echo "=== Build Complete ==="
echo "Container is available at: $PWD/$IMAGE_NAME"

View File

@@ -0,0 +1,7 @@
This directory contains scripts that are not to be run directly by the user.
They are [SHEBANG scripts](https://en.wikipedia.org/wiki/Shebang_(Unix)) that are used to run the appropriate apptainer container.
For example, the script `modelhub_exec.sh` is used to run the modelhub apptainer container with the latest apptainer image
stored locally or at the IPD.
The shebang lines (`#!/bin/bash` ...) at the top of entry point scripts like `train.py` redirect the system to here to find the correct apptainer container.

View File

@@ -0,0 +1 @@
/projects/ml/modelhub/apptainer/modelhub_2025-03-19.sif

151
scripts/shebang/modelhub_exec.sh Executable file
View File

@@ -0,0 +1,151 @@
#!/usr/bin/bash
###################
# You can add the path to this file as the shebang line in your python script.
# Then by default, the python script will be executed with the python interpreter
# in the SIF_PATH container. Here, we launch the container with nvidia gpu and slurm support.
#
# Example shebang: #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/scripts/shebang/modelhub_exec.sh" "$0" "$@"'
###################
# Let the user know this script is setting things up behind the scene
SCRIPT_PATH=$(realpath $0)
SCRIPT_DIR=$(dirname $SCRIPT_PATH)
echo '################## Start shebang info ##################'
echo "The file $SCRIPT_PATH is being run as a shebang executable.
It will...
1. Add the 'modelhub' and 'src/modelhub' repo directories to your PYTHONPATH.
2. Run your python script from the right container, which contains all dependencies.
3. Launch the container with slurm and nvidia gpu support."
# Extract the path to the Python script from the arguments
PYTHON_SCRIPT=$(realpath "$1")
shift
# Find repository root by looking for .project-root file
find_repo_root() {
local current_dir="$1"
while [ "$current_dir" != "/" ]; do
if [ -f "$current_dir/.project-root" ]; then
echo "$current_dir"
return 0
fi
current_dir="$(dirname "$current_dir")"
done
return 1
}
echo
echo "Searching for repository root directory..."
REPO_ROOT=$(find_repo_root "$(dirname "$PYTHON_SCRIPT")")
if [ -z "$REPO_ROOT" ]; then
echo "Error: Could not find .project-root file in any parent directory"
exit 1
else
echo "... found repository root at '$REPO_ROOT'"
fi
# Function to add a directory to PYTHONPATH if it's not already included
add_to_pythonpath() {
local dir_path="$1"
if [[ ":$PYTHONPATH:" != *":$dir_path:"* ]]; then
export PYTHONPATH="$dir_path:$PYTHONPATH"
echo "Added '$dir_path' to PYTHONPATH."
else
echo "'$dir_path' is already in PYTHONPATH."
fi
}
# Add the src directory to PYTHONPATH if not already present
echo
echo "Checking and adding 'src' directory to PYTHONPATH..."
SRC_PATH="$REPO_ROOT/src"
add_to_pythonpath "$SRC_PATH"
# Add modelhub to PYTHONPATH if not already present
echo
echo "Checking and adding 'modelhub' directory to PYTHONPATH..."
MODELHUB_PATH="$SRC_PATH/modelhub"
add_to_pythonpath "$MODELHUB_PATH"
# Load the .env file environment variables from the repo root
echo
echo "Attempting to load environment variables from .env file:"
if [ -f "$REPO_ROOT/.env" ]; then
echo "... loading environment variables from '$REPO_ROOT/.env'"
export $(cat "$REPO_ROOT/.env" | grep -v '#' | xargs)
else
echo " Warning: No .env file found at repository root ($REPO_ROOT)"
fi
# check if we are at the IPD
IPD_FILE="/software/containers/versions/rf_diffusion_aa/ipd.txt"
SIF_PATH=""
echo
echo "Fetching the appropriate apptainer image..."
if [ -z "$APPTAINER_NAME" ]; then
if [ -n "$PROJECT_PATH" ]; then
# Attempt to find any .sif file in the PROJECT_PATH/scripts/shebang directory
SIF_DIR="$PROJECT_PATH/scripts/shebang"
SIF_FILE=$(find "$SIF_DIR" -maxdepth 1 -name "*.sif" -print -quit)
if [ -n "$SIF_FILE" ]; then
SIF_PATH="$SIF_FILE"
fi
fi
# If SIF_PATH is still empty, use the default SIF
if [ -z "$SIF_PATH" ]; then
SIF_NAME="modelhub.sif"
SIF_PATH="$SCRIPT_DIR/$SIF_NAME"
fi
echo "... looking for a local apptainer image at '$SIF_PATH'"
# Check if the SIF file exists
if [ ! -f "$SIF_PATH" ]; then
echo "... apptainer not found. To run with your own apptainer image, you can build it with 'make apptainer' and place it here: '$SIF_PATH'"
echo "Attempting to run $PYTHON_SCRIPT with $(which python)"
fi
else
echo "Already running inside container $APPTAINER_NAME. Executing $PYTHON_SCRIPT with $(which python) in the existing container."
fi
# Function to print debug=mode warning
print_debug_warning() {
echo
echo "###############################################################################"
echo "# #"
echo "# ⚠️ WARNING ⚠️ #"
echo "# RUNNING WITH DEBUGPY ON PORT $DEBUG_PORT #"
echo "# DON'T FORGET TO ATTACH A DEBUGGER #"
echo "# #"
echo "###############################################################################"
echo
}
if [ -n "$DEBUG_PORT" ]; then
print_debug_warning
python_cmd="python -m debugpy --listen $DEBUG_PORT --wait-for-client"
else
python_cmd="python"
echo
fi
if [ ! -z $SIF_PATH ]; then
echo "Running $PYTHON_SCRIPT with apptainer: $SIF_PATH."
echo '################## End shebang info ####################'
echo
/usr/bin/apptainer exec --nv --slurm \
--bind "$REPO_ROOT:$REPO_ROOT" \
--env PYTHONPATH="\$PYTHONPATH:$PYTHONPATH" \
$SIF_PATH $python_cmd "$PYTHON_SCRIPT" "$@"
else
echo "Running $PYTHON_SCRIPT with python: $(which python)"
echo '################## End shebang info ####################'
echo
$python_cmd "$PYTHON_SCRIPT" "$@"
fi

78
scripts/slurm/launch.sh Normal file
View File

@@ -0,0 +1,78 @@
#!/bin/bash
#SBATCH -p gpu-train
#SBATCH --nodes 2
#SBATCH --gres=gpu:l40:8
#SBATCH --ntasks-per-node 8
#SBATCH --mem=512g
#SBATCH -t 7-00:00:00
#SBATCH -J af3-old-msas-pdb-only-experimental
#SBATCH -o slurm_logs/%x_%j.out
#SBATCH -e slurm_logs/%x_%j.err
#SBATCH --no-kill=off
### Excluded Nodes:
### To call this script run: `sbatch launch.sh` from this directory
### For reference, see the Lightning Fabric + SLURM guide: https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html
# (In case we're still running in debug mode)
unset DEBUG_PORT
unset PROJECT_PATH
# (SLURM setup, ensuring we have a unique port per job, and setting the master address to Rank 0)
export MASTER_PORT=$((1024 + RANDOM % 64512))
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
### Set custom paths
# WARNING: You will need to update these paths to match your local setup
# ... cifutils and datahub
export PYTHONPATH="/home/ncorley/projects/datahub/src:/home/ncorley/projects/cifutils/src:/home/ncorley/projects/modelhub/src"
# ... project path (if not using root src/modelhub)
export PROJECT_PATH="/home/ncorley/projects/modelhub/projects/rfscore"
# ... cache directory for Triton kernels (e.g., DeepSpeed4Science fused kernels)
export TRITON_CACHE_DIR="/home/ncorley/.triton" # Change this to a directory with write permissions
### Environment flags
# Debugging flags (optional)
export NCCL_DEBUG=INFO # NCCL internal debugging
export PYTHONFAULTHANDLER=1 # Catches Python core dumps (e.g., segmentation faults)
# Expand CUDA memory
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
# Turn off NVLink (L40 do not have NVLink)
export NCCL_P2P_DISABLE=1
# OPENMP and OPENBLAS optimizations
# https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#utilize-openmp
# NOTE: Must be optimized per-system; see: https://github.com/pytorch/pytorch/blob/65e6194aeb3269a182cfe2c05c122159da12770f/torch/distributed/run.py#L596-L608
export OMP_NUM_THREADS=4
export OPENBLAS_NUM_THREADS=4
#######################################################################################################
### WARNING: The command below is just an example. It will fail if you don't update the experiment ###
### config in the command below. Please adapt according to your target experiment ###
#######################################################################################################
### Set the effective batch size
EFFECTIVE_BATCH_SIZE=16
### Compose the training script
DEVICES_PER_NODE=${SLURM_NTASKS_PER_NODE:-8} # Default to 8 if not set
echo "Running on $SLURM_NNODES nodes with $DEVICES_PER_NODE tasks per node"
### Calculate grad_accum_steps
GRAD_ACCUM_STEPS=$((EFFECTIVE_BATCH_SIZE / (DEVICES_PER_NODE * SLURM_NNODES)))
echo "Grad Accumulation Steps: $GRAD_ACCUM_STEPS"
command="srun --kill-on-bad-exit ../../src/modelhub/train.py \
experiment=$SLURM_JOB_NAME \
++trainer.devices_per_node=$DEVICES_PER_NODE \
++trainer.num_nodes=$SLURM_NNODES \
++trainer.grad_accum_steps=$GRAD_ACCUM_STEPS"
echo -e "command\t$command"
# Let 'er rip
$command

Some files were not shown because too many files have changed in this diff Show More