mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
refactor: new modelhub (#109)
* Initial commit of chiral changes Initial checkin of chiral feature code Add chiral metric * Update the way chiral features are incorporated into the model Move initialization to new func use default pytorch reset parameters fix initialization for chirals config rename argument of confidence head fix initialization for chirals * refactor: src nest, rename rf2aa to modelhub * refactor: initial commit without projects * Initial commit of chiral changes * Initial checkin of chiral feature code * Add chiral metric * Remove option for double residual connection. Add kq_norm oiptions to base (20250125) config. * Restoring flag * config * rename argument of confidence head * Update the way chiral features are incorporated into the model * config * rename argument of confidence head * Update the way chiral features are incorporated into the model * Initial commit of chiral changes Initial checkin of chiral feature code Add chiral metric * Update the way chiral features are incorporated into the model Move initialization to new func use default pytorch reset parameters fix initialization for chirals config rename argument of confidence head fix initialization for chirals * refactor: new modelhub --------- Co-authored-by: fdimaio <dimaio@uw.edu> Co-authored-by: HaotianZhangAI4Science <haotianzhang@zju.edu.cn>
This commit is contained in:
6
.env
6
.env
@@ -1,5 +1,9 @@
|
|||||||
|
# +--------+ Cifutils +--------+
|
||||||
CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd
|
CCD_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_11_ccd
|
||||||
|
|
||||||
PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb
|
PDB_MIRROR_PATH=/projects/ml/frozen_pdb_copies/2024_12_01_pdb
|
||||||
|
|
||||||
|
# +--------+ Datahub +--------+
|
||||||
|
# (Distillation)
|
||||||
AF2FB_PATH=/squash/af2_distillation_facebook
|
AF2FB_PATH=/squash/af2_distillation_facebook
|
||||||
|
# (PadDNA TRansform)
|
||||||
|
X3DNA=/projects/ml/prot_dna/x3dna-v2.4
|
||||||
|
|||||||
237
.gitignore
vendored
237
.gitignore
vendored
@@ -1,23 +1,220 @@
|
|||||||
valid_remapped
|
# Base .gitignore from https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||||
lig_test
|
|
||||||
dataset.pkl
|
# Byte-compiled / optimized / DLL files
|
||||||
run_digs.sh
|
|
||||||
*.pdb
|
|
||||||
.vscode
|
|
||||||
slurm_logs/
|
|
||||||
**/output/
|
|
||||||
**/outputs/
|
|
||||||
*/notebooks/
|
|
||||||
*/models/
|
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*/run_scripts/
|
*/__pycache__/
|
||||||
unit_tests/
|
*.py[cod]
|
||||||
ruff.toml
|
*$py.class
|
||||||
*/scratch/
|
|
||||||
*/wandb/
|
# C extensions
|
||||||
rf2aa/dataset_20240318.pkl
|
*.so
|
||||||
*.csv
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebooks (unless explicitly not ignored)
|
||||||
|
.ipynb_checkpoints
|
||||||
|
**/.ipynb
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
|
.pdm.toml
|
||||||
|
.pdm-python
|
||||||
|
.pdm-build/
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
#.idea/
|
||||||
|
|
||||||
|
# VS Code
|
||||||
|
.vscode
|
||||||
|
.history/
|
||||||
|
|
||||||
|
# Slurm
|
||||||
|
*/slurm_logs/
|
||||||
*.err
|
*.err
|
||||||
*.log
|
*.log
|
||||||
*.json
|
|
||||||
data/run.sh
|
# Ruff
|
||||||
|
ruff.toml
|
||||||
|
.ruff_cache
|
||||||
|
|
||||||
|
# Development
|
||||||
|
dev.py
|
||||||
|
|
||||||
|
# Pytest
|
||||||
|
*.benchmarks/
|
||||||
|
|
||||||
|
# Images
|
||||||
|
*.png
|
||||||
|
*.pdf
|
||||||
|
*.svg
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.gif
|
||||||
|
*.bmp
|
||||||
|
*.tiff
|
||||||
|
|
||||||
|
# Versioning
|
||||||
|
**/version.py
|
||||||
|
|
||||||
|
# W&B
|
||||||
|
wandb/
|
||||||
|
|
||||||
|
# Hydra
|
||||||
|
.hydra/
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
**/outputs/
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Other
|
||||||
|
*.sif
|
||||||
|
*.out
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
**/notebooks/
|
||||||
|
**/models/
|
||||||
|
**/run_scripts/
|
||||||
|
**/scratch/
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +0,0 @@
|
|||||||
# This file is a template, and might need editing before it works on your project.
|
|
||||||
# This is a sample GitLab CI/CD configuration file that should run without any modifications.
|
|
||||||
# It demonstrates a basic 3 stage CI/CD pipeline. Instead of real tests or scripts,
|
|
||||||
# it uses echo commands to simulate the pipeline execution.
|
|
||||||
#
|
|
||||||
# A pipeline is composed of independent jobs that run scripts, grouped into stages.
|
|
||||||
# Stages run in sequential order, but jobs within stages run in parallel.
|
|
||||||
#
|
|
||||||
# For more information, see: https://docs.gitlab.com/ee/ci/yaml/index.html#stages
|
|
||||||
#
|
|
||||||
# You can copy and paste this template into a new `.gitlab-ci.yml` file.
|
|
||||||
# You should not add this template to an existing `.gitlab-ci.yml` file by using the `include:` keyword.
|
|
||||||
#
|
|
||||||
# To contribute improvements to CI/CD templates, please follow the Development guide at:
|
|
||||||
# https://docs.gitlab.com/ee/development/cicd/templates.html
|
|
||||||
# This specific template is located at:
|
|
||||||
# https://gitlab.com/gitlab-org/gitlab/-/blob/master/lib/gitlab/ci/templates/Getting-Started.gitlab-ci.yml
|
|
||||||
|
|
||||||
stages: # List of stages for jobs, and their order of execution
|
|
||||||
- test
|
|
||||||
|
|
||||||
unit-test-job: # This job runs in the test stage.
|
|
||||||
stage: test # It only starts when the job in the build stage completes successfully.
|
|
||||||
rules:
|
|
||||||
- if: '$CI_PIPELINE_SOURCE == "merge_request_event"'
|
|
||||||
script:
|
|
||||||
- echo "Running unit tests"
|
|
||||||
- git submodule update --init
|
|
||||||
- cd rf2aa
|
|
||||||
- srun -p gpu --gres=gpu:a4000:1 --cpus-per-task=4 --mem=32G bash ../ci/run_tests.sh
|
|
||||||
|
|
||||||
65
Makefile
65
Makefile
@@ -24,25 +24,72 @@ format:
|
|||||||
ruff format .
|
ruff format .
|
||||||
ruff check --fix .
|
ruff check --fix .
|
||||||
|
|
||||||
## Create a new conda environment
|
_github_token_error:
|
||||||
|
@echo "==============================================================================="; \
|
||||||
|
echo "Error: Environment variables GITHUB_USER and GITHUB_TOKEN must be set."; \
|
||||||
|
echo ""; \
|
||||||
|
echo "You need to set the environment variables GITHUB_USER and GITHUB_TOKEN."; \
|
||||||
|
echo "You can create a personal access token on GitHub at:"; \
|
||||||
|
echo " https://github.com/settings/tokens"; \
|
||||||
|
echo ""; \
|
||||||
|
echo "For more info see: https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens#creating-a-personal-access-token-classic"; \
|
||||||
|
echo ""; \
|
||||||
|
echo "To expose these variables, you can use:"; \
|
||||||
|
echo "export GITHUB_USER=<github-username>"; \
|
||||||
|
echo "export GITHUB_TOKEN=<github-token>"; \
|
||||||
|
echo ""; \
|
||||||
|
echo "It is recommended that you set these tokens in your .bashrc or .zshrc file for future use."; \
|
||||||
|
echo "===============================================================================";
|
||||||
|
exit 1;
|
||||||
|
|
||||||
|
_check_conda:
|
||||||
|
@echo "... checking if conda/mamba is installed"
|
||||||
|
@command -v $(CONDA_BINARY) >/dev/null 2>&1 || { \
|
||||||
|
echo "Error: Conda/mamba is not installed or not found in PATH" >&2; \
|
||||||
|
exit 1; \
|
||||||
|
}
|
||||||
|
@echo "... found conda executable: $(CONDA_BINARY)"
|
||||||
|
|
||||||
|
_check_tokens:
|
||||||
|
@echo "... checking if GITHUB_USER and GITHUB_TOKEN are set"
|
||||||
|
@if [ -z "$(GITHUB_USER)" ] || [ -z "$(GITHUB_TOKEN)" ]; then \
|
||||||
|
$(MAKE) _github_token_error; \
|
||||||
|
fi
|
||||||
|
@echo "... found GITHUB_USER ($(GITHUB_USER)) and GITHUB_TOKEN."
|
||||||
|
|
||||||
|
## Create a new conda environment and install modelhub
|
||||||
env:
|
env:
|
||||||
$(CONDA_BINARY) env create -n modelhub --file environment.yaml
|
@echo "Creating modelhub conda environment: modelhub"
|
||||||
conda init
|
|
||||||
conda activate modelhub
|
@$(MAKE) --no-print-directory _check_tokens
|
||||||
pip install -e ".[dev]"
|
@$(MAKE) --no-print-directory _check_conda
|
||||||
|
|
||||||
|
@$(CONDA_BINARY) env create -n modelhub --file environment.yaml
|
||||||
|
@conda init
|
||||||
|
@conda activate modelhub
|
||||||
|
@pip install -e ".[dev]"
|
||||||
|
@python -m biotite.setup_ccd
|
||||||
|
|
||||||
|
|
||||||
## Install modelhub locally into the current environment
|
## Install modelhub locally into the current environment
|
||||||
install:
|
install:
|
||||||
# Install the conda requirements in the current activated environment
|
# Install the conda requirements in the current activated environment
|
||||||
$(CONDA_BINARY) env update --file environment.yaml
|
$(CONDA_BINARY) env update --file environment.yaml
|
||||||
# Install the pip requirements in the current activated environment
|
# Install the pip requirements in the current activated environment
|
||||||
pip install -e ".[dev]"
|
@pip install -e ".[dev]"
|
||||||
|
@python -m biotite.setup_ccd
|
||||||
|
|
||||||
## Build the apptainer image
|
## Build the apptainer image
|
||||||
apptainer:
|
base_apptainer:
|
||||||
$(eval DATE := $(shell date +%Y-%m-%d))
|
$(eval DATE := $(shell date +%Y-%m-%d))
|
||||||
$(eval IMAGE := modelhub_$(DATE).sif)
|
bash ./scripts/build_base_apptainer.sh
|
||||||
bash ./scripts/build_apptainer.sh
|
|
||||||
|
# Set INSTALL_PROJECT to true to install modelhub within the apptainer (much slower)
|
||||||
|
# e.g., `make INSTALL_PROJECT=true freeze_apptainer` or `make freeze_apptainer INSTALL_PROJECT=true`
|
||||||
|
INSTALL_PROJECT ?= false
|
||||||
|
freeze_apptainer:
|
||||||
|
$(eval DATE := $(shell date +%Y-%m-%d))
|
||||||
|
bash ./scripts/freeze_apptainer.sh $(INSTALL_PROJECT)
|
||||||
|
|
||||||
## Run pytest and generate coverage report
|
## Run pytest and generate coverage report
|
||||||
test:
|
test:
|
||||||
|
|||||||
219
README.md
219
README.md
@@ -1,51 +1,210 @@
|
|||||||
RoseTTAFold All-Atom
|
# Modelhub
|
||||||
--------------------
|
|
||||||
|
|
||||||
This repository contains the code to training and running inference on
|
- [Modelhub](#modelhub)
|
||||||
RoseTTAFold All-Atom (RFAA), a neural network that can predict the structures
|
- [Background](#background)
|
||||||
of proteins in complex with DNA, RNA, and/or small molecule ligands.
|
- [Division of code between Modelhub, Datahub, and Cifutils](#division-of-code-between-modelhub-datahub-and-cifutils)
|
||||||
|
- [Cifutils](#cifutils)
|
||||||
|
- [Datahub](#datahub)
|
||||||
|
- [Training, Validation, and Inference](#training-validation-and-inference)
|
||||||
|
- [Training and Validation](#training-and-validation)
|
||||||
|
- [Inference](#inference)
|
||||||
|
- [Setup](#setup)
|
||||||
|
- [Apptainers](#apptainers)
|
||||||
|
- [Base Apptainer](#base-apptainer)
|
||||||
|
- [Frozen Apptainer](#frozen-apptainer)
|
||||||
|
- [Shebang](#shebang)
|
||||||
|
- [General Use](#general-use)
|
||||||
|
- [Debugging](#debugging)
|
||||||
|
|
||||||
`rf2aa/` contains the model and training code.
|
## Background
|
||||||
`data/` contains code used to curate the training data from the PDB.
|
|
||||||
|
|
||||||
|
This repository constitutes the base for deep-learning method development at the Institute for Protein Design.
|
||||||
|
|
||||||
## Contributing to RFAA
|
It is symbiotic with two other Institute for Protein Design repositories:
|
||||||
|
- [cifutils](https://github.com/baker-laboratory/cifutils), which manages input parsing and data cleaning
|
||||||
|
- [datahub](https://github.com/baker-laboratory/datahub), which manages input featurization and holds our composable `Transform` components
|
||||||
|
|
||||||
### Set Up
|
Within this ontology, `modelhub` contains the *architectures*, *training* code, and *inference* endpoints.
|
||||||
|
|
||||||
|
## Division of code between Modelhub, Datahub, and Cifutils
|
||||||
|
|
||||||
|
Across our codebases, we balance the need to develop quickly with the need to write code that we can continue to maintain and that is easy to understand. We below lay out some thoughts on what code should live where.
|
||||||
|
|
||||||
|
We enforce a strict dependency flow of `modelhub` -> (depends on) `datahub` -> (depends on) `cifutils`; it would be a circular anti-pattern to thus import any `datahub` or `modelhub` functions from within `cifutils`.
|
||||||
|
|
||||||
|
### Cifutils
|
||||||
|
|
||||||
|
[cifutils](https://github.com/baker-laboratory/cifutils) is the most static of our three codebases. Basic parsing functionality, RDKit and other molecular toolkit utilities, and `AtomArray` quality-of-life tools live in this repository.
|
||||||
|
|
||||||
|
Examples of `cifutils` functions are:
|
||||||
|
- All functions related to **parsing structural files from source**; e.g., keeping/removing hydrogens, resolving occupancy, etc.
|
||||||
|
- Utility functions to manipulate `AtomArrays`, the core API of the `biotite` library, upon which we heavily rely
|
||||||
|
- Utility functions for common bioinformatics software, such as `RDKit`, that interface with `AtomArrays`
|
||||||
|
|
||||||
|
As a foundational library for the Institute for Protein Design, `cifutils` functions most like an open-source codebase. We must keep the code easy-to-understand and easy-to-maintain, both now and into the future. As such, `cifutils`:
|
||||||
|
- Maintains the **highest code quality standard**, requiring well-documented, easy-to-maintain code with adequate test coverage (we aim for **>85%** coverage)
|
||||||
|
- **Strictly versions** to minimize breaking changes with downstream repositories
|
||||||
|
|
||||||
|
You should write code in `cifutils` if:
|
||||||
|
- You are are writing **core** `AtomArray`-level level functionality that will be broadly useful, not only to those at the Institute for Protein Design but possibly the wider bioinformatics community (i.e., without dependencies, or even knowledge of, `datahub` or `modelhub`)
|
||||||
|
- You are willing to spend some additional time to ensure the code is **scalable, well-tested, and maintainable**
|
||||||
|
|
||||||
|
Quick-and-dirty experiments that require modifying `cifutils` can be performed by submoduling or cloning the repository and exporting a local path.
|
||||||
|
|
||||||
|
### Datahub
|
||||||
|
|
||||||
|
[datahub](https://github.com/baker-laboratory/datahub) manages data loading, preprocessing, and featurization pipelines for structure-dependent deep-learning models. We offer three core components: a `Transforms` library, a set of `Preprocessing` scripts, and `Datasets`.
|
||||||
|
- **Transforms**: A series of composable classes that take as input a dictionary containing sequence- and structure-based data (in the form of an `AtomArray`) and perform arbitrary operations, analogous to TorchVision's [approach](https://pytorch.org/vision/main/transforms.html) for computer vision
|
||||||
|
- **Preprocessing**: Scripts and functions for common data cleaning and preparation tasks, including specialized pipelines for frequent use cases (e.g., antibodies, clash detection, cleaning PDB data, etc.). Many of these *scripts* output `parquet` files stored to disk that are sampled from at train-time, while the *functions* are called by the scripts to clean, label, or filter the data (e.g., `has_clash()`, etc.)
|
||||||
|
- **Datasets**: The base `Datasets` and `Sampler` classes used for training, imported by `modelhub`
|
||||||
|
|
||||||
|
`datahub` is less static than `cifutils`; however, it still must operate as a stand-alone library that others can continue to build around and upon, even without `modelhub`. We strive to maintain `datahub` like an open-source software project such that others in the lab can easily understand, and build upon, our base components. We focus on **maintainable** and **flexible** code - if a particular `Transform` is bespoke or non-generalizable (at least initially), then the `/projects` folder within `Modelhub` may be a more appropriate place for initial development.
|
||||||
|
|
||||||
|
You should write code in `datahub` if:
|
||||||
|
- You are writing flexible, generic *pre-processing scripts* or *functions* that others in the lab have expressed interest in using (vs. a single-purpose pipeline or feature to test a hypothesis)
|
||||||
|
- **Example that should live in `datahub`**: You are writing a pre-processing pipeline to label all beta barrels in the PDB. Your scripts, written in a functional manner, may be a good candidate for `datahub/scripts/preprocessing`, so long as you are willing to write them generally and include tests. Similarly, if a single function may be generalizable but the pipeline is bespoke, that single function (with a test) could still be included as a stand-alone element in `datahub`, e.g.,
|
||||||
|
```python
|
||||||
|
atom_array_has_beta_barrel(atom_array: AtomArray) -> bool
|
||||||
|
```
|
||||||
|
- **Example that should live in `modelhub/projects`**: You have pulled together a script that loads PDB files, includes manual annotations, and saves out to CIF. Such a script may be appropriate for the specific use case but is unlikely to generalize across other use cases.
|
||||||
|
- You are writing `Transforms` that generalize to additional use cases beyond the current project
|
||||||
|
- **Example that should live in `datahub`**: Any `Transform` that adds a useful annotation to an `AtomArray` (e.g., annotationg pocket residues, hydrogen bonds, SASA, etc.)
|
||||||
|
- **Example that should live in `datahub`**: A `Transform` that pads DNA with generated B-form structure, as is done in AF-3; such a `Transform` may be applicable to both structure prediction and design, when proven effective
|
||||||
|
- **Example that should live in `modelhub/projects`**: A `Transform` that aggregates and/or concatenates features for a bespoke model pipeline
|
||||||
|
- You are willing to spend some additional time to ensure the code is scalable, well-tested, and maintainable. Otherwise the `projects` folder of `modelhub` may be a more appropriate place in the interim
|
||||||
|
|
||||||
|
## Training, Validation, and Inference
|
||||||
|
|
||||||
|
> If you are developing at the IPD, our `shebang` executables will take care of identifying and executing with the most up-do-date apptainer. If you are not at the IPD, you will need to ensure you have the appropriate apptainer. See below for details.
|
||||||
|
|
||||||
|
NOTE: For Training, Validation, and Inference, we make heavy use of [Hydra](https://hydra.cc/) for configuration management.
|
||||||
|
|
||||||
|
Before running any of the below commands, you will need to ensure `datahub` and `cifutils` are in your `PYTHONPATH`. E.g.,
|
||||||
```
|
```
|
||||||
git clone https://git.ipd.uw.edu/jue/RF2-allatom.git
|
export PYTHONPATH="/home/<USER>/projects/datahub/src:/home/<USER>/projects/cifutils/src"
|
||||||
cd RF2-allatom
|
|
||||||
```
|
```
|
||||||
|
|
||||||
If you are on digs, the S3nv.sif apptainer has all the relevant packages. To get started coding:
|
### Training and Validation
|
||||||
|
|
||||||
|
For Training and Validation, when you execute `train.py` or `validate.py`, you will need to provide an *experiment* Hydra config. Experiments are a Hydra best-practice pattern to enable us to maintain multiple configurations; see more in the [Hydra documentaion](https://hydra.cc/docs/patterns/configuring_experiments/)
|
||||||
|
and in the `configs/experiment` sub-directory.
|
||||||
|
|
||||||
|
For example, to test AF-3 training without confidence, run:
|
||||||
```
|
```
|
||||||
export PYTHONPATH="../RF2-allatom"
|
./src/modelhub/train.py experiment=quick-af3 debug=default
|
||||||
```
|
```
|
||||||
|
|
||||||
First, run the test suite:
|
**Explanation:**
|
||||||
|
- `./src/modelhub/train.py` — we execute our `train.py` like a bash executable, which triggers the `shebang` code to find the correct apptainer. It's equivalent to `apptainer exec --nv /path/to/apptainer python ./src/modelhub/train.py`
|
||||||
|
- `experiment=quick-af3` — we identify the experiment we want to use for training; in this case, `quick-af3`, which can be viewed at `configs/experiment/quick-af3.yaml`. This experiment is a simple test config for AF-3 that loads and runs more rapidly that the full training config
|
||||||
|
- `debug=default` - a setting letter Hydra know we are debugging; when we debug, we perform some automatic time-savings like setting a small diffusion batch size and crop size. You could remove this line if you don't want those options. You can explore more about various `debug` options in `config/debug`
|
||||||
|
|
||||||
|
For validation only, run the following:
|
||||||
```
|
```
|
||||||
apptainer exec --nv /software/containers/versions/SE3nv/SE3nv-20240415.sif pytest tests/
|
./src/modelhub/validate.py experiment=quick-af3 debug=default
|
||||||
```
|
```
|
||||||
If all the tests pass, you have a stable version of the code.
|
|
||||||
|
|
||||||
### Running model training
|
Note that since we use `hydra`, you could specify additional setup arguments using the command line. For example, by default, we `prevalidate` - running validation at the beginning of training so we develop a baseline and catch any errors (especially out-of-memory errors) before training for a full epoch. If you don't want that behavior, you could override in-line:
|
||||||
|
|
||||||
We use a package called hydra to configure different training runs of the model. Config files for different training runs can be found in `rf2aa/config/train`. The base trainable version is `rf2aa/config/train/rf2aa.yaml`, to run training with this version, run:
|
|
||||||
```
|
```
|
||||||
/software/containers/versions/SE3nv/SE3nv-20240415.sif trainer_new.py --config-name rf2aa
|
./src/modelhub/train.py experiment=quick-af3 debug=default trainer.prevalidate=false
|
||||||
```
|
```
|
||||||
These tests are most often run on a4000s on digs. If you have a separate installation of cifutils in your home directory, this can potentially break the tests.
|
|
||||||
|
|
||||||
If you make changes in the code, they should NOT break backwards compatibility, e.g. there should be a flag in the yaml files that would make it as if your changes were never committed.
|
You can view the flattened Hydra configuration to determine how to best override or add additional arguments by:
|
||||||
|
- Running training or validation and viewing the pretty-printed file, which looks like:
|
||||||
|

|
||||||
|
- Adding `--cfg job` to your launch command, which prints the config for the application and then exits
|
||||||
|
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
To support multiple models and multiple projects, we build an `InferenceEngine` for each use case. For end-users the details of the `InferenceEngine` are not necessary; the appropriate engine can be specified with with `inference_engine` argument.
|
||||||
|
|
||||||
|
For example, to run the latest AF-3 model with confidence, we can execute (if `cifutils` and `datahub` are in the `PYTHONPATH`):
|
||||||
|
```
|
||||||
|
./src/modelhub/inference.py inference_engine=af3 inputs='./tests/data/example_with_ncaa.json'
|
||||||
|
```
|
||||||
|
|
||||||
|
We can then modify the command by adding/removing arguments with Hydra to our liking; for example, to dump diffusion trajectories and only include one model per CIF file:
|
||||||
|
```
|
||||||
|
./src/modelhub/inference.py inference_engine=af3 inputs='./tests/data/example_with_ncaa.json' dump_trajectories=true one_model_per_file=true
|
||||||
|
```
|
||||||
|
|
||||||
|
More details can be found in the [inference README](src/modelhub/inference_engines/README.md)
|
||||||
|
|
||||||
|
## Setup
|
||||||
|
|
||||||
|
> If you are developing at the IPD, then our `shebang` executables will handle the Apptainer dependencies; no need to run the commands below. See the `shebang` section below.
|
||||||
|
|
||||||
|
### Apptainers
|
||||||
|
To accelerate development and better contain dependencies, we offer two apptainers:
|
||||||
|
- `base_apptainer`: Contains all of the development dependencies, pre-compiled DeepSpeed, but *NOT* `cifutils` or `datahub`. The rationale for the base apptainer is that you expose these libraries via your PYTHONPATH/PATH to allow you to develop & pull updates for these libraries without having to re-build any apptainer.
|
||||||
|
- `freeze_apptainer`: Takes the `base_apptainer` as its image, and adds versioned `cifutils`, `datahub`, and (optionally) pip-installs `modelhub` as well (useful for releasing self-contained inference code). The rationale for these apptainers is to provide designers with a stable environment to tackle design problems in.
|
||||||
|
|
||||||
|
#### Base Apptainer
|
||||||
|
|
||||||
|
To make the base apptainer, run:
|
||||||
|
```
|
||||||
|
make base_apptainer
|
||||||
|
```
|
||||||
|
from the project root. This container will **not** contain `cifutils` or `datahub`; those paths must be exported explicitly during development (e.g., the paths to their respective submodules or clones elsewhere).
|
||||||
|
|
||||||
|
Building this apptainer pre-compiles DeepSpeed, among other actions, and is slow. You **should not** need to re-build this apptainer often; changes to `datahub` and `cifutils` can be addressed much more efficiently through the `freeze_apptainer` command specified below.
|
||||||
|
|
||||||
|
> NOTE: Since we pre-compile CUDA-specific DeepSpeed, you must run `make base_apptainer` on a GPU node
|
||||||
|
|
||||||
|
> NOTE: You will need to adjust the IPD-speciifc paths to frozen copies of the PDB and the CCD
|
||||||
|
|
||||||
|
#### Frozen Apptainer
|
||||||
|
|
||||||
|
To make a container that contains `cifutils` and `datahub`, but not `modelhub`, run:
|
||||||
|
```
|
||||||
|
make freeze_apptainer
|
||||||
|
```
|
||||||
|
This will use the `base_apptainer` pointed to by the `shebang` symlink as a base. Note that by default the versions of `cifutils` and `datahub` are fixed; update the `freeze_apptainer.spec` file to adjust the version numbers and/or add dependencies.
|
||||||
|
|
||||||
|
To make a container that contains `modelhub`, `datahub`, and `cifutils` (e.g., for production usage across the lab), run
|
||||||
|
```
|
||||||
|
make INSTALL_PROJECT=true freeze_apptainer
|
||||||
|
```
|
||||||
|
> NOTE: Since we build from the `base_apptainer` image, which contains pre-compiled DeepSpeed, `make freeze_apptainer` does NOT need to be run from a GPU
|
||||||
|
|
||||||
|
### Shebang
|
||||||
|
|
||||||
|
#### General Use
|
||||||
|
We use `shebang` to help manage and version apptainers. Namely:
|
||||||
|
- The shebang lines (`#!/bin/bash` ...) at the top of entry point scripts like `train.py` redirect the system to to `scripts/shebang/modelhub_exec.sh`
|
||||||
|
- The script `modelhub_exec.sh` in turn identifies the correct Apptainer and executes your command
|
||||||
|
- Apptainers are symlinks in `scripts/shebang` to elsewhere on the DIGS (where they are versioned); thus, when we update apptainers, we must also update the symlink. This allows us to track which apptainers to use for a given branch of the code at any given time (provided you update the symlinks for your branch when you switch out which apptainer you run with!)
|
||||||
|
|
||||||
|
For example, to launch a dummy training run, one could type (after adding `cifutils` and `datahub` to your `PYTHONPATH`):
|
||||||
|
```
|
||||||
|
cd src/modelhub
|
||||||
|
./train.py experiment=none-00-dummy
|
||||||
|
```
|
||||||
|
> You may need to adjust the permissions on `train.py` (e.g., `chmod +x train.py`) in order to execute the file like a script.
|
||||||
|
|
||||||
|
#### Debugging
|
||||||
|
We also support VSCode-native debugging with Apptainers. To debug:
|
||||||
|
1. Update your `launch.json` to include `Python: Attach`; for example, add the configuration:
|
||||||
|
```
|
||||||
|
{
|
||||||
|
"name": "Python: Attach",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "attach",
|
||||||
|
"connect": {
|
||||||
|
"host": "localhost",
|
||||||
|
"port": 2345
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
2. Add any interactive debug breakpoints in VSCode
|
||||||
|
3. Set the `DEBUG_PORT` to `2345`, and then execute your script with `shebang` like normal. That is:
|
||||||
|
```
|
||||||
|
export DEBUG_PORT=2345
|
||||||
|
./train.py experiment=none-00-dummy
|
||||||
|
```
|
||||||
|
4. When prompted in the termal, launch the VSCode debug session (shortcut: `F5`)
|
||||||
|
|
||||||
|
Happy debugging!
|
||||||
|
|
||||||
### Contributing to model code
|
|
||||||
Generally, we follow software engineering practices of:
|
|
||||||
1. Not duplicating functionality that is already in the code
|
|
||||||
2. Keeping functions as short as possible, and splitting complicated functions into multiple functions
|
|
||||||
3. Using object oriented programming, which means subclassing already existing classes when possible.
|
|
||||||
4. Writing tests for our code and sending small functional PRs for review.
|
|
||||||
5. Maintaining code stability and not breaking backwards compatibility for users using the package.
|
|
||||||
|
|
||||||
To write new blocks in RF, you can go to the rf2aa/model directory and add the new block into the simulator_blocks.py file (and be sure to add a relevant name in the blocks_factory dictionary). These names can be referenced in hydra configs: see rf2aa.yaml for an example with any keyword arguments necessary to initialize the block.
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ From: ubuntu:24.04
|
|||||||
IncludeCmd: yes
|
IncludeCmd: yes
|
||||||
# NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd`
|
# NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd`
|
||||||
# To build this apptainer, use:
|
# To build this apptainer, use:
|
||||||
# apptainer build --bind $PWD:/modelhub_host modelhub_apptainer.sif apptainer.spec
|
# make apptainer
|
||||||
|
|
||||||
%setup
|
%setup
|
||||||
# Create a directory in the container to bind the host's current working directory
|
# Create a directory in the container to bind the host's current working directory
|
||||||
@@ -15,8 +15,9 @@ IncludeCmd: yes
|
|||||||
# ... for mounting `/squash` with --bind
|
# ... for mounting `/squash` with --bind
|
||||||
mkdir ${APPTAINER_ROOTFS}/squash
|
mkdir ${APPTAINER_ROOTFS}/squash
|
||||||
|
|
||||||
|
|
||||||
%files
|
%files
|
||||||
|
/etc/localtime
|
||||||
|
/etc/hosts
|
||||||
environment.yaml /opt/environment.yaml
|
environment.yaml /opt/environment.yaml
|
||||||
|
|
||||||
%post
|
%post
|
||||||
@@ -47,7 +48,10 @@ IncludeCmd: yes
|
|||||||
apt-get clean
|
apt-get clean
|
||||||
|
|
||||||
# Clone CUTLASS (for DeepSpeed)
|
# Clone CUTLASS (for DeepSpeed)
|
||||||
git clone https://github.com/NVIDIA/cutlass /opt/cutlass
|
git clone https://github.com/NVIDIA/cutlass.git /opt/cutlass
|
||||||
|
|
||||||
|
# Clone DeepSpeed (so we can pre-install the wheel)
|
||||||
|
git clone --branch v0.16.2 https://github.com/deepspeedai/DeepSpeed.git /opt/deepspeed
|
||||||
|
|
||||||
## ENVIRONMENT CREATION & DEPENDENCY INSTALLATION
|
## ENVIRONMENT CREATION & DEPENDENCY INSTALLATION
|
||||||
# Download miniconda
|
# Download miniconda
|
||||||
@@ -66,15 +70,27 @@ IncludeCmd: yes
|
|||||||
# Add conda environment to PATH
|
# Add conda environment to PATH
|
||||||
export PATH=/usr/envs/modelhub-apptainer/bin:$PATH
|
export PATH=/usr/envs/modelhub-apptainer/bin:$PATH
|
||||||
|
|
||||||
|
echo "Proceeding with DeepSpeed reinstallation."
|
||||||
|
|
||||||
|
## PRE-COMPILE DEEPSPEED FROM WHEEL
|
||||||
|
# (Overwrite deepspeed installation from the `environment.yaml`)
|
||||||
|
pip uninstall deepspeed -y # Avoid interactive prompts
|
||||||
|
|
||||||
|
# (Flags for building the Evoformer attention)
|
||||||
|
export TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0;8.6;8.9"
|
||||||
|
export DS_BUILD_EVOFORMER_ATTN=1
|
||||||
|
export CUTLASS_PATH=/opt/cutlass/
|
||||||
|
|
||||||
|
# Reinstall DeepSpeed, pre-compiling the evoformer attentino kernel
|
||||||
|
pip wheel /opt/deepspeed -w /opt/deepspeed
|
||||||
|
pip install /opt/deepspeed/deepspeed-0.16.2+b344c04d-cp311-cp311-linux_x86_64.whl
|
||||||
|
|
||||||
# Run the biotite setup command
|
# Run the biotite setup command
|
||||||
# (Temporary measure until we switch to released Biotite version)
|
# (Temporary measure until we switch to released Biotite version)
|
||||||
. /usr/etc/profile.d/conda.sh
|
. /usr/etc/profile.d/conda.sh
|
||||||
conda activate modelhub-apptainer
|
conda activate modelhub-apptainer
|
||||||
python -m biotite.setup_ccd
|
python -m biotite.setup_ccd
|
||||||
|
|
||||||
# deepspeed
|
|
||||||
pip install deepspeed==0.15.1
|
|
||||||
|
|
||||||
# clean up files to reduce size
|
# clean up files to reduce size
|
||||||
# ... remove conda
|
# ... remove conda
|
||||||
mamba clean -a -y
|
mamba clean -a -y
|
||||||
@@ -93,7 +109,7 @@ IncludeCmd: yes
|
|||||||
|
|
||||||
%runscript
|
%runscript
|
||||||
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
|
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
|
||||||
exec "$@"
|
exec python "$@"
|
||||||
|
|
||||||
%help
|
%help
|
||||||
modelhub environment for running modelhub independently and for development
|
modelhub environment for running modelhub independently and for development
|
||||||
@@ -1,4 +0,0 @@
|
|||||||
APP=/software/containers/versions/rf_diffusion_aa/24-05-21/rf_diffusion_aa.sif
|
|
||||||
PYTHONPATH=.. $APP -mpytest --benchmark-skip --ignore tests/test_semantics.py --durations=10 tests
|
|
||||||
|
|
||||||
|
|
||||||
5
configs/callbacks/default.yaml
Normal file
5
configs/callbacks/default.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- train_logging
|
||||||
|
- metrics_logging
|
||||||
|
- dump_validation_structures
|
||||||
|
- _self_
|
||||||
6
configs/callbacks/dump_validation_structures.yaml
Normal file
6
configs/callbacks/dump_validation_structures.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
dump_validation_structures_callback:
|
||||||
|
_target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback
|
||||||
|
save_dir: ${paths.output_dir}/val_structures
|
||||||
|
dump_predictions: False
|
||||||
|
one_model_per_file: False
|
||||||
|
dump_trajectories: False
|
||||||
14
configs/callbacks/metrics_logging.yaml
Normal file
14
configs/callbacks/metrics_logging.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
store_validation_metrics_in_df_callback:
|
||||||
|
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
|
||||||
|
save_dir: ${paths.output_dir}/val_metrics
|
||||||
|
metrics_to_save: "all"
|
||||||
|
|
||||||
|
log_af3_validation_metrics_callback:
|
||||||
|
_target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
|
||||||
|
metrics_to_log:
|
||||||
|
# Only logs if present in the metric output dictionary
|
||||||
|
# Must be subset of metrics_to_save
|
||||||
|
- by_type_lddt
|
||||||
|
- all_atom_lddt
|
||||||
|
- distogram_loss
|
||||||
|
- distogram_comparisons
|
||||||
16
configs/callbacks/train_logging.yaml
Normal file
16
configs/callbacks/train_logging.yaml
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
log_af3_training_losses_callback:
|
||||||
|
_target_: modelhub.callbacks.train_logging.LogAF3TrainingLossesCallback
|
||||||
|
log_every_n: 10
|
||||||
|
log_full_batch_losses: true
|
||||||
|
|
||||||
|
log_learning_rate_callback:
|
||||||
|
_target_: modelhub.callbacks.train_logging.LogLearningRateCallback
|
||||||
|
log_every_n: 10
|
||||||
|
|
||||||
|
log_model_parameters_callback:
|
||||||
|
_target_: modelhub.callbacks.train_logging.LogModelParametersCallback
|
||||||
|
|
||||||
|
log_dataset_sampling_ratios_callback:
|
||||||
|
_target_: modelhub.callbacks.train_logging.LogDatasetSamplingRatiosCallback
|
||||||
|
|
||||||
|
|
||||||
15
configs/dataloader/default.yaml
Normal file
15
configs/dataloader/default.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
train:
|
||||||
|
dataloader_params:
|
||||||
|
# These parameters will be unpacked as kwargs for the DataLoader
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 2
|
||||||
|
prefetch_factor: 3
|
||||||
|
n_fallback_retries: 4
|
||||||
|
|
||||||
|
val:
|
||||||
|
dataloader_params:
|
||||||
|
# These parameters will be unpacked as kwargs for the DataLoader
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 2
|
||||||
|
prefetch_factor: 3
|
||||||
|
n_fallback_retries: 0 # Disable fallback retries for validation
|
||||||
21
configs/datasets/af3.yaml
Normal file
21
configs/datasets/af3.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# AF3 dataset configuration with monomer distillation
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
# The @ symbol specifies the tree under which the item will be attached to the config
|
||||||
|
- train/pdb/af3_train_interface@train.pdb.sub_datasets.interface
|
||||||
|
- train/pdb/af3_train_pn_unit@train.pdb.sub_datasets.pn_unit
|
||||||
|
- train:
|
||||||
|
- monomer_distillation
|
||||||
|
- val/af3_validation@val.af3_validation
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
# Dataloading pipeline to use
|
||||||
|
pipeline_target: datahub.pipelines.af3.build_af3_transform_pipeline
|
||||||
|
|
||||||
|
# Dataset weighting
|
||||||
|
train:
|
||||||
|
pdb:
|
||||||
|
probability: 0.5
|
||||||
|
monomer_distillation:
|
||||||
|
probability: 0.5
|
||||||
12
configs/datasets/base.yaml
Normal file
12
configs/datasets/base.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Base Transform defaults
|
||||||
|
diffusion_batch_size_train: 48
|
||||||
|
diffusion_batch_size_inference: 5
|
||||||
|
|
||||||
|
n_recycles_train: 4
|
||||||
|
n_recycles_validation: 10
|
||||||
|
|
||||||
|
n_msa: 1024
|
||||||
|
crop_size: 384
|
||||||
|
max_atoms_in_crop: 5000
|
||||||
|
|
||||||
|
key_to_balance: n_tokens_total
|
||||||
38
configs/datasets/train/monomer_distillation.yaml
Normal file
38
configs/datasets/train/monomer_distillation.yaml
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
monomer_distillation:
|
||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||||
|
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||||
|
|
||||||
|
# cif parser arguments
|
||||||
|
cif_parser_args:
|
||||||
|
cache_dir: null
|
||||||
|
load_from_cache: False
|
||||||
|
save_to_cache: False
|
||||||
|
|
||||||
|
# metadata parser
|
||||||
|
dataset_parser:
|
||||||
|
_target_: datahub.datasets.parsers.GenericDFParser
|
||||||
|
pn_unit_iid_colnames: null
|
||||||
|
|
||||||
|
# metadata dataset
|
||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.PandasDataset
|
||||||
|
name: af2fb_distillation
|
||||||
|
id_column: example_id
|
||||||
|
data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet
|
||||||
|
columns_to_load:
|
||||||
|
- example_id
|
||||||
|
- path
|
||||||
|
return_key: null
|
||||||
|
transform:
|
||||||
|
_target_: ${datasets.pipeline_target}
|
||||||
|
is_inference: False
|
||||||
|
protein_msa_dirs: [{"dir": "${paths.data.monomer_distillation_data_dir}/msa", "extension": ".a3m", "directory_depth": 2}]
|
||||||
|
rna_msa_dirs: []
|
||||||
|
n_recycles: ${datasets.n_recycles_train}
|
||||||
|
crop_size: ${datasets.crop_size}
|
||||||
|
n_msa: ${datasets.n_msa}
|
||||||
|
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
|
||||||
|
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
|
||||||
|
crop_contiguous_probability: 0.25
|
||||||
|
crop_spatial_probability: 0.75
|
||||||
45
configs/datasets/train/pdb/af3_train_interface.yaml
Normal file
45
configs/datasets/train/pdb/af3_train_interface.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
dataset_parser:
|
||||||
|
_target_: datahub.datasets.parsers.InterfacesDFParser
|
||||||
|
dataset:
|
||||||
|
name: interface
|
||||||
|
data: ${paths.data.pdb_data_dir}/interfaces_df_train.parquet
|
||||||
|
filters:
|
||||||
|
# filters common across all PDB datasets
|
||||||
|
- "deposition_date < '2021-09-30'"
|
||||||
|
- "resolution < 9.0"
|
||||||
|
- "num_polymer_pn_units <= 300"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
# interface specific filters
|
||||||
|
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||||
|
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||||
|
- "is_inter_molecule"
|
||||||
|
columns_to_load:
|
||||||
|
# columns common across all PDB datasets
|
||||||
|
- example_id
|
||||||
|
- pdb_id
|
||||||
|
- assembly_id
|
||||||
|
- deposition_date
|
||||||
|
- resolution
|
||||||
|
- num_polymer_pn_units
|
||||||
|
- method
|
||||||
|
- cluster
|
||||||
|
- n_prot
|
||||||
|
- n_nuc
|
||||||
|
- n_ligand
|
||||||
|
- n_peptide
|
||||||
|
# interface specific columns
|
||||||
|
- pn_unit_1_iid
|
||||||
|
- pn_unit_2_iid
|
||||||
|
- pn_unit_1_non_polymer_res_names
|
||||||
|
- pn_unit_2_non_polymer_res_names
|
||||||
|
- is_inter_molecule
|
||||||
|
- all_pn_unit_iids_after_processing
|
||||||
|
- involves_loi
|
||||||
|
transform:
|
||||||
|
# interface-specific Transform pipeline parameters
|
||||||
|
crop_contiguous_probability: 0.0
|
||||||
|
crop_spatial_probability: 1.0
|
||||||
41
configs/datasets/train/pdb/af3_train_pn_unit.yaml
Normal file
41
configs/datasets/train/pdb/af3_train_pn_unit.yaml
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
dataset_parser:
|
||||||
|
_target_: datahub.datasets.parsers.PNUnitsDFParser
|
||||||
|
dataset:
|
||||||
|
name: pn_unit
|
||||||
|
data: ${paths.data.pdb_data_dir}/pn_units_df_train.parquet
|
||||||
|
filters:
|
||||||
|
# filters common across all PDB datasets
|
||||||
|
- "deposition_date < '2021-09-30'"
|
||||||
|
- "resolution < 9.0"
|
||||||
|
- "num_polymer_pn_units <= 300"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
# pn_unit specific filters
|
||||||
|
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||||
|
columns_to_load:
|
||||||
|
# columns common across all PDB datasets
|
||||||
|
- example_id
|
||||||
|
- pdb_id
|
||||||
|
- assembly_id
|
||||||
|
- deposition_date
|
||||||
|
- resolution
|
||||||
|
- num_polymer_pn_units
|
||||||
|
- method
|
||||||
|
- cluster
|
||||||
|
- n_prot
|
||||||
|
- n_nuc
|
||||||
|
- n_ligand
|
||||||
|
- n_peptide
|
||||||
|
- total_num_atoms_in_unprocessed_assembly
|
||||||
|
# pn_unit specific columns
|
||||||
|
- q_pn_unit_iid
|
||||||
|
- q_pn_unit_non_polymer_res_names
|
||||||
|
- all_pn_unit_iids_after_processing
|
||||||
|
- q_pn_unit_is_loi
|
||||||
|
transform:
|
||||||
|
# pn_unit-specific Transform pipeline parameters
|
||||||
|
crop_contiguous_probability: 0.3333333333333333
|
||||||
|
crop_spatial_probability: 0.6666666666666667
|
||||||
33
configs/datasets/train/pdb/base.yaml
Normal file
33
configs/datasets/train/pdb/base.yaml
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||||
|
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||||
|
cif_parser_args:
|
||||||
|
cache_dir: null
|
||||||
|
load_from_cache: false
|
||||||
|
save_to_cache: false
|
||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.PandasDataset
|
||||||
|
# we will use the example_id as the unique column
|
||||||
|
id_column: example_id
|
||||||
|
# return all keys (do not subset)
|
||||||
|
return_key: null
|
||||||
|
transform:
|
||||||
|
# common Transform pipeline components for all PDB datasets
|
||||||
|
_target_: ${datasets.pipeline_target}
|
||||||
|
is_inference: False
|
||||||
|
protein_msa_dirs: ${paths.data.protein_msa_dirs}
|
||||||
|
rna_msa_dirs: ${paths.data.rna_msa_dirs}
|
||||||
|
n_recycles: ${datasets.n_recycles_train}
|
||||||
|
crop_size: ${datasets.crop_size}
|
||||||
|
n_msa: ${datasets.n_msa}
|
||||||
|
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
|
||||||
|
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
|
||||||
|
|
||||||
|
weights:
|
||||||
|
_target_: datahub.samplers.calculate_weights_for_pdb_dataset_df
|
||||||
|
beta: 0.5
|
||||||
|
alphas:
|
||||||
|
a_prot: 3.0 # 3 for AF-3
|
||||||
|
a_nuc: 0.0 # 3 for AF-3
|
||||||
|
a_ligand: 1.0 # 1 for AF-3
|
||||||
|
a_loi: 5.0 # 5 for AF-3
|
||||||
12
configs/datasets/val/af3_validation.yaml
Normal file
12
configs/datasets/val/af3_validation.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
|
||||||
|
dataset:
|
||||||
|
dataset_parser:
|
||||||
|
_target_: datahub.datasets.parsers.ValidationDFParserLikeAF3
|
||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.PandasDataset
|
||||||
|
data: ${paths.data.pdb_data_dir}/entry_level_val_df.parquet
|
||||||
|
filters:
|
||||||
|
# NOTE: We exclude these examples from validation because they produce an error upon RDKit small molecule processing that causes a data loading fallback
|
||||||
|
- example_id not in ["{['validation']}{7erc}{1}{[]}", "{['validation']}{7qbs}{1}{[]}", "{['validation']}{7z0n}{1}{[]}"]
|
||||||
26
configs/datasets/val/base.yaml
Normal file
26
configs/datasets/val/base.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.StructuralDatasetWrapper
|
||||||
|
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
|
||||||
|
cif_parser_args:
|
||||||
|
cache_dir: null
|
||||||
|
load_from_cache: False
|
||||||
|
save_to_cache: False
|
||||||
|
dataset:
|
||||||
|
_target_: datahub.datasets.datasets.PandasDataset
|
||||||
|
# we will use the example_id as the unique column
|
||||||
|
id_column: example_id
|
||||||
|
# return all keys (do not subset)
|
||||||
|
return_key: null
|
||||||
|
transform:
|
||||||
|
# common Transform pipeline components for all PDB datasets
|
||||||
|
_target_: ${datasets.pipeline_target}
|
||||||
|
is_inference: True
|
||||||
|
protein_msa_dirs: ${paths.data.protein_msa_dirs}
|
||||||
|
rna_msa_dirs: ${paths.data.rna_msa_dirs}
|
||||||
|
n_recycles: ${datasets.n_recycles_validation}
|
||||||
|
crop_size: null # do not crop for inference
|
||||||
|
n_msa: ${datasets.n_msa}
|
||||||
|
diffusion_batch_size: ${datasets.diffusion_batch_size_inference}
|
||||||
|
max_atoms_in_crop: null # do not crop for inference
|
||||||
|
return_atom_array: True # return atom array for inference
|
||||||
|
key_to_balance: ${datasets.key_to_balance}
|
||||||
64
configs/debug/default.yaml
Normal file
64
configs/debug/default.yaml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /logger: null
|
||||||
|
|
||||||
|
# default debugging setup, runs 1 full epoch
|
||||||
|
# other debugging configs can inherit from this one
|
||||||
|
|
||||||
|
# overwrite task name so debugging logs are stored in separate folder
|
||||||
|
task_name: "debug"
|
||||||
|
|
||||||
|
extras:
|
||||||
|
ignore_warnings: False
|
||||||
|
enforce_tags: False
|
||||||
|
|
||||||
|
# sets level of all command line loggers to 'DEBUG'
|
||||||
|
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
|
||||||
|
hydra:
|
||||||
|
job_logging:
|
||||||
|
root:
|
||||||
|
level: DEBUG
|
||||||
|
# use the below to also set hydra loggers to 'DEBUG'
|
||||||
|
verbose: True
|
||||||
|
|
||||||
|
# Print example ID before forward pass
|
||||||
|
callbacks:
|
||||||
|
print_example_id_before_forward_pass:
|
||||||
|
_target_: modelhub.callbacks.train_logging.PrintExampleIDBeforeForwardPassCallback
|
||||||
|
|
||||||
|
dataloader:
|
||||||
|
train:
|
||||||
|
dataloader_params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 0 # debuggers don't like multiprocessing -- work on main thread
|
||||||
|
pin_memory: False # disable gpu memory pin
|
||||||
|
prefetch_factor: null # must be null for num_workers=0
|
||||||
|
n_fallback_retries: 0 # disable fallback retries for debugging
|
||||||
|
|
||||||
|
val:
|
||||||
|
dataloader_params:
|
||||||
|
batch_size: 1
|
||||||
|
num_workers: 0
|
||||||
|
pin_memory: False
|
||||||
|
prefetch_factor: null # must be null for num_workers=0
|
||||||
|
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
crop_size: 100 # set small crop size for quick debugging
|
||||||
|
diffusion_batch_size_train: 1
|
||||||
|
diffusion_batch_size_inference: 1
|
||||||
|
n_recycles_train: 1
|
||||||
|
n_recycles_validation: 1
|
||||||
|
n_msa: 128
|
||||||
|
key_to_balance: null # otherwise big examples will be processed first
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
devices_per_node: 1
|
||||||
|
limit_train_batches: 1
|
||||||
|
limit_val_batches: 1
|
||||||
|
validate_every_n_epochs: 1
|
||||||
|
|
||||||
|
# Set tags to help identify debugging runs
|
||||||
|
tags:
|
||||||
|
- debug
|
||||||
21
configs/debug/train_specific_examples.yaml
Normal file
21
configs/debug/train_specific_examples.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# See: https://hydra.cc/docs/patterns/configuring_experiments/
|
||||||
|
|
||||||
|
# to execute this experiment run:
|
||||||
|
# python train.py +debug=train_single_example [any other arguments]
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- default
|
||||||
|
- gpu
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
# you can add specific example IDs here to load a subset of the dataset (training)
|
||||||
|
subset_to_example_ids:
|
||||||
|
- "{['pdb', 'pn_units']}{3px1}{1}{['A_3']}"
|
||||||
|
val: null
|
||||||
|
|
||||||
|
tags:
|
||||||
|
- debug
|
||||||
|
- train
|
||||||
|
- specific-examples
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3-elements-as-ligand-atom-names
|
||||||
|
|
||||||
|
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||||
|
defaults:
|
||||||
|
- override /trainer: af3
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3
|
||||||
|
|
||||||
|
tags:
|
||||||
|
# list of tags to add to the run ( & on wandb to easily find & filter runs)
|
||||||
|
- atom-names
|
||||||
|
- experiment
|
||||||
|
|
||||||
|
project: af3
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/weights_with_no_confidence_2025_1_21_new_modelhub.ckpt
|
||||||
|
|
||||||
|
model:
|
||||||
|
lr_scheduler:
|
||||||
|
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
pdb:
|
||||||
|
probability: 1.0
|
||||||
|
sub_datasets:
|
||||||
|
interface:
|
||||||
|
dataset:
|
||||||
|
transform:
|
||||||
|
use_element_for_atom_names_of_atomized_tokens: True
|
||||||
|
dataset:
|
||||||
|
# Same as AF-3 training, but limit to protein-ligand interfaces
|
||||||
|
filters:
|
||||||
|
# (from before)
|
||||||
|
- "deposition_date < '2021-09-30'"
|
||||||
|
- "resolution < 9.0"
|
||||||
|
- "num_polymer_pn_units <= 300"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
- >
|
||||||
|
~(pn_unit_1_non_polymer_res_names.notnull() and
|
||||||
|
pn_unit_1_non_polymer_res_names.str.contains(
|
||||||
|
'${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}',
|
||||||
|
regex=True))
|
||||||
|
- >
|
||||||
|
~(pn_unit_2_non_polymer_res_names.notnull() and
|
||||||
|
pn_unit_2_non_polymer_res_names.str.contains(
|
||||||
|
'${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}',
|
||||||
|
regex=True))
|
||||||
|
- "is_inter_molecule"
|
||||||
|
|
||||||
|
# only protein-ligand interfaces
|
||||||
|
- "(n_prot == 1 and n_nuc == 0 and n_ligand == 1)"
|
||||||
|
pn_unit:
|
||||||
|
dataset:
|
||||||
|
transform:
|
||||||
|
use_element_for_atom_names_of_atomized_tokens: True
|
||||||
|
dataset:
|
||||||
|
# Same as AF-3 training, but limit to protein-ligand interfaces
|
||||||
|
filters:
|
||||||
|
# (from before)
|
||||||
|
- "deposition_date < '2021-09-30'"
|
||||||
|
- "resolution < 9.0"
|
||||||
|
- "num_polymer_pn_units <= 300"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:cifutils.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
|
||||||
|
|
||||||
|
# only proteins or ligands
|
||||||
|
- "(n_prot == 1 or n_ligand == 1)"
|
||||||
|
# Datasets set to null are ignored
|
||||||
|
monomer_distillation: null
|
||||||
|
val:
|
||||||
|
af3_validation:
|
||||||
|
dataset:
|
||||||
|
transform:
|
||||||
|
use_element_for_atom_names_of_atomized_tokens: True
|
||||||
|
dataset:
|
||||||
|
filters:
|
||||||
|
- "n_tokens_total < 400"
|
||||||
|
- "interfaces_to_score.str.contains('protein-ligand')"
|
||||||
|
# Exclude example where RDKit errors
|
||||||
|
- example_id not in ["{['validation']}{7qbs}{1}{[]}"]
|
||||||
30
configs/experiment/ncorley/af3-fine-tune-bfloat-msa.yaml
Normal file
30
configs/experiment/ncorley/af3-fine-tune-bfloat-msa.yaml
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3
|
||||||
|
- override /trainer: af3
|
||||||
|
|
||||||
|
tags:
|
||||||
|
- af3
|
||||||
|
- fine-tune
|
||||||
|
|
||||||
|
project: af3
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt
|
||||||
|
|
||||||
|
model:
|
||||||
|
lr_scheduler:
|
||||||
|
base_lr: 0.9e-3 # 1/2 of original learning rate (1.8e-3)
|
||||||
|
|
||||||
|
# Protein-ligand only for speed
|
||||||
|
val:
|
||||||
|
af3_validation:
|
||||||
|
dataset:
|
||||||
|
dataset:
|
||||||
|
filters:
|
||||||
|
# Only score examples with protein-ligand interfaces
|
||||||
|
- "interfaces_to_score.str.contains('protein-ligand')"
|
||||||
|
- example_id not in ["{['validation']}{7qbs}{1}{[]}"]
|
||||||
37
configs/experiment/ncorley/af3-new-msas-pdb-only.yaml
Normal file
37
configs/experiment/ncorley/af3-new-msas-pdb-only.yaml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3-new-msas-pdb-only
|
||||||
|
|
||||||
|
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||||
|
defaults:
|
||||||
|
- override /trainer: af3
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3
|
||||||
|
|
||||||
|
tags:
|
||||||
|
# list of tags to add to the run ( & on wandb to easily find & filter runs)
|
||||||
|
- msas
|
||||||
|
- experiment
|
||||||
|
|
||||||
|
project: af3
|
||||||
|
|
||||||
|
paths:
|
||||||
|
data:
|
||||||
|
protein_msa_dirs:
|
||||||
|
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
pdb:
|
||||||
|
# We must adjust the probability, since we set the monomer distillation dataset to null
|
||||||
|
probability: 1.0
|
||||||
|
# Datasets set to null are ignored
|
||||||
|
monomer_distillation: null
|
||||||
|
val:
|
||||||
|
af3_validation:
|
||||||
|
dataset:
|
||||||
|
dataset:
|
||||||
|
filters:
|
||||||
|
- "n_tokens_total < 400"
|
||||||
13
configs/experiment/ncorley/af3-old-msas-pdb-only.yaml
Normal file
13
configs/experiment/ncorley/af3-old-msas-pdb-only.yaml
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3-old-msas-pdb-only
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- af3-new-msas-pdb-only
|
||||||
|
|
||||||
|
paths:
|
||||||
|
data:
|
||||||
|
protein_msa_dirs:
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
15
configs/experiment/none-00-dummy.yaml
Normal file
15
configs/experiment/none-00-dummy.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# NOTE: Dummy experiment that you can use to just run the code
|
||||||
|
# . For actual experiments, please create a new experiment config from copying the template 'user-XX-template.yaml'
|
||||||
|
|
||||||
|
|
||||||
|
name: none-00-dummy
|
||||||
|
|
||||||
|
tags:
|
||||||
|
# list of tags to add to the run ( & on wandb to easily find & filter runs)
|
||||||
|
- experiment
|
||||||
|
- dummy
|
||||||
|
|
||||||
|
project: test
|
||||||
|
|
||||||
15
configs/experiment/pretrained/af3.yaml
Normal file
15
configs/experiment/pretrained/af3.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3
|
||||||
|
- override /trainer: af3
|
||||||
|
|
||||||
|
tags:
|
||||||
|
- af3
|
||||||
|
|
||||||
|
project: af3
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt
|
||||||
15
configs/experiment/pretrained/af3_with_confidence.yaml
Normal file
15
configs/experiment/pretrained/af3_with_confidence.yaml
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
name: af3-with-confidence
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3_with_confidence
|
||||||
|
- override /trainer: af3_with_confidence
|
||||||
|
|
||||||
|
tags:
|
||||||
|
- af3
|
||||||
|
|
||||||
|
project: af3
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/weights_with_confidence_2025_2_27_new_modelhub.ckpt
|
||||||
14
configs/experiment/quick-af3-with-confidence.yaml
Normal file
14
configs/experiment/quick-af3-with-confidence.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Experiment that loads a small dataset for quick testing
|
||||||
|
|
||||||
|
name: quick-af3-with-confidence
|
||||||
|
|
||||||
|
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||||
|
defaults:
|
||||||
|
- quick-af3
|
||||||
|
- override /model: af3_with_confidence
|
||||||
|
- override /trainer: af3_with_confidence
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/weights_with_confidence_2025_2_27_new_modelhub.ckpt
|
||||||
50
configs/experiment/quick-af3.yaml
Normal file
50
configs/experiment/quick-af3.yaml
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
# Experiment that loads a small dataset for quick testing
|
||||||
|
|
||||||
|
name: quick-af3
|
||||||
|
|
||||||
|
# For explanation of the "override" syntax, see: https://hydra.cc/docs/upgrades/1.0_to_1.1/defaults_list_override/
|
||||||
|
defaults:
|
||||||
|
- override /trainer: af3
|
||||||
|
- override /datasets: af3
|
||||||
|
- override /model: af3
|
||||||
|
|
||||||
|
tags:
|
||||||
|
# list of tags to add to the run ( & on wandb to easily find & filter runs)
|
||||||
|
- quick
|
||||||
|
|
||||||
|
project: test
|
||||||
|
|
||||||
|
ckpt_path: /projects/ml/modelhub/inference/rf2aa-af3-repro7_ep680.pt
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
train:
|
||||||
|
pdb:
|
||||||
|
# We must adjust the probability, since we set the monomer distillation dataset to null
|
||||||
|
probability: 1.0
|
||||||
|
sub_datasets:
|
||||||
|
interface:
|
||||||
|
dataset:
|
||||||
|
dataset:
|
||||||
|
# A small dataframe that loads quickly
|
||||||
|
data: /projects/ml/datahub/dfs/pdb/test_dfs/interfaces_df.parquet
|
||||||
|
filters:
|
||||||
|
- "num_polymer_pn_units <= 2"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
pn_unit:
|
||||||
|
dataset:
|
||||||
|
dataset:
|
||||||
|
# A small dataframe that loads quickly
|
||||||
|
data: /projects/ml/datahub/dfs/pdb/test_dfs/pn_units_df.parquet
|
||||||
|
filters:
|
||||||
|
- "num_polymer_pn_units <= 2"
|
||||||
|
- "cluster.notnull()"
|
||||||
|
# Datasets set to null are ignored
|
||||||
|
monomer_distillation: null
|
||||||
|
val:
|
||||||
|
af3_validation:
|
||||||
|
dataset:
|
||||||
|
dataset:
|
||||||
|
filters:
|
||||||
|
- "n_tokens_total < 200"
|
||||||
18
configs/hydra/default.yaml
Normal file
18
configs/hydra/default.yaml
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
# https://hydra.cc/docs/configure_hydra/intro/
|
||||||
|
|
||||||
|
# enable color logging (requires `colorlog` to be installed)
|
||||||
|
# defaults:
|
||||||
|
# - override hydra_logging: colorlog
|
||||||
|
# - override job_logging: colorlog
|
||||||
|
|
||||||
|
|
||||||
|
# output directory, generated dynamically on each run
|
||||||
|
run:
|
||||||
|
dir: ${paths.log_dir}/${task_name}/${name}/${now:%Y-%m-%d}_${now:%H-%M}
|
||||||
|
|
||||||
|
# ... this is where the log file is written (i.e. the programs output)
|
||||||
|
job_logging:
|
||||||
|
handlers:
|
||||||
|
file:
|
||||||
|
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
||||||
|
filename: ${hydra.runtime.output_dir}/experiment.log
|
||||||
7
configs/hydra/no_logging.yaml
Normal file
7
configs/hydra/no_logging.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
defaults:
|
||||||
|
- override job_logging: disabled
|
||||||
|
- override hydra_logging: disabled
|
||||||
|
|
||||||
|
output_subdir: null
|
||||||
|
run:
|
||||||
|
dir: .
|
||||||
7
configs/inference.yaml
Normal file
7
configs/inference.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
# @package _global_
|
||||||
|
# ^ The "package" determines where the content of the config is placed in the output config
|
||||||
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- inference_engine: ???
|
||||||
|
- _self_
|
||||||
22
configs/inference_engine/af3.yaml
Normal file
22
configs/inference_engine/af3.yaml
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- base
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
_target_: modelhub.inference_engines.af3.AF3InferenceEngine
|
||||||
|
|
||||||
|
ckpt_path: /net/tukwila/ncorley/modelhub/inference/modelhub_latest.ckpt
|
||||||
|
|
||||||
|
n_recycles: 10
|
||||||
|
diffusion_batch_size: 5
|
||||||
|
residue_renaming_dict: null
|
||||||
|
num_steps: 50
|
||||||
|
solver: "af3"
|
||||||
|
print_config: true
|
||||||
|
seed: null
|
||||||
|
skip_existing: true
|
||||||
|
|
||||||
|
dump_predictions: true
|
||||||
|
dump_trajectories: false
|
||||||
|
one_model_per_file: false
|
||||||
10
configs/inference_engine/base.yaml
Normal file
10
configs/inference_engine/base.yaml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
# @package _global_
|
||||||
|
|
||||||
|
defaults:
|
||||||
|
- /hydra: no_logging
|
||||||
|
|
||||||
|
ckpt_path: ???
|
||||||
|
inputs: ???
|
||||||
|
out_dir: ./
|
||||||
|
num_nodes: 1
|
||||||
|
devices_per_node: 1
|
||||||
6
configs/logger/csv.yaml
Normal file
6
configs/logger/csv.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# https://lightning.ai/docs/fabric/latest/api/generated/lightning.fabric.loggers.CSVLogger.html#lightning.fabric.loggers.CSVLogger
|
||||||
|
|
||||||
|
csv:
|
||||||
|
_target_: lightning.fabric.loggers.CSVLogger
|
||||||
|
root_dir: ${paths.output_dir}
|
||||||
|
flush_logs_every_n_steps: 1
|
||||||
3
configs/logger/default.yaml
Normal file
3
configs/logger/default.yaml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
defaults:
|
||||||
|
- wandb
|
||||||
|
- csv
|
||||||
14
configs/logger/wandb.yaml
Normal file
14
configs/logger/wandb.yaml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
# https://wandb.ai
|
||||||
|
|
||||||
|
wandb:
|
||||||
|
_target_: wandb.integration.lightning.fabric.WandbLogger
|
||||||
|
save_dir: ${paths.output_dir}
|
||||||
|
offline: False
|
||||||
|
id: null # pass correct id (along with checkpoint path, and resume='allow' or 'must') to resume a run
|
||||||
|
anonymous: null # enable anonymous logging
|
||||||
|
project: ${project}
|
||||||
|
prefix: "" # a string to put at the beginning of metric keys
|
||||||
|
log_model: False # do not upload model checkpoints
|
||||||
|
tags: ${tags}
|
||||||
|
# (Default resume to "never" to avoid accidentally resuming runs; we want to be explicit about resuming)
|
||||||
|
resume: never # never, allow, or must (see: https://docs.wandb.ai/guides/runs/resuming/)
|
||||||
7
configs/model/af3.yaml
Normal file
7
configs/model/af3.yaml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
defaults:
|
||||||
|
- optimizers/adam@optimizer
|
||||||
|
- schedulers/af3@lr_scheduler
|
||||||
|
- components/ema@ema
|
||||||
|
- components/af3_net@net
|
||||||
|
|
||||||
|
|
||||||
5
configs/model/af3_with_confidence.yaml
Normal file
5
configs/model/af3_with_confidence.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- af3
|
||||||
|
- components/af3_net_with_confidence_head@net
|
||||||
|
|
||||||
|
|
||||||
177
configs/model/components/af3_net.yaml
Normal file
177
configs/model/components/af3_net.yaml
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
# Model architecture
|
||||||
|
_target_: modelhub.model.AF3.AF3
|
||||||
|
|
||||||
|
# +---------- Channel dimensions ----------+
|
||||||
|
c_s: 384
|
||||||
|
c_z: 128
|
||||||
|
c_atom: 128
|
||||||
|
c_atompair: 16
|
||||||
|
c_s_inputs: 449 # TODO: What is this?
|
||||||
|
|
||||||
|
# +---------- Feature embedding ----------+
|
||||||
|
feature_initializer:
|
||||||
|
# InputFeatureEmbedder
|
||||||
|
input_feature_embedder:
|
||||||
|
features:
|
||||||
|
- restype
|
||||||
|
- profile
|
||||||
|
- deletion_mean
|
||||||
|
atom_attention_encoder:
|
||||||
|
c_token: 384
|
||||||
|
c_atom_1d_features: 389
|
||||||
|
c_tokenpair: ${model.net.c_z}
|
||||||
|
atom_1d_features:
|
||||||
|
- ref_pos
|
||||||
|
- ref_charge
|
||||||
|
- ref_mask
|
||||||
|
- ref_element
|
||||||
|
- ref_atom_name_chars
|
||||||
|
atom_transformer:
|
||||||
|
n_queries: 32
|
||||||
|
n_keys: 128
|
||||||
|
l_max: 40_000 # does not matter
|
||||||
|
diffusion_transformer:
|
||||||
|
n_block: 3
|
||||||
|
diffusion_transformer_block:
|
||||||
|
n_head: 4
|
||||||
|
no_residual_connection_between_attention_and_transition: true
|
||||||
|
kq_norm: false
|
||||||
|
|
||||||
|
# RelativePositionEncoding
|
||||||
|
relative_position_encoding:
|
||||||
|
r_max: 32
|
||||||
|
s_max: 2
|
||||||
|
|
||||||
|
# +---------- Recycler ----------+
|
||||||
|
recycler:
|
||||||
|
# Pairformer
|
||||||
|
n_pairformer_blocks: 48
|
||||||
|
pairformer_block:
|
||||||
|
p_drop: 0.25
|
||||||
|
triangle_multiplication:
|
||||||
|
d_hidden: 128
|
||||||
|
triangle_attention:
|
||||||
|
n_head: 4
|
||||||
|
d_hidden: 32
|
||||||
|
attention_pair_bias:
|
||||||
|
n_head: 16
|
||||||
|
|
||||||
|
# TemplateEmbedder
|
||||||
|
template_embedder:
|
||||||
|
n_block: 2
|
||||||
|
raw_template_dim: 108
|
||||||
|
c: 64
|
||||||
|
p_drop: 0.25
|
||||||
|
|
||||||
|
# MSA module
|
||||||
|
msa_module:
|
||||||
|
n_block: 4
|
||||||
|
c_m: 64
|
||||||
|
p_drop_msa: 0.15
|
||||||
|
p_drop_pair: 0.25
|
||||||
|
msa_subsample_embedder:
|
||||||
|
num_sequences: 1024
|
||||||
|
dim_raw_msa: 34
|
||||||
|
c_s_inputs: ${model.net.c_s_inputs}
|
||||||
|
c_msa_embed: ${model.net.recycler.msa_module.c_m}
|
||||||
|
outer_product:
|
||||||
|
c_msa_embed: ${model.net.recycler.msa_module.c_m}
|
||||||
|
c_outer_product: 32
|
||||||
|
c_out: ${model.net.c_z}
|
||||||
|
msa_pair_weighted_averaging:
|
||||||
|
n_heads: 8
|
||||||
|
c_weighted_average: 32
|
||||||
|
c_msa_embed: ${model.net.recycler.msa_module.c_m}
|
||||||
|
c_z: ${model.net.c_z}
|
||||||
|
separate_gate_for_every_channel: true
|
||||||
|
msa_transition:
|
||||||
|
n: 4
|
||||||
|
c: ${model.net.recycler.msa_module.c_m}
|
||||||
|
triangle_multiplication_outgoing:
|
||||||
|
d_pair: ${model.net.c_z}
|
||||||
|
d_hidden: 128
|
||||||
|
bias: True
|
||||||
|
triangle_multiplication_incoming:
|
||||||
|
d_pair: ${model.net.c_z}
|
||||||
|
d_hidden: 128
|
||||||
|
bias: True
|
||||||
|
triangle_attention_starting:
|
||||||
|
d_pair: ${model.net.c_z}
|
||||||
|
n_head: 4
|
||||||
|
d_hidden: 32
|
||||||
|
p_drop: 0.0 # This does not do anything: TODO: Remove
|
||||||
|
triangle_attention_ending:
|
||||||
|
d_pair: ${model.net.c_z}
|
||||||
|
n_head: 4
|
||||||
|
d_hidden: 32
|
||||||
|
p_drop: 0.0 # This does not do anything; TODO: Remove
|
||||||
|
pair_transition:
|
||||||
|
n: 4
|
||||||
|
c: ${model.net.c_z}
|
||||||
|
|
||||||
|
# +---------- Diffusion module ----------+
|
||||||
|
diffusion_module:
|
||||||
|
sigma_data: 16
|
||||||
|
c_token: 768
|
||||||
|
f_pred: edm
|
||||||
|
diffusion_conditioning:
|
||||||
|
c_s_inputs: ${model.net.c_s_inputs}
|
||||||
|
c_t_embed: 256
|
||||||
|
relative_position_encoding:
|
||||||
|
r_max: 32
|
||||||
|
s_max: 2
|
||||||
|
atom_attention_encoder:
|
||||||
|
c_tokenpair: ${model.net.c_z}
|
||||||
|
c_atom_1d_features: 389
|
||||||
|
atom_1d_features:
|
||||||
|
- ref_pos
|
||||||
|
- ref_charge
|
||||||
|
- ref_mask
|
||||||
|
- ref_element
|
||||||
|
- ref_atom_name_chars
|
||||||
|
atom_transformer:
|
||||||
|
n_queries: 32
|
||||||
|
n_keys: 128
|
||||||
|
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
|
||||||
|
diffusion_transformer:
|
||||||
|
n_block: 3
|
||||||
|
diffusion_transformer_block:
|
||||||
|
n_head: 4
|
||||||
|
no_residual_connection_between_attention_and_transition: true
|
||||||
|
kq_norm: false
|
||||||
|
broadcast_trunk_feats_on_1dim_old: false
|
||||||
|
use_chiral_features: true
|
||||||
|
diffusion_transformer:
|
||||||
|
n_block: 24
|
||||||
|
diffusion_transformer_block:
|
||||||
|
n_head: 16
|
||||||
|
no_residual_connection_between_attention_and_transition: true
|
||||||
|
kq_norm: true
|
||||||
|
atom_attention_decoder:
|
||||||
|
atom_transformer:
|
||||||
|
n_queries: 32
|
||||||
|
n_keys: 128
|
||||||
|
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
|
||||||
|
diffusion_transformer:
|
||||||
|
n_block: 3
|
||||||
|
diffusion_transformer_block:
|
||||||
|
n_head: 4
|
||||||
|
no_residual_connection_between_attention_and_transition: true
|
||||||
|
kq_norm: false
|
||||||
|
distogram_head:
|
||||||
|
bins: 65
|
||||||
|
|
||||||
|
# +---------- Inference sampler ----------+
|
||||||
|
inference_sampler:
|
||||||
|
solver: "af3"
|
||||||
|
num_timesteps: 200
|
||||||
|
min_t: 0
|
||||||
|
max_t: 1
|
||||||
|
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||||
|
s_min: 4e-4
|
||||||
|
s_max: 160
|
||||||
|
p: 7
|
||||||
|
gamma_0: 0.8
|
||||||
|
gamma_min: 1.0
|
||||||
|
noise_scale: 1.003
|
||||||
|
step_scale: 1.5
|
||||||
45
configs/model/components/af3_net_with_confidence_head.yaml
Normal file
45
configs/model/components/af3_net_with_confidence_head.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
defaults:
|
||||||
|
- af3_net
|
||||||
|
|
||||||
|
# Model architecture
|
||||||
|
_target_: modelhub.model.AF3.AF3WithConfidence
|
||||||
|
|
||||||
|
# +---------- Mini rollout sampler ----------+
|
||||||
|
# From the AF-3 main text:
|
||||||
|
# > ...To remedy this, we developed a diffusion ‘rollout’ procedure for the full-structure prediction generation during training (using a larger step size than normal)
|
||||||
|
# They do not further elaborate on how they adjusted the step size during diffusion rollout, but this may be a fruitful area of exploration moving forwards
|
||||||
|
mini_rollout_sampler:
|
||||||
|
solver: "af3"
|
||||||
|
num_timesteps: 20 # 20 timesteps for the mini-rollout (vs. 200 for the full rollout during inference)
|
||||||
|
min_t: 0
|
||||||
|
max_t: 1
|
||||||
|
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||||
|
s_min: 4e-4
|
||||||
|
s_max: 160
|
||||||
|
p: 7
|
||||||
|
gamma_0: 0.8
|
||||||
|
gamma_min: 1.0
|
||||||
|
noise_scale: 1.003
|
||||||
|
step_scale: 1.5
|
||||||
|
|
||||||
|
# +---------- Confidence head architecture ----------+
|
||||||
|
confidence_head:
|
||||||
|
c_s: ${model.net.c_s}
|
||||||
|
c_z: ${model.net.c_z}
|
||||||
|
n_pairformer_layers: 4
|
||||||
|
pairformer:
|
||||||
|
p_drop: 0.25
|
||||||
|
triangle_multiplication:
|
||||||
|
d_hidden: 128
|
||||||
|
triangle_attention:
|
||||||
|
n_head: 4
|
||||||
|
d_hidden: 32
|
||||||
|
attention_pair_bias:
|
||||||
|
n_head: 16
|
||||||
|
n_bins_pae: 64
|
||||||
|
n_bins_pde: 64
|
||||||
|
n_bins_plddt: 50
|
||||||
|
n_bins_exp_resolved: 2
|
||||||
|
use_Cb_distances: False
|
||||||
|
use_af3_style_binning_and_final_layer_norms: True
|
||||||
|
symmetrize_Cb_logits: True
|
||||||
1
configs/model/components/ema.yaml
Normal file
1
configs/model/components/ema.yaml
Normal file
@@ -0,0 +1 @@
|
|||||||
|
decay: 0.999 # From AF-3
|
||||||
5
configs/model/optimizers/adam.yaml
Normal file
5
configs/model/optimizers/adam.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Optimizer
|
||||||
|
_target_: torch.optim.Adam
|
||||||
|
lr: 0 # Will be set by the scheduler (starts at 0, increasing to `base_lr`)
|
||||||
|
betas: [0.9, 0.95]
|
||||||
|
eps: 1.0e-8
|
||||||
6
configs/model/schedulers/af3.yaml
Normal file
6
configs/model/schedulers/af3.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
# Learning rate scheduler
|
||||||
|
_target_: modelhub.training.schedulers.AF3Scheduler
|
||||||
|
base_lr: 1.8e-3
|
||||||
|
warmup_steps: 1000
|
||||||
|
decay_factor: 0.95
|
||||||
|
decay_steps: 50000
|
||||||
23
configs/paths/data/default.yaml
Normal file
23
configs/paths/data/default.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# path to directory with training splits
|
||||||
|
pdb_data_dir: /projects/ml/datahub/dfs/af3_splits/2024_12_16/
|
||||||
|
|
||||||
|
# fb monomer distillation dataset
|
||||||
|
monomer_distillation_data_dir: /squash/af2_distillation_facebook/
|
||||||
|
monomer_distillation_parquet_dir: /projects/ml/datahub/dfs/distillation/af2_distillation_facebook
|
||||||
|
|
||||||
|
# path(s) to search for protein MSAs (for PDB datasets)
|
||||||
|
protein_msa_dirs:
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_protein_msas", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/missing_msas_through_2024_08_12", "extension": ".msa0.a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/net/scratch/mkazman/msa/validate_no_leak_taxid", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/net/scratch/mkazman/msa/missing_antibody_msas", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/net/scratch/mkazman/msa/post_training_cutoff_msas/processed_nested", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/net/scratch/mkazman/msa/post_training_cutoff_msas/extra_seqs_processed_nested", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
- {"dir": "/projects/msa/nvidia_renamed_with_seq_hash/maxseq_10k", "extension": ".a3m.gz", "directory_depth": 2}
|
||||||
|
|
||||||
|
# path(s) to search for RNA MSAs
|
||||||
|
rna_msa_dirs:
|
||||||
|
- {"dir": "/projects/msa/rf2aa_af3/rf2aa_paper_model_rna_msas", "extension": ".afa", "directory_depth": 0}
|
||||||
|
|
||||||
|
# path to save examples that fail during the Transform pipeline (null = do not save)
|
||||||
|
failed_examples_dir: null
|
||||||
21
configs/paths/default.yaml
Normal file
21
configs/paths/default.yaml
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
||||||
|
defaults:
|
||||||
|
- _self_
|
||||||
|
- data: default
|
||||||
|
|
||||||
|
# path to root directory (requires the `PROJECT_ROOT` environment variable to be set)
|
||||||
|
# NOTE: This variable is auto-set upon loading via `rootutils`
|
||||||
|
root_dir: ${oc.env:PROJECT_ROOT}
|
||||||
|
|
||||||
|
# where to store data (checkpoints, logs, etc.) of all experiments in general
|
||||||
|
# (this influences the output_dir in the hydra/default.yaml config)
|
||||||
|
# change this to e.g. /scratch if you are running larger experiments with lots lof logs, checkpoints, etc.
|
||||||
|
log_dir: ${.root_dir}/logs/
|
||||||
|
|
||||||
|
# path to output directory for this specific run, created dynamically by hydra
|
||||||
|
# path generation pattern is specified in `configs/hydra/default.yaml`
|
||||||
|
# use it to store all files generated during the run, like ckpts and metrics
|
||||||
|
output_dir: ${hydra:runtime.output_dir}
|
||||||
|
|
||||||
|
# path to working directory (auto-generated by hydra)
|
||||||
|
work_dir: ${hydra:runtime.cwd}
|
||||||
42
configs/train.yaml
Normal file
42
configs/train.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# @package _global_
|
||||||
|
# ^ The "package" determines where the content of the config is placed in the output config
|
||||||
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||||
|
|
||||||
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
||||||
|
defaults:
|
||||||
|
- callbacks: default
|
||||||
|
- logger: csv
|
||||||
|
- trainer: ???
|
||||||
|
- paths: default
|
||||||
|
- datasets: ???
|
||||||
|
- dataloader: default
|
||||||
|
- hydra: default
|
||||||
|
- model: ???
|
||||||
|
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
# experiment configs allow for version control of specific hyperparameters
|
||||||
|
# e.g. best hyperparameters for given model and datamodule
|
||||||
|
- experiment: ???
|
||||||
|
|
||||||
|
# debug configs to add onto any experiment for quickly testing or debugging code
|
||||||
|
- debug: null
|
||||||
|
|
||||||
|
|
||||||
|
# DO NOT set these here. Set them in the relevant experiment config file.
|
||||||
|
# ... these are just here to ensure users always specify these fields in their experiment configs.
|
||||||
|
name: ???
|
||||||
|
tags: ???
|
||||||
|
|
||||||
|
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
|
||||||
|
# here.
|
||||||
|
# ... task name (determines the output directory path)
|
||||||
|
task_name: "train"
|
||||||
|
|
||||||
|
project: ??? # required for W&B logging
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
|
||||||
|
# Provide checkpoint path to resume training from a checkpoint
|
||||||
|
# NOTE: If using W&B, must also set the `id` and `resume` fields in the `logger/wandb` config
|
||||||
|
ckpt_path: null
|
||||||
20
configs/trainer/af3.yaml
Normal file
20
configs/trainer/af3.yaml
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
defaults:
|
||||||
|
- ddp
|
||||||
|
- loss: structure_prediction
|
||||||
|
- metrics: structure_prediction
|
||||||
|
|
||||||
|
_target_: modelhub.trainers.af3.AF3Trainer
|
||||||
|
validate_every_n_epochs: 1
|
||||||
|
max_epochs: 10_000
|
||||||
|
n_examples_per_epoch: 24000
|
||||||
|
prevalidate: True
|
||||||
|
|
||||||
|
# We must pre-specify the number of recycles during training so we can pre-sample recycles per batch consistently for each GPU
|
||||||
|
n_recycles_train: ${datasets.n_recycles_train}
|
||||||
|
|
||||||
|
clip_grad_max_norm: 10.0
|
||||||
|
|
||||||
|
output_dir: ${paths.output_dir}
|
||||||
|
checkpoint_every_n_epochs: 1
|
||||||
|
|
||||||
|
# precision: bf16-mixed # Mixed precision training with bfloat16 (currently does not work)
|
||||||
5
configs/trainer/af3_with_confidence.yaml
Normal file
5
configs/trainer/af3_with_confidence.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
defaults:
|
||||||
|
- af3
|
||||||
|
- override loss: structure_prediction_with_confidence
|
||||||
|
|
||||||
|
_target_: modelhub.trainers.af3.AF3TrainerWithConfidence
|
||||||
6
configs/trainer/cpu.yaml
Normal file
6
configs/trainer/cpu.yaml
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
defaults:
|
||||||
|
- af3
|
||||||
|
|
||||||
|
accelerator: cpu
|
||||||
|
devices_per_node: 1
|
||||||
|
num_nodes: 1
|
||||||
5
configs/trainer/ddp.yaml
Normal file
5
configs/trainer/ddp.yaml
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
strategy: ddp
|
||||||
|
|
||||||
|
accelerator: gpu
|
||||||
|
devices_per_node: 1
|
||||||
|
num_nodes: 1
|
||||||
29
configs/trainer/loss/losses/confidence_loss.yaml
Normal file
29
configs/trainer/loss/losses/confidence_loss.yaml
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
_target_: modelhub.loss.af3_confidence_loss.ConfidenceLoss
|
||||||
|
weight: 1.0
|
||||||
|
|
||||||
|
plddt:
|
||||||
|
weight: 1.0
|
||||||
|
n_bins: 50
|
||||||
|
max_value: 1.0
|
||||||
|
|
||||||
|
pae:
|
||||||
|
weight: 1.0
|
||||||
|
n_bins: 64
|
||||||
|
max_value: 32
|
||||||
|
|
||||||
|
pde:
|
||||||
|
weight: 1.0
|
||||||
|
n_bins: 64
|
||||||
|
max_value: 32
|
||||||
|
|
||||||
|
exp_resolved:
|
||||||
|
weight: 1.0
|
||||||
|
n_bins: 2
|
||||||
|
max_value: 1
|
||||||
|
|
||||||
|
# Adds to loss_dict true and predicted average plddt, pae, and pde per batch, also info about the spread and correlation of those values within a batch
|
||||||
|
log_statistics: True
|
||||||
|
|
||||||
|
rank_loss:
|
||||||
|
use_listnet_loss: False
|
||||||
|
weight: 0.0
|
||||||
9
configs/trainer/loss/losses/diffusion_loss.yaml
Normal file
9
configs/trainer/loss/losses/diffusion_loss.yaml
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
_target_: modelhub.loss.af3_losses.DiffusionLoss
|
||||||
|
weight: 4.0
|
||||||
|
sigma_data: ${model.net.diffusion_module.sigma_data}
|
||||||
|
alpha_dna: 5
|
||||||
|
alpha_rna: 5
|
||||||
|
alpha_ligand: 10
|
||||||
|
edm_lambda: True
|
||||||
|
se3_invariant_loss: True
|
||||||
|
clamp_diffusion_loss: False
|
||||||
2
configs/trainer/loss/losses/distogram_loss.yaml
Normal file
2
configs/trainer/loss/losses/distogram_loss.yaml
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
_target_: modelhub.loss.af3_losses.DistogramLoss
|
||||||
|
weight: 3e-2
|
||||||
4
configs/trainer/loss/structure_prediction.yaml
Normal file
4
configs/trainer/loss/structure_prediction.yaml
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
defaults:
|
||||||
|
# Note that the SmoothedLDDTLoss is included within the DiffusionLoss
|
||||||
|
- losses/diffusion_loss@diffusion_loss
|
||||||
|
- losses/distogram_loss@distogram_loss
|
||||||
@@ -0,0 +1,2 @@
|
|||||||
|
defaults:
|
||||||
|
- losses/confidence_loss@confidence_loss
|
||||||
8
configs/trainer/metrics/structure_prediction.yaml
Normal file
8
configs/trainer/metrics/structure_prediction.yaml
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
by_type_lddt:
|
||||||
|
_target_: modelhub.metrics.lddt.ByTypeLDDT
|
||||||
|
all_atom_lddt:
|
||||||
|
_target_: modelhub.metrics.lddt.AllAtomLDDT
|
||||||
|
distogram:
|
||||||
|
_target_: modelhub.metrics.distogram.DistogramLoss
|
||||||
|
distogram_comparisons:
|
||||||
|
_target_: modelhub.metrics.distogram.DistogramComparisons
|
||||||
49
configs/validate.yaml
Normal file
49
configs/validate.yaml
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
# @package _global_
|
||||||
|
# ^ The "package" determines where the content of the config is placed in the output config
|
||||||
|
# For more information about overriding configs, see: https://hydra.cc/docs/advanced/overriding_packages/#overriding-packages-using-the-defaults-list
|
||||||
|
|
||||||
|
# NOTE: order of defaults determines the order in which configs override each other (higher up items are overridden by lower items)
|
||||||
|
defaults:
|
||||||
|
- callbacks: default
|
||||||
|
- logger: csv
|
||||||
|
- trainer: ???
|
||||||
|
- paths: default
|
||||||
|
- datasets: ???
|
||||||
|
- dataloader: default
|
||||||
|
- hydra: default
|
||||||
|
- model: ???
|
||||||
|
# We must keep _self_ before experiment and debug to ensure that the experiment and debug configs can override
|
||||||
|
- _self_
|
||||||
|
|
||||||
|
# experiment configs allow for version control of specific hyperparameters
|
||||||
|
# e.g. best hyperparameters for given model and datamodule
|
||||||
|
- experiment: ???
|
||||||
|
|
||||||
|
# debug configs to add onto any experiment for quickly testing or debugging code
|
||||||
|
- debug: null
|
||||||
|
|
||||||
|
|
||||||
|
# DO NOT set these here. Set them in the relevant experiment config file.
|
||||||
|
# ... these are just here to ensure users always specify these fields in their experiment configs.
|
||||||
|
name: ???
|
||||||
|
tags: ???
|
||||||
|
|
||||||
|
# NOTE: These values will be overwritten by the experiment config if they are set there. They are just provided as defaults
|
||||||
|
# here.
|
||||||
|
# ... task name (determines the output directory path)
|
||||||
|
task_name: "validate"
|
||||||
|
|
||||||
|
project: ??? # required for W&B logging
|
||||||
|
|
||||||
|
seed: 1
|
||||||
|
|
||||||
|
# Dump CIF files for validation structures
|
||||||
|
callbacks:
|
||||||
|
dump_validation_structures_callback:
|
||||||
|
dump_predictions: True
|
||||||
|
one_model_per_file: False
|
||||||
|
dump_trajectories: False
|
||||||
|
|
||||||
|
# passing checkpoint path required for validation
|
||||||
|
# DO NOT set here; set in the experiment config file
|
||||||
|
ckpt_path: ???
|
||||||
@@ -6,53 +6,80 @@ channels:
|
|||||||
- conda-forge
|
- conda-forge
|
||||||
- defaults
|
- defaults
|
||||||
dependencies:
|
dependencies:
|
||||||
|
# Core dependencies
|
||||||
|
- pip
|
||||||
- python=3.11
|
- python=3.11
|
||||||
- cuda
|
- cuda
|
||||||
- pytorch=2.4
|
- pytorch=2.4
|
||||||
- pytorch-cuda=12.4
|
- pytorch-cuda=12.4
|
||||||
- pytorch-scatter>=2.1.0,<3
|
- pytorch-scatter>=2.1.0,<3
|
||||||
- lightning>=2.4.0,<2.5
|
- lightning>=2.4.0,<2.5
|
||||||
- pandas>=1.4.2,<2.3
|
# Small molecule libraries
|
||||||
- numpy>=1.25.0,<2.1
|
|
||||||
- scipy>=1.13.1,<2
|
|
||||||
- cytoolz>=0.12.3,<1
|
|
||||||
- biopython>=1.83,<2
|
|
||||||
- fire>=0.6.0,<1
|
|
||||||
- ruff>=0.6.2
|
|
||||||
- pytest-dotenv>=0.5.2,<1
|
|
||||||
- pytest-cov>=4.1.0,<5
|
|
||||||
- rdkit>=2024.3.5
|
- rdkit>=2024.3.5
|
||||||
- openbabel=3.1.1
|
- openbabel=3.1.1
|
||||||
- pip
|
|
||||||
- pip:
|
- pip:
|
||||||
- biotite>=1.1.0,<1.2
|
# Project-related dependencies
|
||||||
- seaborn>=0.13.0,<1
|
# ... generic tools
|
||||||
- loguru>=0.7.0,<1
|
- GitPython>=3.0.0,<4 # GitPython is a Python library used to interact with Git repositories
|
||||||
- beartype>=0.18.0,<1
|
- cython>=3.0.0,<4 # Cython compiler for C extensions
|
||||||
|
- cytoolz>=0.12.3,<1 # Cython-optimized tools for itertools and functional programming
|
||||||
|
- assertpy>=1.1.0,<2 # Assertions library
|
||||||
|
- tqdm>=4.65.0,<5 # Fast, extensible progress bar for loops and more
|
||||||
|
- rootutils>=1.0.7,<1.1 # Setting up the project root paths
|
||||||
|
- dm-tree>=0.1.6,<1 # Tree data structure from DeepMind
|
||||||
|
- deepdiff>=8.0.0,<9 # Deep difference and search of any Python object
|
||||||
|
# ... configuration & CLI
|
||||||
|
- fire>=0.6.0,<1 # Better argument parsing than argparse
|
||||||
|
- hydra-core>=1.3.0,<1.4 # Config management framework
|
||||||
|
- environs>=11.0.0,<12
|
||||||
|
# ... linear algebra, maths & ml
|
||||||
|
- numpy>=1.25.0,<2
|
||||||
|
- scipy>=1.13.1,<2
|
||||||
- einops>=0.8.0,<1
|
- einops>=0.8.0,<1
|
||||||
- einx>=0.1.0,<1
|
- einx>=0.1.0,<1
|
||||||
- debugpy>=1.8.5,<2
|
|
||||||
- cython>=3.0.0,<4
|
|
||||||
- pytest>=8.2.0,<9
|
|
||||||
- assertpy>=1.1.0,<2
|
|
||||||
- pre-commit>=3.7.1
|
|
||||||
- tqdm>=4.65.0,<5
|
|
||||||
- py3Dmol>=2.2.1,<3
|
|
||||||
- pyarrow>=17.0.0
|
|
||||||
- fastparquet>=2024.5.0
|
|
||||||
- ipykernel>=6.29.4,<7
|
|
||||||
- jaxtyping>=0.2.17,<1
|
|
||||||
- hydra-core>=1.3.0,<1.4
|
|
||||||
- wandb>=0.15.10,<1
|
|
||||||
- environs>=11.0.0,<12
|
|
||||||
- rootutils>=1.0.7,<1.1
|
|
||||||
- opt_einsum>=3.4.0,<4
|
- opt_einsum>=3.4.0,<4
|
||||||
- rich>=13.9.4,<14
|
- deepspeed>=0.15.1 # will be uninstalled by the apptainer's `spec` file, if pre-compiling
|
||||||
- msgpack>=1.1.0,<2
|
# ... data tools
|
||||||
- pymol-remote>=0.1.0
|
- pandas>=2.2,<2.3 # Data manipulation and analysis
|
||||||
- deepspeed>=0.15.1
|
- pyarrow==17.0.0 # Columnar data format for efficient data storage and processing
|
||||||
|
- fastparquet==2024.5.0 # Fast Parquet file format implementation
|
||||||
|
- seaborn>=0.13.0,<1
|
||||||
|
# ... bioinformatics
|
||||||
|
- biopython>=1.83,<2 # Collection of Python modules for bioinformatics
|
||||||
|
- py3Dmol>=2.2.1,<3 # Python wrapper for 3Dmol.js
|
||||||
|
- pymol-remote>=0.0.5 # Remote access to PyMOL from Python (has no dependencies)
|
||||||
- git+https://github.com/biotite-dev/biotite.git@fab175e7ba4608d9613f092ad4e080661c6cc816
|
- git+https://github.com/biotite-dev/biotite.git@fab175e7ba4608d9613f092ad4e080661c6cc816
|
||||||
- GitPython>=3.0.0,<4 # Git library for Python
|
- hydride==1.2.3 #biotite supported hydrogen addition
|
||||||
|
# ... logging
|
||||||
|
- wandb>=0.15.10,<1
|
||||||
|
- rich>=13.9.4,<14
|
||||||
|
|
||||||
# NOTE: After navigating to the datahub / cifutils directories, you can install the local package in editable mode with:
|
# Formatting & linting (only needed for development)
|
||||||
|
- ruff==0.8.3 # python linter & formatter
|
||||||
|
- pre-commit==3.7.1 # pre-commit hooks for formatting & linting
|
||||||
|
|
||||||
|
# Debugger & interactive tools (only needed for development)
|
||||||
|
- debugpy>=1.8.5,<2 # debugger for python
|
||||||
|
- ipykernel>=6.29.4,<7 # ipython kernel for jupyter
|
||||||
|
- icecream>=2.0.0,<3 # print debugging
|
||||||
|
- pymol-remote>=0.1.0 # Remote access to PyMOL from Python (has no dependencies)
|
||||||
|
- ipdb>=0.13.9 # IPython debugger
|
||||||
|
|
||||||
|
# Pytest plugins (only needed for development)
|
||||||
|
- pytest>=8.2.0,<9 # testing framework
|
||||||
|
- pytest-testmon>=2.1.1,<3 # run only tests related to changed code
|
||||||
|
- pytest-xdist>=3.6.1,<4 # run tests in parallel
|
||||||
|
- pytest-dotenv>=0.5.2,<1 # load environment variables from .env file
|
||||||
|
- pytest-cov>=4.1.0,<5 # generate coverage report
|
||||||
|
- pytest-benchmark>=5.0.0,<6 # benchmark tests for speed
|
||||||
|
|
||||||
|
# Typing & documentation (only needed for development)
|
||||||
|
- jaxtyping>=0.2.17,<1
|
||||||
|
- beartype>=0.18.0,<1
|
||||||
|
|
||||||
|
# NOTE: After navigating to the datahub / cifutils / modelhub directories, you can install the local package in editable mode with:
|
||||||
# pip install -e .
|
# pip install -e .
|
||||||
|
|
||||||
|
# NOTE: By default, DeepSpeed just-in-time compiles, which may take 3-4 minutes when first running the code on a new machine.
|
||||||
|
# It may be possible to pre-compile DeepSpeed within a `conda` environment; see: https://www.deepspeed.ai/tutorials/advanced-install/
|
||||||
|
# By default, the apptainers will have DeepSpeed pre-compiled, so when performance is a concern, it is recommended to use the apptainers.
|
||||||
|
|||||||
115
freeze_apptainer.spec
Normal file
115
freeze_apptainer.spec
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
Bootstrap: localimage
|
||||||
|
From: ./scripts/shebang/modelhub.sif
|
||||||
|
IncludeCmd: yes
|
||||||
|
# NOTE: This apptainer was written using apptainer version `1.1.6+2-g6808b5172-ipd`
|
||||||
|
|
||||||
|
%setup
|
||||||
|
# NOTE: This is executed on the host, not the container
|
||||||
|
# Ensure the token environment variables are set
|
||||||
|
set +x # ... supress bash output to avoid printing the tokens in the output
|
||||||
|
for var in GITHUB_USER GITHUB_TOKEN; do
|
||||||
|
if [ -z "$(eval echo \$$var)" ]; then
|
||||||
|
set -x
|
||||||
|
echo "ERROR: $var is not set. Please create a personal access token at"
|
||||||
|
echo " - GitHub: https://github.com/settings/tokens"
|
||||||
|
echo "Then set the following environment variables:"
|
||||||
|
echo " - GITHUB_USER"
|
||||||
|
echo " - GITHUB_TOKEN"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
set -x
|
||||||
|
# Create temporary `secrets.txt` file from host's environment variables in the container
|
||||||
|
# (which are otherwise not available in the %post section)
|
||||||
|
echo "Creating temporary secrets.txt file with access tokens in the container"
|
||||||
|
set +x
|
||||||
|
touch ${APPTAINER_ROOTFS}/secrets.txt
|
||||||
|
echo "GITHUB_USER=${GITHUB_USER}" >> ${APPTAINER_ROOTFS}/secrets.txt
|
||||||
|
echo "GITHUB_TOKEN=${GITHUB_TOKEN}" >> ${APPTAINER_ROOTFS}/secrets.txt
|
||||||
|
set -x
|
||||||
|
|
||||||
|
# Conditionally copy the project files based on the INSTALL_PROJECT environment variable
|
||||||
|
if [ ${INSTALL_PROJECT} = "true" ]; then
|
||||||
|
echo "Copying project files into the container..."
|
||||||
|
mkdir -p ${APPTAINER_ROOTFS}/opt/modelhub
|
||||||
|
rsync -av ./ ${APPTAINER_ROOTFS}/opt/modelhub/
|
||||||
|
else
|
||||||
|
echo "Skipping copying of project files."
|
||||||
|
fi
|
||||||
|
|
||||||
|
%post
|
||||||
|
# get os name
|
||||||
|
echo "Running on OS name $(lsb_release -i | awk '{ print $3 }')"
|
||||||
|
# get os version
|
||||||
|
echo "... in OS version $(lsb_release -r | awk '{ print $2 }')"
|
||||||
|
|
||||||
|
## SECRETS FILE
|
||||||
|
# Deal with secrets file
|
||||||
|
# ... verify that the secrets file is present on the container
|
||||||
|
if [ ! -e /secrets.txt ]; then
|
||||||
|
echo "ERROR: secrets.txt is not present on the container"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
# ... temporarily set the access token environment variables
|
||||||
|
# from the secrets file
|
||||||
|
echo "Exporting access tokens from secrets.txt"
|
||||||
|
set +x
|
||||||
|
export GITHUB_USER=$(grep GITHUB_USER /secrets.txt | cut -d '=' -f2)
|
||||||
|
export GITHUB_TOKEN=$(grep GITHUB_TOKEN /secrets.txt | cut -d '=' -f2)
|
||||||
|
set -x
|
||||||
|
# ... remove secrets file
|
||||||
|
rm secrets.txt
|
||||||
|
# ... verify that the secrets file is not present on the container
|
||||||
|
if [ -e /secrets.txt ]; then
|
||||||
|
echo "ERROR: secrets.txt is still present on the container"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "Verified that secrets.txt is not present on the container"
|
||||||
|
fi
|
||||||
|
# ... verify that the access token environment variables are set
|
||||||
|
set +x
|
||||||
|
for var in GITHUB_USER GITHUB_TOKEN; do
|
||||||
|
if [ -z "$(eval echo \$$var)" ]; then
|
||||||
|
echo "ERROR: $var is not set"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
set -x
|
||||||
|
echo "Verified that access tokens are set"
|
||||||
|
|
||||||
|
# Install additional libraries
|
||||||
|
|
||||||
|
# Cifutils
|
||||||
|
pip install git+https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/baker-laboratory/cifutils.git@v2.15.0
|
||||||
|
|
||||||
|
# Datahub
|
||||||
|
pip install git+https://${GITHUB_USER}:${GITHUB_TOKEN}@github.com/baker-laboratory/datahub.git@v3.14.1
|
||||||
|
|
||||||
|
# Modelhub (maybe)
|
||||||
|
if [ -d "/opt/modelhub" ]; then
|
||||||
|
echo "Installing the project from /opt/modelhub..."
|
||||||
|
pip install /opt/modelhub
|
||||||
|
else
|
||||||
|
echo "Skipping project installation. /opt/modelhub does not exist."
|
||||||
|
fi
|
||||||
|
|
||||||
|
## CLEANUP
|
||||||
|
# Unset the access token environment variables to avoid possibly
|
||||||
|
# leaking them in the container
|
||||||
|
unset GITHUB_USER
|
||||||
|
unset GITHUB_TOKEN
|
||||||
|
# ... verify that the access token environment variables are unset
|
||||||
|
set +x
|
||||||
|
for var in GITHUB_USER GITHUB_TOKEN; do
|
||||||
|
if [ -n "$(eval echo \$$var)" ]; then
|
||||||
|
set -x
|
||||||
|
echo "ERROR: $var is still set"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
set -x
|
||||||
|
echo "Verified that access tokens are unset."
|
||||||
|
|
||||||
|
%runscript
|
||||||
|
# NOTE: The %runscript is invoked when the container is run without specifying a different command.
|
||||||
|
exec python "$@"
|
||||||
547
notebooks/plot.ipynb
Normal file
547
notebooks/plot.ipynb
Normal file
@@ -0,0 +1,547 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Plotting AF3 Results with CSV Logger"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"This notebook provides examples of how to parse the results of the `CSVLogger` for both training and validation."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 1,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Imports for this notebook\n",
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import matplotlib.pyplot as plt\n",
|
||||||
|
"import seaborn as sns\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"from pathlib import Path"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Path to the log folder\n",
|
||||||
|
"LOG_PATH = Path(\"/path/to/logs\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Validation"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Workflows to plot and visualize validation metrics"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Plot Results for Most Recent Epoch"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"val_df = pd.read_csv(LOG_PATH / \"val_metrics/validation_output_all_epochs.csv\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def plot_validation_results_by_type(\n",
|
||||||
|
" df: pd.DataFrame,\n",
|
||||||
|
" ignore_zeros: bool = False,\n",
|
||||||
|
") -> None:\n",
|
||||||
|
" \"\"\"Visualize metrics across all datasets.\n",
|
||||||
|
"\n",
|
||||||
|
" NOTE: Ensure that you first subset the DataFrame to only include the desired epoch.\n",
|
||||||
|
" \n",
|
||||||
|
" Args:\n",
|
||||||
|
" df: Combined DataFrame containing metrics data\n",
|
||||||
|
" ignore_zeros: Whether to treat zero values as missing data\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # (Copy the DataFrame to avoid modifying the original)\n",
|
||||||
|
" _df = df.copy()\n",
|
||||||
|
"\n",
|
||||||
|
" # ... subset to only include the desired columns\n",
|
||||||
|
" _df = _df[[\"dataset\", \"by_type_lddt.type\", \"by_type_lddt.best_of_1_lddt\", \"by_type_lddt.best_of_5_lddt\"]]\n",
|
||||||
|
" _df = _df.dropna()\n",
|
||||||
|
"\n",
|
||||||
|
" if ignore_zeros:\n",
|
||||||
|
" _df = _df.replace(0, pd.NA)\n",
|
||||||
|
"\n",
|
||||||
|
" # Prepare data\n",
|
||||||
|
" melted = pd.melt(_df,\n",
|
||||||
|
" id_vars=[\"dataset\", \"by_type_lddt.type\"],\n",
|
||||||
|
" value_vars=[\"by_type_lddt.best_of_1_lddt\", \"by_type_lddt.best_of_5_lddt\"],\n",
|
||||||
|
" var_name='metric',\n",
|
||||||
|
" value_name='lddt')\n",
|
||||||
|
" \n",
|
||||||
|
" # Create visualization\n",
|
||||||
|
" sns.set(style=\"whitegrid\", font_scale=1.1)\n",
|
||||||
|
" plt.figure(figsize=(15, 8))\n",
|
||||||
|
" \n",
|
||||||
|
" g = sns.catplot(\n",
|
||||||
|
" data=melted,\n",
|
||||||
|
" x='by_type_lddt.type',\n",
|
||||||
|
" y='lddt',\n",
|
||||||
|
" hue='metric',\n",
|
||||||
|
" col='dataset',\n",
|
||||||
|
" kind='bar',\n",
|
||||||
|
" estimator='mean', # Explicitly set to mean aggregation\n",
|
||||||
|
" ci=None, # Disable confidence intervals\n",
|
||||||
|
" height=6,\n",
|
||||||
|
" aspect=2,\n",
|
||||||
|
" sharey=False,\n",
|
||||||
|
" legend_out=False\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Annotate bars with values\n",
|
||||||
|
" for ax in g.axes.flat:\n",
|
||||||
|
" for p in ax.patches:\n",
|
||||||
|
" ax.annotate(f\"{p.get_height():.2f}\",\n",
|
||||||
|
" (p.get_x() + p.get_width() / 2., p.get_height()),\n",
|
||||||
|
" ha='center', va='center',\n",
|
||||||
|
" fontsize=10,\n",
|
||||||
|
" color='black',\n",
|
||||||
|
" xytext=(0, 7),\n",
|
||||||
|
" textcoords='offset points')\n",
|
||||||
|
" \n",
|
||||||
|
" ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')\n",
|
||||||
|
" ax.set_xlabel('')\n",
|
||||||
|
" ax.set_ylabel('LDDT')\n",
|
||||||
|
"\n",
|
||||||
|
" plt.suptitle(f'Model Performance Comparison', y=1.02)\n",
|
||||||
|
" plt.tight_layout()\n",
|
||||||
|
" plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"current_epoch_df = val_df[val_df[\"epoch\"] == val_df[\"epoch\"].max()]\n",
|
||||||
|
"plot_validation_results_by_type(current_epoch_df)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Plot Validation Curves"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def plot_metric_trend(\n",
|
||||||
|
" df: pd.DataFrame,\n",
|
||||||
|
" metric_type: str,\n",
|
||||||
|
" dataset: str | None = None,\n",
|
||||||
|
" ignore_zeros: bool = False,\n",
|
||||||
|
" last_n_epochs: int | None = None\n",
|
||||||
|
") -> None:\n",
|
||||||
|
" \"\"\"Plot best-of-1 vs best-of-5 trends across epochs for a specific metric type.\n",
|
||||||
|
" \n",
|
||||||
|
" Args:\n",
|
||||||
|
" df: Combined metrics DataFrame\n",
|
||||||
|
" metric_type: The 'type' to plot (e.g., 'protein-ligand')\n",
|
||||||
|
" dataset_filter: Optional specific dataset to filter\n",
|
||||||
|
" ignore_zeros: Whether to exclude zero values\n",
|
||||||
|
" last_n_epochs: Optional number of most recent epochs to plot\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Filter data\n",
|
||||||
|
" filtered = df[df['by_type_lddt.type'] == metric_type]\n",
|
||||||
|
" \n",
|
||||||
|
" if dataset:\n",
|
||||||
|
" filtered = filtered[filtered['dataset'] == dataset]\n",
|
||||||
|
" \n",
|
||||||
|
" if ignore_zeros:\n",
|
||||||
|
" filtered = filtered.replace(0, pd.NA).dropna(\n",
|
||||||
|
" subset=['by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" if filtered.empty:\n",
|
||||||
|
" raise ValueError(f\"No data found for {metric_type} in dataset {dataset or 'any'}\")\n",
|
||||||
|
"\n",
|
||||||
|
" if last_n_epochs:\n",
|
||||||
|
" max_epoch = filtered['epoch'].max()\n",
|
||||||
|
" filtered = filtered[filtered['epoch'] > (max_epoch - last_n_epochs)]\n",
|
||||||
|
"\n",
|
||||||
|
" # Aggregate by epoch\n",
|
||||||
|
" trend_data = filtered.groupby('epoch').agg({\n",
|
||||||
|
" 'by_type_lddt.best_of_1_lddt': 'mean',\n",
|
||||||
|
" 'by_type_lddt.best_of_5_lddt': 'mean'\n",
|
||||||
|
" }).reset_index()\n",
|
||||||
|
"\n",
|
||||||
|
" # Create plot\n",
|
||||||
|
" plt.figure(figsize=(12, 6))\n",
|
||||||
|
" sns.set_style(\"whitegrid\")\n",
|
||||||
|
"\n",
|
||||||
|
" # Plot lines with markers\n",
|
||||||
|
" sns.lineplot(\n",
|
||||||
|
" data=trend_data,\n",
|
||||||
|
" x='epoch',\n",
|
||||||
|
" y='by_type_lddt.best_of_1_lddt',\n",
|
||||||
|
" color='#1f77b4',\n",
|
||||||
|
" label='Best of 1',\n",
|
||||||
|
" marker='o',\n",
|
||||||
|
" markersize=8,\n",
|
||||||
|
" linewidth=2\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" sns.lineplot(\n",
|
||||||
|
" data=trend_data,\n",
|
||||||
|
" x='epoch',\n",
|
||||||
|
" y='by_type_lddt.best_of_5_lddt',\n",
|
||||||
|
" color='#ff7f0e',\n",
|
||||||
|
" label='Best of 5',\n",
|
||||||
|
" marker='s',\n",
|
||||||
|
" markersize=8,\n",
|
||||||
|
" linewidth=2\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" # Style plot\n",
|
||||||
|
" plt.title(f\"{metric_type} LDDT Trends\\nDataset: {dataset or 'All'}\")\n",
|
||||||
|
" plt.xlabel(\"Epoch\")\n",
|
||||||
|
" plt.ylabel(\"Average LDDT\")\n",
|
||||||
|
" plt.legend(title=\"Strategy\")\n",
|
||||||
|
" plt.grid(alpha=0.3)\n",
|
||||||
|
"\n",
|
||||||
|
" # Add padding to autoscaled y-axis\n",
|
||||||
|
" plt.ylim(top=min(1.0, plt.ylim()[1] * 1.05)) # Cap at 1.0 if near upper bound\n",
|
||||||
|
" plt.ylim(bottom=max(0.0, plt.ylim()[0] * 0.95)) # Floor at 0.0 if near lower bound\n",
|
||||||
|
"\n",
|
||||||
|
" # Set x-axis to show only whole numbers\n",
|
||||||
|
" plt.xticks(ticks=trend_data['epoch'], labels=trend_data['epoch'].astype(int))\n",
|
||||||
|
" \n",
|
||||||
|
" # Add final value annotations\n",
|
||||||
|
" last_values = trend_data.iloc[-1]\n",
|
||||||
|
" plt.text(\n",
|
||||||
|
" 0.95, 0.15,\n",
|
||||||
|
" f\"Final Bo1: {last_values['by_type_lddt.best_of_1_lddt']:.2f}\\nFinal Bo5: {last_values['by_type_lddt.best_of_5_lddt']:.2f}\",\n",
|
||||||
|
" ha='right',\n",
|
||||||
|
" va='bottom',\n",
|
||||||
|
" transform=plt.gca().transAxes,\n",
|
||||||
|
" bbox=dict(facecolor='white', alpha=0.8)\n",
|
||||||
|
" )\n",
|
||||||
|
"\n",
|
||||||
|
" plt.tight_layout()\n",
|
||||||
|
" plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plot_metric_trend(val_df, 'protein-ligand', dataset=\"af3_validation\", ignore_zeros=False, last_n_epochs=4)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Visualize Outliers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Identify and visualize outliers"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"structures_path = f\"{LOG_PATH}/val_structures\""
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def get_worst_examples(\n",
|
||||||
|
" df: pd.DataFrame, \n",
|
||||||
|
" metric_type: str, \n",
|
||||||
|
" dataset: str = None, \n",
|
||||||
|
" epoch: int = None\n",
|
||||||
|
") -> list:\n",
|
||||||
|
" \"\"\"Return example IDs sorted by worst performance for a metric type at specific epoch.\"\"\"\n",
|
||||||
|
" filtered = df[df['by_type_lddt.type'] == metric_type]\n",
|
||||||
|
" \n",
|
||||||
|
" if dataset:\n",
|
||||||
|
" filtered = filtered[filtered['dataset'] == dataset]\n",
|
||||||
|
" \n",
|
||||||
|
" # Use latest epoch if none specified\n",
|
||||||
|
" target_epoch = epoch if epoch is not None else filtered['epoch'].max()\n",
|
||||||
|
" filtered = filtered[filtered['epoch'] == target_epoch]\n",
|
||||||
|
" \n",
|
||||||
|
" return (\n",
|
||||||
|
" filtered[['example_id', 'by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']]\n",
|
||||||
|
" .assign(worst_score=lambda x: x[['by_type_lddt.best_of_1_lddt', 'by_type_lddt.best_of_5_lddt']].min(axis=1))\n",
|
||||||
|
" .sort_values('worst_score')\n",
|
||||||
|
" ['example_id']\n",
|
||||||
|
" .tolist()\n",
|
||||||
|
" )"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from datahub.common import parse_example_id\n",
|
||||||
|
"from cifutils.utils.visualize import view\n",
|
||||||
|
"from cifutils import parse\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = \"af3-validation\"\n",
|
||||||
|
"metric = \"protein-ligand\"\n",
|
||||||
|
"latest_epoch = val_df['epoch'].max()\n",
|
||||||
|
"\n",
|
||||||
|
"worst_protein_ligand_examples = get_worst_examples(val_df, metric, dataset=\"af3_validation\", epoch=latest_epoch)\n",
|
||||||
|
"\n",
|
||||||
|
"# Visualize the worst example\n",
|
||||||
|
"parsed_id = parse_example_id(worst_protein_ligand_examples[0])\n",
|
||||||
|
"\n",
|
||||||
|
"# Find the worst example in the structures directory\n",
|
||||||
|
"structure_path_for_epoch = Path(structures_path) / f\"epoch_{latest_epoch}\" / dataset\n",
|
||||||
|
"\n",
|
||||||
|
"if structure_path_for_epoch.exists():\n",
|
||||||
|
" example_path = next(structure_path_for_epoch.glob(f\"*{parsed_id['pdb_id']}_{parsed_id['assembly_id']}*\"))\n",
|
||||||
|
"\n",
|
||||||
|
" # ... and visualize\n",
|
||||||
|
" atom_array = parse(example_path)\n",
|
||||||
|
" view(atom_array[\"assemblies\"][\"1\"][0])\n",
|
||||||
|
"else:\n",
|
||||||
|
" print(f\"No structure found for {parsed_id}\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Training"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"csv_path = f\"{LOG_PATH}/lightning_logs/version_0/metrics.csv\"\n",
|
||||||
|
"_df = pd.read_csv(csv_path)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training Curves"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Subset to the relevant columns\n",
|
||||||
|
"mean_cols = [\"train/batch_mean/diffusion_loss\", \"train/batch_mean/smoothed_lddt_loss\", \"train/batch_mean/total_loss\", \"train/batch_mean/distogram_loss\"]\n",
|
||||||
|
"per_structure_cols = [\"train/per_structure/t\", \"train/per_structure/diffusion_loss\", \"train/per_structure/smoothed_lddt_loss\"]\n",
|
||||||
|
"train_df = _df[mean_cols + per_structure_cols + [\"step\", \"train/learning_rate\"]]\n",
|
||||||
|
"\n",
|
||||||
|
"# Remove rows with all NaN values except for the 'step' column\n",
|
||||||
|
"check_cols = [col for col in train_df.columns if col != 'step']\n",
|
||||||
|
"train_df = train_df.dropna(how='all', subset=check_cols)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def plot_training_metrics(train_df: pd.DataFrame) -> None:\n",
|
||||||
|
" \"\"\"Plot all training metrics from a DataFrame.\"\"\"\n",
|
||||||
|
" \n",
|
||||||
|
" processed = (\n",
|
||||||
|
" train_df\n",
|
||||||
|
" .groupby('step', as_index=False)\n",
|
||||||
|
" .mean()\n",
|
||||||
|
" .melt(id_vars='step', var_name='metric')\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # Create visualization\n",
|
||||||
|
" plt.figure(figsize=(12, 8))\n",
|
||||||
|
" sns.set_style(\"whitegrid\")\n",
|
||||||
|
" \n",
|
||||||
|
" g = sns.FacetGrid(\n",
|
||||||
|
" processed,\n",
|
||||||
|
" col='metric',\n",
|
||||||
|
" col_wrap=3,\n",
|
||||||
|
" height=4,\n",
|
||||||
|
" aspect=1.5,\n",
|
||||||
|
" sharey=False\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" g.map(sns.lineplot, 'step', 'value', color='#2ca02c')\n",
|
||||||
|
" g.set_titles(\"{col_name}\")\n",
|
||||||
|
" g.set_axis_labels(\"Training Step\", \"Value\")\n",
|
||||||
|
" \n",
|
||||||
|
" # Special handling for learning rate\n",
|
||||||
|
" if 'learning_rate' in processed['metric'].unique():\n",
|
||||||
|
" g.axes[-1].set_yscale('log')\n",
|
||||||
|
" \n",
|
||||||
|
" plt.tight_layout()\n",
|
||||||
|
" plt.show()"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"mean_df = train_df[mean_cols + [\"step\"]].copy()\n",
|
||||||
|
"mean_df = mean_df.dropna(subset=mean_cols)\n",
|
||||||
|
"plot_training_metrics(train_df[mean_cols + [\"step\", \"train/learning_rate\"]])"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## Training Loss by T"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def plot_loss_scatter_by_t(\n",
|
||||||
|
" train_df: pd.DataFrame,\n",
|
||||||
|
" loss_column: str,\n",
|
||||||
|
" t_column: str = 'train/per_structure/t',\n",
|
||||||
|
" n_steps: int = 1000\n",
|
||||||
|
") -> None:\n",
|
||||||
|
" \"\"\"Plot loss values as a scatter plot against train/t values for the most recent N steps with a log-scaled x-axis.\n",
|
||||||
|
" \n",
|
||||||
|
" Args:\n",
|
||||||
|
" train_df: DataFrame containing training metrics\n",
|
||||||
|
" loss_column: Name of loss column to plot (e.g., 'train/total_loss')\n",
|
||||||
|
" t_column: Name of the column representing 't' values\n",
|
||||||
|
" n_steps: Number of recent training steps to analyze (default: 1000)\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" train_df = train_df.copy()\n",
|
||||||
|
"\n",
|
||||||
|
" assert loss_column in train_df.columns, f\"Loss column '{loss_column}' not found in DataFrame\"\n",
|
||||||
|
"\n",
|
||||||
|
" # Get most recent steps\n",
|
||||||
|
" unique_steps = train_df['step'].dropna().unique()\n",
|
||||||
|
" \n",
|
||||||
|
" # Get actual number of available steps\n",
|
||||||
|
" n_steps = min(n_steps, len(unique_steps))\n",
|
||||||
|
" latest_steps = np.sort(unique_steps)[-n_steps:]\n",
|
||||||
|
" \n",
|
||||||
|
" # Filter recent data\n",
|
||||||
|
" recent_data = train_df[train_df['step'].isin(latest_steps)]\n",
|
||||||
|
"\n",
|
||||||
|
" # Subset to relevant columns and remove rows with NaN values\n",
|
||||||
|
" recent_data = recent_data[[t_column, loss_column]]\n",
|
||||||
|
" \n",
|
||||||
|
" # Create scatter plot\n",
|
||||||
|
" plt.figure(figsize=(10, 6)) # Fixed figure size\n",
|
||||||
|
" sns.set_style(\"whitegrid\")\n",
|
||||||
|
" \n",
|
||||||
|
" sns.scatterplot(\n",
|
||||||
|
" data=recent_data,\n",
|
||||||
|
" x=t_column,\n",
|
||||||
|
" y=loss_column,\n",
|
||||||
|
" color='#2ca02c',\n",
|
||||||
|
" alpha=0.6\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" plt.xscale('log') # Set x-axis to logarithmic scale\n",
|
||||||
|
" plt.title(f\"{loss_column} vs {t_column} (Log Scale)\\n(Last {n_steps} Steps)\")\n",
|
||||||
|
" plt.xlabel(t_column)\n",
|
||||||
|
" plt.ylabel(loss_column)\n",
|
||||||
|
" \n",
|
||||||
|
" plt.tight_layout()\n",
|
||||||
|
" plt.show()\n"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"plot_loss_scatter_by_t(train_df, 'train/per_structure/smoothed_lddt_loss', n_steps=1000)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "modelhub",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.11.11"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 2
|
||||||
|
}
|
||||||
@@ -20,19 +20,19 @@ build-backend = "hatchling.build"
|
|||||||
source = "vcs"
|
source = "vcs"
|
||||||
|
|
||||||
[tool.hatch.build.hooks.vcs]
|
[tool.hatch.build.hooks.vcs]
|
||||||
version-file = "rf2aa/version.py"
|
version-file = "src/modelhub/version.py"
|
||||||
|
|
||||||
[tool.hatch.metadata]
|
[tool.hatch.metadata]
|
||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["rf2aa"]
|
packages = ["src/modelhub"]
|
||||||
|
|
||||||
# Formatting & linting settings -------------------------------------------------------
|
# Formatting & linting settings -------------------------------------------------------
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 88
|
line-length = 88
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
target-version = "py311"
|
target-version = "py310"
|
||||||
exclude = [
|
exclude = [
|
||||||
".bzr",
|
".bzr",
|
||||||
".direnv",
|
".direnv",
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 1.1 MiB |
@@ -1,479 +0,0 @@
|
|||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
import tree
|
|
||||||
from icecream import ic
|
|
||||||
from lightning import LightningModule, Trainer
|
|
||||||
from lightning.pytorch.callbacks import Callback
|
|
||||||
from scipy.stats import norm
|
|
||||||
|
|
||||||
from rf2aa import pymol, pymol_tools
|
|
||||||
from rf2aa.chemical import ChemicalData as ChemData
|
|
||||||
from rf2aa.debug import pretty_describe_dict
|
|
||||||
from rf2aa.loss.af3_losses import Loss
|
|
||||||
from rf2aa.pymol import cmd
|
|
||||||
from rf2aa.util import writepdb
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def flatten_dictionary(dictionary, parent_key="", separator="."):
|
|
||||||
flattened_dict = {}
|
|
||||||
for key, value in dictionary.items():
|
|
||||||
new_key = f"{parent_key}{separator}{key}" if parent_key else key
|
|
||||||
if isinstance(value, dict):
|
|
||||||
flattened_dict.update(flatten_dictionary(value, new_key, separator))
|
|
||||||
else:
|
|
||||||
flattened_dict[new_key] = value
|
|
||||||
return flattened_dict
|
|
||||||
|
|
||||||
|
|
||||||
class LogMetrics(Callback):
|
|
||||||
def __init__(self, config):
|
|
||||||
super().__init__()
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
def on_train_batch_end(
|
|
||||||
self,
|
|
||||||
trainer: Trainer,
|
|
||||||
pl_module: LightningModule,
|
|
||||||
outputs,
|
|
||||||
batch,
|
|
||||||
batch_idx: int,
|
|
||||||
) -> None:
|
|
||||||
logger.debug("on_train_batch_end outputs:\n" + pretty_describe_dict(outputs))
|
|
||||||
|
|
||||||
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
|
|
||||||
o = {}
|
|
||||||
stratifications = defaultdict(list)
|
|
||||||
for metric in [diffusion_losses, lddt_metrics]:
|
|
||||||
metric_d, stratification_keys = metric(self.config, outputs)
|
|
||||||
stratifications[stratification_keys].extend(metric_d.keys())
|
|
||||||
o.update(metric_d)
|
|
||||||
|
|
||||||
o["t"] = outputs["t"]
|
|
||||||
o["t_quantile_4"] = get_t_quantiles(
|
|
||||||
outputs["t"], self.config.loss.sigma_data, 4
|
|
||||||
)
|
|
||||||
df = pd.DataFrame.from_dict(o)
|
|
||||||
df = df.reindex(sorted(df.columns), axis=1)
|
|
||||||
ic(o)
|
|
||||||
(D,) = outputs["t"].shape
|
|
||||||
df["batch_idx"] = batch_idx
|
|
||||||
df["data_idx"] = np.arange(D)
|
|
||||||
df["global_step"] = trainer.global_step
|
|
||||||
trainer.logger.log_df(df, stratifications=stratifications)
|
|
||||||
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
|
||||||
|
|
||||||
def on_validation_batch_end(
|
|
||||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
|
||||||
):
|
|
||||||
outputs = tree.map_structure(lambda x: x.detach().cpu(), outputs)
|
|
||||||
o = {}
|
|
||||||
for metric in [lddt_metrics, lddt_metrics_null, diffusion_losses]:
|
|
||||||
metric_d, stratification_keys = metric(self.config, outputs)
|
|
||||||
o.update(metric_d)
|
|
||||||
df = pd.DataFrame.from_dict(o)
|
|
||||||
df = df.reindex(sorted(df.columns), axis=1)
|
|
||||||
ic(o)
|
|
||||||
df["batch_idx"] = batch_idx
|
|
||||||
df["global_step"] = trainer.global_step
|
|
||||||
|
|
||||||
trainer.logger.log_df(df, stratifications={})
|
|
||||||
return super().on_validation_batch_end(
|
|
||||||
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def lddt_metrics(config, outputs):
|
|
||||||
# compute distances between ground truth atoms
|
|
||||||
ground_truth_distances = torch.cdist(outputs["X_gt_L"], outputs["X_gt_L"])
|
|
||||||
# compute distances between predicted atoms
|
|
||||||
predicted_distances = torch.cdist(outputs["X_L"], outputs["X_L"])
|
|
||||||
# compute LDDT score for each pair of distances
|
|
||||||
difference_distances = torch.abs(ground_truth_distances - predicted_distances)
|
|
||||||
lddt_matrix = torch.zeros_like(difference_distances)
|
|
||||||
lddt_matrix = (
|
|
||||||
0.25 * (difference_distances < 4.0)
|
|
||||||
+ 0.25 * (difference_distances < 2.0)
|
|
||||||
+ 0.25 * (difference_distances < 1.0)
|
|
||||||
+ 0.25 * (difference_distances < 0.5)
|
|
||||||
)
|
|
||||||
# remove unresolved atoms, atoms within same residue
|
|
||||||
is_real_atom = ChemData().heavyatom_mask.to(outputs["seq"].device)[outputs["seq"]]
|
|
||||||
is_resolved_atom_L = outputs["crd_mask_I"][is_real_atom]
|
|
||||||
is_unresolved_distance_LL = (
|
|
||||||
is_resolved_atom_L[..., None] & is_resolved_atom_L[None, ...]
|
|
||||||
)
|
|
||||||
in_same_residue_LL = (
|
|
||||||
outputs["f"]["tok_idx"][:, None] == outputs["f"]["tok_idx"][None, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
lddt_values = {}
|
|
||||||
for mask, mask_type in get_lddt_masks(outputs):
|
|
||||||
mask = mask & is_unresolved_distance_LL & ~in_same_residue_LL
|
|
||||||
lddt = torch.div(lddt_matrix[:, mask].sum(dim=(-1)), mask.sum(dim=(-1, -2)))
|
|
||||||
lddt_values[f"lddt_{mask_type}"] = lddt
|
|
||||||
return lddt_values, ("t_quantile_4",)
|
|
||||||
|
|
||||||
|
|
||||||
def lddt_metrics_null(config, outputs):
|
|
||||||
# compute distances between ground truth atoms
|
|
||||||
ground_truth_distances = torch.cdist(outputs["X_gt_L"], outputs["X_gt_L"])
|
|
||||||
# compute distances between predicted atoms
|
|
||||||
t = outputs["t"]
|
|
||||||
X_noisy_L = outputs["X_noisy_L"]
|
|
||||||
sigma_data = 16
|
|
||||||
|
|
||||||
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[..., None, None] * X_noisy_L
|
|
||||||
|
|
||||||
predicted_distances = torch.cdist(null_pred, null_pred)
|
|
||||||
# compute LDDT score for each pair of distances
|
|
||||||
difference_distances = torch.abs(ground_truth_distances - predicted_distances)
|
|
||||||
lddt_matrix = torch.zeros_like(difference_distances)
|
|
||||||
lddt_matrix = (
|
|
||||||
0.25 * (difference_distances < 4.0)
|
|
||||||
+ 0.25 * (difference_distances < 2.0)
|
|
||||||
+ 0.25 * (difference_distances < 1.0)
|
|
||||||
+ 0.25 * (difference_distances < 0.5)
|
|
||||||
)
|
|
||||||
# remove unresolved atoms, atoms within same residue
|
|
||||||
is_real_atom = ChemData().heavyatom_mask[outputs["seq"]]
|
|
||||||
is_resolved_atom_L = outputs["crd_mask_I"][is_real_atom]
|
|
||||||
is_unresolved_distance_LL = (
|
|
||||||
is_resolved_atom_L[..., None] & is_resolved_atom_L[None, ...]
|
|
||||||
)
|
|
||||||
in_same_residue_LL = (
|
|
||||||
outputs["f"]["tok_idx"][:, None] == outputs["f"]["tok_idx"][None, :]
|
|
||||||
)
|
|
||||||
|
|
||||||
lddt_values = {}
|
|
||||||
for mask, mask_type in get_lddt_masks(outputs):
|
|
||||||
mask = mask & is_unresolved_distance_LL & ~in_same_residue_LL
|
|
||||||
lddt = torch.div(lddt_matrix[:, mask].sum(dim=(-1)), mask.sum(dim=(-1, -2)))
|
|
||||||
lddt_values[f"lddt_{mask_type}_null"] = lddt
|
|
||||||
return lddt_values, ("t_quantile_4",)
|
|
||||||
|
|
||||||
|
|
||||||
def get_lddt_masks(outputs):
|
|
||||||
D, L = outputs["X_L"].shape[:2]
|
|
||||||
|
|
||||||
tok_idx = outputs["f"]["tok_idx"]
|
|
||||||
is_protein_L = outputs["f"]["is_protein"][tok_idx]
|
|
||||||
is_dna_L = outputs["f"]["is_dna"][tok_idx]
|
|
||||||
is_rna_L = outputs["f"]["is_rna"][tok_idx]
|
|
||||||
is_ligand_L = outputs["f"]["is_ligand"][tok_idx]
|
|
||||||
asym_id_L = outputs["f"]["asym_id"][tok_idx]
|
|
||||||
same_chain_LL = asym_id_L[:, None] == asym_id_L[None, :]
|
|
||||||
for mask_type in [
|
|
||||||
"all",
|
|
||||||
"protein_intra",
|
|
||||||
"protein_inter",
|
|
||||||
"ligand_intra",
|
|
||||||
"ligand_inter",
|
|
||||||
]:
|
|
||||||
if mask_type == "all":
|
|
||||||
mask = torch.ones((L, L), dtype=torch.bool, device=outputs["X_L"].device)
|
|
||||||
elif mask_type == "protein_intra":
|
|
||||||
mask = is_protein_L[:, None] & is_protein_L[None, :]
|
|
||||||
mask *= same_chain_LL
|
|
||||||
elif mask_type == "protein_inter":
|
|
||||||
mask = is_protein_L[:, None] & is_protein_L[None, :]
|
|
||||||
mask *= ~same_chain_LL
|
|
||||||
elif mask_type == "ligand_intra":
|
|
||||||
mask = is_ligand_L[:, None] & is_ligand_L[None, :]
|
|
||||||
mask *= same_chain_LL
|
|
||||||
elif mask_type == "ligand_inter":
|
|
||||||
mask = is_ligand_L[:, None] & is_ligand_L[None, :]
|
|
||||||
mask *= ~same_chain_LL
|
|
||||||
elif mask_type == "protein_ligand_inter":
|
|
||||||
mask = is_protein_L[:, None] & is_ligand_L[None, :]
|
|
||||||
yield (mask, mask_type)
|
|
||||||
|
|
||||||
|
|
||||||
def diffusion_losses(config, outputs):
|
|
||||||
loss = Loss(**config.loss)
|
|
||||||
|
|
||||||
loss_dict_by_type = {}
|
|
||||||
t = outputs["t"]
|
|
||||||
X_noisy_L = outputs["X_noisy_L"]
|
|
||||||
sigma_data = 16
|
|
||||||
|
|
||||||
null_pred = (sigma_data**2 / (sigma_data**2 + t**2))[..., None, None] * X_noisy_L
|
|
||||||
|
|
||||||
sigma_gt = torch.var(outputs["X_gt_L"], dim=(1, 2)) ** 0.5
|
|
||||||
for input_type, X_L in (
|
|
||||||
("pred", outputs["X_L"]),
|
|
||||||
# ('input', outputs['X_noisy_L']),
|
|
||||||
("true", outputs["X_gt_L"]),
|
|
||||||
("null_pred", null_pred),
|
|
||||||
):
|
|
||||||
l_total, _, loss_dict_batched = loss(
|
|
||||||
outputs["f"],
|
|
||||||
X_L,
|
|
||||||
outputs["X_gt_L"],
|
|
||||||
outputs["t"],
|
|
||||||
outputs["seq"],
|
|
||||||
outputs["crd_mask_I"],
|
|
||||||
)
|
|
||||||
# loss_dict_by_type[input_type] = loss_dict_batched
|
|
||||||
loss_dict_batched_prefixed = {
|
|
||||||
f"{k}.{input_type}": v for k, v in loss_dict_batched.items()
|
|
||||||
}
|
|
||||||
loss_dict_by_type.update(loss_dict_batched_prefixed)
|
|
||||||
|
|
||||||
# Correcting for EDM : AF3 lambda conversion
|
|
||||||
edm_corr = (t + loss.sigma_data) ** 2 / (t * loss.sigma_data) ** 2
|
|
||||||
loss_dict_batched_edm = {k: v * edm_corr for k, v in loss_dict_batched.items()}
|
|
||||||
loss_dict_batched_prefixed_edm = {
|
|
||||||
f"{k}_edm.{input_type}": v for k, v in loss_dict_batched_edm.items()
|
|
||||||
}
|
|
||||||
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
|
|
||||||
|
|
||||||
# Correcting for Var(gt) != sigma_data
|
|
||||||
expected_loss_gt = (
|
|
||||||
1
|
|
||||||
/ (loss.sigma_data**2 + t**2)
|
|
||||||
* (loss.sigma_data**2 + t**2 * sigma_gt**2 / loss.sigma_data**2)
|
|
||||||
)
|
|
||||||
loss_dict_batched_edm_gt_corr = {
|
|
||||||
k: edm_corr * v / expected_loss_gt for k, v in loss_dict_batched.items()
|
|
||||||
}
|
|
||||||
loss_dict_batched_prefixed_edm = {
|
|
||||||
f"{k}_edm_gt_corr.{input_type}": v
|
|
||||||
for k, v in loss_dict_batched_edm_gt_corr.items()
|
|
||||||
}
|
|
||||||
loss_dict_by_type.update(loss_dict_batched_prefixed_edm)
|
|
||||||
|
|
||||||
o = flatten_dictionary(loss_dict_by_type)
|
|
||||||
o["pred_over_null_pred"] = o["diffusion_loss.pred"] / o["diffusion_loss.null_pred"]
|
|
||||||
o["pred_over_null_pred_norm"] = (
|
|
||||||
o["diffusion_loss_edm_gt_corr.pred"] / o["diffusion_loss_edm_gt_corr.null_pred"]
|
|
||||||
)
|
|
||||||
return o, ("t_quantile_4",)
|
|
||||||
|
|
||||||
|
|
||||||
def get_normal_quantiles(n):
|
|
||||||
# Generate n evenly spaced probabilities between 0 and 1
|
|
||||||
probabilities = np.linspace(0, 1, n)
|
|
||||||
# Use the percent point function (inverse CDF) of the standard normal distribution
|
|
||||||
return norm.ppf(probabilities)
|
|
||||||
|
|
||||||
|
|
||||||
def get_t_quantiles(t, sigma_data, n):
|
|
||||||
bins = sigma_data * np.exp(-1.2 + 1.5 * get_normal_quantiles(n + 1))
|
|
||||||
t_binned_list = []
|
|
||||||
for t in t:
|
|
||||||
t_bin = np.digitize(t, bins) - 1
|
|
||||||
bin_start = bins[t_bin]
|
|
||||||
bin_end = bins[t_bin + 1]
|
|
||||||
t_binned = f"t=[{bin_start:.2f},{bin_end:.2f})"
|
|
||||||
t_binned_list.append(t_binned)
|
|
||||||
return t_binned_list
|
|
||||||
|
|
||||||
|
|
||||||
class NetworkOutputGradSanityCheck(Callback):
|
|
||||||
def __init__(self, call_n_times=0, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.call_n_times = call_n_times
|
|
||||||
self.call_count = 0
|
|
||||||
|
|
||||||
def on_after_backward(self, trainer, pl_module):
|
|
||||||
if self.call_count < self.call_n_times:
|
|
||||||
self.call_count += 1
|
|
||||||
r_projection_weight = pl_module.model.model.diffusion_module.atom_attention_decoder.to_r_update[
|
|
||||||
1
|
|
||||||
].weight
|
|
||||||
ic(
|
|
||||||
torch.linalg.norm(r_projection_weight)
|
|
||||||
if r_projection_weight is not None
|
|
||||||
else None,
|
|
||||||
torch.linalg.norm(r_projection_weight.grad)
|
|
||||||
if r_projection_weight.grad is not None
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MonitorActivations(Callback):
|
|
||||||
def make_hook(self, label):
|
|
||||||
def hook(module, args, kwargs, output):
|
|
||||||
activation_metrics = {
|
|
||||||
f"{label}:inter_batch_cosine_similarity": F.cosine_similarity(
|
|
||||||
torch.flatten(output[0]),
|
|
||||||
torch.flatten(output[1]),
|
|
||||||
dim=0,
|
|
||||||
),
|
|
||||||
f"{label}:intra_batch_cosine_similarity_to_elem_0": F.cosine_similarity(
|
|
||||||
output[0][0:1],
|
|
||||||
output[0],
|
|
||||||
).mean(),
|
|
||||||
}
|
|
||||||
self.log_dict(activation_metrics)
|
|
||||||
|
|
||||||
return hook
|
|
||||||
|
|
||||||
def setup(self, trainer, pl_module, stage):
|
|
||||||
self.pl_module = pl_module
|
|
||||||
self.trainer = trainer
|
|
||||||
|
|
||||||
pl_module.model.model.diffusion_module.atom_attention_decoder.register_forward_hook(
|
|
||||||
self.make_hook(
|
|
||||||
"diffusion_module.atom_attention_decoder",
|
|
||||||
),
|
|
||||||
with_kwargs=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FindUnusedParameters(Callback):
|
|
||||||
def __init__(self, only_once=True, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.only_once = only_once
|
|
||||||
self.called = False
|
|
||||||
|
|
||||||
def on_after_backward(self, trainer, pl_module):
|
|
||||||
if self.called and self.only_once:
|
|
||||||
return
|
|
||||||
self.called = True
|
|
||||||
# Calculate unused parameters after each batch
|
|
||||||
unused_params = [
|
|
||||||
name for name, param in pl_module.named_parameters() if param.grad is None
|
|
||||||
]
|
|
||||||
|
|
||||||
# Log unused parameters
|
|
||||||
logging.info(
|
|
||||||
f"global_step={pl_module.global_step}: parameters with no gradient: {json.dumps(unused_params, indent=4)}"
|
|
||||||
)
|
|
||||||
if unused_params:
|
|
||||||
raise Exception("storp")
|
|
||||||
|
|
||||||
|
|
||||||
class WriteToPymol(Callback):
|
|
||||||
def __init__(self, only_once=True, *args, **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.only_once = only_once
|
|
||||||
self.called = False
|
|
||||||
pymol.init("http://chesaw.dhcp.ipd:9123")
|
|
||||||
|
|
||||||
def on_train_batch_end(
|
|
||||||
self,
|
|
||||||
trainer: Trainer,
|
|
||||||
pl_module: LightningModule,
|
|
||||||
outputs,
|
|
||||||
batch,
|
|
||||||
batch_idx: int,
|
|
||||||
) -> None:
|
|
||||||
if self.called and self.only_once:
|
|
||||||
return
|
|
||||||
self.called = True
|
|
||||||
|
|
||||||
pymol_tools.clear()
|
|
||||||
predicted = outputs
|
|
||||||
|
|
||||||
logger.info("predicted:\n" + pretty_describe_dict(predicted))
|
|
||||||
ic(predicted["loss"])
|
|
||||||
|
|
||||||
D = predicted["X_L"].shape[0]
|
|
||||||
|
|
||||||
max_to_show = 16
|
|
||||||
grid_slot = 1
|
|
||||||
cmd.set("grid_mode", 1)
|
|
||||||
for i in range(min(D, max_to_show)):
|
|
||||||
X_gt_L = predicted["X_gt_L"][i]
|
|
||||||
X_L = predicted["X_L"][i]
|
|
||||||
X_noisy_L = predicted["X_noisy_L"][i]
|
|
||||||
t = predicted["t"][i]
|
|
||||||
|
|
||||||
label = pymol_tools.show_pymol(
|
|
||||||
pymol_tools.to_atom37(X_noisy_L, predicted["crd_mask_I"]),
|
|
||||||
predicted["seq"],
|
|
||||||
predicted["bond_feats"],
|
|
||||||
label=f"input_{i}_t_{t.item():.2f}",
|
|
||||||
)
|
|
||||||
cmd.set("grid_slot", grid_slot, label)
|
|
||||||
cmd.color("yellow", label)
|
|
||||||
|
|
||||||
label = pymol_tools.show_pymol(
|
|
||||||
pymol_tools.to_atom37(X_L, predicted["crd_mask_I"]),
|
|
||||||
predicted["seq"],
|
|
||||||
predicted["bond_feats"],
|
|
||||||
label=f"pred_{i}_t_{t.item():.2f}",
|
|
||||||
)
|
|
||||||
cmd.set("grid_slot", grid_slot, label)
|
|
||||||
cmd.color("green", label)
|
|
||||||
|
|
||||||
label = pymol_tools.show_pymol(
|
|
||||||
pymol_tools.to_atom37(X_gt_L, predicted["crd_mask_I"]),
|
|
||||||
predicted["seq"],
|
|
||||||
predicted["bond_feats"],
|
|
||||||
label=f"gt_{i}",
|
|
||||||
)
|
|
||||||
cmd.set("grid_slot", grid_slot, label)
|
|
||||||
cmd.color("blue", label)
|
|
||||||
grid_slot += 1
|
|
||||||
|
|
||||||
cmd.show_as("licorice", "all")
|
|
||||||
cmd.alter("name CA", "vdw=2.0")
|
|
||||||
# cmd.set('sphere_transparency', 0.0)
|
|
||||||
cmd.show("spheres", "name CA")
|
|
||||||
|
|
||||||
return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
|
|
||||||
|
|
||||||
|
|
||||||
class WritePDB(Callback):
|
|
||||||
def on_validation_batch_end(
|
|
||||||
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
|
||||||
):
|
|
||||||
seq = batch["seq"][0][0]
|
|
||||||
is_real_atom = ChemData().heavyatom_mask.to(seq.device)[seq]
|
|
||||||
X_L = outputs["X_L"]
|
|
||||||
X_gt_L = outputs["X_gt_L"]
|
|
||||||
atom_mask = outputs["crd_mask_I"]
|
|
||||||
bond_feats = batch["bond_feats"]
|
|
||||||
|
|
||||||
X_I = torch.full(
|
|
||||||
(X_L.shape[0], atom_mask.shape[0], ChemData().NTOTAL, 3), np.nan
|
|
||||||
).to(X_L.device)
|
|
||||||
X_I[..., is_real_atom, :] = X_L
|
|
||||||
|
|
||||||
X_gt_I = torch.full(
|
|
||||||
(X_gt_L.shape[0], atom_mask.shape[0], ChemData().NTOTAL, 3), np.nan
|
|
||||||
).to(X_gt_L.device)
|
|
||||||
X_gt_I[..., atom_mask, :] = X_gt_L
|
|
||||||
pdb_path = f"tmp/true_{batch_idx}.pdb"
|
|
||||||
writepdb(
|
|
||||||
pdb_path,
|
|
||||||
X_gt_I[0],
|
|
||||||
seq.long(),
|
|
||||||
bond_feats=bond_feats,
|
|
||||||
)
|
|
||||||
for i in range(X_L.shape[0]):
|
|
||||||
pdb_path = f"tmp/pred_{batch_idx}_{i}.pdb"
|
|
||||||
writepdb(
|
|
||||||
pdb_path,
|
|
||||||
X_I[i],
|
|
||||||
seq.long(),
|
|
||||||
bond_feats=bond_feats,
|
|
||||||
)
|
|
||||||
|
|
||||||
return super().on_validation_batch_end(
|
|
||||||
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DebugGrads(Callback):
|
|
||||||
def on_after_backward(self, trainer, pl_module):
|
|
||||||
grad_dict = {}
|
|
||||||
for name, param in pl_module.named_parameters():
|
|
||||||
if param.grad is not None and "pairformer" in name:
|
|
||||||
grad_dict[name] = param.grad.clone().detach()
|
|
||||||
ic(
|
|
||||||
name,
|
|
||||||
torch.linalg.norm(param.grad),
|
|
||||||
torch.linalg.norm(param),
|
|
||||||
)
|
|
||||||
torch.save(grad_dict, "grad_dict_unbatched.pt")
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,534 +0,0 @@
|
|||||||
import os
|
|
||||||
import argparse
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import pickle
|
|
||||||
import tempfile
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from os import PathLike
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from biotite.structure import AtomArray, AtomArrayStack, stack
|
|
||||||
from cifutils import parse
|
|
||||||
from cifutils.tools.inference import (
|
|
||||||
build_msa_paths_by_chain_id_from_component_list,
|
|
||||||
components_to_atom_array,
|
|
||||||
)
|
|
||||||
from cifutils.utils.io_utils import to_cif_file
|
|
||||||
from datahub.encoding_definitions import AF3SequenceEncoding
|
|
||||||
import omegaconf
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
|
|
||||||
from rf2aa.metrics.predicted_error import WriteAF3Confidence
|
|
||||||
from rf2aa.trainer_base import trainer_factory
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Define the sequence encoding; needed to decode the restypes when saving to CIF
|
|
||||||
encoding = AF3SequenceEncoding()
|
|
||||||
|
|
||||||
|
|
||||||
def build_stack_from_atom_array_and_batched_coords(
|
|
||||||
coords: np.ndarray,
|
|
||||||
atom_array: AtomArray,
|
|
||||||
annotations_to_keep: list[str] = [
|
|
||||||
"chain_id",
|
|
||||||
"transformation_id",
|
|
||||||
"res_id",
|
|
||||||
"res_name",
|
|
||||||
"element",
|
|
||||||
"atom_name",
|
|
||||||
],
|
|
||||||
) -> AtomArrayStack:
|
|
||||||
"""Builds an AtomArrayStack from an AtomArray and a set of coordinates with a batch dimension.
|
|
||||||
|
|
||||||
Additionally, handles the case where the AtomArray contains multiple transformations and we must adjust the chain_id.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
coords (np.array): The coordinates to be assigned to the AtomArrayStack. Must have shape (nbatch, n_atoms, 3).
|
|
||||||
atom_array (AtomArray): The AtomArray to be stacked. Must have shape (n_atoms,)
|
|
||||||
"""
|
|
||||||
# (Diffusion batch size will become the number of models)
|
|
||||||
n_batch = coords.shape[0]
|
|
||||||
|
|
||||||
# Remove unwanted annotations
|
|
||||||
for annotation in atom_array.get_annotation_categories():
|
|
||||||
if annotation not in annotations_to_keep:
|
|
||||||
atom_array.del_annotation(annotation)
|
|
||||||
|
|
||||||
# Build the stack and assign the coordinates
|
|
||||||
atom_array_stack = stack([atom_array for _ in range(n_batch)])
|
|
||||||
atom_array_stack.coord = coords
|
|
||||||
|
|
||||||
# Adjust chain_id if there are multiple transformations
|
|
||||||
# (Otherwise, we will have ambiguous bond annotations, since only `chain_id` is used for the bond annotations)
|
|
||||||
if (
|
|
||||||
"transformation_id" in atom_array.get_annotation_categories()
|
|
||||||
and len(np.unique(atom_array_stack.transformation_id)) > 1
|
|
||||||
):
|
|
||||||
atom_array_stack.chain_id = (
|
|
||||||
atom_array_stack.chain_id + atom_array_stack.transformation_id
|
|
||||||
)
|
|
||||||
|
|
||||||
return atom_array_stack
|
|
||||||
|
|
||||||
|
|
||||||
def _spoof_cif_from_dictionary(item: dict, temp_dir: PathLike) -> Path:
|
|
||||||
"""Unpacks a dictionary to create a CIF file from its components.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
item (dict): A dictionary containing 'name' and 'components', optionally 'bonds'.
|
|
||||||
temp_dir (Path): Path to the temporary directory for storing CIF files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path: The path to the created CIF file, saved in the temporary directory.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
NotImplementedError: If 'bonds' is present in the dictionary.
|
|
||||||
ValueError: If 'name' or 'components' are missing from the dictionary.
|
|
||||||
"""
|
|
||||||
# Validate the dictionary structure ("name" and "components" are required, "bonds" is optional)
|
|
||||||
assert (
|
|
||||||
"name" in item and "components" in item
|
|
||||||
), "The input dictionary must contain 'name' and 'components' keys."
|
|
||||||
|
|
||||||
# Build components
|
|
||||||
atom_array, component_list = components_to_atom_array(
|
|
||||||
item["components"], return_components=True, bonds=item.get("bonds", None)
|
|
||||||
)
|
|
||||||
msa_paths_by_chain_id = build_msa_paths_by_chain_id_from_component_list(
|
|
||||||
component_list
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create a temporary CIF file from the JSON data
|
|
||||||
cif_path = Path(temp_dir) / f"{item['name']}.cif"
|
|
||||||
save_path = to_cif_file(
|
|
||||||
atom_array,
|
|
||||||
cif_path,
|
|
||||||
extra_categories={"msa_paths_by_chain_id": msa_paths_by_chain_id}
|
|
||||||
if msa_paths_by_chain_id
|
|
||||||
else None,
|
|
||||||
file_type="cif", # Not zipped for efficiency (as it's a temporary directory anyways)
|
|
||||||
)
|
|
||||||
|
|
||||||
return Path(save_path)
|
|
||||||
|
|
||||||
|
|
||||||
def _build_file_paths_for_prediction(inputs: list, temp_dir: PathLike) -> list[Path]:
|
|
||||||
"""Prepare files for prediction based on the input paths.
|
|
||||||
|
|
||||||
Input paths may be dictionary-like format (e.g., JSON, YAML, Pickle), CIF/PDB files, or directories containing these files.
|
|
||||||
Processes directories to find supported file types and converts dictionary-like formats to CIF files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (list): List of input paths (JSON, YAML, Pickle, or CIF/PDB).
|
|
||||||
temp_dir (Path): Path to the temporary directory for storing CIF files.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Path]: List of file paths for prediction.
|
|
||||||
"""
|
|
||||||
DICTIONARY_LIKE_EXTENSIONS = {".json", ".yaml", ".yml", ".pkl"}
|
|
||||||
CIF_LIKE_EXTENSIONS = {".cif", ".pdb", ".bcif", ".cif.gz", ".pdb.gz", ".bcif.gz"}
|
|
||||||
|
|
||||||
# Collect all files from inputs, handling directories and individual files
|
|
||||||
paths_to_raw_input_files = []
|
|
||||||
for input_path in inputs:
|
|
||||||
if Path(input_path).is_dir():
|
|
||||||
paths_to_raw_input_files.extend(
|
|
||||||
_find_files(
|
|
||||||
input_path, DICTIONARY_LIKE_EXTENSIONS | CIF_LIKE_EXTENSIONS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
paths_to_raw_input_files.append(Path(input_path))
|
|
||||||
|
|
||||||
paths_to_cif_like_files = []
|
|
||||||
for path in paths_to_raw_input_files:
|
|
||||||
#concatenated_suffix = "".join(path.suffixes)
|
|
||||||
concatenated_suffix = path.suffixes[-1]
|
|
||||||
if concatenated_suffix in DICTIONARY_LIKE_EXTENSIONS:
|
|
||||||
# Spoof CIF files from dictionary-like formats
|
|
||||||
with open(path, "rb" if path.suffix == ".pkl" else "r") as file:
|
|
||||||
# Load data based on file extension
|
|
||||||
if path.suffix == ".json":
|
|
||||||
data = json.load(file)
|
|
||||||
elif path.suffix in {".yaml", ".yml"}:
|
|
||||||
raise NotImplementedError("YAML files are not yet supported.")
|
|
||||||
elif path.suffix == ".pkl":
|
|
||||||
data = pickle.load(file)
|
|
||||||
|
|
||||||
if isinstance(data, dict):
|
|
||||||
data = [
|
|
||||||
data
|
|
||||||
] # Convert single dictionary to list for uniform processing
|
|
||||||
|
|
||||||
for item in data:
|
|
||||||
paths_to_cif_like_files.append(
|
|
||||||
_spoof_cif_from_dictionary(item, temp_dir)
|
|
||||||
)
|
|
||||||
elif concatenated_suffix in CIF_LIKE_EXTENSIONS:
|
|
||||||
# Directly use CIF-like files
|
|
||||||
paths_to_cif_like_files.append(path)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unsupported file extension: {path.suffix} (path: {path}; paths: {paths_to_raw_input_files})."
|
|
||||||
)
|
|
||||||
|
|
||||||
return paths_to_cif_like_files
|
|
||||||
|
|
||||||
|
|
||||||
def _find_files(path: PathLike, supported_file_types: list) -> list[Path]:
|
|
||||||
"""Recursively find all files with the given extensions in the specified path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (PathLike): Path to the directory containing the files.
|
|
||||||
supported_file_types (list): List of supported file extensions.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list[Path]: List of files with the given extensions.
|
|
||||||
"""
|
|
||||||
files_with_supported_types = []
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
# Check if the path is a directory
|
|
||||||
if path.is_dir():
|
|
||||||
# Search for files with each supported extension
|
|
||||||
for file_type in supported_file_types:
|
|
||||||
files_with_supported_types.extend(path.glob(f"*{file_type}"))
|
|
||||||
elif path.is_file() and path.suffix in supported_file_types:
|
|
||||||
# If it's a file and has a supported extension, add to the list
|
|
||||||
files_with_supported_types.append(path)
|
|
||||||
|
|
||||||
return files_with_supported_types
|
|
||||||
|
|
||||||
|
|
||||||
def _update_nested_dictconfig(d: Mapping, u: Mapping, depth: int = 0) -> Mapping:
|
|
||||||
"""Recursive function to overwrite contents of one nested omegaconf dictconfig with another.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
d: dictionary of dictconfigs whose contents will be overwritten
|
|
||||||
u: dictionary of items which will overwrite or add to values in d
|
|
||||||
depth: depth of recursion: a positive integer:
|
|
||||||
-used to keep the outermost layer of the config as a dict instead of DictConfig.
|
|
||||||
-set to 1 or higher to return only DictConfig.
|
|
||||||
Returns:
|
|
||||||
d updated to contain values in u
|
|
||||||
"""
|
|
||||||
d = dict(d)
|
|
||||||
u = dict(u)
|
|
||||||
for k, v in u.items():
|
|
||||||
if isinstance(v, Mapping):
|
|
||||||
d[k] = _update_nested_dictconfig(d.get(k, {}), v, depth=depth + 1)
|
|
||||||
else:
|
|
||||||
d[k] = v
|
|
||||||
if depth == 0:
|
|
||||||
return d
|
|
||||||
else:
|
|
||||||
return omegaconf.dictconfig.DictConfig(d)
|
|
||||||
|
|
||||||
|
|
||||||
class EvaluateAF3:
|
|
||||||
"""Class for inference with AF3. Evaluates a trained AF3 model on a set of spoofed CIFs."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
checkpoint_path: PathLike,
|
|
||||||
cif_out_dir: PathLike,
|
|
||||||
n_recycles: int,
|
|
||||||
diffusion_batch_size: int,
|
|
||||||
config_override_path: PathLike | None = None,
|
|
||||||
residue_renaming_dict: dict | None = None,
|
|
||||||
temp_dir: PathLike | None = None,
|
|
||||||
num_steps: int = 200,
|
|
||||||
solver: str = "af3",
|
|
||||||
overwrite: bool = False
|
|
||||||
):
|
|
||||||
"""Initialize the evaluator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
checkpoint_path (PathLike): Path to the checkpoint file, e.g., /path/to/checkpoint.pt.
|
|
||||||
cif_out_dir (PathLike): Directory to save the output (predicted) CIF files.
|
|
||||||
config_override_path (PathLike): Path to a yaml file with config options to override those in the checkpoint file.
|
|
||||||
world_size (int): Number of GPUs to use for evaluation.
|
|
||||||
n_recycles (int): Number of recycles for AF3. The default is 10.
|
|
||||||
diffusion_batch_size (int): Diffusion batch size for AF3. Each predicted structure will be saved as a separate model within the same CIF file.
|
|
||||||
residue_renaming_dict (dict): Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}.
|
|
||||||
temp_dir (PathLike): Temporary directory to store intermediate files. The default is None.
|
|
||||||
num_steps (int): Number of steps for sampling of the diffusion model. The default is 200; we see reasonable results with 50 steps.
|
|
||||||
solver (str): Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'. The default is 'af3'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Load the checkpoint
|
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
||||||
checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
|
|
||||||
|
|
||||||
# Load the config
|
|
||||||
self.config = OmegaConf.create(checkpoint["training_config"])
|
|
||||||
|
|
||||||
if config_override_path is not None:
|
|
||||||
with open(config_override_path, 'r') as fs:
|
|
||||||
config_override_dict = yaml.load(fs, yaml.FullLoader)
|
|
||||||
self.config = _update_nested_dictconfig(self.config, config_override_dict)
|
|
||||||
self.config = OmegaConf.create(self.config)
|
|
||||||
|
|
||||||
# Make sure we aren't using the version with a bug in plddt
|
|
||||||
if (
|
|
||||||
self.config.experiment.name
|
|
||||||
== "rf2aa-af3-repro-rollout_nmw_from_scratch_af3_style_no_cb_normal_crop_cont_3"
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"These weights are outdated and the plddt metric may be inaccurate. Please update to the latest available weights."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sampler sets diffusion batch size based on the following, not strictly on batch size in vaildation transform
|
|
||||||
self.config.dataset_params["diffusion_batch_size_valid"] = diffusion_batch_size
|
|
||||||
self.config.af3_inference["num_steps"] = num_steps
|
|
||||||
self.config.af3_inference["solver"] = solver
|
|
||||||
|
|
||||||
# Load the AF-3 trainer
|
|
||||||
self.trainer = trainer_factory[self.config.experiment.trainer](
|
|
||||||
config=self.config
|
|
||||||
)
|
|
||||||
self.trainer.checkpoint = checkpoint
|
|
||||||
|
|
||||||
# Set the output directory for the CIF files (e.g., predicted structures)
|
|
||||||
self.cif_out_dir = Path(cif_out_dir) if cif_out_dir else Path("./")
|
|
||||||
|
|
||||||
# Model parameters
|
|
||||||
self.n_recycles = n_recycles
|
|
||||||
self.diffusion_batch_size = diffusion_batch_size
|
|
||||||
if "confidence_loss" in self.config.loss:
|
|
||||||
self.confidence_writer = WriteAF3Confidence(
|
|
||||||
**self.config.loss.confidence_loss
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.confidence_writer = None
|
|
||||||
|
|
||||||
# Rename residues
|
|
||||||
self.residue_renaming_dict = residue_renaming_dict
|
|
||||||
self.temp_dir = Path(temp_dir)
|
|
||||||
|
|
||||||
self.overwrite = overwrite
|
|
||||||
|
|
||||||
def construct_pipeline(self):
|
|
||||||
"""Construct the AF3 inference pipeline."""
|
|
||||||
self.config.dataset_params.val.interface.transform.n_recycles = self.n_recycles
|
|
||||||
self.config.dataset_params.val.interface.transform.diffusion_batch_size = (
|
|
||||||
self.diffusion_batch_size
|
|
||||||
)
|
|
||||||
self.config.dataset_params.val.interface.transform.return_atom_array = (
|
|
||||||
True # Required for `to_cif`
|
|
||||||
)
|
|
||||||
|
|
||||||
assert (
|
|
||||||
self.config.dataset_params.val.interface.transform.n_recycles
|
|
||||||
== self.n_recycles
|
|
||||||
), "Number of recycles not set correctly."
|
|
||||||
assert (
|
|
||||||
self.config.dataset_params.val.interface.transform.diffusion_batch_size
|
|
||||||
== self.diffusion_batch_size
|
|
||||||
), "Diffusion batch size not set correctly."
|
|
||||||
pipeline = hydra.utils.instantiate(
|
|
||||||
self.config.dataset_params.val.interface.transform
|
|
||||||
)
|
|
||||||
return pipeline
|
|
||||||
|
|
||||||
def eval(self, files: list[PathLike]):
|
|
||||||
"""Evaluate the model on a set of spoofed CIF files.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
files (list[PathLike]): List of paths to spoofed CIF files or directories containing spoofed CIF files.
|
|
||||||
Coordinates must be present but may contain NaN values. If a directory is provided,
|
|
||||||
all files with the extensions .cif, .pdb, .bcif, .cif.gz, .pdb.gz, .bcif.gz will be processed.
|
|
||||||
"""
|
|
||||||
# Construct the model and load the checkpoint
|
|
||||||
gpu = "cuda:0" if torch.cuda.is_available() else "cpu"
|
|
||||||
self.trainer.construct_model(device=gpu, inference=True)
|
|
||||||
self.trainer.load_model()
|
|
||||||
|
|
||||||
# Set the model to evaluation mode
|
|
||||||
self.trainer.model.eval()
|
|
||||||
|
|
||||||
logger.info("Building Transform pipeline...")
|
|
||||||
|
|
||||||
# Construct the AF3 inference pipeline
|
|
||||||
pipeline = self.construct_pipeline()
|
|
||||||
|
|
||||||
logger.info(f"Found {len(files)} structures to predict: {files}.")
|
|
||||||
|
|
||||||
for structure in files:
|
|
||||||
# ... parse into an AtomArray (`parse` handles all valid formats)
|
|
||||||
logger.info(f"Parsing from path: {structure}")
|
|
||||||
#example_id = structure.name.split(".")[0]
|
|
||||||
example_id = ".".join(structure.name.split(".")[:-1])
|
|
||||||
|
|
||||||
# optionally, skip if output already exists
|
|
||||||
cif_output_path = example_id + '.cif'
|
|
||||||
cif_output_path = self.cif_out_dir / cif_output_path
|
|
||||||
if os.path.exists(cif_output_path) and not self.overwrite:
|
|
||||||
logger.info(f"Existing output for {example_id} found at {cif_output_path}. Skipping this example. Set --overwrite to not skip examples with existing output")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# If we're renaming residues, we do a brute-force replacement in the CIF file
|
|
||||||
if self.residue_renaming_dict:
|
|
||||||
logger.info(
|
|
||||||
f"Renaming residues in {structure} with brute-force find and replace: {self.residue_renaming_dict}"
|
|
||||||
)
|
|
||||||
with open(structure, "r") as f:
|
|
||||||
content = f.read()
|
|
||||||
for old_res, new_res in self.residue_renaming_dict.items():
|
|
||||||
content = content.replace(old_res, new_res)
|
|
||||||
structure = Path(self.temp_dir / structure.name)
|
|
||||||
with open(structure, "w") as f:
|
|
||||||
f.write(content)
|
|
||||||
|
|
||||||
out = parse(structure, remove_hydrogens=True)
|
|
||||||
|
|
||||||
# ... get the atom array and set NaN coordinates to random
|
|
||||||
atom_array = (
|
|
||||||
out["assemblies"]["1"][0]
|
|
||||||
if "assemblies" in out
|
|
||||||
else out["asym_unit"][0]
|
|
||||||
)
|
|
||||||
|
|
||||||
# HACK: Set NaN coordinates to random values to avoid unexpected behavior in the pipeline
|
|
||||||
atom_array.coord[np.isnan(atom_array.coord)] = np.random.rand(
|
|
||||||
*atom_array.coord[np.isnan(atom_array.coord)].shape
|
|
||||||
)
|
|
||||||
|
|
||||||
# ... assemble the pipeline input in a format compatible with the DataHub pipeline
|
|
||||||
pipeline_input = {
|
|
||||||
"example_id": example_id,
|
|
||||||
"atom_array": atom_array,
|
|
||||||
"chain_info": out["chain_info"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# ... run dataloading and featurization
|
|
||||||
pipeline_output = pipeline(pipeline_input)
|
|
||||||
|
|
||||||
# Model inference
|
|
||||||
with torch.no_grad():
|
|
||||||
outputs = self.trainer.sampler.sample(
|
|
||||||
[pipeline_output],
|
|
||||||
n_cycle=self.n_recycles,
|
|
||||||
use_amp=self.config.training_params.use_amp,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Override the AtomArray with the predited coordinates
|
|
||||||
atom_array_stack = build_stack_from_atom_array_and_batched_coords(
|
|
||||||
outputs["X_L"].cpu().numpy(), pipeline_output["atom_array"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Write the atom array to a CIF file
|
|
||||||
# NOTE: To make the secondary structure appear, run `dss` in PyMol (see: https://biology.stackexchange.com/questions/70143/can-pymol-show-cartoon-secondary-structure-for-a-pdb-of-multiple-frames)
|
|
||||||
out_path = to_cif_file(
|
|
||||||
atom_array_stack, self.cif_out_dir / example_id, file_type="cif"
|
|
||||||
)
|
|
||||||
logger.info(f"Prediction for {example_id} written to {out_path}.")
|
|
||||||
|
|
||||||
if "confidence" in outputs:
|
|
||||||
loss_input = {
|
|
||||||
"example_id": example_id,
|
|
||||||
"is_real_atom": pipeline_output["confidence_feats"]["is_real_atom"],
|
|
||||||
}
|
|
||||||
logger.info(f"Writing {example_id}.score to {self.cif_out_dir}")
|
|
||||||
df = self.confidence_writer(None, outputs, loss_input)
|
|
||||||
df.to_csv(self.cif_out_dir / f"{example_id}.score", index=False)
|
|
||||||
logger.info(
|
|
||||||
f"Confidence metrics for {example_id}.cif written to {self.cif_out_dir / example_id}.score."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser(description="Evaluate AF3 using specified paths.")
|
|
||||||
parser.add_argument(
|
|
||||||
"inputs",
|
|
||||||
nargs="+",
|
|
||||||
help="List of paths to supported file types or directories of of supported files.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--checkpoint_path", type=str, required=True, help="Path to the checkpoint file"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--cif_out_dir", type=str, required=True, help="Directory for output CIF files"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--config_override_path",
|
|
||||||
type=str,
|
|
||||||
required=False,
|
|
||||||
help="Path to a yaml file with configs to override those in the checkpoint file",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--n_recycles", type=int, default=10, help="Number of recycles for AF3"
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--diffusion_batch_size",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Diffusion batch size for AF3",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--rename_residues",
|
|
||||||
type=str,
|
|
||||||
default="",
|
|
||||||
help="Dictionary of residue names to rename to avoid CCD clashes, e.g., {'ALA': 'L:1'}",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--num_steps",
|
|
||||||
type=int,
|
|
||||||
default=200,
|
|
||||||
help="Number of steps for sampling of the diffusion model",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--solver",
|
|
||||||
type=str,
|
|
||||||
default="af3",
|
|
||||||
help="Solver to use for inference. Options are 'af3', 'simple', 'euler', and 'heun'.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--overwrite",
|
|
||||||
default=False,
|
|
||||||
action="store_true",
|
|
||||||
help="Overwrite existing .cif outputs with new runs.",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
|
||||||
temp_dir = Path(temp_dir)
|
|
||||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Prepare inputs based on the file types
|
|
||||||
file_paths_for_prediction = _build_file_paths_for_prediction(
|
|
||||||
args.inputs, temp_dir
|
|
||||||
)
|
|
||||||
|
|
||||||
# Rename residues if necessary (e.g., for MPNN outputs that have ligand names that clash with the CCD)
|
|
||||||
residue_renaming_dict = (
|
|
||||||
json.loads(args.rename_residues) if args.rename_residues else {}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Construct the evaluator
|
|
||||||
evaluator = EvaluateAF3(
|
|
||||||
checkpoint_path=args.checkpoint_path,
|
|
||||||
cif_out_dir=args.cif_out_dir,
|
|
||||||
config_override_path=args.config_override_path,
|
|
||||||
n_recycles=args.n_recycles,
|
|
||||||
diffusion_batch_size=args.diffusion_batch_size,
|
|
||||||
residue_renaming_dict=residue_renaming_dict,
|
|
||||||
temp_dir=temp_dir,
|
|
||||||
num_steps=args.num_steps,
|
|
||||||
solver=args.solver,
|
|
||||||
overwrite=args.overwrite
|
|
||||||
)
|
|
||||||
|
|
||||||
# Launch the evaluation
|
|
||||||
evaluator.eval(files=file_paths_for_prediction)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from rf2aa.loss.af3_losses import distogram_loss
|
|
||||||
from rf2aa.metrics.metrics_base import Metric
|
|
||||||
|
|
||||||
|
|
||||||
class DistogramLoss(Metric):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.cce_loss = nn.CrossEntropyLoss(reduction="none")
|
|
||||||
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
pred_distogram = network_output["distogram"]
|
|
||||||
X_rep_atoms_I = loss_input["X_rep_atoms_I"]
|
|
||||||
crd_mask_rep_atoms_I = loss_input["crd_mask_rep_atoms_I"]
|
|
||||||
loss = distogram_loss(
|
|
||||||
pred_distogram, X_rep_atoms_I, crd_mask_rep_atoms_I, self.cce_loss
|
|
||||||
)
|
|
||||||
return {"distogram_loss": loss.detach().item()}
|
|
||||||
|
|
||||||
|
|
||||||
class SaveDistograms(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
pred_distogram = network_output["distogram"]
|
|
||||||
example_id = loss_input["example_id"]
|
|
||||||
torch.save(pred_distogram, f"distograms/{example_id}.pt")
|
|
||||||
return {"distogram_saved": True}
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from rf2aa.metrics.metrics_base import Metric
|
|
||||||
|
|
||||||
|
|
||||||
def calc_lddt(
|
|
||||||
X_L,
|
|
||||||
X_gt_L,
|
|
||||||
crd_mask_L,
|
|
||||||
tok_idx,
|
|
||||||
pairs_to_score=None,
|
|
||||||
distance_cutoff=15.0,
|
|
||||||
use_amp=True,
|
|
||||||
eps=1e-6,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
X_L: predicted coordinates (D, L, 3)
|
|
||||||
X_gt_L: ground truth coordinates (D, L, 3)
|
|
||||||
crd_mask_L: mask of coordinates (D, L,)
|
|
||||||
tok_idx: token index of each atom (L,)
|
|
||||||
pairs_to_score: pairs to score (L, L) | None
|
|
||||||
"""
|
|
||||||
D, L = X_L.shape[:2]
|
|
||||||
if pairs_to_score is None:
|
|
||||||
pairs_to_score = torch.ones((L, L), dtype=torch.bool).triu(0).to(X_L.device)
|
|
||||||
else:
|
|
||||||
assert pairs_to_score.shape == (L, L)
|
|
||||||
pairs_to_score = pairs_to_score.triu(0).to(X_L.device)
|
|
||||||
|
|
||||||
first_index, second_index = torch.nonzero(pairs_to_score, as_tuple=True)
|
|
||||||
|
|
||||||
lddt = []
|
|
||||||
for d in range(D):
|
|
||||||
ground_truth_distances = torch.linalg.norm(
|
|
||||||
X_gt_L[d, first_index] - X_gt_L[d, second_index], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
pair_mask = torch.logical_and(
|
|
||||||
ground_truth_distances > 0, ground_truth_distances < distance_cutoff
|
|
||||||
)
|
|
||||||
|
|
||||||
# only score pairs that are resolved in the ground truth
|
|
||||||
pair_mask *= crd_mask_L[d, first_index] * crd_mask_L[d, second_index]
|
|
||||||
# don't score pairs that are in the same token
|
|
||||||
pair_mask *= tok_idx[first_index] != tok_idx[second_index]
|
|
||||||
|
|
||||||
valid_pairs = pair_mask.nonzero(as_tuple=True)
|
|
||||||
pair_mask = pair_mask[valid_pairs].to(X_L.dtype)
|
|
||||||
ground_truth_distances = ground_truth_distances[valid_pairs]
|
|
||||||
first_index, second_index = first_index[valid_pairs], second_index[valid_pairs]
|
|
||||||
|
|
||||||
predicted_distances = torch.linalg.norm(
|
|
||||||
X_L[d, first_index] - X_L[d, second_index], dim=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
delta_distances = torch.abs(predicted_distances - ground_truth_distances + eps)
|
|
||||||
del predicted_distances, ground_truth_distances
|
|
||||||
|
|
||||||
lddt.append(
|
|
||||||
0.25
|
|
||||||
* (
|
|
||||||
torch.sum((delta_distances < 4.0) * pair_mask)
|
|
||||||
+ torch.sum((delta_distances < 2.0) * pair_mask)
|
|
||||||
+ torch.sum((delta_distances < 1.0) * pair_mask)
|
|
||||||
+ torch.sum((delta_distances < 0.5) * pair_mask)
|
|
||||||
)
|
|
||||||
/ (torch.sum(pair_mask) + eps)
|
|
||||||
)
|
|
||||||
|
|
||||||
return torch.tensor(lddt)
|
|
||||||
|
|
||||||
|
|
||||||
class InterfaceLDDT(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
interface_lddt = {"interface_lddt_first": [], "interface_lddt_best": []}
|
|
||||||
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
|
|
||||||
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
|
|
||||||
for chain_i, chain_j, interface_type in loss_input["interfaces_to_score"]:
|
|
||||||
# get tokens in chain_i and chain_j
|
|
||||||
chain_i_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
chain_j_tokens = chain_iid_token_lvl == chain_j
|
|
||||||
# convert the token level to the atom level
|
|
||||||
chain_i_atoms = chain_i_tokens[tok_idx]
|
|
||||||
chain_j_atoms = chain_j_tokens[tok_idx]
|
|
||||||
# compute the intersection of chain_i and chain_j
|
|
||||||
|
|
||||||
chain_ij_atoms = torch.einsum(
|
|
||||||
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
|
|
||||||
).to(network_output["X_L"].device)
|
|
||||||
|
|
||||||
# symmetrize
|
|
||||||
chain_ij_atoms = chain_ij_atoms | chain_ij_atoms.T
|
|
||||||
|
|
||||||
# compute lddt using the pairs_to_score from the intersection
|
|
||||||
lddt = calc_lddt(
|
|
||||||
network_output["X_L"],
|
|
||||||
loss_input["X_gt_L"],
|
|
||||||
loss_input["crd_mask_L"],
|
|
||||||
torch.tensor(tok_idx).to(network_output["X_L"].device),
|
|
||||||
pairs_to_score=chain_ij_atoms,
|
|
||||||
distance_cutoff=30.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
interface_lddt["interface_lddt_first"].append(lddt[0].item())
|
|
||||||
interface_lddt["interface_lddt_best"].append(lddt.max().item())
|
|
||||||
return interface_lddt
|
|
||||||
|
|
||||||
|
|
||||||
class ConfidenceInterfaceLDDT(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
interface_lddt = {
|
|
||||||
"interface_lddt_first": [],
|
|
||||||
"interface_lddt_best": [],
|
|
||||||
"interface_lddt_pae": [],
|
|
||||||
"interface_lddt_pde": [],
|
|
||||||
"interface_lddt_plddt": [],
|
|
||||||
"interface_lddt_af3_style_ipae": [],
|
|
||||||
"interface_lddt_af3_style_lig_ipae": [],
|
|
||||||
}
|
|
||||||
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
|
|
||||||
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
|
|
||||||
for chain_i, chain_j, interface_type in loss_input["interfaces_to_score"]:
|
|
||||||
# get tokens in chain_i and chain_j
|
|
||||||
chain_i_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
chain_j_tokens = chain_iid_token_lvl == chain_j
|
|
||||||
# convert the token level to the atom level
|
|
||||||
chain_i_atoms = chain_i_tokens[tok_idx]
|
|
||||||
chain_j_atoms = chain_j_tokens[tok_idx]
|
|
||||||
# compute the intersection of chain_i and chain_j
|
|
||||||
|
|
||||||
chain_ij_atoms = torch.einsum(
|
|
||||||
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
|
|
||||||
).to(network_output["X_L"].device)
|
|
||||||
|
|
||||||
# compute lddt using the pairs_to_score from the intersection
|
|
||||||
lddt = calc_lddt(
|
|
||||||
network_output["X_L"],
|
|
||||||
loss_input["X_gt_L"],
|
|
||||||
loss_input["crd_mask_L"],
|
|
||||||
torch.tensor(tok_idx).to(network_output["X_L"].device),
|
|
||||||
pairs_to_score=chain_ij_atoms,
|
|
||||||
distance_cutoff=30.0,
|
|
||||||
)
|
|
||||||
pae_idx = loss_input["pae_idx"]
|
|
||||||
pde_idx = loss_input["pde_idx"]
|
|
||||||
plddt_idx = loss_input["plddt_idx"]
|
|
||||||
af3_style_ipae_idx = loss_input["best_interface_idx"][
|
|
||||||
f"{chain_i}-{chain_j}"
|
|
||||||
]
|
|
||||||
interface_lddt["interface_lddt_first"].append(lddt[0].item())
|
|
||||||
interface_lddt["interface_lddt_best"].append(lddt.max().item())
|
|
||||||
interface_lddt["interface_lddt_pae"].append(lddt[pae_idx].item())
|
|
||||||
interface_lddt["interface_lddt_pde"].append(lddt[pde_idx].item())
|
|
||||||
interface_lddt["interface_lddt_plddt"].append(lddt[plddt_idx].item())
|
|
||||||
interface_lddt["interface_lddt_af3_style_ipae"].append(
|
|
||||||
lddt[af3_style_ipae_idx].item()
|
|
||||||
)
|
|
||||||
interface_lddt["interface_lddt_af3_style_lig_ipae"].append(
|
|
||||||
lddt[loss_input["best_lig_ipae_idx"][f"{chain_i}-{chain_j}"]].item()
|
|
||||||
)
|
|
||||||
return interface_lddt
|
|
||||||
|
|
||||||
|
|
||||||
class ConfidenceChainLDDT(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
chain_lddt = {
|
|
||||||
"chain_lddt_first": [],
|
|
||||||
"chain_lddt_best": [],
|
|
||||||
"chain_lddt_pae": [],
|
|
||||||
"chain_lddt_pde": [],
|
|
||||||
"chain_lddt_plddt": [],
|
|
||||||
"chain_lddt_af3_style_chain": [],
|
|
||||||
"chain_lddt_af3_style_single_chain": [],
|
|
||||||
}
|
|
||||||
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
|
|
||||||
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
|
|
||||||
for chain_i, chain_type in loss_input["pn_units_to_score"]:
|
|
||||||
# print(chain_type)
|
|
||||||
# get tokens in chain_i and chain_j
|
|
||||||
chain_i_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
chain_j_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
# convert the token level to the atom level
|
|
||||||
chain_i_atoms = chain_i_tokens[tok_idx]
|
|
||||||
chain_j_atoms = chain_j_tokens[tok_idx]
|
|
||||||
# compute the intersection of chain_i and chain_j
|
|
||||||
chain_ij_atoms = torch.einsum(
|
|
||||||
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
|
|
||||||
).to(network_output["X_L"].device)
|
|
||||||
|
|
||||||
# compute lddt using the pairs_to_score from the intersection
|
|
||||||
lddt = calc_lddt(
|
|
||||||
network_output["X_L"],
|
|
||||||
loss_input["X_gt_L"],
|
|
||||||
loss_input["crd_mask_L"],
|
|
||||||
torch.tensor(tok_idx).to(network_output["X_L"].device),
|
|
||||||
pairs_to_score=chain_ij_atoms,
|
|
||||||
)
|
|
||||||
|
|
||||||
chain_lddt["chain_lddt_first"].append(lddt[0].item())
|
|
||||||
chain_lddt["chain_lddt_best"].append(lddt.max().item())
|
|
||||||
chain_lddt["chain_lddt_pae"].append(lddt[loss_input["pae_idx"]].item())
|
|
||||||
chain_lddt["chain_lddt_pde"].append(lddt[loss_input["pde_idx"]].item())
|
|
||||||
chain_lddt["chain_lddt_plddt"].append(lddt[loss_input["plddt_idx"]].item())
|
|
||||||
chain_lddt["chain_lddt_af3_style_chain"].append(
|
|
||||||
lddt[loss_input["best_chain_to_all_idx"][chain_i]].item()
|
|
||||||
)
|
|
||||||
chain_lddt["chain_lddt_af3_style_single_chain"].append(
|
|
||||||
lddt[loss_input["best_chain_to_self_idx"][chain_i]].item()
|
|
||||||
)
|
|
||||||
return chain_lddt
|
|
||||||
|
|
||||||
|
|
||||||
class LigRMSD(Metric):
|
|
||||||
# TODO: move these to a separate file, here for backwards compatibility with configs
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class InterfacePocketLigandRMSD(Metric):
|
|
||||||
# TODO: move these to a separate file, here for backwards compatibility with configs
|
|
||||||
|
|
||||||
"""
|
|
||||||
Compute the Ligand RMSD for each interface in the interfaces_to_score list.
|
|
||||||
|
|
||||||
The ligand RMSD is computed only for interface protein-ligand chains.
|
|
||||||
Given a chain pair (chain_i, chain_j) and the interface type, the RMSD is computed as follows:
|
|
||||||
- if the interface_type is protein_ligand: continue
|
|
||||||
- Rigid align the GT coordinates of onto the predicted coordinates using only the CA atoms within 10A of the ligand in chain_i or chain_j
|
|
||||||
- Compute the RMSD between the aligned GT coordinates and the predicted coordinates of the ligand atoms
|
|
||||||
|
|
||||||
Note: if the interface is not between a protein-ligand pair, the RMSD is set to -1
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
class ChainLDDT(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
chain_lddt = {"chain_lddt_first": [], "chain_lddt_best": []}
|
|
||||||
chain_iid_token_lvl = loss_input["chain_iid_token_lvl"]
|
|
||||||
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
|
|
||||||
for chain_i, chain_type in loss_input["pn_units_to_score"]:
|
|
||||||
# get tokens in chain_i and chain_j
|
|
||||||
chain_i_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
chain_j_tokens = chain_iid_token_lvl == chain_i
|
|
||||||
# convert the token level to the atom level
|
|
||||||
chain_i_atoms = chain_i_tokens[tok_idx]
|
|
||||||
chain_j_atoms = chain_j_tokens[tok_idx]
|
|
||||||
# compute the intersection of chain_i and chain_j
|
|
||||||
|
|
||||||
chain_ij_atoms = torch.einsum(
|
|
||||||
"L, K -> LK", torch.tensor(chain_i_atoms), torch.tensor(chain_j_atoms)
|
|
||||||
).to(network_output["X_L"].device)
|
|
||||||
|
|
||||||
# compute lddt using the pairs_to_score from the intersection
|
|
||||||
lddt = calc_lddt(
|
|
||||||
network_output["X_L"],
|
|
||||||
loss_input["X_gt_L"],
|
|
||||||
loss_input["crd_mask_L"],
|
|
||||||
torch.tensor(tok_idx).to(network_output["X_L"].device),
|
|
||||||
pairs_to_score=chain_ij_atoms,
|
|
||||||
)
|
|
||||||
|
|
||||||
chain_lddt["chain_lddt_first"].append(lddt[0].item())
|
|
||||||
chain_lddt["chain_lddt_best"].append(lddt.max().item())
|
|
||||||
return chain_lddt
|
|
||||||
|
|
||||||
|
|
||||||
class LDDTByDiffusionStep(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
lddt_by_step = {"lddt_by_step": []}
|
|
||||||
tok_idx = network_input["f"]["atom_to_token_map"].cpu().numpy()
|
|
||||||
for i, X_L in enumerate(network_output["X_denoised_L_traj"]):
|
|
||||||
lddt = calc_lddt(
|
|
||||||
X_L,
|
|
||||||
loss_input["X_gt_L"],
|
|
||||||
loss_input["crd_mask_L"],
|
|
||||||
torch.tensor(tok_idx).to(network_output["X_L"].device),
|
|
||||||
)
|
|
||||||
lddt_by_step["lddt_by_step"].append(lddt)
|
|
||||||
return lddt_by_step
|
|
||||||
|
|
||||||
|
|
||||||
class SmoothedLDDT(nn.Module):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
raise NotImplementedError()
|
|
||||||
@@ -1,42 +0,0 @@
|
|||||||
import hydra
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
# class Metric:
|
|
||||||
# def __call__(self, rf_output, loss_calc_items) -> float:
|
|
||||||
# raise NotImplementedError("base class")
|
|
||||||
|
|
||||||
|
|
||||||
class MetricManager(nn.Module):
|
|
||||||
"""
|
|
||||||
Similar syntax to LossManager, but for metrics
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, **metrics):
|
|
||||||
super().__init__()
|
|
||||||
self.to_compute = []
|
|
||||||
for metric_name, metric in metrics.items():
|
|
||||||
metric_fn = hydra.utils.instantiate(metric)
|
|
||||||
print(f"Adding metric {metric_name} to the validation metrics")
|
|
||||||
self.to_compute.append(metric_fn)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
network_input,
|
|
||||||
network_output,
|
|
||||||
loss_input,
|
|
||||||
):
|
|
||||||
loss_dict = {}
|
|
||||||
for loss_fn in self.to_compute:
|
|
||||||
loss_dict_ = loss_fn(network_input, network_output, loss_input)
|
|
||||||
loss_dict.update(loss_dict_)
|
|
||||||
return loss_dict
|
|
||||||
|
|
||||||
|
|
||||||
class Metric:
|
|
||||||
def __call__(self, network_input, network_output, loss_input) -> float:
|
|
||||||
raise NotImplementedError("base class")
|
|
||||||
|
|
||||||
|
|
||||||
class AddExampleID(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
return {"example_id": loss_input["example_id"]}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
from typing import Dict
|
|
||||||
|
|
||||||
from rf2aa.metrics.predicted_error import PAE, PLDDT
|
|
||||||
|
|
||||||
|
|
||||||
class MetricManager:
|
|
||||||
def __init__(self, config) -> None:
|
|
||||||
self.config = config
|
|
||||||
self.metrics = {metric: metrics_factory[metric] for metric in config.metrics}
|
|
||||||
|
|
||||||
def __call__(self, rf_outputs, loss_calc_items) -> Dict:
|
|
||||||
metrics_dict = {}
|
|
||||||
for metric_name, metric in self.metrics:
|
|
||||||
metric_value = metric(rf_outputs, loss_calc_items)
|
|
||||||
metrics_dict[metric_name] = metric_value
|
|
||||||
return metrics_dict
|
|
||||||
|
|
||||||
|
|
||||||
metrics_factory = {"mean_pae": PAE(), "mean_plddt": PLDDT()}
|
|
||||||
@@ -1,359 +0,0 @@
|
|||||||
from itertools import combinations
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import tree
|
|
||||||
|
|
||||||
from rf2aa.chemical import ChemicalData as ChemData
|
|
||||||
from rf2aa.metrics.metric_utils import (
|
|
||||||
compute_mean_over_subsampled_pairs,
|
|
||||||
create_chainwise_masks_1d,
|
|
||||||
create_chainwise_masks_2d,
|
|
||||||
create_interface_masks_2d,
|
|
||||||
spread_batch_into_dictionary,
|
|
||||||
unbin_logits,
|
|
||||||
)
|
|
||||||
from rf2aa.metrics.metrics_base import Metric
|
|
||||||
|
|
||||||
|
|
||||||
class WriteAF3Confidence(Metric):
|
|
||||||
"""
|
|
||||||
Given some config setups of pae, plddt, and pde, computes aggregate metrics for the model's confidence predictions
|
|
||||||
TO be used at inference time for users to know how confident their predictions are.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, pae, plddt, pde, **kwargs):
|
|
||||||
super().__init__()
|
|
||||||
self.pae = pae
|
|
||||||
self.plddt = plddt
|
|
||||||
self.pde = pde
|
|
||||||
|
|
||||||
def __call__(self, network_input, network_output, loss_input) -> Any:
|
|
||||||
plddt_logit_stack = network_output["confidence"]["plddt_logits"]
|
|
||||||
pae_logits = network_output["confidence"]["pae_logits"]
|
|
||||||
pde_logits = network_output["confidence"]["pde_logits"]
|
|
||||||
ch_label = network_output["confidence"]["chain_iid_token_lvl"]
|
|
||||||
is_real_atom = network_output["confidence"]["is_real_atom"]
|
|
||||||
|
|
||||||
# reorder the input tensors to be in (B, n_bins, ...) format for unbinning
|
|
||||||
plddt = unbin_logits(
|
|
||||||
plddt_logit_stack.reshape(
|
|
||||||
-1,
|
|
||||||
plddt_logit_stack.shape[1],
|
|
||||||
ChemData().NHEAVY,
|
|
||||||
self.plddt.n_bins,
|
|
||||||
).permute(0, 3, 1, 2).float(),
|
|
||||||
self.plddt.max_value,
|
|
||||||
self.plddt.n_bins,
|
|
||||||
)
|
|
||||||
pae = unbin_logits(
|
|
||||||
pae_logits.permute(0, 3, 1, 2).float(), self.pae.max_value, self.pae.n_bins
|
|
||||||
)
|
|
||||||
pde = unbin_logits(
|
|
||||||
pde_logits.permute(0, 3, 1, 2).float(), self.pde.max_value, self.pde.n_bins
|
|
||||||
)
|
|
||||||
|
|
||||||
pae_interface = {}
|
|
||||||
pde_interface = {}
|
|
||||||
for interface, pairs_to_score in create_interface_masks_2d(
|
|
||||||
ch_label, device=pae.device
|
|
||||||
).items():
|
|
||||||
pae_interface[interface] = spread_batch_into_dictionary(
|
|
||||||
compute_mean_over_subsampled_pairs(pae, pairs_to_score)
|
|
||||||
)
|
|
||||||
pde_interface[interface] = spread_batch_into_dictionary(
|
|
||||||
compute_mean_over_subsampled_pairs(pde, pairs_to_score)
|
|
||||||
)
|
|
||||||
|
|
||||||
pae_chainwise = {}
|
|
||||||
pde_chainwise = {}
|
|
||||||
for chain, pairs_to_score in create_chainwise_masks_2d(
|
|
||||||
ch_label, device=pae.device
|
|
||||||
).items():
|
|
||||||
pae_chainwise[chain] = spread_batch_into_dictionary(
|
|
||||||
compute_mean_over_subsampled_pairs(pae, pairs_to_score)
|
|
||||||
)
|
|
||||||
pde_chainwise[chain] = spread_batch_into_dictionary(
|
|
||||||
compute_mean_over_subsampled_pairs(pde, pairs_to_score)
|
|
||||||
)
|
|
||||||
|
|
||||||
plddt_chainwise = {}
|
|
||||||
for chain, residue_atom_indices_to_score in create_chainwise_masks_1d(
|
|
||||||
ch_label, device=is_real_atom.device
|
|
||||||
).items():
|
|
||||||
chain_is_real_atom = (
|
|
||||||
is_real_atom[..., : ChemData().NHEAVY]
|
|
||||||
* residue_atom_indices_to_score[:, None]
|
|
||||||
)
|
|
||||||
plddt_chainwise[chain] = spread_batch_into_dictionary(
|
|
||||||
compute_mean_over_subsampled_pairs(plddt, chain_is_real_atom)
|
|
||||||
)
|
|
||||||
|
|
||||||
confidence_data = {
|
|
||||||
"example_id": loss_input["example_id"],
|
|
||||||
"mean_plddt": spread_batch_into_dictionary(plddt.mean(dim=(-1, -2))),
|
|
||||||
"mean_pae": spread_batch_into_dictionary(pae.mean(dim=(-1, -2))),
|
|
||||||
"mean_pde": spread_batch_into_dictionary(pde.mean(dim=(-1, -2))),
|
|
||||||
"chain_wise_mean_plddt": plddt_chainwise,
|
|
||||||
"chain_wise_mean_pae": pae_chainwise,
|
|
||||||
"chain_wise_mean_pde": pde_chainwise,
|
|
||||||
"interface_wise_mean_pae": pae_interface,
|
|
||||||
"interface_wise_mean_pde": pde_interface,
|
|
||||||
}
|
|
||||||
|
|
||||||
num_batches = plddt.shape[0]
|
|
||||||
chains = np.unique(ch_label)
|
|
||||||
num_chains = len(chains)
|
|
||||||
chain_pairs = list(combinations(chains, 2))
|
|
||||||
|
|
||||||
# TODO: refactor to remove for loops
|
|
||||||
rows = []
|
|
||||||
for batch_idx in range(num_batches):
|
|
||||||
for chain_id in range(num_chains):
|
|
||||||
chain = chains[chain_id]
|
|
||||||
row = {
|
|
||||||
"example_id": confidence_data["example_id"],
|
|
||||||
"chain_chainwise": chain,
|
|
||||||
"chainwise_plddt": confidence_data["chain_wise_mean_plddt"][chain][
|
|
||||||
batch_idx
|
|
||||||
],
|
|
||||||
"chainwise_pde": confidence_data["chain_wise_mean_pde"][chain][
|
|
||||||
batch_idx
|
|
||||||
],
|
|
||||||
"chainwise_pae": confidence_data["chain_wise_mean_pae"][chain][
|
|
||||||
batch_idx
|
|
||||||
],
|
|
||||||
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
||||||
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
||||||
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
||||||
"batch_idx": batch_idx,
|
|
||||||
}
|
|
||||||
rows.append(row)
|
|
||||||
for interface in chain_pairs:
|
|
||||||
chain_i, chain_j = interface
|
|
||||||
row = {
|
|
||||||
"example_id": confidence_data["example_id"],
|
|
||||||
"chain_i_interface": chain_i,
|
|
||||||
"chain_j_interface": chain_j,
|
|
||||||
"pae_interface": confidence_data["interface_wise_mean_pae"][
|
|
||||||
interface
|
|
||||||
][batch_idx],
|
|
||||||
"pde_interface": confidence_data["interface_wise_mean_pde"][
|
|
||||||
interface
|
|
||||||
][batch_idx],
|
|
||||||
"overall_plddt": confidence_data["mean_plddt"][batch_idx],
|
|
||||||
"overall_pde": confidence_data["mean_pde"][batch_idx],
|
|
||||||
"overall_pae": confidence_data["mean_pae"][batch_idx],
|
|
||||||
"batch_idx": batch_idx,
|
|
||||||
}
|
|
||||||
rows.append(row)
|
|
||||||
|
|
||||||
return pd.DataFrame(rows)
|
|
||||||
|
|
||||||
|
|
||||||
class GetConfidenceIndices(Metric):
|
|
||||||
def __call__(self, network_input, network_output, loss_input):
|
|
||||||
# AF3's ranking metrics work like this, but using ptm instead of ipae:
|
|
||||||
confidence_loss = loss_input["confidence_loss"]
|
|
||||||
del loss_input["confidence_loss"]
|
|
||||||
|
|
||||||
ch_label = loss_input["chain_iid_token_lvl"]
|
|
||||||
scored_chains, interfaces, interface_chains = select_scored_units(loss_input)
|
|
||||||
|
|
||||||
chain_to_all_masks = create_chain_to_all_masks(ch_label, scored_chains)
|
|
||||||
chain_to_self_masks = create_chain_to_self_masks(ch_label, scored_chains)
|
|
||||||
interface_masks, lig_chains = create_interface_masks(
|
|
||||||
ch_label, interfaces, loss_input["is_ligand"]
|
|
||||||
)
|
|
||||||
|
|
||||||
# map everything to gpu
|
|
||||||
gpu = network_output["confidence"]["plddt_logits"].device
|
|
||||||
chain_to_all_masks = tree.map_structure(
|
|
||||||
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_all_masks
|
|
||||||
)
|
|
||||||
chain_to_self_masks = tree.map_structure(
|
|
||||||
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, chain_to_self_masks
|
|
||||||
)
|
|
||||||
interface_masks = tree.map_structure(
|
|
||||||
lambda x: x.to(gpu) if hasattr(x, "cpu") else x, interface_masks
|
|
||||||
)
|
|
||||||
|
|
||||||
confidence = network_output["confidence"]
|
|
||||||
|
|
||||||
plddt_logits = confidence["plddt_logits"]
|
|
||||||
|
|
||||||
# Reshape logits to B, K, L, NHEAVY
|
|
||||||
is_real_atom = network_output["confidence"]["is_real_atom"]
|
|
||||||
plddt_logits = plddt_logits.reshape(
|
|
||||||
-1, plddt_logits.shape[1], ChemData().NHEAVY, confidence_loss.plddt.n_bins
|
|
||||||
).permute(0, 3, 1, 2).float()
|
|
||||||
# Reshape the pae and pde logits to B, K, L, L
|
|
||||||
pae_logits = confidence["pae_logits"].permute(0, 3, 1, 2).float()
|
|
||||||
pde_logits = confidence["pde_logits"].permute(0, 3, 1, 2).float()
|
|
||||||
|
|
||||||
pae_logits_unbinned = unbin_logits(
|
|
||||||
pae_logits, confidence_loss.pae.max_value, confidence_loss.pae.n_bins
|
|
||||||
)
|
|
||||||
plddt_logits_unbinned = unbin_logits(
|
|
||||||
plddt_logits, confidence_loss.plddt.max_value, confidence_loss.plddt.n_bins
|
|
||||||
)
|
|
||||||
pde_logits_unbinned = unbin_logits(
|
|
||||||
pde_logits, confidence_loss.pde.max_value, confidence_loss.pde.n_bins
|
|
||||||
)
|
|
||||||
|
|
||||||
complex_pae = pae_logits_unbinned.mean(dim=(1, 2))
|
|
||||||
complex_pde = pde_logits_unbinned.mean(dim=(1, 2))
|
|
||||||
complex_plddt = (
|
|
||||||
plddt_logits_unbinned * is_real_atom[..., : ChemData().NHEAVY]
|
|
||||||
).sum(dim=(1, 2)) / is_real_atom[..., : ChemData().NHEAVY].sum()
|
|
||||||
|
|
||||||
loss_input["pae_idx"] = torch.argmin(complex_pae)
|
|
||||||
loss_input["pde_idx"] = torch.argmin(complex_pde)
|
|
||||||
loss_input["plddt_idx"] = torch.argmax(complex_plddt)
|
|
||||||
|
|
||||||
chain_to_self_paes = get_masked_error_per_chain(
|
|
||||||
scored_chains, chain_to_self_masks, pae_logits_unbinned
|
|
||||||
)
|
|
||||||
chain_to_all_paes = get_masked_error_per_chain(
|
|
||||||
scored_chains, chain_to_all_masks, pae_logits_unbinned
|
|
||||||
)
|
|
||||||
interface_chain_paes = get_masked_error_per_chain(
|
|
||||||
interface_chains, interface_masks, pae_logits_unbinned
|
|
||||||
)
|
|
||||||
# average over both interfaces
|
|
||||||
average_interface_paes = get_average_error_per_interface(
|
|
||||||
interfaces, lig_chains, interface_chain_paes
|
|
||||||
)
|
|
||||||
|
|
||||||
loss_input["best_chain_to_all_idx"] = get_lowest_error_indices(
|
|
||||||
chain_to_all_paes
|
|
||||||
)
|
|
||||||
loss_input["best_chain_to_self_idx"] = get_lowest_error_indices(
|
|
||||||
chain_to_self_paes
|
|
||||||
)
|
|
||||||
loss_input["best_interface_idx"] = get_lowest_error_indices(
|
|
||||||
average_interface_paes
|
|
||||||
)
|
|
||||||
# for ligands, we don't average the error
|
|
||||||
loss_input["best_lig_ipae_idx"] = get_lowest_error_ligand_indices(
|
|
||||||
interface_chain_paes, interfaces, lig_chains
|
|
||||||
)
|
|
||||||
|
|
||||||
return loss_input
|
|
||||||
|
|
||||||
|
|
||||||
def select_scored_units(loss_input):
|
|
||||||
scored_chains = []
|
|
||||||
interfaces = []
|
|
||||||
interface_chains = []
|
|
||||||
for k in loss_input["interfaces_to_score"]:
|
|
||||||
interfaces.append(f"{k[0]}-{k[1]}")
|
|
||||||
interface_chains.append(k[0])
|
|
||||||
interface_chains.append(k[1])
|
|
||||||
for k in loss_input["pn_units_to_score"]:
|
|
||||||
scored_chains.append(k[0])
|
|
||||||
|
|
||||||
return scored_chains, interfaces, interface_chains
|
|
||||||
|
|
||||||
|
|
||||||
def create_chain_to_all_masks(ch_label, chains_to_score):
|
|
||||||
unique_chains = np.unique(ch_label)
|
|
||||||
I = len(ch_label)
|
|
||||||
chain_to_all_masks = {}
|
|
||||||
for chain in unique_chains:
|
|
||||||
if chain in chains_to_score:
|
|
||||||
indices = torch.from_numpy((ch_label == chain))
|
|
||||||
mask = indices.unsqueeze(0) | indices.unsqueeze(1)
|
|
||||||
# set the diagonal to false
|
|
||||||
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
||||||
chain_to_all_masks[chain] = mask
|
|
||||||
return chain_to_all_masks
|
|
||||||
|
|
||||||
|
|
||||||
def create_chain_to_self_masks(ch_label, chains_to_score):
|
|
||||||
unique_chains = np.unique(ch_label)
|
|
||||||
I = len(ch_label)
|
|
||||||
chain_to_self_masks = {}
|
|
||||||
for chain in unique_chains:
|
|
||||||
if chain in chains_to_score:
|
|
||||||
indices = torch.from_numpy((ch_label == chain))
|
|
||||||
mask = indices.unsqueeze(0) & indices.unsqueeze(1)
|
|
||||||
# set the diagonal to false
|
|
||||||
mask = mask & ~torch.eye(I, device=mask.device, dtype=torch.bool)
|
|
||||||
chain_to_self_masks[chain] = mask
|
|
||||||
return chain_to_self_masks
|
|
||||||
|
|
||||||
|
|
||||||
def create_interface_masks(ch_label, interfaces, is_ligand):
|
|
||||||
interface_masks = {}
|
|
||||||
interface_chains = []
|
|
||||||
ligand_chains = []
|
|
||||||
for interface in interfaces:
|
|
||||||
interface_chains.append(interface.split("-")[0])
|
|
||||||
interface_chains.append(interface.split("-")[1])
|
|
||||||
interface_chains = set(interface_chains)
|
|
||||||
for chain in interface_chains:
|
|
||||||
chain_indices = torch.from_numpy((ch_label == chain))
|
|
||||||
|
|
||||||
to_self = chain_indices.unsqueeze(0) & chain_indices.unsqueeze(1)
|
|
||||||
to_all = chain_indices.unsqueeze(0) | chain_indices.unsqueeze(1)
|
|
||||||
interface_mask = to_all & ~to_self
|
|
||||||
interface_masks[chain] = interface_mask
|
|
||||||
|
|
||||||
if torch.all(is_ligand[chain_indices]):
|
|
||||||
ligand_chains.append(chain)
|
|
||||||
|
|
||||||
return interface_masks, ligand_chains
|
|
||||||
|
|
||||||
|
|
||||||
def get_masked_error_per_chain(chains, masks, unbinned_logits):
|
|
||||||
error = {}
|
|
||||||
for chain in chains:
|
|
||||||
mask = masks[chain]
|
|
||||||
chain_error = compute_mean_over_subsampled_pairs(unbinned_logits, mask)
|
|
||||||
error[chain] = chain_error
|
|
||||||
|
|
||||||
return error
|
|
||||||
|
|
||||||
|
|
||||||
def get_average_error_per_interface(interfaces, lig_chains, interface_errors):
|
|
||||||
average_error = {}
|
|
||||||
for interface in interfaces:
|
|
||||||
chain_a = interface.split("-")[0]
|
|
||||||
chain_b = interface.split("-")[1]
|
|
||||||
average_error[interface] = (
|
|
||||||
interface_errors[chain_a] + interface_errors[chain_b]
|
|
||||||
) / 2
|
|
||||||
|
|
||||||
return average_error
|
|
||||||
|
|
||||||
|
|
||||||
def get_lowest_error_indices(errors):
|
|
||||||
lowest_error_indices = {}
|
|
||||||
for k, v in errors.items():
|
|
||||||
lowest_error_indices[k] = torch.argmin(v)
|
|
||||||
|
|
||||||
return lowest_error_indices
|
|
||||||
|
|
||||||
|
|
||||||
def get_lowest_error_ligand_indices(errors, interfaces, lig_chains):
|
|
||||||
# ligands are a special case in AF3, where they only consider the ligand chain's error and not the average for the interface
|
|
||||||
lowest_error_indices = {}
|
|
||||||
for interface in interfaces:
|
|
||||||
chain_a = interface.split("-")[0]
|
|
||||||
chain_b = interface.split("-")[1]
|
|
||||||
if chain_a in lig_chains or chain_b in lig_chains:
|
|
||||||
if chain_a in lig_chains:
|
|
||||||
lig_chain = chain_a
|
|
||||||
elif chain_b in lig_chains:
|
|
||||||
lig_chain = chain_b
|
|
||||||
|
|
||||||
lowest_error_indices[interface] = torch.argmin(errors[lig_chain])
|
|
||||||
else:
|
|
||||||
# assign a random value to avoid key errors downstream; sorting ligand interfaces
|
|
||||||
# from other types is handles in analysis
|
|
||||||
lowest_error_indices[interface] = 0
|
|
||||||
|
|
||||||
return lowest_error_indices
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from rf2aa.model.AF3_structure import AtomAttentionDecoder, AtomAttentionEncoder
|
|
||||||
|
|
||||||
|
|
||||||
class NonEquivariantAtomEncoder(nn.Module):
|
|
||||||
def __init__(self, block_params):
|
|
||||||
super().__init__()
|
|
||||||
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
|
||||||
self.model = AtomAttentionEncoder(**block_params)
|
|
||||||
|
|
||||||
|
|
||||||
class NonEquivariantAtomDecoder(nn.Module):
|
|
||||||
def __init__(self, block_params):
|
|
||||||
super().__init__()
|
|
||||||
# c_atom, c_atompair, c_token = block_params.c_atom_pair, block_params.c_atom, block_params.c_token
|
|
||||||
self.model = AtomAttentionDecoder(**block_params)
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
import importlib
|
|
||||||
import sys
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_import(path: str) -> Any:
|
|
||||||
"""
|
|
||||||
Import a module from a string path.
|
|
||||||
If the module is not already imported, we dynamically import
|
|
||||||
with `importlib.import_module` and return the module object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path (str): The path to the module.
|
|
||||||
|
|
||||||
Example usage with Hydra, assuming the module `rf2aa.setup` exists within the PYTHONPATH:
|
|
||||||
```yaml
|
|
||||||
# config.yaml
|
|
||||||
setup: ${resolve_import:rf2aa.setup}
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
namespace, name = path.rsplit(".", maxsplit=1)
|
|
||||||
importlib.import_module(namespace)
|
|
||||||
return sys.modules[namespace].__dict__[name]
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -1,183 +0,0 @@
|
|||||||
asd
|
|
||||||
OpenBabel10022416543D
|
|
||||||
|
|
||||||
86 91 0 0 1 0 0 0 0 0999 V2000
|
|
||||||
0.0000 0.0000 0.0000 Ru 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.9811 0.0000 0.0000 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.8984 0.0000 0.9976 N 0 3 0 0 0 4 0 0 0 0 0 0
|
|
||||||
4.2283 -0.4311 0.5225 C 0 0 1 0 0 0 0 0 0 0 0 0
|
|
||||||
4.1294 -0.1510 -0.9789 C 0 0 2 0 0 0 0 0 0 0 0 0
|
|
||||||
2.6689 -0.1856 -1.1704 N 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
4.5104 0.8502 -1.2196 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
4.3148 -1.5116 0.6738 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.0926 0.0177 -2.4711 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.0551 1.3215 -3.0085 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.4395 1.5092 -4.2457 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.8688 0.4467 -4.9577 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.9728 -0.8382 -4.4242 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.5980 -1.0817 -3.1942 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.3946 2.5134 -4.6607 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.5711 -1.6818 -4.9808 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.7839 -2.5051 -2.7336 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.6142 -2.9688 -3.2824 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.9963 -2.5662 -1.6680 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.8879 -3.0998 -2.9345 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.6908 2.4918 -2.3015 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.2457 3.4308 -2.6402 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
3.7660 2.5403 -2.5179 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.5630 2.4326 -1.2198 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.1546 0.6923 -6.2641 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.0871 -0.2224 -6.8610 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
0.6644 1.4570 -6.8598 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-0.8688 1.0481 -6.0884 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.6482 0.2211 2.3858 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.4836 1.5443 2.8329 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.2054 1.7521 4.1866 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.0596 2.7690 4.5422 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.2790 -0.6127 4.6093 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.5615 -0.8708 3.2641 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.1903 -1.4511 5.2957 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.0965 0.6881 5.0892 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.7358 -2.2864 2.7757 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.3651 -2.9970 3.5199 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
3.7928 -2.5228 2.5978 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.1924 -2.4512 1.8403 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
3.4096 2.6945 1.2403 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.5155 2.7062 1.8725 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.4962 3.6544 2.4168 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.6520 2.6776 1.1976 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.8132 0.9407 6.5497 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.1857 1.8271 6.6864 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
2.7445 1.1113 7.1053 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1.3095 0.0856 7.0112 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-0.3034 2.2463 -0.7911 Cl 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-0.0311 -2.4021 0.1684 Cl 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-0.3847 0.3239 1.7667 C 0 0 0 0 0 3 0 0 0 0 0 0
|
|
||||||
0.3401 0.5495 2.5448 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-1.7569 0.2734 2.2121 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-2.7733 0.0216 1.2582 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-4.1105 -0.0598 1.6390 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-4.4360 0.1136 2.9889 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.4526 0.3669 3.9517 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-2.1198 0.4461 3.5620 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-1.3353 0.6371 4.2895 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.7295 0.4995 4.9924 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-5.4785 0.0501 3.2858 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-4.8949 -0.2514 0.9184 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-2.2963 -0.1404 -0.0112 O 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
5.5671 -0.2126 2.2083 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
5.1577 1.3111 1.4179 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
6.6571 0.2637 0.4773 N 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
5.3944 0.2611 1.2403 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
6.8286 1.1116 -0.0564 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
7.1294 -1.0762 -0.3820 S 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
6.3097 -0.9086 -1.8483 N 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
8.5424 -0.9362 -0.7279 O 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
6.6677 -2.2360 0.3905 O 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
4.5395 -1.0515 -2.8887 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
4.6166 -2.1893 -1.5280 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
4.8612 -1.1705 -1.8517 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
6.8079 -1.4416 -2.5601 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.2077 -0.2475 -1.1550 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.9743 -0.9748 -0.8710 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-2.3957 -0.8002 -2.3143 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-1.6290 -0.0874 -2.6345 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-1.9096 -1.7391 -2.0383 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.0669 -0.9839 -3.1595 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.8219 1.1126 -1.4561 C 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-4.5239 1.0169 -2.2909 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-4.3670 1.5131 -0.5970 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
-3.0331 1.8184 -1.7297 H 0 0 0 0 0 0 0 0 0 0 0 0
|
|
||||||
1 50 1 0 0 0 0
|
|
||||||
1 51 1 0 0 0 0
|
|
||||||
2 1 1 0 0 0 0
|
|
||||||
2 3 2 0 0 0 0
|
|
||||||
3 29 1 0 0 0 0
|
|
||||||
4 8 1 1 0 0 0
|
|
||||||
4 3 1 0 0 0 0
|
|
||||||
4 67 1 0 0 0 0
|
|
||||||
5 7 1 6 0 0 0
|
|
||||||
5 4 1 0 0 0 0
|
|
||||||
6 5 1 0 0 0 0
|
|
||||||
6 2 1 0 0 0 0
|
|
||||||
9 6 1 0 0 0 0
|
|
||||||
10 9 2 0 0 0 0
|
|
||||||
10 21 1 0 0 0 0
|
|
||||||
11 10 1 0 0 0 0
|
|
||||||
12 13 1 0 0 0 0
|
|
||||||
12 11 2 0 0 0 0
|
|
||||||
13 14 2 0 0 0 0
|
|
||||||
14 17 1 0 0 0 0
|
|
||||||
14 9 1 0 0 0 0
|
|
||||||
15 11 1 0 0 0 0
|
|
||||||
16 13 1 0 0 0 0
|
|
||||||
17 19 1 0 0 0 0
|
|
||||||
18 17 1 0 0 0 0
|
|
||||||
20 17 1 0 0 0 0
|
|
||||||
21 24 1 0 0 0 0
|
|
||||||
22 21 1 0 0 0 0
|
|
||||||
23 21 1 0 0 0 0
|
|
||||||
25 28 1 0 0 0 0
|
|
||||||
25 12 1 0 0 0 0
|
|
||||||
26 25 1 0 0 0 0
|
|
||||||
27 25 1 0 0 0 0
|
|
||||||
29 30 2 0 0 0 0
|
|
||||||
29 34 1 0 0 0 0
|
|
||||||
30 31 1 0 0 0 0
|
|
||||||
31 32 1 0 0 0 0
|
|
||||||
31 36 2 0 0 0 0
|
|
||||||
33 36 1 0 0 0 0
|
|
||||||
33 35 1 0 0 0 0
|
|
||||||
34 33 2 0 0 0 0
|
|
||||||
36 45 1 0 0 0 0
|
|
||||||
37 34 1 0 0 0 0
|
|
||||||
37 38 1 0 0 0 0
|
|
||||||
39 37 1 0 0 0 0
|
|
||||||
40 37 1 0 0 0 0
|
|
||||||
41 42 1 0 0 0 0
|
|
||||||
42 43 1 0 0 0 0
|
|
||||||
42 30 1 0 0 0 0
|
|
||||||
44 42 1 0 0 0 0
|
|
||||||
45 46 1 0 0 0 0
|
|
||||||
45 48 1 0 0 0 0
|
|
||||||
45 47 1 0 0 0 0
|
|
||||||
49 1 1 0 0 0 0
|
|
||||||
51 53 1 0 0 0 0
|
|
||||||
51 52 1 0 0 0 0
|
|
||||||
53 58 2 0 0 0 0
|
|
||||||
54 55 2 0 0 0 0
|
|
||||||
54 53 1 0 0 0 0
|
|
||||||
55 56 1 0 0 0 0
|
|
||||||
56 61 1 0 0 0 0
|
|
||||||
56 57 2 0 0 0 0
|
|
||||||
57 60 1 0 0 0 0
|
|
||||||
58 57 1 0 0 0 0
|
|
||||||
58 59 1 0 0 0 0
|
|
||||||
62 55 1 0 0 0 0
|
|
||||||
63 54 1 0 0 0 0
|
|
||||||
63 1 1 0 0 0 0
|
|
||||||
66 67 1 0 0 0 0
|
|
||||||
67 65 1 0 0 0 0
|
|
||||||
67 64 1 0 0 0 0
|
|
||||||
68 66 1 0 0 0 0
|
|
||||||
69 72 2 0 0 0 0
|
|
||||||
69 66 1 0 0 0 0
|
|
||||||
70 69 1 0 0 0 0
|
|
||||||
71 69 2 0 0 0 0
|
|
||||||
73 75 1 0 0 0 0
|
|
||||||
75 70 1 0 0 0 0
|
|
||||||
75 74 1 0 0 0 0
|
|
||||||
75 5 1 0 0 0 0
|
|
||||||
76 70 1 0 0 0 0
|
|
||||||
77 78 1 0 0 0 0
|
|
||||||
77 63 1 0 0 0 0
|
|
||||||
79 81 1 0 0 0 0
|
|
||||||
79 77 1 0 0 0 0
|
|
||||||
80 79 1 0 0 0 0
|
|
||||||
82 79 1 0 0 0 0
|
|
||||||
83 77 1 0 0 0 0
|
|
||||||
83 85 1 0 0 0 0
|
|
||||||
84 83 1 0 0 0 0
|
|
||||||
86 83 1 0 0 0 0
|
|
||||||
M END
|
|
||||||
$$$$
|
|
||||||
@@ -1,144 +0,0 @@
|
|||||||
import os
|
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
import hydra
|
|
||||||
import pandas as pd
|
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
from omegaconf import OmegaConf
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
||||||
|
|
||||||
from rf2aa.chemical import initialize_chemdata
|
|
||||||
from rf2aa.data.compose_dataset import compose_posebusters
|
|
||||||
from rf2aa.data.dataloader_adaptor import get_loss_calc_items
|
|
||||||
from rf2aa.trainer_new import trainer_factory
|
|
||||||
from rf2aa.util import writepdb
|
|
||||||
|
|
||||||
|
|
||||||
class PoseBustersBenchmark:
|
|
||||||
def __init__(self, config):
|
|
||||||
# config file logic for validation, low->high prio:
|
|
||||||
# 1) use default parameters in config/train/base.yml
|
|
||||||
# 2) use parameters saved in model
|
|
||||||
# 3) use specific params in config/inference
|
|
||||||
default_config_path = os.path.join(
|
|
||||||
os.path.dirname(os.path.abspath(__file__)), "config/train/base.yaml"
|
|
||||||
)
|
|
||||||
base_config = OmegaConf.load(default_config_path)
|
|
||||||
tmp_data = torch.load(config.eval_params.checkpoint_path, mmap=True)
|
|
||||||
if "training_config" in tmp_data:
|
|
||||||
train_config = tmp_data["training_config"]
|
|
||||||
self.config = OmegaConf.merge(base_config, train_config, config)
|
|
||||||
else:
|
|
||||||
self.config = OmegaConf.merge(base_config, config)
|
|
||||||
tmp_data = None
|
|
||||||
|
|
||||||
assert self.config.ddp_params.batch_size == 1, "batch size is assumed to be 1"
|
|
||||||
if self.config.experiment.output_dir is not None:
|
|
||||||
self.output_dir = self.config.experiment.output_dir
|
|
||||||
else:
|
|
||||||
self.output_dir = "output/"
|
|
||||||
if not os.path.exists(self.output_dir):
|
|
||||||
os.makedirs(self.output_dir)
|
|
||||||
|
|
||||||
self.trainer = trainer_factory[self.config.experiment.trainer](
|
|
||||||
config=self.config
|
|
||||||
)
|
|
||||||
|
|
||||||
def construct_dataset(self, rank, world_size):
|
|
||||||
# fd initialize chemical data based on input arguments
|
|
||||||
# this needs to be initialized first
|
|
||||||
init = partial(initialize_chemdata, self.config)
|
|
||||||
init()
|
|
||||||
|
|
||||||
return compose_posebusters(init, self.config.loader_params, rank, world_size)
|
|
||||||
|
|
||||||
def launch_distributed_eval(self):
|
|
||||||
world_size = torch.cuda.device_count()
|
|
||||||
if "MASTER_ADDR" not in os.environ:
|
|
||||||
os.environ["MASTER_ADDR"] = (
|
|
||||||
"127.0.0.1" # multinode requires this set in submit script
|
|
||||||
)
|
|
||||||
if "MASTER_PORT" not in os.environ:
|
|
||||||
os.environ["MASTER_PORT"] = "%d" % self.config.ddp_params.port
|
|
||||||
|
|
||||||
world_size = torch.cuda.device_count()
|
|
||||||
|
|
||||||
if world_size == 0:
|
|
||||||
print("Error! No GPUs found!")
|
|
||||||
elif world_size == 1:
|
|
||||||
# No need for multiple processes with 1 GPU
|
|
||||||
self.evaluate_model(0, world_size)
|
|
||||||
else:
|
|
||||||
mp.spawn(
|
|
||||||
self.evaluate_model, args=(world_size,), nprocs=world_size, join=True
|
|
||||||
)
|
|
||||||
|
|
||||||
def evaluate_model(self, rank, world_size):
|
|
||||||
gpu = self.trainer.init_process_group(rank, world_size)
|
|
||||||
benchmark_loader = self.construct_dataset(rank, world_size)
|
|
||||||
|
|
||||||
# move global information to device
|
|
||||||
self.trainer.move_constants_to_device(gpu)
|
|
||||||
|
|
||||||
self.trainer.construct_model(device=gpu)
|
|
||||||
self.trainer.model = DDP(
|
|
||||||
self.trainer.model,
|
|
||||||
device_ids=[gpu],
|
|
||||||
find_unused_parameters=False,
|
|
||||||
broadcast_buffers=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.trainer.load_checkpoint(rank)
|
|
||||||
self.trainer.load_model()
|
|
||||||
self.trainer.model.eval()
|
|
||||||
records = []
|
|
||||||
for inputs in benchmark_loader:
|
|
||||||
item = inputs[-1]
|
|
||||||
with torch.no_grad():
|
|
||||||
loss, loss_dict, outputs = self.trainer.train_step(
|
|
||||||
inputs,
|
|
||||||
self.config.loader_params.maxcycle,
|
|
||||||
nograds=True,
|
|
||||||
return_outputs=True,
|
|
||||||
)
|
|
||||||
loss_dict["CHAINID"] = item["CHAINID"][0]
|
|
||||||
for k, v in loss_dict.items():
|
|
||||||
if torch.is_tensor(v):
|
|
||||||
loss_dict[k] = v.item()
|
|
||||||
records.append(loss_dict)
|
|
||||||
df = pd.DataFrame(records)
|
|
||||||
df.to_csv(
|
|
||||||
f"{self.output_dir}/{self.config.experiment.name}_{rank}_posebusters.csv"
|
|
||||||
)
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
true_crds = inputs[5]
|
|
||||||
seq, _, idx_pdb, bond_feats, _, _ = get_loss_calc_items(inputs, device=gpu)
|
|
||||||
pred_crds, alphas, pred_lddts = outputs[5], outputs[6], outputs[8]
|
|
||||||
_, pred_allatom = self.trainer.xyz_converter.compute_all_atom(
|
|
||||||
seq[:, 0], pred_crds[-1], alphas[-1]
|
|
||||||
)
|
|
||||||
|
|
||||||
writepdb(
|
|
||||||
f"{self.output_dir}/{item['CHAINID'][0]}_nat.pdb",
|
|
||||||
true_crds[:, 0],
|
|
||||||
seq[:, 0].long(),
|
|
||||||
bond_feats=bond_feats,
|
|
||||||
)
|
|
||||||
writepdb(
|
|
||||||
f"{self.output_dir}/{item['CHAINID'][0]}_pred.pdb",
|
|
||||||
pred_allatom[0],
|
|
||||||
seq[:, 0].long(),
|
|
||||||
bond_feats=bond_feats,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(version_base=None, config_path="config/inference")
|
|
||||||
def main(config):
|
|
||||||
benchmarker = PoseBustersBenchmark(config=config)
|
|
||||||
benchmarker.launch_distributed_eval()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
33
scripts/build_base_apptainer.sh
Normal file
33
scripts/build_base_apptainer.sh
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# This script builds a datahub apptainer container.
|
||||||
|
set -e # Exit on error
|
||||||
|
|
||||||
|
echo "Running from $PWD"
|
||||||
|
|
||||||
|
# Check if apptainer/singularity is available
|
||||||
|
APPTAINER_BINARY=$(command -v apptainer || command -v singularity)
|
||||||
|
if [ -z "$APPTAINER_BINARY" ]; then
|
||||||
|
echo "Error: Neither apptainer nor singularity found in PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Using apptainer at: $APPTAINER_BINARY"
|
||||||
|
|
||||||
|
# Generate the image name with today's date
|
||||||
|
DATE=$(date +%Y-%m-%d)
|
||||||
|
IMAGE_NAME="modelhub_${DATE}.sif"
|
||||||
|
echo "Building apptainer image: $IMAGE_NAME"
|
||||||
|
|
||||||
|
# Build Phase
|
||||||
|
echo
|
||||||
|
echo "=== Starting Build Phase ==="
|
||||||
|
echo "Running: $APPTAINER_BINARY build --notest '$IMAGE_NAME' base_apptainer.spec"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
$APPTAINER_BINARY build \
|
||||||
|
--nv \
|
||||||
|
--notest \
|
||||||
|
"$IMAGE_NAME" base_apptainer.spec
|
||||||
|
echo "----------------------------------------"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "=== Build Complete ==="
|
||||||
|
echo "Container is available at: $PWD/$IMAGE_NAME"
|
||||||
50
scripts/freeze_apptainer.sh
Executable file
50
scripts/freeze_apptainer.sh
Executable file
@@ -0,0 +1,50 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# This script freezes CIFUtils, Datahub, and Modelhub versions within an existing apptainer.
|
||||||
|
set -e # Exit on error
|
||||||
|
|
||||||
|
echo "Running from $PWD"
|
||||||
|
|
||||||
|
SCRIPT_PATH=$(realpath $0)
|
||||||
|
SCRIPT_DIR=$(dirname $SCRIPT_PATH)
|
||||||
|
|
||||||
|
# Check if apptainer/singularity is available
|
||||||
|
APPTAINER_BINARY=$(command -v apptainer || command -v singularity)
|
||||||
|
if [ -z "$APPTAINER_BINARY" ]; then
|
||||||
|
echo "Error: Neither apptainer nor singularity found in PATH"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "Using apptainer at: $APPTAINER_BINARY"
|
||||||
|
|
||||||
|
# This is the default apptainer that you can build from 'make apptainer'
|
||||||
|
echo "... looking for a local apptainer image at '$SCRIPT_DIR/modelhub.sif'"
|
||||||
|
SIF_PATH="$SCRIPT_DIR/modelhub.sif"
|
||||||
|
SIF_PATH=$(readlink -f "$SCRIPT_DIR/modelhub.sif" )
|
||||||
|
echo "Base SIF path to build from: $SIF_PATH"
|
||||||
|
|
||||||
|
# Generate the image name with today's date
|
||||||
|
DATE=$(date +%Y-%m-%d)
|
||||||
|
IMAGE_NAME="frozen_modelhub_${DATE}.sif"
|
||||||
|
echo "Building apptainer from image with frozen dependencies: $IMAGE_NAME"
|
||||||
|
|
||||||
|
# Check if INSTALL_PROJECT is set to true and set the image name accordingly
|
||||||
|
if ${INSTALL_PROJECT}; then
|
||||||
|
echo "Modelhub WILL be installed in the apptainer! Ensure that this is intentional."
|
||||||
|
IMAGE_NAME="frozen_modelhub_datahub_cifutils_${DATE}.sif"
|
||||||
|
else
|
||||||
|
IMAGE_NAME="frozen_datahub_cifutils_${DATE}.sif"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Build Phase
|
||||||
|
echo
|
||||||
|
echo "=== Starting Build Phase ==="
|
||||||
|
echo "Running: $APPTAINER_BINARY build --notest '$IMAGE_NAME' freeze_apptainer.spec"
|
||||||
|
echo "----------------------------------------"
|
||||||
|
INSTALL_PROJECT=$INSTALL_PROJECT $APPTAINER_BINARY build \
|
||||||
|
--nv \
|
||||||
|
--notest \
|
||||||
|
"$IMAGE_NAME" freeze_apptainer.spec
|
||||||
|
echo "----------------------------------------"
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "=== Build Complete ==="
|
||||||
|
echo "Container is available at: $PWD/$IMAGE_NAME"
|
||||||
7
scripts/shebang/README.md
Normal file
7
scripts/shebang/README.md
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
This directory contains scripts that are not to be run directly by the user.
|
||||||
|
They are [SHEBANG scripts](https://en.wikipedia.org/wiki/Shebang_(Unix)) that are used to run the appropriate apptainer container.
|
||||||
|
|
||||||
|
For example, the script `modelhub_exec.sh` is used to run the modelhub apptainer container with the latest apptainer image
|
||||||
|
stored locally or at the IPD.
|
||||||
|
|
||||||
|
The shebang lines (`#!/bin/bash` ...) at the top of entry point scripts like `train.py` redirect the system to here to find the correct apptainer container.
|
||||||
1
scripts/shebang/modelhub.sif
Symbolic link
1
scripts/shebang/modelhub.sif
Symbolic link
@@ -0,0 +1 @@
|
|||||||
|
/projects/ml/modelhub/apptainer/modelhub_2025-03-19.sif
|
||||||
151
scripts/shebang/modelhub_exec.sh
Executable file
151
scripts/shebang/modelhub_exec.sh
Executable file
@@ -0,0 +1,151 @@
|
|||||||
|
#!/usr/bin/bash
|
||||||
|
|
||||||
|
###################
|
||||||
|
# You can add the path to this file as the shebang line in your python script.
|
||||||
|
# Then by default, the python script will be executed with the python interpreter
|
||||||
|
# in the SIF_PATH container. Here, we launch the container with nvidia gpu and slurm support.
|
||||||
|
#
|
||||||
|
# Example shebang: #!/usr/bin/env -S /bin/sh -c '"$(dirname "$0")/scripts/shebang/modelhub_exec.sh" "$0" "$@"'
|
||||||
|
###################
|
||||||
|
|
||||||
|
# Let the user know this script is setting things up behind the scene
|
||||||
|
SCRIPT_PATH=$(realpath $0)
|
||||||
|
SCRIPT_DIR=$(dirname $SCRIPT_PATH)
|
||||||
|
echo '################## Start shebang info ##################'
|
||||||
|
echo "The file $SCRIPT_PATH is being run as a shebang executable.
|
||||||
|
It will...
|
||||||
|
|
||||||
|
1. Add the 'modelhub' and 'src/modelhub' repo directories to your PYTHONPATH.
|
||||||
|
2. Run your python script from the right container, which contains all dependencies.
|
||||||
|
3. Launch the container with slurm and nvidia gpu support."
|
||||||
|
|
||||||
|
# Extract the path to the Python script from the arguments
|
||||||
|
PYTHON_SCRIPT=$(realpath "$1")
|
||||||
|
shift
|
||||||
|
|
||||||
|
# Find repository root by looking for .project-root file
|
||||||
|
find_repo_root() {
|
||||||
|
local current_dir="$1"
|
||||||
|
while [ "$current_dir" != "/" ]; do
|
||||||
|
if [ -f "$current_dir/.project-root" ]; then
|
||||||
|
echo "$current_dir"
|
||||||
|
return 0
|
||||||
|
fi
|
||||||
|
current_dir="$(dirname "$current_dir")"
|
||||||
|
done
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Searching for repository root directory..."
|
||||||
|
REPO_ROOT=$(find_repo_root "$(dirname "$PYTHON_SCRIPT")")
|
||||||
|
if [ -z "$REPO_ROOT" ]; then
|
||||||
|
echo "Error: Could not find .project-root file in any parent directory"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo "... found repository root at '$REPO_ROOT'"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Function to add a directory to PYTHONPATH if it's not already included
|
||||||
|
add_to_pythonpath() {
|
||||||
|
local dir_path="$1"
|
||||||
|
if [[ ":$PYTHONPATH:" != *":$dir_path:"* ]]; then
|
||||||
|
export PYTHONPATH="$dir_path:$PYTHONPATH"
|
||||||
|
echo "Added '$dir_path' to PYTHONPATH."
|
||||||
|
else
|
||||||
|
echo "'$dir_path' is already in PYTHONPATH."
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add the src directory to PYTHONPATH if not already present
|
||||||
|
echo
|
||||||
|
echo "Checking and adding 'src' directory to PYTHONPATH..."
|
||||||
|
SRC_PATH="$REPO_ROOT/src"
|
||||||
|
add_to_pythonpath "$SRC_PATH"
|
||||||
|
|
||||||
|
# Add modelhub to PYTHONPATH if not already present
|
||||||
|
echo
|
||||||
|
echo "Checking and adding 'modelhub' directory to PYTHONPATH..."
|
||||||
|
MODELHUB_PATH="$SRC_PATH/modelhub"
|
||||||
|
add_to_pythonpath "$MODELHUB_PATH"
|
||||||
|
|
||||||
|
# Load the .env file environment variables from the repo root
|
||||||
|
echo
|
||||||
|
echo "Attempting to load environment variables from .env file:"
|
||||||
|
if [ -f "$REPO_ROOT/.env" ]; then
|
||||||
|
echo "... loading environment variables from '$REPO_ROOT/.env'"
|
||||||
|
export $(cat "$REPO_ROOT/.env" | grep -v '#' | xargs)
|
||||||
|
else
|
||||||
|
echo " Warning: No .env file found at repository root ($REPO_ROOT)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# check if we are at the IPD
|
||||||
|
IPD_FILE="/software/containers/versions/rf_diffusion_aa/ipd.txt"
|
||||||
|
|
||||||
|
SIF_PATH=""
|
||||||
|
|
||||||
|
echo
|
||||||
|
echo "Fetching the appropriate apptainer image..."
|
||||||
|
|
||||||
|
if [ -z "$APPTAINER_NAME" ]; then
|
||||||
|
if [ -n "$PROJECT_PATH" ]; then
|
||||||
|
# Attempt to find any .sif file in the PROJECT_PATH/scripts/shebang directory
|
||||||
|
SIF_DIR="$PROJECT_PATH/scripts/shebang"
|
||||||
|
SIF_FILE=$(find "$SIF_DIR" -maxdepth 1 -name "*.sif" -print -quit)
|
||||||
|
|
||||||
|
if [ -n "$SIF_FILE" ]; then
|
||||||
|
SIF_PATH="$SIF_FILE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# If SIF_PATH is still empty, use the default SIF
|
||||||
|
if [ -z "$SIF_PATH" ]; then
|
||||||
|
SIF_NAME="modelhub.sif"
|
||||||
|
SIF_PATH="$SCRIPT_DIR/$SIF_NAME"
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "... looking for a local apptainer image at '$SIF_PATH'"
|
||||||
|
# Check if the SIF file exists
|
||||||
|
if [ ! -f "$SIF_PATH" ]; then
|
||||||
|
echo "... apptainer not found. To run with your own apptainer image, you can build it with 'make apptainer' and place it here: '$SIF_PATH'"
|
||||||
|
echo "Attempting to run $PYTHON_SCRIPT with $(which python)"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Already running inside container $APPTAINER_NAME. Executing $PYTHON_SCRIPT with $(which python) in the existing container."
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Function to print debug=mode warning
|
||||||
|
print_debug_warning() {
|
||||||
|
echo
|
||||||
|
echo "###############################################################################"
|
||||||
|
echo "# #"
|
||||||
|
echo "# ⚠️ WARNING ⚠️ #"
|
||||||
|
echo "# RUNNING WITH DEBUGPY ON PORT $DEBUG_PORT #"
|
||||||
|
echo "# DON'T FORGET TO ATTACH A DEBUGGER #"
|
||||||
|
echo "# #"
|
||||||
|
echo "###############################################################################"
|
||||||
|
echo
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -n "$DEBUG_PORT" ]; then
|
||||||
|
print_debug_warning
|
||||||
|
python_cmd="python -m debugpy --listen $DEBUG_PORT --wait-for-client"
|
||||||
|
else
|
||||||
|
python_cmd="python"
|
||||||
|
echo
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [ ! -z $SIF_PATH ]; then
|
||||||
|
echo "Running $PYTHON_SCRIPT with apptainer: $SIF_PATH."
|
||||||
|
echo '################## End shebang info ####################'
|
||||||
|
echo
|
||||||
|
/usr/bin/apptainer exec --nv --slurm \
|
||||||
|
--bind "$REPO_ROOT:$REPO_ROOT" \
|
||||||
|
--env PYTHONPATH="\$PYTHONPATH:$PYTHONPATH" \
|
||||||
|
$SIF_PATH $python_cmd "$PYTHON_SCRIPT" "$@"
|
||||||
|
else
|
||||||
|
echo "Running $PYTHON_SCRIPT with python: $(which python)"
|
||||||
|
echo '################## End shebang info ####################'
|
||||||
|
echo
|
||||||
|
$python_cmd "$PYTHON_SCRIPT" "$@"
|
||||||
|
fi
|
||||||
78
scripts/slurm/launch.sh
Normal file
78
scripts/slurm/launch.sh
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
#SBATCH -p gpu-train
|
||||||
|
#SBATCH --nodes 2
|
||||||
|
#SBATCH --gres=gpu:l40:8
|
||||||
|
#SBATCH --ntasks-per-node 8
|
||||||
|
#SBATCH --mem=512g
|
||||||
|
#SBATCH -t 7-00:00:00
|
||||||
|
#SBATCH -J af3-old-msas-pdb-only-experimental
|
||||||
|
#SBATCH -o slurm_logs/%x_%j.out
|
||||||
|
#SBATCH -e slurm_logs/%x_%j.err
|
||||||
|
#SBATCH --no-kill=off
|
||||||
|
|
||||||
|
### Excluded Nodes:
|
||||||
|
|
||||||
|
### To call this script run: `sbatch launch.sh` from this directory
|
||||||
|
### For reference, see the Lightning Fabric + SLURM guide: https://lightning.ai/docs/fabric/stable/guide/multi_node/slurm.html
|
||||||
|
|
||||||
|
# (In case we're still running in debug mode)
|
||||||
|
unset DEBUG_PORT
|
||||||
|
unset PROJECT_PATH
|
||||||
|
|
||||||
|
# (SLURM setup, ensuring we have a unique port per job, and setting the master address to Rank 0)
|
||||||
|
export MASTER_PORT=$((1024 + RANDOM % 64512))
|
||||||
|
export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
|
||||||
|
|
||||||
|
### Set custom paths
|
||||||
|
# WARNING: You will need to update these paths to match your local setup
|
||||||
|
# ... cifutils and datahub
|
||||||
|
export PYTHONPATH="/home/ncorley/projects/datahub/src:/home/ncorley/projects/cifutils/src:/home/ncorley/projects/modelhub/src"
|
||||||
|
# ... project path (if not using root src/modelhub)
|
||||||
|
export PROJECT_PATH="/home/ncorley/projects/modelhub/projects/rfscore"
|
||||||
|
# ... cache directory for Triton kernels (e.g., DeepSpeed4Science fused kernels)
|
||||||
|
export TRITON_CACHE_DIR="/home/ncorley/.triton" # Change this to a directory with write permissions
|
||||||
|
|
||||||
|
### Environment flags
|
||||||
|
|
||||||
|
# Debugging flags (optional)
|
||||||
|
export NCCL_DEBUG=INFO # NCCL internal debugging
|
||||||
|
export PYTHONFAULTHANDLER=1 # Catches Python core dumps (e.g., segmentation faults)
|
||||||
|
|
||||||
|
# Expand CUDA memory
|
||||||
|
export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
|
||||||
|
|
||||||
|
# Turn off NVLink (L40 do not have NVLink)
|
||||||
|
export NCCL_P2P_DISABLE=1
|
||||||
|
|
||||||
|
# OPENMP and OPENBLAS optimizations
|
||||||
|
# https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#utilize-openmp
|
||||||
|
# NOTE: Must be optimized per-system; see: https://github.com/pytorch/pytorch/blob/65e6194aeb3269a182cfe2c05c122159da12770f/torch/distributed/run.py#L596-L608
|
||||||
|
export OMP_NUM_THREADS=4
|
||||||
|
export OPENBLAS_NUM_THREADS=4
|
||||||
|
|
||||||
|
#######################################################################################################
|
||||||
|
### WARNING: The command below is just an example. It will fail if you don't update the experiment ###
|
||||||
|
### config in the command below. Please adapt according to your target experiment ###
|
||||||
|
#######################################################################################################
|
||||||
|
|
||||||
|
### Set the effective batch size
|
||||||
|
EFFECTIVE_BATCH_SIZE=16
|
||||||
|
|
||||||
|
### Compose the training script
|
||||||
|
DEVICES_PER_NODE=${SLURM_NTASKS_PER_NODE:-8} # Default to 8 if not set
|
||||||
|
echo "Running on $SLURM_NNODES nodes with $DEVICES_PER_NODE tasks per node"
|
||||||
|
|
||||||
|
### Calculate grad_accum_steps
|
||||||
|
GRAD_ACCUM_STEPS=$((EFFECTIVE_BATCH_SIZE / (DEVICES_PER_NODE * SLURM_NNODES)))
|
||||||
|
echo "Grad Accumulation Steps: $GRAD_ACCUM_STEPS"
|
||||||
|
|
||||||
|
command="srun --kill-on-bad-exit ../../src/modelhub/train.py \
|
||||||
|
experiment=$SLURM_JOB_NAME \
|
||||||
|
++trainer.devices_per_node=$DEVICES_PER_NODE \
|
||||||
|
++trainer.num_nodes=$SLURM_NNODES \
|
||||||
|
++trainer.grad_accum_steps=$GRAD_ACCUM_STEPS"
|
||||||
|
|
||||||
|
echo -e "command\t$command"
|
||||||
|
|
||||||
|
# Let 'er rip
|
||||||
|
$command
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user