mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge pull request #438 from jnwei/pl_upgrades
Upgrades pl_upgrades to match main branch changes.
This commit is contained in:
4
.github/workflows/docker-image.yml
vendored
4
.github/workflows/docker-image.yml
vendored
@@ -11,5 +11,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Cleanup
|
||||
run: sudo rm -rf /usr/share/dotnet && sudo rm -rf /opt/ghc && sudo rm -rf "/usr/local/share/boost" && sudo rm -rf "$AGENT_TOOLSDIRECTORY"
|
||||
- name: Build the Docker image
|
||||
run: docker build . --file Dockerfile --tag openfold:$(date +%s)
|
||||
run: docker build . --file Dockerfile --tag openfold:$(date +%s)
|
||||
|
||||
14
.readthedocs.yaml
Normal file
14
.readthedocs.yaml
Normal file
@@ -0,0 +1,14 @@
|
||||
version: 2
|
||||
|
||||
# Set the OS, Python version and other tools you might need
|
||||
build:
|
||||
os: ubuntu-22.04
|
||||
tools:
|
||||
python: "mambaforge-4.10"
|
||||
|
||||
# Build documentation in the "docs/" directory with Sphinx
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
|
||||
conda:
|
||||
environment: docs/environment.yml
|
||||
17
Dockerfile
17
Dockerfile
@@ -1,17 +1,20 @@
|
||||
FROM nvidia/cuda:11.3.1-cudnn8-devel-ubuntu18.04
|
||||
FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
|
||||
|
||||
# metainformation
|
||||
LABEL org.opencontainers.image.version = "1.0.0"
|
||||
LABEL org.opencontainers.image.authors = "Gustaf Ahdritz"
|
||||
LABEL org.opencontainers.image.version = "2.0.0"
|
||||
LABEL org.opencontainers.image.authors = "OpenFold Team"
|
||||
LABEL org.opencontainers.image.source = "https://github.com/aqlaboratory/openfold"
|
||||
LABEL org.opencontainers.image.licenses = "Apache License 2.0"
|
||||
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:10.2-cudnn8-runtime-ubuntu18.04"
|
||||
LABEL org.opencontainers.image.base.name="docker.io/nvidia/cuda:12.4.1-devel-ubuntu22.04"
|
||||
|
||||
RUN apt-get update && apt-get install -y wget
|
||||
|
||||
RUN apt-key del 7fa2af80
|
||||
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/7fa2af80.pub
|
||||
RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
|
||||
RUN wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
|
||||
RUN dpkg -i cuda-keyring_1.0-1_all.deb
|
||||
|
||||
RUN apt-get install -y libxml2 cuda-minimal-build-12-1 libcusparse-dev-12-1 libcublas-dev-12-1 libcusolver-dev-12-1 git
|
||||
|
||||
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
|
||||
RUN wget -P /tmp \
|
||||
"https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
|
||||
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
|
||||
|
||||
2
LICENSE
2
LICENSE
@@ -187,7 +187,7 @@
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
Copyright 2024 AlQuraishi Laboratory
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
|
||||
548
README.md
548
README.md
@@ -1,557 +1,15 @@
|
||||

|
||||
_Figure: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
|
||||
|
||||
|
||||
# OpenFold
|
||||
|
||||
A faithful but trainable PyTorch reproduction of DeepMind's
|
||||
[AlphaFold 2](https://github.com/deepmind/alphafold).
|
||||
|
||||
## Contents
|
||||
# Documentation
|
||||
See our new home for docs at [openfold.readthedocs.io](https://openfold.readthedocs.io/en/latest/), with instructions for installation and model inference/training.
|
||||
|
||||
- [OpenFold](#openfold)
|
||||
- [Contents](#contents)
|
||||
- [Features](#features)
|
||||
- [Installation (Linux)](#installation-linux)
|
||||
- [Download Alignment Databases](#download-alignment-databases)
|
||||
- [Inference](#inference)
|
||||
- [Monomer inference](#monomer-inference)
|
||||
- [Multimer Inference](#multimer-inference)
|
||||
- [Soloseq Inference](#soloseq-inference)
|
||||
- [Training](#training)
|
||||
- [Testing](#testing)
|
||||
- [Building and Using the Docker Container](#building-and-using-the-docker-container)
|
||||
- [Copyright Notice](#copyright-notice)
|
||||
- [Contributing](#contributing)
|
||||
- [Citing this Work](#citing-this-work)
|
||||
|
||||
## Features
|
||||
|
||||
OpenFold carefully reproduces (almost) all of the features of the original open
|
||||
source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
|
||||
model ensembling, which fared poorly in DeepMind's own ablation testing and is being
|
||||
phased out in future DeepMind experiments. It is omitted here for the sake of reducing
|
||||
clutter. In cases where the *Nature* paper differs from the source, we always defer to the
|
||||
latter.
|
||||
|
||||
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
|
||||
and we've trained it from scratch, matching the performance of the original.
|
||||
We've publicly released model weights and our training data — some 400,000
|
||||
MSAs and PDB70 template hit files — under a permissive license. Model weights
|
||||
are available via scripts in this repository while the MSAs are hosted by the
|
||||
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
|
||||
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
|
||||
|
||||
OpenFold also supports inference using AlphaFold's official parameters, and
|
||||
vice versa (see `scripts/convert_of_weights_to_jax.py`).
|
||||
|
||||
OpenFold has the following advantages over the reference implementation:
|
||||
|
||||
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
|
||||
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
|
||||
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
|
||||
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
|
||||
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
|
||||
kernels support in-place attention during inference and training. They use
|
||||
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
|
||||
implementations, respectively.
|
||||
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
|
||||
- **FlashAttention** support greatly speeds up MSA attention.
|
||||
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.
|
||||
|
||||
## Installation (Linux)
|
||||
|
||||
All Python dependencies are specified in `environment.yml`. For producing sequence
|
||||
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
|
||||
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
|
||||
installed on your system. You'll need `git-lfs` to download OpenFold parameters.
|
||||
Finally, some download scripts require `aria2c` and `aws`.
|
||||
|
||||
This package is currently supported for CUDA 11 and Pytorch 1.12
|
||||
|
||||
To install:
|
||||
1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
|
||||
1. From the `openfold` repo:
|
||||
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
|
||||
`mamba env create -n openfold_env -f environment.yml`
|
||||
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
|
||||
- Activate the environment, e.g `conda activate openfold_env`
|
||||
1. Run `scripts/install_third_party_dependencies.sh` to configure kernels and folding resources.
|
||||
|
||||
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
|
||||
|
||||
|
||||
## Download Alignment Databases
|
||||
|
||||
If you intend to generate your own alignments, e.g. for inference, you have two
|
||||
choices for downloading protein databases, depending on whether you want to use
|
||||
DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
|
||||
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
|
||||
MMseqs2 instead. For the former, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_alphafold_dbs.sh data/
|
||||
```
|
||||
|
||||
For the latter, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_mmseqs_dbs.sh data/ # downloads .tar files
|
||||
bash scripts/prep_mmseqs_dbs.sh data/ # unpacks and preps the databases
|
||||
```
|
||||
|
||||
Make sure to run the latter command on the machine that will be used for MSA
|
||||
generation (the script estimates how the precomputed database index used by
|
||||
MMseqs2 should be split according to the memory available on the system).
|
||||
|
||||
If you're using your own precomputed MSAs or MSAs from the RODA repository,
|
||||
there's no need to download these alignment databases. Simply make sure that
|
||||
the `alignment_dir` contains one directory per chain and that each of these
|
||||
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain. You
|
||||
can use `scripts/flatten_roda.sh` to reformat RODA downloads in this way.
|
||||
Note that the RODA alignments are NOT compatible with the recent .cif ground
|
||||
truth files downloaded by `scripts/download_alphafold_dbs.sh`. To fetch .cif
|
||||
files that match the RODA MSAs, once the alignments are flattened, use
|
||||
`scripts/download_roda_pdbs.sh`. That script outputs a list of alignment dirs
|
||||
for which matching .cif files could not be found. These should be removed from
|
||||
the alignment directory.
|
||||
|
||||
Alternatively, you can use raw MSAs from
|
||||
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
|
||||
that database, use `scripts/prep_proteinnet_msas.py` to convert the data
|
||||
into a format recognized by the OpenFold parser. The resulting directory
|
||||
becomes the `alignment_dir` used in subsequent steps. Use
|
||||
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
|
||||
files.
|
||||
|
||||
For both inference and training, the model's hyperparameters can be tuned from
|
||||
`openfold/config.py`. Of course, if you plan to perform inference using
|
||||
DeepMind's pretrained parameters, you will only be able to make changes that
|
||||
do not affect the shapes of model parameters. For an example of initializing
|
||||
the model, consult `run_pretrained_openfold.py`.
|
||||
|
||||
## Inference
|
||||
|
||||
OpenFold now supports three inference modes:
|
||||
- [Monomer Inference](#monomer-inference): OpenFold reproduction of AlphaFold2. Inference available with either DeepMind's pretrained parameters or OpenFold trained parameters.
|
||||
- [Multimer Inference](#multimer-inference): OpenFold reproduction of AlphaFold-Multimer. Inference available with DeepMind's pre-trained parameters.
|
||||
- [Single Sequence Inference (SoloSeq)](#soloseq-inference): Language Model based structure prediction, using [ESM-1b](https://github.com/facebookresearch/esm) embeddings.
|
||||
|
||||
More instructions for each inference mode are provided below:
|
||||
|
||||
### Monomer inference
|
||||
|
||||
To run inference on a sequence or multiple sequences using a set of DeepMind's
|
||||
pretrained parameters, first download the OpenFold weights e.g.:
|
||||
|
||||
```bash
|
||||
bash scripts/download_openfold_params.sh openfold/resources
|
||||
```
|
||||
|
||||
then run e.g.:
|
||||
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
--config_preset "model_1_ptm" \
|
||||
--model_device "cuda:0" \
|
||||
--output_dir ./ \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
|
||||
```
|
||||
|
||||
where `data` is the same directory as in the previous step. If `jackhmmer`,
|
||||
`hhblits`, `hhsearch` and `kalign` are available at the default path of
|
||||
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
|
||||
If you've already computed alignments for the query, you have the option to
|
||||
skip the expensive alignment computation here with
|
||||
`--use_precomputed_alignments`.
|
||||
|
||||
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
|
||||
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
|
||||
respectively. For a breakdown of the differences between the different parameter
|
||||
files, see the README downloaded to `openfold/resources/openfold_params/`. Since
|
||||
OpenFold was trained under a newer training schedule than the one from which the
|
||||
`model_n` config presets are derived, there is no clean correspondence between
|
||||
`config_preset` settings and OpenFold checkpoints; the only restraints are that
|
||||
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
|
||||
checkpoints are only compatible with template-less presets (`model_3` and above).
|
||||
|
||||
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
|
||||
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
|
||||
to `None` in the config. If a value is specified, OpenFold will attempt to
|
||||
dynamically tune it, considering the chunk size specified in the config as a
|
||||
minimum. This tuning process automatically ensures consistently fast runtimes
|
||||
regardless of input sequence length, but it also introduces some runtime
|
||||
variability, which may be undesirable for certain users. It is also recommended
|
||||
to disable this feature for very long chains (see below). To do so, set the
|
||||
`tune_chunk_size` option in the config to `False`.
|
||||
|
||||
For large-scale batch inference, we offer an optional tracing mode, which
|
||||
massively improves runtimes at the cost of a lengthy model compilation process.
|
||||
To enable it, add `--trace_model` to the inference command.
|
||||
|
||||
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
|
||||
in the config. Note that it appears to work best for sequences with < 1000 residues.
|
||||
|
||||
To minimize memory usage during inference on long sequences, consider the
|
||||
following changes:
|
||||
|
||||
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
|
||||
stack is a major memory bottleneck for inference on long sequences. OpenFold
|
||||
supports two mutually exclusive inference modes to address this issue. One,
|
||||
`average_templates` in the `template` section of the config, is similar to the
|
||||
solution offered by AlphaFold-Multimer, which is simply to average individual
|
||||
template representations. Our version is modified slightly to accommodate
|
||||
weights trained using the standard template algorithm. Using said weights, we
|
||||
notice no significant difference in performance between our averaged template
|
||||
embeddings and the standard ones. The second, `offload_templates`, temporarily
|
||||
offloads individual template embeddings into CPU memory. The former is an
|
||||
approximation while the latter is slightly slower; both are memory-efficient
|
||||
and allow the model to utilize arbitrarily many templates across sequence
|
||||
lengths. Both are disabled by default, and it is up to the user to determine
|
||||
which best suits their needs, if either.
|
||||
- Inference-time low-memory attention (LMA) can be enabled in the model config.
|
||||
This setting trades off speed for vastly improved memory usage. By default,
|
||||
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
|
||||
These represent a favorable tradeoff in most memory-constrained cases.
|
||||
Powerusers can choose to tweak these settings in
|
||||
`openfold/model/primitives.py`. For more information on the LMA algorithm,
|
||||
see the aforementioned Staats & Rabe preprint.
|
||||
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
|
||||
wastes time.
|
||||
- As a last resort, consider enabling `offload_inference`. This enables more
|
||||
extensive CPU offloading at various bottlenecks throughout the model.
|
||||
- Disable FlashAttention, which seems unstable on long sequences.
|
||||
|
||||
Using the most conservative settings, we were able to run inference on a
|
||||
4600-residue complex with a single A100. Compared to AlphaFold's own memory
|
||||
offloading mode, ours is considerably faster; the same complex takes the more
|
||||
efficent AlphaFold-Multimer more than double the time. Use the
|
||||
`long_sequence_inference` config option to enable all of these interventions
|
||||
at once. The `run_pretrained_openfold.py` script can enable this config option with the
|
||||
`--long_sequence_inference` command line option
|
||||
|
||||
Input FASTA files containing multiple sequences are treated as complexes. In
|
||||
this case, the inference script runs AlphaFold-Gap, a hack proposed
|
||||
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
|
||||
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
|
||||
|
||||
### Multimer Inference
|
||||
|
||||
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
|
||||
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
|
||||
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
|
||||
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
|
||||
--uniprot_database_path data/uniprot/uniprot.fasta \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
|
||||
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
--config_preset "model_1_multimer_v3" \
|
||||
--model_device "cuda:0" \
|
||||
--output_dir ./
|
||||
```
|
||||
|
||||
As with monomer inference, if you've already computed alignments for the query, you can use
|
||||
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
|
||||
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
|
||||
|
||||
**Upgrade from an existing OpenFold installation**
|
||||
|
||||
The above command requires several upgrades to existing openfold installations.
|
||||
|
||||
1. Re-download the alphafold parameters to get the latest
|
||||
AlphaFold-Multimer v3 weights:
|
||||
|
||||
```bash
|
||||
bash scripts/download_alphafold_params.sh openfold/resources
|
||||
```
|
||||
|
||||
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
|
||||
and [PDB SeqRes](https://www.rcsb.org/) databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniprot.sh data/
|
||||
```
|
||||
|
||||
The PDB SeqRes and PDB databases must be from the same date to avoid potential
|
||||
errors during template searching. Remove the existing `data/pdb_mmcif` directory
|
||||
and download both databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_pdb_mmcif.sh data/
|
||||
bash scripts/download_pdb_seqres.sh data/
|
||||
```
|
||||
|
||||
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
|
||||
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniref30.sh data/
|
||||
bash scripts/download_mgnify.sh data/
|
||||
```
|
||||
Multimer inference can also run with the older database versions if desired.
|
||||
|
||||
|
||||
### Soloseq Inference
|
||||
|
||||
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
|
||||
|
||||
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
|
||||
|
||||
```bash
|
||||
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
|
||||
```
|
||||
|
||||
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
|
||||
|
||||
Then download the SoloSeq model weights, e.g.:
|
||||
|
||||
|
||||
```bash
|
||||
bash scripts/download_openfold_soloseq_params.sh openfold/resources
|
||||
```
|
||||
|
||||
|
||||
Now, you are ready to run inference:
|
||||
```bash
|
||||
python run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--use_precomputed_alignments embeddings_output_dir \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
|
||||
```
|
||||
|
||||
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
```
|
||||
|
||||
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
|
||||
|
||||
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
|
||||
|
||||
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
|
||||
|
||||
## Training
|
||||
|
||||
To train the model, you will first need to precompute protein alignments.
|
||||
|
||||
You have two options. You can use the same procedure DeepMind used by running
|
||||
the following:
|
||||
|
||||
```bash
|
||||
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--cpus_per_task 16 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
|
||||
```
|
||||
|
||||
As noted before, you can skip the `binary_path` arguments if these binaries are
|
||||
at `/usr/bin`. Expect this step to take a very long time, even for small
|
||||
numbers of proteins.
|
||||
|
||||
Alternatively, you can generate MSAs with the ColabFold pipeline (and templates
|
||||
with HHsearch) with:
|
||||
|
||||
```bash
|
||||
python3 scripts/precompute_alignments_mmseqs.py input.fasta \
|
||||
data/mmseqs_dbs \
|
||||
uniref30_2103_db \
|
||||
alignment_dir \
|
||||
~/MMseqs2/build/bin/mmseqs \
|
||||
/usr/bin/hhsearch \
|
||||
--env_db colabfold_envdb_202108_db
|
||||
--pdb70 data/pdb70/pdb70
|
||||
```
|
||||
|
||||
where `input.fasta` is a FASTA file containing one or more query sequences. To
|
||||
generate an input FASTA from a directory of mmCIF and/or ProteinNet .core
|
||||
files, we provide `scripts/data_dir_to_fasta.py`.
|
||||
|
||||
Next, generate a cache of certain datapoints in the template mmCIF files:
|
||||
|
||||
```bash
|
||||
python3 scripts/generate_mmcif_cache.py \
|
||||
mmcif_dir/ \
|
||||
mmcif_cache.json \
|
||||
--no_workers 16
|
||||
```
|
||||
|
||||
This cache is used to pre-filter templates.
|
||||
|
||||
Next, generate a separate chain-level cache with data used for training-time
|
||||
data filtering:
|
||||
|
||||
```bash
|
||||
python3 scripts/generate_chain_data_cache.py \
|
||||
mmcif_dir/ \
|
||||
chain_data_cache.json \
|
||||
--cluster_file clusters-by-entity-40.txt \
|
||||
--no_workers 16
|
||||
```
|
||||
|
||||
where the `cluster_file` argument is a file of chain clusters, one cluster
|
||||
per line.
|
||||
|
||||
Optionally, download an AlphaFold-style validation set from
|
||||
[CAMEO](https://cameo3d.org) using `scripts/download_cameo.py`. Use the
|
||||
resulting FASTA files to generate validation alignments and then specify
|
||||
the validation set's location using the `--val_...` family of training script
|
||||
flags.
|
||||
|
||||
Finally, call the training script:
|
||||
|
||||
```bash
|
||||
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/ \
|
||||
2021-10-10 \
|
||||
--template_release_dates_cache_path mmcif_cache.json \
|
||||
--precision bf16 \
|
||||
--gpus 8 --replace_sampler_ddp=True \
|
||||
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
|
||||
--deepspeed_config_path deepspeed_config.json \
|
||||
--checkpoint_every_epoch \
|
||||
--resume_from_ckpt ckpt_dir/ \
|
||||
--train_chain_data_cache_path chain_data_cache.json \
|
||||
--obsolete_pdbs_file_path obsolete.dat
|
||||
```
|
||||
|
||||
where `--template_release_dates_cache_path` is a path to the mmCIF cache.
|
||||
Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
|
||||
training targets. A suitable DeepSpeed configuration file can be generated with
|
||||
`scripts/build_deepspeed_config.py`. The training script is
|
||||
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
|
||||
and supports the full range of training options that entails, including
|
||||
multi-node distributed training, validation, and so on. For more information,
|
||||
consult PyTorch Lightning documentation and the `--help` flag of the training
|
||||
script.
|
||||
|
||||
Note that, despite its variable name, `mmcif_dir` can also contain PDB files
|
||||
or even ProteinNet .core files.
|
||||
|
||||
To emulate the AlphaFold training procedure, which uses a self-distillation set
|
||||
subject to special preprocessing steps, use the family of `--distillation` flags.
|
||||
|
||||
In cases where it may be burdensome to create separate files for each chain's
|
||||
alignments, alignment directories can be consolidated using the scripts in
|
||||
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
|
||||
consolidate an alignment directory into a pair of database and index files.
|
||||
Once all alignment directories (or shards of a single alignment directory)
|
||||
have been compiled, unify the indices with `unify_alignment_db_indices.py`. The
|
||||
resulting index, `super.index`, can be passed to the training script flags
|
||||
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
|
||||
flags instead represent the directory containing the compiled alignment
|
||||
databases. Both the training and distillation datasets can be compiled in this
|
||||
way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
|
||||
|
||||
## Testing
|
||||
|
||||
To run unit tests, use
|
||||
|
||||
```bash
|
||||
scripts/run_unit_tests.sh
|
||||
```
|
||||
|
||||
The script is a thin wrapper around Python's `unittest` suite, and recognizes
|
||||
`unittest` arguments. E.g., to run a specific test verbosely:
|
||||
|
||||
```bash
|
||||
scripts/run_unit_tests.sh -v tests.test_model
|
||||
```
|
||||
|
||||
Certain tests require that AlphaFold (v2.0.1) be installed in the same Python
|
||||
environment. These run components of AlphaFold and OpenFold side by side and
|
||||
ensure that output activations are adequately similar. For most modules, we
|
||||
target a maximum pointwise difference of `1e-4`.
|
||||
|
||||
## Building and Using the Docker Container
|
||||
|
||||
**Building the Docker Image**
|
||||
|
||||
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
|
||||
|
||||
```bash
|
||||
docker build -t openfold .
|
||||
```
|
||||
|
||||
**Running the Docker Container**
|
||||
|
||||
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
|
||||
|
||||
The docker container installs all conda components to the base conda environment in `/opt/conda`, and installs openfold itself in `/opt/openfold`,
|
||||
|
||||
Before running the docker container, you can verify that your docker installation is able to properly communicate with your GPU by running the following command:
|
||||
|
||||
|
||||
```bash
|
||||
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
|
||||
```
|
||||
|
||||
Note the `--gpus all` option passed to `docker run`. This option is necessary in order for the container to use the GPUs on the host machine.
|
||||
|
||||
To run the inference code under docker, you can use a command like the one below. In this example, parameters and sequences from the alphafold dataset are being used and are located at `/mnt/alphafold_database` on the host machine, and the input files are located in the current working directory. You can adjust the volume mount locations as needed to reflect the locations of your data.
|
||||
|
||||
```bash
|
||||
docker run \
|
||||
--gpus all \
|
||||
-v $PWD/:/data \
|
||||
-v /mnt/alphafold_database/:/database \
|
||||
-ti openfold:latest \
|
||||
python3 /opt/openfold/run_pretrained_openfold.py \
|
||||
/data/fasta_dir \
|
||||
/database/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path /database/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path /database/pdb70/pdb70 \
|
||||
--uniclust30_database_path /database/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--output_dir /data \
|
||||
--bfd_database_path /database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--model_device cuda:0 \
|
||||
--jackhmmer_binary_path /opt/conda/bin/jackhmmer \
|
||||
--hhblits_binary_path /opt/conda/bin/hhblits \
|
||||
--hhsearch_binary_path /opt/conda/bin/hhsearch \
|
||||
--kalign_binary_path /opt/conda/bin/kalign \
|
||||
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
|
||||
```
|
||||
Much of the content from this page may be found [here.](https://github.com/aqlaboratory/openfold/blob/main/docs/source/original_readme.md)
|
||||
|
||||
## Copyright Notice
|
||||
|
||||
|
||||
20
docs/Makefile
Normal file
20
docs/Makefile
Normal file
@@ -0,0 +1,20 @@
|
||||
# Minimal makefile for Sphinx documentation
|
||||
#
|
||||
|
||||
# You can set these variables from the command line, and also
|
||||
# from the environment for the first two.
|
||||
SPHINXOPTS ?=
|
||||
SPHINXBUILD ?= sphinx-build
|
||||
SOURCEDIR = source
|
||||
BUILDDIR = build
|
||||
|
||||
# Put it first so that "make" without argument is like "make help".
|
||||
help:
|
||||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
|
||||
.PHONY: help Makefile
|
||||
|
||||
# Catch-all target: route all unknown targets to Sphinx using the new
|
||||
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
||||
%: Makefile
|
||||
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
||||
7
docs/environment.yml
Normal file
7
docs/environment.yml
Normal file
@@ -0,0 +1,7 @@
|
||||
name: openfold-docs
|
||||
channels:
|
||||
- conda-forge
|
||||
dependencies:
|
||||
- sphinx=7
|
||||
- myst-parser=3
|
||||
- furo
|
||||
BIN
docs/imgs/of_banner.png
Executable file
BIN
docs/imgs/of_banner.png
Executable file
Binary file not shown.
|
After Width: | Height: | Size: 1.8 MiB |
203
docs/source/Aux_seq_files.md
Normal file
203
docs/source/Aux_seq_files.md
Normal file
@@ -0,0 +1,203 @@
|
||||
# Auxiliary Sequence Files for OpenFold Training
|
||||
|
||||
The training dataset of OpenFold is very large. The `pdb` directory alone contains 185,000 mmcifs; each chain for has multiple sequence alignment files and mmcif files.
|
||||
|
||||
OpenFold introduces a few new file structures for faster access to alignments and mmcif data.
|
||||
|
||||
This documentation will explain the benefits of having the condensed file structure, and explain the contents of each of the files.
|
||||
|
||||
## Default alignment file structure
|
||||
|
||||
One way to store mmcifs and alignments files would be to have a directory for each mmcif chain.
|
||||
|
||||
For example, consider two protein as a case study
|
||||
```
|
||||
- OpenProteinSet
|
||||
└── mmcifs
|
||||
├── 3lrm.cif
|
||||
└── 6kwc.cif
|
||||
...
|
||||
```
|
||||
|
||||
In the `alignments` directory, [PDB:6KWC](https://www.rcsb.org/structure/6KWC) is a monomer with one chain, and thus would have one alignment direcotry. [PDB:3LRM](https://www.rcsb.org/structure/3lrm), a homotetramer, would have one alignment directory for each of its four chains.
|
||||
```
|
||||
- OpenProteinSet
|
||||
└── alignments
|
||||
└── 3lrm_A
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── mgnify_hits.a3m
|
||||
├── pdb70_hits.hhr
|
||||
└── uniref90_hits.a3m
|
||||
└── 3lrm_B
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── mgnify_hits.a3m
|
||||
├── pdb70_hits.hhr
|
||||
└── uniref90_hits.a3m
|
||||
└── 3lrm_C
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── mgnify_hits.a3m
|
||||
├── pdb70_hits.hhr
|
||||
└── uniref90_hits.a3m
|
||||
└── 3lrm_D
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── mgnify_hits.a3m
|
||||
├── pdb70_hits.hhr
|
||||
└── uniref90_hits.a3m
|
||||
└── 6kwc_A
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── mgnify_hits.a3m
|
||||
├── pdb70_hits.hhr
|
||||
└── uniref90_hits.a3m
|
||||
...
|
||||
```
|
||||
|
||||
In practice, the IO overhead of having one directory per protein chain makes accessing the alignments slow.
|
||||
|
||||
## OpenFold DB file structure
|
||||
|
||||
Here we describe a new filesystem that can be used by OpenFold for more efficient access of alignment file and index file contents
|
||||
|
||||
All together, the file directory would look like:
|
||||
```
|
||||
- OpenProteinSet
|
||||
├── duplicate_pdb_chains.txt
|
||||
└── pdb
|
||||
├── mmcif_cache.json
|
||||
└── mmcifs
|
||||
├── 3lrm.cif
|
||||
└── 6kwc.cif
|
||||
└── alignment_db
|
||||
├── alignment_db_0.db
|
||||
├── alignment_db_1.db
|
||||
...
|
||||
├── alignment_db_9.db
|
||||
└── alignment_db.index
|
||||
```
|
||||
|
||||
We will describe each of the file types here.
|
||||
|
||||
### Alignments db files and index files
|
||||
|
||||
To speed up access of MSAs, OpenFold has an alternate alignments storage procedure. Instead of storing dedicated files for each single alignment, we consolidate large sets of alignments to single files referred to as _alignments_db's_. This can reduce I/O overhead and in practice we recommend using around 10 `alignments_db_x.db` files to store the total training set of OpenFold. During training, OpenFold can access each alignment using byte index pointers that are stored in a separate index file (`alignments_db.index`). The alignments for the `3LRM` and `6KWC` examples would be recorded in the index file as follows:
|
||||
|
||||
```alignments_db.index
|
||||
{
|
||||
...
|
||||
"3lrm_A": {
|
||||
"db": "alignment_db_0.db",
|
||||
"files": [
|
||||
["bfd_uniclust_hits.a3m", 212896478938, 1680200],
|
||||
["mgnify_hits.a3m", 212893696883, 2782055],
|
||||
["pdb70_hits.hhr", 212898159138, 614978],
|
||||
["uniref90_hits.a3m", 212898774116, 6165789]
|
||||
]
|
||||
},
|
||||
"6kwc_A": {
|
||||
"db": "alignment_db_1.db",
|
||||
"files": [
|
||||
["bfd_uniclust_hits.a3m", 415618723280, 380289],
|
||||
["mgnify_hits.a3m", 415618556077, 167203],
|
||||
["pdb70_hits.hhr", 415619103569, 148672],
|
||||
["uniref90_hits.a3m", 415617547852, 1008225]
|
||||
]
|
||||
}
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
For each entry, the corresponding `alignment_db` file and the byte start location and number of bytes to read the respective alignments are given. For example, the alignment information in `bfd_uniclust_hits.a3m` for chain `3lrm_A` can be found in the database file `alignment_db_0.db`, starting at byte location `212896478938` and reading in the next `1680200` bytes.
|
||||
|
||||
### Chain cache files and mmCIF cache files
|
||||
|
||||
Information from the mmcif files can be parsed in advance to create a `chain_cache.json` or a `mmcif_cache.json`. For OpenFold, the `chain_cache.json` is used to sample chains for training, and the `mmcif_cache.json` is used to prefilter templates.
|
||||
|
||||
Here's what the chain_cache.json entry looks like for our examples:
|
||||
|
||||
```chain_cache.json
|
||||
{
|
||||
...
|
||||
"3lrm_A": {
|
||||
"release_date": "2010-06-30",
|
||||
"seq": "MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"resolution": 2.7,
|
||||
"cluster_size": 6
|
||||
},
|
||||
"3lrm_B": {
|
||||
"release_date": "2010-06-30",
|
||||
"seq": "MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"resolution": 2.7,
|
||||
"cluster_size": 6
|
||||
},
|
||||
"3lrm_C": {
|
||||
"release_date": "2010-06-30",
|
||||
"seq": "MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"resolution": 2.7,
|
||||
"cluster_size": 6
|
||||
},
|
||||
"3lrm_D": {
|
||||
"release_date": "2010-06-30",
|
||||
"seq": "MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"resolution": 2.7,
|
||||
"cluster_size": 6
|
||||
},
|
||||
"6kwc_A": {
|
||||
"release_date": "2021-01-27",
|
||||
"seq": "GSTIQPGTGYNNGYFYSYWNDGHGGVTYTNGPGGQFSVNWSNSGEFVGGKGWQPGTKNKVINFSGSYNPNGNSYLSVYGWSRNPLIEYYIVENFGTYNPSTGATKLGEVTSDGSVYDIYRTQRVNQPSIIGTATFYQYWSVRRNHRSSGSVNTANHFNAWAQQGLTLGTMDYQIVAVQGYFSSGSASITVS",
|
||||
"resolution": 1.297,
|
||||
"cluster_size": 195
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
The mmcif_cache.json file would contain similar information, but condensed by mmcif id, e.g.
|
||||
|
||||
```mmcif_cache.json
|
||||
{
|
||||
"3lrm": {
|
||||
"release_date": "2010-06-30",
|
||||
"chain_ids": [
|
||||
"A",
|
||||
"B",
|
||||
"C",
|
||||
"D"
|
||||
],
|
||||
"seqs": [
|
||||
"MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK",
|
||||
"MFAFYFLTACISLKGVFGVSPSYNGLGLTPQMGWDNWNTFACDVSEQLLLDTADRISDLGLKDMGYKYIILDDCWSSGRDSDGFLVADEQKFPNGMGHVADHLHNNSFLFGMYSSAGEYTCAGYPGSLGREEEDAQFFANNRVDYLKYANCYNKGQFGTPEISYHRYKAMSDALNKTGRPVFYSLCNWGQDLTFYWGSGIANSWRMSGDVTAEFTRPDSRCPCDGDEYDCKYAGFHCSIMNILNKAAPMGQNAGVGGWNDLDNLEVGVGNLTDDEEKAHFSMWAMVKSPLIIGANVNNLKASSYSIYSQASVIAINQDSNGIPATRVWRYYVSDTDEYGQGEIQMWSGPLDNGDQVVALLNGGSVSRPMNTTLEEIFFDSNLGSKKLTSTWDIYDLWANRVDNSTASAILGRNKTATGILYNATEQSYKDGLSKNDTRLFGQKIGSLSPNAILNTTVPAHGIAFYRLRPSSDYKDDDDK"
|
||||
],
|
||||
"no_chains": 4,
|
||||
"resolution": 2.7
|
||||
},
|
||||
"6kwc": {
|
||||
"release_date": "2021-01-27",
|
||||
"chain_ids": [
|
||||
"A"
|
||||
],
|
||||
"seqs": [
|
||||
"GSTIQPGTGYNNGYFYSYWNDGHGGVTYTNGPGGQFSVNWSNSGEFVGGKGWQPGTKNKVINFSGSYNPNGNSYLSVYGWSRNPLIEYYIVENFGTYNPSTGATKLGEVTSDGSVYDIYRTQRVNQPSIIGTATFYQYWSVRRNHRSSGSVNTANHFNAWAQQGLTLGTMDYQIVAVQGYFSSGSASITVS"
|
||||
],
|
||||
"no_chains": 1,
|
||||
"resolution": 1.297
|
||||
},
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
### Duplicate pdb chain files
|
||||
|
||||
Duplicate chains occur across pdb entries. Some of these chains are the homomeric units of a multimer, others are subunits that are shared across different protein.
|
||||
|
||||
To reduce storage overhead of creating / storing identical data for duplicate entries, we have a duplicate chain file. Each line stores the all chains that are identical. Our `6kwc` and `3lrm` examples would be stored as follows.
|
||||
|
||||
```duplicate_pdb_chains.txt
|
||||
...
|
||||
6kwc_A
|
||||
3lrm_A 3lrm_B 3lrm_C 3lrm_D
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
34
docs/source/FAQ.md
Normal file
34
docs/source/FAQ.md
Normal file
@@ -0,0 +1,34 @@
|
||||
# FAQ
|
||||
|
||||
Frequently asked questions or encountered issues when running OpenFold.
|
||||
|
||||
## Setup
|
||||
|
||||
- When running unit tests (e.g. [`./scripts/run_unit_tests.sh`](https://github.com/aqlaboratory/openfold/blob/main/scripts/run_unit_tests.sh)), I see an error such as
|
||||
```
|
||||
ImportError: version GLIBCXX_3.4.30 not found
|
||||
```
|
||||
|
||||
> Solution: Make sure that the `$LD_LIBRARY_PATH` environment has been set to include the conda path, e.g. `export $LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH`
|
||||
|
||||
- I see a CUDA mismatch error, eg.
|
||||
```
|
||||
The detected CUDA version (11.8) mismatches the version that was used to compile
|
||||
PyTorch (12.1). Please make sure to use the same CUDA versions.
|
||||
```
|
||||
|
||||
> Solution: Ensure that your system's CUDA driver and toolkit match your intended OpenFold installation (CUDA 11 by default). You can check the CUDA driver version with a command such as `nvidia-smi`
|
||||
|
||||
- I get some error involving `fatal error: cuda_runtime.h: No such file or directory` and or `ninja: build stopped: subcommand failed.`.
|
||||
|
||||
> Solution: Something went wrong with setting up some of the custom kernels. Try running `install_third_party_dependencies.sh` again or try `python3 setup.py install` from inside the OpenFold folder. Make sure to prepend the conda environment as described above before running this.
|
||||
|
||||
## Training
|
||||
|
||||
- My model training is hanging on the data loading step:
|
||||
> Solution: While each system is different, a few general suggestions:
|
||||
- Check your `$KMP_AFFINITY` environment setting and see if it is suitable for your system.
|
||||
- Adjust the number of data workers used to prepare data with the `--num_workers` setting. Increasing the number could help with dataset processing speed. However, to many workers could cause an OOM issue.
|
||||
|
||||
- When I reload my pretrained model weights or checkpoints, I get `RuntimeError: Error(s) in loading state_dict for OpenFoldWrapper: Unexpected key(s) in state_dict:`
|
||||
> Solution: This suggests that your checkpoint / model weights are in OpenFold v1 format with outdated model layer names. Convert your weights/checkpoints following [this guide](convert_of_v1_weights.md).
|
||||
170
docs/source/Inference.md
Normal file
170
docs/source/Inference.md
Normal file
@@ -0,0 +1,170 @@
|
||||
# OpenFold Inference
|
||||
|
||||
In this guide, we will cover how to use OpenFold to make structure predictions.
|
||||
|
||||
## Background
|
||||
|
||||
We currently offer three modes of inference prediction:
|
||||
|
||||
- Monomer
|
||||
- Multimer
|
||||
- Single Sequence (Soloseq)
|
||||
|
||||
This guide will focus on monomer prediction, the next sections will describe [Multimer](Multimer_Inference.md) and [Single Sequence](Single_Sequence_Inference.md) prediction.
|
||||
`
|
||||
### Pre-requisites:
|
||||
|
||||
- OpenFold Conda Environment. See [OpenFold Installation](Installation.md) for instructions on how to build this environment.
|
||||
- Downloading sequence databases for performing multiple sequence alignments. We provide a script to download the AlphaFold databases [here](https://github.com/aqlaboratory/openfold/blob/main/scripts/download_alphafold_dbs.sh).
|
||||
|
||||
|
||||
## Running AlphaFold Model Inference
|
||||
|
||||
The script [`run_pretrained_openfold.py`](https://github.com/aqlaboratory/openfold/blob/main/run_pretrained_openfold.py) performs model inference. We will go through the steps of how to use this script.
|
||||
|
||||
An example directory for performing infernce on [PDB:6KWC](https://www.rcsb.org/structure/6KWC) is provided [here](https://github.com/aqlaboratory/openfold/tree/main/examples/monomer). We refer to this example directory for the below examples.
|
||||
|
||||
### Download Model Parameters
|
||||
|
||||
For monomer inference, you may either use the model parameters provided by Deepmind, or you may use the OpenFold trained parameters. Both models should give similar performance, please see [our main paper](https://www.biorxiv.org/content/10.1101/2022.11.20.517210v3) for further reference.
|
||||
|
||||
The model parameters provided by Deepmind can be downloaded with the following script located in this repository's `scripts/` directory:
|
||||
|
||||
```
|
||||
$ bash scripts/download_alphafold_params.sh $PARAMS_DIR
|
||||
```
|
||||
|
||||
To use the OpenFold trained parameters, you can use the following script
|
||||
|
||||
```
|
||||
$ bash scripts/download_openfold_params.sh $PARAMS_DIR
|
||||
```
|
||||
|
||||
We recommend selecting `openfold/resources` as the params directory as this is the default directory used by the `run_pretrained_openfold.py` to locate parameters.
|
||||
|
||||
If you choose to use a different directory, you may make a symlink to the `openfold/resources` directory, or specify an alternate parameter path with the command line argument `--jax_path` for AlphaFold parameters or `--openfold_checkpoint_path` for OpenFold parameters.
|
||||
|
||||
|
||||
### Model Inference
|
||||
|
||||
The input to [`run_pretrained_openfold.py`](https://github.com/aqlaboratory/openfold/blob/main/run_pretrained_openfold.py) is a directory of FASTA files. AlphaFold-style models also require a sequence alignment to perform inference.
|
||||
|
||||
If you do not have sequence alignments for your input sequences, you can compute them using the inference script directly by following the instructions for the following section [inference without pre-computed alignments](#model-inference-without-pre-computed-alignments).
|
||||
|
||||
Otherwise, if you already have alignments for your input FASTA sequences, skip ahead to the [inference with pre-computed alignments](#model-inference-with-pre-computed-alignments) section.
|
||||
|
||||
#### Model inference without pre-computed alignments
|
||||
The following command performs a sequence alignment against the OpenProteinSet databases and performs model inference.
|
||||
|
||||
```
|
||||
python3 run_pretrained_openfold.py \
|
||||
$INPUT_FASTA_DIR \
|
||||
$TEMPLATE_MMCIF_DIR
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--config_preset model_1_ptm \
|
||||
--uniref90_database_path $BASE_DATA_DIR/uniref90 \
|
||||
--mgnify_database_path $BASE_DATA_DIR/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path $BASE_DATA_DIR/pdb70 \
|
||||
--uniclust30_database_path $BASE_DATA_DIR/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--bfd_database_path $BASE_DATA_DIR/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--model_device "cuda:0"
|
||||
```
|
||||
|
||||
**Required arguments:**
|
||||
- `--output_dir`: specify the output directory
|
||||
- `$INPUT_FASTA_DIR`: Directory of query fasta files, one sequence per file,e.g. `examples/monomer/fasta_dir`
|
||||
- `$TEMPLATE_MMCIF_DIR`: MMCIF files to use for template matching. This directory is required even if using template free inference.
|
||||
- `*_database_path`: Paths to sequence databases for sequence alignment.
|
||||
- `--model_device`: Specify to use a GPU is one is available.
|
||||
|
||||
#### Model inference with pre-computed alignments
|
||||
To perform model inference with pre-computed alignments, use the following command
|
||||
|
||||
```
|
||||
python3 run_pretrained_openfold.py ${INPUT_FASTA_DIR} \
|
||||
$TEMPLATE_MMCIF_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--use_precomputed_alignments $PRECOMPUTED_ALIGNMENTS \
|
||||
--config_preset model_1_ptm \
|
||||
--model_device "cuda:0" \
|
||||
```
|
||||
|
||||
where `$PRECOMPUTED_ALIGNMENTS` is a directory that contains alignments. A sample alignments directory structure for a single query is:
|
||||
|
||||
```
|
||||
alignments
|
||||
└── 6KWC_1
|
||||
├── bfd_uniclust_hits.a3m
|
||||
├── hhsearch_output.hhr
|
||||
├── mgnify_hits.sto
|
||||
└── uniref90_hits.sto
|
||||
```
|
||||
|
||||
`bfd_uniclust_hits.a3m`, `mgnify_hits.sto`, and `uniref90_hits.sto` are all alignments of the query structure against the BFD, Mgnify, and Uniref90 datasets respsectively. `hhsearch_output.hhr` contains hits against the PDB70 database used for template matching. The example directory `examples/monomer/alignments` shows examples of expected directories.
|
||||
|
||||
|
||||
#### Configuration settings for template modeling / pTM scoring
|
||||
There are a few configuration settings available for template based and template-free modeling, and for the option to estimate a predicted template modeling score (pTM).
|
||||
|
||||
This table provides guidance on which setting to use for each set of predictions, as well as the parameters to select for each preset.
|
||||
|
||||
| Setting | `config_preset` | AlphaFold params (match config name) | OpenFold params (any are allowed) |
|
||||
| -------------------------: | ----------------------------------------: | :-------------------------------------------------------------------------------- | :--------------------------------- |
|
||||
| With template, no ptm | model_1<br>model_2 | `parms_model_1.npz`<br>`parms_model_2.npz` | `finetuning_[2-5].pt` |
|
||||
| With template, with ptm | model_1_ptm<br>model_2_ptm | `params_model_1_ptm.npz`<br>`params_model_2_ptm.npz` | `finetuning_ptm_[1-2].pt` |
|
||||
| Without template, no ptm | model_3<br>model_4<br>model_5 | `parms_model_3.npz`<br>`parms_model_4.npz`<br>`parms_model_5.npz` | `finetuning_no_templ_[1-2].pt` |
|
||||
| Without template, with ptm | model_3_ptm<br>model_4_ptm<br>model_5_ptm | `parms_model_3_ptm.npz`<br>`parms_model_4_ptm.npz`<br>`parms_model_5_ptm.npz`<br> | `finetuning_no_templ_ptm_1.pt` |
|
||||
|
||||
If you use AlphaFold parameters, and the AlphaFold parameters are located in the default parameter directory (e.g. `openfold/resources`) the parameters that match the `--config_preset` will be selected.
|
||||
|
||||
The full set of configurations available for all 5 AlphaFold model presets can be viewed in [`config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L105). The [OpenFold Parameters](OpenFold_Parameters.md) page contains more information about the individual OpenFold parameter files.
|
||||
|
||||
|
||||
#### Model outputs
|
||||
|
||||
The expected output contents are as follows:
|
||||
- `alignments`: Directory of alignments. One directory is made per query sequence, and each directory contains alignments against each of the databases used.
|
||||
- `predictions`: PDB files for predicted structures
|
||||
- `timings.json`: Json with timings for inference and relaxation, if specified
|
||||
|
||||
|
||||
### Optional Flags
|
||||
|
||||
Some commonly used command line flags are here. A full list of flags can be viewed from the `--help` menu
|
||||
|
||||
- `--config_preset`: Specify a different model configuration. There are 5 available model preset settings, some of which support template modeling, others support template-free modeling. The default is `model_1`. More details can be below in the [[Inference#Template-free modeling]] section
|
||||
- `--hmmsearch_binary_path`, `--hmmbuild_binary_path`, etc. Hmmer, HHsuite, kalign are required to run alignments. `run_pretrained_openfold.py` will search for these packages in the `bin/` directory of your conda environment. If needed, you can specify a different binary directory with these arguments.
|
||||
- `--openfold_checkpoint_path` : Uses an checkpoint or parameter file. Expected types are Deepspeed checkpoint files or `.pt` files. Make sure your selected checkpoint file matches the configuration setting chosen in `--config_preset`.
|
||||
- `--data_random_seed`: Specifies a random seed to use.
|
||||
- `--save_outputs`: Saves a copy of all outputs from the model, e.g. the output of the msa track, ptm heads.
|
||||
- `--experiment_config_json`: Specify configuration settings using a json file. For example, passing a json with `{globals.relax.max_iterations = 10}` specifies 10 as the maximum number of relaxation iterations. See for [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py#L283) the full dictionary of configuration settings. Any parameters that are not manually set in these configuration settings will refer to the defaults specified by your `config_preset`.
|
||||
|
||||
|
||||
### Advanced Options for Increasing Efficiency
|
||||
|
||||
#### Speeding up inference
|
||||
|
||||
The **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative.
|
||||
|
||||
If your system supports deepseed, using deepspeed generally leads an inference speedup of 2 - 3x without significant additional memory use. You may specify this option by selecting the `--use_deepspeed_inference` argument.
|
||||
|
||||
If DeepSpeed is unavailable for your system, you may also try using [FlashAttention](https://github.com/HazyResearch/flash-attention) by adding `globals.use_flash = True` to the `--experiment_config_json`. Note that FlashAttention appears to work best for sequences with < 1000 residues.
|
||||
|
||||
#### Large-scale batch inference
|
||||
For large-scale batch inference, we offer an optional tracing mode, which massively improves runtimes at the cost of a lengthy model compilation process. To enable it, add `--trace_model` to the inference command.
|
||||
|
||||
#### Configuring the chunk size for sequence alignments
|
||||
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement) is enabled by default in inference mode. To disable it, set `globals.chunk_size` to `None` in the config. If a value is specified, OpenFold will attempt to dynamically tune it, considering the chunk size specified in the config as a minimum. This tuning process automatically ensures consistently fast runtimes regardless of input sequence length, but it also introduces some runtime variability, which may be undesirable for certain users. It is also recommended to disable this feature for very long chains (see below). To do so, set the `tune_chunk_size` option in the config to `False`.
|
||||
|
||||
#### Long sequence inference
|
||||
To minimize memory usage during inference on long sequences, consider the following changes:
|
||||
|
||||
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template stack is a major memory bottleneck for inference on long sequences. OpenFold supports two mutually exclusive inference modes to address this issue. One, `average_templates` in the `template` section of the config, is similar to the solution offered by AlphaFold-Multimer, which is simply to average individual template representations. Our version is modified slightly to accommodate weights trained using the standard template algorithm. Using said weights, we notice no significant difference in performance between our averaged template embeddings and the standard ones. The second, `offload_templates`, temporarily offloads individual template embeddings into CPU memory. The former is an approximation while the latter is slightly slower; both are memory-efficient and allow the model to utilize arbitrarily many templates across sequence lengths. Both are disabled by default, and it is up to the user to determine which best suits their needs, if either.
|
||||
- Inference-time low-memory attention (LMA) can be enabled in the model config. This setting trades off speed for vastly improved memory usage. By default, LMA is run with query and key chunk sizes of 1024 and 4096, respectively. These represent a favorable tradeoff in most memory-constrained cases. Powerusers can choose to tweak these settings in `openfold/model/primitives.py`. For more information on the LMA algorithm, see the aforementioned Staats & Rabe preprint.
|
||||
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only wastes time.
|
||||
- As a last resort, consider enabling `offload_inference`. This enables more extensive CPU offloading at various bottlenecks throughout the model.
|
||||
- Disable FlashAttention, which seems unstable on long sequences.
|
||||
|
||||
Using the most conservative settings, we were able to run inference on a 4600-residue complex with a single A100. Compared to AlphaFold's own memory offloading mode, ours is considerably faster; the same complex takes the more efficent AlphaFold-Multimer more than double the time. Use the `long_sequence_inference` config option to enable all of these interventions at once. The `run_pretrained_openfold.py` script can enable this config option with the `--long_sequence_inference` command line option
|
||||
|
||||
Input FASTA files containing multiple sequences are treated as complexes. In this case, the inference script runs AlphaFold-Gap, a hack proposed [here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
|
||||
79
docs/source/Multimer_Inference.md
Normal file
79
docs/source/Multimer_Inference.md
Normal file
@@ -0,0 +1,79 @@
|
||||
# Multimer Inference
|
||||
|
||||
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, you will need:
|
||||
|
||||
- AlphaFold Multimer v2.3 parameters
|
||||
- Updated sequence databases, with UniRef and PDB Seqres databases.
|
||||
|
||||
|
||||
## Upgrade from a previous OpenFold Installation
|
||||
|
||||
If you had previously downloaded OpenFold parameters and or AlphaFold databases, you will need to download updated versions. Here are some instructions for upgrading from an existing openfold installations.
|
||||
|
||||
### Download AlphaFold-Multimer v2.3 Model Parameters
|
||||
1. Re-download the alphafold parameters to get the latest
|
||||
AlphaFold-Multimer v2.3 weights:
|
||||
|
||||
```bash
|
||||
bash scripts/download_alphafold_params.sh openfold/resources
|
||||
```
|
||||
|
||||
### Download AlphaFold Databases for Multimer
|
||||
|
||||
1. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
|
||||
and [PDB SeqRes](https://www.rcsb.org/) databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniprot.sh data/
|
||||
```
|
||||
|
||||
The PDB SeqRes and PDB databases must be from the same date to avoid potential
|
||||
errors during template searching. Remove the existing `data/pdb_mmcif` directory
|
||||
and download both databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_pdb_mmcif.sh data/
|
||||
bash scripts/download_pdb_seqres.sh data/
|
||||
```
|
||||
|
||||
1. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
|
||||
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniref30.sh data/
|
||||
bash scripts/download_mgnify.sh data/
|
||||
```
|
||||
|
||||
```{note}
|
||||
Multimer inference can also run with the older database versions if desired.
|
||||
```
|
||||
|
||||
## Running Multimer Inference
|
||||
|
||||
The [`run_pretrained_openfold.py`](https://github.com/aqlaboratory/openfold/blob/main/run_pretrained_openfold.py) script can be used to run multimer inference with the follwoing command.
|
||||
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
|
||||
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
|
||||
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
|
||||
--uniprot_database_path data/uniprot/uniprot.fasta \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
|
||||
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
--config_preset "model_1_multimer_v3" \
|
||||
--model_device "cuda:0" \
|
||||
--output_dir ./
|
||||
```
|
||||
|
||||
Note that template searching in the multimer pipeline
|
||||
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
|
||||
|
||||
As with monomer inference, if you've already computed alignments for the query, you can use
|
||||
the `--use_precomputed_alignments` option.
|
||||
55
docs/source/OpenFold_Parameters.md
Normal file
55
docs/source/OpenFold_Parameters.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Notes on OpenFold Training and Parameters
|
||||
|
||||
For OpenFold model parameters, v. 06_22.
|
||||
|
||||
## Training details
|
||||
|
||||
OpenFold was trained using OpenFold on 44 A100s using the training schedule from Table 4 in
|
||||
the AlphaFold supplement. AlphaFold was used as the pre-distillation model.
|
||||
Training data is hosted publicly in the "OpenFold Training Data" RODA repository.
|
||||
|
||||
To improve model diversity, we forked training after the initial training phase
|
||||
and finetuned an additonal branch without templates.
|
||||
|
||||
## Parameter files
|
||||
|
||||
Parameter files fall into the following categories:
|
||||
|
||||
initial_training.pt:
|
||||
OpenFold at the end of the initial training phase.
|
||||
finetuning_x.pt:
|
||||
Checkpoints in chronological order corresponding to peaks in the
|
||||
validation LDDT-Ca during the finetuning phase. Roughly evenly spaced
|
||||
across the 45 finetuning epochs.
|
||||
|
||||
NOTE: finetuning_1.pt, which was included in a previous release, has
|
||||
been deprecated.
|
||||
finetuning_no_templ_x.pt
|
||||
Checkpoints in chronological order corresponding to peaks during an
|
||||
additional finetuning phase also starting from the 'initial_training.pt'
|
||||
checkpoint but with templates disabled.
|
||||
finetuning_no_templ_ptm_x.pt
|
||||
Checkpoints in chronological order corresponding to peaks during the
|
||||
pTM training phase of the `no_templ` branch. Models in this category
|
||||
include the pTM module and comprise the most recent of the checkpoints
|
||||
in said branch.
|
||||
finetuning_ptm_x.pt:
|
||||
Checkpoints in chronological order corresponding to peaks in the pTM
|
||||
training phase of the mainline branch. Models in this category include
|
||||
the pTM module and comprise the most recent of the checkpoints in said
|
||||
branch.
|
||||
|
||||
Average validation LDDT-Ca scores for each of the checkpoints are listed below.
|
||||
The validation set contains approximately 180 chains drawn from CAMEO over a
|
||||
three-month period at the end of 2021.
|
||||
|
||||
initial_training: 0.9088
|
||||
finetuning_2: 0.9061
|
||||
finetuning_3: 0.9075
|
||||
finetuning_4: 0.9059
|
||||
finetuning_5: 0.9054
|
||||
finetuning_no_templ_1: 0.9014
|
||||
finetuning_no_templ_2: 0.9032
|
||||
finetuning_no_templ_ptm_1: 0.9025
|
||||
finetuning_ptm_1: 0.9075
|
||||
finetuning_ptm_2: 0.9097
|
||||
131
docs/source/OpenFold_Training_Setup.md
Normal file
131
docs/source/OpenFold_Training_Setup.md
Normal file
@@ -0,0 +1,131 @@
|
||||
# Setting up the OpenFold PDB training set from RODA
|
||||
|
||||
The multiple sequence alignments of OpenProteinSet and mmCIF structure files required to train OpenFold are freely available at the [Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold/). Additionally, OpenFold requires some postprocessing and [auxiliary files](Aux_seq_files.md) for training that need to be generated from the AWS data manually. This documentation is intended to give a full overview of those steps starting from the data download.
|
||||
|
||||
### Pre-Requisites:
|
||||
- OpenFold conda environment. See [OpenFold Installation](Installation.md) for instructions on how to build this environment.
|
||||
- In particular, the [AWS CLI](https://aws.amazon.com/cli/) is used to download data from RODA.
|
||||
- For this guide, we assume that the OpenFold codebase is located at `$OF_DIR`.
|
||||
|
||||
## 1. Downloading alignments and structure files
|
||||
To fetch all the alignments corresponding to the original PDB training set of OpenFold alongside their mmCIF 3D structures, you can run the following commands:
|
||||
|
||||
```bash
|
||||
mkdir -p alignment_data/alignment_dir_roda
|
||||
aws s3 cp s3://openfold/pdb/ alignment_data/alignment_dir_roda/ --recursive --no-sign-request
|
||||
|
||||
mkdir pdb_data
|
||||
aws s3 cp s3://openfold/pdb_mmcif.zip pdb_data/ --no-sign-request
|
||||
aws s3 cp s3://openfold/duplicate_pdb_chains.txt . --no-sign-request
|
||||
unzip pdb_mmcif.zip -d pdb_data
|
||||
```
|
||||
|
||||
The nested alignment directory structure is not yet exactly what OpenFold expects, so you can run the `flatten_roda.sh` script to convert them to the correct format:
|
||||
|
||||
```bash
|
||||
bash $OF_DIR/scripts/flatten_roda.sh alignment_data/alignment_dir_roda alignment_data/
|
||||
```
|
||||
|
||||
Afterwards, the old directory can be safely removed:
|
||||
|
||||
```bash
|
||||
rm -r alignment_data/alignment_dir_roda
|
||||
```
|
||||
|
||||
## 2. Creating alignment DBs (optional)
|
||||
As further explained in [Auxiliary Sequence Files in OpenFold](Aux_seq_files.md), OpenFold supports an alternate format for storing alignments that can increase training performance in I/O bottlenecked systems. These so-called `alignment_db` files can be generated with the following script:
|
||||
|
||||
```bash
|
||||
python $OF_DIR/scripts/alignment_db_scripts/create_alignment_db_sharded.py \
|
||||
alignment_data/alignments \
|
||||
alignment_data/alignment_dbs \
|
||||
alignment_db \
|
||||
--n_shards 10 \
|
||||
--duplicate_chains_file pdb_data/duplicate_pdb_chains.txt
|
||||
```
|
||||
|
||||
We recommend creating 10 total `alignment_db` files (= "shards") for better
|
||||
filesystem health and fast preprocessing, but note that this script will only run
|
||||
optimally if the number of CPUs on your machine is at least as big as the number
|
||||
of shards you are creating.
|
||||
|
||||
As an optional check, you can run the following command which should return $634,434$:
|
||||
|
||||
```bash
|
||||
grep "files" alignment_data/alignment_dbs/alignment_db.index | wc -l
|
||||
```
|
||||
|
||||
## 3. Adding duplicate chains to alignments (skip if step 2 was used)
|
||||
To save space, the OpenProteinSet alignment database is stored without duplicates, meaning that only one representative alignment is stored for all chains with identical sequences in the PDB and duplicate instances are tracked with a [`duplicate_chains.txt`](Aux_seq_files.md#duplicate-pdb-chain-files) file. As OpenFold will select chains during training based on the chains in the alignment directory (or `alignment_db`), we therefore need to add those duplicate chains back in in order to train on the full conformational diversity of chains in the PDB.
|
||||
|
||||
If you've followed the optional Step 2, the `.index` file of your `alignment_db` files will have already been adjusted for duplicates and you can proceed to the next step. Otherwise, the standard alignment directory can be expanded to accommodate duplicates by inserting symlinked directories for the duplicate chains that point to their representative alignments:
|
||||
|
||||
```bash
|
||||
python $OF_DIR/scripts/expand_alignment_duplicates.py \
|
||||
alignment_data/alignments \
|
||||
pdb_data/duplicate_pdb_chains.txt
|
||||
```
|
||||
|
||||
As an optional check, the following command should return $634,434$:
|
||||
|
||||
```bash
|
||||
ls alignment_data/alignments/ | wc -l
|
||||
```
|
||||
|
||||
## 4. Generating cluster-files
|
||||
The AlphaFold dataloader adjusts the sampling probability of chains by their inverse cluster size, so we need to generate these sequence clusters for our training set.
|
||||
|
||||
As a first step, we'll need a `.fasta` file of all sequences in the training set. This can be generated with the following scripts, depending on how you set up your alignment data in the previous steps:
|
||||
|
||||
**Use this if you set up the duplicate-expanded alignment directory (faster):**
|
||||
```bash
|
||||
python $OF_DIR/scripts/alignment_data_to_fasta.py \
|
||||
alignment_data/all-seqs.fasta \
|
||||
--alignment_dir alignment_data/alignments
|
||||
```
|
||||
|
||||
**Use this if you set up the `alignment_db` files:**
|
||||
```bash
|
||||
python $OF_DIR/scripts/alignment_data_to_fasta.py \
|
||||
alignment_data/all-seqs.fasta \
|
||||
--alignment_db_index alignment_data/alignment_dbs/alignment_db.index
|
||||
```
|
||||
|
||||
Next, we need to generate a cluster file at 40% sequence identity, which will contain all chains in a particular cluster on the same line. You'll need [MMSeqs2](https://github.com/soedinglab/MMseqs2?tab=readme-ov-file#installation) for this as well, which can be set up either in a conda environment or as a binary.
|
||||
|
||||
```bash
|
||||
python $OF_DIR/scripts/fasta_to_clusterfile.py \
|
||||
alignment_data/all-seqs.fasta \
|
||||
alignment_data/all-seqs_clusters-40.txt \
|
||||
/path/to/mmseqs \
|
||||
--seq-id 0.4
|
||||
```
|
||||
|
||||
## 5. Generating cluster-files
|
||||
As a last step, OpenFold requires ["cache" files](Aux_seq_files.md#chain-cache-files-and-mmcif-cache-files) with metadata information for each chain that are used for choosing templates and samples during training.
|
||||
|
||||
The data caches for OpenProteinSet can be downloaded from RODA with the following:
|
||||
|
||||
```bash
|
||||
aws s3 cp s3://openfold/data_caches/ pdb_data/ --recursive --no-sign-request
|
||||
```
|
||||
If you wish to create data caches for your own datasets, the steps to generate the cache are as follows:
|
||||
|
||||
```bash
|
||||
mkdir pdb_data/data_caches
|
||||
|
||||
python $OF_DIR/scripts/generate_mmcif_cache.py \
|
||||
pdb_data/mmcif_files \
|
||||
pdb_data/data_caches/mmcif_cache.json \
|
||||
--no_workers 16
|
||||
```
|
||||
|
||||
The chain-data-cache is used for filtering training samples and adjusting per-chain sampling probabilities and can be generated with the following script:
|
||||
|
||||
```bash
|
||||
python $OF_DIR/scripts/generate_chain_data_cache.py \
|
||||
pdb_data/mmcif_files \
|
||||
pdb_data/data_caches/chain_data_cache.json \
|
||||
--cluster_file alignment_data/all-seqs_clusters-40.txt \
|
||||
--no_workers 16
|
||||
```
|
||||
57
docs/source/Single_Sequence_Inference.md
Normal file
57
docs/source/Single_Sequence_Inference.md
Normal file
@@ -0,0 +1,57 @@
|
||||
### Soloseq inference
|
||||
|
||||
MSA-free sequence to structure prediction using the [ESM-1b model](https://github.com/facebookresearch/esm) embeddings.
|
||||
|
||||
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
|
||||
|
||||
For generating ESM-1b embeddings in bulk, use the provided script: [`scripts/precompute_embeddings.py`](https://github.com/aqlaboratory/openfold/blob/main/scripts/precompute_embeddings.py). The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
|
||||
|
||||
```shell
|
||||
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
|
||||
```
|
||||
|
||||
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
|
||||
|
||||
Then download the SoloSeq model weights, e.g.:
|
||||
|
||||
```shell
|
||||
bash scripts/download_openfold_soloseq_params.sh openfold/resources
|
||||
```
|
||||
|
||||
Now, you are ready to run inference:
|
||||
|
||||
```shell
|
||||
python run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--use_precomputed_alignments embeddings_output_dir \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
|
||||
```
|
||||
|
||||
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
|
||||
|
||||
```shell
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
```
|
||||
|
||||
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
|
||||
|
||||
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
|
||||
|
||||
```{note}
|
||||
Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
|
||||
```
|
||||
170
docs/source/Training_OpenFold.md
Normal file
170
docs/source/Training_OpenFold.md
Normal file
@@ -0,0 +1,170 @@
|
||||
# Training OpenFold
|
||||
## Background
|
||||
|
||||
This guide covers how to train an OpenFold model for monomers. Some additional instructions are provided at the end for fine-tuning your model.
|
||||
|
||||
### Pre-requisites:
|
||||
|
||||
This guide requires the following:
|
||||
- [Installation of OpenFold and dependencies](Installation.md) (Including jackhmmer and hhblits depedencies)
|
||||
- A preprocessed dataset:
|
||||
- For this guide, we will use the original OpenFold dataset which is available on RODA, processed with [these instructions](OpenFold_Training_Setup.md).
|
||||
- GPUs configured with CUDA. Training OpenFold with CPUs only is not supported.
|
||||
|
||||
## Training a new OpenFold model
|
||||
|
||||
#### Basic command
|
||||
|
||||
For a dataset that has the default alignment file structure, e.g.
|
||||
|
||||
```
|
||||
-$DATA_DIR
|
||||
└── pdb_data
|
||||
├── mmcifs
|
||||
├── 3lrm.cif
|
||||
└── 6kwc.cif
|
||||
...
|
||||
├── obsolete.dat
|
||||
├── duplicate_pdb_chains.txt
|
||||
└── data_caches
|
||||
├── duplicate_pdb_chains.txt
|
||||
└── data_caches
|
||||
└── alignment_data
|
||||
└── alignments
|
||||
├── 3lrm_A/
|
||||
├── 3lrm_B/
|
||||
└── 6kwc_A/
|
||||
...
|
||||
```
|
||||
|
||||
The basic command to train a new OpenFold model is:
|
||||
|
||||
```
|
||||
python3 train_openfold.py $DATA_DIR/pdb/mmcifs $DATA_DIR/alignment_data/alignments $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
|
||||
--max_template_date 2021-10-10 \
|
||||
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
|
||||
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
|
||||
--config_preset initial_training \
|
||||
--seed 42 \
|
||||
--obsolete_pdbs_file_path $DATA_DIR/pdb_data/obsolete.dat \
|
||||
--num_nodes 1 \
|
||||
--gpus 4 \
|
||||
--num_workers 4
|
||||
```
|
||||
|
||||
The required arguments are:
|
||||
- `mmcif_dir` : Mmcif files for the training set.
|
||||
- `alignments_dir`: Alignments for the sequences in `mmcif_dir`, see expected directory structure
|
||||
- `template_mmcif_dir`: Template mmcif files with structures, which can be the same directory as mmcif_dir. The `max_template_date` and `template_release_dates_cache_path` will specify which templates will be allowed based on a date cutoff
|
||||
- `output_dir` : Where model checkpoint files and other outputs will be saved.
|
||||
|
||||
Commonly used flags include:
|
||||
- `config_preset`: Specifies which selection of hyperparameters should be used for initial model training. Commonly used configs are defined in [`openfold/config.py`](https://github.com/aqlaboratory/openfold)
|
||||
- `num_nodes` and `gpus`: Specifies number of nodes and GPUs available to train OpenFold.
|
||||
- `seed` - Specifies random seed
|
||||
- `num_workers`: Number of CPU workers to assign for creating dataset examples
|
||||
- `obsolete_pdbs_file_path`: Specifies obsolete pdb IDs that should be excluded from training.
|
||||
- `val_data_dir` and `val_alignment_dir`: Specifies data directory and alignments for validation dataset.
|
||||
|
||||
```{note}
|
||||
Note that `--seed` must be specified to correctly configure training examples on multi-GPU training runs
|
||||
```
|
||||
|
||||
#### Train with OpenFold Dataset Configuration
|
||||
|
||||
If the [OpenFold alignment database](OpenFold_Training_Setup.md#2-creating-alignment-dbs-optional) setup is used, resulting in a data directory such as:
|
||||
```
|
||||
- $DATA_DIR
|
||||
├── duplicate_pdb_chains.txt
|
||||
├── pdb_data
|
||||
└── mmcifs
|
||||
├── 3lrm.cif
|
||||
└── 6kwc.cif
|
||||
└── alignment_data
|
||||
└── alignment_db
|
||||
├── alignment_db_0.db
|
||||
├── alignment_db_1.db
|
||||
...
|
||||
├── alignment_db_9.db
|
||||
└── alignment_db.index
|
||||
```
|
||||
|
||||
The training command will use the `alignment_index_path` argument to specify `db.index` files, e.g.:
|
||||
|
||||
```
|
||||
python3 train_openfold.py $DATA_DIR/pdb_data/mmcifs $DATA_DIR/alignment_data/alignment_db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
|
||||
--max_template_date 2021-10-10 \
|
||||
--train_chain_data_cache_path $DATA_DIR/pdb_data/data_caches/chain_data_cache.json \
|
||||
--template_release_dates_cache_path $DATA_DIR/pdb_data/data_caches/mmcif_cache.json \
|
||||
--alignment_index_path $DATA_DIR/pdb/alignment_db.index
|
||||
--config_preset initial_training \
|
||||
--seed 42 \
|
||||
--obsolete_pdbs_file_path $DATA_DIR/pdb/obsolete.dat \
|
||||
--num_nodes 1 \
|
||||
--gpus 4 \
|
||||
--num_workers 4
|
||||
```
|
||||
|
||||
#### Additional command line flag options:
|
||||
|
||||
Here we provide brief descriptions for customizing your training run of OpenFold. A full description of all flags can be accessed by using the `--help` option in the script
|
||||
|
||||
- **Use Deepspeed acceleration strategy:** `--deepspeed_config` This option configures OpenFold to use custom Deepspeed kernels. This option requires a deepspeed_config.json, you can create your own, or use the one in the OpenFold directory
|
||||
|
||||
- **Use a validation dataset:** Specify validation database paths with `--val_data_dir` + `--val_alignment_dir`. Validation metrics will be evaluated on these datasets.
|
||||
|
||||
- **Use a self-distillation dataset:** Specify paths with `--distillation_data_dir` and `--distillation_alignment_dir` flags
|
||||
|
||||
- **Change specific parameters in the model or data setup:** `--experiment_config_json`. These parameters must be defined in the [`openfold/config.py`](https://github.com/aqlaboratory/openfold/blob/main/openfold/config.py). For example to change the crop size for training a model, you can write the following json:
|
||||
```cropsize.json
|
||||
{
|
||||
"data.train.crop_size": 128
|
||||
}
|
||||
```
|
||||
|
||||
- **Configure training settings with PyTorch Lightning**
|
||||
|
||||
Some flags e.g. `--precision`, `--max_epochs` configure training behavior. See the Pytorch Lightning Trainer args section in the `--help` menu for more information and consult [Pytorch lightning documentation](https://lightning.ai/docs/pytorch/stable/)
|
||||
|
||||
- Precision: On A100s, OpenFold training works best with bfloat 16 precision (e.g. `--precision bf16-mixed`)
|
||||
|
||||
- **Restart training from an existing checkpoint:** Use the `--resume_from_ckpt` to restart training from an existing checkpoint.
|
||||
|
||||
## Advanced Training Configurations
|
||||
:::
|
||||
|
||||
### Fine tuning from existing model weights
|
||||
|
||||
If you have existing model weights, you can fine tune the model by specifying a checkpoint path with `--resume_from_ckpt` and `--resume_model_weights_only` arguments, e.g.
|
||||
|
||||
```
|
||||
python3 train_openfold.py $DATA_DIR/mmcifs $DATA_DIR/alignment.db $TEMPLATE_MMCIF_DIR $OUTPUT_DIR \
|
||||
--max_template_date 2021-10-10 \
|
||||
--train_chain_data_cache_path chain_data_cache.json \
|
||||
--template_release_dates_cache_path mmcif_cache.json \
|
||||
--config_preset finetuning \
|
||||
--alignment_index_path $DATA_DIR/pdb/alignment_db.index \
|
||||
--seed 4242022 \
|
||||
--obsolete_pdbs_file_path obsolete.dat \
|
||||
--num_nodes 1 \
|
||||
--gpus 4 \
|
||||
--num_workers 4 \
|
||||
--resume_from_ckpt $CHECKPOINT_PATH \
|
||||
--resume_model_weights_only
|
||||
```
|
||||
|
||||
If you have model parameters from OpenFold v1.x, you may need to convert your checkpoint file or parameter. See [Converting OpenFold v1 Weights](convert_of_v1_weights.md) for more details.
|
||||
|
||||
### Using MPI
|
||||
|
||||
If MPI is configured on your system, and you would like to use MPI to train OpenFold models, you may do so with the following step:
|
||||
|
||||
1. Add the `mpi4py` package, which are available through pip and conda. Please see [mpi4py documentation](https://pypi.org/project/mpi4py/) for more instructions on installation.
|
||||
2. Add the `--mpi_plugin` flag to your training command.
|
||||
|
||||
|
||||
### Training Multimer models
|
||||
|
||||
```{note}
|
||||
Coming soon.
|
||||
```
|
||||
31
docs/source/conf.py
Normal file
31
docs/source/conf.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Configuration file for the Sphinx documentation builder.
|
||||
#
|
||||
# For the full list of built-in configuration values, see the documentation:
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html
|
||||
|
||||
# -- Project information -----------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||
|
||||
project = 'OpenFold'
|
||||
copyright = '2024, OpenFold Team'
|
||||
author = 'OpenFold Team'
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
|
||||
|
||||
extensions = [
|
||||
'myst_parser',
|
||||
]
|
||||
|
||||
templates_path = ['_templates']
|
||||
exclude_patterns = []
|
||||
|
||||
|
||||
|
||||
# -- Options for HTML output -------------------------------------------------
|
||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
|
||||
|
||||
html_theme = 'furo'
|
||||
html_static_path = ['_static']
|
||||
myst_enable_extensions = ["colon_fence", "dollarmath", "amsmath"]
|
||||
|
||||
38
docs/source/convert_of_v1_weights.md
Normal file
38
docs/source/convert_of_v1_weights.md
Normal file
@@ -0,0 +1,38 @@
|
||||
## Weights Renaming
|
||||
|
||||
As part of the [OpenFold v2 update](https://github.com/aqlaboratory/openfold/releases/tag/v2.0.0) with the integration of multimer prediction, certain model layers of the AlphaFold model were renamed. For example.
|
||||
|
||||
`module.model.template_angle_embedder.*` is now referred to as
|
||||
`module.model.template_embedder.template_single_embedder.*`
|
||||
|
||||
If you have some checkpoints that were trained using OpenFold v1 or older, and now want to resume training on OpenFold v2, you may need to convert your checkpoints.
|
||||
|
||||
## FAQ
|
||||
|
||||
### Do I need to convert my checkpoints / model weights?
|
||||
|
||||
If you want to run inference or resume training from a checkpoint that was trained with OpenFold V1, you will need to convert your checkpoint.
|
||||
|
||||
If you want load model weights only, without starting from a specific time step, then you should not need to convert your checkpoints. The training of the model will begin from `step=0` in this case. To do so, you'll need both the `--resume_from_ckpt` and `--resume_model_weights_only` flags. This example allows you train starting from the pre-trained openfold weights:
|
||||
|
||||
```bash
|
||||
$ python3 $OPENFOLD_DIR/train_openfold.py test_data_epoch/mmcifs test_data_epoch/alignments test_data_epoch/template_mmcifs $OUTPUT_DIR 2021-09-30 \
|
||||
...
|
||||
--resume_from_ckpt openfold/resources/openfold_params/finetuning_2.pt \
|
||||
--resume_model_weights_only
|
||||
|
||||
```
|
||||
|
||||
### How do I convert my checkpoints?
|
||||
|
||||
Use [`scripts/convert_v1_to_v2_weights.py`](https://github.com/aqlaboratory/openfold/blob/main/scripts/convert_v1_to_v2_weights.py) e.g.
|
||||
|
||||
`python scripts/convert_v1_to_v2_weights.py checkpoints/6-209.ckpt checkpoints/6-209.ckpt.converted`
|
||||
|
||||
Then, to resume training, set the following flags:
|
||||
|
||||
```bash
|
||||
$ python3 $OPENFOLD_DIR/train_openfold.py test_data_epoch/mmcifs test_data_epoch/alignments test_data_epoch/template_mmcifs $OUTPUT_DIR 2021-09-30 \
|
||||
...
|
||||
--resume_from_ckpt checkpoints/6-209.ckpt.converted
|
||||
```
|
||||
118
docs/source/index.md
Normal file
118
docs/source/index.md
Normal file
@@ -0,0 +1,118 @@
|
||||
# OpenFold
|
||||
|
||||
```{figure} ../imgs/of_banner.png
|
||||
:width: 900px
|
||||
:align: center
|
||||
:alt: Comparison of OpenFold and AlphaFold2 predictions to the experimental structure of PDB 7KDX, chain B._
|
||||
```
|
||||
Welcome to the Documentation for OpenFold, the fully open source, trainable, PyTorch-based reproduction of DeepMind's
|
||||
[AlphaFold 2](https://github.com/deepmind/alphafold).
|
||||
|
||||
Here, you will find guides and documentation for:
|
||||
- [Getting started with OpenFold](installation.md)!
|
||||
- Learn how to [run inference with OpenFold](Inference.md)
|
||||
- [Train your own OpenFold models](Training_OpenFold.md)
|
||||
- Find guidance for setup and running OpenFold in the [FAQ](FAQ.md).
|
||||
|
||||
We also have a [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb) that can be used for single structure / multimer prediction.
|
||||
|
||||
Some portions of the documentation are still under migration from the original README, which can be found [here](original_readme.md).
|
||||
|
||||
# Features
|
||||
|
||||
OpenFold carefully reproduces (almost) all of the features of the original open
|
||||
source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
|
||||
model ensembling, which fared poorly in DeepMind's own ablation testing and is being
|
||||
phased out in future DeepMind experiments. It is omitted here for the sake of reducing
|
||||
clutter. In cases where the *Nature* paper differs from the source, we always defer to the
|
||||
latter.
|
||||
|
||||
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
|
||||
and we've trained it from scratch, matching the performance of the original.
|
||||
We've publicly released model weights and our training data — some 400,000
|
||||
MSAs and PDB70 template hit files — under a permissive license. Model weights
|
||||
are available via scripts in this repository while the MSAs are hosted by the
|
||||
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
|
||||
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
|
||||
|
||||
OpenFold also supports inference using AlphaFold's official parameters, and
|
||||
vice versa (see `scripts/convert_of_weights_to_jax.py`).
|
||||
|
||||
OpenFold has the following advantages over the reference implementation:
|
||||
|
||||
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
|
||||
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
|
||||
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
|
||||
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
|
||||
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
|
||||
kernels support in-place attention during inference and training. They use
|
||||
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
|
||||
implementations, respectively.
|
||||
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
|
||||
- **FlashAttention** support greatly speeds up MSA attention.
|
||||
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference.
|
||||
|
||||
# Copyright Notice
|
||||
|
||||
While AlphaFold's and, by extension, OpenFold's source code is licensed under
|
||||
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
|
||||
fall under the CC BY 4.0 license, a copy of which is downloaded to
|
||||
`openfold/resources/params` by the installation script. Note that the latter
|
||||
replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you encounter problems using OpenFold, feel free to create an issue! We also
|
||||
welcome pull requests from the community.
|
||||
|
||||
## Citing this Work
|
||||
|
||||
Please cite our paper:
|
||||
|
||||
```bibtex
|
||||
@article {Ahdritz2022.11.20.517210,
|
||||
author = {Ahdritz, Gustaf and Bouatta, Nazim and Floristean, Christina and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
|
||||
title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization},
|
||||
elocation-id = {2022.11.20.517210},
|
||||
year = {2022},
|
||||
doi = {10.1101/2022.11.20.517210},
|
||||
publisher = {Cold Spring Harbor Laboratory},
|
||||
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
|
||||
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
|
||||
journal = {bioRxiv}
|
||||
}
|
||||
```
|
||||
If you use OpenProteinSet, please also cite:
|
||||
|
||||
```bibtex
|
||||
@misc{ahdritz2023openproteinset,
|
||||
title={{O}pen{P}rotein{S}et: {T}raining data for structural biology at scale},
|
||||
author={Gustaf Ahdritz and Nazim Bouatta and Sachin Kadyan and Lukas Jarosch and Daniel Berenberg and Ian Fisk and Andrew M. Watkins and Stephen Ra and Richard Bonneau and Mohammed AlQuraishi},
|
||||
year={2023},
|
||||
eprint={2308.05326},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={q-bio.BM}
|
||||
}
|
||||
```
|
||||
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
|
||||
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
:caption: Guides
|
||||
Installation.md
|
||||
Inference.md
|
||||
Single_Sequence_Inference.md
|
||||
Multimer_Inference.md
|
||||
OpenFold_Training_Setup.md
|
||||
Training_OpenFold.md
|
||||
```
|
||||
|
||||
```{toctree}
|
||||
:hidden:
|
||||
:caption: Reference
|
||||
Aux_seq_files.md
|
||||
OpenFold_Parameters.md
|
||||
FAQ.md
|
||||
original_readme.md
|
||||
```
|
||||
62
docs/source/installation.md
Normal file
62
docs/source/installation.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Setting Up OpenFold
|
||||
|
||||
In this guide, we will OpenFold and its dependencies.
|
||||
|
||||
**Pre-requisites**
|
||||
|
||||
This package is currently supported for CUDA 11 and Pytorch 1.12. All dependencies are listed in the [`environment.yml`](https://github.com/aqlaboratory/openfold/blob/main/environment.yml)
|
||||
|
||||
At this time, only Linux systems are supported.
|
||||
|
||||
## Instructions
|
||||
:::
|
||||
|
||||
### Installation:
|
||||
1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
|
||||
1. From the `openfold` repo:
|
||||
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
|
||||
`mamba env create -n openfold_env -f environment.yml`
|
||||
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
|
||||
- Activate the environment, e.g `conda activate openfold_env`
|
||||
1. Run the setup script to configure kernels and folding resources.
|
||||
> scripts/install_third_party_dependencies.sh`
|
||||
3. Prepend the conda environment to the $LD_LIBRARY_PATH., e.g.
|
||||
`export $LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH``. You may optionally set this as a conda environment variable according to the [conda docs](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html#saving-environment-variables) to activate each time the environment is used.
|
||||
4. Download parameters. We recommend using a destination as `openfold/resources` as our unittests will look for the weights there.
|
||||
- For AlphaFold2 weights, use
|
||||
> ./scripts/download_alphafold_params.sh <dest>
|
||||
- For OpenFold weights, use :
|
||||
> ./scripts/download_openfold_params.sh <dest>
|
||||
- For OpenFold SoloSeq weights, use:
|
||||
> ./scripts/download_openfold_soloseq_params.sh <dest>
|
||||
|
||||
### Checking your build with unit tests:
|
||||
|
||||
To test your installation, you can run OpenFold unit tests. Make sure that the OpenFold and AlphaFold parameters have been downloaded, and that they are located (or symlinked) in the directory `openfold/resources`
|
||||
|
||||
Run with the following script:
|
||||
> scripts/run_unit_tests.sh
|
||||
|
||||
The script is a thin wrapper around Python's `unittest` suite, and recognizes `unittest` arguments. E.g., to run a specific test verbosely:
|
||||
|
||||
> scripts/run_unit_tests.sh -v tests.test_model
|
||||
|
||||
**Alphafold Comparison tests:**
|
||||
Certain tests perform equivalence comparisons with the AlphaFold implementation. Instructions to run this level of tests requires an environment with both AlphaFold 2.0.1 and OpenFold installed, and is not covered in this guide. These tests are skipped by default if no installation of AlphaFold is found.
|
||||
|
||||
## Environment specific modifications
|
||||
|
||||
### CUDA 12
|
||||
To use OpenFold on CUDA 12 environment rather than a CUDA 11 environment.
|
||||
In step 1, use the branch [`pl_upgrades`](https://github.com/aqlaboratory/openfold/tree/pl_upgrades) rather than the main branch, i.e. replace the URL in step 1 with https://github.com/aqlaboratory/openfold/tree/pl_upgrades
|
||||
Follow the rest of the steps of [Installation Guide](#Installation)
|
||||
|
||||
### Install OpenFold parameters without aws
|
||||
If you don't have access to `aws` on your system, you can use a different download source:
|
||||
|
||||
- HuggingFace (requires `git-lts`): `scripts/download_openfold_params_huggingface.sh`
|
||||
- Google Drive: `scripts/download_openfold_params_gdrive.sh`
|
||||
|
||||
### Docker setup
|
||||
|
||||
A [`Dockerfile`] is provided to build an OpenFold Docker image. Additional notes for setting up a docker container for OpenFold and running inference can be found [here](original_readme.md#building-and-using-the-docker-container).
|
||||
594
docs/source/original_readme.md
Normal file
594
docs/source/original_readme.md
Normal file
@@ -0,0 +1,594 @@
|
||||
# Original OpenFold README
|
||||
|
||||
A faithful but trainable PyTorch reproduction of DeepMind's
|
||||
[AlphaFold 2](https://github.com/deepmind/alphafold).
|
||||
|
||||
## Contents
|
||||
|
||||
- [OpenFold]
|
||||
- [Contents]
|
||||
- [Features](#features)
|
||||
- [Installation (Linux)](#installation-linux)
|
||||
- [Download Alignment Databases](#download-alignment-databases)
|
||||
- [Inference](#inference)
|
||||
- [Monomer inference](#monomer-inference)
|
||||
- [Multimer Inference](#multimer-inference)
|
||||
- [Soloseq Inference](#soloseq-inference)
|
||||
- [Training](#training)
|
||||
- [Testing](#testing)
|
||||
- [Building and Using the Docker Container](#building-and-using-the-docker-container)
|
||||
- [Copyright Notice](#copyright-notice)
|
||||
- [Contributing](#contributing)
|
||||
- [Citing this Work](#citing-this-work)
|
||||
|
||||
## Features
|
||||
|
||||
OpenFold carefully reproduces (almost) all of the features of the original open
|
||||
source monomer (v2.0.1) and multimer (v2.3.2) inference code. The sole exception is
|
||||
model ensembling, which fared poorly in DeepMind's own ablation testing and is being
|
||||
phased out in future DeepMind experiments. It is omitted here for the sake of reducing
|
||||
clutter. In cases where the *Nature* paper differs from the source, we always defer to the
|
||||
latter.
|
||||
|
||||
OpenFold is trainable in full precision, half precision, or `bfloat16` with or without DeepSpeed,
|
||||
and we've trained it from scratch, matching the performance of the original.
|
||||
We've publicly released model weights and our training data — some 400,000
|
||||
MSAs and PDB70 template hit files — under a permissive license. Model weights
|
||||
are available via scripts in this repository while the MSAs are hosted by the
|
||||
[Registry of Open Data on AWS (RODA)](https://registry.opendata.aws/openfold).
|
||||
Try out running inference for yourself with our [Colab notebook](https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb).
|
||||
|
||||
OpenFold also supports inference using AlphaFold's official parameters, and
|
||||
vice versa (see `scripts/convert_of_weights_to_jax.py`).
|
||||
|
||||
OpenFold has the following advantages over the reference implementation:
|
||||
|
||||
- **Faster inference** on GPU, sometimes by as much as 2x. The greatest speedups are achieved on Ampere or higher architecture GPUs.
|
||||
- **Inference on extremely long chains**, made possible by our implementation of low-memory attention
|
||||
([Rabe & Staats 2021](https://arxiv.org/pdf/2112.05682.pdf)). OpenFold can predict the structures of
|
||||
sequences with more than 4000 residues on a single A100, and even longer ones with CPU offloading.
|
||||
- **Custom CUDA attention kernels** modified from [FastFold](https://github.com/hpcaitech/FastFold)'s
|
||||
kernels support in-place attention during inference and training. They use
|
||||
4x and 5x less GPU memory than equivalent FastFold and stock PyTorch
|
||||
implementations, respectively.
|
||||
- **Efficient alignment scripts** using the original AlphaFold HHblits/JackHMMER pipeline or [ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster MMseqs2 instead. We've used them to generate millions of alignments.
|
||||
- **FlashAttention** support greatly speeds up MSA attention.
|
||||
- **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative. The kernel provides substantial speedups for training and inference, and significantly reduces the model's peak device memory requirement by 13X. The model is 15% faster during the initial training and finetuning stages, and up to 4x faster during inference. To use this feature, simply set the `use_deepspeed_evo_attention` option in `openfold/config.py`.
|
||||
|
||||
## Installation (Linux)
|
||||
|
||||
All Python dependencies are specified in `environment.yml`. For producing sequence
|
||||
alignments, you'll also need `kalign`, the [HH-suite](https://github.com/soedinglab/hh-suite),
|
||||
and one of {`jackhmmer`, [MMseqs2](https://github.com/soedinglab/mmseqs2) (nightly build)}
|
||||
installed on your system. You'll need `git-lfs` to download OpenFold parameters.
|
||||
Finally, some download scripts require `aria2c` and `aws`.
|
||||
|
||||
This package is currently supported for CUDA 11 and Pytorch 1.12
|
||||
|
||||
To install:
|
||||
1. Clone the repository, e.g. `git clone https://github.com/aqlaboratory/openfold.git`
|
||||
1. From the `openfold` repo:
|
||||
- Create a [Mamba]("https://github.com/conda-forge/miniforge/releases/latest/download/) environment, e.g.
|
||||
`mamba env create -n openfold_env -f environment.yml`
|
||||
Mamba is recommended as the dependencies required by OpenFold are quite large and mamba can speed up the process.
|
||||
- Activate the environment, e.g `conda activate openfold_env`
|
||||
1. Run `scripts/install_third_party_dependencies.sh` to configure kernels and folding resources.
|
||||
|
||||
For some systems, it may help to append the Conda environment library path to `$LD_LIBRARY_PATH`. The `install_third_party_dependencies.sh` script does this once, but you may need this for each bash instance.
|
||||
|
||||
|
||||
## Download Alignment Databases
|
||||
|
||||
If you intend to generate your own alignments, e.g. for inference, you have two
|
||||
choices for downloading protein databases, depending on whether you want to use
|
||||
DeepMind's MSA generation pipeline (w/ HMMR & HHblits) or
|
||||
[ColabFold](https://github.com/sokrypton/ColabFold)'s, which uses the faster
|
||||
MMseqs2 instead. For the former, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_alphafold_dbs.sh data/
|
||||
```
|
||||
|
||||
For the latter, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_mmseqs_dbs.sh data/ # downloads .tar files
|
||||
bash scripts/prep_mmseqs_dbs.sh data/ # unpacks and preps the databases
|
||||
```
|
||||
|
||||
Make sure to run the latter command on the machine that will be used for MSA
|
||||
generation (the script estimates how the precomputed database index used by
|
||||
MMseqs2 should be split according to the memory available on the system).
|
||||
|
||||
If you're using your own precomputed MSAs or MSAs from the RODA repository,
|
||||
there's no need to download these alignment databases. Simply make sure that
|
||||
the `alignment_dir` contains one directory per chain and that each of these
|
||||
contains alignments (.sto, .a3m, and .hhr) corresponding to that chain. You
|
||||
can use `scripts/flatten_roda.sh` to reformat RODA downloads in this way.
|
||||
Note that the RODA alignments are NOT compatible with the recent .cif ground
|
||||
truth files downloaded by `scripts/download_alphafold_dbs.sh`. To fetch .cif
|
||||
files that match the RODA MSAs, once the alignments are flattened, use
|
||||
`scripts/download_roda_pdbs.sh`. That script outputs a list of alignment dirs
|
||||
for which matching .cif files could not be found. These should be removed from
|
||||
the alignment directory.
|
||||
|
||||
Alternatively, you can use raw MSAs from
|
||||
[ProteinNet](https://github.com/aqlaboratory/proteinnet). After downloading
|
||||
that database, use `scripts/prep_proteinnet_msas.py` to convert the data
|
||||
into a format recognized by the OpenFold parser. The resulting directory
|
||||
becomes the `alignment_dir` used in subsequent steps. Use
|
||||
`scripts/unpack_proteinnet.py` to extract `.core` files from ProteinNet text
|
||||
files.
|
||||
|
||||
For both inference and training, the model's hyperparameters can be tuned from
|
||||
`openfold/config.py`. Of course, if you plan to perform inference using
|
||||
DeepMind's pretrained parameters, you will only be able to make changes that
|
||||
do not affect the shapes of model parameters. For an example of initializing
|
||||
the model, consult `run_pretrained_openfold.py`.
|
||||
|
||||
## Inference
|
||||
|
||||
OpenFold now supports three inference modes:
|
||||
- [Monomer Inference](#monomer-inference): OpenFold reproduction of AlphaFold2. Inference available with either DeepMind's pretrained parameters or OpenFold trained parameters.
|
||||
- [Multimer Inference](#multimer-inference): OpenFold reproduction of AlphaFold-Multimer. Inference available with DeepMind's pre-trained parameters.
|
||||
- [Single Sequence Inference (SoloSeq)](#soloseq-inference): Language Model based structure prediction, using [ESM-1b](https://github.com/facebookresearch/esm) embeddings.
|
||||
|
||||
More instructions for each inference mode are provided below:
|
||||
|
||||
### Monomer inference
|
||||
|
||||
To run inference on a sequence or multiple sequences using a set of DeepMind's
|
||||
pretrained parameters, first download the OpenFold weights e.g.:
|
||||
|
||||
```bash
|
||||
bash scripts/download_openfold_params.sh openfold/resources
|
||||
```
|
||||
|
||||
then run e.g.:
|
||||
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
--config_preset "model_1_ptm" \
|
||||
--model_device "cuda:0" \
|
||||
--output_dir ./ \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_params/finetuning_ptm_2.pt
|
||||
```
|
||||
|
||||
where `data` is the same directory as in the previous step. If `jackhmmer`,
|
||||
`hhblits`, `hhsearch` and `kalign` are available at the default path of
|
||||
`/usr/bin`, their `binary_path` command-line arguments can be dropped.
|
||||
If you've already computed alignments for the query, you have the option to
|
||||
skip the expensive alignment computation here with
|
||||
`--use_precomputed_alignments`.
|
||||
|
||||
`--openfold_checkpoint_path` or `--jax_param_path` accept comma-delineated lists
|
||||
of .pt/DeepSpeed OpenFold checkpoints and AlphaFold's .npz JAX parameter files,
|
||||
respectively. For a breakdown of the differences between the different parameter
|
||||
files, see the README downloaded to `openfold/resources/openfold_params/`. Since
|
||||
OpenFold was trained under a newer training schedule than the one from which the
|
||||
`model_n` config presets are derived, there is no clean correspondence between
|
||||
`config_preset` settings and OpenFold checkpoints; the only restraints are that
|
||||
`*_ptm` checkpoints must be run with `*_ptm` config presets and that `_no_templ_`
|
||||
checkpoints are only compatible with template-less presets (`model_3` and above).
|
||||
|
||||
Note that chunking (as defined in section 1.11.8 of the AlphaFold 2 supplement)
|
||||
is enabled by default in inference mode. To disable it, set `globals.chunk_size`
|
||||
to `None` in the config. If a value is specified, OpenFold will attempt to
|
||||
dynamically tune it, considering the chunk size specified in the config as a
|
||||
minimum. This tuning process automatically ensures consistently fast runtimes
|
||||
regardless of input sequence length, but it also introduces some runtime
|
||||
variability, which may be undesirable for certain users. It is also recommended
|
||||
to disable this feature for very long chains (see below). To do so, set the
|
||||
`tune_chunk_size` option in the config to `False`.
|
||||
|
||||
For large-scale batch inference, we offer an optional tracing mode, which
|
||||
massively improves runtimes at the cost of a lengthy model compilation process.
|
||||
To enable it, add `--trace_model` to the inference command.
|
||||
|
||||
To get a speedup during inference, enable [FlashAttention](https://github.com/HazyResearch/flash-attention)
|
||||
in the config. Note that it appears to work best for sequences with < 1000 residues.
|
||||
|
||||
To minimize memory usage during inference on long sequences, consider the
|
||||
following changes:
|
||||
|
||||
- As noted in the AlphaFold-Multimer paper, the AlphaFold/OpenFold template
|
||||
stack is a major memory bottleneck for inference on long sequences. OpenFold
|
||||
supports two mutually exclusive inference modes to address this issue. One,
|
||||
`average_templates` in the `template` section of the config, is similar to the
|
||||
solution offered by AlphaFold-Multimer, which is simply to average individual
|
||||
template representations. Our version is modified slightly to accommodate
|
||||
weights trained using the standard template algorithm. Using said weights, we
|
||||
notice no significant difference in performance between our averaged template
|
||||
embeddings and the standard ones. The second, `offload_templates`, temporarily
|
||||
offloads individual template embeddings into CPU memory. The former is an
|
||||
approximation while the latter is slightly slower; both are memory-efficient
|
||||
and allow the model to utilize arbitrarily many templates across sequence
|
||||
lengths. Both are disabled by default, and it is up to the user to determine
|
||||
which best suits their needs, if either.
|
||||
- Inference-time low-memory attention (LMA) can be enabled in the model config.
|
||||
This setting trades off speed for vastly improved memory usage. By default,
|
||||
LMA is run with query and key chunk sizes of 1024 and 4096, respectively.
|
||||
These represent a favorable tradeoff in most memory-constrained cases.
|
||||
Powerusers can choose to tweak these settings in
|
||||
`openfold/model/primitives.py`. For more information on the LMA algorithm,
|
||||
see the aforementioned Staats & Rabe preprint.
|
||||
- Disable `tune_chunk_size` for long sequences. Past a certain point, it only
|
||||
wastes time.
|
||||
- As a last resort, consider enabling `offload_inference`. This enables more
|
||||
extensive CPU offloading at various bottlenecks throughout the model.
|
||||
- Disable FlashAttention, which seems unstable on long sequences.
|
||||
|
||||
Using the most conservative settings, we were able to run inference on a
|
||||
4600-residue complex with a single A100. Compared to AlphaFold's own memory
|
||||
offloading mode, ours is considerably faster; the same complex takes the more
|
||||
efficent AlphaFold-Multimer more than double the time. Use the
|
||||
`long_sequence_inference` config option to enable all of these interventions
|
||||
at once. The `run_pretrained_openfold.py` script can enable this config option with the
|
||||
`--long_sequence_inference` command line option
|
||||
|
||||
Input FASTA files containing multiple sequences are treated as complexes. In
|
||||
this case, the inference script runs AlphaFold-Gap, a hack proposed
|
||||
[here](https://twitter.com/minkbaek/status/1417538291709071362?lang=en), using
|
||||
the specified stock AlphaFold/OpenFold parameters (NOT AlphaFold-Multimer).
|
||||
|
||||
### Multimer Inference
|
||||
|
||||
To run inference on a complex or multiple complexes using a set of DeepMind's pretrained parameters, run e.g.:
|
||||
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2022_05.fa \
|
||||
--pdb_seqres_database_path data/pdb_seqres/pdb_seqres.txt \
|
||||
--uniref30_database_path data/uniref30/UniRef30_2021_03 \
|
||||
--uniprot_database_path data/uniprot/uniprot.fasta \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hmmsearch_binary_path lib/conda/envs/openfold_venv/bin/hmmsearch \
|
||||
--hmmbuild_binary_path lib/conda/envs/openfold_venv/bin/hmmbuild \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
--config_preset "model_1_multimer_v3" \
|
||||
--model_device "cuda:0" \
|
||||
--output_dir ./
|
||||
```
|
||||
|
||||
As with monomer inference, if you've already computed alignments for the query, you can use
|
||||
the `--use_precomputed_alignments` option. Note that template searching in the multimer pipeline
|
||||
uses HMMSearch with the PDB SeqRes database, replacing HHSearch and PDB70 used in the monomer pipeline.
|
||||
|
||||
**Upgrade from an existing OpenFold installation**
|
||||
|
||||
The above command requires several upgrades to existing openfold installations.
|
||||
|
||||
1. Re-download the alphafold parameters to get the latest
|
||||
AlphaFold-Multimer v3 weights:
|
||||
|
||||
```bash
|
||||
bash scripts/download_alphafold_params.sh openfold/resources
|
||||
```
|
||||
|
||||
2. Download the [UniProt](https://www.uniprot.org/uniprotkb/)
|
||||
and [PDB SeqRes](https://www.rcsb.org/) databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniprot.sh data/
|
||||
```
|
||||
|
||||
The PDB SeqRes and PDB databases must be from the same date to avoid potential
|
||||
errors during template searching. Remove the existing `data/pdb_mmcif` directory
|
||||
and download both databases:
|
||||
|
||||
```bash
|
||||
bash scripts/download_pdb_mmcif.sh data/
|
||||
bash scripts/download_pdb_seqres.sh data/
|
||||
```
|
||||
|
||||
3. Additionally, AlphaFold-Multimer uses upgraded versions of the [MGnify](https://www.ebi.ac.uk/metagenomics)
|
||||
and [UniRef30](https://uniclust.mmseqs.com/) (previously UniClust30) databases. To download the upgraded databases, run:
|
||||
|
||||
```bash
|
||||
bash scripts/download_uniref30.sh data/
|
||||
bash scripts/download_mgnify.sh data/
|
||||
```
|
||||
Multimer inference can also run with the older database versions if desired.
|
||||
|
||||
|
||||
### Soloseq Inference
|
||||
|
||||
To run inference for a sequence using the SoloSeq single-sequence model, you can either precompute ESM-1b embeddings in bulk, or you can generate them during inference.
|
||||
|
||||
For generating ESM-1b embeddings in bulk, use the provided script: `scripts/precompute_embeddings.py`. The script takes a directory of FASTA files (one sequence per file) and generates ESM-1b embeddings in the same format and directory structure as required by SoloSeq. Following is an example command to use the script:
|
||||
|
||||
```bash
|
||||
python scripts/precompute_embeddings.py fasta_dir/ embeddings_output_dir/
|
||||
```
|
||||
|
||||
In the same per-label subdirectories inside `embeddings_output_dir`, you can also place `*.hhr` files (outputs from HHSearch), which can contain the details about the structures that you want to use as templates. If you do not place any such file, templates will not be used and only the ESM-1b embeddings will be used to predict the structure. If you want to use templates, you need to pass the PDB MMCIF dataset to the command.
|
||||
|
||||
Then download the SoloSeq model weights, e.g.:
|
||||
|
||||
|
||||
```bash
|
||||
bash scripts/download_openfold_soloseq_params.sh openfold/resources
|
||||
```
|
||||
|
||||
|
||||
Now, you are ready to run inference:
|
||||
```bash
|
||||
python run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--use_precomputed_alignments embeddings_output_dir \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt
|
||||
```
|
||||
|
||||
For generating the embeddings during inference, skip the `--use_precomputed_alignments` argument. The `*.hhr` files will be generated as well if you pass the paths to the relevant databases and tools, as specified in the command below. If you skip the database and tool arguments, HHSearch will not be used to find templates and only generated ESM-1b embeddings will be used to predict the structure.
|
||||
```bash
|
||||
python3 run_pretrained_openfold.py \
|
||||
fasta_dir \
|
||||
data/pdb_mmcif/mmcif_files/ \
|
||||
--output_dir ./ \
|
||||
--model_device "cuda:0" \
|
||||
--config_preset "seq_model_esm1b_ptm" \
|
||||
--openfold_checkpoint_path openfold/resources/openfold_soloseq_params/seq_model_esm1b_ptm.pt \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign \
|
||||
```
|
||||
|
||||
For generating template information, you will need the UniRef90 and PDB70 databases and the JackHmmer and HHSearch binaries.
|
||||
|
||||
SoloSeq allows you to use the same flags and optimizations as the MSA-based OpenFold. For example, you can skip relaxation using `--skip_relaxation`, save all model outputs using `--save_outputs`, and generate output files in MMCIF format using `--cif_output`.
|
||||
|
||||
**NOTE:** Due to the nature of the ESM-1b embeddings, the sequence length for inference using the SoloSeq model is limited to 1022 residues. Sequences longer than that will be truncated.
|
||||
|
||||
## Training
|
||||
|
||||
To train the model, you will first need to precompute protein alignments.
|
||||
|
||||
You have two options. You can use the same procedure DeepMind used by running
|
||||
the following:
|
||||
|
||||
```bash
|
||||
python3 scripts/precompute_alignments.py mmcif_dir/ alignment_dir/ \
|
||||
--uniref90_database_path data/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path data/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path data/pdb70/pdb70 \
|
||||
--uniclust30_database_path data/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--bfd_database_path data/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--cpus_per_task 16 \
|
||||
--jackhmmer_binary_path lib/conda/envs/openfold_venv/bin/jackhmmer \
|
||||
--hhblits_binary_path lib/conda/envs/openfold_venv/bin/hhblits \
|
||||
--hhsearch_binary_path lib/conda/envs/openfold_venv/bin/hhsearch \
|
||||
--kalign_binary_path lib/conda/envs/openfold_venv/bin/kalign
|
||||
```
|
||||
|
||||
As noted before, you can skip the `binary_path` arguments if these binaries are
|
||||
at `/usr/bin`. Expect this step to take a very long time, even for small
|
||||
numbers of proteins.
|
||||
|
||||
Alternatively, you can generate MSAs with the ColabFold pipeline (and templates
|
||||
with HHsearch) with:
|
||||
|
||||
```bash
|
||||
python3 scripts/precompute_alignments_mmseqs.py input.fasta \
|
||||
data/mmseqs_dbs \
|
||||
uniref30_2103_db \
|
||||
alignment_dir \
|
||||
~/MMseqs2/build/bin/mmseqs \
|
||||
/usr/bin/hhsearch \
|
||||
--env_db colabfold_envdb_202108_db
|
||||
--pdb70 data/pdb70/pdb70
|
||||
```
|
||||
|
||||
where `input.fasta` is a FASTA file containing one or more query sequences. To
|
||||
generate an input FASTA from a directory of mmCIF and/or ProteinNet .core
|
||||
files, we provide `scripts/data_dir_to_fasta.py`.
|
||||
|
||||
Next, generate a cache of certain datapoints in the template mmCIF files:
|
||||
|
||||
```bash
|
||||
python3 scripts/generate_mmcif_cache.py \
|
||||
mmcif_dir/ \
|
||||
mmcif_cache.json \
|
||||
--no_workers 16
|
||||
```
|
||||
|
||||
This cache is used to pre-filter templates.
|
||||
|
||||
Next, generate a separate chain-level cache with data used for training-time
|
||||
data filtering:
|
||||
|
||||
```bash
|
||||
python3 scripts/generate_chain_data_cache.py \
|
||||
mmcif_dir/ \
|
||||
chain_data_cache.json \
|
||||
--cluster_file clusters-by-entity-40.txt \
|
||||
--no_workers 16
|
||||
```
|
||||
|
||||
where the `cluster_file` argument is a file of chain clusters, one cluster
|
||||
per line.
|
||||
|
||||
Optionally, download an AlphaFold-style validation set from
|
||||
[CAMEO](https://cameo3d.org) using `scripts/download_cameo.py`. Use the
|
||||
resulting FASTA files to generate validation alignments and then specify
|
||||
the validation set's location using the `--val_...` family of training script
|
||||
flags.
|
||||
|
||||
Finally, call the training script:
|
||||
|
||||
```bash
|
||||
python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ output_dir/ \
|
||||
2021-10-10 \
|
||||
--template_release_dates_cache_path mmcif_cache.json \
|
||||
--precision bf16 \
|
||||
--gpus 8 --replace_sampler_ddp=True \
|
||||
--seed 4242022 \ # in multi-gpu settings, the seed must be specified
|
||||
--deepspeed_config_path deepspeed_config.json \
|
||||
--checkpoint_every_epoch \
|
||||
--resume_from_ckpt ckpt_dir/ \
|
||||
--train_chain_data_cache_path chain_data_cache.json \
|
||||
--obsolete_pdbs_file_path obsolete.dat
|
||||
```
|
||||
|
||||
where `--template_release_dates_cache_path` is a path to the mmCIF cache.
|
||||
Note that `template_mmcif_dir` can be the same as `mmcif_dir` which contains
|
||||
training targets. A suitable DeepSpeed configuration file can be generated with
|
||||
`scripts/build_deepspeed_config.py`. The training script is
|
||||
written with [PyTorch Lightning](https://github.com/PyTorchLightning/pytorch-lightning)
|
||||
and supports the full range of training options that entails, including
|
||||
multi-node distributed training, validation, and so on. For more information,
|
||||
consult PyTorch Lightning documentation and the `--help` flag of the training
|
||||
script.
|
||||
|
||||
Note that, despite its variable name, `mmcif_dir` can also contain PDB files
|
||||
or even ProteinNet .core files.
|
||||
|
||||
To emulate the AlphaFold training procedure, which uses a self-distillation set
|
||||
subject to special preprocessing steps, use the family of `--distillation` flags.
|
||||
|
||||
In cases where it may be burdensome to create separate files for each chain's
|
||||
alignments, alignment directories can be consolidated using the scripts in
|
||||
`scripts/alignment_db_scripts/`. First, run `create_alignment_db.py` to
|
||||
consolidate an alignment directory into a pair of database and index files.
|
||||
Once all alignment directories (or shards of a single alignment directory)
|
||||
have been compiled, unify the indices with `unify_alignment_db_indices.py`. The
|
||||
resulting index, `super.index`, can be passed to the training script flags
|
||||
containing the phrase `alignment_index`. In this scenario, the `alignment_dir`
|
||||
flags instead represent the directory containing the compiled alignment
|
||||
databases. Both the training and distillation datasets can be compiled in this
|
||||
way. Anecdotally, this can speed up training in I/O-bottlenecked environments.
|
||||
|
||||
## Testing
|
||||
|
||||
To run unit tests, use
|
||||
|
||||
```bash
|
||||
scripts/run_unit_tests.sh
|
||||
```
|
||||
|
||||
The script is a thin wrapper around Python's `unittest` suite, and recognizes
|
||||
`unittest` arguments. E.g., to run a specific test verbosely:
|
||||
|
||||
```bash
|
||||
scripts/run_unit_tests.sh -v tests.test_model
|
||||
```
|
||||
|
||||
Certain tests require that AlphaFold (v2.0.1) be installed in the same Python
|
||||
environment. These run components of AlphaFold and OpenFold side by side and
|
||||
ensure that output activations are adequately similar. For most modules, we
|
||||
target a maximum pointwise difference of `1e-4`.
|
||||
|
||||
## Building and Using the Docker Container
|
||||
|
||||
**Building the Docker Image**
|
||||
|
||||
Openfold can be built as a docker container using the included dockerfile. To build it, run the following command from the root of this repository:
|
||||
|
||||
```bash
|
||||
docker build -t openfold .
|
||||
```
|
||||
|
||||
**Running the Docker Container**
|
||||
|
||||
The built container contains both `run_pretrained_openfold.py` and `train_openfold.py` as well as all necessary software dependencies. It does not contain the model parameters, sequence, or structural databases. These should be downloaded to the host machine following the instructions in the Usage section above.
|
||||
|
||||
The docker container installs all conda components to the base conda environment in `/opt/conda`, and installs openfold itself in `/opt/openfold`,
|
||||
|
||||
Before running the docker container, you can verify that your docker installation is able to properly communicate with your GPU by running the following command:
|
||||
|
||||
|
||||
```bash
|
||||
docker run --rm --gpus all nvidia/cuda:11.0-base nvidia-smi
|
||||
```
|
||||
|
||||
Note the `--gpus all` option passed to `docker run`. This option is necessary in order for the container to use the GPUs on the host machine.
|
||||
|
||||
To run the inference code under docker, you can use a command like the one below. In this example, parameters and sequences from the alphafold dataset are being used and are located at `/mnt/alphafold_database` on the host machine, and the input files are located in the current working directory. You can adjust the volume mount locations as needed to reflect the locations of your data.
|
||||
|
||||
```bash
|
||||
docker run \
|
||||
--gpus all \
|
||||
-v $PWD/:/data \
|
||||
-v /mnt/alphafold_database/:/database \
|
||||
-ti openfold:latest \
|
||||
python3 /opt/openfold/run_pretrained_openfold.py \
|
||||
/data/fasta_dir \
|
||||
/database/pdb_mmcif/mmcif_files/ \
|
||||
--uniref90_database_path /database/uniref90/uniref90.fasta \
|
||||
--mgnify_database_path /database/mgnify/mgy_clusters_2018_12.fa \
|
||||
--pdb70_database_path /database/pdb70/pdb70 \
|
||||
--uniclust30_database_path /database/uniclust30/uniclust30_2018_08/uniclust30_2018_08 \
|
||||
--output_dir /data \
|
||||
--bfd_database_path /database/bfd/bfd_metaclust_clu_complete_id30_c90_final_seq.sorted_opt \
|
||||
--model_device cuda:0 \
|
||||
--jackhmmer_binary_path /opt/conda/bin/jackhmmer \
|
||||
--hhblits_binary_path /opt/conda/bin/hhblits \
|
||||
--hhsearch_binary_path /opt/conda/bin/hhsearch \
|
||||
--kalign_binary_path /opt/conda/bin/kalign \
|
||||
--openfold_checkpoint_path /database/openfold_params/finetuning_ptm_2.pt
|
||||
```
|
||||
|
||||
## Copyright Notice
|
||||
|
||||
While AlphaFold's and, by extension, OpenFold's source code is licensed under
|
||||
the permissive Apache Licence, Version 2.0, DeepMind's pretrained parameters
|
||||
fall under the CC BY 4.0 license, a copy of which is downloaded to
|
||||
`openfold/resources/params` by the installation script. Note that the latter
|
||||
replaces the original, more restrictive CC BY-NC 4.0 license as of January 2022.
|
||||
|
||||
## Contributing
|
||||
|
||||
If you encounter problems using OpenFold, feel free to create an issue! We also
|
||||
welcome pull requests from the community.
|
||||
|
||||
## Citing this Work
|
||||
|
||||
Please cite our paper:
|
||||
|
||||
```bibtex
|
||||
@article {Ahdritz2022.11.20.517210,
|
||||
author = {Ahdritz, Gustaf and Bouatta, Nazim and Floristean, Christina and Kadyan, Sachin and Xia, Qinghui and Gerecke, William and O{\textquoteright}Donnell, Timothy J and Berenberg, Daniel and Fisk, Ian and Zanichelli, Niccolò and Zhang, Bo and Nowaczynski, Arkadiusz and Wang, Bei and Stepniewska-Dziubinska, Marta M and Zhang, Shang and Ojewole, Adegoke and Guney, Murat Efe and Biderman, Stella and Watkins, Andrew M and Ra, Stephen and Lorenzo, Pablo Ribalta and Nivon, Lucas and Weitzner, Brian and Ban, Yih-En Andrew and Sorger, Peter K and Mostaque, Emad and Zhang, Zhao and Bonneau, Richard and AlQuraishi, Mohammed},
|
||||
title = {{O}pen{F}old: {R}etraining {A}lpha{F}old2 yields new insights into its learning mechanisms and capacity for generalization},
|
||||
elocation-id = {2022.11.20.517210},
|
||||
year = {2022},
|
||||
doi = {10.1101/2022.11.20.517210},
|
||||
publisher = {Cold Spring Harbor Laboratory},
|
||||
URL = {https://www.biorxiv.org/content/10.1101/2022.11.20.517210},
|
||||
eprint = {https://www.biorxiv.org/content/early/2022/11/22/2022.11.20.517210.full.pdf},
|
||||
journal = {bioRxiv}
|
||||
}
|
||||
```
|
||||
If you use OpenProteinSet, please also cite:
|
||||
|
||||
```bibtex
|
||||
@misc{ahdritz2023openproteinset,
|
||||
title={{O}pen{P}rotein{S}et: {T}raining data for structural biology at scale},
|
||||
author={Gustaf Ahdritz and Nazim Bouatta and Sachin Kadyan and Lukas Jarosch and Daniel Berenberg and Ian Fisk and Andrew M. Watkins and Stephen Ra and Richard Bonneau and Mohammed AlQuraishi},
|
||||
year={2023},
|
||||
eprint={2308.05326},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={q-bio.BM}
|
||||
}
|
||||
```
|
||||
Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
|
||||
@@ -1,18 +1,18 @@
|
||||
name: openfold-venv
|
||||
name: openfold-env
|
||||
channels:
|
||||
- conda-forge
|
||||
- bioconda
|
||||
- pytorch
|
||||
- nvidia
|
||||
dependencies:
|
||||
- python=3.9
|
||||
- python=3.10
|
||||
- libgcc=7.2
|
||||
- setuptools=59.5.0
|
||||
- pip
|
||||
- openmm=7.7
|
||||
- pdbfixer
|
||||
- pytorch-lightning
|
||||
- biopython==1.79
|
||||
- biopython
|
||||
- numpy
|
||||
- pandas
|
||||
- PyYAML==5.4.1
|
||||
@@ -24,11 +24,12 @@ dependencies:
|
||||
- modelcif==0.7
|
||||
- awscli
|
||||
- ml-collections
|
||||
- mkl=2022.1
|
||||
- aria2
|
||||
- git
|
||||
- bioconda::hmmer==3.3.2
|
||||
- bioconda::hhsuite==3.3.0
|
||||
- bioconda::kalign2==2.04
|
||||
- bioconda::hmmer
|
||||
- bioconda::hhsuite
|
||||
- bioconda::kalign2
|
||||
- pytorch::pytorch=2.1
|
||||
- pytorch::pytorch-cuda=12.1
|
||||
- pip:
|
||||
|
||||
2856
examples/monomer/alignments/6KWC_1/bfd_uniref_hits.a3m
Normal file
2856
examples/monomer/alignments/6KWC_1/bfd_uniref_hits.a3m
Normal file
File diff suppressed because it is too large
Load Diff
1877
examples/monomer/alignments/6KWC_1/hhsearch_output.hhr
Normal file
1877
examples/monomer/alignments/6KWC_1/hhsearch_output.hhr
Normal file
File diff suppressed because it is too large
Load Diff
10200
examples/monomer/alignments/6KWC_1/mgnify_hits.sto
Normal file
10200
examples/monomer/alignments/6KWC_1/mgnify_hits.sto
Normal file
File diff suppressed because it is too large
Load Diff
19860
examples/monomer/alignments/6KWC_1/uniref90_hits.sto
Normal file
19860
examples/monomer/alignments/6KWC_1/uniref90_hits.sto
Normal file
File diff suppressed because it is too large
Load Diff
2
examples/monomer/fasta_dir/6kwc.fasta
Normal file
2
examples/monomer/fasta_dir/6kwc.fasta
Normal file
@@ -0,0 +1,2 @@
|
||||
>6KWC_1
|
||||
GSTIQPGTGYNNGYFYSYWNDGHGGVTYTNGPGGQFSVNWSNSGEFVGGKGWQPGTKNKVINFSGSYNPNGNSYLSVYGWSRNPLIEYYIVENFGTYNPSTGATKLGEVTSDGSVYDIYRTQRVNQPSIIGTATFYQYWSVRRNHRSSGSVNTANHFNAWAQQGLTLGTMDYQIVAVQGYFSSGSASITVS
|
||||
16
examples/monomer/inference.sh
Executable file
16
examples/monomer/inference.sh
Executable file
@@ -0,0 +1,16 @@
|
||||
#!/bin/bash
|
||||
export LD_LIBRARY_PATH=$CONDA_PREFIX/lib:$LD_LIBRARY_PATH
|
||||
export LIBRARY_PATH=$CONDA_PREFIX/lib:$LIBRARY_PATH
|
||||
|
||||
export FASTA_DIR=./fasta_dir
|
||||
export OUTPUT_DIR=./
|
||||
export PRECOMPUTED_ALIGNMENT_DIR=./alignments
|
||||
export MMCIF_DIR=/mmcifs # UPDATE with path to your mmcifs directory
|
||||
|
||||
python3 run_pretrained_openfold.py $FASTA_DIR \
|
||||
$MMCIF_DIR \
|
||||
--output_dir $OUTPUT_DIR \
|
||||
--config_preset model_1_ptm \
|
||||
--model_device "cuda:0" \
|
||||
--data_random_seed 42 \
|
||||
--use_precomputed_alignments $PRECOMPUTED_ALIGNMENT_DIR
|
||||
2843
examples/monomer/sample_predictions/6KWC_1_model_1_ptm_relaxed.pdb
Normal file
2843
examples/monomer/sample_predictions/6KWC_1_model_1_ptm_relaxed.pdb
Normal file
File diff suppressed because it is too large
Load Diff
1488
examples/monomer/sample_predictions/6KWC_1_model_1_ptm_unrelaxed.pdb
Normal file
1488
examples/monomer/sample_predictions/6KWC_1_model_1_ptm_unrelaxed.pdb
Normal file
File diff suppressed because it is too large
Load Diff
472
experiments/test_templates_openfold.py
Normal file
472
experiments/test_templates_openfold.py
Normal file
@@ -0,0 +1,472 @@
|
||||
# Adapted from https://www.github.com/jproney/AF2Rank/blob/master/test_templates.py
|
||||
|
||||
# Copyright 2024 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import traceback
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
import re
|
||||
import subprocess
|
||||
import torch
|
||||
from collections import namedtuple
|
||||
from copy import deepcopy
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("name", help="name to save everything under")
|
||||
parser.add_argument("--target_list", nargs='*', help="List of target names to run")
|
||||
parser.add_argument("--targets_file", default="", help="File with list of target names to run")
|
||||
parser.add_argument("--recycles", type=int, default=1, help="Number of recycles when predicting")
|
||||
parser.add_argument("--model_name", type=str, default="model_1_ptm", help="Which OF model to use")
|
||||
parser.add_argument("--seed", type=int, default=0, help="RNG Seed")
|
||||
parser.add_argument("--verbose", action='store_true', help="print extra")
|
||||
parser.add_argument("--deterministic", action='store_true', help="make all data processing deterministic (no masking, etc.)")
|
||||
parser.add_argument("--use_native", action='store_true', help="add the native structure as a decoy, and compare outputs against it")
|
||||
parser.add_argument("--mask_sidechains", action='store_true', help="mask out sidechain atoms except for C-Beta")
|
||||
parser.add_argument("--mask_sidechains_add_cb", action='store_true', help="mask out sidechain atoms except for C-Beta, and add C-Beta to glycines")
|
||||
parser.add_argument("--seq_replacement", default='', help="Amino acid residue to fill the decoy sequence with. Default keeps target sequence")
|
||||
parser.add_argument("--of_dir", default="/home/user/openfold/", help="OpenFold code and weights directory")
|
||||
parser.add_argument("--esm_dir", help="ESM1b embeddings directory, containing embeddings as *.pt")
|
||||
parser.add_argument("--decoy_dir", default="/home/user/openfold/rosetta_decoy_set/", help="Rosetta decoy directory")
|
||||
parser.add_argument("--output_dir", default="/home/user/ofss_ranking_experiment/outputs/", help="Rosetta decoy directory")
|
||||
parser.add_argument("--openfold_checkpoint_path", help="Path to the OpenFold model checkpoint")
|
||||
parser.add_argument("--jax_param_path", help="Path to the JAX parameters checkpoint")
|
||||
parser.add_argument("--model_device", default="cpu", help="Device to run the model on")
|
||||
parser.add_argument("--tm_exec", default="/home/user/tmscore/TMscore", help="TMScore executable")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
sys.path.insert(0, args.of_dir)
|
||||
|
||||
|
||||
# openfold imports
|
||||
from openfold import config
|
||||
|
||||
from openfold.data import data_pipeline
|
||||
from openfold.data import feature_pipeline
|
||||
|
||||
from openfold.np import protein
|
||||
from openfold.np import residue_constants
|
||||
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
from openfold.utils.script_utils import load_models_from_command_line, run_model
|
||||
|
||||
|
||||
# helper functions
|
||||
|
||||
"""
|
||||
Read in a PDB file from a path
|
||||
"""
|
||||
def pdb_to_string(pdb_file):
|
||||
lines = []
|
||||
for line in open(pdb_file,"r"):
|
||||
if line[:6] == "HETATM" and line[17:20] == "MSE":
|
||||
line = "ATOM "+line[6:17]+"MET"+line[20:]
|
||||
if line[:4] == "ATOM":
|
||||
lines.append(line)
|
||||
return "".join(lines)
|
||||
|
||||
"""
|
||||
Compute aligned RMSD between two corresponding sets of points
|
||||
true -- set of reference points. Numpy array of dimension N x 3
|
||||
pred -- set of predicted points, Numpy array of dimension N x 3
|
||||
"""
|
||||
def jnp_rmsd(true, pred):
|
||||
def kabsch(P, Q):
|
||||
V, S, W = jnp.linalg.svd(P.T @ Q, full_matrices=False)
|
||||
flip = jax.nn.sigmoid(-10 * jnp.linalg.det(V) * jnp.linalg.det(W))
|
||||
S = flip * S.at[-1].set(-S[-1]) + (1-flip) * S
|
||||
V = flip * V.at[:,-1].set(-V[:,-1]) + (1-flip) * V
|
||||
return V@W
|
||||
p = true - true.mean(0,keepdims=True)
|
||||
q = pred - pred.mean(0,keepdims=True)
|
||||
p = p @ kabsch(p,q)
|
||||
loss = jnp.sqrt(jnp.square(p-q).sum(-1).mean() + 1e-8)
|
||||
return float(loss)
|
||||
|
||||
"""
|
||||
Create an OpenFold model runner
|
||||
name -- The name of the model to get the parameters from. Options: model_[1-5]
|
||||
"""
|
||||
def make_model_runner(name, recycles, args):
|
||||
cfg = config.model_config(name)
|
||||
|
||||
cfg.data.common.max_recycling_iters = recycles
|
||||
|
||||
if args.deterministic:
|
||||
cfg.data.eval.masked_msa_replace_fraction = 0.0
|
||||
cfg.data.predict.masked_msa_replace_fraction = 0.0
|
||||
|
||||
model_generator = load_models_from_command_line(cfg, args.model_device, args.openfold_checkpoint_path, args.jax_param_path, args.output_dir)
|
||||
model, _ = model_generator.__next__()
|
||||
|
||||
return model, cfg
|
||||
|
||||
"""
|
||||
Make a set of empty features for no-template evaluations
|
||||
"""
|
||||
def empty_placeholder_template_features(num_templates, num_res):
|
||||
return {
|
||||
'template_aatype': np.zeros((num_templates, num_res), dtype=np.int64),
|
||||
'template_all_atom_mask': np.zeros(
|
||||
(num_templates, num_res, residue_constants.atom_type_num),
|
||||
dtype=np.float32),
|
||||
'template_all_atom_positions': np.zeros(
|
||||
(num_templates, num_res, residue_constants.atom_type_num, 3),
|
||||
dtype=np.float32),
|
||||
'template_domain_names': np.zeros([num_templates], dtype=object),
|
||||
'template_sequence': np.zeros([num_templates], dtype=object),
|
||||
'template_sum_probs': np.zeros([num_templates, 1], dtype=np.float32),
|
||||
}
|
||||
|
||||
def make_embedding_features(args, label):
|
||||
seqemb_features = {}
|
||||
|
||||
path = os.path.join(args.esm_dir, label+'.pt')
|
||||
|
||||
# Load embedding file
|
||||
seqemb_data = torch.load(path)
|
||||
seqemb_features["seq_embedding"] = seqemb_data["representations"][33]
|
||||
|
||||
return seqemb_features
|
||||
|
||||
"""
|
||||
Create a feature dictionary for input to OpenFold
|
||||
runner - The model runner being invoked. Returned from `make_model_runner`
|
||||
sequence - The target sequence being predicted
|
||||
templates - The template features being added to the inputs
|
||||
seed - The random seed being used for data processing
|
||||
"""
|
||||
def make_processed_feature_dict(cfg, sequence, name="test", templates=None, seed=0):
|
||||
feature_dict = {}
|
||||
feature_dict.update(data_pipeline.make_sequence_features(sequence, name, len(sequence)))
|
||||
|
||||
msa = [[sequence]]
|
||||
deletion_matrix = [[[0 for _ in sequence]]]
|
||||
|
||||
feature_dict.update(data_pipeline.make_msa_features(msa, deletion_matrix))
|
||||
|
||||
if templates is not None:
|
||||
feature_dict.update(templates)
|
||||
else:
|
||||
feature_dict.update(empty_placeholder_template_features(num_templates=0, num_res=len(sequence)))
|
||||
|
||||
feature_dict.update(make_embedding_features(args, name.split('_')[0]))
|
||||
|
||||
feature_processor = feature_pipeline.FeaturePipeline(cfg.data)
|
||||
processed_feature_dict = feature_processor.process_features(feature_dict, mode='predict')
|
||||
processed_feature_dict = {
|
||||
k: torch.as_tensor(v, device=args.model_device)
|
||||
for k, v in processed_feature_dict.items()
|
||||
}
|
||||
|
||||
return processed_feature_dict
|
||||
|
||||
"""
|
||||
Package OpenFold's output into an easy-to-use dictionary
|
||||
prediction_result - output from running OpenFold on an input dictionary
|
||||
processed_feature_dict -- The dictionary passed to OpenFold as input. Returned by `make_processed_feature_dict`.
|
||||
"""
|
||||
def parse_results(prediction_result, processed_feature_dict):
|
||||
b_factors = prediction_result['plddt'][:,None] * prediction_result['final_atom_mask']
|
||||
|
||||
out = {"unrelaxed_protein": protein.from_prediction(processed_feature_dict, prediction_result, b_factors=b_factors),
|
||||
"plddt": prediction_result['plddt'],
|
||||
"pLDDT": prediction_result['plddt'].mean(),}
|
||||
|
||||
out.update({"pTMscore": prediction_result['predicted_tm_score']})
|
||||
|
||||
return out
|
||||
|
||||
|
||||
'''
|
||||
Function used to add C-Beta to glycine resides
|
||||
input: 3 coords (a,b,c), (L)ength, (A)ngle, and (D)ihedral
|
||||
output: 4th coord
|
||||
'''
|
||||
def extend(a,b,c, L,A,D):
|
||||
N = lambda x: x/np.sqrt(np.square(x).sum(-1,keepdims=True) + 1e-8)
|
||||
bc = N(b-c)
|
||||
n = N(np.cross(b-a, bc))
|
||||
m = [bc,np.cross(n,bc),n]
|
||||
d = [L*np.cos(A), L*np.sin(A)*np.cos(D), -L*np.sin(A)*np.sin(D)]
|
||||
return c + sum([m*d for m,d in zip(m,d)])
|
||||
|
||||
"""
|
||||
Ingest a decoy protein, pass it to OpenFold as a template, and extract the parsed output
|
||||
target_seq -- the sequence to be predicted
|
||||
decoy_prot -- the decoy structure to be injected as a template
|
||||
model_runner -- the model runner to execute
|
||||
name -- the name associated with this prediction
|
||||
"""
|
||||
def score_decoy(target_seq, decoy_prot, model_runner, name):
|
||||
decoy_seq_in = "".join([residue_constants.restypes[x] for x in decoy_prot.aatype]) # the sequence in the decoy PDB file
|
||||
|
||||
mismatch = False
|
||||
if decoy_seq_in == target_seq:
|
||||
assert jnp.all(decoy_prot.residue_index - 1 == np.arange(len(target_seq)))
|
||||
else: # case when template is missing some residues
|
||||
if args.verbose:
|
||||
print("Sequence mismatch: {}".format(name))
|
||||
mismatch=True
|
||||
|
||||
assert "".join(target_seq[i-1] for i in decoy_prot.residue_index) == decoy_seq_in
|
||||
|
||||
# use this to index into the template features
|
||||
template_idxs = decoy_prot.residue_index-1
|
||||
template_idx_set = set(template_idxs)
|
||||
|
||||
# The sequence associated with the decoy. Always has same length as target sequence.
|
||||
decoy_seq = args.seq_replacement*len(target_seq) if len(args.seq_replacement) == 1 else target_seq
|
||||
|
||||
# create empty template features
|
||||
pos = np.zeros([1,len(decoy_seq), 37, 3])
|
||||
atom_mask = np.zeros([1, len(decoy_seq), 37])
|
||||
|
||||
if args.mask_sidechains_add_cb:
|
||||
pos[0, template_idxs, :5] = decoy_prot.atom_positions[:,:5]
|
||||
|
||||
# residues where we have all of the key backbone atoms (N CA C)
|
||||
backbone_modelled = np.asarray(jnp.all(decoy_prot.atom_mask[:,[0,1,2]] == 1, axis=1))
|
||||
backbone_idx_set = set(decoy_prot.residue_index[backbone_modelled] - 1)
|
||||
|
||||
projected_cb = [i-1 for i,b,m in zip(decoy_prot.residue_index, backbone_modelled, decoy_prot.atom_mask) if m[3] == 0 and b]
|
||||
projected_cb_set = set(projected_cb)
|
||||
gly_idx = [i for i,a in enumerate(target_seq) if a == "G"]
|
||||
assert all([k in projected_cb_set for k in gly_idx if k in template_idx_set and k in backbone_idx_set]) # make sure we are adding CBs to all of the glycines
|
||||
|
||||
cbs = np.array([extend(c,n,ca, 1.522, 1.927, -2.143) for c, n ,ca in zip(pos[0,:,2], pos[0,:,0], pos[0,:,1])])
|
||||
|
||||
pos[0, projected_cb, 3] = cbs[projected_cb]
|
||||
atom_mask[0, template_idxs, :5] = decoy_prot.atom_mask[:, :5]
|
||||
atom_mask[0, projected_cb, 3] = 1
|
||||
|
||||
template = {"template_aatype":residue_constants.sequence_to_onehot(decoy_seq, residue_constants.HHBLITS_AA_TO_ID)[None],
|
||||
"template_all_atom_mask": atom_mask.astype(np.float32),
|
||||
"template_all_atom_positions":pos.astype(np.float32),
|
||||
"template_domain_names":np.asarray(["None"])}
|
||||
elif args.mask_sidechains:
|
||||
pos[0, template_idxs, :5] = decoy_prot.atom_positions[:,:5]
|
||||
atom_mask[0, template_idxs, :5] = decoy_prot.atom_mask[:,:5]
|
||||
|
||||
template = {"template_aatype":residue_constants.sequence_to_onehot(decoy_seq, residue_constants.HHBLITS_AA_TO_ID)[None],
|
||||
"template_all_atom_mask": atom_mask.astype(np.float32),
|
||||
"template_all_atom_positions": pos.astype(np.float32),
|
||||
"template_domain_names":np.asarray(["None"])}
|
||||
else:
|
||||
pos[0, template_idxs] = decoy_prot.atom_positions
|
||||
atom_mask[0, template_idxs] = decoy_prot.atom_mask
|
||||
|
||||
template = {"template_aatype":residue_constants.sequence_to_onehot(decoy_seq, residue_constants.HHBLITS_AA_TO_ID)[None],
|
||||
"template_all_atom_mask":decoy_prot.atom_mask[None].astype(np.float32),
|
||||
"template_all_atom_positions":decoy_prot.atom_positions[None].astype(np.float32),
|
||||
"template_domain_names":np.asarray(["None"])}
|
||||
|
||||
features = make_processed_feature_dict(cfg, target_seq, name=name, templates=template, seed=args.seed)
|
||||
#with open(os.path.join(args.output_dir, name + '_features.pt'), 'wb') as outfile:
|
||||
# torch.save(features, outfile)
|
||||
working_batch = deepcopy(features)
|
||||
out, inference_time = run_model(model_runner, working_batch, name, args.output_dir)
|
||||
print(f"{name} done. Inference time: ", inference_time)
|
||||
working_batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), working_batch)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
result = parse_results(out, working_batch)
|
||||
return result, mismatch
|
||||
|
||||
|
||||
tm_re = re.compile(r'TM-score[\s]*=[\s]*(\d.\d+)')
|
||||
ref_len_re = re.compile(r'Length=[\s]*(\d+)[\s]*\(by which all scores are normalized\)')
|
||||
common_re = re.compile(r'Number of residues in common=[\s]*(\d+)')
|
||||
super_re = re.compile(r'\(":" denotes the residue pairs of distance < 5\.0 Angstrom\)\\n([A-Z\-]+)\\n[" ", :]+\\n([A-Z\-]+)\\n')
|
||||
|
||||
"""
|
||||
Compute TM Scores between two PDBs and parse outputs
|
||||
pdb_pred -- The path to the predicted PDB
|
||||
pdb_native -- The path to the native PDB
|
||||
test_len -- run asserts that the input and output should have the same length
|
||||
"""
|
||||
def compute_tmscore(pdb_pred, pdb_native, test_len=True):
|
||||
cmd = ([args.tm_exec, pdb_pred, pdb_native])
|
||||
tmscore_output = str(subprocess.check_output(cmd))
|
||||
try:
|
||||
tm_out = float(tm_re.search(tmscore_output).group(1))
|
||||
reflen = int(ref_len_re.search(tmscore_output).group(1))
|
||||
common = int(common_re.search(tmscore_output).group(1))
|
||||
|
||||
seq1 = super_re.search(tmscore_output).group(1)
|
||||
seq2 = super_re.search(tmscore_output).group(1)
|
||||
except Exception as e:
|
||||
print("Failed on: " + " ".join(cmd))
|
||||
raise e
|
||||
|
||||
if test_len:
|
||||
assert reflen == common, cmd
|
||||
assert seq1 == seq2, cmd
|
||||
assert len(seq1) == reflen, cmd
|
||||
|
||||
return tm_out
|
||||
|
||||
# Simple wrapper for keeping track of the information associated with each decoy.
|
||||
decoy_fields_list = ['target', 'decoy_id', 'decoy_path', 'rmsd', 'rosettascore', 'gdt_ts', 'tmscore', 'danscore']
|
||||
Decoy = namedtuple("Decoy", decoy_fields_list)
|
||||
|
||||
|
||||
# headers for csv outputs
|
||||
csv_headers = decoy_fields_list + ['output_path', 'rmsd_out', 'tm_diff', 'tm_out', 'plddt', 'ptm']
|
||||
|
||||
def write_results(decoy, af_result, prot_native=None, pdb_native=None, mismatch=False):
|
||||
plddt = float(af_result['pLDDT'])
|
||||
if "pTMscore" not in af_result:
|
||||
ptm = -1
|
||||
else:
|
||||
ptm = float(af_result["pTMscore"])
|
||||
if prot_native is None:
|
||||
rms_out = -1
|
||||
else:
|
||||
rms_out = jnp_rmsd(jnp.asarray(prot_native.atom_positions[:,1,:]), jnp.asarray(af_result['unrelaxed_protein'].atom_positions[:,1,:]))
|
||||
|
||||
pdb_lines = protein.to_pdb(af_result["unrelaxed_protein"])
|
||||
pdb_out_path = args.output_dir + args.name + "/pdbs/" + decoy.target + "_" + decoy.decoy_id
|
||||
with open(pdb_out_path, 'w') as f:
|
||||
f.write(pdb_lines)
|
||||
|
||||
if decoy.decoy_id != "none.pdb":
|
||||
tm_diff = compute_tmscore(decoy.decoy_path, pdb_out_path, test_len = not mismatch)
|
||||
else:
|
||||
tm_diff = -1
|
||||
|
||||
if pdb_native is None:
|
||||
tm_out = -1
|
||||
else:
|
||||
tm_out = compute_tmscore(pdb_out_path, pdb_native)
|
||||
|
||||
if not os.path.exists(args.output_dir + args.name + "/results/results_{}.csv".format(decoy.target)):
|
||||
with open(args.output_dir + args.name + "/results/results_{}.csv".format(decoy.target), "w") as f:
|
||||
f.write(",".join(csv_headers) + "\n")
|
||||
|
||||
|
||||
with open(args.output_dir + args.name + "/results/results_{}.csv".format(decoy.target), "a") as f:
|
||||
result_fields = [str(x) for x in list(decoy) + [pdb_out_path, rms_out, tm_diff, tm_out, plddt, ptm]]
|
||||
f.write(",".join(result_fields) + "\n")
|
||||
|
||||
if args.verbose:
|
||||
print(",".join([x + "=" + y for x,y in zip(csv_headers, result_fields)]))
|
||||
|
||||
|
||||
# create all of the output directories
|
||||
os.makedirs(args.output_dir + args.name, exist_ok=True)
|
||||
os.makedirs(args.output_dir + args.name + "/pdbs", exist_ok=True)
|
||||
os.makedirs(args.output_dir + args.name + "/results", exist_ok=True)
|
||||
|
||||
if len(args.targets_file) > 0:
|
||||
natives_list = open(args.targets_file, 'r').read().split("\n")[:-1]
|
||||
else:
|
||||
natives_list = args.target_list
|
||||
|
||||
|
||||
finished_decoys = []
|
||||
for n in natives_list:
|
||||
if os.path.exists(args.output_dir + args.name + "/results/results_{}.csv".format(n)):
|
||||
finished_decoys += [x.split(",")[0] + "_" + x.split(",")[1] for x in open(args.output_dir + args.name + "/results/results_{}.csv".format(n), "r").readlines()]
|
||||
finished_decoys = set(finished_decoys)
|
||||
|
||||
|
||||
if os.path.exists(args.output_dir + args.name + "/finished_targets.txt"):
|
||||
finished_targets = set(open(args.output_dir + args.name + "/finished_targets.txt", 'r').read().split("\n")[:-1])
|
||||
else:
|
||||
finished_targets = []
|
||||
|
||||
|
||||
# info of the form "target decoy_id"
|
||||
decoy_list = [x.split() for x in open(args.decoy_dir + "decoy_list.txt", 'r').read().split("\n")[:-1]]
|
||||
|
||||
# parse all of the information about the decoys
|
||||
decoy_data = {}
|
||||
for field in decoy_fields_list[2:]:
|
||||
if os.path.exists(args.decoy_dir + field + ".txt"):
|
||||
lines = [x.split() for x in open(args.decoy_dir + field + ".txt", 'r').read().split("\n")[:-1]] # form "target decoy_id metric value"
|
||||
|
||||
# make sure everything is in the same order
|
||||
for i,l in enumerate(lines):
|
||||
assert l[0] == decoy_list[i][0]
|
||||
assert l[1] == decoy_list[i][1]
|
||||
|
||||
decoy_data[field] = [l[-1] for l in lines]
|
||||
else:
|
||||
decoy_data[field] = [-1]*len(decoy_list) # -1 as a placeholder
|
||||
|
||||
decoy_dict = {n : [] for n in natives_list if n not in finished_targets} # key = target name, value = list of Decoy objects
|
||||
|
||||
for i, d in enumerate(decoy_list):
|
||||
|
||||
decoy = Decoy(target=d[0], decoy_id=d[1], decoy_path=args.decoy_dir + "decoys/" + d[0] + "/" + d[1],
|
||||
rmsd = decoy_data["rmsd"][i], rosettascore = decoy_data["rosettascore"][i], gdt_ts = decoy_data["gdt_ts"][i],
|
||||
tmscore=decoy_data["tmscore"][i], danscore = decoy_data["danscore"][i])
|
||||
|
||||
if decoy.target in decoy_dict and decoy.target + "_" + decoy.decoy_id not in finished_decoys:
|
||||
decoy_dict[decoy.target].append(decoy)
|
||||
|
||||
# add another decoy entry for the native structure
|
||||
if args.use_native:
|
||||
for n in decoy_dict.keys():
|
||||
if n + "_native" not in finished_decoys:
|
||||
decoy_dict[n].insert(0, Decoy(target=n, decoy_id="native.pdb", decoy_path=args.decoy_dir + "natives/" + n + ".pdb",
|
||||
rmsd = 0, rosettascore = -1, gdt_ts = 1, tmscore = 1, danscore = -1))
|
||||
|
||||
if args.verbose:
|
||||
print(finished_decoys)
|
||||
|
||||
model_name = args.model_name
|
||||
results_key = model_name + "_seed_{}".format(args.seed)
|
||||
for n in natives_list:
|
||||
try:
|
||||
pdb_native = args.decoy_dir + "natives/" + n + ".pdb"
|
||||
prot_native = protein.from_pdb_string(pdb_to_string(pdb_native))
|
||||
seq_native = "".join([residue_constants.restypes[x] for x in prot_native.aatype])
|
||||
runner, cfg = make_model_runner(model_name, args.recycles, args)
|
||||
|
||||
if n + "_none.pdb" not in finished_decoys:
|
||||
|
||||
# run the model with no templates
|
||||
features = make_processed_feature_dict(cfg, seq_native, name=n + "_none", seed=args.seed)
|
||||
working_batch = deepcopy(features)
|
||||
out, inference_time = run_model(runner, working_batch, n + "_none", args.output_dir)
|
||||
print(f"{n}_none done. Inference time: ", inference_time)
|
||||
working_batch = tensor_tree_map(lambda x: np.array(x[..., -1].cpu()), working_batch)
|
||||
out = tensor_tree_map(lambda x: np.array(x.cpu()), out)
|
||||
result = parse_results(out, working_batch)
|
||||
|
||||
dummy_decoy = Decoy(target=n, decoy_id="none.pdb", decoy_path="_", rmsd=-1, rosettascore=-1, gdt_ts=-1, tmscore=-1,danscore=-1)
|
||||
write_results(dummy_decoy, result, prot_native=prot_native if args.use_native else None, pdb_native=pdb_native if args.use_native else None)
|
||||
|
||||
|
||||
# run the model with all of the decoys passed as templates
|
||||
for d in decoy_dict[n]:
|
||||
prot = protein.from_pdb_string(pdb_to_string(d.decoy_path))
|
||||
result, mismatch = score_decoy(seq_native, prot, runner, d.target + "_" + d.decoy_id)
|
||||
write_results(d, result, prot_native=prot_native if args.use_native else None, pdb_native=pdb_native if args.use_native else None, mismatch=mismatch)
|
||||
|
||||
|
||||
with open(args.output_dir + args.name + "/finished_targets.txt", 'a') as f:
|
||||
f.write(n + "\n")
|
||||
except AssertionError as ae:
|
||||
print(f"AssertionError encountered while processing a decoy of native {n}")
|
||||
traceback.print_exc()
|
||||
except Exception as e:
|
||||
print(f"Exception encountered while processing a decoy of native {n}")
|
||||
traceback.print_exc()
|
||||
@@ -1,14 +1,5 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/aqlaboratory/openfold/blob/main/notebooks/OpenFold.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@@ -136,7 +127,7 @@
|
||||
"\n",
|
||||
" %shell mkdir -p /content/openfold/openfold/resources\n",
|
||||
"\n",
|
||||
" commit = \"e2e19f16676b1a409f9ba3a6f69b11ee7f5887c2\"\n",
|
||||
" commit = \"a96ffd67f8c96f8c4decc3abdd2cffbb57fc5764\"\n",
|
||||
" os.system(f\"pip install -q git+https://github.com/aqlaboratory/openfold.git@{commit}\")\n",
|
||||
"\n",
|
||||
" os.system(f\"cp -f -p /content/stereo_chemical_props.txt /usr/local/lib/python{python_version}/site-packages/openfold/resources/\")\n",
|
||||
@@ -259,7 +250,7 @@
|
||||
"from openfold.np import protein\n",
|
||||
"from openfold.np.relax import relax\n",
|
||||
"from openfold.np.relax.utils import overwrite_b_factors\n",
|
||||
"from openfold.utils.import_weights import import_jax_weights_\n",
|
||||
"from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_\n",
|
||||
"from openfold.utils.tensor_utils import tensor_tree_map\n",
|
||||
"\n",
|
||||
"from IPython import display\n",
|
||||
@@ -582,7 +573,7 @@
|
||||
" model_name,\n",
|
||||
" )\n",
|
||||
" d = torch.load(params_name)\n",
|
||||
" openfold_model.load_state_dict(d)\n",
|
||||
" import_openfold_weights_(model=openfold_model, state_dict=d)\n",
|
||||
" else:\n",
|
||||
" raise ValueError(f\"Invalid weight set: {weight_set}\")\n",
|
||||
"\n",
|
||||
|
||||
@@ -62,7 +62,8 @@ def model_config(
|
||||
name,
|
||||
train=False,
|
||||
low_prec=False,
|
||||
long_sequence_inference=False
|
||||
long_sequence_inference=False,
|
||||
use_deepspeed_evoformer_attention=False,
|
||||
):
|
||||
c = copy.deepcopy(config)
|
||||
# TRAINING PRESETS
|
||||
@@ -237,6 +238,9 @@ def model_config(
|
||||
c.model.extra_msa.extra_msa_stack.tune_chunk_size = False
|
||||
c.model.evoformer_stack.tune_chunk_size = False
|
||||
|
||||
if use_deepspeed_evoformer_attention:
|
||||
c.globals.use_deepspeed_evo_attention = True
|
||||
|
||||
if train:
|
||||
c.globals.blocks_per_ckpt = 1
|
||||
c.globals.chunk_size = None
|
||||
|
||||
@@ -1053,7 +1053,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
|
||||
def val_dataloader(self):
|
||||
if self.eval_dataset is not None:
|
||||
return self._gen_dataloader("eval")
|
||||
# Temp fix to pass the validation step
|
||||
return []
|
||||
|
||||
def predict_dataloader(self):
|
||||
|
||||
@@ -24,7 +24,7 @@ import os
|
||||
from typing import Any, Mapping, Optional, Sequence, Tuple
|
||||
|
||||
from Bio import PDB
|
||||
from Bio.Data import SCOPData
|
||||
from Bio.Data import PDBData
|
||||
import numpy as np
|
||||
|
||||
from openfold.data.errors import MultipleChainsError
|
||||
@@ -283,7 +283,7 @@ def parse(
|
||||
author_chain = mmcif_to_author_chain_id[chain_id]
|
||||
seq = []
|
||||
for monomer in seq_info:
|
||||
code = SCOPData.protein_letters_3to1.get(monomer.id, "X")
|
||||
code = PDBData.protein_letters_3to1.get(monomer.id, "X")
|
||||
seq.append(code if len(code) == 1 else "X")
|
||||
seq = "".join(seq)
|
||||
author_chain_to_sequence[author_chain] = seq
|
||||
@@ -347,6 +347,7 @@ def _get_header(parsed_info: MmCIFDict) -> PdbHeader:
|
||||
try:
|
||||
raw_resolution = parsed_info[res_key][0]
|
||||
header["resolution"] = float(raw_resolution)
|
||||
break
|
||||
except ValueError:
|
||||
logging.debug(
|
||||
"Invalid resolution format: %s", parsed_info[res_key]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import logging
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
@@ -681,15 +682,18 @@ def convert_deprecated_v1_keys(state_dict):
|
||||
}
|
||||
|
||||
convert_key_re = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
|
||||
template_emb_re = re.compile(r"^((module\.)?(model\.)?)(template(?!_embedder).*)")
|
||||
|
||||
converted_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# For each match, look-up replacement value in the dictionary
|
||||
new_key = convert_key_re.sub(lambda m: replacements[m.group()], key)
|
||||
new_key = convert_key_re.sub(lambda m: replacements[m.group(1)], key)
|
||||
|
||||
# Add prefix for template modules
|
||||
if new_key.startswith('template'):
|
||||
new_key = f'template_embedder.{new_key}'
|
||||
# Add prefix for template layers
|
||||
template_match = re.match(template_emb_re, new_key)
|
||||
if template_match:
|
||||
prefix = template_match.group(1)
|
||||
new_key = f'{prefix if prefix else ""}template_embedder.{template_match.group(4)}'
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import random
|
||||
import torch
|
||||
|
||||
from typing import Tuple, List, Dict
|
||||
from openfold.np import residue_constants as rc
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -13,6 +13,17 @@ def compute_rmsd(
|
||||
atom_mask: torch.Tensor = None,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Function to calculate RMSD between predicted and ground truth atom position
|
||||
|
||||
Args:
|
||||
true_atom_pos: a [nres*3] tensor
|
||||
pred_atom_pos: a [nres*3] tensor
|
||||
atom_mask: a [1*nres] tensor
|
||||
|
||||
Return:
|
||||
RMSD value between true and predicted atom positions
|
||||
"""
|
||||
sq_diff = torch.square(true_atom_pos - pred_atom_pos).sum(dim=-1, keepdim=False)
|
||||
if atom_mask is not None:
|
||||
sq_diff = torch.masked_select(sq_diff, atom_mask.to(sq_diff.device))
|
||||
@@ -21,7 +32,7 @@ def compute_rmsd(
|
||||
return torch.sqrt(msd + eps) # prevent sqrt 0
|
||||
|
||||
|
||||
def kabsch_rotation(P, Q):
|
||||
def kabsch_rotation(P: torch.Tensor, Q: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate the best rotation that minimises the RMSD between P and Q.
|
||||
|
||||
@@ -29,11 +40,11 @@ def kabsch_rotation(P, Q):
|
||||
https://en.wikipedia.org/wiki/Kabsch_algorithm
|
||||
|
||||
Args:
|
||||
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
|
||||
Q: [N * 3] the same dimension as P
|
||||
P: [N * 3] Nres is the number of atoms and each row corresponds to the atom's x,y,z coordinates
|
||||
Q: [N * 3] the same dimension as P
|
||||
|
||||
return:
|
||||
A 3*3 rotation matrix
|
||||
one 3*3 rotation matrix that best aligns the sorce and target atoms
|
||||
"""
|
||||
assert P.shape == torch.Size([Q.shape[0], Q.shape[1]])
|
||||
|
||||
@@ -54,11 +65,20 @@ def get_optimal_transform(
|
||||
src_atoms: torch.Tensor,
|
||||
tgt_atoms: torch.Tensor,
|
||||
mask: torch.Tensor = None,
|
||||
):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
src_atoms: predicted CA positions, shape:[num_res,3]
|
||||
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
|
||||
mask: a vector of boolean values, shape:[num_res]
|
||||
A function that obtain the transformation that optimally align
|
||||
src_atoms with tgt_atoms
|
||||
|
||||
Args:
|
||||
src_atoms: predicted CA positions, shape:[num_res,3]
|
||||
tgt_atoms: ground-truth CA positions, shape:[num_res,3]
|
||||
mask: a vector of boolean values, shape:[num_res]
|
||||
|
||||
Returns:
|
||||
a rotation matrix that record the optimal rotation
|
||||
that will best align selected anchor prediction to selected anchor truth
|
||||
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
|
||||
"""
|
||||
assert src_atoms.shape == tgt_atoms.shape, (src_atoms.shape, tgt_atoms.shape)
|
||||
assert src_atoms.shape[-1] == 3
|
||||
@@ -88,7 +108,7 @@ def get_optimal_transform(
|
||||
return r, x
|
||||
|
||||
|
||||
def get_least_asym_entity_or_longest_length(batch, input_asym_id):
|
||||
def get_least_asym_entity_or_longest_length(batch: dict, input_asym_id: list) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""
|
||||
First check how many subunit(s) one sequence has. Select the subunit that is less
|
||||
common, e.g. if the protein was AABBB then select one of the A as anchor
|
||||
@@ -97,15 +117,15 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
|
||||
then choose one of the corresponding subunits as anchor
|
||||
|
||||
Args:
|
||||
batch: in this function batch is the full ground truth features
|
||||
input_asym_id: A list of asym_ids that are in the cropped input features
|
||||
batch: in this function batch is the full ground truth features
|
||||
input_asym_id: A list of asym_ids that are in the cropped input features
|
||||
|
||||
Return:
|
||||
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
|
||||
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
|
||||
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
|
||||
anchor_pred_asym_ids: list(Tensor(int)) a list of all possible pred anchor candidates
|
||||
"""
|
||||
entity_2_asym_list = get_entity_2_asym_list(batch)
|
||||
unique_entity_ids = torch.unique(batch["entity_id"])
|
||||
unique_entity_ids = [i for i in torch.unique(batch["entity_id"]) if i !=0]# if entity_id is 0, that means this entity_id comes from padding
|
||||
entity_asym_count = {}
|
||||
entity_length = {}
|
||||
|
||||
@@ -145,19 +165,38 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
|
||||
|
||||
|
||||
def greedy_align(
|
||||
batch,
|
||||
per_asym_residue_index,
|
||||
entity_2_asym_list,
|
||||
pred_ca_pos,
|
||||
pred_ca_mask,
|
||||
true_ca_poses,
|
||||
true_ca_masks,
|
||||
):
|
||||
batch: dict,
|
||||
per_asym_residue_index: dict,
|
||||
entity_2_asym_list: dict,
|
||||
pred_ca_pos: torch.Tensor,
|
||||
pred_ca_mask: torch.Tensor,
|
||||
true_ca_poses: list,
|
||||
true_ca_masks: list
|
||||
) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
Implement Algorithm 4 in the Supplementary Information of AlphaFold-Multimer paper:
|
||||
Evans,R et al., 2022 Protein complex prediction with AlphaFold-Multimer, bioRxiv 2021.10.04.463034; doi: https://doi.org/10.1101/2021.10.04.463034
|
||||
|
||||
Args:
|
||||
batch: a dictionary of ground truth features
|
||||
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
|
||||
entity_2_asym_list: a dictionary recording which asym_id(s) belong to which entity_id
|
||||
pred_ca_pos: predicted positions of c-alpha atoms from the results of model.forward()
|
||||
pred_ca_mask: a boolean tensor that masks pred_ca_pos
|
||||
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
|
||||
true_ca_masks: a list of tensors, corresponding to the masks of c-alpha positions of the ground truth structure. If there are 5 chains, this list will have a length of 5
|
||||
|
||||
Return:
|
||||
A list of tuple(int,int) that provides instructions of how the ground truth chains should be permuated
|
||||
e.g. if 3 chains in the imput model have the same sequences, an example return would be:
|
||||
[(0,2),(1,1),(2,0)], meaning the 1st chain in the predicted structure should be aligned to the 3rd chain in the ground truth,
|
||||
and the 2nd chain in the predicted structure is ok to stay with the 2nd chain in the ground truth.
|
||||
|
||||
Note: the tuples in the returned list begin with 0 indexing but aym_id begins with 1. The reason why tuples in the return are 0-indexing
|
||||
is that at the stage of loss calculation, the ground truth atom positions: true_ca_poses, are already split up into a list of matrices.
|
||||
Hence, now this function needs to return tuples that provide the index to select from the list: true_ca_poses, and list index starts from 0.
|
||||
"""
|
||||
used = [False for _ in range(len(true_ca_poses))]
|
||||
used = [False for _ in range(len(true_ca_poses))] # a list the keeps recording whether a ground truth chain has been used or not
|
||||
align = []
|
||||
unique_asym_ids = [i for i in torch.unique(batch["asym_id"]) if i != 0]
|
||||
for cur_asym_id in unique_asym_ids:
|
||||
@@ -189,21 +228,38 @@ def greedy_align(
|
||||
return align
|
||||
|
||||
|
||||
def pad_features(feature_tensor, nres_pad, pad_dim):
|
||||
"""Pad input feature tensor"""
|
||||
def pad_features(feature_tensor: torch.Tensor, nres_pad: int, pad_dim: int) -> torch.Tensor:
|
||||
"""
|
||||
Pad input feature tensor. Padding values will be 0 and put behind the true feature values
|
||||
|
||||
Args:
|
||||
feature_tensor: A feature tensor
|
||||
nres_pad: number of residues to add
|
||||
pad_dim: along which dimension of the feature_tensor to pad
|
||||
|
||||
Returns:
|
||||
a padded feature tensor
|
||||
"""
|
||||
pad_shape = list(feature_tensor.shape)
|
||||
pad_shape[pad_dim] = nres_pad
|
||||
padding_tensor = feature_tensor.new_zeros(pad_shape, device=feature_tensor.device)
|
||||
return torch.concat((feature_tensor, padding_tensor), dim=pad_dim)
|
||||
|
||||
|
||||
def merge_labels(per_asym_residue_index, labels, align, original_nres):
|
||||
def merge_labels(per_asym_residue_index: Dict[int,List[int]],
|
||||
labels: List[Dict], align: List[Tuple[int, int]],
|
||||
original_nres: int) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Merge ground truth labels according to the permutation results
|
||||
|
||||
labels: list of original ground truth feats
|
||||
align: list of tuples, each entry specify the corresponding label of the asym.
|
||||
Args:
|
||||
per_asym_residue_index: a dictionary recording which residues belong to which aysm_id
|
||||
labels: list of original ground truth feats e.g. if there're 5 chains, labels will have a length of 5
|
||||
align: list of tuples, each entry specify the corresponding label of the asym.
|
||||
original_nres: int, corresponding to the number of residues specified by crop_size in config.py
|
||||
|
||||
Returns:
|
||||
A new dictionary of permuated ground truth features
|
||||
modified based on UniFold:
|
||||
https://github.com/dptech-corp/Uni-Fold/blob/b1c89a2cebd4e4ee4c47b4e443f92beeb9138fbb/unifold/losses/chain_align.py#L176C1-L176C1
|
||||
"""
|
||||
@@ -230,13 +286,20 @@ def merge_labels(per_asym_residue_index, labels, align, original_nres):
|
||||
return outs
|
||||
|
||||
|
||||
def split_ground_truth_labels(gt_features):
|
||||
def split_ground_truth_labels(gt_features: dict) -> List[Dict]:
|
||||
"""
|
||||
Splits ground truth features according to chains
|
||||
|
||||
Args:
|
||||
gt_features: A dictionary within a the PyTorch DataSet iteration, which returns by the upstream DataLoader.iter() method
|
||||
In the DataLoader pipeline, all tensors belonging to all the ground truth changes are concatenated so it stays the same as monomer data input format/pipeline,
|
||||
thus, this function is needed to 1) detect the number of chains i.e. unique(asym_id)
|
||||
2) split the concatenated tensors back to individual ones that correspond to individual asym_ids
|
||||
|
||||
Returns:
|
||||
a list of feature dictionaries with only necessary ground truth features
|
||||
required to finish multi-chain permutation
|
||||
a list of feature dictionaries with only necessary ground truth features
|
||||
required to finish multi-chain permutation, e.g. it will be a list of 5 elements if there
|
||||
are 5 chains in total.
|
||||
"""
|
||||
unique_asym_ids, asym_id_counts = torch.unique(gt_features["asym_id"], sorted=True, return_counts=True)
|
||||
n_res = gt_features["asym_id"].shape[-1]
|
||||
@@ -251,7 +314,16 @@ def split_ground_truth_labels(gt_features):
|
||||
return labels
|
||||
|
||||
|
||||
def get_per_asym_residue_index(features):
|
||||
def get_per_asym_residue_index(features: dict) -> Dict[int, torch.Tensor]:
|
||||
"""
|
||||
A function that retrieve which residues belong to which asym_id
|
||||
|
||||
Args:
|
||||
features: a dictionary that contains input features after cropping
|
||||
|
||||
Returns:
|
||||
A dictionary that records which region of the sequence belongs to which asym_id
|
||||
"""
|
||||
unique_asym_ids = [i for i in torch.unique(features["asym_id"]) if i != 0]
|
||||
per_asym_residue_index = {}
|
||||
for cur_asym_id in unique_asym_ids:
|
||||
@@ -261,34 +333,36 @@ def get_per_asym_residue_index(features):
|
||||
return per_asym_residue_index
|
||||
|
||||
|
||||
def get_entity_2_asym_list(batch):
|
||||
def get_entity_2_asym_list(features: dict) -> Dict[int, list]:
|
||||
"""
|
||||
Generates a dictionary mapping unique entity IDs to lists of unique asymmetry IDs (asym_id) for each entity.
|
||||
|
||||
Args:
|
||||
batch (dict): A dictionary containing data batches, including "entity_id" and "asym_id" tensors.
|
||||
features (dict): A dictionary containing data features, including "entity_id" and "asym_id" tensors.
|
||||
|
||||
Returns:
|
||||
entity_2_asym_list (dict): A dictionary where keys are unique entity IDs, and values are lists of unique asymmetry IDs
|
||||
associated with each entity.
|
||||
"""
|
||||
entity_2_asym_list = {}
|
||||
unique_entity_ids = torch.unique(batch["entity_id"])
|
||||
unique_entity_ids = torch.unique(features["entity_id"])
|
||||
for cur_ent_id in unique_entity_ids:
|
||||
ent_mask = batch["entity_id"] == cur_ent_id
|
||||
cur_asym_id = torch.unique(batch["asym_id"][ent_mask])
|
||||
ent_mask = features["entity_id"] == cur_ent_id
|
||||
cur_asym_id = torch.unique(features["asym_id"][ent_mask])
|
||||
entity_2_asym_list[int(cur_ent_id)] = cur_asym_id
|
||||
return entity_2_asym_list
|
||||
|
||||
|
||||
def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
|
||||
asym_mask, pred_ca_mask):
|
||||
def calculate_input_mask(true_ca_masks: List[torch.Tensor], anchor_gt_idx: torch.Tensor,
|
||||
anchor_gt_residue: torch.Tensor,
|
||||
asym_mask: torch.Tensor, pred_ca_mask: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Calculate an input mask for downstream optimal transformation computation
|
||||
|
||||
Args:
|
||||
true_ca_masks (Tensor): ca mask from ground truth.
|
||||
anchor_gt_idx (Tensor): The index of selected ground truth anchor.
|
||||
true_ca_masks: list of masks from ground truth chains.
|
||||
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
|
||||
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
|
||||
asym_mask (Tensor): Boolean tensor indicating which regions are selected predicted anchor.
|
||||
pred_ca_mask (Tensor): ca mask from predicted structure.
|
||||
|
||||
@@ -303,11 +377,38 @@ def calculate_input_mask(true_ca_masks, anchor_gt_idx, anchor_gt_residue,
|
||||
return input_mask
|
||||
|
||||
|
||||
def calculate_optimal_transform(true_ca_poses,
|
||||
anchor_gt_idx, anchor_gt_residue,
|
||||
true_ca_masks, pred_ca_mask,
|
||||
asym_mask,
|
||||
pred_ca_pos):
|
||||
def calculate_optimal_transform(true_ca_poses: List[torch.Tensor],
|
||||
anchor_gt_idx: int, anchor_gt_residue: torch.Tensor,
|
||||
true_ca_masks: List[torch.Tensor], pred_ca_mask: torch.Tensor,
|
||||
asym_mask: torch.Tensor,
|
||||
pred_ca_pos: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
"""
|
||||
Takes selected anchor ground truth c-alpha positions and
|
||||
selected predicted anchor c-alpha position then calculate the optimal rotation matrix
|
||||
to align ground-truth anchor and predicted anchor
|
||||
|
||||
Args:
|
||||
true_ca_poses: a list of tensors, corresponding to the c-alpha positions of the ground truth structure. e.g. If there are 5 chains, this list will have a length of 5
|
||||
anchor_gt_idx (Tensor): a tensor with one integer in it. The index of selected ground truth anchor.
|
||||
anchor_gt_residue:a 1D vector tensor of residue indexes that belongs to the selected ground truth anchor
|
||||
true_ca_masks: list of masks from ground truth chains e.g. it will be length=5 if there are 5 chains in ground truth structure
|
||||
pred_ca_mask: A boolean tensor corresponds to the mask to mask the predicted features
|
||||
asym_mask: A boolean tensor that mask out other elements in a tensor if they do not belong to a this asym_id
|
||||
pred_ca_pos: a [nres*3] tensor of predicted c-alpha atom positions
|
||||
|
||||
Process:
|
||||
1) select an achor chain from ground truth, denoted by anchor_gt_idx, and
|
||||
an chor chain from the predicted structure. Both anchor_gt and anchor_pred have exactly the same sequence
|
||||
2) obtain the C-alpha positions corresponding to the selected anchor_gt, done be slicing the true_ca_pose according to anchor_gt_residue
|
||||
3) calculate the optimal transformation that can best align the C-alpha atoms of anchor_pred to those of anchor_gt,
|
||||
done by Kabsch algorithm: source https://en.wikipedia.org/wiki/Kabsch_algorithm
|
||||
|
||||
Returns:
|
||||
a rotation matrix that record the optimal rotation
|
||||
that will best align selected anchor prediction to selected anchor truth
|
||||
a matrix records how the atoms should be shifted after applying r i.e. optimal alignment requires 1) rotate 2) shift the positions
|
||||
"""
|
||||
input_mask = calculate_input_mask(true_ca_masks,
|
||||
anchor_gt_idx,
|
||||
anchor_gt_residue,
|
||||
@@ -326,13 +427,27 @@ def calculate_optimal_transform(true_ca_poses,
|
||||
return r, x
|
||||
|
||||
|
||||
def compute_permutation_alignment(out, features, ground_truth):
|
||||
def compute_permutation_alignment(out: Dict[str,torch.Tensor],
|
||||
features: Dict[str,torch.Tensor],
|
||||
ground_truth: List[Dict[str, torch.Tensor]]) -> Tuple[List[Tuple[int, int]], Dict[int, List[int]]]:
|
||||
"""
|
||||
A class method that first permutate chains in ground truth first
|
||||
before calculating the loss.
|
||||
A method that permutes chains in ground truth before calculating the loss
|
||||
because the mapping between the predicted and ground-truth will become arbitrary.
|
||||
The model cannot be assumed to predict chains in the same order as the ground truth.
|
||||
Thus, this function pick the optimal permutaion of predicted chains that best matches the ground truth,
|
||||
by minimising the RMSD i.e. the best permutation of ground truth chains is selected based on which permutation has the lowest RMSD calculation
|
||||
|
||||
Details are described in Section 7.3 in the Supplementary of AlphaFold-Multimer paper:
|
||||
https://www.biorxiv.org/content/10.1101/2021.10.04.463034v2
|
||||
|
||||
Args:
|
||||
out: a dictionary of output tensors from model.forward()
|
||||
features: a dictionary of feature tensors that are used as input for model.forward()
|
||||
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
|
||||
|
||||
Returns:
|
||||
a list of tuple(int,int) that instructs how ground truth chains should be permutated
|
||||
a dictionary recording which residues belong to which aysm_id
|
||||
"""
|
||||
unique_asym_ids = set(torch.unique(features['asym_id']).tolist())
|
||||
unique_asym_ids.discard(0) # Remove padding asym_id
|
||||
@@ -397,13 +512,19 @@ def compute_permutation_alignment(out, features, ground_truth):
|
||||
return best_align, per_asym_residue_index
|
||||
|
||||
|
||||
def multi_chain_permutation_align(out, features, ground_truth):
|
||||
"""Compute multi-chain permutation alignment.
|
||||
def multi_chain_permutation_align(out: Dict[str, torch.Tensor],
|
||||
features: Dict[str, torch.Tensor],
|
||||
ground_truth: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute multi-chain permutation alignment.
|
||||
|
||||
Args:
|
||||
out: The output of model.forward()
|
||||
features: Input features
|
||||
ground_truth: Ground truth features
|
||||
out: a dictionary of output tensors from model.forward()
|
||||
features: a dictionary of feature tensors that are used as input for model.forward()
|
||||
ground_truth: a list of dictionaries of features corresponding to chains in ground truth structure e.g. it will be a length of 5 if there are 5 chains in ground truth structure
|
||||
|
||||
Returns:
|
||||
features: a dictionary with updated ground truth feature tensors, ready for downstream loss calculations.
|
||||
"""
|
||||
|
||||
labels = split_ground_truth_labels(ground_truth)
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import random
|
||||
import numpy as np
|
||||
from pytorch_lightning import seed_everything
|
||||
|
||||
from openfold.utils.suppress_output import SuppressLogging
|
||||
|
||||
|
||||
def seed_globally(seed=None):
|
||||
if("PL_GLOBAL_SEED" not in os.environ):
|
||||
if(seed is None):
|
||||
seed = random.randint(0, np.iinfo(np.uint32).max)
|
||||
os.environ["PL_GLOBAL_SEED"] = str(seed)
|
||||
logging.info(f'os.environ["PL_GLOBAL_SEED"] set to {seed}')
|
||||
|
||||
# seed_everything is a bit log-happy
|
||||
with SuppressLogging(logging.INFO):
|
||||
seed_everything(seed=None)
|
||||
@@ -35,10 +35,10 @@ def _superimpose_np(reference, coords):
|
||||
|
||||
|
||||
def _superimpose_single(reference, coords):
|
||||
reference_np = reference.detach().cpu().numpy()
|
||||
coords_np = coords.detach().cpu().numpy()
|
||||
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
|
||||
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
|
||||
reference_np = reference.detach().to(torch.float).cpu().numpy()
|
||||
coords_np = coords.detach().to(torch.float).cpu().numpy()
|
||||
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
|
||||
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
|
||||
|
||||
|
||||
def superimpose(reference, coords, mask):
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class SuppressStdout:
|
||||
def __enter__(self):
|
||||
self.stdout = sys.stdout
|
||||
dev_null = open("/dev/null", "w")
|
||||
sys.stdout = dev_null
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
fp = sys.stdout
|
||||
sys.stdout = self.stdout
|
||||
fp.close()
|
||||
|
||||
|
||||
class SuppressLogging:
|
||||
def __init__(self, level):
|
||||
self.level = level
|
||||
|
||||
def __enter__(self):
|
||||
logging.disable(self.level)
|
||||
|
||||
def __exit__(self, typ, value, traceback):
|
||||
logging.disable(logging.NOTSET)
|
||||
|
||||
@@ -114,8 +114,7 @@ def tree_map(fn, tree, leaf_type):
|
||||
elif isinstance(tree, leaf_type):
|
||||
return fn(tree)
|
||||
else:
|
||||
print(type(tree))
|
||||
raise ValueError("Not supported")
|
||||
raise ValueError(f"Tree of type {type(tree)} not supported")
|
||||
|
||||
|
||||
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
|
||||
|
||||
@@ -20,6 +20,7 @@ import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import json
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__file__)
|
||||
@@ -131,7 +132,16 @@ def generate_feature_dict(
|
||||
args,
|
||||
):
|
||||
tmp_fasta_path = os.path.join(args.output_dir, f"tmp_{os.getpid()}.fasta")
|
||||
if len(seqs) == 1:
|
||||
|
||||
if "multimer" in args.config_preset:
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
fp.write(
|
||||
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
|
||||
)
|
||||
feature_dict = data_processor.process_fasta(
|
||||
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
|
||||
)
|
||||
elif len(seqs) == 1:
|
||||
tag = tags[0]
|
||||
seq = seqs[0]
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
@@ -143,14 +153,6 @@ def generate_feature_dict(
|
||||
alignment_dir=local_alignment_dir,
|
||||
seqemb_mode=args.use_single_seq_mode,
|
||||
)
|
||||
elif "multimer" in args.config_preset:
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
fp.write(
|
||||
'\n'.join([f">{tag}\n{seq}" for tag, seq in zip(tags, seqs)])
|
||||
)
|
||||
feature_dict = data_processor.process_fasta(
|
||||
fasta_path=tmp_fasta_path, alignment_dir=alignment_dir,
|
||||
)
|
||||
else:
|
||||
with open(tmp_fasta_path, "w") as fp:
|
||||
fp.write(
|
||||
@@ -177,7 +179,21 @@ def main(args):
|
||||
if args.config_preset.startswith("seq"):
|
||||
args.use_single_seq_mode = True
|
||||
|
||||
config = model_config(args.config_preset, long_sequence_inference=args.long_sequence_inference)
|
||||
config = model_config(
|
||||
args.config_preset,
|
||||
long_sequence_inference=args.long_sequence_inference,
|
||||
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
|
||||
)
|
||||
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
if args.trace_model:
|
||||
if not config.data.predict.fixed_size:
|
||||
@@ -452,6 +468,13 @@ if __name__ == "__main__":
|
||||
"--cif_output", action="store_true", default=False,
|
||||
help="Output predicted models in ModelCIF format instead of PDB format (default)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
|
||||
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
def main(args):
|
||||
# get the super index
|
||||
with open(args.alignment_db_super_index_path, "r") as fp:
|
||||
super_index = json.load(fp)
|
||||
|
||||
# get all chains and sequences
|
||||
chains_to_seqs = {}
|
||||
with open(args.all_chains_fasta, "r") as fp:
|
||||
lines = fp.readlines()
|
||||
|
||||
# iterate through chain-sequence pairs
|
||||
for chain_idx in range(0, len(lines), 2):
|
||||
chain = lines[chain_idx][1:].strip()
|
||||
seq = lines[chain_idx + 1].strip()
|
||||
chains_to_seqs[chain] = seq
|
||||
|
||||
chains_w_alignments = set(super_index.keys())
|
||||
chains_wo_alignments = set(chains_to_seqs.keys()) - chains_w_alignments
|
||||
|
||||
seq_to_chain_w_alignment = {
|
||||
chains_to_seqs[chain]: chain for chain in chains_w_alignments
|
||||
}
|
||||
|
||||
print("Unique sequences with alignments:", len(seq_to_chain_w_alignment))
|
||||
|
||||
# map chain without alignment to alignment entry of another chain with the
|
||||
# same sequence
|
||||
remaining_unaligned_chains = []
|
||||
for chain in chains_wo_alignments:
|
||||
seq = chains_to_seqs[chain]
|
||||
|
||||
try:
|
||||
corresponding_alignment = super_index[seq_to_chain_w_alignment[seq]]
|
||||
# no corresponding chain with alignment found
|
||||
except KeyError:
|
||||
remaining_unaligned_chains.append(chain)
|
||||
continue
|
||||
|
||||
super_index[chain] = corresponding_alignment
|
||||
|
||||
with open(args.output_path, "w") as fp:
|
||||
json.dump(super_index, fp)
|
||||
|
||||
print(
|
||||
f"No corresponding alignment found for the following {len(remaining_unaligned_chains)} chains:",
|
||||
remaining_unaligned_chains,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(
|
||||
description="""
|
||||
If the alignment-db index was created on unique-chain alignments only,
|
||||
this will add the missing chain entries to the super-index file based on
|
||||
a .fasta file that contains sequences for all chains.
|
||||
|
||||
Note that this only modifies the index and not the database itself, as
|
||||
the duplicate sequences will just point to the same alignments.
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"alignment_db_super_index_path",
|
||||
type=Path,
|
||||
help="Path to alignment-db super index file.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path", type=Path, help="Write the output super index to this path."
|
||||
)
|
||||
parser.add_argument(
|
||||
"all_chains_fasta",
|
||||
type=Path,
|
||||
help="Path to the fasta file containing sequences for all chains.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
193
scripts/alignment_db_scripts/create_alignment_db_sharded.py
Normal file
193
scripts/alignment_db_scripts/create_alignment_db_sharded.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""
|
||||
This is a modified version of the create_alignment_db.py script in OpenFold
|
||||
which supports sharding into multiple files. The created index is already a
|
||||
super index, meaning that "unify_alignment_db_indices.py" does not need to be
|
||||
run on the output index. Additionally this script uses threading and
|
||||
multiprocessing and is much faster than the old version.
|
||||
"""
|
||||
import argparse
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from tqdm import tqdm
|
||||
from math import ceil
|
||||
|
||||
|
||||
def split_file_list(file_list, n_shards):
|
||||
"""
|
||||
Split up the total file list into n_shards sublists.
|
||||
"""
|
||||
split_list = []
|
||||
|
||||
for i in range(n_shards):
|
||||
split_list.append(file_list[i::n_shards])
|
||||
|
||||
assert len([f for sublist in split_list for f in sublist]) == len(file_list)
|
||||
|
||||
return split_list
|
||||
|
||||
|
||||
def chunked_iterator(lst, chunk_size):
|
||||
"""Iterate over a list in chunks of size chunk_size."""
|
||||
for i in range(0, len(lst), chunk_size):
|
||||
yield lst[i : i + chunk_size]
|
||||
|
||||
|
||||
def read_chain_dir(chain_dir) -> dict:
|
||||
"""
|
||||
Read all alignment files in a single chain directory and return a dict
|
||||
mapping chain name to file names and bytes.
|
||||
"""
|
||||
if not chain_dir.is_dir():
|
||||
raise ValueError(f"chain_dir must be a directory, but is {chain_dir}")
|
||||
|
||||
# ensure that PDB IDs are all lowercase
|
||||
pdb_id, chain = chain_dir.name.split("_")
|
||||
pdb_id = pdb_id.lower()
|
||||
chain_name = f"{pdb_id}_{chain}"
|
||||
|
||||
|
||||
file_data = []
|
||||
|
||||
for file_path in sorted(chain_dir.iterdir()):
|
||||
file_name = file_path.name
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_bytes = file.read()
|
||||
|
||||
file_data.append((file_name, file_bytes))
|
||||
|
||||
return {chain_name: file_data}
|
||||
|
||||
|
||||
def process_chunk(chain_files: List[Path]) -> dict:
|
||||
"""
|
||||
Returns the file names and bytes for all chains in a chunk of files.
|
||||
"""
|
||||
chunk_data = {}
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
for file_data in executor.map(read_chain_dir, chain_files):
|
||||
chunk_data.update(file_data)
|
||||
|
||||
return chunk_data
|
||||
|
||||
|
||||
def create_index_default_dict() -> dict:
|
||||
"""
|
||||
Returns a default dict for the index entries).
|
||||
"""
|
||||
return {"db": None, "files": []}
|
||||
|
||||
|
||||
def create_shard(
|
||||
shard_files: List[Path], output_dir: Path, output_name: str, shard_num: int
|
||||
) -> dict:
|
||||
"""
|
||||
Creates a single shard of the alignment database, and returns the
|
||||
corresponding indices for the super index.
|
||||
"""
|
||||
CHUNK_SIZE = 200
|
||||
shard_index = defaultdict(
|
||||
create_index_default_dict
|
||||
) # {chain_name: {db: str, files: [(file_name, db_offset, file_length)]}, ...}
|
||||
chunk_iter = chunked_iterator(shard_files, CHUNK_SIZE)
|
||||
|
||||
pbar_desc = f"Shard {shard_num}"
|
||||
output_path = output_dir / f"{output_name}_{shard_num}.db"
|
||||
|
||||
db_offset = 0
|
||||
db_file = open(output_path, "wb")
|
||||
for files_chunk in tqdm(
|
||||
chunk_iter, total=ceil(len(shard_files) / CHUNK_SIZE), desc=pbar_desc, position=shard_num, leave=False
|
||||
):
|
||||
# get processed files for one chunk
|
||||
chunk_data = process_chunk(files_chunk)
|
||||
|
||||
# write to db and store info in index
|
||||
for chain_name, file_data in chunk_data.items():
|
||||
shard_index[chain_name]["db"] = output_path.name
|
||||
|
||||
for file_name, file_bytes in file_data:
|
||||
file_length = len(file_bytes)
|
||||
shard_index[chain_name]["files"].append(
|
||||
(file_name, db_offset, file_length)
|
||||
)
|
||||
db_file.write(file_bytes)
|
||||
db_offset += file_length
|
||||
db_file.close()
|
||||
|
||||
return shard_index
|
||||
|
||||
|
||||
def main(args):
|
||||
alignment_dir = args.alignment_dir
|
||||
output_dir = args.output_db_path
|
||||
output_db_name = args.output_db_name
|
||||
n_shards = args.n_shards
|
||||
|
||||
# get all chain dirs in alignment_dir
|
||||
print("Getting chain directories...")
|
||||
all_chain_dirs = sorted([f for f in tqdm(alignment_dir.iterdir())])
|
||||
|
||||
# split chain dirs into n_shards sublists
|
||||
chain_dir_shards = split_file_list(all_chain_dirs, n_shards)
|
||||
|
||||
# total index for all shards
|
||||
super_index = {}
|
||||
|
||||
# create a shard for each sublist
|
||||
print(f"Creating {n_shards} alignment-db files...")
|
||||
with ProcessPoolExecutor() as executor:
|
||||
futures = [
|
||||
executor.submit(
|
||||
create_shard, shard_files, output_dir, output_db_name, shard_index
|
||||
)
|
||||
for shard_index, shard_files in enumerate(chain_dir_shards)
|
||||
]
|
||||
|
||||
for future in as_completed(futures):
|
||||
shard_index = future.result()
|
||||
super_index.update(shard_index)
|
||||
print("\nCreated all shards.")
|
||||
|
||||
# write super index to file
|
||||
print("\nWriting super index...")
|
||||
index_path = output_dir / f"{output_db_name}.index"
|
||||
with open(index_path, "w") as fp:
|
||||
json.dump(super_index, fp, indent=4)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="""
|
||||
This script creates an alignment database format from a directory of
|
||||
precomputed alignments. For better file system health, the total
|
||||
database is split into n_shards files, where each shard contains a
|
||||
subset of the total alignments. The output is a directory containing the
|
||||
n_shards database files, and a single index file mapping chain names to
|
||||
the database file and byte offsets for each alignment file.
|
||||
|
||||
Note: For optimal performance, your machine should have at least as many
|
||||
cores as shards you want to create.
|
||||
"""
|
||||
)
|
||||
parser.add_argument(
|
||||
"alignment_dir",
|
||||
type=Path,
|
||||
help="""Path to precomputed alignment directory, with one subdirectory
|
||||
per chain.""",
|
||||
)
|
||||
parser.add_argument("output_db_path", type=Path)
|
||||
parser.add_argument("output_db_name", type=str)
|
||||
parser.add_argument(
|
||||
"n_shards", type=int, help="Number of shards to split the database into"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
97
scripts/convert_v1_to_v2_weights.py
Executable file
97
scripts/convert_v1_to_v2_weights.py
Executable file
@@ -0,0 +1,97 @@
|
||||
# Copyright 2022 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
|
||||
# used to run inference using DeepMind's JAX code.
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import torch
|
||||
|
||||
from openfold.utils.import_weights import convert_deprecated_v1_keys
|
||||
from deepspeed.utils.zero_to_fp32 import (
|
||||
get_optim_files, parse_optim_states, get_model_state_file
|
||||
)
|
||||
|
||||
|
||||
def convert_v1_to_v2_weights(args):
|
||||
checkpoint_path = args.input_ckpt_path
|
||||
is_dir = os.path.isdir(checkpoint_path)
|
||||
if is_dir:
|
||||
# A DeepSpeed checkpoint
|
||||
logging.info(
|
||||
'Converting deepspeed checkpoint found at {args.input_checkpoint_path}')
|
||||
state_dict_key = 'module'
|
||||
latest_path = os.path.join(checkpoint_path, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
|
||||
ds_checkpoint_dir = os.path.join(checkpoint_path, tag)
|
||||
model_output_path = os.path.join(args.output_ckpt_path, tag)
|
||||
optim_files = get_optim_files(ds_checkpoint_dir)
|
||||
zero_stage, _, _ = parse_optim_states(optim_files, ds_checkpoint_dir)
|
||||
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
|
||||
else:
|
||||
# A Pytorch Lightning checkpoint
|
||||
logging.info(
|
||||
'Converting pytorch lightning checkpoint found at {args.input_checkpoint_path}')
|
||||
state_dict_key = 'state_dict'
|
||||
model_output_path = args.output_ckpt_path
|
||||
model_file = checkpoint_path
|
||||
|
||||
model_dict = torch.load(model_file, map_location=torch.device('cpu'))
|
||||
model_dict[state_dict_key] = convert_deprecated_v1_keys(
|
||||
model_dict[state_dict_key])
|
||||
|
||||
if 'ema' in model_dict:
|
||||
ema_state_dict = model_dict['ema']['params']
|
||||
model_dict['ema']['params'] = convert_deprecated_v1_keys(
|
||||
ema_state_dict)
|
||||
|
||||
if is_dir:
|
||||
param_shapes = convert_deprecated_v1_keys(
|
||||
model_dict['param_shapes'][0])
|
||||
model_dict['param_shapes'] = [param_shapes]
|
||||
|
||||
shutil.copytree(checkpoint_path, args.output_ckpt_path)
|
||||
out_fname = os.path.join(
|
||||
model_output_path, os.path.basename(model_file))
|
||||
|
||||
for optim_file in optim_files:
|
||||
optim_dict = torch.load(optim_file)
|
||||
new_optim_dict = optim_dict.copy()
|
||||
new_optim_dict['optimizer_state_dict']['param_slice_mappings'][0] = convert_deprecated_v1_keys(
|
||||
optim_dict['optimizer_state_dict']['param_slice_mappings'][0])
|
||||
out_optim_fname = os.path.join(
|
||||
model_output_path, os.path.basename(optim_file))
|
||||
torch.save(new_optim_dict, out_optim_fname)
|
||||
else:
|
||||
out_fname = model_output_path
|
||||
|
||||
torch.save(model_dict, out_fname)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("input_ckpt_path", type=str)
|
||||
parser.add_argument("output_ckpt_path", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
convert_v1_to_v2_weights(args)
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright 2024 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,9 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Downloads OpenFold parameters.
|
||||
# Downloads OpenFold SoloSeq (single sequence model) parameters.
|
||||
#
|
||||
# Usage: bash download_openfold_params_huggingface.sh /path/to/download/directory
|
||||
# Usage: bash download_openfold_soloseq_params.sh /path/to/download/directory
|
||||
set -e
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
|
||||
34
scripts/download_soloseq_embeddings.sh
Executable file
34
scripts/download_soloseq_embeddings.sh
Executable file
@@ -0,0 +1,34 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Copyright 2024 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Downloads ESM-1b embeddings used to train OpenFold SoloSeq single-seq model.
|
||||
#
|
||||
# Usage: bash download_soloseq_embeddings.sh /path/to/download/directory
|
||||
set -e
|
||||
|
||||
if [[ $# -eq 0 ]]; then
|
||||
echo "Error: download directory must be provided as an input argument."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v aws &> /dev/null ; then
|
||||
echo "Error: aws could not be found. Please install aws."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
DOWNLOAD_DIR="${1}/soloseq_embeddings"
|
||||
mkdir -p "${DOWNLOAD_DIR}"
|
||||
aws s3 cp --no-sign-request --region us-east-1 s3://openfold/soloseq_embeddings/ "${DOWNLOAD_DIR}" --recursive
|
||||
105
scripts/fasta_to_clusterfile.py
Normal file
105
scripts/fasta_to_clusterfile.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""
|
||||
This script takes a .fasta file as input and then clusters it on a given
|
||||
sequence identity threshold using mmseqs2. The mmseqs2 flags are identical to
|
||||
what PDB officially uses to provide their official sequence clusters
|
||||
(https://github.com/soedinglab/MMseqs2/issues/452).
|
||||
"""
|
||||
import shutil
|
||||
import subprocess
|
||||
from argparse import ArgumentParser
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def reformat_cluster_file(cluster_file: Path, output_file: Path):
|
||||
"""
|
||||
This function takes a mmseqs2 output cluster file and reformats it to a text
|
||||
file where each line contains a space-separated list of {PDB_ID}_{CHAIN_ID}
|
||||
belonging to the same cluster.
|
||||
"""
|
||||
cluster_to_chains = defaultdict(list)
|
||||
|
||||
# extract all chains belonging to each cluster
|
||||
with open(cluster_file, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
cluster_name, chain_id = line.split()
|
||||
cluster_to_chains[cluster_name].append(chain_id)
|
||||
|
||||
# write all chains belonging to the same cluster on the same line
|
||||
with open(output_file, "w") as f:
|
||||
for chains in cluster_to_chains.values():
|
||||
f.write(f"{' '.join(chains)}\n")
|
||||
|
||||
|
||||
def main(args):
|
||||
input_file = args.input_fasta.absolute()
|
||||
output_file = args.output_file.absolute()
|
||||
output_dir = args.output_file.parent
|
||||
|
||||
# prefix that all output files get
|
||||
mmseqs_prefix = "_mmseqs_out"
|
||||
|
||||
# temporary directory that mmseqs2 uses
|
||||
tmp_name = f"{mmseqs_prefix}_temp"
|
||||
tmp_dir = output_dir / tmp_name
|
||||
|
||||
mmseqs_command = [
|
||||
args.mmseqs_binary_path,
|
||||
"easy-cluster",
|
||||
input_file,
|
||||
mmseqs_prefix,
|
||||
tmp_name,
|
||||
"--min-seq-id",
|
||||
str(args.seq_id),
|
||||
"-c",
|
||||
"0.9",
|
||||
"-s",
|
||||
"8",
|
||||
"--max-seqs",
|
||||
"1000",
|
||||
"--cluster-mode",
|
||||
"1",
|
||||
]
|
||||
|
||||
# run mmseqs with PDB settings
|
||||
print("Running mmseqs2...")
|
||||
subprocess.run(mmseqs_command, check=True, cwd=output_dir)
|
||||
|
||||
cluster_file = output_dir / "_mmseqs_out_cluster.tsv"
|
||||
|
||||
print("Reformatting output file...")
|
||||
reformat_cluster_file(cluster_file, output_file)
|
||||
|
||||
print("Cleaning up mmseqs2 output...")
|
||||
mmseqs_outputs = [
|
||||
output_dir / f"{mmseqs_prefix}_{suffix}"
|
||||
for suffix in ["cluster.tsv", "rep_seq.fasta", "all_seqs.fasta"]
|
||||
]
|
||||
for file in mmseqs_outputs:
|
||||
file.unlink()
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
print("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser(
|
||||
description="Creates a sequence cluster file from a .fasta file using mmseqs2 with PDB settings."
|
||||
)
|
||||
parser.add_argument(
|
||||
"input_fasta",
|
||||
type=Path,
|
||||
help="Input .fasta file. Sequence names should be in format >{PDB_ID}_{CHAIN_ID}",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_file",
|
||||
type=Path,
|
||||
help="Output file. Each line will contain a space-separated list of {PDB_ID}_{CHAIN_ID} belonging to the same cluster.",
|
||||
)
|
||||
parser.add_argument("mmseqs_binary_path", type=str, help="Path to mmseqs binary")
|
||||
parser.add_argument("--seq-id", type=float, default=0.4, help="Sequence identity threshold for clustering.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -1,8 +1,14 @@
|
||||
import argparse
|
||||
import ctypes
|
||||
from datetime import date
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
if 'CONDA_PREFIX' in os.environ:
|
||||
CONDA_ENV_BINARY_PATH= Path(os.environ['CONDA_PREFIX']) / 'bin'
|
||||
else:
|
||||
CONDA_ENV_BINARY_PATH = Path('/bin')
|
||||
|
||||
def add_data_args(parser: argparse.ArgumentParser):
|
||||
parser.add_argument(
|
||||
@@ -30,22 +36,22 @@ def add_data_args(parser: argparse.ArgumentParser):
|
||||
'--bfd_database_path', type=str, default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
'--jackhmmer_binary_path', type=str, default='/usr/bin/jackhmmer'
|
||||
'--jackhmmer_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'jackhmmer'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--hhblits_binary_path', type=str, default='/usr/bin/hhblits'
|
||||
'--hhblits_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhblits'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--hhsearch_binary_path', type=str, default='/usr/bin/hhsearch'
|
||||
'--hhsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hhsearch'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--hmmsearch_binary_path', type=str, default='/usr/bin/hmmsearch'
|
||||
'--hmmsearch_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmsearch'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--hmmbuild_binary_path', type=str, default='/usr/bin/hmmbuild'
|
||||
'--hmmbuild_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'hmmbuild'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--kalign_binary_path', type=str, default='/usr/bin/kalign'
|
||||
'--kalign_binary_path', type=str, default=str(CONDA_ENV_BINARY_PATH / 'kalign'),
|
||||
)
|
||||
parser.add_argument(
|
||||
'--max_template_date', type=str,
|
||||
|
||||
@@ -1,465 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# This script extracts fp32 consolidated weights from a zero 2 and 3 DeepSpeed checkpoints. It gets
|
||||
# copied into the top level checkpoint dir, so the user can easily do the conversion at any point in
|
||||
# the future. Once extracted, the weights don't require DeepSpeed and can be used in any
|
||||
# application.
|
||||
#
|
||||
# example: python zero_to_fp32.py . pytorch_model.bin
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
import re
|
||||
|
||||
# while this script doesn't use deepspeed to recover data, since the checkpoints are pickled with
|
||||
# DeepSpeed data structures it has to be available in the current python environment.
|
||||
import deepspeed
|
||||
from deepspeed.utils import logger
|
||||
|
||||
debug = 0
|
||||
|
||||
# load to cpu
|
||||
device = torch.device('cpu')
|
||||
|
||||
|
||||
def get_model_state_file(checkpoint_dir, zero_stage):
|
||||
if not os.path.isdir(checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{checkpoint_dir}' doesn't exist")
|
||||
|
||||
# there should be only one file
|
||||
if zero_stage == 2:
|
||||
file = os.path.join(checkpoint_dir, "mp_rank_00_model_states.pt")
|
||||
elif zero_stage == 3:
|
||||
file = os.path.join(checkpoint_dir, "zero_pp_rank_0_mp_rank_00_model_states.pt")
|
||||
|
||||
if not os.path.exists(file):
|
||||
raise FileNotFoundError(f"can't find model states file at '{file}'")
|
||||
|
||||
return file
|
||||
|
||||
|
||||
def get_optim_files(checkpoint_dir):
|
||||
# XXX: need to test that this simple glob rule works for multi-node setup too
|
||||
optim_files = sorted(glob.glob(os.path.join(checkpoint_dir, "*_optim_states.pt")))
|
||||
|
||||
if len(optim_files) == 0:
|
||||
raise FileNotFoundError(
|
||||
f"can't find '*_optim_states.pt' files in directory '{checkpoint_dir}'")
|
||||
|
||||
return optim_files
|
||||
|
||||
|
||||
def parse_model_state(file):
|
||||
state_dict = torch.load(file, map_location=device)
|
||||
|
||||
if "buffer_names" not in state_dict:
|
||||
raise ValueError(f"{file} is not a model state checkpoint")
|
||||
buffer_names = state_dict["buffer_names"]
|
||||
if debug:
|
||||
print("Found buffers:", buffer_names)
|
||||
|
||||
# recover just the buffers while restoring them to fp32 if they were saved in fp16
|
||||
buffers = {
|
||||
k: v.float()
|
||||
for k,
|
||||
v in state_dict["module"].items() if k in buffer_names
|
||||
}
|
||||
return buffers
|
||||
|
||||
|
||||
def parse_optim_states(files, ds_checkpoint_dir):
|
||||
|
||||
total_files = len(files)
|
||||
state_dicts = []
|
||||
for f in files:
|
||||
state_dicts.append(torch.load(f, map_location=device))
|
||||
|
||||
if not "zero_stage" in state_dicts[0]['optimizer_state_dict']:
|
||||
raise ValueError(f"{files[0]} is not a zero checkpoint")
|
||||
zero_stage = state_dicts[0]['optimizer_state_dict']["zero_stage"]
|
||||
world_size = state_dicts[0]['optimizer_state_dict']["partition_count"]
|
||||
param_shapes = state_dicts[0]["param_shapes"]
|
||||
# For ZeRO-2 each param group can have different partition_count as data parallelism for expert
|
||||
# parameters can be different from data parallelism for non-expert parameters. So we can just
|
||||
# use the max of the partition_count to get the dp world_size.
|
||||
|
||||
if type(world_size) is list:
|
||||
world_size = max(world_size)
|
||||
|
||||
if world_size != total_files:
|
||||
raise ValueError(
|
||||
f"Expected {world_size} of '*_optim_states.pt' under '{ds_checkpoint_dir}' but found {total_files} files. "
|
||||
"Possibly due to an overwrite of an old checkpoint, or a checkpoint didn't get saved by one or more processes."
|
||||
)
|
||||
|
||||
# the groups are named differently in each stage
|
||||
if zero_stage == 2:
|
||||
fp32_groups_key = "single_partition_of_fp32_groups"
|
||||
elif zero_stage == 3:
|
||||
fp32_groups_key = "fp32_flat_groups"
|
||||
else:
|
||||
raise ValueError(f"unknown zero stage {zero_stage}")
|
||||
|
||||
if zero_stage == 2:
|
||||
fp32_flat_groups = [
|
||||
state_dicts[i]['optimizer_state_dict'][fp32_groups_key]
|
||||
for i in range(len(state_dicts))
|
||||
]
|
||||
elif zero_stage == 3:
|
||||
# if there is more than one param group, there will be multiple flattened tensors - one
|
||||
# flattened tensor per group - for simplicity merge them into a single tensor
|
||||
#
|
||||
# XXX: could make the script more memory efficient for when there are multiple groups - it
|
||||
# will require matching the sub-lists of param_shapes for each param group flattened tensor
|
||||
|
||||
fp32_flat_groups = [
|
||||
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
|
||||
0) for i in range(len(state_dicts))
|
||||
]
|
||||
|
||||
return zero_stage, world_size, param_shapes, fp32_flat_groups
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir):
|
||||
"""
|
||||
Returns fp32 state_dict reconstructed from ds checkpoint
|
||||
|
||||
Args:
|
||||
- ``ds_checkpoint_dir``: path to the deepspeed checkpoint folder (where the optimizer files are)
|
||||
|
||||
"""
|
||||
print(f"Processing zero checkpoint '{ds_checkpoint_dir}'")
|
||||
|
||||
optim_files = get_optim_files(ds_checkpoint_dir)
|
||||
zero_stage, world_size, param_shapes, fp32_flat_groups = parse_optim_states(optim_files, ds_checkpoint_dir)
|
||||
print(
|
||||
f"Detected checkpoint of type zero stage {zero_stage}, world_size: {world_size}")
|
||||
|
||||
model_file = get_model_state_file(ds_checkpoint_dir, zero_stage)
|
||||
buffers = parse_model_state(model_file)
|
||||
|
||||
if zero_stage == 2:
|
||||
return _get_fp32_state_dict_from_zero2_checkpoint(world_size,
|
||||
param_shapes,
|
||||
fp32_flat_groups,
|
||||
buffers)
|
||||
elif zero_stage == 3:
|
||||
return _get_fp32_state_dict_from_zero3_checkpoint(world_size,
|
||||
param_shapes,
|
||||
fp32_flat_groups,
|
||||
buffers)
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero2_checkpoint(world_size,
|
||||
param_shapes,
|
||||
fp32_flat_groups,
|
||||
buffers):
|
||||
|
||||
# Reconstruction protocol:
|
||||
#
|
||||
# XXX: document this
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
for j in range(len(fp32_flat_groups[0])):
|
||||
print(f"fp32_flat_groups[{i}][{j}].shape={fp32_flat_groups[i][j].shape}")
|
||||
|
||||
# XXX: memory usage doubles here (zero2)
|
||||
num_param_groups = len(fp32_flat_groups[0])
|
||||
merged_single_partition_of_fp32_groups = []
|
||||
for i in range(num_param_groups):
|
||||
merged_partitions = [sd[i] for sd in fp32_flat_groups]
|
||||
full_single_fp32_vector = torch.cat(merged_partitions, 0)
|
||||
merged_single_partition_of_fp32_groups.append(full_single_fp32_vector)
|
||||
avail_numel = sum([
|
||||
full_single_fp32_vector.numel()
|
||||
for full_single_fp32_vector in merged_single_partition_of_fp32_groups
|
||||
])
|
||||
|
||||
if debug:
|
||||
wanted_params = sum([len(shapes) for shapes in param_shapes])
|
||||
wanted_numel = sum(
|
||||
[sum(shape.numel() for shape in shapes.values()) for shapes in param_shapes])
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
print(f"Have {avail_numel} numels to process.")
|
||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for shapes, full_single_fp32_vector in zip(param_shapes, merged_single_partition_of_fp32_groups):
|
||||
offset = 0
|
||||
avail_numel = full_single_fp32_vector.numel()
|
||||
for name, shape in shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"{name} full shape: {shape} unpartitioned numel {unpartitioned_numel} "
|
||||
)
|
||||
state_dict[name] = full_single_fp32_vector.narrow(
|
||||
0,
|
||||
offset,
|
||||
unpartitioned_numel).view(shape)
|
||||
offset += unpartitioned_numel
|
||||
|
||||
# Z2 started to align to 2*world_size to improve nccl performance. Therefore both offset and
|
||||
# avail_numel can differ by anywhere between 0..2*world_size. Due to two unrelated complex
|
||||
# paddings performed in the code it's almost impossible to predict the exact numbers w/o the
|
||||
# live optimizer object, so we are checking that the numbers are within the right range
|
||||
align_to = 2 * world_size
|
||||
|
||||
def zero2_align(x):
|
||||
return align_to * math.ceil(x / align_to)
|
||||
|
||||
if debug:
|
||||
print(f"original offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
offset = zero2_align(offset)
|
||||
avail_numel = zero2_align(avail_numel)
|
||||
|
||||
if debug:
|
||||
print(f"aligned offset={offset}, avail_numel={avail_numel}")
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(
|
||||
f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(
|
||||
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def zero3_partitioned_param_info(unpartitioned_numel, world_size):
|
||||
remainder = unpartitioned_numel % world_size
|
||||
padding_numel = (world_size - remainder) if remainder else 0
|
||||
partitioned_numel = math.ceil(unpartitioned_numel / world_size)
|
||||
return partitioned_numel, padding_numel
|
||||
|
||||
|
||||
def _get_fp32_state_dict_from_zero3_checkpoint(world_size,
|
||||
param_shapes,
|
||||
fp32_flat_groups,
|
||||
buffers):
|
||||
|
||||
# Reconstruction protocol: For zero3 we need to zip the partitions together at boundary of each
|
||||
# param, re-consolidating each param, while dealing with padding if any
|
||||
|
||||
avail_numel = fp32_flat_groups[0].numel() * world_size
|
||||
# merge list of dicts, preserving order
|
||||
param_shapes = {k: v for d in param_shapes for k, v in d.items()}
|
||||
|
||||
if debug:
|
||||
for i in range(world_size):
|
||||
print(f"fp32_flat_groups[{i}].shape={fp32_flat_groups[i].shape}")
|
||||
|
||||
wanted_params = len(param_shapes)
|
||||
wanted_numel = sum(shape.numel() for shape in param_shapes.values())
|
||||
# not asserting if there is a mismatch due to possible padding
|
||||
print(f"Have {avail_numel} numels to process.")
|
||||
print(f"Need {wanted_numel} numels in {wanted_params} params.")
|
||||
|
||||
state_dict = OrderedDict()
|
||||
|
||||
# buffers
|
||||
state_dict.update(buffers)
|
||||
if debug:
|
||||
print(f"added {len(buffers)} buffers")
|
||||
|
||||
# params
|
||||
# XXX: for huge models that can't fit into the host's RAM we will have to recode this to support
|
||||
# out-of-core computing solution
|
||||
offset = 0
|
||||
total_numel = 0
|
||||
total_params = 0
|
||||
for name, shape in param_shapes.items():
|
||||
|
||||
unpartitioned_numel = shape.numel()
|
||||
total_numel += unpartitioned_numel
|
||||
total_params += 1
|
||||
|
||||
partitioned_numel, partitioned_padding_numel = zero3_partitioned_param_info(unpartitioned_numel, world_size)
|
||||
|
||||
if debug:
|
||||
print(
|
||||
f"{total_params} {name} full shape: {shape} partition0 numel={partitioned_numel} partitioned_padding_numel={partitioned_padding_numel}"
|
||||
)
|
||||
|
||||
# XXX: memory usage doubles here
|
||||
state_dict[name] = torch.cat(
|
||||
tuple(fp32_flat_groups[i].narrow(0,
|
||||
offset,
|
||||
partitioned_numel)
|
||||
for i in range(world_size)),
|
||||
0).narrow(0,
|
||||
0,
|
||||
unpartitioned_numel).view(shape)
|
||||
offset += partitioned_numel
|
||||
|
||||
offset *= world_size
|
||||
|
||||
# Sanity check
|
||||
if offset != avail_numel:
|
||||
raise ValueError(
|
||||
f"consumed {offset} numels out of {avail_numel} - something is wrong")
|
||||
|
||||
print(
|
||||
f"Reconstructed fp32 state dict with {total_params} params {total_numel} elements"
|
||||
)
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag=None):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated state_dict that can be loaded with
|
||||
``load_state_dict()`` and used for training without DeepSpeed or shared with others, for example
|
||||
via a model hub.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in 'latest' file. e.g., ``global_step14``
|
||||
|
||||
Returns:
|
||||
- pytorch ``state_dict``
|
||||
|
||||
Note: this approach may not work if your application doesn't have sufficient free CPU memory and
|
||||
you may need to use the offline approach using the ``zero_to_fp32.py`` script that is saved with
|
||||
the checkpoint.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import get_fp32_state_dict_from_zero_checkpoint
|
||||
# do the training and checkpoint saving
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir) # already on cpu
|
||||
model = model.cpu() # move to cpu
|
||||
model.load_state_dict(state_dict)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
In this example the ``model`` will no longer be usable in the deepspeed context of the same
|
||||
application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
If you want it all done for you, use ``load_state_dict_from_zero_checkpoint`` instead.
|
||||
|
||||
"""
|
||||
if tag is None:
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
|
||||
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
||||
|
||||
if not os.path.isdir(ds_checkpoint_dir):
|
||||
raise FileNotFoundError(f"Directory '{ds_checkpoint_dir}' doesn't exist")
|
||||
|
||||
return _get_fp32_state_dict_from_zero_checkpoint(ds_checkpoint_dir)
|
||||
|
||||
|
||||
def convert_zero_checkpoint_to_fp32_state_dict(checkpoint_dir, output_file, tag=None):
|
||||
"""
|
||||
Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict`` file that can be
|
||||
loaded with ``torch.load(file)`` + ``load_state_dict()`` and used for training without DeepSpeed.
|
||||
|
||||
Args:
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``output_file``: path to the pytorch fp32 state_dict output file (e.g. path/pytorch_model.bin)
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
"""
|
||||
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
||||
print(f"Saving fp32 state dict to {output_file}")
|
||||
torch.save(state_dict, output_file)
|
||||
|
||||
|
||||
def load_state_dict_from_zero_checkpoint(model, checkpoint_dir, tag=None):
|
||||
"""
|
||||
1. Put the provided model to cpu
|
||||
2. Convert ZeRO 2 or 3 checkpoint into a single fp32 consolidated ``state_dict``
|
||||
3. Load it into the provided model
|
||||
|
||||
Args:
|
||||
- ``model``: the model object to update
|
||||
- ``checkpoint_dir``: path to the desired checkpoint folder. (one that contains the tag-folder, like ``global_step14``)
|
||||
- ``tag``: checkpoint tag used as a unique identifier for checkpoint. If not provided will attempt to load tag in the file named ``latest`` in the checkpoint folder, e.g., ``global_step14``
|
||||
|
||||
Returns:
|
||||
- ``model`: modified model
|
||||
|
||||
Make sure you have plenty of CPU memory available before you call this function. If you don't
|
||||
have enough use the ``zero_to_fp32.py`` utility to do the conversion. You will find it
|
||||
conveniently placed for you in the checkpoint folder.
|
||||
|
||||
A typical usage might be ::
|
||||
|
||||
from deepspeed.utils.zero_to_fp32 import load_state_dict_from_zero_checkpoint
|
||||
model = load_state_dict_from_zero_checkpoint(trainer.model, checkpoint_dir)
|
||||
# submit to model hub or save the model to share with others
|
||||
|
||||
Note, that once this was run, the ``model`` will no longer be usable in the deepspeed context
|
||||
of the same application. i.e. you will need to re-initialize the deepspeed engine, since
|
||||
``model.load_state_dict(state_dict)`` will remove all the deepspeed magic from it.
|
||||
|
||||
"""
|
||||
logger.info(f"Extracting fp32 weights")
|
||||
state_dict = get_fp32_state_dict_from_zero_checkpoint(checkpoint_dir, tag)
|
||||
|
||||
logger.info(f"Overwriting model with fp32 weights")
|
||||
model = model.cpu()
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return model
|
||||
|
||||
def get_global_step_from_zero_checkpoint(checkpoint_dir):
|
||||
global_step = -1
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
match = re.match(r"global_step([0-9]+)", tag)
|
||||
global_step = int(match.group(1))
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
return global_step
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"checkpoint_dir",
|
||||
type=str,
|
||||
help="path to the desired checkpoint folder, e.g., path/checkpoint-12")
|
||||
parser.add_argument(
|
||||
"output_file",
|
||||
type=str,
|
||||
help=
|
||||
"path to the pytorch fp32 state_dict output file (e.g. path/checkpoint-12/pytorch_model.bin)"
|
||||
)
|
||||
parser.add_argument("-d", "--debug", action='store_true', help="enable debug")
|
||||
args = parser.parse_args()
|
||||
|
||||
debug = args.debug
|
||||
|
||||
convert_zero_checkpoint_to_fp32_state_dict(args.checkpoint_dir, args.output_file)
|
||||
3
setup.py
3
setup.py
@@ -52,7 +52,6 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
compute_capabilities = set([
|
||||
(3, 7), # K80, e.g.
|
||||
(5, 2), # Titan X
|
||||
(6, 1), # GeForce 1000-series
|
||||
])
|
||||
@@ -130,7 +129,7 @@ setup(
|
||||
classifiers=[
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Programming Language :: Python :: 3.9,'
|
||||
'Programming Language :: Python :: 3.10,'
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
],
|
||||
)
|
||||
|
||||
@@ -202,4 +202,4 @@ class TestModel(unittest.TestCase):
|
||||
out_repro = out_repro["sm"]["positions"][-1]
|
||||
out_repro = out_repro.squeeze(0)
|
||||
|
||||
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 1e-3)
|
||||
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 1e-3)
|
||||
|
||||
@@ -48,7 +48,8 @@ class TestPermutation(unittest.TestCase):
|
||||
self.chain_a_num_res = 9
|
||||
self.chain_b_num_res = 13
|
||||
# below create default fake ground truth structures for a hetero-pentamer A2B3
|
||||
self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
|
||||
self.residue_index = list(
|
||||
range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
|
||||
self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3
|
||||
self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [
|
||||
3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device)
|
||||
@@ -63,19 +64,44 @@ class TestPermutation(unittest.TestCase):
|
||||
'entity_id': self.entity_id,
|
||||
'seq_length': torch.tensor([57])
|
||||
}
|
||||
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
|
||||
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(
|
||||
batch, batch['asym_id'])
|
||||
anchor_gt_asym = int(anchor_gt_asym)
|
||||
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
|
||||
expected_anchors = {1, 2}
|
||||
expected_non_anchors = {3, 4, 5}
|
||||
|
||||
self.assertIn(anchor_gt_asym, expected_anchors)
|
||||
self.assertIn(anchor_gt_asym, expected_anchors)
|
||||
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
|
||||
# Check that predicted anchors are within expected anchor set
|
||||
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
|
||||
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
|
||||
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
|
||||
|
||||
def test_2_permutation_pentamer(self):
|
||||
"""
|
||||
Test the permutation results on a pentamer A2B3, in which protein A has 9 residues
|
||||
and protein B has 13 residues.
|
||||
|
||||
Expected outputs:
|
||||
Only protein A should be selected as an anchor thus, in the output list, either [(0,1), (1,0)] or [(0,0), (1,1)] are allowed
|
||||
The 3 chains from protein B should ALWAYS be aligned in a way that predicted b1 to be aligned with ground truth b1, pred b2 to ground truth b2
|
||||
as shown below:
|
||||
|
||||
predicted structure: a2 - a1 - b2 - b3 - b1
|
||||
indexes in the predicted list: 0 1 2 3 4
|
||||
|
||||
ground truth structure: a1 - a2 - b1 - b2 - b3
|
||||
indexes in the ground truth list: 0 1 2 3 4
|
||||
|
||||
then the 2 protein A chains are free to be aligned by either order, thus either [(0,1),(1,0)] or [(0,0),(1,1)] is valid.
|
||||
|
||||
However, the 3 protein B chains should be strictly aligned in the following order:
|
||||
[(2,3), (3,4), (4,2)], regardless of how protein A chains are aligned.
|
||||
|
||||
Therefore, the only 2 correct permutations are :
|
||||
[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)] and
|
||||
[(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
|
||||
"""
|
||||
batch = {
|
||||
'asym_id': self.asym_id,
|
||||
'sym_id': self.sym_id,
|
||||
@@ -85,7 +111,7 @@ class TestPermutation(unittest.TestCase):
|
||||
}
|
||||
batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res)
|
||||
batch["residue_index"] = torch.tensor([self.residue_index])
|
||||
# create fake ground truth atom positions
|
||||
# create fake ground truth atom positions
|
||||
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
|
||||
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
|
||||
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
|
||||
@@ -93,16 +119,22 @@ class TestPermutation(unittest.TestCase):
|
||||
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
|
||||
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
|
||||
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
|
||||
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
|
||||
# Below permutate predicted chain positions
|
||||
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
|
||||
chain_b3_pos = torch.matmul(torch.matmul(
|
||||
chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
|
||||
# Below permutate predicted chain positions
|
||||
# here the b2 chain from the ground truth is deliberately put in b1 chain's position, and predicted b3 chain to b2's position
|
||||
# and predicted b1 chain to b3's position
|
||||
pred_atom_position = torch.cat(
|
||||
(chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
|
||||
|
||||
pred_atom_mask = torch.ones((1, self.num_res, 37))
|
||||
out = {
|
||||
'final_atom_positions': pred_atom_position,
|
||||
'final_atom_mask': pred_atom_mask
|
||||
}
|
||||
|
||||
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
|
||||
true_atom_position = torch.cat(
|
||||
(chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
|
||||
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
|
||||
torch.ones((1, self.chain_a_num_res, 37)),
|
||||
torch.ones((1, self.chain_b_num_res, 37)),
|
||||
@@ -111,27 +143,47 @@ class TestPermutation(unittest.TestCase):
|
||||
batch['all_atom_positions'] = true_atom_position
|
||||
batch['all_atom_mask'] = true_atom_mask
|
||||
|
||||
aligns, _ = compute_permutation_alignment(out, batch,
|
||||
batch)
|
||||
print(f"##### aligns is {aligns}")
|
||||
possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]]
|
||||
wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]]
|
||||
self.assertIn(aligns, possible_outcome)
|
||||
self.assertNotIn(aligns, wrong_outcome)
|
||||
aligns, per_asym_residue_index = compute_permutation_alignment(out, batch,
|
||||
batch)
|
||||
|
||||
expected_asym_residue_index = {
|
||||
1: torch.tensor(list(range(self.chain_a_num_res))),
|
||||
2: torch.tensor(list(range(self.chain_a_num_res))),
|
||||
3: torch.tensor(list(range(self.chain_b_num_res))),
|
||||
4: torch.tensor(list(range(self.chain_b_num_res))),
|
||||
5: torch.tensor(list(range(self.chain_b_num_res)))
|
||||
}
|
||||
chain_a_permutated_chain_b_permutated = [
|
||||
(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)]
|
||||
chain_a_not_permutated_chain_b_permutated = [
|
||||
(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]
|
||||
chain_a_permutated_chain_b_not_permuated = [
|
||||
(0, 1), (1, 0), (2, 2), (3, 3), (4, 4)]
|
||||
chain_a_not_permutated_chain_b_not_permuated = [
|
||||
(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
|
||||
|
||||
# test on the permutation alignments
|
||||
self.assertIn(aligns, [chain_a_permutated_chain_b_permutated,
|
||||
chain_a_not_permutated_chain_b_permutated])
|
||||
self.assertNotIn(aligns, [chain_a_permutated_chain_b_not_permuated,
|
||||
chain_a_not_permutated_chain_b_not_permuated])
|
||||
|
||||
# test on the per_aysm_residue_index
|
||||
for k, v in expected_asym_residue_index.items():
|
||||
self.assertTrue(torch.equal(v, per_asym_residue_index[k]))
|
||||
|
||||
@unittest.skip("Test needs to be fixed post-refactor")
|
||||
def test_3_merge_labels(self):
|
||||
nres_pad = 325 - 57 # suppose the cropping size is 325
|
||||
batch = {
|
||||
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
|
||||
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
|
||||
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
|
||||
'aatype': torch.randint(21, size=(1, 325)),
|
||||
'asym_id': self.asym_id,
|
||||
'sym_id': self.sym_id,
|
||||
'entity_id': self.entity_id,
|
||||
'aatype': torch.randint(21, size=(1, 57)),
|
||||
'seq_length': torch.tensor([57])
|
||||
}
|
||||
batch['asym_id'] = batch['asym_id'].reshape(1, 325)
|
||||
batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
|
||||
# create fake ground truth atom positions
|
||||
batch['asym_id'] = batch['asym_id'].reshape(1, 57)
|
||||
batch["residue_index"] = torch.tensor([self.residue_index])
|
||||
# create fake ground truth atom positions
|
||||
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
|
||||
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
|
||||
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
|
||||
@@ -139,39 +191,64 @@ class TestPermutation(unittest.TestCase):
|
||||
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
|
||||
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
|
||||
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
|
||||
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
|
||||
# Below permutate predicted chain positions
|
||||
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
|
||||
chain_b3_pos = torch.matmul(torch.matmul(
|
||||
chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
|
||||
# Below permutate predicted chain positions
|
||||
pred_atom_position = torch.cat(
|
||||
(chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
|
||||
pred_atom_mask = torch.ones((1, self.num_res, 37))
|
||||
pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1)
|
||||
pred_atom_position = pad_features(
|
||||
pred_atom_position, nres_pad, pad_dim=1)
|
||||
pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1)
|
||||
out = {
|
||||
'final_atom_positions': pred_atom_position,
|
||||
'final_atom_mask': pred_atom_mask
|
||||
}
|
||||
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
|
||||
true_atom_position = torch.cat(
|
||||
(chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
|
||||
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
|
||||
torch.ones((1, self.chain_a_num_res, 37)),
|
||||
torch.ones((1, self.chain_b_num_res, 37)),
|
||||
torch.ones((1, self.chain_b_num_res, 37)),
|
||||
torch.ones((1, self.chain_b_num_res, 37))), dim=1)
|
||||
batch['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1)
|
||||
batch['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1)
|
||||
|
||||
# tensor_to_cuda = lambda t: t.to('cuda')
|
||||
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
|
||||
batch['all_atom_positions'] = true_atom_position
|
||||
batch['all_atom_mask'] = true_atom_mask
|
||||
|
||||
# Below create a fake_input_features
|
||||
fake_input_features = {
|
||||
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
|
||||
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
|
||||
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
|
||||
'aatype': torch.randint(21, size=(1, 325)),
|
||||
'seq_length': torch.tensor([57])
|
||||
}
|
||||
fake_input_features['asym_id'] = fake_input_features['asym_id'].reshape(
|
||||
1, 325)
|
||||
fake_input_features["residue_index"] = pad_features(
|
||||
torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
|
||||
fake_input_features['all_atom_positions'] = pad_features(
|
||||
true_atom_position, nres_pad, pad_dim=1)
|
||||
fake_input_features['all_atom_mask'] = pad_features(
|
||||
true_atom_mask, nres_pad=nres_pad, pad_dim=1)
|
||||
|
||||
# NOTE
|
||||
# batch: simulates ground_truth features
|
||||
# fake_input_features: simulates the data that are going be used as input for model.forward(fake_input_features)
|
||||
# out: simulates the output of model.forward(fake_input_features)
|
||||
aligns, per_asym_residue_index = compute_permutation_alignment(out,
|
||||
batch,
|
||||
fake_input_features,
|
||||
batch)
|
||||
print(f"##### aligns is {aligns}")
|
||||
labels = split_ground_truth_labels(batch)
|
||||
|
||||
labels = merge_labels(per_asym_residue_index, labels, aligns,
|
||||
original_nres=batch['aatype'].shape[-1])
|
||||
|
||||
self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index']))
|
||||
self.assertTrue(torch.equal(
|
||||
labels['residue_index'], batch['residue_index']))
|
||||
|
||||
expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos),
|
||||
dim=1)
|
||||
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
|
||||
self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos))
|
||||
|
||||
self.assertTrue(torch.equal(
|
||||
labels['all_atom_positions'], expected_permutated_gt_pos))
|
||||
|
||||
@@ -2,20 +2,25 @@ import argparse
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import json
|
||||
|
||||
import pytorch_lightning as pl
|
||||
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
|
||||
from pytorch_lightning.callbacks import DeviceStatsMonitor
|
||||
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.strategies import DeepSpeedStrategy, DDPStrategy
|
||||
from pytorch_lightning.strategies import DDPStrategy, DeepSpeedStrategy
|
||||
from pytorch_lightning.plugins.environments import MPIEnvironment
|
||||
from pytorch_lightning import seed_everything
|
||||
import torch
|
||||
import wandb
|
||||
from deepspeed.utils import zero_to_fp32
|
||||
|
||||
from openfold.config import model_config
|
||||
from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
|
||||
from openfold.model.model import AlphaFold
|
||||
from openfold.model.torchscript import script_preset_
|
||||
from openfold.np import residue_constants
|
||||
from openfold.utils.argparse_utils import remove_arguments
|
||||
from openfold.utils.callbacks import (
|
||||
EarlyStoppingVerbose,
|
||||
)
|
||||
@@ -23,7 +28,6 @@ from openfold.utils.exponential_moving_average import ExponentialMovingAverage
|
||||
from openfold.utils.loss import AlphaFoldLoss, lddt_ca
|
||||
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
|
||||
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
|
||||
from openfold.utils.seed import seed_everything
|
||||
from openfold.utils.superimposition import superimpose
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
from openfold.utils.validation_metrics import (
|
||||
@@ -35,11 +39,6 @@ from openfold.utils.import_weights import (
|
||||
import_jax_weights_,
|
||||
import_openfold_weights_
|
||||
)
|
||||
from scripts.zero_to_fp32 import (
|
||||
get_fp32_state_dict_from_zero_checkpoint,
|
||||
get_global_step_from_zero_checkpoint
|
||||
)
|
||||
|
||||
from openfold.utils.logger import PerformanceLoggingCallback
|
||||
|
||||
|
||||
@@ -58,6 +57,7 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
|
||||
self.cached_weights = None
|
||||
self.last_lr_step = -1
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, batch):
|
||||
return self.model(batch)
|
||||
@@ -66,16 +66,17 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
phase = "train" if train else "val"
|
||||
for loss_name, indiv_loss in loss_breakdown.items():
|
||||
self.log(
|
||||
f"{phase}/{loss_name}",
|
||||
indiv_loss,
|
||||
on_step=train, on_epoch=(not train), logger=True,
|
||||
f"{phase}/{loss_name}",
|
||||
indiv_loss,
|
||||
prog_bar=(loss_name == 'loss'),
|
||||
on_step=train, on_epoch=(not train), logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
if (train):
|
||||
self.log(
|
||||
f"{phase}/{loss_name}_epoch",
|
||||
indiv_loss,
|
||||
on_step=False, on_epoch=True, logger=True,
|
||||
on_step=False, on_epoch=True, logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -89,7 +90,8 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
self.log(
|
||||
f"{phase}/{k}",
|
||||
torch.mean(v),
|
||||
on_step=False, on_epoch=True, logger=True
|
||||
prog_bar = (k == 'loss'),
|
||||
on_step=False, on_epoch=True, logger=True, sync_dist=False,
|
||||
)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
@@ -152,8 +154,8 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
)
|
||||
|
||||
self._log(loss_breakdown, batch, outputs, train=False)
|
||||
|
||||
def on_validation_epoch_end(self, _):
|
||||
|
||||
def on_validation_epoch_end(self):
|
||||
# Restore the model weights to normal
|
||||
self.model.load_state_dict(self.cached_weights)
|
||||
self.cached_weights = None
|
||||
@@ -212,15 +214,10 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
|
||||
return metrics
|
||||
|
||||
def configure_optimizers(self,
|
||||
learning_rate: float = 1e-3,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.optim.Adam:
|
||||
# return torch.optim.Adam(
|
||||
# self.model.parameters(),
|
||||
# lr=learning_rate,
|
||||
# eps=eps
|
||||
# )
|
||||
def configure_optimizers(self,
|
||||
learning_rate: float = 1e-3,
|
||||
eps: float = 1e-5,
|
||||
) -> torch.optim.Adam:
|
||||
# Ignored as long as a DeepSpeed optimizer is configured
|
||||
optimizer = torch.optim.Adam(
|
||||
self.model.parameters(),
|
||||
@@ -271,37 +268,69 @@ class OpenFoldWrapper(pl.LightningModule):
|
||||
self.model, jax_path, version=model_version
|
||||
)
|
||||
|
||||
def get_model_state_dict_from_ds_checkpoint(checkpoint_dir):
|
||||
latest_path = os.path.join(checkpoint_dir, 'latest')
|
||||
if os.path.isfile(latest_path):
|
||||
with open(latest_path, 'r') as fd:
|
||||
tag = fd.read().strip()
|
||||
else:
|
||||
raise ValueError(f"Unable to find 'latest' file at {latest_path}")
|
||||
|
||||
ds_checkpoint_dir = os.path.join(checkpoint_dir, tag)
|
||||
_DS_CHECKPOINT_VERSION = 2 # based on manual parsing of checkpoint files
|
||||
state_file = zero_to_fp32.get_model_state_file(ds_checkpoint_dir, _DS_CHECKPOINT_VERSION)
|
||||
return torch.load(state_file)
|
||||
|
||||
def main(args):
|
||||
if (args.seed is not None):
|
||||
seed_everything(args.seed)
|
||||
if(args.seed is not None):
|
||||
seed_everything(args.seed, workers=True)
|
||||
|
||||
is_low_precision = args.precision in [
|
||||
"bf16-mixed", "16", "bf16", "16-true", "16-mixed", "bf16-mixed"]
|
||||
|
||||
config = model_config(
|
||||
args.config_preset,
|
||||
train=True,
|
||||
low_prec=(str(args.precision) == "16")
|
||||
)
|
||||
args.config_preset,
|
||||
train=True,
|
||||
low_prec=is_low_precision,
|
||||
)
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
model_module = OpenFoldWrapper(config)
|
||||
|
||||
if (args.resume_from_ckpt):
|
||||
if (os.path.isdir(args.resume_from_ckpt)):
|
||||
last_global_step = get_global_step_from_zero_checkpoint(
|
||||
args.resume_from_ckpt)
|
||||
else:
|
||||
sd = torch.load(args.resume_from_ckpt)
|
||||
if args.resume_from_ckpt:
|
||||
if args.resume_model_weights_only:
|
||||
# Load the checkpoint
|
||||
if os.path.isdir(args.resume_from_ckpt):
|
||||
sd = zero_to_fp32.get_fp32_state_dict_from_zero_checkpoint(
|
||||
args.resume_from_ckpt)
|
||||
else:
|
||||
sd = torch.load(args.resume_from_ckpt)
|
||||
# Process the state dict
|
||||
if 'module' in sd:
|
||||
sd = {k[len('module.'):]: v for k, v in sd['module'].items()}
|
||||
import_openfold_weights_(model=model_module, state_dict=sd)
|
||||
elif 'state_dict' in sd:
|
||||
import_openfold_weights_(
|
||||
model=model_module, state_dict=sd['state_dict'])
|
||||
else:
|
||||
# Loading from pre-trained model
|
||||
sd = {'model.'+k: v for k, v in sd.items()}
|
||||
import_openfold_weights_(model=model_module, state_dict=sd)
|
||||
logging.info("Successfully loaded model weights...")
|
||||
|
||||
else: # Loads a checkpoint to start from a specific time step
|
||||
if os.path.isdir(args.resume_from_ckpt):
|
||||
sd = get_model_state_dict_from_ds_checkpoint(args.resume_from_ckpt)
|
||||
else:
|
||||
sd = torch.load(args.resume_from_ckpt)
|
||||
last_global_step = int(sd['global_step'])
|
||||
model_module.resume_last_lr_step(last_global_step)
|
||||
logging.info("Successfully loaded last lr step...")
|
||||
if (args.resume_from_ckpt and args.resume_model_weights_only):
|
||||
if (os.path.isdir(args.resume_from_ckpt)):
|
||||
sd = get_fp32_state_dict_from_zero_checkpoint(
|
||||
args.resume_from_ckpt)
|
||||
else:
|
||||
sd = torch.load(args.resume_from_ckpt)
|
||||
sd = {k[len("module."):]: v for k, v in sd.items()}
|
||||
import_openfold_weights_(model=model_module, state_dict=sd)
|
||||
logging.info("Successfully loaded model weights...")
|
||||
if (args.resume_from_jax_params):
|
||||
model_module.resume_last_lr_step(last_global_step)
|
||||
logging.info("Successfully loaded last lr step...")
|
||||
|
||||
if args.resume_from_jax_params:
|
||||
model_module.load_from_jax(args.resume_from_jax_params)
|
||||
logging.info(
|
||||
f"Successfully loaded JAX parameters at {args.resume_from_jax_params}...")
|
||||
@@ -360,7 +389,20 @@ def main(args):
|
||||
callbacks.append(lr_monitor)
|
||||
|
||||
loggers = []
|
||||
if (args.wandb):
|
||||
is_rank_zero = args.mpi_plugin and (int(os.environ.get("PMI_RANK")) == 0)
|
||||
if(args.wandb):
|
||||
if args.mpi_plugin and is_rank_zero:
|
||||
wandb_init_dict = dict(
|
||||
name=args.experiment_name,
|
||||
project=args.wandb_project,
|
||||
id=args.wandb_id,
|
||||
dir=args.output_dir,
|
||||
resume="allow",
|
||||
anonymous=None,
|
||||
entity=args.wandb_entity
|
||||
)
|
||||
wandb.run = wandb.init(**wandb_init_dict)
|
||||
|
||||
wdb_logger = WandbLogger(
|
||||
name=args.experiment_name,
|
||||
save_dir=args.output_dir,
|
||||
@@ -370,27 +412,28 @@ def main(args):
|
||||
)
|
||||
loggers.append(wdb_logger)
|
||||
|
||||
if (args.deepspeed_config_path is not None):
|
||||
cluster_environment = MPIEnvironment() if args.mpi_plugin else None
|
||||
if(args.deepspeed_config_path is not None):
|
||||
strategy = DeepSpeedStrategy(
|
||||
config=args.deepspeed_config_path,
|
||||
cluster_environment=cluster_environment,
|
||||
)
|
||||
if (args.wandb):
|
||||
if(args.wandb and is_rank_zero):
|
||||
wdb_logger.experiment.save(args.deepspeed_config_path)
|
||||
wdb_logger.experiment.save("openfold/config.py")
|
||||
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
|
||||
strategy = DDPStrategy(find_unused_parameters=False)
|
||||
strategy = DDPStrategy(find_unused_parameters=False,
|
||||
cluster_environment=cluster_environment)
|
||||
else:
|
||||
strategy = None
|
||||
|
||||
if (args.wandb):
|
||||
|
||||
if(args.wandb and is_rank_zero):
|
||||
freeze_path = f"{wdb_logger.experiment.dir}/package_versions.txt"
|
||||
os.system(f"{sys.executable} -m pip freeze > {freeze_path}")
|
||||
wdb_logger.experiment.save(f"{freeze_path}")
|
||||
|
||||
# Raw dump of all args from pl.Trainer constructor
|
||||
trainer_kws = set([
|
||||
'accelerator', 'strategy', 'devices', 'num_nodes', 'precision', 'logger', 'callbacks', 'fast_dev_run', 'max_epochs', 'min_epochs', 'max_steps', 'min_steps', 'max_tim', 'limit_train_batches', 'limit_val_batches', 'limit_test_batches', 'limit_predict_batches', 'overfit_batches', 'val_check_interval', 'check_val_every_n_epoch', 'num_sanity_val_steps', 'log_every_n_steps', 'enable_checkpointing', 'enable_progress_bar', 'enable_model_summary', 'accumulate_grad_batches', 'gradient_clip_val', 'gradient_clip_algorithm', 'deterministic', 'benchmark', 'inference_mode', 'use_distributed_sampler', 'profiler', 'detect_anomaly', 'barebones', 'plugins', 'sync_batchnorm', 'reload_dataloaders_every_n_epochs', 'default_root_dir',
|
||||
])
|
||||
trainer_kws = ['num_nodes', 'precision', 'max_epochs', 'log_every_n_steps',
|
||||
'flush_logs_ever_n_steps', 'num_sanity_val_steps', 'reload_dataloaders_every_n_epochs']
|
||||
trainer_args = {k: v for k, v in vars(args).items() if k in trainer_kws}
|
||||
trainer_args.update({
|
||||
'default_root_dir': args.output_dir,
|
||||
@@ -400,6 +443,7 @@ def main(args):
|
||||
})
|
||||
trainer = pl.Trainer(**trainer_args)
|
||||
|
||||
|
||||
if (args.resume_model_weights_only):
|
||||
ckpt_path = None
|
||||
else:
|
||||
@@ -606,44 +650,41 @@ if __name__ == "__main__":
|
||||
help="Distillation alignment index. See the README for instructions."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experiment_config_json", default="", help="Path to a json file with custom config values to overwrite config setting",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus", type=int, default=1, help='For determining optimal strategy and effective batch size.'
|
||||
)
|
||||
parser.add_argument("--mpi_plugin", action="store_true", default=False,
|
||||
help="Whether to use MPI for parallele processing")
|
||||
|
||||
trainer_group = parser.add_argument_group(
|
||||
'Arguments to pass to PyTorch Lightning Trainer')
|
||||
trainer_group.add_argument(
|
||||
"--num_nodes", type=int, default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpus", type=int, default=1,
|
||||
trainer_group.add_argument(
|
||||
"--precision", type=str, default='bf16',
|
||||
help='Sets precision, lower precision improves runtime performance.',
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", type=str, default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--replace_sampler_ddp", type=bool_type, default=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
trainer_group.add_argument(
|
||||
"--max_epochs", type=int, default=1,
|
||||
)
|
||||
parser.add_argument(
|
||||
trainer_group.add_argument(
|
||||
"--log_every_n_steps", type=int, default=25,
|
||||
)
|
||||
parser.add_argument(
|
||||
trainer_group.add_argument(
|
||||
"--flush_logs_every_n_steps", type=int, default=5,
|
||||
)
|
||||
trainer_group.add_argument(
|
||||
"--num_sanity_val_steps", type=int, default=0,
|
||||
)
|
||||
|
||||
# parser = pl.Trainer.add_argparse_args(parser)
|
||||
#
|
||||
# # Disable the initial validation pass
|
||||
# parser.set_defaults(
|
||||
# num_sanity_val_steps=0,
|
||||
# )
|
||||
|
||||
# # Remove some buggy/redundant arguments introduced by the Trainer
|
||||
# remove_arguments(
|
||||
# parser,
|
||||
# [
|
||||
# "--accelerator",
|
||||
# "--resume_from_checkpoint",
|
||||
# "--reload_dataloaders_every_epoch",
|
||||
# "--reload_dataloaders_every_n_epochs",
|
||||
# ]
|
||||
# )
|
||||
trainer_group.add_argument(
|
||||
"--reload_dataloaders_every_n_epochs", type=int, default=1,
|
||||
)
|
||||
trainer_group.add_argument(
|
||||
"--accumulate_grad_batches", type=int, default=1,
|
||||
help="Accumulate gradients over k batches before next optimizer step.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -659,7 +700,4 @@ if __name__ == "__main__":
|
||||
raise ValueError(
|
||||
"Choose between loading pretrained Jax-weights and a checkpoint-path")
|
||||
|
||||
# This re-applies the training-time filters at the beginning of every epoch
|
||||
args.reload_dataloaders_every_n_epochs = 1
|
||||
|
||||
main(args)
|
||||
|
||||
Reference in New Issue
Block a user