mirror of
https://github.com/gcorso/DiffDock.git
synced 2026-06-04 09:54:21 +08:00
first commit v1.1
This commit is contained in:
218
README.md
218
README.md
@@ -1,52 +1,75 @@
|
||||
# DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking
|
||||
[](https://paperswithcode.com/sota/blind-docking-on-pdbbind?p=diffdock-diffusion-steps-twists-and-turns-for)
|
||||
[](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web)
|
||||
|
||||
### [Paper on arXiv](https://arxiv.org/abs/2210.01776)
|
||||
|
||||

|
||||
|
||||
### [Original paper on arXiv](https://arxiv.org/abs/2210.01776)
|
||||
|
||||
Implementation of DiffDock, state-of-the-art method for molecular docking, by Gabriele Corso*, Hannes Stark*, Bowen Jing*, Regina Barzilay and Tommi Jaakkola.
|
||||
This repository contains all code, instructions and model weights necessary to run the method or to retrain a model.
|
||||
This repository contains code and instructions to run the method.
|
||||
If you have any question, feel free to open an issue or reach out to us: [gcorso@mit.edu](gcorso@mit.edu), [hstark@mit.edu](hstark@mit.edu), [bjing@mit.edu](bjing@mit.edu).
|
||||
|
||||

|
||||
**Update February 2024:** We have released DiffDock-L, a new version of DiffDock that provides a significant improvement in performance and generalization capacity (see the description of the new version in [our new paper]()). By default the repository now runs the new model, please use GitHub commit history to run the original DiffDock model. Further we now provide instructions for Docker and to set up your own local UI interface.
|
||||
|
||||
The repository also contains all the scripts to run the baselines and generate the figures.
|
||||
Additionally, there are visualization videos in `visualizations`.
|
||||
|
||||
You might also be interested in this [Google Colab notebook](https://colab.research.google.com/drive/1CTtUGg05-2MtlWmfJhqzLTtkDDaxCDOQ#scrollTo=zlPOKLIBsiPU) to run DiffDock by Brian Naughton.
|
||||
|
||||
# Dataset
|
||||
|
||||
The files in `data` contain the names for the time-based data split.
|
||||
|
||||
If you want to train one of our models with the data then:
|
||||
1. download it from [zenodo](https://zenodo.org/record/6408497)
|
||||
2. unzip the directory and place it into `data` such that you have the path `data/PDBBind_processed`
|
||||
You can also try out the model on [Hugging Face Spaces](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web).
|
||||
|
||||
|
||||
|
||||
## Setup Environment
|
||||
<details><summary><b>Citation</b></summary>
|
||||
|
||||
We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html). Clone the
|
||||
current repo
|
||||
If you use this code or the models in your research, please cite the following paper:
|
||||
|
||||
git clone https://github.com/gcorso/DiffDock.git
|
||||
```bibtex
|
||||
@inproceedings{corso2023diffdock,
|
||||
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
|
||||
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
This is an example for how to set up a working conda environment to run the code (but make sure to use the correct pytorch, pytorch-geometric, cuda versions or cpu only versions):
|
||||
If you use the latest version, DiffDock-L, please also cite the following paper:
|
||||
|
||||
conda create --name diffdock python=3.9
|
||||
conda activate diffdock
|
||||
conda install pytorch==1.11.0 pytorch-cuda=11.7 -c pytorch -c nvidia
|
||||
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.0.4 -f https://data.pyg.org/whl/torch-1.11.0+cu117.html
|
||||
python -m pip install PyYAML scipy "networkx[default]" biopython rdkit-pypi e3nn spyrmsd pandas biopandas
|
||||
```bibtex
|
||||
@inproceedings{corso2024discovery,
|
||||
title={Deep Confident Steps to New Pockets: Strategies for Docking Generalization},
|
||||
author={Corso, Gabriele and Deng, Arthur and Polizzi, Nicholas and Barzilay, Regina and Jaakkola, Tommi},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
Then you need to install ESM that we use both for protein sequence embeddings and for the protein structure prediction in case you only have the sequence of your target. Note that OpenFold (and so ESMFold) requires a GPU. If you don't have a GPU, you can still use DiffDock with existing protein structures.
|
||||
```
|
||||
</details>
|
||||
|
||||
pip install "fair-esm[esmfold]"
|
||||
pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
|
||||
pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'
|
||||
<details open><summary><b>Table of contents</b></summary>
|
||||
|
||||
- [Usage](#usage)
|
||||
- [Quick Start](#quickstart)
|
||||
- [Setup Environment](#environment)
|
||||
- [Docking Prediction](#inference)
|
||||
- [FAQ](#faq)
|
||||
- [Datasets](#datasets)
|
||||
- [Replicate results](#replicate)
|
||||
- [Citations](#citations)
|
||||
- [License](#license)
|
||||
- [Acknowledgements](#acknowledgements)
|
||||
</details>
|
||||
|
||||
|
||||
# Running DiffDock on your own complexes
|
||||
|
||||
## Usage <a name="usage"></a>
|
||||
|
||||
### Quick Start <a name="quickstart"></a>
|
||||
|
||||
You can directly try out the model without the need of installing anything through [Hugging Face Spaces](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web). Credit for the current HF interface goes to Jacob Silterra and for the previous version to Simon Duerr.
|
||||
|
||||
### Setup Environment <a name="environment"></a>
|
||||
|
||||
|
||||
### Docking Prediction <a name="inference"></a>
|
||||
|
||||
We support multiple input formats depending on whether you only want to make predictions for a single complex or for many at once.\
|
||||
The protein inputs need to be `.pdb` files or sequences that will be folded with ESMFold. The ligand input can either be a SMILES string or a filetype that RDKit can read like `.sdf` or `.mol2`.
|
||||
|
||||
@@ -57,18 +80,60 @@ An example .csv is at `data/protein_ligand_example_csv.csv` and you would use it
|
||||
|
||||
And you are ready to run inference:
|
||||
|
||||
python -m inference --protein_ligand_csv data/protein_ligand_example_csv.csv --out_dir results/user_predictions_small --inference_steps 20 --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
|
||||
python -m inference --config inference_args.yaml --protein_ligand_csv data/protein_ligand_example_csv.csv --out_dir results/user_predictions_small
|
||||
|
||||
When providing the `.pdb` files you can run DiffDock also on CPU, however, if possible, we recommend using a GPU as the model runs significantly faster. Note that the first time you run DiffDock on a device the program will precompute and store in cache look-up tables for SO(2) and SO(3) distributions (typically takes a couple of minutes), this won't be repeated in following runs.
|
||||
|
||||
|
||||
# Retraining DiffDock
|
||||
Download the data and place it as described in the "Dataset" section above.
|
||||
## FAQ <a name="faq"></a>
|
||||
|
||||
<details>
|
||||
<summary><b>How to interpret the DiffDock output confidence score?</b> </summary>
|
||||
It can be hard to interpret and compare confidence score of different complexes or different protein conformations, however, here a rough guideline that we typically use (c is the confidence score of the top pose):
|
||||
|
||||
- c > 0 high confidence
|
||||
- -1.5 < c < 0 moderate confidence
|
||||
- c < -1.5 low confidence
|
||||
|
||||
This is assuming the complex is similar to what DiffDock saw in the training set i.e. a not too large drug-like molecule bound to medium size protein (1 or 2 chains) in a conformation that is similar to the bound one (e.g. if it comes from an homologue crystal structure). If you are dealing with a large ligand, a large protein complex and/or an app/unbound protein conformation you should shift these intervals down.
|
||||
</details>
|
||||
|
||||
|
||||
<details>
|
||||
<summary><b>Does DiffDock predict the binding affinity of the ligand to the protein?</b> </summary>
|
||||
No, DiffDock does not predict the binding affinity of the ligand to the protein. It predicts the 3D structure of the complex and it outputs a confidence score. This latter is a measure of the quality of the prediction, i.e. the model's confidence in its prediction of the binding structure. Several of our collaborators have seen this to have some correlation with binding affinity (intuitively if a ligand does not bind there will be no good pose), but it is not a direct measure of it.
|
||||
|
||||
We are working on better affinity prediction models, but in the meantime we recommend combining DiffDock's prediction with other tools such as docking function (e.g. GNINA), MM/GBSA or absolute binding free energy calculations. For this we recommend to first relax the DiffDock's structure predictions with the tool/force field used for the affinity prediction.
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>Can I use DiffDock for protein-protein or protein-nucleic acid interactions?</b> </summary>
|
||||
While the program might not throw and error when fed with a large biomolecules as input, the model has only been designed, trained and tested for small molecule docking to proteins. Therefore, DiffDock is only likely to be able to deal with small peptides and nucleic acids as ligands, we do not recommend using DiffDock for the interactions of larger biomolecules. For other interactions we recommend looking at [DiffDock-PP](https://github.com/ketatam/DiffDock-PP) (rigid protein-protein interactions), [AlphaFold-Multimer](https://github.com/google-deepmind/alphafold) (flexible protein-protein interactions) or [RoseTTAFold2NA](https://github.com/uw-ipd/RoseTTAFold2NA) (protein-nucleic acid interactions).
|
||||
</details>
|
||||
|
||||
## Datasets <a name="datasets"></a>
|
||||
|
||||
The files in `data` contain the splits used for the various datasets. Below instructions for how to download each of the different datasets used for training and evaluation:
|
||||
|
||||
- **PDBBind:** download the processed complexes from [zenodo](https://zenodo.org/record/6408497), unzip the directory and place it into `data` such that you have the path `data/PDBBind_processed`.
|
||||
- **BindingMOAD:** download the processed complexes from [zenodo](https://zenodo.org/records/10656052) under `BindingMOAD_2020_processed.tar`, unzip the directory and place it into `data` such that you have the path `data/BindingMOAD_2020_processed`.
|
||||
- **DockGen:** to evaluate the performance of `DiffDock-L` with this repository you should use directly the data from BindingMOAD above. For other purposes you can download exclusively the complexes of the DockGen benchmark already processed (e.g. chain cutoff) from [zenodo](https://zenodo.org/records/10656052) downloading the `DockGen.tar` file.
|
||||
- **PoseBusters:** download the processed complexes from [zenodo](https://zenodo.org/records/8278563).
|
||||
- **van der Mers:** the protein structures used for the van der Mers data augmentation strategy were downloaded [here](https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz).
|
||||
|
||||
|
||||
## Replicate results <a name="replicate"></a>
|
||||
|
||||
If you are interested in replicating the results of the original DiffDock paper please checkout to the following commit:
|
||||
|
||||
git checkout v1.0
|
||||
|
||||
Otherwise download the data and place it as described in the "Dataset" section above.
|
||||
|
||||
### Generate the ESM2 embeddings for the proteins
|
||||
First run:
|
||||
To avoid having to compute ESM embeddings every time we evaluate on a dataset we first cache them and then run the evaluation script. Here the instructions for generating these for PDBBind but it also applies similarly to the other benchmarks. First run the following command to save the list of ESM embeddings:
|
||||
|
||||
python datasets/pdbbind_lm_embedding_preparation.py
|
||||
python datasets/esm_embedding_preparation.py
|
||||
|
||||
Use the generated file `data/pdbbind_sequences.fasta` to generate the ESM2 language model embeddings using the library https://github.com/facebookresearch/esm by installing their repository and executing the following in their repository:
|
||||
|
||||
@@ -79,65 +144,54 @@ Then run the command:
|
||||
|
||||
python datasets/esm_embeddings_to_pt.py
|
||||
|
||||
### Using the provided model weights for evaluation
|
||||
We first generate the language model embeddings for the testset, then run inference with DiffDock, and then evaluate the files that DiffDock produced:
|
||||
### Run DiffDock-L
|
||||
|
||||
python datasets/esm_embedding_preparation.py --protein_ligand_csv data/testset_csv.csv --out_file data/prepared_for_esm_testset.fasta
|
||||
git clone https://github.com/facebookresearch/esm
|
||||
cd esm
|
||||
pip install -e .
|
||||
cd ..
|
||||
HOME=esm/model_weights python esm/scripts/extract.py esm2_t33_650M_UR50D data/prepared_for_esm_testset.fasta data/esm2_output --repr_layers 33 --include per_tok
|
||||
python -m inference --protein_ligand_csv data/testset_csv.csv --out_dir results/user_predictions_testset --inference_steps 20 --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
|
||||
python evaluate_files.py --results_path results/user_predictions_testset --file_to_exclude rank1.sdf --num_predictions 40
|
||||
For PDBBind:
|
||||
|
||||
<!--
|
||||
To predict binding structures using the provided model weights run:
|
||||
python -m evaluate --config inference_args.yaml --split_path data/splits/timesplit_test --split_path data/splits/timesplit_test --batch_size 10 --esm_embeddings_path data/esm2_embeddings.pt --data_dir data/PDBBind_processed/ --tqdm --split test --chain_cutoff 10 --dataset pdbbind
|
||||
|
||||
python -m evaluate --model_dir workdir/paper_score_model --ckpt best_ema_inference_epoch_model.pt --confidence_ckpt best_model_epoch75.pt --confidence_model_dir workdir/paper_confidence_model --run_name DiffDockInference --inference_steps 20 --split_path data/splits/timesplit_test --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
|
||||
For DockGen:
|
||||
|
||||
To additionally save the .sdf files of the generated molecules, add the flag `--save_visualisation`
|
||||
-->
|
||||
### Training a model yourself and using those weights
|
||||
Train the large score model:
|
||||
python -m evaluate --config inference_args.yaml --dataset moad --data_dir data/BindingMOAD_2020_processed --unroll_clusters --tqdm --split test --esm_embeddings_path data/moad_esm2_embeddings.pt --min_ligand_size 2 --moad_esm_embeddings_sequences_path data/moad_sequences_to_id.fasta --chain_cutoff 10 --batch_size 10
|
||||
|
||||
python -m train --run_name big_score_model --test_sigma_intervals --esm_embeddings_path data/esm2_3billion_embeddings.pt --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 19 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --batch_size 16 --ns 48 --nv 10 --num_conv_layers 6 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --val_inference_freq 5 --num_inference_complexes 500 --use_ema --distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 --scheduler_patience 30 --n_epochs 850
|
||||
For PoseBusters:
|
||||
|
||||
The model weights are saved in the `workdir` directory.
|
||||
python -m evaluate --config inference_args.yaml --data_dir data/posebusters_benchmark_set --tqdm --dataset posebusters --split_path data/splits/posebusters_benchmark_set_ids.txt --esm_embeddings_path data/posebusters_ESM.pt --chain_cutoff 10 --batch_size 10 --protein_file protein --ligand_file ligands
|
||||
|
||||
Train a small score model with higher maximum translation sigma that will be used to generate the samples for training the confidence model:
|
||||
To additionally save the .sdf files of the generated molecules, add the flag `--save_visualisation`.
|
||||
|
||||
python -m train --run_name small_score_model --test_sigma_intervals --esm_embeddings_path data/esm2_3billion_embeddings.pt --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 34 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --batch_size 16 --ns 24 --nv 6 --num_conv_layers 5 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --val_inference_freq 5 --num_inference_complexes 500 --use_ema --scheduler_patience 30 --n_epochs 300
|
||||
Note: the notebook `data/apo_alignment.ipynb` contains the code used to align the ESMFold-generated apo-structures to the holo-structures.
|
||||
|
||||
In practice, you could also likely achieve the same or better results by using the first score model for creating the samples to train the confidence model, but this is what we did in the paper.
|
||||
The score model used to generate the samples to train the confidence model does not have to be the same as the score model that is used with that confidence model during inference.
|
||||
## Citations <a name="citations"></a>
|
||||
If you use this code or the models in your research, please cite the following paper:
|
||||
|
||||
Train the confidence model by running the following:
|
||||
```bibtex
|
||||
@inproceedings{corso2023diffdock,
|
||||
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
|
||||
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
|
||||
python -m confidence.confidence_train --original_model_dir workdir/small_score_model --run_name confidence_model --inference_steps 20 --samples_per_complex 7 --batch_size 16 --n_epochs 100 --lr 3e-4 --scheduler_patience 50 --ns 24 --nv 6 --num_conv_layers 5 --dynamic_max_cross --scale_by_sigma --dropout 0.1 --all_atoms --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --esm_embeddings_path data/esm2_3billion_embeddings.pt --main_metric loss --main_metric_goal min --best_model_save_frequency 5 --rmsd_classification_cutoff 2 --cache_creation_id 1 --cache_ids_to_combine 1 2 3 4
|
||||
If you use the latest version of our model, DiffDock-L, please also cite the following paper:
|
||||
|
||||
first with `--cache_creation_id 1` then `--cache_creation_id 2` etc. up to 4
|
||||
```bibtex
|
||||
@inproceedings{corso2024discovery,
|
||||
title={Deep Confident Steps to New Pockets: Strategies for Docking Generalization},
|
||||
author={Corso, Gabriele and Deng, Arthur and Polizzi, Nicholas and Barzilay, Regina and Jaakkola, Tommi},
|
||||
booktitle={International Conference on Learning Representations (ICLR)},
|
||||
year={2024}
|
||||
}
|
||||
```
|
||||
|
||||
Now everything is trained and you can run inference with:
|
||||
## License <a name="license"></a>
|
||||
The code and model weights are released under MIT license. See the [LICENSE](LICENSE) file for details.
|
||||
|
||||
python -m evaluate --model_dir workdir/big_score_model --ckpt best_ema_inference_epoch_model.pt --confidence_ckpt best_model_epoch75.pt --confidence_model_dir workdir/confidence_model --run_name DiffDockInference --inference_steps 20 --split_path data/splits/timesplit_test --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
|
||||
Components of the code of the [spyrmsd](spyrmsd) package by Rocco Meli (also MIT license) were integrated in the repo.
|
||||
|
||||
Note: the notebook `data/apo_alignment.ipynb` contains the code used to align the ESMFold-generated apo-structures to the holo-structures.
|
||||
|
||||
## Citation
|
||||
@article{corso2023diffdock,
|
||||
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
|
||||
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
|
||||
journal={International Conference on Learning Representations (ICLR)},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
## License
|
||||
MIT
|
||||
|
||||
## Acknowledgements
|
||||
|
||||
We thank Wei Lu and Rachel Wu for pointing out some issues with the code.
|
||||
|
||||
|
||||

|
||||
## Acknowledgements <a name="acknowledgements"></a>
|
||||
We sincerely thank:
|
||||
* Jacob Silterra for his help with the publishing and deployment of the code.
|
||||
* Arthur Deng, Nicholas Polizzi and Ben Fry for their critical contributions to part of the code in this repository.
|
||||
* Wei Lu and Rachel Wu for pointing out some issues with the code.
|
||||
@@ -1,219 +0,0 @@
|
||||
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as initial pose
|
||||
import os
|
||||
|
||||
import plotly.express as px
|
||||
import time
|
||||
from argparse import FileType, ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import wandb
|
||||
from biopandas.pdb import PandasPdb
|
||||
from rdkit import Chem
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.pdbbind import read_mol
|
||||
from datasets.process_mols import read_molecule
|
||||
from utils.utils import read_strings_from_txt, get_symmetry_rmsd
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--config', type=FileType(mode='r'), default=None)
|
||||
parser.add_argument('--run_name', type=str, default='gnina_results', help='')
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--results_path', type=str, default='results/user_inference', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand.pdb', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
|
||||
parser.add_argument('--wandb', action='store_true', default=False, help='')
|
||||
parser.add_argument('--file_to_exclude', type=str, default=None, help='')
|
||||
parser.add_argument('--all_dirs_in_results', action='store_true', default=True, help='Evaluate all directories in the results path instead of using directly looking for the names')
|
||||
parser.add_argument('--num_predictions', type=int, default=10, help='')
|
||||
parser.add_argument('--no_id_in_filename', action='store_true', default=False, help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('Reading paths and names.')
|
||||
names = read_strings_from_txt(f'data/splits/timesplit_test')
|
||||
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
|
||||
results_path_containments = os.listdir(args.results_path)
|
||||
|
||||
if args.wandb:
|
||||
wandb.init(
|
||||
entity='coarse-graining-mit',
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
project=args.project,
|
||||
name=args.run_name,
|
||||
config=args
|
||||
)
|
||||
|
||||
all_times = []
|
||||
successful_names_list = []
|
||||
rmsds_list = []
|
||||
centroid_distances_list = []
|
||||
min_cross_distances_list = []
|
||||
min_self_distances_list = []
|
||||
without_rec_overlap_list = []
|
||||
start_time = time.time()
|
||||
for i, name in enumerate(tqdm(names)):
|
||||
mol = read_mol(args.data_dir, name, remove_hs=True)
|
||||
mol = Chem.RemoveAllHs(mol)
|
||||
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
|
||||
|
||||
if args.all_dirs_in_results:
|
||||
directory_with_name = [directory for directory in results_path_containments if name in directory][0]
|
||||
ligand_pos = []
|
||||
for i in range(args.num_predictions):
|
||||
file_paths = os.listdir(os.path.join(args.results_path, directory_with_name))
|
||||
file_path = [path for path in file_paths if f'rank{i+1}' in path][0]
|
||||
if args.file_to_exclude is not None and args.file_to_exclude in file_path: continue
|
||||
mol_pred = read_molecule(os.path.join(args.results_path, directory_with_name, file_path),remove_hs=True, sanitize=True)
|
||||
mol_pred = Chem.RemoveAllHs(mol_pred)
|
||||
ligand_pos.append(mol_pred.GetConformer().GetPositions())
|
||||
ligand_pos = np.asarray(ligand_pos)
|
||||
else:
|
||||
if not os.path.exists(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}')): raise Exception('path did not exists:', os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'))
|
||||
mol_pred = read_molecule(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'), remove_hs=True, sanitize=True)
|
||||
if mol_pred == None:
|
||||
print("Skipping ", name, ' because RDKIT could not read it.')
|
||||
continue
|
||||
mol_pred = Chem.RemoveAllHs(mol_pred)
|
||||
ligand_pos = np.asarray([np.array(mol_pred.GetConformer(i).GetPositions()) for i in range(args.num_predictions)])
|
||||
try:
|
||||
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred)
|
||||
except Exception as e:
|
||||
print("Using non corrected RMSD because of the error:", e)
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
|
||||
|
||||
rmsds_list.append(rmsd)
|
||||
centroid_distances_list.append(np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1))
|
||||
|
||||
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
|
||||
if not os.path.exists(rec_path):
|
||||
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
|
||||
rec = PandasPdb().read_pdb(rec_path)
|
||||
rec_df = rec.df['ATOM']
|
||||
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
|
||||
receptor_pos = np.tile(receptor_pos, (args.num_predictions, 1, 1))
|
||||
|
||||
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
|
||||
self_distances = np.linalg.norm(ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
|
||||
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
|
||||
min_cross_distances_list.append(np.min(cross_distances, axis=(1,2)))
|
||||
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
|
||||
successful_names_list.append(name)
|
||||
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
|
||||
performance_metrics = {}
|
||||
for overlap in ['', 'no_overlap_']:
|
||||
if 'no_overlap_' == overlap:
|
||||
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
|
||||
rmsds = np.array(rmsds_list)[without_rec_overlap]
|
||||
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
|
||||
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
|
||||
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
|
||||
successful_names = np.array(successful_names_list)[without_rec_overlap]
|
||||
else:
|
||||
rmsds = np.array(rmsds_list)
|
||||
centroid_distances = np.array(centroid_distances_list)
|
||||
min_cross_distances = np.array(min_cross_distances_list)
|
||||
min_self_distances = np.array(min_self_distances_list)
|
||||
successful_names = np.array(successful_names_list)
|
||||
|
||||
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), successful_names)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), np.array(min_cross_distances))
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), np.array(min_self_distances))
|
||||
|
||||
performance_metrics.update({
|
||||
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / args.num_predictions).__round__(2),
|
||||
f'{overlap}self_intersect_fraction': (100 * (min_self_distances < 0.4).sum() / len(min_self_distances) / args.num_predictions).__round__(2),
|
||||
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
|
||||
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
|
||||
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
|
||||
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
|
||||
|
||||
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
|
||||
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
|
||||
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
|
||||
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
|
||||
})
|
||||
|
||||
top5_rmsds = np.min(rmsds[:, :5], axis=1)
|
||||
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
|
||||
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
|
||||
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
|
||||
top10_rmsds = np.min(rmsds[:, :10], axis=1)
|
||||
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
|
||||
f'{overlap}top10_steric_clash_fraction': ( 100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
|
||||
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
|
||||
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
|
||||
})
|
||||
for k in performance_metrics:
|
||||
print(k, performance_metrics[k])
|
||||
|
||||
if args.wandb:
|
||||
wandb.log(performance_metrics)
|
||||
histogram_metrics_list = [('rmsd', rmsds[:,0]),
|
||||
('centroid_distance', centroid_distances[:,0]),
|
||||
('mean_rmsd', rmsds[:,0]),
|
||||
('mean_centroid_distance', centroid_distances[:,0])]
|
||||
histogram_metrics_list.append(('top5_rmsds', top5_rmsds))
|
||||
histogram_metrics_list.append(('top5_centroid_distances', top5_centroid_distances))
|
||||
histogram_metrics_list.append(('top10_rmsds', top10_rmsds))
|
||||
histogram_metrics_list.append(('top10_centroid_distances', top10_centroid_distances))
|
||||
|
||||
os.makedirs(f'.plotly_cache/baseline_cache', exist_ok=True)
|
||||
images = []
|
||||
for metric_name, metric in histogram_metrics_list:
|
||||
d = {args.results_path: metric}
|
||||
df = pd.DataFrame(data=d)
|
||||
fig = px.ecdf(df, width=900, height=600, range_x=[0, 40])
|
||||
fig.add_vline(x=2, annotation_text='2 A;', annotation_font_size=20, annotation_position="top right",
|
||||
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
|
||||
fig.add_vline(x=5, annotation_text='5 A;', annotation_font_size=20, annotation_position="top right",
|
||||
line_dash='dash', line_color='green', annotation_font_color='green')
|
||||
fig.update_xaxes(title=f'{metric_name} in Angstrom', title_font={"size": 20}, tickfont={"size": 20})
|
||||
fig.update_yaxes(title=f'Fraction of predictions with lower error', title_font={"size": 20},
|
||||
tickfont={"size": 20})
|
||||
fig.update_layout(autosize=False, margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='white',
|
||||
paper_bgcolor='white', legend_title_text='Method', legend_title_font_size=17,
|
||||
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ), )
|
||||
fig.update_xaxes(showgrid=True, gridcolor='lightgrey')
|
||||
fig.update_yaxes(showgrid=True, gridcolor='lightgrey')
|
||||
|
||||
fig.write_image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'))
|
||||
wandb.log({metric_name: wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}")})
|
||||
images.append(wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}"))
|
||||
wandb.log({'images': images})
|
||||
@@ -1,175 +0,0 @@
|
||||
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as
|
||||
# initial pose
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import time
|
||||
from argparse import ArgumentParser, FileType
|
||||
from datetime import datetime
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from biopandas.pdb import PandasPdb
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import AllChem, MolToPDBFile
|
||||
from scipy.spatial.distance import cdist
|
||||
|
||||
from datasets.pdbbind import read_mol
|
||||
from utils.utils import read_strings_from_txt
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--results_path', type=str, default='results/gnina_predictions', help='')
|
||||
parser.add_argument('--complex_names_path', type=str, default='data/splits/timesplit_test', help='')
|
||||
parser.add_argument('--seed_molecules_path', type=str, default=None, help='Use the molecules at seed molecule path as initialization and only search around them')
|
||||
parser.add_argument('--seed_molecule_filename', type=str, default='equibind_corrected.sdf', help='Use the molecules at seed molecule path as initialization and only search around them')
|
||||
parser.add_argument('--smina', action='store_true', default=False, help='')
|
||||
parser.add_argument('--no_gpu', action='store_true', default=False, help='')
|
||||
parser.add_argument('--exhaustiveness', type=int, default=8, help='')
|
||||
parser.add_argument('--num_cpu', type=int, default=16, help='')
|
||||
parser.add_argument('--pocket_mode', action='store_true', default=False, help='')
|
||||
parser.add_argument('--pocket_cutoff', type=int, default=5, help='')
|
||||
parser.add_argument('--num_modes', type=int, default=10, help='')
|
||||
parser.add_argument('--autobox_add', type=int, default=4, help='')
|
||||
parser.add_argument('--use_p2rank_pocket', action='store_true', default=False, help='')
|
||||
parser.add_argument('--skip_p2rank', action='store_true', default=False, help='')
|
||||
parser.add_argument('--prank_path', type=str, default='/Users/hstark/projects/p2rank_2.3/prank', help='')
|
||||
parser.add_argument('--skip_existing', action='store_true', default=False, help='')
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
class Logger(object):
|
||||
def __init__(self, logpath, syspart=sys.stdout):
|
||||
self.terminal = syspart
|
||||
self.log = open(logpath, "a")
|
||||
|
||||
def write(self, message):
|
||||
self.terminal.write(message)
|
||||
self.log.write(message)
|
||||
self.log.flush()
|
||||
|
||||
def flush(self):
|
||||
# this flush method is needed for python 3 compatibility.
|
||||
# this handles the flush command by doing nothing.
|
||||
# you might want to specify some extra behavior here.
|
||||
pass
|
||||
|
||||
|
||||
def log(*args):
|
||||
print(f'[{datetime.now()}]', *args)
|
||||
|
||||
|
||||
# parameters
|
||||
names = read_strings_from_txt(args.complex_names_path)
|
||||
|
||||
if os.path.exists(args.results_path) and not args.skip_existing:
|
||||
shutil.rmtree(args.results_path)
|
||||
os.makedirs(args.results_path, exist_ok=True)
|
||||
sys.stdout = Logger(logpath=f'{args.results_path}/gnina.log', syspart=sys.stdout)
|
||||
sys.stderr = Logger(logpath=f'{args.results_path}/error.log', syspart=sys.stderr)
|
||||
|
||||
p2rank_cache_path = "results/.p2rank_cache"
|
||||
if args.use_p2rank_pocket and not args.skip_p2rank:
|
||||
os.makedirs(p2rank_cache_path, exist_ok=True)
|
||||
pdb_files_cache = os.path.join(p2rank_cache_path,'pdb_files')
|
||||
os.makedirs(pdb_files_cache, exist_ok=True)
|
||||
with open(f"{p2rank_cache_path}/pdb_list_p2rank.txt", "w") as out:
|
||||
for name in names:
|
||||
shutil.copy(os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb'), f'{pdb_files_cache}/{name}_protein_processed.pdb')
|
||||
out.write(os.path.join('pdb_files', f'{name}_protein_processed.pdb\n'))
|
||||
cmd = f"bash {args.prank_path} predict {p2rank_cache_path}/pdb_list_p2rank.txt -o {p2rank_cache_path}/p2rank_output -threads 4"
|
||||
os.system(cmd)
|
||||
|
||||
|
||||
all_times = []
|
||||
start_time = time.time()
|
||||
for i, name in enumerate(names):
|
||||
os.makedirs(os.path.join(args.results_path, name), exist_ok=True)
|
||||
log('\n')
|
||||
log(f'complex {i} of {len(names)}')
|
||||
# call gnina to find binding pose
|
||||
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
|
||||
prediction_output_name = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.pdb')
|
||||
log_path = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.log')
|
||||
if args.seed_molecules_path is not None: seed_mol_path = os.path.join(args.seed_molecules_path, name, f'{args.seed_molecule_filename}')
|
||||
if args.skip_existing and os.path.exists(prediction_output_name): continue
|
||||
|
||||
if args.pocket_mode:
|
||||
mol = read_mol(args.data_dir, name, remove_hs=False)
|
||||
rec = PandasPdb().read_pdb(rec_path)
|
||||
rec_df = rec.get(s='c-alpha')
|
||||
rec_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
|
||||
lig_pos = mol.GetConformer().GetPositions()
|
||||
d = cdist(rec_pos, lig_pos)
|
||||
label = np.any(d < args.pocket_cutoff, axis=1)
|
||||
|
||||
if np.any(label):
|
||||
center_pocket = rec_pos[label].mean(axis=0)
|
||||
else:
|
||||
print("No pocket residue below minimum distance ", args.pocket_cutoff, "taking closest at", np.min(d))
|
||||
center_pocket = rec_pos[np.argmin(np.min(d, axis=1)[0])]
|
||||
radius_pocket = np.max(np.linalg.norm(lig_pos - center_pocket[None, :], axis=1))
|
||||
diameter_pocket = radius_pocket * 2
|
||||
center_x = center_pocket[0]
|
||||
size_x = diameter_pocket + 8
|
||||
center_y = center_pocket[1]
|
||||
size_y = diameter_pocket + 8
|
||||
center_z = center_pocket[2]
|
||||
size_z = diameter_pocket + 8
|
||||
|
||||
|
||||
mol_rdkit = read_mol(args.data_dir, name, remove_hs=False)
|
||||
single_time = time.time()
|
||||
|
||||
mol_rdkit.RemoveAllConformers()
|
||||
ps = AllChem.ETKDGv2()
|
||||
id = AllChem.EmbedMolecule(mol_rdkit, ps)
|
||||
if id == -1:
|
||||
print('rdkit pos could not be generated without using random pos. using random pos now.')
|
||||
ps.useRandomCoords = True
|
||||
AllChem.EmbedMolecule(mol_rdkit, ps)
|
||||
AllChem.MMFFOptimizeMolecule(mol_rdkit, confId=0)
|
||||
rdkit_mol_path = os.path.join(args.data_dir, name, f'{name}_rdkit_ligand.pdb')
|
||||
MolToPDBFile(mol_rdkit, rdkit_mol_path)
|
||||
|
||||
fallback_without_p2rank = False
|
||||
if args.use_p2rank_pocket:
|
||||
df = pd.read_csv(f'{p2rank_cache_path}/p2rank_output/{name}_protein_processed.pdb_predictions.csv')
|
||||
rdkit_lig_pos = mol_rdkit.GetConformer().GetPositions()
|
||||
diameter_pocket = np.max(cdist(rdkit_lig_pos, rdkit_lig_pos))
|
||||
size_x = diameter_pocket + args.autobox_add * 2
|
||||
size_y = diameter_pocket + args.autobox_add * 2
|
||||
size_z = diameter_pocket + args.autobox_add * 2
|
||||
if df.empty:
|
||||
fallback_without_p2rank = True
|
||||
else:
|
||||
center_x = df.iloc[0][' center_x']
|
||||
center_y = df.iloc[0][' center_y']
|
||||
center_z = df.iloc[0][' center_z']
|
||||
|
||||
|
||||
|
||||
log(f'processing {rec_path}')
|
||||
if not args.pocket_mode and not args.use_p2rank_pocket or fallback_without_p2rank:
|
||||
return_code = subprocess.run(
|
||||
f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --autobox_ligand {rec_path if args.seed_molecules_path is None else seed_mol_path} --autobox_add {args.autobox_add} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''}",
|
||||
shell=True)
|
||||
else:
|
||||
return_code = subprocess.run(
|
||||
f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''} --center_x {center_x} --center_y {center_y} --center_z {center_z} --size_x {size_x} --size_y {size_y} --size_z {size_z}",
|
||||
shell=True)
|
||||
log(return_code)
|
||||
all_times.append(time.time() - single_time)
|
||||
|
||||
log("single time: --- %s seconds ---" % (time.time() - single_time))
|
||||
log("time so far: --- %s seconds ---" % (time.time() - start_time))
|
||||
log('\n')
|
||||
log(all_times)
|
||||
log("--- %s seconds ---" % (time.time() - start_time))
|
||||
@@ -1,5 +0,0 @@
|
||||
for i in $(seq 0 15); do
|
||||
python baseline_tankbind_runtime.py --parallel_id $i --parallel_tot 16 --prank_path /data/rsg/nlp/hstark/TankBind/packages/p2rank_2.3/prank --data_dir /data/rsg/nlp/hstark/ligbind/data/PDBBind_processed --split_path /data/rsg/nlp/hstark/ligbind/data/splits/timesplit_test --results_path /data/rsg/nlp/hstark/ligbind/results/tankbind_16_worker_runtime --device cpu --skip_p2rank --num_workers 1 --skip_multiple_pocket_outputs &
|
||||
done
|
||||
wait
|
||||
|
||||
@@ -1,239 +0,0 @@
|
||||
|
||||
import copy
|
||||
import os
|
||||
|
||||
import plotly.express as px
|
||||
import time
|
||||
from argparse import FileType, ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import wandb
|
||||
from biopandas.pdb import PandasPdb
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import RemoveHs
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.pdbbind import read_mol
|
||||
from datasets.process_mols import read_molecule, read_sdf_or_mol2
|
||||
from utils.utils import read_strings_from_txt, get_symmetry_rmsd, remove_all_hs
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--config', type=FileType(mode='r'), default=None)
|
||||
parser.add_argument('--run_name', type=str, default='tankbind', help='')
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--renumbered_atoms_dir', type=str, default='../TankBind/examples/tankbind_pdb/renumber_atom_index_same_as_smiles', help='')
|
||||
parser.add_argument('--results_path', type=str, default='results/tankbind_top5', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
|
||||
parser.add_argument('--wandb', action='store_true', default=True, help='')
|
||||
parser.add_argument('--num_predictions', type=int, default=5, help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
names = read_strings_from_txt(f'data/splits/timesplit_test')
|
||||
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
|
||||
|
||||
if args.wandb:
|
||||
wandb.init(
|
||||
entity='coarse-graining-mit',
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
project=args.project,
|
||||
name=args.run_name,
|
||||
config=args
|
||||
)
|
||||
|
||||
all_times = []
|
||||
rmsds_list = []
|
||||
unsym_rmsds_list = []
|
||||
centroid_distances_list = []
|
||||
min_cross_distances_list = []
|
||||
min_self_distances_list = []
|
||||
made_prediction_list = []
|
||||
steric_clash_list = []
|
||||
without_rec_overlap_list = []
|
||||
|
||||
start_time = time.time()
|
||||
successful_names_list = []
|
||||
for i, name in enumerate(tqdm(names)):
|
||||
mol, _ = read_sdf_or_mol2(f"{args.renumbered_atoms_dir}/{name}.sdf", None)
|
||||
sm = Chem.MolToSmiles(mol)
|
||||
m_order = list(mol.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
|
||||
mol = Chem.RenumberAtoms(mol, m_order)
|
||||
mol = Chem.RemoveHs(mol)
|
||||
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
|
||||
|
||||
assert(os.path.exists(os.path.join(args.results_path, name, f'{name}_tankbind_0.sdf')))
|
||||
ligand_pos = []
|
||||
for i in range(args.num_predictions):
|
||||
if not os.path.exists(os.path.join(args.results_path, name, f'{name}_tankbind_{i}.sdf')): break
|
||||
mol_pred, _ = read_sdf_or_mol2(os.path.join(args.results_path, name, f'{name}_tankbind_{i}.sdf'),None)
|
||||
sm = Chem.MolToSmiles(mol_pred)
|
||||
m_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
|
||||
mol_pred = Chem.RenumberAtoms(mol_pred, m_order)
|
||||
mol_pred = RemoveHs(mol_pred)
|
||||
ligand_pos.append(np.array(mol_pred.GetConformer().GetPositions()))
|
||||
ligand_pos = np.asarray(ligand_pos)
|
||||
|
||||
try:
|
||||
unsym_rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
|
||||
rmsd = np.array(get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred))
|
||||
except Exception as e:
|
||||
print("Using non corrected RMSD because of the error:", e)
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
|
||||
|
||||
num_pockets = len(ligand_pos)
|
||||
unsym_rmsds_list.append(np.lib.pad(unsym_rmsd, (0,10-len(unsym_rmsd)), 'constant', constant_values=(0)) )
|
||||
rmsds_list.append(np.lib.pad(rmsd, (0,10-len(rmsd)), 'constant', constant_values=(0)) )
|
||||
centroid_distance = np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1)
|
||||
centroid_distances_list.append(np.lib.pad(centroid_distance, (0,10-len(rmsd)), 'constant', constant_values=(0)) )
|
||||
|
||||
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
|
||||
if not os.path.exists(rec_path):
|
||||
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
|
||||
rec = PandasPdb().read_pdb(rec_path)
|
||||
rec_df = rec.df['ATOM']
|
||||
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
|
||||
receptor_pos = np.tile(receptor_pos, (10, 1, 1))
|
||||
|
||||
ligand_pos_padded = np.lib.pad(ligand_pos, ((0,10-len(ligand_pos)), (0,0), (0,0)), 'constant', constant_values=(np.inf))
|
||||
ligand_pos_padded_zero = np.lib.pad(ligand_pos, ((0, 10 - len(ligand_pos)), (0, 0), (0, 0)), 'constant',constant_values=0)
|
||||
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos_padded[:, None, :, :], axis=-1)
|
||||
self_distances = np.linalg.norm(ligand_pos_padded_zero[:, :, None, :] - ligand_pos_padded_zero[:, None, :, :], axis=-1)
|
||||
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
|
||||
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
|
||||
min_cross_distance = np.min(cross_distances, axis=(1, 2))
|
||||
individual_made_prediction = np.lib.pad(np.ones(num_pockets), (0,10-len(rmsd)), 'constant', constant_values=(0))
|
||||
made_prediction_list.append(individual_made_prediction)
|
||||
min_cross_distances_list.append(min_cross_distance)
|
||||
successful_names_list.append(name)
|
||||
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
|
||||
|
||||
performance_metrics = {}
|
||||
for overlap in ['', 'no_overlap_']:
|
||||
if 'no_overlap_' == overlap:
|
||||
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
|
||||
unsym_rmsds = np.array(unsym_rmsds_list)[without_rec_overlap]
|
||||
rmsds = np.array(rmsds_list)[without_rec_overlap]
|
||||
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
|
||||
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
|
||||
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
|
||||
made_prediction = np.array(made_prediction_list)[without_rec_overlap]
|
||||
successful_names = np.array(successful_names_list)[without_rec_overlap]
|
||||
else:
|
||||
unsym_rmsds = np.array(unsym_rmsds_list)
|
||||
rmsds = np.array(rmsds_list)
|
||||
centroid_distances = np.array(centroid_distances_list)
|
||||
min_cross_distances = np.array(min_cross_distances_list)
|
||||
min_self_distances = np.array(min_self_distances_list)
|
||||
made_prediction = np.array(made_prediction_list)
|
||||
successful_names = np.array(successful_names_list)
|
||||
|
||||
inf_rmsds = copy.deepcopy(rmsds)
|
||||
inf_rmsds[~made_prediction.astype(bool)] = np.inf
|
||||
inf_centroid_distances = copy.deepcopy(centroid_distances)
|
||||
inf_centroid_distances[~made_prediction.astype(bool)] = np.inf
|
||||
|
||||
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), np.array(successful_names))
|
||||
np.save(os.path.join(args.results_path, f'{overlap}centroid_distances.npy'), centroid_distances)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), min_cross_distances)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), min_self_distances)
|
||||
|
||||
performance_metrics.update({
|
||||
f'{overlap}self_intersect_fraction': (100 * (min_self_distances[:, 0] < 0.4).sum() / len(min_self_distances[:, 0])),
|
||||
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances[:,0] < 0.4).sum() / len(min_cross_distances[:,0])),
|
||||
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
|
||||
f'{overlap}unsym_rmsds_below_2': (100 * (unsym_rmsds[:,0] < 2).sum() / len(unsym_rmsds[:,0])),
|
||||
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
|
||||
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
|
||||
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
|
||||
|
||||
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
|
||||
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
|
||||
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
|
||||
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
|
||||
})
|
||||
|
||||
top5_rmsds = np.min(inf_rmsds[:, :5], axis=1)
|
||||
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
|
||||
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
|
||||
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
|
||||
|
||||
|
||||
top10_rmsds = np.min(inf_rmsds[:, :10], axis=1)
|
||||
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top10_steric_clash_fraction': (100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
|
||||
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
|
||||
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
|
||||
})
|
||||
for k in performance_metrics:
|
||||
print(k, performance_metrics[k])
|
||||
|
||||
if args.wandb:
|
||||
wandb.log(performance_metrics)
|
||||
histogram_metrics_list = [('rmsd', rmsds[:,0]),
|
||||
('centroid_distance', centroid_distances[:,0]),
|
||||
('mean_rmsd', rmsds[:,0]),
|
||||
('mean_centroid_distance', centroid_distances[:,0])]
|
||||
histogram_metrics_list.append(('top5_rmsds', top5_rmsds))
|
||||
histogram_metrics_list.append(('top5_centroid_distances', top5_centroid_distances))
|
||||
histogram_metrics_list.append(('top10_rmsds', top10_rmsds))
|
||||
histogram_metrics_list.append(('top10_centroid_distances', top10_centroid_distances))
|
||||
|
||||
os.makedirs(f'.plotly_cache/baseline_cache', exist_ok=True)
|
||||
images = []
|
||||
for metric_name, metric in histogram_metrics_list:
|
||||
d = {args.results_path: metric}
|
||||
df = pd.DataFrame(data=d)
|
||||
fig = px.ecdf(df, width=900, height=600, range_x=[0, 40])
|
||||
fig.add_vline(x=2, annotation_text='2 A;', annotation_font_size=20, annotation_position="top right",
|
||||
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
|
||||
fig.add_vline(x=5, annotation_text='5 A;', annotation_font_size=20, annotation_position="top right",
|
||||
line_dash='dash', line_color='green', annotation_font_color='green')
|
||||
fig.update_xaxes(title=f'{metric_name} in Angstrom', title_font={"size": 20}, tickfont={"size": 20})
|
||||
fig.update_yaxes(title=f'Fraction of predictions with lower error', title_font={"size": 20},
|
||||
tickfont={"size": 20})
|
||||
fig.update_layout(autosize=False, margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='white',
|
||||
paper_bgcolor='white', legend_title_text='Method', legend_title_font_size=17,
|
||||
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ), )
|
||||
fig.update_xaxes(showgrid=True, gridcolor='lightgrey')
|
||||
fig.update_yaxes(showgrid=True, gridcolor='lightgrey')
|
||||
|
||||
fig.write_image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'))
|
||||
wandb.log({metric_name: wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}")})
|
||||
images.append(wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}"))
|
||||
wandb.log({'images': images})
|
||||
@@ -1,342 +0,0 @@
|
||||
# This file needs to be ran in the TANKBind repository together with baseline_run_tankbind_parallel.sh
|
||||
|
||||
import sys
|
||||
import time
|
||||
from multiprocessing import Pool
|
||||
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from rdkit.Chem import AllChem, RemoveHs
|
||||
|
||||
from feature_utils import save_cleaned_protein, read_mol
|
||||
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
|
||||
import logging
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm # pip install tqdm if fails.
|
||||
from model import get_model
|
||||
# from utils import *
|
||||
import torch
|
||||
|
||||
|
||||
from data import TankBind_prediction
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import rdkit.Chem as Chem
|
||||
from feature_utils import generate_sdf_from_smiles_using_rdkit
|
||||
from feature_utils import get_protein_feature
|
||||
from Bio.PDB import PDBParser
|
||||
from feature_utils import extract_torchdrug_feature_from_mol
|
||||
|
||||
|
||||
def read_strings_from_txt(path):
|
||||
# every line will be one element of the returned list
|
||||
with open(path) as file:
|
||||
lines = file.readlines()
|
||||
return [line.rstrip() for line in lines]
|
||||
|
||||
|
||||
def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
|
||||
if molecule_file.endswith('.mol2'):
|
||||
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
|
||||
elif molecule_file.endswith('.sdf'):
|
||||
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
|
||||
mol = supplier[0]
|
||||
elif molecule_file.endswith('.pdbqt'):
|
||||
with open(molecule_file) as file:
|
||||
pdbqt_data = file.readlines()
|
||||
pdb_block = ''
|
||||
for line in pdbqt_data:
|
||||
pdb_block += '{}\n'.format(line[:66])
|
||||
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
|
||||
elif molecule_file.endswith('.pdb'):
|
||||
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
|
||||
else:
|
||||
return ValueError('Expect the format of the molecule_file to be '
|
||||
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
|
||||
try:
|
||||
if sanitize or calc_charges:
|
||||
Chem.SanitizeMol(mol)
|
||||
|
||||
if calc_charges:
|
||||
# Compute Gasteiger charges on the molecule.
|
||||
try:
|
||||
AllChem.ComputeGasteigerCharges(mol)
|
||||
except:
|
||||
warnings.warn('Unable to compute charges for the molecule.')
|
||||
|
||||
if remove_hs:
|
||||
mol = Chem.RemoveHs(mol, sanitize=sanitize)
|
||||
except:
|
||||
return None
|
||||
|
||||
return mol
|
||||
|
||||
|
||||
def parallel_save_prediction(arguments):
|
||||
dataset, y_pred_list, chosen,rdkit_mol_path, result_folder, name = arguments
|
||||
for idx, line in chosen.iterrows():
|
||||
pocket_name = line['pocket_name']
|
||||
compound_name = line['compound_name']
|
||||
ligandName = compound_name.split("_")[1]
|
||||
dataset_index = line['dataset_index']
|
||||
coords = dataset[dataset_index].coords.to('cpu')
|
||||
protein_nodes_xyz = dataset[dataset_index].node_xyz.to('cpu')
|
||||
n_compound = coords.shape[0]
|
||||
n_protein = protein_nodes_xyz.shape[0]
|
||||
y_pred = y_pred_list[dataset_index].reshape(n_protein, n_compound).to('cpu')
|
||||
compound_pair_dis_constraint = torch.cdist(coords, coords)
|
||||
mol = Chem.MolFromMolFile(rdkit_mol_path)
|
||||
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
|
||||
pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
|
||||
LAS_distance_constraint_mask=LAS_distance_constraint_mask,
|
||||
n_repeat=1, show_progress=False)
|
||||
|
||||
toFile = f'{result_folder}/{name}_tankbind_chosen.sdf'
|
||||
new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
|
||||
write_with_new_coords(mol, new_coords, toFile)
|
||||
|
||||
if __name__ == '__main__':
|
||||
tankbind_src_folder = "../tankbind"
|
||||
sys.path.insert(0, tankbind_src_folder)
|
||||
torch.set_num_threads(16)
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--data_dir', type=str, default='/Users/hstark/projects/ligbind/data/PDBBind_processed', help='')
|
||||
parser.add_argument('--split_path', type=str, default='/Users/hstark/projects/ligbind/data/splits/timesplit_test', help='')
|
||||
parser.add_argument('--prank_path', type=str, default='/Users/hstark/projects/p2rank_2.3/prank', help='')
|
||||
parser.add_argument('--results_path', type=str, default='results/tankbind_results', help='')
|
||||
parser.add_argument('--skip_existing', action='store_true', default=False, help='')
|
||||
parser.add_argument('--skip_p2rank', action='store_true', default=False, help='')
|
||||
parser.add_argument('--skip_multiple_pocket_outputs', action='store_true', default=False, help='')
|
||||
parser.add_argument('--device', type=str, default='cpu', help='')
|
||||
parser.add_argument('--num_workers', type=int, default=1, help='')
|
||||
parser.add_argument('--parallel_id', type=int, default=0, help='')
|
||||
parser.add_argument('--parallel_tot', type=int, default=1, help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
device = args.device
|
||||
cache_path = "tankbind_cache"
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
os.makedirs(args.results_path, exist_ok=True)
|
||||
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = get_model(0, logging, device)
|
||||
# re-dock model
|
||||
# modelFile = "../saved_models/re_dock.pt"
|
||||
# self-dock model
|
||||
modelFile = f"{tankbind_src_folder}/../saved_models/self_dock.pt"
|
||||
|
||||
model.load_state_dict(torch.load(modelFile, map_location=device))
|
||||
_ = model.eval()
|
||||
batch_size = 5
|
||||
names = read_strings_from_txt(args.split_path)
|
||||
if args.parallel_tot > 1:
|
||||
size = len(names) // args.parallel_tot + 1
|
||||
names = names[args.parallel_id*size:(args.parallel_id+1)*size]
|
||||
rmsds = []
|
||||
|
||||
forward_pass_time = []
|
||||
times_preprocess = []
|
||||
times_inference = []
|
||||
top_10_generation_time = []
|
||||
top_1_generation_time = []
|
||||
start_time = time.time()
|
||||
if not args.skip_p2rank:
|
||||
for name in names:
|
||||
if args.skip_existing and os.path.exists(f'{args.results_path}/{name}/{name}_tankbind_1.sdf'): continue
|
||||
print("Now processing: ", name)
|
||||
protein_path = f'{args.data_dir}/{name}/{name}_protein_processed.pdb'
|
||||
cleaned_protein_path = f"{cache_path}/{name}_protein_tankbind_cleaned.pdb" # if you change this you also need to change below
|
||||
parser = PDBParser(QUIET=True)
|
||||
s = parser.get_structure(name, protein_path)
|
||||
c = s[0]
|
||||
clean_res_list, ligand_list = save_cleaned_protein(c, cleaned_protein_path)
|
||||
|
||||
with open(f"{cache_path}/pdb_list_p2rank.txt", "w") as out:
|
||||
for name in names:
|
||||
out.write(f"{name}_protein_tankbind_cleaned.pdb\n")
|
||||
cmd = f"bash {args.prank_path} predict {cache_path}/pdb_list_p2rank.txt -o {cache_path}/p2rank -threads 4"
|
||||
os.system(cmd)
|
||||
times_preprocess.append(time.time() - start_time)
|
||||
p2_rank_time = time.time() - start_time
|
||||
|
||||
|
||||
|
||||
|
||||
list_to_parallelize = []
|
||||
for name in tqdm(names):
|
||||
single_preprocess_time = time.time()
|
||||
if args.skip_existing and os.path.exists(f'{args.results_path}/{name}/{name}_tankbind_1.sdf'): continue
|
||||
print("Now processing: ", name)
|
||||
protein_path = f'{args.data_dir}/{name}/{name}_protein_processed.pdb'
|
||||
ligand_path = f"{args.data_dir}/{name}/{name}_ligand.sdf"
|
||||
cleaned_protein_path = f"{cache_path}/{name}_protein_tankbind_cleaned.pdb" # if you change this you also need to change below
|
||||
rdkit_mol_path = f"{cache_path}/{name}_rdkit_ligand.sdf"
|
||||
|
||||
parser = PDBParser(QUIET=True)
|
||||
s = parser.get_structure(name, protein_path)
|
||||
c = s[0]
|
||||
clean_res_list, ligand_list = save_cleaned_protein(c, cleaned_protein_path)
|
||||
lig, _ = read_mol(f"{args.data_dir}/{name}/{name}_ligand.sdf", f"{args.data_dir}/{name}/{name}_ligand.mol2")
|
||||
|
||||
lig = RemoveHs(lig)
|
||||
smiles = Chem.MolToSmiles(lig)
|
||||
generate_sdf_from_smiles_using_rdkit(smiles, rdkit_mol_path, shift_dis=0)
|
||||
|
||||
parser = PDBParser(QUIET=True)
|
||||
s = parser.get_structure("x", cleaned_protein_path)
|
||||
res_list = list(s.get_residues())
|
||||
|
||||
protein_dict = {}
|
||||
protein_dict[name] = get_protein_feature(res_list)
|
||||
compound_dict = {}
|
||||
|
||||
mol = Chem.MolFromMolFile(rdkit_mol_path)
|
||||
compound_dict[name + f"_{name}" + "_rdkit"] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
|
||||
|
||||
info = []
|
||||
for compound_name in list(compound_dict.keys()):
|
||||
# use protein center as the block center.
|
||||
com = ",".join([str(a.round(3)) for a in protein_dict[name][0].mean(axis=0).numpy()])
|
||||
info.append([name, compound_name, "protein_center", com])
|
||||
|
||||
p2rankFile = f"{cache_path}/p2rank/{name}_protein_tankbind_cleaned.pdb_predictions.csv"
|
||||
pocket = pd.read_csv(p2rankFile)
|
||||
pocket.columns = pocket.columns.str.strip()
|
||||
pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values
|
||||
for ith_pocket, com in enumerate(pocket_coms):
|
||||
com = ",".join([str(a.round(3)) for a in com])
|
||||
info.append([name, compound_name, f"pocket_{ith_pocket + 1}", com])
|
||||
info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pocket_name', 'pocket_com'])
|
||||
|
||||
dataset_path = f"{cache_path}/{name}_dataset/"
|
||||
os.system(f"rm -r {dataset_path}")
|
||||
os.system(f"mkdir -p {dataset_path}")
|
||||
dataset = TankBind_prediction(dataset_path, data=info, protein_dict=protein_dict, compound_dict=compound_dict)
|
||||
|
||||
# dataset = TankBind_prediction(dataset_path)
|
||||
times_preprocess.append(time.time() - single_preprocess_time)
|
||||
single_forward_pass_time = time.time()
|
||||
data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False,
|
||||
num_workers=0)
|
||||
affinity_pred_list = []
|
||||
y_pred_list = []
|
||||
for data in tqdm(data_loader):
|
||||
data = data.to(device)
|
||||
y_pred, affinity_pred = model(data)
|
||||
affinity_pred_list.append(affinity_pred.detach().cpu())
|
||||
for i in range(data.y_batch.max() + 1):
|
||||
y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu())
|
||||
|
||||
affinity_pred_list = torch.cat(affinity_pred_list)
|
||||
forward_pass_time.append(time.time() - single_forward_pass_time)
|
||||
output_info = copy.deepcopy(dataset.data)
|
||||
output_info['affinity'] = affinity_pred_list
|
||||
output_info['dataset_index'] = range(len(output_info))
|
||||
output_info_sorted = output_info.sort_values('affinity', ascending=False)
|
||||
|
||||
|
||||
result_folder = f'{args.results_path}/{name}'
|
||||
os.makedirs(result_folder, exist_ok=True)
|
||||
output_info_sorted.to_csv(f"{result_folder}/output_info_sorted_by_affinity.csv")
|
||||
|
||||
if not args.skip_multiple_pocket_outputs:
|
||||
for idx, (dataframe_idx, line) in enumerate(copy.deepcopy(output_info_sorted).iterrows()):
|
||||
single_top10_generation_time = time.time()
|
||||
pocket_name = line['pocket_name']
|
||||
compound_name = line['compound_name']
|
||||
ligandName = compound_name.split("_")[1]
|
||||
coords = dataset[dataframe_idx].coords.to('cpu')
|
||||
protein_nodes_xyz = dataset[dataframe_idx].node_xyz.to('cpu')
|
||||
n_compound = coords.shape[0]
|
||||
n_protein = protein_nodes_xyz.shape[0]
|
||||
y_pred = y_pred_list[dataframe_idx].reshape(n_protein, n_compound).to('cpu')
|
||||
y = dataset[dataframe_idx].dis_map.reshape(n_protein, n_compound).to('cpu')
|
||||
compound_pair_dis_constraint = torch.cdist(coords, coords)
|
||||
mol = Chem.MolFromMolFile(rdkit_mol_path)
|
||||
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
|
||||
pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
|
||||
LAS_distance_constraint_mask=LAS_distance_constraint_mask,
|
||||
n_repeat=1, show_progress=False)
|
||||
|
||||
toFile = f'{result_folder}/{name}_tankbind_{idx}.sdf'
|
||||
new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
|
||||
write_with_new_coords(mol, new_coords, toFile)
|
||||
if idx < 10:
|
||||
top_10_generation_time.append(time.time() - single_top10_generation_time)
|
||||
if idx == 0:
|
||||
top_1_generation_time.append(time.time() - single_top10_generation_time)
|
||||
|
||||
output_info_chosen = copy.deepcopy(dataset.data)
|
||||
output_info_chosen['affinity'] = affinity_pred_list
|
||||
output_info_chosen['dataset_index'] = range(len(output_info_chosen))
|
||||
chosen = output_info_chosen.loc[
|
||||
output_info_chosen.groupby(['protein_name', 'compound_name'], sort=False)['affinity'].agg(
|
||||
'idxmax')].reset_index()
|
||||
|
||||
list_to_parallelize.append((dataset, y_pred_list, chosen, rdkit_mol_path, result_folder, name))
|
||||
|
||||
chosen_generation_start_time = time.time()
|
||||
if args.num_workers > 1:
|
||||
p = Pool(args.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(list_to_parallelize), desc=f'running optimization {i}/{len(list_to_parallelize)}') as pbar:
|
||||
map_fn = p.imap_unordered if args.num_workers > 1 else map
|
||||
for t in map_fn(parallel_save_prediction, list_to_parallelize):
|
||||
pbar.update()
|
||||
if args.num_workers > 1: p.__exit__(None, None, None)
|
||||
chosen_generation_time = time.time() - chosen_generation_start_time
|
||||
"""
|
||||
lig, _ = read_mol(f"{args.data_dir}/{name}/{name}_ligand.sdf", f"{args.data_dir}/{name}/{name}_ligand.mol2")
|
||||
sm = Chem.MolToSmiles(lig)
|
||||
m_order = list(lig.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
|
||||
lig = Chem.RenumberAtoms(lig, m_order)
|
||||
lig = Chem.RemoveAllHs(lig)
|
||||
lig = RemoveHs(lig)
|
||||
true_ligand_pos = np.array(lig.GetConformer().GetPositions())
|
||||
|
||||
toFile = f'{result_folder}/{name}_tankbind_chosen.sdf'
|
||||
mol_pred, _ = read_mol(toFile, None)
|
||||
sm = Chem.MolToSmiles(mol_pred)
|
||||
m_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
|
||||
mol_pred = Chem.RenumberAtoms(mol_pred, m_order)
|
||||
mol_pred = RemoveHs(mol_pred)
|
||||
mol_pred_pos = np.array(mol_pred.GetConformer().GetPositions())
|
||||
rmsds.append(np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0)))
|
||||
print(np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0)))
|
||||
"""
|
||||
forward_pass_time = np.array(forward_pass_time).sum()
|
||||
times_preprocess = np.array(times_preprocess).sum()
|
||||
times_inference = np.array(times_inference).sum()
|
||||
top_10_generation_time = np.array(top_10_generation_time).sum()
|
||||
top_1_generation_time = np.array(top_1_generation_time).sum()
|
||||
|
||||
rmsds = np.array(rmsds)
|
||||
|
||||
print(f'forward_pass_time: {forward_pass_time}')
|
||||
print(f'times_preprocess: {times_preprocess}')
|
||||
print(f'times_inference: {times_inference}')
|
||||
print(f'top_10_generation_time: {top_10_generation_time}')
|
||||
print(f'top_1_generation_time: {top_1_generation_time}')
|
||||
print(f'chosen_generation_time: {chosen_generation_time}')
|
||||
print(f'rmsds_below_2: {(100 * (rmsds < 2).sum() / len(rmsds))}')
|
||||
print(f'p2rank Time: {p2_rank_time}')
|
||||
print(
|
||||
f'total_time: '
|
||||
f'{forward_pass_time + times_preprocess + times_inference + top_10_generation_time + top_1_generation_time + p2_rank_time}')
|
||||
|
||||
with open(os.path.join(args.results_path, 'tankbind_log.log'), 'w') as file:
|
||||
file.write(f'forward_pass_time: {forward_pass_time}')
|
||||
file.write(f'times_preprocess: {times_preprocess}')
|
||||
file.write(f'times_inference: {times_inference}')
|
||||
file.write(f'top_10_generation_time: {top_10_generation_time}')
|
||||
file.write(f'top_1_generation_time: {top_1_generation_time}')
|
||||
file.write(f'rmsds_below_2: {(100 * (rmsds < 2).sum() / len(rmsds))}')
|
||||
file.write(f'p2rank Time: {p2_rank_time}')
|
||||
file.write(f'total_time: {forward_pass_time + times_preprocess + times_inference + top_10_generation_time + top_1_generation_time + p2_rank_time}')
|
||||
@@ -1,363 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "477e3cc1-4143-47c7-8278-80b0cabca0ab",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from tqdm import tqdm\n",
|
||||
"import shutil\n",
|
||||
"import sys\n",
|
||||
"import copy\n",
|
||||
"import warnings\n",
|
||||
"import numpy as np\n",
|
||||
"import pandas as pd\n",
|
||||
"from scipy import spatial as spa\n",
|
||||
"from scipy.optimize import minimize, Bounds\n",
|
||||
"from rdkit import Chem\n",
|
||||
"import Bio\n",
|
||||
"import scipy.spatial as spa\n",
|
||||
"from Bio.PDB import PDBParser\n",
|
||||
"from Bio.PDB.PDBExceptions import PDBConstructionWarning\n",
|
||||
"from rdkit.Chem.rdchem import BondType as BT\n",
|
||||
"from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs\n",
|
||||
"from rdkit.Geometry import Point3D\n",
|
||||
"from scipy import spatial\n",
|
||||
"from scipy.special import softmax"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "ccfaa84e-d6e6-4b24-a0a8-95136256127b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"biopython_parser = PDBParser()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def safe_index(l, e):\n",
|
||||
" \"\"\" Return index of element e in list l. If e is not present, return the last index \"\"\"\n",
|
||||
" try:\n",
|
||||
" return l.index(e)\n",
|
||||
" except:\n",
|
||||
" return len(l) - 1\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parse_receptor(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):\n",
|
||||
" rec = parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file)\n",
|
||||
" return rec\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):\n",
|
||||
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')\n",
|
||||
" if not os.path.exists(rec_path) or use_full_size_file or use_original_protein_file:\n",
|
||||
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_obabel_reduce.pdb')\n",
|
||||
" if not os.path.exists(rec_path) or use_original_protein_file:\n",
|
||||
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein.pdb')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" return parse_pdb_from_path(rec_path)\n",
|
||||
"\n",
|
||||
"def parse_pdb_from_path(path):\n",
|
||||
" with warnings.catch_warnings():\n",
|
||||
" warnings.filterwarnings(\"ignore\", category=PDBConstructionWarning)\n",
|
||||
" structure = biopython_parser.get_structure('random_id', path)\n",
|
||||
" rec = structure[0]\n",
|
||||
" return rec\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):\n",
|
||||
" if molecule_file.endswith('.mol2'):\n",
|
||||
" mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)\n",
|
||||
" elif molecule_file.endswith('.sdf'):\n",
|
||||
" supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)\n",
|
||||
" mol = supplier[0]\n",
|
||||
" elif molecule_file.endswith('.pdbqt'):\n",
|
||||
" with open(molecule_file) as file:\n",
|
||||
" pdbqt_data = file.readlines()\n",
|
||||
" pdb_block = ''\n",
|
||||
" for line in pdbqt_data:\n",
|
||||
" pdb_block += '{}\\n'.format(line[:66])\n",
|
||||
" mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)\n",
|
||||
" elif molecule_file.endswith('.pdb'):\n",
|
||||
" mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)\n",
|
||||
" else:\n",
|
||||
" return ValueError('Expect the format of the molecule_file to be '\n",
|
||||
" 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" if sanitize or calc_charges:\n",
|
||||
" Chem.SanitizeMol(mol)\n",
|
||||
"\n",
|
||||
" if calc_charges:\n",
|
||||
" # Compute Gasteiger charges on the molecule.\n",
|
||||
" try:\n",
|
||||
" AllChem.ComputeGasteigerCharges(mol)\n",
|
||||
" except:\n",
|
||||
" warnings.warn('Unable to compute charges for the molecule.')\n",
|
||||
"\n",
|
||||
" if remove_hs:\n",
|
||||
" mol = Chem.RemoveHs(mol, sanitize=sanitize)\n",
|
||||
" except:\n",
|
||||
" return None\n",
|
||||
"\n",
|
||||
" return mol\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def read_sdf_or_mol2(sdf_fileName):\n",
|
||||
"\n",
|
||||
" try:\n",
|
||||
" mol = read_molecule(sdf_fileName)\n",
|
||||
" except Exception as e:\n",
|
||||
" mol2_fileName = sdf_fileName[:-3] + \"mol2\"\n",
|
||||
" mol = read_molecule(mol2_fileName)\n",
|
||||
" return mol\n",
|
||||
"\n",
|
||||
"def read_mols(pdbbind_dir, name, remove_hs=False):\n",
|
||||
" ligs = []\n",
|
||||
" for file in os.listdir(os.path.join(pdbbind_dir, name)):\n",
|
||||
" if file.endswith(\".sdf\") and 'rdkit' not in file:\n",
|
||||
" lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)\n",
|
||||
" if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + \".mol2\")): # read mol2 file if sdf file cannot be sanitized\n",
|
||||
" #print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')\n",
|
||||
" lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + \".mol2\"), remove_hs=remove_hs, sanitize=True)\n",
|
||||
" if lig is not None:\n",
|
||||
" ligs.append(lig)\n",
|
||||
" return ligs\n",
|
||||
"\n",
|
||||
"def extract_receptor_structure(rec, lig, cutoff=10000, lm_embedding_chains=None):\n",
|
||||
" conf = lig.GetConformer()\n",
|
||||
" lig_coords = conf.GetPositions()\n",
|
||||
" min_distances = []\n",
|
||||
" coords = []\n",
|
||||
" c_alpha_coords = []\n",
|
||||
" n_coords = []\n",
|
||||
" c_coords = []\n",
|
||||
" valid_chain_ids = []\n",
|
||||
" lengths = []\n",
|
||||
" for i, chain in enumerate(rec):\n",
|
||||
" chain_coords = [] # num_residues, num_atoms, 3\n",
|
||||
" chain_c_alpha_coords = []\n",
|
||||
" chain_n_coords = []\n",
|
||||
" chain_c_coords = []\n",
|
||||
" count = 0\n",
|
||||
" invalid_res_ids = []\n",
|
||||
" for res_idx, residue in enumerate(chain):\n",
|
||||
" if residue.get_resname() == 'HOH':\n",
|
||||
" invalid_res_ids.append(residue.get_id())\n",
|
||||
" continue\n",
|
||||
" residue_coords = []\n",
|
||||
" c_alpha, n, c = None, None, None\n",
|
||||
" for atom in residue:\n",
|
||||
" if atom.name == 'CA':\n",
|
||||
" c_alpha = list(atom.get_vector())\n",
|
||||
" if atom.name == 'N':\n",
|
||||
" n = list(atom.get_vector())\n",
|
||||
" if atom.name == 'C':\n",
|
||||
" c = list(atom.get_vector())\n",
|
||||
" residue_coords.append(list(atom.get_vector()))\n",
|
||||
"\n",
|
||||
" if c_alpha != None and n != None and c != None:\n",
|
||||
" # only append residue if it is an amino acid and not some weird molecule that is part of the complex\n",
|
||||
" chain_c_alpha_coords.append(c_alpha)\n",
|
||||
" chain_n_coords.append(n)\n",
|
||||
" chain_c_coords.append(c)\n",
|
||||
" chain_coords.append(np.array(residue_coords))\n",
|
||||
" count += 1\n",
|
||||
" else:\n",
|
||||
" invalid_res_ids.append(residue.get_id())\n",
|
||||
" for res_id in invalid_res_ids:\n",
|
||||
" chain.detach_child(res_id)\n",
|
||||
" if len(chain_coords) > 0:\n",
|
||||
" all_chain_coords = np.concatenate(chain_coords, axis=0)\n",
|
||||
" distances = spatial.distance.cdist(lig_coords, all_chain_coords)\n",
|
||||
" min_distance = distances.min()\n",
|
||||
" else:\n",
|
||||
" min_distance = np.inf\n",
|
||||
"\n",
|
||||
" # this removes chains if they are not close enough to the ligand\n",
|
||||
" min_distances.append(min_distance)\n",
|
||||
" lengths.append(count)\n",
|
||||
" coords.append(chain_coords)\n",
|
||||
" c_alpha_coords.append(np.array(chain_c_alpha_coords))\n",
|
||||
" n_coords.append(np.array(chain_n_coords))\n",
|
||||
" c_coords.append(np.array(chain_c_coords))\n",
|
||||
" if min_distance < cutoff:\n",
|
||||
" valid_chain_ids.append(chain.get_id())\n",
|
||||
" min_distances = np.array(min_distances)\n",
|
||||
" if len(valid_chain_ids) == 0:\n",
|
||||
" valid_chain_ids.append(np.argmin(min_distances))\n",
|
||||
" valid_coords = []\n",
|
||||
" valid_c_alpha_coords = []\n",
|
||||
" valid_n_coords = []\n",
|
||||
" valid_c_coords = []\n",
|
||||
" valid_lengths = []\n",
|
||||
" invalid_chain_ids = []\n",
|
||||
" valid_lm_embeddings = []\n",
|
||||
" for i, chain in enumerate(rec):\n",
|
||||
" if chain.get_id() in valid_chain_ids:\n",
|
||||
" valid_coords.append(coords[i])\n",
|
||||
" valid_c_alpha_coords.append(c_alpha_coords[i])\n",
|
||||
" if lm_embedding_chains is not None:\n",
|
||||
" if i >= len(lm_embedding_chains):\n",
|
||||
" raise ValueError('Encountered valid chain id that was not present in the LM embeddings')\n",
|
||||
" valid_lm_embeddings.append(lm_embedding_chains[i])\n",
|
||||
" valid_n_coords.append(n_coords[i])\n",
|
||||
" valid_c_coords.append(c_coords[i])\n",
|
||||
" valid_lengths.append(lengths[i])\n",
|
||||
" else:\n",
|
||||
" invalid_chain_ids.append(chain.get_id())\n",
|
||||
" coords = [item for sublist in valid_coords for item in sublist] # list with n_residues arrays: [n_atoms, 3]\n",
|
||||
"\n",
|
||||
" c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) # [n_residues, 3]\n",
|
||||
" n_coords = np.concatenate(valid_n_coords, axis=0) # [n_residues, 3]\n",
|
||||
" c_coords = np.concatenate(valid_c_coords, axis=0) # [n_residues, 3]\n",
|
||||
" lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None\n",
|
||||
" for invalid_id in invalid_chain_ids:\n",
|
||||
" rec.detach_child(invalid_id)\n",
|
||||
"\n",
|
||||
" assert len(c_alpha_coords) == len(n_coords)\n",
|
||||
" assert len(c_alpha_coords) == len(c_coords)\n",
|
||||
" assert sum(valid_lengths) == len(c_alpha_coords)\n",
|
||||
" return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"id": "ce502378-1912-4a65-940e-61eb545a4264",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def align_prediction(smoothing_factor, pdbbind_calpha_coords, omegafold_calpha_coords, pdbbind_ligand_coords, return_rotation=False):\n",
|
||||
" pdbbind_dists = spa.distance.cdist(pdbbind_calpha_coords, pdbbind_ligand_coords)\n",
|
||||
" weights = np.exp(-1 * smoothing_factor * np.amin(pdbbind_dists, axis=1))\n",
|
||||
" \n",
|
||||
" pdbbind_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * pdbbind_calpha_coords, axis=0) / np.sum(weights)\n",
|
||||
" omegafold_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * omegafold_calpha_coords, axis=0) / np.sum(weights)\n",
|
||||
" centered_pdbbind_calpha_coords = pdbbind_calpha_coords - pdbbind_calpha_centroid\n",
|
||||
" centered_omegafold_calpha_coords = omegafold_calpha_coords - omegafold_calpha_centroid\n",
|
||||
" centered_pdbbind_ligand_coords = pdbbind_ligand_coords - pdbbind_calpha_centroid\n",
|
||||
" \n",
|
||||
" rotation, rec_weighted_rmsd = spa.transform.Rotation.align_vectors(centered_pdbbind_calpha_coords, centered_omegafold_calpha_coords, weights)\n",
|
||||
" if return_rotation:\n",
|
||||
" return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid\n",
|
||||
" \n",
|
||||
" aligned_omegafold_calpha_coords = rotation.apply(centered_omegafold_calpha_coords)\n",
|
||||
" aligned_omegafold_pdbbind_dists = spa.distance.cdist(aligned_omegafold_calpha_coords, centered_pdbbind_ligand_coords)\n",
|
||||
" inv_r_rmse = np.sqrt(np.mean(((1 / pdbbind_dists) - (1 / aligned_omegafold_pdbbind_dists)) ** 2))\n",
|
||||
" return inv_r_rmse\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"id": "7a20fbfe-a026-4815-9ed9-42c9bbd9cd08",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_alignment_rotation(pdb_id, pdbbind_protein_path, omegafold_protein_path, pdbbind_path):\n",
|
||||
" pdbbind_rec = parse_pdb_from_path(pdbbind_protein_path)\n",
|
||||
" omegafold_rec = parse_pdb_from_path(omegafold_protein_path)\n",
|
||||
" pdbbind_ligand = read_mols(pdbbind_path, pdb_id, remove_hs=True)[0]\n",
|
||||
" \n",
|
||||
" pdbbind_calpha_coords = extract_receptor_structure(pdbbind_rec, pdbbind_ligand)[2]\n",
|
||||
" omegafold_calpha_coords = extract_receptor_structure(omegafold_rec, pdbbind_ligand)[2]\n",
|
||||
" pdbbind_ligand_coords = pdbbind_ligand.GetConformer().GetPositions()\n",
|
||||
"\n",
|
||||
" if pdbbind_calpha_coords.shape != omegafold_calpha_coords.shape:\n",
|
||||
" print(f'Receptor structures differ for PDB ID {pdb_id} - Skipping', pdbbind_calpha_coords.shape, omegafold_calpha_coords.shape)\n",
|
||||
" return None, None, None\n",
|
||||
"\n",
|
||||
" res = minimize(\n",
|
||||
" align_prediction,\n",
|
||||
" [0.1],\n",
|
||||
" bounds=Bounds([0.0],[1.0]),\n",
|
||||
" args=(\n",
|
||||
" pdbbind_calpha_coords,\n",
|
||||
" omegafold_calpha_coords,\n",
|
||||
" pdbbind_ligand_coords\n",
|
||||
" ),\n",
|
||||
" tol=1e-8\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" smoothing_factor = res.x\n",
|
||||
" inv_r_rmse = res.fun\n",
|
||||
" rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = align_prediction(\n",
|
||||
" smoothing_factor,\n",
|
||||
" pdbbind_calpha_coords,\n",
|
||||
" omegafold_calpha_coords,\n",
|
||||
" pdbbind_ligand_coords,\n",
|
||||
" True\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "befbf9d6-fb08-4e85-a326-ff903547e2f9",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from biopandas.pdb import PandasPdb\n",
|
||||
"\n",
|
||||
"for f in tqdm(os.listdir(\"data/esmfold_structures\")):\n",
|
||||
" pdb_id = f.split(\"_\")[0]\n",
|
||||
" \n",
|
||||
" omega_protein_filename = f\"data/esmfold_structures/{pdb_id}_protein_esmfold.pdb\"\n",
|
||||
" omega_protein_output_filename = f\"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_esmfold_aligned_tr.pdb\"\n",
|
||||
" \n",
|
||||
" rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = get_alignment_rotation(pdb_id, f\"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_processed.pdb\", \n",
|
||||
" omega_protein_filename, \"data/PDBBind_processed/\")\n",
|
||||
" \n",
|
||||
" if rotation is None:\n",
|
||||
" continue\n",
|
||||
" \n",
|
||||
" ppdb_omegafold = PandasPdb().read_pdb(omega_protein_filename)\n",
|
||||
" ppdb_omegafold_pre_rot = ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)\n",
|
||||
" ppdb_omegafold_aligned = rotation.apply(ppdb_omegafold_pre_rot - omegafold_calpha_centroid) + pdbbind_calpha_centroid\n",
|
||||
" \n",
|
||||
" \n",
|
||||
" ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ppdb_omegafold_aligned\n",
|
||||
" ppdb_omegafold.to_pdb(path=omega_protein_output_filename, records=['ATOM'], gz=False)\n",
|
||||
" "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "445b0f84-302f-4023-bafb-05780d8967c0",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
BIN
data/splits/MOAD_generalisation_splits.pkl
Normal file
BIN
data/splits/MOAD_generalisation_splits.pkl
Normal file
Binary file not shown.
225095
data/splits/pdbids_2019
Normal file
225095
data/splits/pdbids_2019
Normal file
File diff suppressed because it is too large
Load Diff
428
data/splits/posebusters_benchmark_set_ids.txt
Normal file
428
data/splits/posebusters_benchmark_set_ids.txt
Normal file
@@ -0,0 +1,428 @@
|
||||
5S8I_2LY
|
||||
5SAK_ZRY
|
||||
5SB2_1K2
|
||||
5SD5_HWI
|
||||
5SIS_JSM
|
||||
6M2B_EZO
|
||||
6M73_FNR
|
||||
6T88_MWQ
|
||||
6TW5_9M2
|
||||
6TW7_NZB
|
||||
6VS3_R6V
|
||||
6VTA_AKN
|
||||
6W59_SZD
|
||||
6WTN_RXT
|
||||
6X8D_ARA
|
||||
6XAF_GDP
|
||||
6XBO_5MC
|
||||
6XCT_478
|
||||
6XG5_TOP
|
||||
6XHT_V2V
|
||||
6XM9_V55
|
||||
6XUM_30L
|
||||
6Y7L_QMG
|
||||
6YDY_K73
|
||||
6YJA_2BA
|
||||
6YMS_OZH
|
||||
6YQV_8K2
|
||||
6YQW_82I
|
||||
6YR2_T1C
|
||||
6YRV_PJ8
|
||||
6YSP_PAL
|
||||
6YT6_PKE
|
||||
6YYO_Q1K
|
||||
6Z0R_Q4H
|
||||
6Z14_Q4Z
|
||||
6Z1C_7EY
|
||||
6Z2C_Q5E
|
||||
6Z4N_Q7B
|
||||
6Z5Z_BDF
|
||||
6ZAE_ACV
|
||||
6ZC3_JOR
|
||||
6ZCY_QF8
|
||||
6ZK5_IMH
|
||||
6ZPB_3D1
|
||||
6ZR8_QOZ
|
||||
6ZT2_QPK
|
||||
6ZX3_QRZ
|
||||
6ZXQ_IMO
|
||||
7A1P_QW2
|
||||
7A9E_R4W
|
||||
7A9H_TPP
|
||||
7AA0_R6B
|
||||
7AFX_R9K
|
||||
7AKL_RK5
|
||||
7AMC_73B
|
||||
7AN5_RDH
|
||||
7AS1_21G
|
||||
7AVI_S2Q
|
||||
7B0E_C2E
|
||||
7B2C_TP7
|
||||
7B94_ANP
|
||||
7BA0_T5H
|
||||
7BCP_GCO
|
||||
7BHX_TO5
|
||||
7BJ6_TVK
|
||||
7BJJ_TVW
|
||||
7BKA_4JC
|
||||
7BLA_WCS
|
||||
7BLG_GAL
|
||||
7BMI_U4B
|
||||
7BNH_BEZ
|
||||
7BTT_F8R
|
||||
7C0U_FGO
|
||||
7C3U_AZG
|
||||
7C6P_SQH
|
||||
7C8Q_DSG
|
||||
7CD9_FVR
|
||||
7CIJ_G0C
|
||||
7CL8_TES
|
||||
7CNQ_G8X
|
||||
7CNS_PMV
|
||||
7CTM_BDP
|
||||
7CUO_PHB
|
||||
7D0P_1VU
|
||||
7D5C_GV6
|
||||
7D6O_MTE
|
||||
7D8Q_GZF
|
||||
7D9L_GSF
|
||||
7DIN_MPO
|
||||
7DKT_GLF
|
||||
7DQL_4CL
|
||||
7DUA_HJ0
|
||||
7E2S_BLA
|
||||
7E4L_MDN
|
||||
7EBG_J0L
|
||||
7ECR_SIN
|
||||
7ED2_A3P
|
||||
7ELT_TYM
|
||||
7EN7_J79
|
||||
7EPV_FDA
|
||||
7ES1_UDP
|
||||
7F51_BA7
|
||||
7F5D_EUO
|
||||
7F8T_FAD
|
||||
7FB7_8NF
|
||||
7FHA_ADX
|
||||
7FRX_O88
|
||||
7FT9_4MB
|
||||
7JG0_GAR
|
||||
7JGW_V9S
|
||||
7JHQ_VAJ
|
||||
7JMV_4NC
|
||||
7JNB_A2G
|
||||
7JR8_VH7
|
||||
7JUD_MMA
|
||||
7JXX_VP7
|
||||
7JY3_VUD
|
||||
7K0V_VQP
|
||||
7K41_VUA
|
||||
7KB1_WBJ
|
||||
7KC5_BJZ
|
||||
7KFO_IAC
|
||||
7KLX_WOV
|
||||
7KM8_WPD
|
||||
7KP6_WTP
|
||||
7KQU_YOF
|
||||
7KRU_ATP
|
||||
7KZ9_XN7
|
||||
7L00_XCJ
|
||||
7L03_F9F
|
||||
7L5F_XNG
|
||||
7L6D_BMF
|
||||
7L7C_XQ1
|
||||
7L81_UD4
|
||||
7LB3_XXS
|
||||
7LCU_XTA
|
||||
7LEV_0JO
|
||||
7LJN_GTP
|
||||
7LMO_NYO
|
||||
7LOE_Y84
|
||||
7LOU_IFM
|
||||
7LT0_ONJ
|
||||
7LZD_YHY
|
||||
7LZQ_YJV
|
||||
7M31_TDR
|
||||
7M3H_YPV
|
||||
7M41_YQG
|
||||
7M6K_YRJ
|
||||
7MAE_XUS
|
||||
7MEU_MGP
|
||||
7MFP_Z7P
|
||||
7MGT_ZD4
|
||||
7MGY_ZD1
|
||||
7MMH_ZJY
|
||||
7MOI_HPS
|
||||
7MRH_ZMJ
|
||||
7MS7_ZQ1
|
||||
7MSR_DCA
|
||||
7MWN_WI5
|
||||
7MWU_ZPM
|
||||
7MY1_IPE
|
||||
7MYU_ZR7
|
||||
7MZS_GLA
|
||||
7N03_ZRP
|
||||
7N4N_0BK
|
||||
7N4W_P4V
|
||||
7N6F_0I1
|
||||
7N7B_T3F
|
||||
7N7H_CTP
|
||||
7NA4_1I9
|
||||
7NB4_U6Q
|
||||
7NF0_BYN
|
||||
7NF3_4LU
|
||||
7NFB_GEN
|
||||
7NGW_UAW
|
||||
7NLK_UHK
|
||||
7NLV_UJE
|
||||
7NML_I7B
|
||||
7NP6_UK8
|
||||
7NPL_UKZ
|
||||
7NR6_UO8
|
||||
7NR8_UOE
|
||||
7NSW_HC4
|
||||
7NTG_F6R
|
||||
7NU0_DCL
|
||||
7NUT_GLP
|
||||
7NXO_UU8
|
||||
7O0N_CDP
|
||||
7O1T_5X8
|
||||
7OCB_V88
|
||||
7ODX_DGP
|
||||
7ODY_DGI
|
||||
7OEO_V9Z
|
||||
7OFF_VCB
|
||||
7OFK_VCH
|
||||
7OKC_VFE
|
||||
7OKF_VH5
|
||||
7OLI_8HG
|
||||
7OLT_58J
|
||||
7OMJ_GCP
|
||||
7OMX_CNA
|
||||
7OP9_06K
|
||||
7OPG_06N
|
||||
7ORW_7WA
|
||||
7OSO_0V1
|
||||
7OU8_1XI
|
||||
7OZ9_NGK
|
||||
7OZC_G6S
|
||||
7P1F_KFN
|
||||
7P1M_4IU
|
||||
7P2I_MFU
|
||||
7P2W_4QR
|
||||
7P4C_5OV
|
||||
7P4J_5JK
|
||||
7P4V_DAT
|
||||
7P5T_5YG
|
||||
7P85_5ZG
|
||||
7PA4_C
|
||||
7PGX_FMN
|
||||
7PIH_7QW
|
||||
7PJQ_OWH
|
||||
7PK0_BYC
|
||||
7PL1_SFG
|
||||
7POM_7VZ
|
||||
7PRI_7TI
|
||||
7PRM_81I
|
||||
7PT3_3KK
|
||||
7PUV_84Z
|
||||
7Q19_DSM
|
||||
7Q25_8J9
|
||||
7Q27_8KC
|
||||
7Q2B_M6H
|
||||
7Q5I_I0F
|
||||
7QE4_NGA
|
||||
7QF4_RBF
|
||||
7QFM_AY3
|
||||
7QGP_DJ8
|
||||
7QHG_T3B
|
||||
7QHL_D5P
|
||||
7QK0_EBL
|
||||
7QPP_VDX
|
||||
7QSW_CAP
|
||||
7QTA_URI
|
||||
7R3D_APR
|
||||
7R59_I5F
|
||||
7R6J_2I7
|
||||
7R7R_AWJ
|
||||
7R9N_F97
|
||||
7RC3_SAH
|
||||
7REE_4LY
|
||||
7RH3_59O
|
||||
7RH8_UTP
|
||||
7RKW_5TV
|
||||
7RNI_60I
|
||||
7ROR_69X
|
||||
7ROU_66I
|
||||
7RPZ_6IC
|
||||
7RSV_7IQ
|
||||
7RUI_7QZ
|
||||
7RWO_7WN
|
||||
7RWS_4UR
|
||||
7RZL_NPO
|
||||
7S45_ACO
|
||||
7S9H_7PP
|
||||
7SCW_GSP
|
||||
7SDD_4IP
|
||||
7SED_8VD
|
||||
7SFO_98L
|
||||
7SGV_L30
|
||||
7SIU_9ID
|
||||
7SNE_9XR
|
||||
7SSM_B7L
|
||||
7SUC_COM
|
||||
7SZA_DUI
|
||||
7T0D_FPP
|
||||
7T0U_E3I
|
||||
7T1D_E7K
|
||||
7T2I_E9F
|
||||
7T3E_SLB
|
||||
7T3F_EM0
|
||||
7T9O_GEI
|
||||
7TB0_UD1
|
||||
7TBU_S3P
|
||||
7TE8_P0T
|
||||
7TH4_FFO
|
||||
7THI_PGA
|
||||
7TM6_GPJ
|
||||
7TOM_5AD
|
||||
7TS6_KMI
|
||||
7TSF_H4B
|
||||
7TUO_KL9
|
||||
7TWC_CXS
|
||||
7TXK_LW8
|
||||
7TXP_0FX
|
||||
7TYP_KUR
|
||||
7U0U_FK5
|
||||
7U3J_L6U
|
||||
7UAS_MBU
|
||||
7UAW_MF6
|
||||
7UEY_N0R
|
||||
7UF2_5SP
|
||||
7UJ4_OQ4
|
||||
7UJ5_DGL
|
||||
7UJF_R3V
|
||||
7ULC_56B
|
||||
7UMV_NUU
|
||||
7UMW_NAD
|
||||
7UP3_NZ0
|
||||
7UQ3_O2U
|
||||
7USH_82V
|
||||
7UTW_NAI
|
||||
7UXS_OJC
|
||||
7UY4_SMI
|
||||
7UYB_OK0
|
||||
7V14_ORU
|
||||
7V3N_AKG
|
||||
7V3S_5I9
|
||||
7V43_C4O
|
||||
7V8Z_5YH
|
||||
7VB8_STL
|
||||
7VBU_6I4
|
||||
7VC5_9SF
|
||||
7VJT_7IJ
|
||||
7VKZ_NOJ
|
||||
7VQ9_ISY
|
||||
7VWF_K55
|
||||
7VYJ_CA0
|
||||
7W05_GMP
|
||||
7W06_ITN
|
||||
7W6F_8I6
|
||||
7WCF_ACP
|
||||
7WDT_NGS
|
||||
7WJB_BGC
|
||||
7WKL_CAQ
|
||||
7WL4_JFU
|
||||
7WN5_JGL
|
||||
7WPW_F15
|
||||
7WQQ_5Z6
|
||||
7WUX_6OI
|
||||
7WUY_76N
|
||||
7WY1_D0L
|
||||
7X5N_5M5
|
||||
7X9K_8OG
|
||||
7XBV_APC
|
||||
7XEK_9YX
|
||||
7XFA_D9J
|
||||
7XG5_PLP
|
||||
7XI7_4RI
|
||||
7XIJ_EJ3
|
||||
7XJN_NSD
|
||||
7XPO_UPG
|
||||
7XQZ_FPF
|
||||
7XRL_FWK
|
||||
7YZU_DO7
|
||||
7Z1Q_NIO
|
||||
7Z2O_IAJ
|
||||
7Z7F_IF3
|
||||
7ZCC_OGA
|
||||
7ZDY_6MJ
|
||||
7ZF0_DHR
|
||||
7ZHP_IQY
|
||||
7ZL5_IWE
|
||||
7ZOC_T8E
|
||||
7ZTL_BCN
|
||||
7ZU2_DHT
|
||||
7ZXV_45D
|
||||
7ZXZ_K9R
|
||||
7ZYS_KNR
|
||||
7ZZB_KGX
|
||||
7ZZW_KKW
|
||||
8A1H_DLZ
|
||||
8A2D_KXY
|
||||
8AAU_LH0
|
||||
8ACL_LQL
|
||||
8AEM_LVF
|
||||
8AEU_M0L
|
||||
8AIE_M7L
|
||||
8AIJ_M9I
|
||||
8AJX_FUM
|
||||
8AP0_PRP
|
||||
8AQL_PLG
|
||||
8AUH_L9I
|
||||
8AY3_OE3
|
||||
8B8H_OJQ
|
||||
8BN6_R53
|
||||
8BOM_QU6
|
||||
8BPL_CP
|
||||
8BRO_R7E
|
||||
8BTI_RFO
|
||||
8C3N_ADP
|
||||
8C5D_GTB
|
||||
8C5M_MTA
|
||||
8C7Y_TXV
|
||||
8CGC_LMR
|
||||
8CI0_8EL
|
||||
8CNH_V6U
|
||||
8CSD_C5P
|
||||
8D19_GSH
|
||||
8D39_QDB
|
||||
8D5D_5DK
|
||||
8DHG_T78
|
||||
8DKO_TFB
|
||||
8DP2_UMA
|
||||
8DSC_NCA
|
||||
8DW5_FQ7
|
||||
8DZT_G4P
|
||||
8E77_ULP
|
||||
8EAB_VN2
|
||||
8EAD_UY0
|
||||
8ERS_WQO
|
||||
8EX2_Q2Q
|
||||
8EXL_799
|
||||
8EYE_X4I
|
||||
8F4J_PHO
|
||||
8F8E_XJI
|
||||
8FAV_4Y5
|
||||
8FLN_Y7W
|
||||
8FLV_ZB9
|
||||
8FO5_Y4U
|
||||
8FV9_80J
|
||||
8G0V_YHT
|
||||
8G43_ZU6
|
||||
8G6P_API
|
||||
8GFD_ZHR
|
||||
8H0M_2EH
|
||||
8HFN_XGC
|
||||
8HO0_3ZI
|
||||
8SLG_G5A
|
||||
BIN
data/splits/self_distillation_splits.pkl
Normal file
BIN
data/splits/self_distillation_splits.pkl
Normal file
Binary file not shown.
585
data/splits/timesplit_val_filter
Normal file
585
data/splits/timesplit_val_filter
Normal file
@@ -0,0 +1,585 @@
|
||||
4f49
|
||||
4z88
|
||||
3q8d
|
||||
5yvx
|
||||
2w2i
|
||||
4dcd
|
||||
4nuf
|
||||
4xxh
|
||||
4wk2
|
||||
3odl
|
||||
1ui0
|
||||
4ris
|
||||
4d3h
|
||||
3k3g
|
||||
6drt
|
||||
4b4q
|
||||
2p59
|
||||
4u3f
|
||||
2o9a
|
||||
3m94
|
||||
2y3p
|
||||
3rdv
|
||||
1ols
|
||||
2yiv
|
||||
1jgl
|
||||
6bj2
|
||||
2ke1
|
||||
5ah2
|
||||
2xtk
|
||||
6hly
|
||||
4wk1
|
||||
1fpy
|
||||
3iw4
|
||||
1olu
|
||||
4xqu
|
||||
1i3z
|
||||
5zla
|
||||
2cfd
|
||||
5zku
|
||||
4dj7
|
||||
3evf
|
||||
1iup
|
||||
6e5x
|
||||
5j5d
|
||||
3n2c
|
||||
3odi
|
||||
4c94
|
||||
2x7x
|
||||
4rww
|
||||
5x33
|
||||
1olx
|
||||
4mfe
|
||||
1wvc
|
||||
2wzm
|
||||
4pxf
|
||||
1v1m
|
||||
1v11
|
||||
6hm2
|
||||
5uxf
|
||||
3rde
|
||||
5syn
|
||||
2qta
|
||||
2puy
|
||||
2za5
|
||||
4nrt
|
||||
3rg2
|
||||
2jc0
|
||||
4xtt
|
||||
1v16
|
||||
1bq4
|
||||
1uef
|
||||
5o0j
|
||||
6hlz
|
||||
2p4s
|
||||
2o4h
|
||||
1njj
|
||||
2cfg
|
||||
3ucj
|
||||
16pk
|
||||
1u9l
|
||||
4lkj
|
||||
4p5z
|
||||
1mpl
|
||||
4u7t
|
||||
1jr1
|
||||
5wbf
|
||||
3ess
|
||||
2n7b
|
||||
6hlx
|
||||
5j5t
|
||||
2qw1
|
||||
4wn5
|
||||
1xt3
|
||||
5ks7
|
||||
1ff1
|
||||
4up5
|
||||
4tx6
|
||||
4tsz
|
||||
1dpu
|
||||
1oqp
|
||||
5bml
|
||||
6fdt
|
||||
4ie6
|
||||
3juo
|
||||
2xvn
|
||||
2xye
|
||||
4olh
|
||||
4udb
|
||||
1my2
|
||||
4n9e
|
||||
4dht
|
||||
3azb
|
||||
6eqp
|
||||
4eb9
|
||||
4okg
|
||||
4q18
|
||||
5a3w
|
||||
2mip
|
||||
4yp1
|
||||
2wnc
|
||||
3skf
|
||||
3zm9
|
||||
1tsm
|
||||
3wzp
|
||||
3qqk
|
||||
1pmv
|
||||
3q4l
|
||||
4x6p
|
||||
5xij
|
||||
4wkt
|
||||
4g31
|
||||
5y3n
|
||||
1cw2
|
||||
4tk0
|
||||
2y34
|
||||
4a4l
|
||||
3dga
|
||||
4igr
|
||||
5fnt
|
||||
2d06
|
||||
6gjy
|
||||
4j84
|
||||
5icy
|
||||
5oui
|
||||
3rj7
|
||||
6ccy
|
||||
4xu3
|
||||
5e0l
|
||||
6fjm
|
||||
2igx
|
||||
4odl
|
||||
4gr8
|
||||
4bhf
|
||||
3iux
|
||||
5v82
|
||||
4c6z
|
||||
4yje
|
||||
4rz1
|
||||
5hv1
|
||||
2vc7
|
||||
1jyc
|
||||
2lgf
|
||||
4clj
|
||||
5aom
|
||||
1l5r
|
||||
1nvq
|
||||
5ukl
|
||||
6b96
|
||||
1xmu
|
||||
2bcd
|
||||
2i0a
|
||||
4yrr
|
||||
4w9g
|
||||
3lxl
|
||||
1sqo
|
||||
3wk6
|
||||
3t1l
|
||||
3p1d
|
||||
2eg7
|
||||
5o0s
|
||||
5jn8
|
||||
2wap
|
||||
5u9i
|
||||
6min
|
||||
3eyh
|
||||
4p4e
|
||||
6mr5
|
||||
1sqp
|
||||
1f8a
|
||||
1d4w
|
||||
5oa6
|
||||
3kx1
|
||||
5myr
|
||||
2iwu
|
||||
1ero
|
||||
4xnw
|
||||
4bch
|
||||
5tob
|
||||
4uxl
|
||||
1jd6
|
||||
3hfj
|
||||
4de3
|
||||
2brp
|
||||
4lkg
|
||||
2fme
|
||||
2y5g
|
||||
4glr
|
||||
1k08
|
||||
1hyv
|
||||
3ske
|
||||
4hv7
|
||||
3fun
|
||||
3gwu
|
||||
4xsx
|
||||
4ec4
|
||||
5zmq
|
||||
3zlw
|
||||
3fcl
|
||||
5f3t
|
||||
3mfv
|
||||
3bzf
|
||||
4jbo
|
||||
2ow2
|
||||
5l2z
|
||||
2qci
|
||||
1hiy
|
||||
6f6n
|
||||
6fgg
|
||||
1z2b
|
||||
1szm
|
||||
3n9r
|
||||
5dhr
|
||||
1jvu
|
||||
1jd0
|
||||
4a9r
|
||||
2vgc
|
||||
5y5t
|
||||
4iax
|
||||
5l4m
|
||||
2vnp
|
||||
3n3l
|
||||
1qxz
|
||||
1tok
|
||||
4aqh
|
||||
1wbg
|
||||
5mkx
|
||||
2oqs
|
||||
6dar
|
||||
5v49
|
||||
3ft2
|
||||
3o2m
|
||||
3g4l
|
||||
4bgy
|
||||
5gsw
|
||||
1fkw
|
||||
2gg8
|
||||
3tv5
|
||||
4at5
|
||||
4kn4
|
||||
6cvw
|
||||
1x7b
|
||||
6fsd
|
||||
3hig
|
||||
5ewa
|
||||
5vij
|
||||
3d4f
|
||||
5vyy
|
||||
3v66
|
||||
3tzd
|
||||
4xkc
|
||||
4y4g
|
||||
2w0x
|
||||
5mft
|
||||
3p78
|
||||
2xn3
|
||||
4x3h
|
||||
4d8z
|
||||
4r17
|
||||
6bik
|
||||
1sqc
|
||||
4jkw
|
||||
4m7c
|
||||
4kpx
|
||||
1a52
|
||||
4qew
|
||||
3zlk
|
||||
5np8
|
||||
1qpe
|
||||
4wmx
|
||||
1z4u
|
||||
5j7q
|
||||
4xt9
|
||||
4tpm
|
||||
5wlg
|
||||
5fot
|
||||
4bcc
|
||||
5l3e
|
||||
3nti
|
||||
1xr9
|
||||
1zpa
|
||||
5we9
|
||||
5avi
|
||||
4mbj
|
||||
3cwj
|
||||
6ck6
|
||||
4ir5
|
||||
5l8n
|
||||
3bz3
|
||||
1azm
|
||||
5a82
|
||||
3u15
|
||||
5wbl
|
||||
5mli
|
||||
1ujj
|
||||
3wp1
|
||||
3obu
|
||||
3vzd
|
||||
4rj5
|
||||
5d6p
|
||||
2xkf
|
||||
4ib5
|
||||
3jsw
|
||||
2zv9
|
||||
1okz
|
||||
3ff6
|
||||
2gz8
|
||||
6bg5
|
||||
2wk2
|
||||
4wet
|
||||
4l53
|
||||
3jzg
|
||||
1g3d
|
||||
3pka
|
||||
5yg4
|
||||
4few
|
||||
1z9g
|
||||
5evd
|
||||
2a4q
|
||||
5ehg
|
||||
3max
|
||||
6msy
|
||||
1fcy
|
||||
4nhy
|
||||
1k4g
|
||||
5lz8
|
||||
4kqr
|
||||
1t1r
|
||||
2z78
|
||||
3dz6
|
||||
4ygf
|
||||
4ht6
|
||||
2v7a
|
||||
3wax
|
||||
4ivc
|
||||
3u7l
|
||||
3hvi
|
||||
5o0e
|
||||
2c97
|
||||
4b6f
|
||||
5mmn
|
||||
6ee3
|
||||
3fmz
|
||||
5wzv
|
||||
5znr
|
||||
5ct2
|
||||
4fai
|
||||
5uit
|
||||
2gkl
|
||||
2vpg
|
||||
1jtq
|
||||
3veh
|
||||
4jit
|
||||
4bgg
|
||||
2ce9
|
||||
5aqu
|
||||
6hth
|
||||
4m5k
|
||||
1aj6
|
||||
3h22
|
||||
5vnd
|
||||
1q4l
|
||||
2j7x
|
||||
3acx
|
||||
5l6p
|
||||
5m0m
|
||||
5oh9
|
||||
2a5c
|
||||
5npc
|
||||
4qk4
|
||||
1fzj
|
||||
5bqi
|
||||
5c4l
|
||||
1fq4
|
||||
2v5x
|
||||
4k63
|
||||
3uik
|
||||
1me8
|
||||
5uq9
|
||||
4nrq
|
||||
2r1y
|
||||
3ldw
|
||||
4bh4
|
||||
4pl5
|
||||
1mfg
|
||||
3c5u
|
||||
3iqh
|
||||
6awo
|
||||
5d1t
|
||||
4urn
|
||||
3dnj
|
||||
4csy
|
||||
1caq
|
||||
4oem
|
||||
4wke
|
||||
3pb7
|
||||
1qwu
|
||||
5g3m
|
||||
2buc
|
||||
4i9h
|
||||
5epr
|
||||
5ivy
|
||||
3krj
|
||||
3f9w
|
||||
3rwj
|
||||
4bo4
|
||||
4amx
|
||||
4wvl
|
||||
5x54
|
||||
2a58
|
||||
6ma1
|
||||
1i41
|
||||
2ew5
|
||||
1y57
|
||||
3ovn
|
||||
5g1n
|
||||
2bba
|
||||
3iej
|
||||
6eip
|
||||
3iaf
|
||||
2uy4
|
||||
830c
|
||||
1aqc
|
||||
2wot
|
||||
1bmq
|
||||
4nal
|
||||
3kbz
|
||||
1gt1
|
||||
4pp9
|
||||
2aq9
|
||||
1b7h
|
||||
3lp1
|
||||
3exf
|
||||
6g9a
|
||||
4gzw
|
||||
1p06
|
||||
4own
|
||||
3qo2
|
||||
5emk
|
||||
1gvk
|
||||
6chn
|
||||
4r1v
|
||||
1q95
|
||||
4x5q
|
||||
2x6x
|
||||
4lh6
|
||||
5mb1
|
||||
4zzz
|
||||
4ty9
|
||||
2hkf
|
||||
1n3i
|
||||
4b6p
|
||||
5vo2
|
||||
2ovq
|
||||
1vcj
|
||||
3kjn
|
||||
4zz1
|
||||
1f73
|
||||
6baw
|
||||
6arv
|
||||
2yxj
|
||||
2xo8
|
||||
3hqz
|
||||
4li6
|
||||
1uze
|
||||
1ih0
|
||||
4uwh
|
||||
5l6h
|
||||
3g0w
|
||||
5v3r
|
||||
6ccu
|
||||
2g78
|
||||
3eqr
|
||||
5d10
|
||||
2vte
|
||||
4e3j
|
||||
2a5u
|
||||
5duc
|
||||
5a0e
|
||||
4zzx
|
||||
6ezq
|
||||
6e4w
|
||||
6m8w
|
||||
3mo0
|
||||
3eyd
|
||||
1pl0
|
||||
4bhz
|
||||
2wxn
|
||||
3vb7
|
||||
3daj
|
||||
3ara
|
||||
4b4m
|
||||
2y59
|
||||
3tcp
|
||||
5hjd
|
||||
4wym
|
||||
4qo9
|
||||
4yz9
|
||||
2jew
|
||||
5tq4
|
||||
4qsh
|
||||
4djp
|
||||
3rqg
|
||||
4ddh
|
||||
6bj3
|
||||
3czv
|
||||
4jok
|
||||
5n2d
|
||||
2afw
|
||||
4h2o
|
||||
2vd7
|
||||
2gvf
|
||||
2yhd
|
||||
3drg
|
||||
1zc9
|
||||
5khi
|
||||
5aix
|
||||
2y2i
|
||||
4yax
|
||||
5d75
|
||||
4m3b
|
||||
5hjc
|
||||
4u0b
|
||||
4n7j
|
||||
4b6s
|
||||
3lq4
|
||||
1lhu
|
||||
5ur5
|
||||
5fcz
|
||||
5fyx
|
||||
3ml5
|
||||
4mb9
|
||||
4nzn
|
||||
5wpb
|
||||
4omd
|
||||
5t2d
|
||||
1bzh
|
||||
5idn
|
||||
4f9v
|
||||
2m3o
|
||||
6cw4
|
||||
4i6f
|
||||
3t2t
|
||||
4ep2
|
||||
5x4o
|
||||
3vqh
|
||||
4itj
|
||||
1p93
|
||||
4b5w
|
||||
5ech
|
||||
2ltw
|
||||
2hrp
|
||||
5lpm
|
||||
5twh
|
||||
2gbg
|
||||
4jjq
|
||||
4kil
|
||||
2r7g
|
||||
5dgu
|
||||
1hbj
|
||||
5vsj
|
||||
6e05
|
||||
1jq9
|
||||
4gto
|
||||
4gii
|
||||
2wc4
|
||||
@@ -1,364 +1,364 @@
|
||||
complex_name,protein_path,ligand_description,protein_sequence
|
||||
0,data/PDBBind_processed/6qqw/6qqw_protein_processed.pdb,data/PDBBind_processed/6qqw/6qqw_ligand.mol2,
|
||||
1,data/PDBBind_processed/6d08/6d08_protein_processed.pdb,data/PDBBind_processed/6d08/6d08_ligand.sdf,
|
||||
2,data/PDBBind_processed/6jap/6jap_protein_processed.pdb,data/PDBBind_processed/6jap/6jap_ligand.sdf,
|
||||
3,data/PDBBind_processed/6np2/6np2_protein_processed.pdb,data/PDBBind_processed/6np2/6np2_ligand.sdf,
|
||||
4,data/PDBBind_processed/6uvp/6uvp_protein_processed.pdb,data/PDBBind_processed/6uvp/6uvp_ligand.sdf,
|
||||
5,data/PDBBind_processed/6oxq/6oxq_protein_processed.pdb,data/PDBBind_processed/6oxq/6oxq_ligand.sdf,
|
||||
6,data/PDBBind_processed/6jsn/6jsn_protein_processed.pdb,data/PDBBind_processed/6jsn/6jsn_ligand.sdf,
|
||||
7,data/PDBBind_processed/6hzb/6hzb_protein_processed.pdb,data/PDBBind_processed/6hzb/6hzb_ligand.sdf,
|
||||
8,data/PDBBind_processed/6qrc/6qrc_protein_processed.pdb,data/PDBBind_processed/6qrc/6qrc_ligand.mol2,
|
||||
9,data/PDBBind_processed/6oio/6oio_protein_processed.pdb,data/PDBBind_processed/6oio/6oio_ligand.sdf,
|
||||
10,data/PDBBind_processed/6jag/6jag_protein_processed.pdb,data/PDBBind_processed/6jag/6jag_ligand.sdf,
|
||||
11,data/PDBBind_processed/6moa/6moa_protein_processed.pdb,data/PDBBind_processed/6moa/6moa_ligand.mol2,
|
||||
12,data/PDBBind_processed/6hld/6hld_protein_processed.pdb,data/PDBBind_processed/6hld/6hld_ligand.sdf,
|
||||
13,data/PDBBind_processed/6i9a/6i9a_protein_processed.pdb,data/PDBBind_processed/6i9a/6i9a_ligand.sdf,
|
||||
14,data/PDBBind_processed/6e4c/6e4c_protein_processed.pdb,data/PDBBind_processed/6e4c/6e4c_ligand.sdf,
|
||||
15,data/PDBBind_processed/6g24/6g24_protein_processed.pdb,data/PDBBind_processed/6g24/6g24_ligand.sdf,
|
||||
16,data/PDBBind_processed/6jb4/6jb4_protein_processed.pdb,data/PDBBind_processed/6jb4/6jb4_ligand.sdf,
|
||||
17,data/PDBBind_processed/6s55/6s55_protein_processed.pdb,data/PDBBind_processed/6s55/6s55_ligand.sdf,
|
||||
18,data/PDBBind_processed/6seo/6seo_protein_processed.pdb,data/PDBBind_processed/6seo/6seo_ligand.sdf,
|
||||
19,data/PDBBind_processed/6dyz/6dyz_protein_processed.pdb,data/PDBBind_processed/6dyz/6dyz_ligand.mol2,
|
||||
20,data/PDBBind_processed/5zk5/5zk5_protein_processed.pdb,data/PDBBind_processed/5zk5/5zk5_ligand.sdf,
|
||||
21,data/PDBBind_processed/6jid/6jid_protein_processed.pdb,data/PDBBind_processed/6jid/6jid_ligand.sdf,
|
||||
22,data/PDBBind_processed/5ze6/5ze6_protein_processed.pdb,data/PDBBind_processed/5ze6/5ze6_ligand.sdf,
|
||||
23,data/PDBBind_processed/6qlu/6qlu_protein_processed.pdb,data/PDBBind_processed/6qlu/6qlu_ligand.sdf,
|
||||
24,data/PDBBind_processed/6a6k/6a6k_protein_processed.pdb,data/PDBBind_processed/6a6k/6a6k_ligand.sdf,
|
||||
25,data/PDBBind_processed/6qgf/6qgf_protein_processed.pdb,data/PDBBind_processed/6qgf/6qgf_ligand.sdf,
|
||||
26,data/PDBBind_processed/6e3z/6e3z_protein_processed.pdb,data/PDBBind_processed/6e3z/6e3z_ligand.sdf,
|
||||
27,data/PDBBind_processed/6te6/6te6_protein_processed.pdb,data/PDBBind_processed/6te6/6te6_ligand.sdf,
|
||||
28,data/PDBBind_processed/6pka/6pka_protein_processed.pdb,data/PDBBind_processed/6pka/6pka_ligand.sdf,
|
||||
29,data/PDBBind_processed/6g2o/6g2o_protein_processed.pdb,data/PDBBind_processed/6g2o/6g2o_ligand.sdf,
|
||||
30,data/PDBBind_processed/6jsf/6jsf_protein_processed.pdb,data/PDBBind_processed/6jsf/6jsf_ligand.sdf,
|
||||
31,data/PDBBind_processed/5zxk/5zxk_protein_processed.pdb,data/PDBBind_processed/5zxk/5zxk_ligand.sdf,
|
||||
32,data/PDBBind_processed/6qxd/6qxd_protein_processed.pdb,data/PDBBind_processed/6qxd/6qxd_ligand.sdf,
|
||||
33,data/PDBBind_processed/6n97/6n97_protein_processed.pdb,data/PDBBind_processed/6n97/6n97_ligand.sdf,
|
||||
34,data/PDBBind_processed/6jt3/6jt3_protein_processed.pdb,data/PDBBind_processed/6jt3/6jt3_ligand.sdf,
|
||||
35,data/PDBBind_processed/6qtr/6qtr_protein_processed.pdb,data/PDBBind_processed/6qtr/6qtr_ligand.sdf,
|
||||
36,data/PDBBind_processed/6oy1/6oy1_protein_processed.pdb,data/PDBBind_processed/6oy1/6oy1_ligand.sdf,
|
||||
37,data/PDBBind_processed/6n96/6n96_protein_processed.pdb,data/PDBBind_processed/6n96/6n96_ligand.sdf,
|
||||
38,data/PDBBind_processed/6qzh/6qzh_protein_processed.pdb,data/PDBBind_processed/6qzh/6qzh_ligand.sdf,
|
||||
39,data/PDBBind_processed/6qqz/6qqz_protein_processed.pdb,data/PDBBind_processed/6qqz/6qqz_ligand.mol2,
|
||||
40,data/PDBBind_processed/6qmt/6qmt_protein_processed.pdb,data/PDBBind_processed/6qmt/6qmt_ligand.sdf,
|
||||
41,data/PDBBind_processed/6ibx/6ibx_protein_processed.pdb,data/PDBBind_processed/6ibx/6ibx_ligand.sdf,
|
||||
42,data/PDBBind_processed/6hmt/6hmt_protein_processed.pdb,data/PDBBind_processed/6hmt/6hmt_ligand.sdf,
|
||||
43,data/PDBBind_processed/5zk7/5zk7_protein_processed.pdb,data/PDBBind_processed/5zk7/5zk7_ligand.sdf,
|
||||
44,data/PDBBind_processed/6k3l/6k3l_protein_processed.pdb,data/PDBBind_processed/6k3l/6k3l_ligand.sdf,
|
||||
45,data/PDBBind_processed/6cjs/6cjs_protein_processed.pdb,data/PDBBind_processed/6cjs/6cjs_ligand.sdf,
|
||||
46,data/PDBBind_processed/6n9l/6n9l_protein_processed.pdb,data/PDBBind_processed/6n9l/6n9l_ligand.sdf,
|
||||
47,data/PDBBind_processed/6ibz/6ibz_protein_processed.pdb,data/PDBBind_processed/6ibz/6ibz_ligand.sdf,
|
||||
48,data/PDBBind_processed/6ott/6ott_protein_processed.pdb,data/PDBBind_processed/6ott/6ott_ligand.sdf,
|
||||
49,data/PDBBind_processed/6gge/6gge_protein_processed.pdb,data/PDBBind_processed/6gge/6gge_ligand.sdf,
|
||||
50,data/PDBBind_processed/6hot/6hot_protein_processed.pdb,data/PDBBind_processed/6hot/6hot_ligand.sdf,
|
||||
51,data/PDBBind_processed/6e3p/6e3p_protein_processed.pdb,data/PDBBind_processed/6e3p/6e3p_ligand.mol2,
|
||||
52,data/PDBBind_processed/6md6/6md6_protein_processed.pdb,data/PDBBind_processed/6md6/6md6_ligand.sdf,
|
||||
53,data/PDBBind_processed/6hlb/6hlb_protein_processed.pdb,data/PDBBind_processed/6hlb/6hlb_ligand.sdf,
|
||||
54,data/PDBBind_processed/6fe5/6fe5_protein_processed.pdb,data/PDBBind_processed/6fe5/6fe5_ligand.sdf,
|
||||
55,data/PDBBind_processed/6uwp/6uwp_protein_processed.pdb,data/PDBBind_processed/6uwp/6uwp_ligand.sdf,
|
||||
56,data/PDBBind_processed/6npp/6npp_protein_processed.pdb,data/PDBBind_processed/6npp/6npp_ligand.sdf,
|
||||
57,data/PDBBind_processed/6g2f/6g2f_protein_processed.pdb,data/PDBBind_processed/6g2f/6g2f_ligand.sdf,
|
||||
58,data/PDBBind_processed/6mo7/6mo7_protein_processed.pdb,data/PDBBind_processed/6mo7/6mo7_ligand.sdf,
|
||||
59,data/PDBBind_processed/6bqd/6bqd_protein_processed.pdb,data/PDBBind_processed/6bqd/6bqd_ligand.mol2,
|
||||
60,data/PDBBind_processed/6nsv/6nsv_protein_processed.pdb,data/PDBBind_processed/6nsv/6nsv_ligand.mol2,
|
||||
61,data/PDBBind_processed/6i76/6i76_protein_processed.pdb,data/PDBBind_processed/6i76/6i76_ligand.sdf,
|
||||
62,data/PDBBind_processed/6n53/6n53_protein_processed.pdb,data/PDBBind_processed/6n53/6n53_ligand.sdf,
|
||||
63,data/PDBBind_processed/6g2c/6g2c_protein_processed.pdb,data/PDBBind_processed/6g2c/6g2c_ligand.sdf,
|
||||
64,data/PDBBind_processed/6eeb/6eeb_protein_processed.pdb,data/PDBBind_processed/6eeb/6eeb_ligand.mol2,
|
||||
65,data/PDBBind_processed/6n0m/6n0m_protein_processed.pdb,data/PDBBind_processed/6n0m/6n0m_ligand.sdf,
|
||||
66,data/PDBBind_processed/6uvy/6uvy_protein_processed.pdb,data/PDBBind_processed/6uvy/6uvy_ligand.sdf,
|
||||
67,data/PDBBind_processed/6ovz/6ovz_protein_processed.pdb,data/PDBBind_processed/6ovz/6ovz_ligand.sdf,
|
||||
68,data/PDBBind_processed/6olx/6olx_protein_processed.pdb,data/PDBBind_processed/6olx/6olx_ligand.sdf,
|
||||
69,data/PDBBind_processed/6v5l/6v5l_protein_processed.pdb,data/PDBBind_processed/6v5l/6v5l_ligand.mol2,
|
||||
70,data/PDBBind_processed/6hhg/6hhg_protein_processed.pdb,data/PDBBind_processed/6hhg/6hhg_ligand.sdf,
|
||||
71,data/PDBBind_processed/5zcu/5zcu_protein_processed.pdb,data/PDBBind_processed/5zcu/5zcu_ligand.sdf,
|
||||
72,data/PDBBind_processed/6dz2/6dz2_protein_processed.pdb,data/PDBBind_processed/6dz2/6dz2_ligand.mol2,
|
||||
73,data/PDBBind_processed/6mjq/6mjq_protein_processed.pdb,data/PDBBind_processed/6mjq/6mjq_ligand.sdf,
|
||||
74,data/PDBBind_processed/6efk/6efk_protein_processed.pdb,data/PDBBind_processed/6efk/6efk_ligand.sdf,
|
||||
75,data/PDBBind_processed/6s9w/6s9w_protein_processed.pdb,data/PDBBind_processed/6s9w/6s9w_ligand.sdf,
|
||||
76,data/PDBBind_processed/6gdy/6gdy_protein_processed.pdb,data/PDBBind_processed/6gdy/6gdy_ligand.sdf,
|
||||
77,data/PDBBind_processed/6kqi/6kqi_protein_processed.pdb,data/PDBBind_processed/6kqi/6kqi_ligand.sdf,
|
||||
78,data/PDBBind_processed/6ueg/6ueg_protein_processed.pdb,data/PDBBind_processed/6ueg/6ueg_ligand.sdf,
|
||||
79,data/PDBBind_processed/6oxt/6oxt_protein_processed.pdb,data/PDBBind_processed/6oxt/6oxt_ligand.sdf,
|
||||
80,data/PDBBind_processed/6oy0/6oy0_protein_processed.pdb,data/PDBBind_processed/6oy0/6oy0_ligand.sdf,
|
||||
81,data/PDBBind_processed/6qr7/6qr7_protein_processed.pdb,data/PDBBind_processed/6qr7/6qr7_ligand.mol2,
|
||||
82,data/PDBBind_processed/6i41/6i41_protein_processed.pdb,data/PDBBind_processed/6i41/6i41_ligand.sdf,
|
||||
83,data/PDBBind_processed/6cyg/6cyg_protein_processed.pdb,data/PDBBind_processed/6cyg/6cyg_ligand.sdf,
|
||||
84,data/PDBBind_processed/6qmr/6qmr_protein_processed.pdb,data/PDBBind_processed/6qmr/6qmr_ligand.sdf,
|
||||
85,data/PDBBind_processed/6g27/6g27_protein_processed.pdb,data/PDBBind_processed/6g27/6g27_ligand.sdf,
|
||||
86,data/PDBBind_processed/6ggb/6ggb_protein_processed.pdb,data/PDBBind_processed/6ggb/6ggb_ligand.sdf,
|
||||
87,data/PDBBind_processed/6g3c/6g3c_protein_processed.pdb,data/PDBBind_processed/6g3c/6g3c_ligand.sdf,
|
||||
88,data/PDBBind_processed/6n4e/6n4e_protein_processed.pdb,data/PDBBind_processed/6n4e/6n4e_ligand.sdf,
|
||||
89,data/PDBBind_processed/6fcj/6fcj_protein_processed.pdb,data/PDBBind_processed/6fcj/6fcj_ligand.sdf,
|
||||
90,data/PDBBind_processed/6quv/6quv_protein_processed.pdb,data/PDBBind_processed/6quv/6quv_ligand.sdf,
|
||||
91,data/PDBBind_processed/6iql/6iql_protein_processed.pdb,data/PDBBind_processed/6iql/6iql_ligand.mol2,
|
||||
92,data/PDBBind_processed/6i74/6i74_protein_processed.pdb,data/PDBBind_processed/6i74/6i74_ligand.sdf,
|
||||
93,data/PDBBind_processed/6qr4/6qr4_protein_processed.pdb,data/PDBBind_processed/6qr4/6qr4_ligand.mol2,
|
||||
94,data/PDBBind_processed/6rnu/6rnu_protein_processed.pdb,data/PDBBind_processed/6rnu/6rnu_ligand.sdf,
|
||||
95,data/PDBBind_processed/6jib/6jib_protein_processed.pdb,data/PDBBind_processed/6jib/6jib_ligand.sdf,
|
||||
96,data/PDBBind_processed/6izq/6izq_protein_processed.pdb,data/PDBBind_processed/6izq/6izq_ligand.sdf,
|
||||
97,data/PDBBind_processed/6qw8/6qw8_protein_processed.pdb,data/PDBBind_processed/6qw8/6qw8_ligand.sdf,
|
||||
98,data/PDBBind_processed/6qto/6qto_protein_processed.pdb,data/PDBBind_processed/6qto/6qto_ligand.sdf,
|
||||
99,data/PDBBind_processed/6qrd/6qrd_protein_processed.pdb,data/PDBBind_processed/6qrd/6qrd_ligand.mol2,
|
||||
100,data/PDBBind_processed/6hza/6hza_protein_processed.pdb,data/PDBBind_processed/6hza/6hza_ligand.sdf,
|
||||
101,data/PDBBind_processed/6e5s/6e5s_protein_processed.pdb,data/PDBBind_processed/6e5s/6e5s_ligand.sdf,
|
||||
102,data/PDBBind_processed/6dz3/6dz3_protein_processed.pdb,data/PDBBind_processed/6dz3/6dz3_ligand.mol2,
|
||||
103,data/PDBBind_processed/6e6w/6e6w_protein_processed.pdb,data/PDBBind_processed/6e6w/6e6w_ligand.mol2,
|
||||
104,data/PDBBind_processed/6cyh/6cyh_protein_processed.pdb,data/PDBBind_processed/6cyh/6cyh_ligand.sdf,
|
||||
105,data/PDBBind_processed/5zlf/5zlf_protein_processed.pdb,data/PDBBind_processed/5zlf/5zlf_ligand.sdf,
|
||||
106,data/PDBBind_processed/6om4/6om4_protein_processed.pdb,data/PDBBind_processed/6om4/6om4_ligand.sdf,
|
||||
107,data/PDBBind_processed/6gga/6gga_protein_processed.pdb,data/PDBBind_processed/6gga/6gga_ligand.sdf,
|
||||
108,data/PDBBind_processed/6pgp/6pgp_protein_processed.pdb,data/PDBBind_processed/6pgp/6pgp_ligand.sdf,
|
||||
109,data/PDBBind_processed/6qqv/6qqv_protein_processed.pdb,data/PDBBind_processed/6qqv/6qqv_ligand.mol2,
|
||||
110,data/PDBBind_processed/6qtq/6qtq_protein_processed.pdb,data/PDBBind_processed/6qtq/6qtq_ligand.sdf,
|
||||
111,data/PDBBind_processed/6gj6/6gj6_protein_processed.pdb,data/PDBBind_processed/6gj6/6gj6_ligand.mol2,
|
||||
112,data/PDBBind_processed/6os5/6os5_protein_processed.pdb,data/PDBBind_processed/6os5/6os5_ligand.mol2,
|
||||
113,data/PDBBind_processed/6s07/6s07_protein_processed.pdb,data/PDBBind_processed/6s07/6s07_ligand.sdf,
|
||||
114,data/PDBBind_processed/6i77/6i77_protein_processed.pdb,data/PDBBind_processed/6i77/6i77_ligand.sdf,
|
||||
115,data/PDBBind_processed/6hhj/6hhj_protein_processed.pdb,data/PDBBind_processed/6hhj/6hhj_ligand.sdf,
|
||||
116,data/PDBBind_processed/6ahs/6ahs_protein_processed.pdb,data/PDBBind_processed/6ahs/6ahs_ligand.sdf,
|
||||
117,data/PDBBind_processed/6oxx/6oxx_protein_processed.pdb,data/PDBBind_processed/6oxx/6oxx_ligand.sdf,
|
||||
118,data/PDBBind_processed/6mjj/6mjj_protein_processed.pdb,data/PDBBind_processed/6mjj/6mjj_ligand.sdf,
|
||||
119,data/PDBBind_processed/6hor/6hor_protein_processed.pdb,data/PDBBind_processed/6hor/6hor_ligand.sdf,
|
||||
120,data/PDBBind_processed/6jb0/6jb0_protein_processed.pdb,data/PDBBind_processed/6jb0/6jb0_ligand.sdf,
|
||||
121,data/PDBBind_processed/6i68/6i68_protein_processed.pdb,data/PDBBind_processed/6i68/6i68_ligand.sdf,
|
||||
122,data/PDBBind_processed/6pz4/6pz4_protein_processed.pdb,data/PDBBind_processed/6pz4/6pz4_ligand.sdf,
|
||||
123,data/PDBBind_processed/6mhb/6mhb_protein_processed.pdb,data/PDBBind_processed/6mhb/6mhb_ligand.sdf,
|
||||
124,data/PDBBind_processed/6uim/6uim_protein_processed.pdb,data/PDBBind_processed/6uim/6uim_ligand.sdf,
|
||||
125,data/PDBBind_processed/6jsg/6jsg_protein_processed.pdb,data/PDBBind_processed/6jsg/6jsg_ligand.sdf,
|
||||
126,data/PDBBind_processed/6i78/6i78_protein_processed.pdb,data/PDBBind_processed/6i78/6i78_ligand.sdf,
|
||||
127,data/PDBBind_processed/6oxy/6oxy_protein_processed.pdb,data/PDBBind_processed/6oxy/6oxy_ligand.sdf,
|
||||
128,data/PDBBind_processed/6gbw/6gbw_protein_processed.pdb,data/PDBBind_processed/6gbw/6gbw_ligand.sdf,
|
||||
129,data/PDBBind_processed/6mo0/6mo0_protein_processed.pdb,data/PDBBind_processed/6mo0/6mo0_ligand.sdf,
|
||||
130,data/PDBBind_processed/6ggf/6ggf_protein_processed.pdb,data/PDBBind_processed/6ggf/6ggf_ligand.sdf,
|
||||
131,data/PDBBind_processed/6qge/6qge_protein_processed.pdb,data/PDBBind_processed/6qge/6qge_ligand.sdf,
|
||||
132,data/PDBBind_processed/6cjr/6cjr_protein_processed.pdb,data/PDBBind_processed/6cjr/6cjr_ligand.sdf,
|
||||
133,data/PDBBind_processed/6oxp/6oxp_protein_processed.pdb,data/PDBBind_processed/6oxp/6oxp_ligand.sdf,
|
||||
134,data/PDBBind_processed/6d07/6d07_protein_processed.pdb,data/PDBBind_processed/6d07/6d07_ligand.sdf,
|
||||
135,data/PDBBind_processed/6i63/6i63_protein_processed.pdb,data/PDBBind_processed/6i63/6i63_ligand.sdf,
|
||||
136,data/PDBBind_processed/6ten/6ten_protein_processed.pdb,data/PDBBind_processed/6ten/6ten_ligand.sdf,
|
||||
137,data/PDBBind_processed/6uii/6uii_protein_processed.pdb,data/PDBBind_processed/6uii/6uii_ligand.sdf,
|
||||
138,data/PDBBind_processed/6qlr/6qlr_protein_processed.pdb,data/PDBBind_processed/6qlr/6qlr_ligand.sdf,
|
||||
139,data/PDBBind_processed/6sen/6sen_protein_processed.pdb,data/PDBBind_processed/6sen/6sen_ligand.mol2,
|
||||
140,data/PDBBind_processed/6oxv/6oxv_protein_processed.pdb,data/PDBBind_processed/6oxv/6oxv_ligand.sdf,
|
||||
141,data/PDBBind_processed/6g2b/6g2b_protein_processed.pdb,data/PDBBind_processed/6g2b/6g2b_ligand.sdf,
|
||||
142,data/PDBBind_processed/5zr3/5zr3_protein_processed.pdb,data/PDBBind_processed/5zr3/5zr3_ligand.sdf,
|
||||
143,data/PDBBind_processed/6kjf/6kjf_protein_processed.pdb,data/PDBBind_processed/6kjf/6kjf_ligand.sdf,
|
||||
144,data/PDBBind_processed/6qr9/6qr9_protein_processed.pdb,data/PDBBind_processed/6qr9/6qr9_ligand.mol2,
|
||||
145,data/PDBBind_processed/6g9f/6g9f_protein_processed.pdb,data/PDBBind_processed/6g9f/6g9f_ligand.sdf,
|
||||
146,data/PDBBind_processed/6e6v/6e6v_protein_processed.pdb,data/PDBBind_processed/6e6v/6e6v_ligand.sdf,
|
||||
147,data/PDBBind_processed/5zk9/5zk9_protein_processed.pdb,data/PDBBind_processed/5zk9/5zk9_ligand.sdf,
|
||||
148,data/PDBBind_processed/6pnn/6pnn_protein_processed.pdb,data/PDBBind_processed/6pnn/6pnn_ligand.sdf,
|
||||
149,data/PDBBind_processed/6nri/6nri_protein_processed.pdb,data/PDBBind_processed/6nri/6nri_ligand.sdf,
|
||||
150,data/PDBBind_processed/6uwv/6uwv_protein_processed.pdb,data/PDBBind_processed/6uwv/6uwv_ligand.sdf,
|
||||
151,data/PDBBind_processed/6ooz/6ooz_protein_processed.pdb,data/PDBBind_processed/6ooz/6ooz_ligand.sdf,
|
||||
152,data/PDBBind_processed/6npi/6npi_protein_processed.pdb,data/PDBBind_processed/6npi/6npi_ligand.sdf,
|
||||
153,data/PDBBind_processed/6oip/6oip_protein_processed.pdb,data/PDBBind_processed/6oip/6oip_ligand.sdf,
|
||||
154,data/PDBBind_processed/6miv/6miv_protein_processed.pdb,data/PDBBind_processed/6miv/6miv_ligand.sdf,
|
||||
155,data/PDBBind_processed/6s57/6s57_protein_processed.pdb,data/PDBBind_processed/6s57/6s57_ligand.sdf,
|
||||
156,data/PDBBind_processed/6p8x/6p8x_protein_processed.pdb,data/PDBBind_processed/6p8x/6p8x_ligand.sdf,
|
||||
157,data/PDBBind_processed/6hoq/6hoq_protein_processed.pdb,data/PDBBind_processed/6hoq/6hoq_ligand.sdf,
|
||||
158,data/PDBBind_processed/6qts/6qts_protein_processed.pdb,data/PDBBind_processed/6qts/6qts_ligand.sdf,
|
||||
159,data/PDBBind_processed/6ggd/6ggd_protein_processed.pdb,data/PDBBind_processed/6ggd/6ggd_ligand.sdf,
|
||||
160,data/PDBBind_processed/6pnm/6pnm_protein_processed.pdb,data/PDBBind_processed/6pnm/6pnm_ligand.sdf,
|
||||
161,data/PDBBind_processed/6oy2/6oy2_protein_processed.pdb,data/PDBBind_processed/6oy2/6oy2_ligand.sdf,
|
||||
162,data/PDBBind_processed/6oi8/6oi8_protein_processed.pdb,data/PDBBind_processed/6oi8/6oi8_ligand.sdf,
|
||||
163,data/PDBBind_processed/6mhd/6mhd_protein_processed.pdb,data/PDBBind_processed/6mhd/6mhd_ligand.sdf,
|
||||
164,data/PDBBind_processed/6agt/6agt_protein_processed.pdb,data/PDBBind_processed/6agt/6agt_ligand.sdf,
|
||||
165,data/PDBBind_processed/6i5p/6i5p_protein_processed.pdb,data/PDBBind_processed/6i5p/6i5p_ligand.sdf,
|
||||
166,data/PDBBind_processed/6hhr/6hhr_protein_processed.pdb,data/PDBBind_processed/6hhr/6hhr_ligand.sdf,
|
||||
167,data/PDBBind_processed/6p8z/6p8z_protein_processed.pdb,data/PDBBind_processed/6p8z/6p8z_ligand.sdf,
|
||||
168,data/PDBBind_processed/6c85/6c85_protein_processed.pdb,data/PDBBind_processed/6c85/6c85_ligand.sdf,
|
||||
169,data/PDBBind_processed/6g5u/6g5u_protein_processed.pdb,data/PDBBind_processed/6g5u/6g5u_ligand.sdf,
|
||||
170,data/PDBBind_processed/6j06/6j06_protein_processed.pdb,data/PDBBind_processed/6j06/6j06_ligand.sdf,
|
||||
171,data/PDBBind_processed/6qsz/6qsz_protein_processed.pdb,data/PDBBind_processed/6qsz/6qsz_ligand.sdf,
|
||||
172,data/PDBBind_processed/6jbb/6jbb_protein_processed.pdb,data/PDBBind_processed/6jbb/6jbb_ligand.sdf,
|
||||
173,data/PDBBind_processed/6hhp/6hhp_protein_processed.pdb,data/PDBBind_processed/6hhp/6hhp_ligand.sdf,
|
||||
174,data/PDBBind_processed/6np5/6np5_protein_processed.pdb,data/PDBBind_processed/6np5/6np5_ligand.sdf,
|
||||
175,data/PDBBind_processed/6nlj/6nlj_protein_processed.pdb,data/PDBBind_processed/6nlj/6nlj_ligand.sdf,
|
||||
176,data/PDBBind_processed/6qlp/6qlp_protein_processed.pdb,data/PDBBind_processed/6qlp/6qlp_ligand.sdf,
|
||||
177,data/PDBBind_processed/6n94/6n94_protein_processed.pdb,data/PDBBind_processed/6n94/6n94_ligand.sdf,
|
||||
178,data/PDBBind_processed/6e13/6e13_protein_processed.pdb,data/PDBBind_processed/6e13/6e13_ligand.sdf,
|
||||
179,data/PDBBind_processed/6qls/6qls_protein_processed.pdb,data/PDBBind_processed/6qls/6qls_ligand.sdf,
|
||||
180,data/PDBBind_processed/6uil/6uil_protein_processed.pdb,data/PDBBind_processed/6uil/6uil_ligand.sdf,
|
||||
181,data/PDBBind_processed/6st3/6st3_protein_processed.pdb,data/PDBBind_processed/6st3/6st3_ligand.sdf,
|
||||
182,data/PDBBind_processed/6n92/6n92_protein_processed.pdb,data/PDBBind_processed/6n92/6n92_ligand.sdf,
|
||||
183,data/PDBBind_processed/6s56/6s56_protein_processed.pdb,data/PDBBind_processed/6s56/6s56_ligand.sdf,
|
||||
184,data/PDBBind_processed/6hzd/6hzd_protein_processed.pdb,data/PDBBind_processed/6hzd/6hzd_ligand.sdf,
|
||||
185,data/PDBBind_processed/6uhv/6uhv_protein_processed.pdb,data/PDBBind_processed/6uhv/6uhv_ligand.sdf,
|
||||
186,data/PDBBind_processed/6k05/6k05_protein_processed.pdb,data/PDBBind_processed/6k05/6k05_ligand.sdf,
|
||||
187,data/PDBBind_processed/6q36/6q36_protein_processed.pdb,data/PDBBind_processed/6q36/6q36_ligand.mol2,
|
||||
188,data/PDBBind_processed/6ic0/6ic0_protein_processed.pdb,data/PDBBind_processed/6ic0/6ic0_ligand.sdf,
|
||||
189,data/PDBBind_processed/6hhi/6hhi_protein_processed.pdb,data/PDBBind_processed/6hhi/6hhi_ligand.sdf,
|
||||
190,data/PDBBind_processed/6e3m/6e3m_protein_processed.pdb,data/PDBBind_processed/6e3m/6e3m_ligand.sdf,
|
||||
191,data/PDBBind_processed/6qtx/6qtx_protein_processed.pdb,data/PDBBind_processed/6qtx/6qtx_ligand.sdf,
|
||||
192,data/PDBBind_processed/6jse/6jse_protein_processed.pdb,data/PDBBind_processed/6jse/6jse_ligand.sdf,
|
||||
193,data/PDBBind_processed/5zjy/5zjy_protein_processed.pdb,data/PDBBind_processed/5zjy/5zjy_ligand.sdf,
|
||||
194,data/PDBBind_processed/6o3y/6o3y_protein_processed.pdb,data/PDBBind_processed/6o3y/6o3y_ligand.sdf,
|
||||
195,data/PDBBind_processed/6rpg/6rpg_protein_processed.pdb,data/PDBBind_processed/6rpg/6rpg_ligand.sdf,
|
||||
196,data/PDBBind_processed/6rr0/6rr0_protein_processed.pdb,data/PDBBind_processed/6rr0/6rr0_ligand.sdf,
|
||||
197,data/PDBBind_processed/6gzy/6gzy_protein_processed.pdb,data/PDBBind_processed/6gzy/6gzy_ligand.sdf,
|
||||
198,data/PDBBind_processed/6qlt/6qlt_protein_processed.pdb,data/PDBBind_processed/6qlt/6qlt_ligand.sdf,
|
||||
199,data/PDBBind_processed/6ufo/6ufo_protein_processed.pdb,data/PDBBind_processed/6ufo/6ufo_ligand.sdf,
|
||||
200,data/PDBBind_processed/6o0h/6o0h_protein_processed.pdb,data/PDBBind_processed/6o0h/6o0h_ligand.sdf,
|
||||
201,data/PDBBind_processed/6o3x/6o3x_protein_processed.pdb,data/PDBBind_processed/6o3x/6o3x_ligand.sdf,
|
||||
202,data/PDBBind_processed/5zjz/5zjz_protein_processed.pdb,data/PDBBind_processed/5zjz/5zjz_ligand.mol2,
|
||||
203,data/PDBBind_processed/6i8t/6i8t_protein_processed.pdb,data/PDBBind_processed/6i8t/6i8t_ligand.sdf,
|
||||
204,data/PDBBind_processed/6ooy/6ooy_protein_processed.pdb,data/PDBBind_processed/6ooy/6ooy_ligand.sdf,
|
||||
205,data/PDBBind_processed/6oiq/6oiq_protein_processed.pdb,data/PDBBind_processed/6oiq/6oiq_ligand.sdf,
|
||||
206,data/PDBBind_processed/6od6/6od6_protein_processed.pdb,data/PDBBind_processed/6od6/6od6_ligand.sdf,
|
||||
207,data/PDBBind_processed/6nrh/6nrh_protein_processed.pdb,data/PDBBind_processed/6nrh/6nrh_ligand.sdf,
|
||||
208,data/PDBBind_processed/6qra/6qra_protein_processed.pdb,data/PDBBind_processed/6qra/6qra_ligand.mol2,
|
||||
209,data/PDBBind_processed/6hhh/6hhh_protein_processed.pdb,data/PDBBind_processed/6hhh/6hhh_ligand.sdf,
|
||||
210,data/PDBBind_processed/6m7h/6m7h_protein_processed.pdb,data/PDBBind_processed/6m7h/6m7h_ligand.sdf,
|
||||
211,data/PDBBind_processed/6ufn/6ufn_protein_processed.pdb,data/PDBBind_processed/6ufn/6ufn_ligand.sdf,
|
||||
212,data/PDBBind_processed/6qr0/6qr0_protein_processed.pdb,data/PDBBind_processed/6qr0/6qr0_ligand.mol2,
|
||||
213,data/PDBBind_processed/6o5u/6o5u_protein_processed.pdb,data/PDBBind_processed/6o5u/6o5u_ligand.sdf,
|
||||
214,data/PDBBind_processed/6h14/6h14_protein_processed.pdb,data/PDBBind_processed/6h14/6h14_ligand.sdf,
|
||||
215,data/PDBBind_processed/6jwa/6jwa_protein_processed.pdb,data/PDBBind_processed/6jwa/6jwa_ligand.sdf,
|
||||
216,data/PDBBind_processed/6ny0/6ny0_protein_processed.pdb,data/PDBBind_processed/6ny0/6ny0_ligand.sdf,
|
||||
217,data/PDBBind_processed/6jan/6jan_protein_processed.pdb,data/PDBBind_processed/6jan/6jan_ligand.sdf,
|
||||
218,data/PDBBind_processed/6ftf/6ftf_protein_processed.pdb,data/PDBBind_processed/6ftf/6ftf_ligand.sdf,
|
||||
219,data/PDBBind_processed/6oxw/6oxw_protein_processed.pdb,data/PDBBind_processed/6oxw/6oxw_ligand.sdf,
|
||||
220,data/PDBBind_processed/6jon/6jon_protein_processed.pdb,data/PDBBind_processed/6jon/6jon_ligand.sdf,
|
||||
221,data/PDBBind_processed/6cf7/6cf7_protein_processed.pdb,data/PDBBind_processed/6cf7/6cf7_ligand.sdf,
|
||||
222,data/PDBBind_processed/6rtn/6rtn_protein_processed.pdb,data/PDBBind_processed/6rtn/6rtn_ligand.mol2,
|
||||
223,data/PDBBind_processed/6jsz/6jsz_protein_processed.pdb,data/PDBBind_processed/6jsz/6jsz_ligand.sdf,
|
||||
224,data/PDBBind_processed/6o9c/6o9c_protein_processed.pdb,data/PDBBind_processed/6o9c/6o9c_ligand.sdf,
|
||||
225,data/PDBBind_processed/6mo8/6mo8_protein_processed.pdb,data/PDBBind_processed/6mo8/6mo8_ligand.sdf,
|
||||
226,data/PDBBind_processed/6qln/6qln_protein_processed.pdb,data/PDBBind_processed/6qln/6qln_ligand.sdf,
|
||||
227,data/PDBBind_processed/6qqu/6qqu_protein_processed.pdb,data/PDBBind_processed/6qqu/6qqu_ligand.mol2,
|
||||
228,data/PDBBind_processed/6i66/6i66_protein_processed.pdb,data/PDBBind_processed/6i66/6i66_ligand.sdf,
|
||||
229,data/PDBBind_processed/6mja/6mja_protein_processed.pdb,data/PDBBind_processed/6mja/6mja_ligand.sdf,
|
||||
230,data/PDBBind_processed/6gwe/6gwe_protein_processed.pdb,data/PDBBind_processed/6gwe/6gwe_ligand.mol2,
|
||||
231,data/PDBBind_processed/6d3z/6d3z_protein_processed.pdb,data/PDBBind_processed/6d3z/6d3z_ligand.sdf,
|
||||
232,data/PDBBind_processed/6oxr/6oxr_protein_processed.pdb,data/PDBBind_processed/6oxr/6oxr_ligand.sdf,
|
||||
233,data/PDBBind_processed/6r4k/6r4k_protein_processed.pdb,data/PDBBind_processed/6r4k/6r4k_ligand.sdf,
|
||||
234,data/PDBBind_processed/6hle/6hle_protein_processed.pdb,data/PDBBind_processed/6hle/6hle_ligand.sdf,
|
||||
235,data/PDBBind_processed/6h9v/6h9v_protein_processed.pdb,data/PDBBind_processed/6h9v/6h9v_ligand.sdf,
|
||||
236,data/PDBBind_processed/6hou/6hou_protein_processed.pdb,data/PDBBind_processed/6hou/6hou_ligand.sdf,
|
||||
237,data/PDBBind_processed/6nv9/6nv9_protein_processed.pdb,data/PDBBind_processed/6nv9/6nv9_ligand.sdf,
|
||||
238,data/PDBBind_processed/6py0/6py0_protein_processed.pdb,data/PDBBind_processed/6py0/6py0_ligand.sdf,
|
||||
239,data/PDBBind_processed/6qlq/6qlq_protein_processed.pdb,data/PDBBind_processed/6qlq/6qlq_ligand.sdf,
|
||||
240,data/PDBBind_processed/6nv7/6nv7_protein_processed.pdb,data/PDBBind_processed/6nv7/6nv7_ligand.sdf,
|
||||
241,data/PDBBind_processed/6n4b/6n4b_protein_processed.pdb,data/PDBBind_processed/6n4b/6n4b_ligand.sdf,
|
||||
242,data/PDBBind_processed/6jaq/6jaq_protein_processed.pdb,data/PDBBind_processed/6jaq/6jaq_ligand.sdf,
|
||||
243,data/PDBBind_processed/6i8m/6i8m_protein_processed.pdb,data/PDBBind_processed/6i8m/6i8m_ligand.sdf,
|
||||
244,data/PDBBind_processed/6dz0/6dz0_protein_processed.pdb,data/PDBBind_processed/6dz0/6dz0_ligand.mol2,
|
||||
245,data/PDBBind_processed/6oxs/6oxs_protein_processed.pdb,data/PDBBind_processed/6oxs/6oxs_ligand.sdf,
|
||||
246,data/PDBBind_processed/6k2n/6k2n_protein_processed.pdb,data/PDBBind_processed/6k2n/6k2n_ligand.sdf,
|
||||
247,data/PDBBind_processed/6cjj/6cjj_protein_processed.pdb,data/PDBBind_processed/6cjj/6cjj_ligand.sdf,
|
||||
248,data/PDBBind_processed/6ffg/6ffg_protein_processed.pdb,data/PDBBind_processed/6ffg/6ffg_ligand.sdf,
|
||||
249,data/PDBBind_processed/6a73/6a73_protein_processed.pdb,data/PDBBind_processed/6a73/6a73_ligand.sdf,
|
||||
250,data/PDBBind_processed/6qqt/6qqt_protein_processed.pdb,data/PDBBind_processed/6qqt/6qqt_ligand.mol2,
|
||||
251,data/PDBBind_processed/6a1c/6a1c_protein_processed.pdb,data/PDBBind_processed/6a1c/6a1c_ligand.sdf,
|
||||
252,data/PDBBind_processed/6oxu/6oxu_protein_processed.pdb,data/PDBBind_processed/6oxu/6oxu_ligand.sdf,
|
||||
253,data/PDBBind_processed/6qre/6qre_protein_processed.pdb,data/PDBBind_processed/6qre/6qre_ligand.mol2,
|
||||
254,data/PDBBind_processed/6qtw/6qtw_protein_processed.pdb,data/PDBBind_processed/6qtw/6qtw_ligand.sdf,
|
||||
255,data/PDBBind_processed/6np4/6np4_protein_processed.pdb,data/PDBBind_processed/6np4/6np4_ligand.sdf,
|
||||
256,data/PDBBind_processed/6hv2/6hv2_protein_processed.pdb,data/PDBBind_processed/6hv2/6hv2_ligand.sdf,
|
||||
257,data/PDBBind_processed/6n55/6n55_protein_processed.pdb,data/PDBBind_processed/6n55/6n55_ligand.sdf,
|
||||
258,data/PDBBind_processed/6e3o/6e3o_protein_processed.pdb,data/PDBBind_processed/6e3o/6e3o_ligand.sdf,
|
||||
259,data/PDBBind_processed/6kjd/6kjd_protein_processed.pdb,data/PDBBind_processed/6kjd/6kjd_ligand.sdf,
|
||||
260,data/PDBBind_processed/6sfc/6sfc_protein_processed.pdb,data/PDBBind_processed/6sfc/6sfc_ligand.sdf,
|
||||
261,data/PDBBind_processed/6qi7/6qi7_protein_processed.pdb,data/PDBBind_processed/6qi7/6qi7_ligand.sdf,
|
||||
262,data/PDBBind_processed/6hzc/6hzc_protein_processed.pdb,data/PDBBind_processed/6hzc/6hzc_ligand.sdf,
|
||||
263,data/PDBBind_processed/6k04/6k04_protein_processed.pdb,data/PDBBind_processed/6k04/6k04_ligand.sdf,
|
||||
264,data/PDBBind_processed/6op0/6op0_protein_processed.pdb,data/PDBBind_processed/6op0/6op0_ligand.sdf,
|
||||
265,data/PDBBind_processed/6q38/6q38_protein_processed.pdb,data/PDBBind_processed/6q38/6q38_ligand.mol2,
|
||||
266,data/PDBBind_processed/6n8x/6n8x_protein_processed.pdb,data/PDBBind_processed/6n8x/6n8x_ligand.sdf,
|
||||
267,data/PDBBind_processed/6np3/6np3_protein_processed.pdb,data/PDBBind_processed/6np3/6np3_ligand.sdf,
|
||||
268,data/PDBBind_processed/6uvv/6uvv_protein_processed.pdb,data/PDBBind_processed/6uvv/6uvv_ligand.sdf,
|
||||
269,data/PDBBind_processed/6pgo/6pgo_protein_processed.pdb,data/PDBBind_processed/6pgo/6pgo_ligand.sdf,
|
||||
270,data/PDBBind_processed/6jbe/6jbe_protein_processed.pdb,data/PDBBind_processed/6jbe/6jbe_ligand.sdf,
|
||||
271,data/PDBBind_processed/6i75/6i75_protein_processed.pdb,data/PDBBind_processed/6i75/6i75_ligand.sdf,
|
||||
272,data/PDBBind_processed/6qqq/6qqq_protein_processed.pdb,data/PDBBind_processed/6qqq/6qqq_ligand.mol2,
|
||||
273,data/PDBBind_processed/6i62/6i62_protein_processed.pdb,data/PDBBind_processed/6i62/6i62_ligand.sdf,
|
||||
274,data/PDBBind_processed/6j9y/6j9y_protein_processed.pdb,data/PDBBind_processed/6j9y/6j9y_ligand.sdf,
|
||||
275,data/PDBBind_processed/6g29/6g29_protein_processed.pdb,data/PDBBind_processed/6g29/6g29_ligand.sdf,
|
||||
276,data/PDBBind_processed/6h7d/6h7d_protein_processed.pdb,data/PDBBind_processed/6h7d/6h7d_ligand.sdf,
|
||||
277,data/PDBBind_processed/6mo9/6mo9_protein_processed.pdb,data/PDBBind_processed/6mo9/6mo9_ligand.sdf,
|
||||
278,data/PDBBind_processed/6jao/6jao_protein_processed.pdb,data/PDBBind_processed/6jao/6jao_ligand.sdf,
|
||||
279,data/PDBBind_processed/6jmf/6jmf_protein_processed.pdb,data/PDBBind_processed/6jmf/6jmf_ligand.sdf,
|
||||
280,data/PDBBind_processed/6hmy/6hmy_protein_processed.pdb,data/PDBBind_processed/6hmy/6hmy_ligand.sdf,
|
||||
281,data/PDBBind_processed/6qfe/6qfe_protein_processed.pdb,data/PDBBind_processed/6qfe/6qfe_ligand.mol2,
|
||||
282,data/PDBBind_processed/5zml/5zml_protein_processed.pdb,data/PDBBind_processed/5zml/5zml_ligand.sdf,
|
||||
283,data/PDBBind_processed/6i65/6i65_protein_processed.pdb,data/PDBBind_processed/6i65/6i65_ligand.sdf,
|
||||
284,data/PDBBind_processed/6e7m/6e7m_protein_processed.pdb,data/PDBBind_processed/6e7m/6e7m_ligand.sdf,
|
||||
285,data/PDBBind_processed/6i61/6i61_protein_processed.pdb,data/PDBBind_processed/6i61/6i61_ligand.sdf,
|
||||
286,data/PDBBind_processed/6rz6/6rz6_protein_processed.pdb,data/PDBBind_processed/6rz6/6rz6_ligand.sdf,
|
||||
287,data/PDBBind_processed/6qtm/6qtm_protein_processed.pdb,data/PDBBind_processed/6qtm/6qtm_ligand.sdf,
|
||||
288,data/PDBBind_processed/6qlo/6qlo_protein_processed.pdb,data/PDBBind_processed/6qlo/6qlo_ligand.sdf,
|
||||
289,data/PDBBind_processed/6oie/6oie_protein_processed.pdb,data/PDBBind_processed/6oie/6oie_ligand.sdf,
|
||||
290,data/PDBBind_processed/6miy/6miy_protein_processed.pdb,data/PDBBind_processed/6miy/6miy_ligand.sdf,
|
||||
291,data/PDBBind_processed/6nrf/6nrf_protein_processed.pdb,data/PDBBind_processed/6nrf/6nrf_ligand.mol2,
|
||||
292,data/PDBBind_processed/6gj5/6gj5_protein_processed.pdb,data/PDBBind_processed/6gj5/6gj5_ligand.mol2,
|
||||
293,data/PDBBind_processed/6jad/6jad_protein_processed.pdb,data/PDBBind_processed/6jad/6jad_ligand.sdf,
|
||||
294,data/PDBBind_processed/6mj4/6mj4_protein_processed.pdb,data/PDBBind_processed/6mj4/6mj4_ligand.sdf,
|
||||
295,data/PDBBind_processed/6h12/6h12_protein_processed.pdb,data/PDBBind_processed/6h12/6h12_ligand.sdf,
|
||||
296,data/PDBBind_processed/6d3y/6d3y_protein_processed.pdb,data/PDBBind_processed/6d3y/6d3y_ligand.sdf,
|
||||
297,data/PDBBind_processed/6qr2/6qr2_protein_processed.pdb,data/PDBBind_processed/6qr2/6qr2_ligand.mol2,
|
||||
298,data/PDBBind_processed/6qxa/6qxa_protein_processed.pdb,data/PDBBind_processed/6qxa/6qxa_ligand.mol2,
|
||||
299,data/PDBBind_processed/6o9b/6o9b_protein_processed.pdb,data/PDBBind_processed/6o9b/6o9b_ligand.sdf,
|
||||
300,data/PDBBind_processed/6ckl/6ckl_protein_processed.pdb,data/PDBBind_processed/6ckl/6ckl_ligand.sdf,
|
||||
301,data/PDBBind_processed/6oir/6oir_protein_processed.pdb,data/PDBBind_processed/6oir/6oir_ligand.sdf,
|
||||
302,data/PDBBind_processed/6d40/6d40_protein_processed.pdb,data/PDBBind_processed/6d40/6d40_ligand.sdf,
|
||||
303,data/PDBBind_processed/6e6j/6e6j_protein_processed.pdb,data/PDBBind_processed/6e6j/6e6j_ligand.mol2,
|
||||
304,data/PDBBind_processed/6i7a/6i7a_protein_processed.pdb,data/PDBBind_processed/6i7a/6i7a_ligand.sdf,
|
||||
305,data/PDBBind_processed/6g25/6g25_protein_processed.pdb,data/PDBBind_processed/6g25/6g25_ligand.mol2,
|
||||
306,data/PDBBind_processed/6oin/6oin_protein_processed.pdb,data/PDBBind_processed/6oin/6oin_ligand.sdf,
|
||||
307,data/PDBBind_processed/6jam/6jam_protein_processed.pdb,data/PDBBind_processed/6jam/6jam_ligand.sdf,
|
||||
308,data/PDBBind_processed/6oxz/6oxz_protein_processed.pdb,data/PDBBind_processed/6oxz/6oxz_ligand.sdf,
|
||||
309,data/PDBBind_processed/6hop/6hop_protein_processed.pdb,data/PDBBind_processed/6hop/6hop_ligand.sdf,
|
||||
310,data/PDBBind_processed/6rot/6rot_protein_processed.pdb,data/PDBBind_processed/6rot/6rot_ligand.sdf,
|
||||
311,data/PDBBind_processed/6uhu/6uhu_protein_processed.pdb,data/PDBBind_processed/6uhu/6uhu_ligand.mol2,
|
||||
312,data/PDBBind_processed/6mji/6mji_protein_processed.pdb,data/PDBBind_processed/6mji/6mji_ligand.sdf,
|
||||
313,data/PDBBind_processed/6nrj/6nrj_protein_processed.pdb,data/PDBBind_processed/6nrj/6nrj_ligand.mol2,
|
||||
314,data/PDBBind_processed/6nt2/6nt2_protein_processed.pdb,data/PDBBind_processed/6nt2/6nt2_ligand.mol2,
|
||||
315,data/PDBBind_processed/6op9/6op9_protein_processed.pdb,data/PDBBind_processed/6op9/6op9_ligand.sdf,
|
||||
316,data/PDBBind_processed/6pno/6pno_protein_processed.pdb,data/PDBBind_processed/6pno/6pno_ligand.sdf,
|
||||
317,data/PDBBind_processed/6e4v/6e4v_protein_processed.pdb,data/PDBBind_processed/6e4v/6e4v_ligand.sdf,
|
||||
318,data/PDBBind_processed/6k1s/6k1s_protein_processed.pdb,data/PDBBind_processed/6k1s/6k1s_ligand.sdf,
|
||||
319,data/PDBBind_processed/6a87/6a87_protein_processed.pdb,data/PDBBind_processed/6a87/6a87_ligand.sdf,
|
||||
320,data/PDBBind_processed/6oim/6oim_protein_processed.pdb,data/PDBBind_processed/6oim/6oim_ligand.sdf,
|
||||
321,data/PDBBind_processed/6cjp/6cjp_protein_processed.pdb,data/PDBBind_processed/6cjp/6cjp_ligand.sdf,
|
||||
322,data/PDBBind_processed/6pyb/6pyb_protein_processed.pdb,data/PDBBind_processed/6pyb/6pyb_ligand.sdf,
|
||||
323,data/PDBBind_processed/6h13/6h13_protein_processed.pdb,data/PDBBind_processed/6h13/6h13_ligand.sdf,
|
||||
324,data/PDBBind_processed/6qrf/6qrf_protein_processed.pdb,data/PDBBind_processed/6qrf/6qrf_ligand.mol2,
|
||||
325,data/PDBBind_processed/6mhc/6mhc_protein_processed.pdb,data/PDBBind_processed/6mhc/6mhc_ligand.sdf,
|
||||
326,data/PDBBind_processed/6j9w/6j9w_protein_processed.pdb,data/PDBBind_processed/6j9w/6j9w_ligand.sdf,
|
||||
327,data/PDBBind_processed/6nrg/6nrg_protein_processed.pdb,data/PDBBind_processed/6nrg/6nrg_ligand.mol2,
|
||||
328,data/PDBBind_processed/6fff/6fff_protein_processed.pdb,data/PDBBind_processed/6fff/6fff_ligand.sdf,
|
||||
329,data/PDBBind_processed/6n93/6n93_protein_processed.pdb,data/PDBBind_processed/6n93/6n93_ligand.sdf,
|
||||
330,data/PDBBind_processed/6jut/6jut_protein_processed.pdb,data/PDBBind_processed/6jut/6jut_ligand.mol2,
|
||||
331,data/PDBBind_processed/6g2e/6g2e_protein_processed.pdb,data/PDBBind_processed/6g2e/6g2e_ligand.sdf,
|
||||
332,data/PDBBind_processed/6nd3/6nd3_protein_processed.pdb,data/PDBBind_processed/6nd3/6nd3_ligand.sdf,
|
||||
333,data/PDBBind_processed/6os6/6os6_protein_processed.pdb,data/PDBBind_processed/6os6/6os6_ligand.mol2,
|
||||
334,data/PDBBind_processed/6dql/6dql_protein_processed.pdb,data/PDBBind_processed/6dql/6dql_ligand.mol2,
|
||||
335,data/PDBBind_processed/6inz/6inz_protein_processed.pdb,data/PDBBind_processed/6inz/6inz_ligand.sdf,
|
||||
336,data/PDBBind_processed/6i67/6i67_protein_processed.pdb,data/PDBBind_processed/6i67/6i67_ligand.sdf,
|
||||
337,data/PDBBind_processed/6quw/6quw_protein_processed.pdb,data/PDBBind_processed/6quw/6quw_ligand.sdf,
|
||||
338,data/PDBBind_processed/6qwi/6qwi_protein_processed.pdb,data/PDBBind_processed/6qwi/6qwi_ligand.sdf,
|
||||
339,data/PDBBind_processed/6npm/6npm_protein_processed.pdb,data/PDBBind_processed/6npm/6npm_ligand.sdf,
|
||||
340,data/PDBBind_processed/6i64/6i64_protein_processed.pdb,data/PDBBind_processed/6i64/6i64_ligand.sdf,
|
||||
341,data/PDBBind_processed/6e3n/6e3n_protein_processed.pdb,data/PDBBind_processed/6e3n/6e3n_ligand.sdf,
|
||||
342,data/PDBBind_processed/6qrg/6qrg_protein_processed.pdb,data/PDBBind_processed/6qrg/6qrg_ligand.mol2,
|
||||
343,data/PDBBind_processed/6nxz/6nxz_protein_processed.pdb,data/PDBBind_processed/6nxz/6nxz_ligand.sdf,
|
||||
344,data/PDBBind_processed/6iby/6iby_protein_processed.pdb,data/PDBBind_processed/6iby/6iby_ligand.sdf,
|
||||
345,data/PDBBind_processed/6gj7/6gj7_protein_processed.pdb,data/PDBBind_processed/6gj7/6gj7_ligand.mol2,
|
||||
346,data/PDBBind_processed/6qr3/6qr3_protein_processed.pdb,data/PDBBind_processed/6qr3/6qr3_ligand.mol2,
|
||||
347,data/PDBBind_processed/6qr1/6qr1_protein_processed.pdb,data/PDBBind_processed/6qr1/6qr1_ligand.mol2,
|
||||
348,data/PDBBind_processed/6s9x/6s9x_protein_processed.pdb,data/PDBBind_processed/6s9x/6s9x_ligand.sdf,
|
||||
349,data/PDBBind_processed/6q4q/6q4q_protein_processed.pdb,data/PDBBind_processed/6q4q/6q4q_ligand.mol2,
|
||||
350,data/PDBBind_processed/6hbn/6hbn_protein_processed.pdb,data/PDBBind_processed/6hbn/6hbn_ligand.sdf,
|
||||
351,data/PDBBind_processed/6nw3/6nw3_protein_processed.pdb,data/PDBBind_processed/6nw3/6nw3_ligand.sdf,
|
||||
352,data/PDBBind_processed/6tel/6tel_protein_processed.pdb,data/PDBBind_processed/6tel/6tel_ligand.sdf,
|
||||
353,data/PDBBind_processed/6p8y/6p8y_protein_processed.pdb,data/PDBBind_processed/6p8y/6p8y_ligand.sdf,
|
||||
354,data/PDBBind_processed/6d5w/6d5w_protein_processed.pdb,data/PDBBind_processed/6d5w/6d5w_ligand.sdf,
|
||||
355,data/PDBBind_processed/6t6a/6t6a_protein_processed.pdb,data/PDBBind_processed/6t6a/6t6a_ligand.mol2,
|
||||
356,data/PDBBind_processed/6o5g/6o5g_protein_processed.pdb,data/PDBBind_processed/6o5g/6o5g_ligand.mol2,
|
||||
357,data/PDBBind_processed/6r7d/6r7d_protein_processed.pdb,data/PDBBind_processed/6r7d/6r7d_ligand.sdf,
|
||||
358,data/PDBBind_processed/6pya/6pya_protein_processed.pdb,data/PDBBind_processed/6pya/6pya_ligand.mol2,
|
||||
359,data/PDBBind_processed/6ffe/6ffe_protein_processed.pdb,data/PDBBind_processed/6ffe/6ffe_ligand.sdf,
|
||||
360,data/PDBBind_processed/6d3x/6d3x_protein_processed.pdb,data/PDBBind_processed/6d3x/6d3x_ligand.sdf,
|
||||
361,data/PDBBind_processed/6gj8/6gj8_protein_processed.pdb,data/PDBBind_processed/6gj8/6gj8_ligand.mol2,
|
||||
362,data/PDBBind_processed/6mo2/6mo2_protein_processed.pdb,data/PDBBind_processed/6mo2/6mo2_ligand.mol2,
|
||||
,protein_path,ligand
|
||||
0,data/PDBBind_processed/6qqw/6qqw_protein_processed.pdb,data/PDBBind_processed/6qqw/6qqw_ligand.mol2
|
||||
1,data/PDBBind_processed/6d08/6d08_protein_processed.pdb,data/PDBBind_processed/6d08/6d08_ligand.sdf
|
||||
2,data/PDBBind_processed/6jap/6jap_protein_processed.pdb,data/PDBBind_processed/6jap/6jap_ligand.sdf
|
||||
3,data/PDBBind_processed/6np2/6np2_protein_processed.pdb,data/PDBBind_processed/6np2/6np2_ligand.sdf
|
||||
4,data/PDBBind_processed/6uvp/6uvp_protein_processed.pdb,data/PDBBind_processed/6uvp/6uvp_ligand.sdf
|
||||
5,data/PDBBind_processed/6oxq/6oxq_protein_processed.pdb,data/PDBBind_processed/6oxq/6oxq_ligand.sdf
|
||||
6,data/PDBBind_processed/6jsn/6jsn_protein_processed.pdb,data/PDBBind_processed/6jsn/6jsn_ligand.sdf
|
||||
7,data/PDBBind_processed/6hzb/6hzb_protein_processed.pdb,data/PDBBind_processed/6hzb/6hzb_ligand.sdf
|
||||
8,data/PDBBind_processed/6qrc/6qrc_protein_processed.pdb,data/PDBBind_processed/6qrc/6qrc_ligand.mol2
|
||||
9,data/PDBBind_processed/6oio/6oio_protein_processed.pdb,data/PDBBind_processed/6oio/6oio_ligand.sdf
|
||||
10,data/PDBBind_processed/6jag/6jag_protein_processed.pdb,data/PDBBind_processed/6jag/6jag_ligand.sdf
|
||||
11,data/PDBBind_processed/6moa/6moa_protein_processed.pdb,data/PDBBind_processed/6moa/6moa_ligand.mol2
|
||||
12,data/PDBBind_processed/6hld/6hld_protein_processed.pdb,data/PDBBind_processed/6hld/6hld_ligand.sdf
|
||||
13,data/PDBBind_processed/6i9a/6i9a_protein_processed.pdb,data/PDBBind_processed/6i9a/6i9a_ligand.sdf
|
||||
14,data/PDBBind_processed/6e4c/6e4c_protein_processed.pdb,data/PDBBind_processed/6e4c/6e4c_ligand.sdf
|
||||
15,data/PDBBind_processed/6g24/6g24_protein_processed.pdb,data/PDBBind_processed/6g24/6g24_ligand.sdf
|
||||
16,data/PDBBind_processed/6jb4/6jb4_protein_processed.pdb,data/PDBBind_processed/6jb4/6jb4_ligand.sdf
|
||||
17,data/PDBBind_processed/6s55/6s55_protein_processed.pdb,data/PDBBind_processed/6s55/6s55_ligand.sdf
|
||||
18,data/PDBBind_processed/6seo/6seo_protein_processed.pdb,data/PDBBind_processed/6seo/6seo_ligand.sdf
|
||||
19,data/PDBBind_processed/6dyz/6dyz_protein_processed.pdb,data/PDBBind_processed/6dyz/6dyz_ligand.mol2
|
||||
20,data/PDBBind_processed/5zk5/5zk5_protein_processed.pdb,data/PDBBind_processed/5zk5/5zk5_ligand.sdf
|
||||
21,data/PDBBind_processed/6jid/6jid_protein_processed.pdb,data/PDBBind_processed/6jid/6jid_ligand.sdf
|
||||
22,data/PDBBind_processed/5ze6/5ze6_protein_processed.pdb,data/PDBBind_processed/5ze6/5ze6_ligand.sdf
|
||||
23,data/PDBBind_processed/6qlu/6qlu_protein_processed.pdb,data/PDBBind_processed/6qlu/6qlu_ligand.sdf
|
||||
24,data/PDBBind_processed/6a6k/6a6k_protein_processed.pdb,data/PDBBind_processed/6a6k/6a6k_ligand.sdf
|
||||
25,data/PDBBind_processed/6qgf/6qgf_protein_processed.pdb,data/PDBBind_processed/6qgf/6qgf_ligand.sdf
|
||||
26,data/PDBBind_processed/6e3z/6e3z_protein_processed.pdb,data/PDBBind_processed/6e3z/6e3z_ligand.sdf
|
||||
27,data/PDBBind_processed/6te6/6te6_protein_processed.pdb,data/PDBBind_processed/6te6/6te6_ligand.sdf
|
||||
28,data/PDBBind_processed/6pka/6pka_protein_processed.pdb,data/PDBBind_processed/6pka/6pka_ligand.sdf
|
||||
29,data/PDBBind_processed/6g2o/6g2o_protein_processed.pdb,data/PDBBind_processed/6g2o/6g2o_ligand.sdf
|
||||
30,data/PDBBind_processed/6jsf/6jsf_protein_processed.pdb,data/PDBBind_processed/6jsf/6jsf_ligand.sdf
|
||||
31,data/PDBBind_processed/5zxk/5zxk_protein_processed.pdb,data/PDBBind_processed/5zxk/5zxk_ligand.sdf
|
||||
32,data/PDBBind_processed/6qxd/6qxd_protein_processed.pdb,data/PDBBind_processed/6qxd/6qxd_ligand.sdf
|
||||
33,data/PDBBind_processed/6n97/6n97_protein_processed.pdb,data/PDBBind_processed/6n97/6n97_ligand.sdf
|
||||
34,data/PDBBind_processed/6jt3/6jt3_protein_processed.pdb,data/PDBBind_processed/6jt3/6jt3_ligand.sdf
|
||||
35,data/PDBBind_processed/6qtr/6qtr_protein_processed.pdb,data/PDBBind_processed/6qtr/6qtr_ligand.sdf
|
||||
36,data/PDBBind_processed/6oy1/6oy1_protein_processed.pdb,data/PDBBind_processed/6oy1/6oy1_ligand.sdf
|
||||
37,data/PDBBind_processed/6n96/6n96_protein_processed.pdb,data/PDBBind_processed/6n96/6n96_ligand.sdf
|
||||
38,data/PDBBind_processed/6qzh/6qzh_protein_processed.pdb,data/PDBBind_processed/6qzh/6qzh_ligand.sdf
|
||||
39,data/PDBBind_processed/6qqz/6qqz_protein_processed.pdb,data/PDBBind_processed/6qqz/6qqz_ligand.mol2
|
||||
40,data/PDBBind_processed/6qmt/6qmt_protein_processed.pdb,data/PDBBind_processed/6qmt/6qmt_ligand.sdf
|
||||
41,data/PDBBind_processed/6ibx/6ibx_protein_processed.pdb,data/PDBBind_processed/6ibx/6ibx_ligand.sdf
|
||||
42,data/PDBBind_processed/6hmt/6hmt_protein_processed.pdb,data/PDBBind_processed/6hmt/6hmt_ligand.sdf
|
||||
43,data/PDBBind_processed/5zk7/5zk7_protein_processed.pdb,data/PDBBind_processed/5zk7/5zk7_ligand.sdf
|
||||
44,data/PDBBind_processed/6k3l/6k3l_protein_processed.pdb,data/PDBBind_processed/6k3l/6k3l_ligand.sdf
|
||||
45,data/PDBBind_processed/6cjs/6cjs_protein_processed.pdb,data/PDBBind_processed/6cjs/6cjs_ligand.sdf
|
||||
46,data/PDBBind_processed/6n9l/6n9l_protein_processed.pdb,data/PDBBind_processed/6n9l/6n9l_ligand.sdf
|
||||
47,data/PDBBind_processed/6ibz/6ibz_protein_processed.pdb,data/PDBBind_processed/6ibz/6ibz_ligand.sdf
|
||||
48,data/PDBBind_processed/6ott/6ott_protein_processed.pdb,data/PDBBind_processed/6ott/6ott_ligand.sdf
|
||||
49,data/PDBBind_processed/6gge/6gge_protein_processed.pdb,data/PDBBind_processed/6gge/6gge_ligand.sdf
|
||||
50,data/PDBBind_processed/6hot/6hot_protein_processed.pdb,data/PDBBind_processed/6hot/6hot_ligand.sdf
|
||||
51,data/PDBBind_processed/6e3p/6e3p_protein_processed.pdb,data/PDBBind_processed/6e3p/6e3p_ligand.mol2
|
||||
52,data/PDBBind_processed/6md6/6md6_protein_processed.pdb,data/PDBBind_processed/6md6/6md6_ligand.sdf
|
||||
53,data/PDBBind_processed/6hlb/6hlb_protein_processed.pdb,data/PDBBind_processed/6hlb/6hlb_ligand.sdf
|
||||
54,data/PDBBind_processed/6fe5/6fe5_protein_processed.pdb,data/PDBBind_processed/6fe5/6fe5_ligand.sdf
|
||||
55,data/PDBBind_processed/6uwp/6uwp_protein_processed.pdb,data/PDBBind_processed/6uwp/6uwp_ligand.sdf
|
||||
56,data/PDBBind_processed/6npp/6npp_protein_processed.pdb,data/PDBBind_processed/6npp/6npp_ligand.sdf
|
||||
57,data/PDBBind_processed/6g2f/6g2f_protein_processed.pdb,data/PDBBind_processed/6g2f/6g2f_ligand.sdf
|
||||
58,data/PDBBind_processed/6mo7/6mo7_protein_processed.pdb,data/PDBBind_processed/6mo7/6mo7_ligand.sdf
|
||||
59,data/PDBBind_processed/6bqd/6bqd_protein_processed.pdb,data/PDBBind_processed/6bqd/6bqd_ligand.mol2
|
||||
60,data/PDBBind_processed/6nsv/6nsv_protein_processed.pdb,data/PDBBind_processed/6nsv/6nsv_ligand.mol2
|
||||
61,data/PDBBind_processed/6i76/6i76_protein_processed.pdb,data/PDBBind_processed/6i76/6i76_ligand.sdf
|
||||
62,data/PDBBind_processed/6n53/6n53_protein_processed.pdb,data/PDBBind_processed/6n53/6n53_ligand.sdf
|
||||
63,data/PDBBind_processed/6g2c/6g2c_protein_processed.pdb,data/PDBBind_processed/6g2c/6g2c_ligand.sdf
|
||||
64,data/PDBBind_processed/6eeb/6eeb_protein_processed.pdb,data/PDBBind_processed/6eeb/6eeb_ligand.mol2
|
||||
65,data/PDBBind_processed/6n0m/6n0m_protein_processed.pdb,data/PDBBind_processed/6n0m/6n0m_ligand.sdf
|
||||
66,data/PDBBind_processed/6uvy/6uvy_protein_processed.pdb,data/PDBBind_processed/6uvy/6uvy_ligand.sdf
|
||||
67,data/PDBBind_processed/6ovz/6ovz_protein_processed.pdb,data/PDBBind_processed/6ovz/6ovz_ligand.sdf
|
||||
68,data/PDBBind_processed/6olx/6olx_protein_processed.pdb,data/PDBBind_processed/6olx/6olx_ligand.sdf
|
||||
69,data/PDBBind_processed/6v5l/6v5l_protein_processed.pdb,data/PDBBind_processed/6v5l/6v5l_ligand.mol2
|
||||
70,data/PDBBind_processed/6hhg/6hhg_protein_processed.pdb,data/PDBBind_processed/6hhg/6hhg_ligand.sdf
|
||||
71,data/PDBBind_processed/5zcu/5zcu_protein_processed.pdb,data/PDBBind_processed/5zcu/5zcu_ligand.sdf
|
||||
72,data/PDBBind_processed/6dz2/6dz2_protein_processed.pdb,data/PDBBind_processed/6dz2/6dz2_ligand.mol2
|
||||
73,data/PDBBind_processed/6mjq/6mjq_protein_processed.pdb,data/PDBBind_processed/6mjq/6mjq_ligand.sdf
|
||||
74,data/PDBBind_processed/6efk/6efk_protein_processed.pdb,data/PDBBind_processed/6efk/6efk_ligand.sdf
|
||||
75,data/PDBBind_processed/6s9w/6s9w_protein_processed.pdb,data/PDBBind_processed/6s9w/6s9w_ligand.sdf
|
||||
76,data/PDBBind_processed/6gdy/6gdy_protein_processed.pdb,data/PDBBind_processed/6gdy/6gdy_ligand.sdf
|
||||
77,data/PDBBind_processed/6kqi/6kqi_protein_processed.pdb,data/PDBBind_processed/6kqi/6kqi_ligand.sdf
|
||||
78,data/PDBBind_processed/6ueg/6ueg_protein_processed.pdb,data/PDBBind_processed/6ueg/6ueg_ligand.sdf
|
||||
79,data/PDBBind_processed/6oxt/6oxt_protein_processed.pdb,data/PDBBind_processed/6oxt/6oxt_ligand.sdf
|
||||
80,data/PDBBind_processed/6oy0/6oy0_protein_processed.pdb,data/PDBBind_processed/6oy0/6oy0_ligand.sdf
|
||||
81,data/PDBBind_processed/6qr7/6qr7_protein_processed.pdb,data/PDBBind_processed/6qr7/6qr7_ligand.mol2
|
||||
82,data/PDBBind_processed/6i41/6i41_protein_processed.pdb,data/PDBBind_processed/6i41/6i41_ligand.sdf
|
||||
83,data/PDBBind_processed/6cyg/6cyg_protein_processed.pdb,data/PDBBind_processed/6cyg/6cyg_ligand.sdf
|
||||
84,data/PDBBind_processed/6qmr/6qmr_protein_processed.pdb,data/PDBBind_processed/6qmr/6qmr_ligand.sdf
|
||||
85,data/PDBBind_processed/6g27/6g27_protein_processed.pdb,data/PDBBind_processed/6g27/6g27_ligand.sdf
|
||||
86,data/PDBBind_processed/6ggb/6ggb_protein_processed.pdb,data/PDBBind_processed/6ggb/6ggb_ligand.sdf
|
||||
87,data/PDBBind_processed/6g3c/6g3c_protein_processed.pdb,data/PDBBind_processed/6g3c/6g3c_ligand.sdf
|
||||
88,data/PDBBind_processed/6n4e/6n4e_protein_processed.pdb,data/PDBBind_processed/6n4e/6n4e_ligand.sdf
|
||||
89,data/PDBBind_processed/6fcj/6fcj_protein_processed.pdb,data/PDBBind_processed/6fcj/6fcj_ligand.sdf
|
||||
90,data/PDBBind_processed/6quv/6quv_protein_processed.pdb,data/PDBBind_processed/6quv/6quv_ligand.sdf
|
||||
91,data/PDBBind_processed/6iql/6iql_protein_processed.pdb,data/PDBBind_processed/6iql/6iql_ligand.mol2
|
||||
92,data/PDBBind_processed/6i74/6i74_protein_processed.pdb,data/PDBBind_processed/6i74/6i74_ligand.sdf
|
||||
93,data/PDBBind_processed/6qr4/6qr4_protein_processed.pdb,data/PDBBind_processed/6qr4/6qr4_ligand.mol2
|
||||
94,data/PDBBind_processed/6rnu/6rnu_protein_processed.pdb,data/PDBBind_processed/6rnu/6rnu_ligand.sdf
|
||||
95,data/PDBBind_processed/6jib/6jib_protein_processed.pdb,data/PDBBind_processed/6jib/6jib_ligand.sdf
|
||||
96,data/PDBBind_processed/6izq/6izq_protein_processed.pdb,data/PDBBind_processed/6izq/6izq_ligand.sdf
|
||||
97,data/PDBBind_processed/6qw8/6qw8_protein_processed.pdb,data/PDBBind_processed/6qw8/6qw8_ligand.sdf
|
||||
98,data/PDBBind_processed/6qto/6qto_protein_processed.pdb,data/PDBBind_processed/6qto/6qto_ligand.sdf
|
||||
99,data/PDBBind_processed/6qrd/6qrd_protein_processed.pdb,data/PDBBind_processed/6qrd/6qrd_ligand.mol2
|
||||
100,data/PDBBind_processed/6hza/6hza_protein_processed.pdb,data/PDBBind_processed/6hza/6hza_ligand.sdf
|
||||
101,data/PDBBind_processed/6e5s/6e5s_protein_processed.pdb,data/PDBBind_processed/6e5s/6e5s_ligand.sdf
|
||||
102,data/PDBBind_processed/6dz3/6dz3_protein_processed.pdb,data/PDBBind_processed/6dz3/6dz3_ligand.mol2
|
||||
103,data/PDBBind_processed/6e6w/6e6w_protein_processed.pdb,data/PDBBind_processed/6e6w/6e6w_ligand.mol2
|
||||
104,data/PDBBind_processed/6cyh/6cyh_protein_processed.pdb,data/PDBBind_processed/6cyh/6cyh_ligand.sdf
|
||||
105,data/PDBBind_processed/5zlf/5zlf_protein_processed.pdb,data/PDBBind_processed/5zlf/5zlf_ligand.sdf
|
||||
106,data/PDBBind_processed/6om4/6om4_protein_processed.pdb,data/PDBBind_processed/6om4/6om4_ligand.sdf
|
||||
107,data/PDBBind_processed/6gga/6gga_protein_processed.pdb,data/PDBBind_processed/6gga/6gga_ligand.sdf
|
||||
108,data/PDBBind_processed/6pgp/6pgp_protein_processed.pdb,data/PDBBind_processed/6pgp/6pgp_ligand.sdf
|
||||
109,data/PDBBind_processed/6qqv/6qqv_protein_processed.pdb,data/PDBBind_processed/6qqv/6qqv_ligand.mol2
|
||||
110,data/PDBBind_processed/6qtq/6qtq_protein_processed.pdb,data/PDBBind_processed/6qtq/6qtq_ligand.sdf
|
||||
111,data/PDBBind_processed/6gj6/6gj6_protein_processed.pdb,data/PDBBind_processed/6gj6/6gj6_ligand.mol2
|
||||
112,data/PDBBind_processed/6os5/6os5_protein_processed.pdb,data/PDBBind_processed/6os5/6os5_ligand.mol2
|
||||
113,data/PDBBind_processed/6s07/6s07_protein_processed.pdb,data/PDBBind_processed/6s07/6s07_ligand.sdf
|
||||
114,data/PDBBind_processed/6i77/6i77_protein_processed.pdb,data/PDBBind_processed/6i77/6i77_ligand.sdf
|
||||
115,data/PDBBind_processed/6hhj/6hhj_protein_processed.pdb,data/PDBBind_processed/6hhj/6hhj_ligand.sdf
|
||||
116,data/PDBBind_processed/6ahs/6ahs_protein_processed.pdb,data/PDBBind_processed/6ahs/6ahs_ligand.sdf
|
||||
117,data/PDBBind_processed/6oxx/6oxx_protein_processed.pdb,data/PDBBind_processed/6oxx/6oxx_ligand.sdf
|
||||
118,data/PDBBind_processed/6mjj/6mjj_protein_processed.pdb,data/PDBBind_processed/6mjj/6mjj_ligand.sdf
|
||||
119,data/PDBBind_processed/6hor/6hor_protein_processed.pdb,data/PDBBind_processed/6hor/6hor_ligand.sdf
|
||||
120,data/PDBBind_processed/6jb0/6jb0_protein_processed.pdb,data/PDBBind_processed/6jb0/6jb0_ligand.sdf
|
||||
121,data/PDBBind_processed/6i68/6i68_protein_processed.pdb,data/PDBBind_processed/6i68/6i68_ligand.sdf
|
||||
122,data/PDBBind_processed/6pz4/6pz4_protein_processed.pdb,data/PDBBind_processed/6pz4/6pz4_ligand.sdf
|
||||
123,data/PDBBind_processed/6mhb/6mhb_protein_processed.pdb,data/PDBBind_processed/6mhb/6mhb_ligand.sdf
|
||||
124,data/PDBBind_processed/6uim/6uim_protein_processed.pdb,data/PDBBind_processed/6uim/6uim_ligand.sdf
|
||||
125,data/PDBBind_processed/6jsg/6jsg_protein_processed.pdb,data/PDBBind_processed/6jsg/6jsg_ligand.sdf
|
||||
126,data/PDBBind_processed/6i78/6i78_protein_processed.pdb,data/PDBBind_processed/6i78/6i78_ligand.sdf
|
||||
127,data/PDBBind_processed/6oxy/6oxy_protein_processed.pdb,data/PDBBind_processed/6oxy/6oxy_ligand.sdf
|
||||
128,data/PDBBind_processed/6gbw/6gbw_protein_processed.pdb,data/PDBBind_processed/6gbw/6gbw_ligand.sdf
|
||||
129,data/PDBBind_processed/6mo0/6mo0_protein_processed.pdb,data/PDBBind_processed/6mo0/6mo0_ligand.sdf
|
||||
130,data/PDBBind_processed/6ggf/6ggf_protein_processed.pdb,data/PDBBind_processed/6ggf/6ggf_ligand.sdf
|
||||
131,data/PDBBind_processed/6qge/6qge_protein_processed.pdb,data/PDBBind_processed/6qge/6qge_ligand.sdf
|
||||
132,data/PDBBind_processed/6cjr/6cjr_protein_processed.pdb,data/PDBBind_processed/6cjr/6cjr_ligand.sdf
|
||||
133,data/PDBBind_processed/6oxp/6oxp_protein_processed.pdb,data/PDBBind_processed/6oxp/6oxp_ligand.sdf
|
||||
134,data/PDBBind_processed/6d07/6d07_protein_processed.pdb,data/PDBBind_processed/6d07/6d07_ligand.sdf
|
||||
135,data/PDBBind_processed/6i63/6i63_protein_processed.pdb,data/PDBBind_processed/6i63/6i63_ligand.sdf
|
||||
136,data/PDBBind_processed/6ten/6ten_protein_processed.pdb,data/PDBBind_processed/6ten/6ten_ligand.sdf
|
||||
137,data/PDBBind_processed/6uii/6uii_protein_processed.pdb,data/PDBBind_processed/6uii/6uii_ligand.sdf
|
||||
138,data/PDBBind_processed/6qlr/6qlr_protein_processed.pdb,data/PDBBind_processed/6qlr/6qlr_ligand.sdf
|
||||
139,data/PDBBind_processed/6sen/6sen_protein_processed.pdb,data/PDBBind_processed/6sen/6sen_ligand.mol2
|
||||
140,data/PDBBind_processed/6oxv/6oxv_protein_processed.pdb,data/PDBBind_processed/6oxv/6oxv_ligand.sdf
|
||||
141,data/PDBBind_processed/6g2b/6g2b_protein_processed.pdb,data/PDBBind_processed/6g2b/6g2b_ligand.sdf
|
||||
142,data/PDBBind_processed/5zr3/5zr3_protein_processed.pdb,data/PDBBind_processed/5zr3/5zr3_ligand.sdf
|
||||
143,data/PDBBind_processed/6kjf/6kjf_protein_processed.pdb,data/PDBBind_processed/6kjf/6kjf_ligand.sdf
|
||||
144,data/PDBBind_processed/6qr9/6qr9_protein_processed.pdb,data/PDBBind_processed/6qr9/6qr9_ligand.mol2
|
||||
145,data/PDBBind_processed/6g9f/6g9f_protein_processed.pdb,data/PDBBind_processed/6g9f/6g9f_ligand.sdf
|
||||
146,data/PDBBind_processed/6e6v/6e6v_protein_processed.pdb,data/PDBBind_processed/6e6v/6e6v_ligand.sdf
|
||||
147,data/PDBBind_processed/5zk9/5zk9_protein_processed.pdb,data/PDBBind_processed/5zk9/5zk9_ligand.sdf
|
||||
148,data/PDBBind_processed/6pnn/6pnn_protein_processed.pdb,data/PDBBind_processed/6pnn/6pnn_ligand.sdf
|
||||
149,data/PDBBind_processed/6nri/6nri_protein_processed.pdb,data/PDBBind_processed/6nri/6nri_ligand.sdf
|
||||
150,data/PDBBind_processed/6uwv/6uwv_protein_processed.pdb,data/PDBBind_processed/6uwv/6uwv_ligand.sdf
|
||||
151,data/PDBBind_processed/6ooz/6ooz_protein_processed.pdb,data/PDBBind_processed/6ooz/6ooz_ligand.sdf
|
||||
152,data/PDBBind_processed/6npi/6npi_protein_processed.pdb,data/PDBBind_processed/6npi/6npi_ligand.sdf
|
||||
153,data/PDBBind_processed/6oip/6oip_protein_processed.pdb,data/PDBBind_processed/6oip/6oip_ligand.sdf
|
||||
154,data/PDBBind_processed/6miv/6miv_protein_processed.pdb,data/PDBBind_processed/6miv/6miv_ligand.sdf
|
||||
155,data/PDBBind_processed/6s57/6s57_protein_processed.pdb,data/PDBBind_processed/6s57/6s57_ligand.sdf
|
||||
156,data/PDBBind_processed/6p8x/6p8x_protein_processed.pdb,data/PDBBind_processed/6p8x/6p8x_ligand.sdf
|
||||
157,data/PDBBind_processed/6hoq/6hoq_protein_processed.pdb,data/PDBBind_processed/6hoq/6hoq_ligand.sdf
|
||||
158,data/PDBBind_processed/6qts/6qts_protein_processed.pdb,data/PDBBind_processed/6qts/6qts_ligand.sdf
|
||||
159,data/PDBBind_processed/6ggd/6ggd_protein_processed.pdb,data/PDBBind_processed/6ggd/6ggd_ligand.sdf
|
||||
160,data/PDBBind_processed/6pnm/6pnm_protein_processed.pdb,data/PDBBind_processed/6pnm/6pnm_ligand.sdf
|
||||
161,data/PDBBind_processed/6oy2/6oy2_protein_processed.pdb,data/PDBBind_processed/6oy2/6oy2_ligand.sdf
|
||||
162,data/PDBBind_processed/6oi8/6oi8_protein_processed.pdb,data/PDBBind_processed/6oi8/6oi8_ligand.sdf
|
||||
163,data/PDBBind_processed/6mhd/6mhd_protein_processed.pdb,data/PDBBind_processed/6mhd/6mhd_ligand.sdf
|
||||
164,data/PDBBind_processed/6agt/6agt_protein_processed.pdb,data/PDBBind_processed/6agt/6agt_ligand.sdf
|
||||
165,data/PDBBind_processed/6i5p/6i5p_protein_processed.pdb,data/PDBBind_processed/6i5p/6i5p_ligand.sdf
|
||||
166,data/PDBBind_processed/6hhr/6hhr_protein_processed.pdb,data/PDBBind_processed/6hhr/6hhr_ligand.sdf
|
||||
167,data/PDBBind_processed/6p8z/6p8z_protein_processed.pdb,data/PDBBind_processed/6p8z/6p8z_ligand.sdf
|
||||
168,data/PDBBind_processed/6c85/6c85_protein_processed.pdb,data/PDBBind_processed/6c85/6c85_ligand.sdf
|
||||
169,data/PDBBind_processed/6g5u/6g5u_protein_processed.pdb,data/PDBBind_processed/6g5u/6g5u_ligand.sdf
|
||||
170,data/PDBBind_processed/6j06/6j06_protein_processed.pdb,data/PDBBind_processed/6j06/6j06_ligand.sdf
|
||||
171,data/PDBBind_processed/6qsz/6qsz_protein_processed.pdb,data/PDBBind_processed/6qsz/6qsz_ligand.sdf
|
||||
172,data/PDBBind_processed/6jbb/6jbb_protein_processed.pdb,data/PDBBind_processed/6jbb/6jbb_ligand.sdf
|
||||
173,data/PDBBind_processed/6hhp/6hhp_protein_processed.pdb,data/PDBBind_processed/6hhp/6hhp_ligand.sdf
|
||||
174,data/PDBBind_processed/6np5/6np5_protein_processed.pdb,data/PDBBind_processed/6np5/6np5_ligand.sdf
|
||||
175,data/PDBBind_processed/6nlj/6nlj_protein_processed.pdb,data/PDBBind_processed/6nlj/6nlj_ligand.sdf
|
||||
176,data/PDBBind_processed/6qlp/6qlp_protein_processed.pdb,data/PDBBind_processed/6qlp/6qlp_ligand.sdf
|
||||
177,data/PDBBind_processed/6n94/6n94_protein_processed.pdb,data/PDBBind_processed/6n94/6n94_ligand.sdf
|
||||
178,data/PDBBind_processed/6e13/6e13_protein_processed.pdb,data/PDBBind_processed/6e13/6e13_ligand.sdf
|
||||
179,data/PDBBind_processed/6qls/6qls_protein_processed.pdb,data/PDBBind_processed/6qls/6qls_ligand.sdf
|
||||
180,data/PDBBind_processed/6uil/6uil_protein_processed.pdb,data/PDBBind_processed/6uil/6uil_ligand.sdf
|
||||
181,data/PDBBind_processed/6st3/6st3_protein_processed.pdb,data/PDBBind_processed/6st3/6st3_ligand.sdf
|
||||
182,data/PDBBind_processed/6n92/6n92_protein_processed.pdb,data/PDBBind_processed/6n92/6n92_ligand.sdf
|
||||
183,data/PDBBind_processed/6s56/6s56_protein_processed.pdb,data/PDBBind_processed/6s56/6s56_ligand.sdf
|
||||
184,data/PDBBind_processed/6hzd/6hzd_protein_processed.pdb,data/PDBBind_processed/6hzd/6hzd_ligand.sdf
|
||||
185,data/PDBBind_processed/6uhv/6uhv_protein_processed.pdb,data/PDBBind_processed/6uhv/6uhv_ligand.sdf
|
||||
186,data/PDBBind_processed/6k05/6k05_protein_processed.pdb,data/PDBBind_processed/6k05/6k05_ligand.sdf
|
||||
187,data/PDBBind_processed/6q36/6q36_protein_processed.pdb,data/PDBBind_processed/6q36/6q36_ligand.mol2
|
||||
188,data/PDBBind_processed/6ic0/6ic0_protein_processed.pdb,data/PDBBind_processed/6ic0/6ic0_ligand.sdf
|
||||
189,data/PDBBind_processed/6hhi/6hhi_protein_processed.pdb,data/PDBBind_processed/6hhi/6hhi_ligand.sdf
|
||||
190,data/PDBBind_processed/6e3m/6e3m_protein_processed.pdb,data/PDBBind_processed/6e3m/6e3m_ligand.sdf
|
||||
191,data/PDBBind_processed/6qtx/6qtx_protein_processed.pdb,data/PDBBind_processed/6qtx/6qtx_ligand.sdf
|
||||
192,data/PDBBind_processed/6jse/6jse_protein_processed.pdb,data/PDBBind_processed/6jse/6jse_ligand.sdf
|
||||
193,data/PDBBind_processed/5zjy/5zjy_protein_processed.pdb,data/PDBBind_processed/5zjy/5zjy_ligand.sdf
|
||||
194,data/PDBBind_processed/6o3y/6o3y_protein_processed.pdb,data/PDBBind_processed/6o3y/6o3y_ligand.sdf
|
||||
195,data/PDBBind_processed/6rpg/6rpg_protein_processed.pdb,data/PDBBind_processed/6rpg/6rpg_ligand.sdf
|
||||
196,data/PDBBind_processed/6rr0/6rr0_protein_processed.pdb,data/PDBBind_processed/6rr0/6rr0_ligand.sdf
|
||||
197,data/PDBBind_processed/6gzy/6gzy_protein_processed.pdb,data/PDBBind_processed/6gzy/6gzy_ligand.sdf
|
||||
198,data/PDBBind_processed/6qlt/6qlt_protein_processed.pdb,data/PDBBind_processed/6qlt/6qlt_ligand.sdf
|
||||
199,data/PDBBind_processed/6ufo/6ufo_protein_processed.pdb,data/PDBBind_processed/6ufo/6ufo_ligand.sdf
|
||||
200,data/PDBBind_processed/6o0h/6o0h_protein_processed.pdb,data/PDBBind_processed/6o0h/6o0h_ligand.sdf
|
||||
201,data/PDBBind_processed/6o3x/6o3x_protein_processed.pdb,data/PDBBind_processed/6o3x/6o3x_ligand.sdf
|
||||
202,data/PDBBind_processed/5zjz/5zjz_protein_processed.pdb,data/PDBBind_processed/5zjz/5zjz_ligand.mol2
|
||||
203,data/PDBBind_processed/6i8t/6i8t_protein_processed.pdb,data/PDBBind_processed/6i8t/6i8t_ligand.sdf
|
||||
204,data/PDBBind_processed/6ooy/6ooy_protein_processed.pdb,data/PDBBind_processed/6ooy/6ooy_ligand.sdf
|
||||
205,data/PDBBind_processed/6oiq/6oiq_protein_processed.pdb,data/PDBBind_processed/6oiq/6oiq_ligand.sdf
|
||||
206,data/PDBBind_processed/6od6/6od6_protein_processed.pdb,data/PDBBind_processed/6od6/6od6_ligand.sdf
|
||||
207,data/PDBBind_processed/6nrh/6nrh_protein_processed.pdb,data/PDBBind_processed/6nrh/6nrh_ligand.sdf
|
||||
208,data/PDBBind_processed/6qra/6qra_protein_processed.pdb,data/PDBBind_processed/6qra/6qra_ligand.mol2
|
||||
209,data/PDBBind_processed/6hhh/6hhh_protein_processed.pdb,data/PDBBind_processed/6hhh/6hhh_ligand.sdf
|
||||
210,data/PDBBind_processed/6m7h/6m7h_protein_processed.pdb,data/PDBBind_processed/6m7h/6m7h_ligand.sdf
|
||||
211,data/PDBBind_processed/6ufn/6ufn_protein_processed.pdb,data/PDBBind_processed/6ufn/6ufn_ligand.sdf
|
||||
212,data/PDBBind_processed/6qr0/6qr0_protein_processed.pdb,data/PDBBind_processed/6qr0/6qr0_ligand.mol2
|
||||
213,data/PDBBind_processed/6o5u/6o5u_protein_processed.pdb,data/PDBBind_processed/6o5u/6o5u_ligand.sdf
|
||||
214,data/PDBBind_processed/6h14/6h14_protein_processed.pdb,data/PDBBind_processed/6h14/6h14_ligand.sdf
|
||||
215,data/PDBBind_processed/6jwa/6jwa_protein_processed.pdb,data/PDBBind_processed/6jwa/6jwa_ligand.sdf
|
||||
216,data/PDBBind_processed/6ny0/6ny0_protein_processed.pdb,data/PDBBind_processed/6ny0/6ny0_ligand.sdf
|
||||
217,data/PDBBind_processed/6jan/6jan_protein_processed.pdb,data/PDBBind_processed/6jan/6jan_ligand.sdf
|
||||
218,data/PDBBind_processed/6ftf/6ftf_protein_processed.pdb,data/PDBBind_processed/6ftf/6ftf_ligand.sdf
|
||||
219,data/PDBBind_processed/6oxw/6oxw_protein_processed.pdb,data/PDBBind_processed/6oxw/6oxw_ligand.sdf
|
||||
220,data/PDBBind_processed/6jon/6jon_protein_processed.pdb,data/PDBBind_processed/6jon/6jon_ligand.sdf
|
||||
221,data/PDBBind_processed/6cf7/6cf7_protein_processed.pdb,data/PDBBind_processed/6cf7/6cf7_ligand.sdf
|
||||
222,data/PDBBind_processed/6rtn/6rtn_protein_processed.pdb,data/PDBBind_processed/6rtn/6rtn_ligand.mol2
|
||||
223,data/PDBBind_processed/6jsz/6jsz_protein_processed.pdb,data/PDBBind_processed/6jsz/6jsz_ligand.sdf
|
||||
224,data/PDBBind_processed/6o9c/6o9c_protein_processed.pdb,data/PDBBind_processed/6o9c/6o9c_ligand.sdf
|
||||
225,data/PDBBind_processed/6mo8/6mo8_protein_processed.pdb,data/PDBBind_processed/6mo8/6mo8_ligand.sdf
|
||||
226,data/PDBBind_processed/6qln/6qln_protein_processed.pdb,data/PDBBind_processed/6qln/6qln_ligand.sdf
|
||||
227,data/PDBBind_processed/6qqu/6qqu_protein_processed.pdb,data/PDBBind_processed/6qqu/6qqu_ligand.mol2
|
||||
228,data/PDBBind_processed/6i66/6i66_protein_processed.pdb,data/PDBBind_processed/6i66/6i66_ligand.sdf
|
||||
229,data/PDBBind_processed/6mja/6mja_protein_processed.pdb,data/PDBBind_processed/6mja/6mja_ligand.sdf
|
||||
230,data/PDBBind_processed/6gwe/6gwe_protein_processed.pdb,data/PDBBind_processed/6gwe/6gwe_ligand.mol2
|
||||
231,data/PDBBind_processed/6d3z/6d3z_protein_processed.pdb,data/PDBBind_processed/6d3z/6d3z_ligand.sdf
|
||||
232,data/PDBBind_processed/6oxr/6oxr_protein_processed.pdb,data/PDBBind_processed/6oxr/6oxr_ligand.sdf
|
||||
233,data/PDBBind_processed/6r4k/6r4k_protein_processed.pdb,data/PDBBind_processed/6r4k/6r4k_ligand.sdf
|
||||
234,data/PDBBind_processed/6hle/6hle_protein_processed.pdb,data/PDBBind_processed/6hle/6hle_ligand.sdf
|
||||
235,data/PDBBind_processed/6h9v/6h9v_protein_processed.pdb,data/PDBBind_processed/6h9v/6h9v_ligand.sdf
|
||||
236,data/PDBBind_processed/6hou/6hou_protein_processed.pdb,data/PDBBind_processed/6hou/6hou_ligand.sdf
|
||||
237,data/PDBBind_processed/6nv9/6nv9_protein_processed.pdb,data/PDBBind_processed/6nv9/6nv9_ligand.sdf
|
||||
238,data/PDBBind_processed/6py0/6py0_protein_processed.pdb,data/PDBBind_processed/6py0/6py0_ligand.sdf
|
||||
239,data/PDBBind_processed/6qlq/6qlq_protein_processed.pdb,data/PDBBind_processed/6qlq/6qlq_ligand.sdf
|
||||
240,data/PDBBind_processed/6nv7/6nv7_protein_processed.pdb,data/PDBBind_processed/6nv7/6nv7_ligand.sdf
|
||||
241,data/PDBBind_processed/6n4b/6n4b_protein_processed.pdb,data/PDBBind_processed/6n4b/6n4b_ligand.sdf
|
||||
242,data/PDBBind_processed/6jaq/6jaq_protein_processed.pdb,data/PDBBind_processed/6jaq/6jaq_ligand.sdf
|
||||
243,data/PDBBind_processed/6i8m/6i8m_protein_processed.pdb,data/PDBBind_processed/6i8m/6i8m_ligand.sdf
|
||||
244,data/PDBBind_processed/6dz0/6dz0_protein_processed.pdb,data/PDBBind_processed/6dz0/6dz0_ligand.mol2
|
||||
245,data/PDBBind_processed/6oxs/6oxs_protein_processed.pdb,data/PDBBind_processed/6oxs/6oxs_ligand.sdf
|
||||
246,data/PDBBind_processed/6k2n/6k2n_protein_processed.pdb,data/PDBBind_processed/6k2n/6k2n_ligand.sdf
|
||||
247,data/PDBBind_processed/6cjj/6cjj_protein_processed.pdb,data/PDBBind_processed/6cjj/6cjj_ligand.sdf
|
||||
248,data/PDBBind_processed/6ffg/6ffg_protein_processed.pdb,data/PDBBind_processed/6ffg/6ffg_ligand.sdf
|
||||
249,data/PDBBind_processed/6a73/6a73_protein_processed.pdb,data/PDBBind_processed/6a73/6a73_ligand.sdf
|
||||
250,data/PDBBind_processed/6qqt/6qqt_protein_processed.pdb,data/PDBBind_processed/6qqt/6qqt_ligand.mol2
|
||||
251,data/PDBBind_processed/6a1c/6a1c_protein_processed.pdb,data/PDBBind_processed/6a1c/6a1c_ligand.sdf
|
||||
252,data/PDBBind_processed/6oxu/6oxu_protein_processed.pdb,data/PDBBind_processed/6oxu/6oxu_ligand.sdf
|
||||
253,data/PDBBind_processed/6qre/6qre_protein_processed.pdb,data/PDBBind_processed/6qre/6qre_ligand.mol2
|
||||
254,data/PDBBind_processed/6qtw/6qtw_protein_processed.pdb,data/PDBBind_processed/6qtw/6qtw_ligand.sdf
|
||||
255,data/PDBBind_processed/6np4/6np4_protein_processed.pdb,data/PDBBind_processed/6np4/6np4_ligand.sdf
|
||||
256,data/PDBBind_processed/6hv2/6hv2_protein_processed.pdb,data/PDBBind_processed/6hv2/6hv2_ligand.sdf
|
||||
257,data/PDBBind_processed/6n55/6n55_protein_processed.pdb,data/PDBBind_processed/6n55/6n55_ligand.sdf
|
||||
258,data/PDBBind_processed/6e3o/6e3o_protein_processed.pdb,data/PDBBind_processed/6e3o/6e3o_ligand.sdf
|
||||
259,data/PDBBind_processed/6kjd/6kjd_protein_processed.pdb,data/PDBBind_processed/6kjd/6kjd_ligand.sdf
|
||||
260,data/PDBBind_processed/6sfc/6sfc_protein_processed.pdb,data/PDBBind_processed/6sfc/6sfc_ligand.sdf
|
||||
261,data/PDBBind_processed/6qi7/6qi7_protein_processed.pdb,data/PDBBind_processed/6qi7/6qi7_ligand.sdf
|
||||
262,data/PDBBind_processed/6hzc/6hzc_protein_processed.pdb,data/PDBBind_processed/6hzc/6hzc_ligand.sdf
|
||||
263,data/PDBBind_processed/6k04/6k04_protein_processed.pdb,data/PDBBind_processed/6k04/6k04_ligand.sdf
|
||||
264,data/PDBBind_processed/6op0/6op0_protein_processed.pdb,data/PDBBind_processed/6op0/6op0_ligand.sdf
|
||||
265,data/PDBBind_processed/6q38/6q38_protein_processed.pdb,data/PDBBind_processed/6q38/6q38_ligand.mol2
|
||||
266,data/PDBBind_processed/6n8x/6n8x_protein_processed.pdb,data/PDBBind_processed/6n8x/6n8x_ligand.sdf
|
||||
267,data/PDBBind_processed/6np3/6np3_protein_processed.pdb,data/PDBBind_processed/6np3/6np3_ligand.sdf
|
||||
268,data/PDBBind_processed/6uvv/6uvv_protein_processed.pdb,data/PDBBind_processed/6uvv/6uvv_ligand.sdf
|
||||
269,data/PDBBind_processed/6pgo/6pgo_protein_processed.pdb,data/PDBBind_processed/6pgo/6pgo_ligand.sdf
|
||||
270,data/PDBBind_processed/6jbe/6jbe_protein_processed.pdb,data/PDBBind_processed/6jbe/6jbe_ligand.sdf
|
||||
271,data/PDBBind_processed/6i75/6i75_protein_processed.pdb,data/PDBBind_processed/6i75/6i75_ligand.sdf
|
||||
272,data/PDBBind_processed/6qqq/6qqq_protein_processed.pdb,data/PDBBind_processed/6qqq/6qqq_ligand.mol2
|
||||
273,data/PDBBind_processed/6i62/6i62_protein_processed.pdb,data/PDBBind_processed/6i62/6i62_ligand.sdf
|
||||
274,data/PDBBind_processed/6j9y/6j9y_protein_processed.pdb,data/PDBBind_processed/6j9y/6j9y_ligand.sdf
|
||||
275,data/PDBBind_processed/6g29/6g29_protein_processed.pdb,data/PDBBind_processed/6g29/6g29_ligand.sdf
|
||||
276,data/PDBBind_processed/6h7d/6h7d_protein_processed.pdb,data/PDBBind_processed/6h7d/6h7d_ligand.sdf
|
||||
277,data/PDBBind_processed/6mo9/6mo9_protein_processed.pdb,data/PDBBind_processed/6mo9/6mo9_ligand.sdf
|
||||
278,data/PDBBind_processed/6jao/6jao_protein_processed.pdb,data/PDBBind_processed/6jao/6jao_ligand.sdf
|
||||
279,data/PDBBind_processed/6jmf/6jmf_protein_processed.pdb,data/PDBBind_processed/6jmf/6jmf_ligand.sdf
|
||||
280,data/PDBBind_processed/6hmy/6hmy_protein_processed.pdb,data/PDBBind_processed/6hmy/6hmy_ligand.sdf
|
||||
281,data/PDBBind_processed/6qfe/6qfe_protein_processed.pdb,data/PDBBind_processed/6qfe/6qfe_ligand.mol2
|
||||
282,data/PDBBind_processed/5zml/5zml_protein_processed.pdb,data/PDBBind_processed/5zml/5zml_ligand.sdf
|
||||
283,data/PDBBind_processed/6i65/6i65_protein_processed.pdb,data/PDBBind_processed/6i65/6i65_ligand.sdf
|
||||
284,data/PDBBind_processed/6e7m/6e7m_protein_processed.pdb,data/PDBBind_processed/6e7m/6e7m_ligand.sdf
|
||||
285,data/PDBBind_processed/6i61/6i61_protein_processed.pdb,data/PDBBind_processed/6i61/6i61_ligand.sdf
|
||||
286,data/PDBBind_processed/6rz6/6rz6_protein_processed.pdb,data/PDBBind_processed/6rz6/6rz6_ligand.sdf
|
||||
287,data/PDBBind_processed/6qtm/6qtm_protein_processed.pdb,data/PDBBind_processed/6qtm/6qtm_ligand.sdf
|
||||
288,data/PDBBind_processed/6qlo/6qlo_protein_processed.pdb,data/PDBBind_processed/6qlo/6qlo_ligand.sdf
|
||||
289,data/PDBBind_processed/6oie/6oie_protein_processed.pdb,data/PDBBind_processed/6oie/6oie_ligand.sdf
|
||||
290,data/PDBBind_processed/6miy/6miy_protein_processed.pdb,data/PDBBind_processed/6miy/6miy_ligand.sdf
|
||||
291,data/PDBBind_processed/6nrf/6nrf_protein_processed.pdb,data/PDBBind_processed/6nrf/6nrf_ligand.mol2
|
||||
292,data/PDBBind_processed/6gj5/6gj5_protein_processed.pdb,data/PDBBind_processed/6gj5/6gj5_ligand.mol2
|
||||
293,data/PDBBind_processed/6jad/6jad_protein_processed.pdb,data/PDBBind_processed/6jad/6jad_ligand.sdf
|
||||
294,data/PDBBind_processed/6mj4/6mj4_protein_processed.pdb,data/PDBBind_processed/6mj4/6mj4_ligand.sdf
|
||||
295,data/PDBBind_processed/6h12/6h12_protein_processed.pdb,data/PDBBind_processed/6h12/6h12_ligand.sdf
|
||||
296,data/PDBBind_processed/6d3y/6d3y_protein_processed.pdb,data/PDBBind_processed/6d3y/6d3y_ligand.sdf
|
||||
297,data/PDBBind_processed/6qr2/6qr2_protein_processed.pdb,data/PDBBind_processed/6qr2/6qr2_ligand.mol2
|
||||
298,data/PDBBind_processed/6qxa/6qxa_protein_processed.pdb,data/PDBBind_processed/6qxa/6qxa_ligand.mol2
|
||||
299,data/PDBBind_processed/6o9b/6o9b_protein_processed.pdb,data/PDBBind_processed/6o9b/6o9b_ligand.sdf
|
||||
300,data/PDBBind_processed/6ckl/6ckl_protein_processed.pdb,data/PDBBind_processed/6ckl/6ckl_ligand.sdf
|
||||
301,data/PDBBind_processed/6oir/6oir_protein_processed.pdb,data/PDBBind_processed/6oir/6oir_ligand.sdf
|
||||
302,data/PDBBind_processed/6d40/6d40_protein_processed.pdb,data/PDBBind_processed/6d40/6d40_ligand.sdf
|
||||
303,data/PDBBind_processed/6e6j/6e6j_protein_processed.pdb,data/PDBBind_processed/6e6j/6e6j_ligand.mol2
|
||||
304,data/PDBBind_processed/6i7a/6i7a_protein_processed.pdb,data/PDBBind_processed/6i7a/6i7a_ligand.sdf
|
||||
305,data/PDBBind_processed/6g25/6g25_protein_processed.pdb,data/PDBBind_processed/6g25/6g25_ligand.mol2
|
||||
306,data/PDBBind_processed/6oin/6oin_protein_processed.pdb,data/PDBBind_processed/6oin/6oin_ligand.sdf
|
||||
307,data/PDBBind_processed/6jam/6jam_protein_processed.pdb,data/PDBBind_processed/6jam/6jam_ligand.sdf
|
||||
308,data/PDBBind_processed/6oxz/6oxz_protein_processed.pdb,data/PDBBind_processed/6oxz/6oxz_ligand.sdf
|
||||
309,data/PDBBind_processed/6hop/6hop_protein_processed.pdb,data/PDBBind_processed/6hop/6hop_ligand.sdf
|
||||
310,data/PDBBind_processed/6rot/6rot_protein_processed.pdb,data/PDBBind_processed/6rot/6rot_ligand.sdf
|
||||
311,data/PDBBind_processed/6uhu/6uhu_protein_processed.pdb,data/PDBBind_processed/6uhu/6uhu_ligand.mol2
|
||||
312,data/PDBBind_processed/6mji/6mji_protein_processed.pdb,data/PDBBind_processed/6mji/6mji_ligand.sdf
|
||||
313,data/PDBBind_processed/6nrj/6nrj_protein_processed.pdb,data/PDBBind_processed/6nrj/6nrj_ligand.mol2
|
||||
314,data/PDBBind_processed/6nt2/6nt2_protein_processed.pdb,data/PDBBind_processed/6nt2/6nt2_ligand.mol2
|
||||
315,data/PDBBind_processed/6op9/6op9_protein_processed.pdb,data/PDBBind_processed/6op9/6op9_ligand.sdf
|
||||
316,data/PDBBind_processed/6pno/6pno_protein_processed.pdb,data/PDBBind_processed/6pno/6pno_ligand.sdf
|
||||
317,data/PDBBind_processed/6e4v/6e4v_protein_processed.pdb,data/PDBBind_processed/6e4v/6e4v_ligand.sdf
|
||||
318,data/PDBBind_processed/6k1s/6k1s_protein_processed.pdb,data/PDBBind_processed/6k1s/6k1s_ligand.sdf
|
||||
319,data/PDBBind_processed/6a87/6a87_protein_processed.pdb,data/PDBBind_processed/6a87/6a87_ligand.sdf
|
||||
320,data/PDBBind_processed/6oim/6oim_protein_processed.pdb,data/PDBBind_processed/6oim/6oim_ligand.sdf
|
||||
321,data/PDBBind_processed/6cjp/6cjp_protein_processed.pdb,data/PDBBind_processed/6cjp/6cjp_ligand.sdf
|
||||
322,data/PDBBind_processed/6pyb/6pyb_protein_processed.pdb,data/PDBBind_processed/6pyb/6pyb_ligand.sdf
|
||||
323,data/PDBBind_processed/6h13/6h13_protein_processed.pdb,data/PDBBind_processed/6h13/6h13_ligand.sdf
|
||||
324,data/PDBBind_processed/6qrf/6qrf_protein_processed.pdb,data/PDBBind_processed/6qrf/6qrf_ligand.mol2
|
||||
325,data/PDBBind_processed/6mhc/6mhc_protein_processed.pdb,data/PDBBind_processed/6mhc/6mhc_ligand.sdf
|
||||
326,data/PDBBind_processed/6j9w/6j9w_protein_processed.pdb,data/PDBBind_processed/6j9w/6j9w_ligand.sdf
|
||||
327,data/PDBBind_processed/6nrg/6nrg_protein_processed.pdb,data/PDBBind_processed/6nrg/6nrg_ligand.mol2
|
||||
328,data/PDBBind_processed/6fff/6fff_protein_processed.pdb,data/PDBBind_processed/6fff/6fff_ligand.sdf
|
||||
329,data/PDBBind_processed/6n93/6n93_protein_processed.pdb,data/PDBBind_processed/6n93/6n93_ligand.sdf
|
||||
330,data/PDBBind_processed/6jut/6jut_protein_processed.pdb,data/PDBBind_processed/6jut/6jut_ligand.mol2
|
||||
331,data/PDBBind_processed/6g2e/6g2e_protein_processed.pdb,data/PDBBind_processed/6g2e/6g2e_ligand.sdf
|
||||
332,data/PDBBind_processed/6nd3/6nd3_protein_processed.pdb,data/PDBBind_processed/6nd3/6nd3_ligand.sdf
|
||||
333,data/PDBBind_processed/6os6/6os6_protein_processed.pdb,data/PDBBind_processed/6os6/6os6_ligand.mol2
|
||||
334,data/PDBBind_processed/6dql/6dql_protein_processed.pdb,data/PDBBind_processed/6dql/6dql_ligand.mol2
|
||||
335,data/PDBBind_processed/6inz/6inz_protein_processed.pdb,data/PDBBind_processed/6inz/6inz_ligand.sdf
|
||||
336,data/PDBBind_processed/6i67/6i67_protein_processed.pdb,data/PDBBind_processed/6i67/6i67_ligand.sdf
|
||||
337,data/PDBBind_processed/6quw/6quw_protein_processed.pdb,data/PDBBind_processed/6quw/6quw_ligand.sdf
|
||||
338,data/PDBBind_processed/6qwi/6qwi_protein_processed.pdb,data/PDBBind_processed/6qwi/6qwi_ligand.sdf
|
||||
339,data/PDBBind_processed/6npm/6npm_protein_processed.pdb,data/PDBBind_processed/6npm/6npm_ligand.sdf
|
||||
340,data/PDBBind_processed/6i64/6i64_protein_processed.pdb,data/PDBBind_processed/6i64/6i64_ligand.sdf
|
||||
341,data/PDBBind_processed/6e3n/6e3n_protein_processed.pdb,data/PDBBind_processed/6e3n/6e3n_ligand.sdf
|
||||
342,data/PDBBind_processed/6qrg/6qrg_protein_processed.pdb,data/PDBBind_processed/6qrg/6qrg_ligand.mol2
|
||||
343,data/PDBBind_processed/6nxz/6nxz_protein_processed.pdb,data/PDBBind_processed/6nxz/6nxz_ligand.sdf
|
||||
344,data/PDBBind_processed/6iby/6iby_protein_processed.pdb,data/PDBBind_processed/6iby/6iby_ligand.sdf
|
||||
345,data/PDBBind_processed/6gj7/6gj7_protein_processed.pdb,data/PDBBind_processed/6gj7/6gj7_ligand.mol2
|
||||
346,data/PDBBind_processed/6qr3/6qr3_protein_processed.pdb,data/PDBBind_processed/6qr3/6qr3_ligand.mol2
|
||||
347,data/PDBBind_processed/6qr1/6qr1_protein_processed.pdb,data/PDBBind_processed/6qr1/6qr1_ligand.mol2
|
||||
348,data/PDBBind_processed/6s9x/6s9x_protein_processed.pdb,data/PDBBind_processed/6s9x/6s9x_ligand.sdf
|
||||
349,data/PDBBind_processed/6q4q/6q4q_protein_processed.pdb,data/PDBBind_processed/6q4q/6q4q_ligand.mol2
|
||||
350,data/PDBBind_processed/6hbn/6hbn_protein_processed.pdb,data/PDBBind_processed/6hbn/6hbn_ligand.sdf
|
||||
351,data/PDBBind_processed/6nw3/6nw3_protein_processed.pdb,data/PDBBind_processed/6nw3/6nw3_ligand.sdf
|
||||
352,data/PDBBind_processed/6tel/6tel_protein_processed.pdb,data/PDBBind_processed/6tel/6tel_ligand.sdf
|
||||
353,data/PDBBind_processed/6p8y/6p8y_protein_processed.pdb,data/PDBBind_processed/6p8y/6p8y_ligand.sdf
|
||||
354,data/PDBBind_processed/6d5w/6d5w_protein_processed.pdb,data/PDBBind_processed/6d5w/6d5w_ligand.sdf
|
||||
355,data/PDBBind_processed/6t6a/6t6a_protein_processed.pdb,data/PDBBind_processed/6t6a/6t6a_ligand.mol2
|
||||
356,data/PDBBind_processed/6o5g/6o5g_protein_processed.pdb,data/PDBBind_processed/6o5g/6o5g_ligand.mol2
|
||||
357,data/PDBBind_processed/6r7d/6r7d_protein_processed.pdb,data/PDBBind_processed/6r7d/6r7d_ligand.sdf
|
||||
358,data/PDBBind_processed/6pya/6pya_protein_processed.pdb,data/PDBBind_processed/6pya/6pya_ligand.mol2
|
||||
359,data/PDBBind_processed/6ffe/6ffe_protein_processed.pdb,data/PDBBind_processed/6ffe/6ffe_ligand.sdf
|
||||
360,data/PDBBind_processed/6d3x/6d3x_protein_processed.pdb,data/PDBBind_processed/6d3x/6d3x_ligand.sdf
|
||||
361,data/PDBBind_processed/6gj8/6gj8_protein_processed.pdb,data/PDBBind_processed/6gj8/6gj8_ligand.mol2
|
||||
362,data/PDBBind_processed/6mo2/6mo2_protein_processed.pdb,data/PDBBind_processed/6mo2/6mo2_ligand.mol2
|
||||
|
||||
|
@@ -58,7 +58,7 @@ class OptimizeConformer:
|
||||
def score_conformation(self, values):
|
||||
for i, r in enumerate(self.rotable_bonds):
|
||||
SetDihedral(self.mol.GetConformer(self.probe_id), r, values[i])
|
||||
return RMSD(self.mol, self.true_mol, self.probe_id, self.ref_id)
|
||||
return AllChem.AlignMol(self.mol, self.true_mol, self.probe_id, self.ref_id)
|
||||
|
||||
|
||||
def get_torsion_angles(mol):
|
||||
@@ -83,114 +83,3 @@ def get_torsion_angles(mol):
|
||||
)
|
||||
return torsions_list
|
||||
|
||||
|
||||
# GeoMol
|
||||
def get_torsions(mol_list):
|
||||
print('USING GEOMOL GET TORSIONS FUNCTION')
|
||||
atom_counter = 0
|
||||
torsionList = []
|
||||
for m in mol_list:
|
||||
torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]'
|
||||
torsionQuery = Chem.MolFromSmarts(torsionSmarts)
|
||||
matches = m.GetSubstructMatches(torsionQuery)
|
||||
for match in matches:
|
||||
idx2 = match[0]
|
||||
idx3 = match[1]
|
||||
bond = m.GetBondBetweenAtoms(idx2, idx3)
|
||||
jAtom = m.GetAtomWithIdx(idx2)
|
||||
kAtom = m.GetAtomWithIdx(idx3)
|
||||
for b1 in jAtom.GetBonds():
|
||||
if (b1.GetIdx() == bond.GetIdx()):
|
||||
continue
|
||||
idx1 = b1.GetOtherAtomIdx(idx2)
|
||||
for b2 in kAtom.GetBonds():
|
||||
if ((b2.GetIdx() == bond.GetIdx())
|
||||
or (b2.GetIdx() == b1.GetIdx())):
|
||||
continue
|
||||
idx4 = b2.GetOtherAtomIdx(idx3)
|
||||
# skip 3-membered rings
|
||||
if (idx4 == idx1):
|
||||
continue
|
||||
if m.GetAtomWithIdx(idx4).IsInRing():
|
||||
torsionList.append(
|
||||
(idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter))
|
||||
break
|
||||
else:
|
||||
torsionList.append(
|
||||
(idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter))
|
||||
break
|
||||
break
|
||||
|
||||
atom_counter += m.GetNumAtoms()
|
||||
return torsionList
|
||||
|
||||
|
||||
def A_transpose_matrix(alpha):
|
||||
return np.array([[np.cos(alpha), np.sin(alpha)], [-np.sin(alpha), np.cos(alpha)]], dtype=np.double)
|
||||
|
||||
|
||||
def S_vec(alpha):
|
||||
return np.array([[np.cos(alpha)], [np.sin(alpha)]], dtype=np.double)
|
||||
|
||||
|
||||
def GetDihedralFromPointCloud(Z, atom_idx):
|
||||
p = Z[list(atom_idx)]
|
||||
b = p[:-1] - p[1:]
|
||||
b[0] *= -1
|
||||
v = np.array([v - (v.dot(b[1]) / b[1].dot(b[1])) * b[1] for v in [b[0], b[2]]])
|
||||
# Normalize vectors
|
||||
v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1, 1)
|
||||
b1 = b[1] / np.linalg.norm(b[1])
|
||||
x = np.dot(v[0], v[1])
|
||||
m = np.cross(v[0], b1)
|
||||
y = np.dot(m, v[1])
|
||||
return np.arctan2(y, x)
|
||||
|
||||
|
||||
def get_dihedral_vonMises(mol, conf, atom_idx, Z):
|
||||
Z = np.array(Z)
|
||||
v = np.zeros((2, 1))
|
||||
iAtom = mol.GetAtomWithIdx(atom_idx[1])
|
||||
jAtom = mol.GetAtomWithIdx(atom_idx[2])
|
||||
k_0 = atom_idx[0]
|
||||
i = atom_idx[1]
|
||||
j = atom_idx[2]
|
||||
l_0 = atom_idx[3]
|
||||
for b1 in iAtom.GetBonds():
|
||||
k = b1.GetOtherAtomIdx(i)
|
||||
if k == j:
|
||||
continue
|
||||
for b2 in jAtom.GetBonds():
|
||||
l = b2.GetOtherAtomIdx(j)
|
||||
if l == i:
|
||||
continue
|
||||
assert k != l
|
||||
s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l)))
|
||||
a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l)))
|
||||
v = v + np.matmul(a_mat, s_star)
|
||||
v = v / np.linalg.norm(v)
|
||||
v = v.reshape(-1)
|
||||
return np.arctan2(v[1], v[0])
|
||||
|
||||
|
||||
def get_von_mises_rms(mol, mol_rdkit, rotable_bonds, conf_id):
|
||||
new_dihedrals = np.zeros(len(rotable_bonds))
|
||||
for idx, r in enumerate(rotable_bonds):
|
||||
new_dihedrals[idx] = get_dihedral_vonMises(mol_rdkit,
|
||||
mol_rdkit.GetConformer(conf_id), r,
|
||||
mol.GetConformer().GetPositions())
|
||||
mol_rdkit = apply_changes(mol_rdkit, new_dihedrals, rotable_bonds, conf_id)
|
||||
return RMSD(mol_rdkit, mol, conf_id)
|
||||
|
||||
|
||||
def mmff_func(mol):
|
||||
mol_mmff = copy.deepcopy(mol)
|
||||
AllChem.MMFFOptimizeMoleculeConfs(mol_mmff, mmffVariant='MMFF94s')
|
||||
for i in range(mol.GetNumConformers()):
|
||||
coords = mol_mmff.GetConformers()[i].GetPositions()
|
||||
for j in range(coords.shape[0]):
|
||||
mol.GetConformer(i).SetAtomPosition(j,
|
||||
Geometry.Point3D(*coords[j]))
|
||||
|
||||
|
||||
RMSD = AllChem.AlignMol
|
||||
|
||||
179
datasets/constants.py
Normal file
179
datasets/constants.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Significant contribution from Ben Fry and Nick Polizzi
|
||||
|
||||
three_to_one = {'ALA': 'A',
|
||||
'ARG': 'R',
|
||||
'ASN': 'N',
|
||||
'ASP': 'D',
|
||||
'CYS': 'C',
|
||||
'GLN': 'Q',
|
||||
'GLU': 'E',
|
||||
'GLY': 'G',
|
||||
'HIS': 'H',
|
||||
'ILE': 'I',
|
||||
'LEU': 'L',
|
||||
'LYS': 'K',
|
||||
'MET': 'M',
|
||||
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
|
||||
'PHE': 'F',
|
||||
'PRO': 'P',
|
||||
'PYL': 'O',
|
||||
'SER': 'S',
|
||||
'SEC': 'U',
|
||||
'THR': 'T',
|
||||
'TRP': 'W',
|
||||
'TYR': 'Y',
|
||||
'VAL': 'V',
|
||||
'ASX': 'B',
|
||||
'GLX': 'Z',
|
||||
'XAA': 'X',
|
||||
'XLE': 'J'}
|
||||
|
||||
|
||||
aa_name2aa_idx = {'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4, 'GLU': 5, 'GLN': 6, 'GLY': 7,
|
||||
'HIS': 8, 'ILE': 9, 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
|
||||
'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19, 'MSE': 12}
|
||||
|
||||
|
||||
aa_short2long = {'C': 'CYS', 'D': 'ASP', 'S': 'SER', 'Q': 'GLN', 'K': 'LYS', 'I': 'ILE',
|
||||
'P': 'PRO', 'T': 'THR', 'F': 'PHE', 'N': 'ASN', 'G': 'GLY', 'H': 'HIS',
|
||||
'L': 'LEU', 'R': 'ARG', 'W': 'TRP', 'A': 'ALA', 'V': 'VAL', 'E': 'GLU',
|
||||
'Y': 'TYR', 'M': 'MET'}
|
||||
|
||||
|
||||
aa_short2aa_idx = {aa_short: aa_name2aa_idx[aa_long] for aa_short, aa_long in aa_short2long.items()}
|
||||
aa_idx2aa_short = {aa_idx: aa_short for aa_short, aa_idx in aa_short2aa_idx.items()}
|
||||
aa_long2short = {aa_long: aa_short for aa_short, aa_long in aa_short2long.items()}
|
||||
aa_long2short['MSE'] = 'M'
|
||||
|
||||
chi = { 'C' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'SG' ) },
|
||||
'D' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'OD1'), },
|
||||
'E' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD' ),
|
||||
3: ('CB' , 'CG' , 'CD' , 'OE1'), },
|
||||
'F' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
|
||||
'H' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'ND1'), },
|
||||
'I' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG1'),
|
||||
2: ('CA' , 'CB' , 'CG1', 'CD1'), },
|
||||
'K' :
|
||||
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' ,'CD' ),
|
||||
3: ('CB' , 'CG' , 'CD' ,'CE' ),
|
||||
4: ('CG' , 'CD' , 'CE' ,'NZ' ), },
|
||||
'L' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
|
||||
'M' :
|
||||
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' ,'SD' ),
|
||||
3: ('CB' , 'CG' , 'SD' ,'CE' ), },
|
||||
'N' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'OD1'), },
|
||||
'P' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD' ), },
|
||||
'Q' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD' ),
|
||||
3: ('CB' , 'CG' , 'CD' , 'OE1'), },
|
||||
'R' :
|
||||
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' ,'CD' ),
|
||||
3: ('CB' , 'CG' , 'CD' ,'NE' ),
|
||||
4: ('CG' , 'CD' , 'NE' ,'CZ' ), },
|
||||
'S' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'OG' ), },
|
||||
'T' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'OG1'), },
|
||||
'V' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG1'), },
|
||||
'W' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
|
||||
'Y' :
|
||||
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
|
||||
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
|
||||
}
|
||||
|
||||
|
||||
atom_order = {'G': ['N', 'CA', 'C', 'O'],
|
||||
'A': ['N', 'CA', 'C', 'O', 'CB'],
|
||||
'S': ['N', 'CA', 'C', 'O', 'CB', 'OG'],
|
||||
'C': ['N', 'CA', 'C', 'O', 'CB', 'SG'],
|
||||
'T': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2'],
|
||||
'P': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD'],
|
||||
'V': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2'],
|
||||
'M': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE'],
|
||||
'N': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2'],
|
||||
'I': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1'],
|
||||
'L': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2'],
|
||||
'D': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2'],
|
||||
'E': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2'],
|
||||
'K': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ'],
|
||||
'Q': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2'],
|
||||
'H': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2'],
|
||||
'F': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
|
||||
'R': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
|
||||
'Y': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH'],
|
||||
'W': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'NE1', 'CZ2', 'CZ3', 'CH2'],
|
||||
'X': ['N', 'CA', 'C', 'O']} # unknown amino acid
|
||||
|
||||
|
||||
|
||||
amino_acid_smiles = {
|
||||
'PHE': '[NH3+]CC(=O)N[C@@H](Cc1ccccc1)C(=O)NCC(=O)O',
|
||||
'MET': 'CSCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'TYR': '[NH3+]CC(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)NCC(=O)O',
|
||||
'ILE': 'CC[C@H](C)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'TRP': '[NH3+]CC(=O)N[C@@H](Cc1c[nH]c2ccccc12)C(=O)NCC(=O)O',
|
||||
'THR': 'C[C@@H](O)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'CYS': '[NH3+]CC(=O)N[C@@H](CS)C(=O)NCC(=O)O',
|
||||
'ALA': 'C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'LYS': '[NH3+]CCCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'PRO': '[NH3+]CC(=O)N1CCC[C@H]1C(=O)NCC(=O)O',
|
||||
'LEU': 'CC(C)C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'GLY': '[NH3+]CC(=O)NCC(=O)NCC(=O)O',
|
||||
'ASP': '[NH3+]CC(=O)N[C@@H](CC(=O)O)C(=O)NCC(=O)O',
|
||||
'HIS': '[NH3+]CC(=O)N[C@@H](Cc1c[nH]c[nH+]1)C(=O)NCC(=O)O',
|
||||
'VAL': 'CC(C)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'SER': '[NH3+]CC(=O)N[C@@H](CO)C(=O)NCC(=O)O',
|
||||
'ARG': 'NC(=[NH2+])NCCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'GLU': '[NH3+]CC(=O)N[C@@H](CCC(=O)O)C(=O)NCC(=O)O',
|
||||
'GLN': 'NC(=O)CC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
'ASN': 'NC(=O)C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
|
||||
}
|
||||
|
||||
cg_rdkit_indices = {
|
||||
'PHE': {4: 'N', 5: 'CA', 13: 'C', 14: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 12: 'CD2', 9: 'CE1', 11: 'CE2', 10: 'CZ'},
|
||||
'MET': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 2: 'CG', 1: 'SD', 0: 'CE'},
|
||||
'TYR': {4: 'N', 5: 'CA', 14: 'C', 15: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 13: 'CD2', 9: 'CE1', 12: 'CE2', 10: 'CZ', 11: 'OH'},
|
||||
'ILE': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 2: 'CB', 1: 'CG1', 3: 'CG2', 0: 'CD1'},
|
||||
'TRP': {4: 'N', 5: 'CA', 16: 'C', 17: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 15: 'CD2', 9: 'NE1', 10: 'CE2', 14: 'CE3', 11: 'CZ2', 13: 'CZ3', 12: 'CH2'},
|
||||
'THR': {4: 'N', 3: 'CA', 9: 'C', 10: 'O', 1: 'CB', 2: 'OG1', 0: 'CG2'},
|
||||
'CYS': {4: 'N', 5: 'CA', 8: 'C', 9: 'O', 6: 'CB', 7: 'SG'},
|
||||
'ALA': {2: 'N', 1: 'CA', 7: 'C', 8: 'O', 0: 'CB'},
|
||||
'LYS': {6: 'N', 5: 'CA', 11: 'C', 12: 'O', 4: 'CB', 3: 'CG', 2: 'CD', 1: 'CE', 0: 'NZ'},
|
||||
'PRO': {4: 'N', 8: 'CA', 9: 'C', 10: 'O', 7: 'CB', 6: 'CG', 5: 'CD'},
|
||||
'LEU': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 1: 'CG', 0: 'CD1', 2: 'CD2'},
|
||||
'GLY': {4: 'N', 5: 'CA', 6: 'C', 7: 'O'},
|
||||
'ASP': {4: 'N', 5: 'CA', 10: 'C', 11: 'O', 6: 'CB', 7: 'CG', 8: 'OD1', 9: 'OD2'},
|
||||
'HIS': {4: 'N', 5: 'CA', 12: 'C', 13: 'O', 6: 'CB', 7: 'CG', 11: 'ND1', 8: 'CD2', 10: 'CE1', 9: 'NE2'},
|
||||
'VAL': {4: 'N', 3: 'CA', 9: 'C', 10: 'O', 1: 'CB', 0: 'CG1', 2: 'CG2'},
|
||||
'SER': {4: 'N', 5: 'CA', 8: 'C', 9: 'O', 6: 'CB', 7: 'OG'},
|
||||
'ARG': {8: 'N', 7: 'CA', 13: 'C', 14: 'O', 6: 'CB', 5: 'CG', 4: 'CD', 3: 'NE', 1: 'CZ', 0: 'NH1', 2: 'NH2'},
|
||||
'GLU': {4: 'N', 5: 'CA', 11: 'C', 12: 'O', 6: 'CB', 7: 'CG', 8: 'CD', 9: 'OE1', 10: 'OE2'},
|
||||
'GLN': {6: 'N', 5: 'CA', 11: 'C', 12: 'O', 4: 'CB', 3: 'CG', 1: 'CD', 2: 'OE1', 0: 'NE2'},
|
||||
'ASN': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 1: 'CG', 2: 'OD1', 0: 'ND2'}
|
||||
}
|
||||
|
||||
aa_to_cg_indices = {aa_long2short[x]: [atom_order[aa_long2short[x]].index(aname) for aname in index_dict.values()] for x, index_dict in cg_rdkit_indices.items()}
|
||||
|
||||
101
datasets/dataloader.py
Normal file
101
datasets/dataloader.py
Normal file
@@ -0,0 +1,101 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch.utils.data
|
||||
from torch.utils.data.dataloader import default_collate
|
||||
|
||||
from torch_geometric.data import Batch, Dataset
|
||||
from torch_geometric.data.data import BaseData
|
||||
|
||||
|
||||
class Collater:
|
||||
def __init__(self, follow_batch, exclude_keys):
|
||||
self.follow_batch = follow_batch
|
||||
self.exclude_keys = exclude_keys
|
||||
|
||||
def __call__(self, batch):
|
||||
batch = [x for x in batch if x is not None]
|
||||
elem = batch[0]
|
||||
if isinstance(elem, BaseData):
|
||||
return Batch.from_data_list(batch, self.follow_batch,
|
||||
self.exclude_keys)
|
||||
elif isinstance(elem, torch.Tensor):
|
||||
return default_collate(batch)
|
||||
elif isinstance(elem, float):
|
||||
return torch.tensor(batch, dtype=torch.float)
|
||||
elif isinstance(elem, int):
|
||||
return torch.tensor(batch)
|
||||
elif isinstance(elem, str):
|
||||
return batch
|
||||
elif isinstance(elem, Mapping):
|
||||
return {key: self([data[key] for data in batch]) for key in elem}
|
||||
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
|
||||
return type(elem)(*(self(s) for s in zip(*batch)))
|
||||
elif isinstance(elem, Sequence) and not isinstance(elem, str):
|
||||
return [self(s) for s in zip(*batch)]
|
||||
|
||||
raise TypeError(f'DataLoader found invalid type: {type(elem)}')
|
||||
|
||||
def collate(self, batch): # Deprecated...
|
||||
return self(batch)
|
||||
|
||||
|
||||
class DataLoader(torch.utils.data.DataLoader):
|
||||
r"""A data loader which merges data objects from a
|
||||
:class:`torch_geometric.data.Dataset` to a mini-batch.
|
||||
Data objects can be either of type :class:`~torch_geometric.data.Data` or
|
||||
:class:`~torch_geometric.data.HeteroData`.
|
||||
|
||||
Args:
|
||||
dataset (Dataset): The dataset from which to load the data.
|
||||
batch_size (int, optional): How many samples per batch to load.
|
||||
(default: :obj:`1`)
|
||||
shuffle (bool, optional): If set to :obj:`True`, the data will be
|
||||
reshuffled at every epoch. (default: :obj:`False`)
|
||||
follow_batch (List[str], optional): Creates assignment batch
|
||||
vectors for each key in the list. (default: :obj:`None`)
|
||||
exclude_keys (List[str], optional): Will exclude each key in the
|
||||
list. (default: :obj:`None`)
|
||||
**kwargs (optional): Additional arguments of
|
||||
:class:`torch.utils.data.DataLoader`.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Union[Dataset, List[BaseData]],
|
||||
batch_size: int = 1,
|
||||
shuffle: bool = False,
|
||||
follow_batch: Optional[List[str]] = None,
|
||||
exclude_keys: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
if 'collate_fn' in kwargs:
|
||||
del kwargs['collate_fn']
|
||||
|
||||
# Save for PyTorch Lightning:
|
||||
self.follow_batch = follow_batch
|
||||
self.exclude_keys = exclude_keys
|
||||
|
||||
super().__init__(
|
||||
dataset,
|
||||
batch_size,
|
||||
shuffle,
|
||||
collate_fn=Collater(follow_batch, exclude_keys),
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn(data_list):
|
||||
data_list = [x for x in data_list if x is not None]
|
||||
return data_list
|
||||
|
||||
|
||||
class DataListLoader(torch.utils.data.DataLoader):
|
||||
def __init__(self, dataset: Union[Dataset, List[BaseData]],
|
||||
batch_size: int = 1, shuffle: bool = False, **kwargs):
|
||||
if 'collate_fn' in kwargs:
|
||||
del kwargs['collate_fn']
|
||||
|
||||
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle,
|
||||
collate_fn=collate_fn, **kwargs)
|
||||
|
||||
@@ -1,60 +1,27 @@
|
||||
import os
|
||||
from argparse import FileType, ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
from Bio.PDB import PDBParser
|
||||
from Bio.Seq import Seq
|
||||
from Bio.SeqRecord import SeqRecord
|
||||
from tqdm import tqdm
|
||||
from Bio import SeqIO
|
||||
|
||||
from datasets.constants import three_to_one
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--out_file', type=str, default="data/prepared_for_esm.fasta")
|
||||
parser.add_argument('--protein_ligand_csv', type=str, default='data/protein_ligand_example_csv.csv', help='Path to a .csv specifying the input as described in the main README')
|
||||
parser.add_argument('--protein_path', type=str, default=None, help='Path to a single PDB file. If this is not None then it will be used instead of the --protein_ligand_csv')
|
||||
parser.add_argument('--dataset', type=str, default="pdbbind")
|
||||
parser.add_argument('--data_dir', type=str, default='../data/BindingMOAD_2020_ab_processed_biounit/pdb_protein/', help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
biopython_parser = PDBParser()
|
||||
|
||||
three_to_one = {'ALA': 'A',
|
||||
'ARG': 'R',
|
||||
'ASN': 'N',
|
||||
'ASP': 'D',
|
||||
'CYS': 'C',
|
||||
'GLN': 'Q',
|
||||
'GLU': 'E',
|
||||
'GLY': 'G',
|
||||
'HIS': 'H',
|
||||
'ILE': 'I',
|
||||
'LEU': 'L',
|
||||
'LYS': 'K',
|
||||
'MET': 'M',
|
||||
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
|
||||
'PHE': 'F',
|
||||
'PRO': 'P',
|
||||
'PYL': 'O',
|
||||
'SER': 'S',
|
||||
'SEC': 'U',
|
||||
'THR': 'T',
|
||||
'TRP': 'W',
|
||||
'TYR': 'Y',
|
||||
'VAL': 'V',
|
||||
'ASX': 'B',
|
||||
'GLX': 'Z',
|
||||
'XAA': 'X',
|
||||
'XLE': 'J'}
|
||||
|
||||
if args.protein_path is not None:
|
||||
file_paths = [args.protein_path]
|
||||
else:
|
||||
df = pd.read_csv(args.protein_ligand_csv)
|
||||
file_paths = list(set(df['protein_path'].tolist()))
|
||||
sequences = []
|
||||
ids = []
|
||||
for file_path in tqdm(file_paths):
|
||||
def get_structure_from_file(file_path):
|
||||
structure = biopython_parser.get_structure('random_id', file_path)
|
||||
structure = structure[0]
|
||||
l = []
|
||||
for i, chain in enumerate(structure):
|
||||
seq = ''
|
||||
for res_idx, residue in enumerate(chain):
|
||||
@@ -75,13 +42,48 @@ for file_path in tqdm(file_paths):
|
||||
except Exception as e:
|
||||
seq += '-'
|
||||
print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', file_path, '. Replacing it with a dash - .')
|
||||
sequences.append(seq)
|
||||
ids.append(f'{os.path.basename(file_path)}_chain_{i}')
|
||||
records = []
|
||||
for (index, seq) in zip(ids,sequences):
|
||||
record = SeqRecord(Seq(seq), str(index))
|
||||
record.description = ''
|
||||
records.append(record)
|
||||
SeqIO.write(records, args.out_file, "fasta")
|
||||
l.append(seq)
|
||||
return l
|
||||
|
||||
data_dir = args.data_dir
|
||||
names = os.listdir(data_dir)
|
||||
|
||||
if args.dataset == 'pdbbind':
|
||||
sequences = []
|
||||
ids = []
|
||||
|
||||
for name in tqdm(names):
|
||||
if name == '.DS_Store': continue
|
||||
if os.path.exists(os.path.join(data_dir, name, f'{name}_protein_processed.pdb')):
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein_processed.pdb')
|
||||
else:
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
|
||||
l = get_structure_from_file(rec_path)
|
||||
for i, seq in enumerate(l):
|
||||
sequences.append(seq)
|
||||
ids.append(f'{name}_chain_{i}')
|
||||
records = []
|
||||
for (index, seq) in zip(ids, sequences):
|
||||
record = SeqRecord(Seq(seq), str(index))
|
||||
record.description = ''
|
||||
records.append(record)
|
||||
SeqIO.write(records, args.out_file, "fasta")
|
||||
|
||||
elif args.dataset == 'moad':
|
||||
names = [n[:6] for n in names]
|
||||
name_to_sequence = {}
|
||||
|
||||
for name in tqdm(names):
|
||||
if name == '.DS_Store': continue
|
||||
if not os.path.exists(os.path.join(data_dir, f'{name}_protein.pdb')):
|
||||
print(f"We are skipping {name} because there was no {name}_protein.pdb")
|
||||
continue
|
||||
rec_path = os.path.join(data_dir, f'{name}_protein.pdb')
|
||||
l = get_structure_from_file(rec_path)
|
||||
for i, seq in enumerate(l):
|
||||
name_to_sequence[name + '_chain_' + str(i)] = seq
|
||||
|
||||
# save to file
|
||||
with open(args.out_file, 'wb') as f:
|
||||
pickle.dump(name_to_sequence, f)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
@@ -7,8 +6,8 @@ from tqdm import tqdm
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--esm_embeddings_path', type=str, default='data/embeddings_output', help='')
|
||||
parser.add_argument('--output_path', type=str, default='data/esm2_3billion_embeddings.pt', help='')
|
||||
parser.add_argument('--esm_embeddings_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new', help='')
|
||||
parser.add_argument('--output_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new.pt', help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
dict = {}
|
||||
|
||||
123
datasets/loader.py
Normal file
123
datasets/loader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import torch
|
||||
from torch_geometric.data import Dataset
|
||||
|
||||
from datasets.dataloader import DataLoader, DataListLoader
|
||||
from datasets.moad import MOAD
|
||||
from datasets.pdb import PDBSidechain
|
||||
from datasets.pdbbind import NoiseTransform, PDBBind
|
||||
from utils.utils import read_strings_from_txt
|
||||
|
||||
|
||||
class CombineDatasets(Dataset):
|
||||
def __init__(self, dataset1, dataset2):
|
||||
super(CombineDatasets, self).__init__()
|
||||
self.dataset1 = dataset1
|
||||
self.dataset2 = dataset2
|
||||
|
||||
def len(self):
|
||||
return len(self.dataset1) + len(self.dataset2)
|
||||
|
||||
def get(self, idx):
|
||||
if idx < len(self.dataset1):
|
||||
return self.dataset1[idx]
|
||||
else:
|
||||
return self.dataset2[idx - len(self.dataset1)]
|
||||
|
||||
def add_complexes(self, new_complex_list):
|
||||
self.dataset1.add_complexes(new_complex_list)
|
||||
|
||||
|
||||
def construct_loader(args, t_to_sigma, device):
|
||||
val_dataset2 = None
|
||||
transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
|
||||
all_atom=args.all_atoms, alpha=args.sampling_alpha, beta=args.sampling_beta,
|
||||
include_miscellaneous_atoms=False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
|
||||
crop_beyond_cutoff=args.crop_beyond)
|
||||
if args.triple_training: assert args.combined_training
|
||||
|
||||
sequences_to_embeddings = None
|
||||
if args.dataset == 'pdbsidechain' or args.triple_training:
|
||||
if args.pdbsidechain_esm_embeddings_path is not None:
|
||||
print('Loading ESM embeddings')
|
||||
id_to_embeddings = torch.load(args.pdbsidechain_esm_embeddings_path)
|
||||
sequences_list = read_strings_from_txt(args.pdbsidechain_esm_embeddings_sequences_path)
|
||||
sequences_to_embeddings = {}
|
||||
for i, seq in enumerate(sequences_list):
|
||||
if str(i) in id_to_embeddings:
|
||||
sequences_to_embeddings[seq] = id_to_embeddings[str(i)]
|
||||
|
||||
if args.dataset == 'pdbsidechain' or args.triple_training:
|
||||
|
||||
common_args = {'root': args.pdbsidechain_dir, 'transform': transform, 'limit_complexes': args.limit_complexes,
|
||||
'receptor_radius': args.receptor_radius,
|
||||
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
|
||||
'remove_hs': args.remove_hs, 'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
|
||||
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
|
||||
'knn_only_graph': not args.not_knn_only_graph, 'sequences_to_embeddings': sequences_to_embeddings,
|
||||
'vandermers_max_dist': args.vandermers_max_dist,
|
||||
'vandermers_buffer_residue_num': args.vandermers_buffer_residue_num,
|
||||
'vandermers_min_contacts': args.vandermers_min_contacts,
|
||||
'remove_second_segment': args.remove_second_segment,
|
||||
'merge_clusters': args.merge_clusters}
|
||||
train_dataset3 = PDBSidechain(cache_path=args.cache_path, split='train', multiplicity=args.train_multiplicity, **common_args)
|
||||
|
||||
if args.dataset == 'pdbsidechain':
|
||||
train_dataset = train_dataset3
|
||||
val_dataset = PDBSidechain(cache_path=args.cache_path, split='val', multiplicity=args.val_multiplicity, **common_args)
|
||||
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
|
||||
|
||||
if args.dataset in ['pdbbind', 'moad', 'generalisation', 'distillation']:
|
||||
common_args = {'transform': transform, 'limit_complexes': args.limit_complexes,
|
||||
'chain_cutoff': args.chain_cutoff, 'receptor_radius': args.receptor_radius,
|
||||
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
|
||||
'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
|
||||
'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
|
||||
'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
|
||||
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
|
||||
'knn_only_graph': False if not hasattr(args, 'not_knn_only_graph') else not args.not_knn_only_graph,
|
||||
'include_miscellaneous_atoms': False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
|
||||
'matching_tries': args.matching_tries}
|
||||
|
||||
if args.dataset == 'pdbbind' or args.dataset == 'generalisation' or args.combined_training:
|
||||
train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
|
||||
num_conformers=args.num_conformers, root=args.pdbbind_dir,
|
||||
esm_embeddings_path=args.pdbbind_esm_embeddings_path,
|
||||
protein_file=args.protein_file, **common_args)
|
||||
|
||||
if args.dataset == 'moad' or args.combined_training:
|
||||
train_dataset2 = MOAD(cache_path=args.cache_path, split='train', keep_original=True,
|
||||
num_conformers=args.num_conformers, max_receptor_size=args.max_receptor_size,
|
||||
remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
|
||||
multiplicity= args.train_multiplicity, unroll_clusters=args.unroll_clusters,
|
||||
esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
|
||||
root=args.moad_dir, esm_embeddings_path=args.moad_esm_embeddings_path,
|
||||
enforce_timesplit=args.enforce_timesplit, **common_args)
|
||||
|
||||
if args.combined_training:
|
||||
train_dataset = CombineDatasets(train_dataset2, train_dataset)
|
||||
if args.triple_training:
|
||||
train_dataset = CombineDatasets(train_dataset, train_dataset3)
|
||||
else:
|
||||
train_dataset = train_dataset2
|
||||
|
||||
if args.dataset == 'pdbbind' or args.double_val:
|
||||
val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True,
|
||||
esm_embeddings_path=args.pdbbind_esm_embeddings_path, root=args.pdbbind_dir,
|
||||
protein_file=args.protein_file, require_ligand=True, **common_args)
|
||||
if args.double_val:
|
||||
val_dataset2 = val_dataset
|
||||
|
||||
if args.dataset == 'moad' or args.dataset == 'generalisation':
|
||||
val_dataset = MOAD(cache_path=args.cache_path, split='val', keep_original=True,
|
||||
multiplicity= args.val_multiplicity, max_receptor_size=args.max_receptor_size,
|
||||
remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
|
||||
esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
|
||||
unroll_clusters=args.unroll_clusters, root=args.moad_dir,
|
||||
esm_embeddings_path=args.moad_esm_embeddings_path, require_ligand=True, **common_args)
|
||||
|
||||
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
|
||||
|
||||
train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
|
||||
val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=False, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
|
||||
return train_loader, val_loader, val_dataset2
|
||||
|
||||
547
datasets/moad.py
Normal file
547
datasets/moad.py
Normal file
@@ -0,0 +1,547 @@
|
||||
import os
|
||||
import pickle
|
||||
from multiprocessing import Pool
|
||||
import random
|
||||
import copy
|
||||
from torch_geometric.data import Batch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from prody import confProDy
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import RemoveHs
|
||||
from torch_geometric.data import Dataset, HeteroData
|
||||
from torch_geometric.utils import subgraph
|
||||
from tqdm import tqdm
|
||||
confProDy(verbosity='none')
|
||||
from datasets.process_mols import get_lig_graph_with_matching, moad_extract_receptor_structure
|
||||
from utils.utils import read_strings_from_txt
|
||||
|
||||
class MOAD(Dataset):
|
||||
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0, chain_cutoff=None,
|
||||
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
|
||||
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
|
||||
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, esm_embeddings_sequences_path=None, require_ligand=False,
|
||||
include_miscellaneous_atoms=False, keep_local_structures=False,
|
||||
min_ligand_size=0, knn_only_graph=False, matching_tries=1, multiplicity=1,
|
||||
max_receptor_size=None, remove_promiscuous_targets=None, unroll_clusters=False, remove_pdbbind=False,
|
||||
enforce_timesplit=False, no_randomness=False, single_cluster_name=None, total_dataset_size=None, skip_matching=False):
|
||||
|
||||
super(MOAD, self).__init__(root, transform)
|
||||
self.moad_dir = root
|
||||
self.include_miscellaneous_atoms = include_miscellaneous_atoms
|
||||
self.max_lig_size = max_lig_size
|
||||
self.split = split
|
||||
self.limit_complexes = limit_complexes
|
||||
self.receptor_radius = receptor_radius
|
||||
self.num_workers = num_workers
|
||||
self.c_alpha_max_neighbors = c_alpha_max_neighbors
|
||||
self.remove_hs = remove_hs
|
||||
self.require_ligand = require_ligand
|
||||
self.esm_embeddings_path = esm_embeddings_path
|
||||
self.esm_embeddings_sequences_path = esm_embeddings_sequences_path
|
||||
self.keep_local_structures = keep_local_structures
|
||||
self.knn_only_graph = knn_only_graph
|
||||
self.matching_tries = matching_tries
|
||||
self.all_atoms = all_atoms
|
||||
self.multiplicity = multiplicity
|
||||
self.chain_cutoff = chain_cutoff
|
||||
self.no_randomness = no_randomness
|
||||
self.total_dataset_size = total_dataset_size
|
||||
self.skip_matching = skip_matching
|
||||
|
||||
self.prot_cache_path = os.path.join(cache_path, f'MOAD12_limit{self.limit_complexes}_INDEX{self.split}'
|
||||
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
|
||||
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
|
||||
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
|
||||
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms')
|
||||
+ ('' if not self.knn_only_graph else '_knnOnly'))
|
||||
|
||||
self.lig_cache_path = os.path.join(cache_path, f'MOAD12_limit{self.limit_complexes}_INDEX{self.split}'
|
||||
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
|
||||
+ ('' if not matching else f'_matching')
|
||||
+ ('' if not skip_matching else f'skip')
|
||||
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}')
|
||||
+ ('' if not keep_local_structures else f'_keptLocalStruct')
|
||||
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}'))
|
||||
|
||||
self.popsize, self.maxiter = popsize, maxiter
|
||||
self.matching, self.keep_original = matching, keep_original
|
||||
self.num_conformers = num_conformers
|
||||
self.single_cluster_name = single_cluster_name
|
||||
if split == 'train':
|
||||
split = 'PDBBind'
|
||||
|
||||
with open("./data/splits/MOAD_generalisation_splits.pkl", "rb") as f:
|
||||
self.split_clusters = pickle.load(f)[split]
|
||||
|
||||
clustes_path = os.path.join(self.moad_dir, "new_cluster_to_ligands.pkl")
|
||||
with open(clustes_path, "rb") as f:
|
||||
self.cluster_to_ligands = pickle.load(f)
|
||||
#self.cluster_to_ligands = {k: [s.split('.')[0] for s in v] for k, v in self.cluster_to_ligands.items()}
|
||||
|
||||
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
||||
if not self.check_all_receptors():
|
||||
os.makedirs(self.prot_cache_path, exist_ok=True)
|
||||
self.preprocessing_receptors()
|
||||
|
||||
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
||||
if not os.path.exists(os.path.join(self.lig_cache_path, "ligands.pkl")):
|
||||
os.makedirs(self.lig_cache_path, exist_ok=True)
|
||||
self.preprocessing_ligands()
|
||||
|
||||
print('loading ligands from memory: ', os.path.join(self.lig_cache_path, "ligands.pkl"))
|
||||
with open(os.path.join(self.lig_cache_path, "ligands.pkl"), 'rb') as f:
|
||||
self.ligands = pickle.load(f)
|
||||
|
||||
if require_ligand:
|
||||
with open(os.path.join(self.lig_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
|
||||
self.rdkit_ligands = pickle.load(f)
|
||||
self.rdkit_ligands = {lig.name:mol for mol, lig in zip(self.rdkit_ligands, self.ligands)}
|
||||
|
||||
len_before = len(self.ligands)
|
||||
if not self.single_cluster_name is None:
|
||||
self.ligands = [lig for lig in self.ligands if lig.name in self.cluster_to_ligands[self.single_cluster_name]]
|
||||
print('Kept', len(self.ligands), f'ligands in {self.single_cluster_name} out of', len_before)
|
||||
|
||||
len_before = len(self.ligands)
|
||||
self.ligands = {lig.name: lig for lig in self.ligands if min_ligand_size == 0 or lig['ligand'].x.shape[0] >= min_ligand_size}
|
||||
print('removed', len_before - len(self.ligands), 'ligands below minimum size out of', len_before)
|
||||
|
||||
receptors_names = set([lig.name[:6] for lig in self.ligands.values()])
|
||||
self.collect_receptors(receptors_names, max_receptor_size, remove_promiscuous_targets)
|
||||
|
||||
# filter ligands for which the receptor failed
|
||||
tot_before = len(self.ligands)
|
||||
self.ligands = {k:v for k, v in self.ligands.items() if k[:6] in self.receptors}
|
||||
print('removed', tot_before - len(self.ligands), 'ligands with no receptor out of', tot_before)
|
||||
|
||||
if remove_pdbbind:
|
||||
complexes_pdbbind = read_strings_from_txt('data/splits/timesplit_no_lig_overlap_train') + read_strings_from_txt('data/splits/timesplit_no_lig_overlap_val')
|
||||
with open('data/BindingMOAD_2020_ab_processed_biounit/ecod_t_group_binding_site_assignment_dict_major_domain.pkl', 'rb') as f:
|
||||
pdbbind_to_cluster = pickle.load(f)
|
||||
clusters_pdbbind = set([pdbbind_to_cluster[c] for c in complexes_pdbbind])
|
||||
self.split_clusters = [c for c in self.split_clusters if c not in clusters_pdbbind]
|
||||
self.cluster_to_ligands = {k: v for k, v in self.cluster_to_ligands.items() if k not in clusters_pdbbind}
|
||||
ligand_accepted = []
|
||||
for c, ligands in self.cluster_to_ligands.items():
|
||||
ligand_accepted += ligands
|
||||
ligand_accepted = set(ligand_accepted)
|
||||
tot_before = len(self.ligands)
|
||||
self.ligands = {k: v for k, v in self.ligands.items() if k in ligand_accepted}
|
||||
print('removed', tot_before - len(self.ligands), 'ligands in overlap with PDBBind out of', tot_before)
|
||||
|
||||
if enforce_timesplit:
|
||||
with open("data/splits/pdbids_2019", "r") as f:
|
||||
lines = f.readlines()
|
||||
pdbids_from2019 = []
|
||||
for i in range(6, len(lines), 4):
|
||||
pdbids_from2019.append(lines[i][18:22])
|
||||
|
||||
pdbids_from2019 = set(pdbids_from2019)
|
||||
len_before = len(self.ligands)
|
||||
self.ligands = {k: v for k, v in self.ligands.items() if k[:4].upper() not in pdbids_from2019}
|
||||
print('removed', len_before - len(self.ligands), 'ligands from 2019 out of', len_before)
|
||||
|
||||
if unroll_clusters:
|
||||
rec_keys = set([k[:6] for k in self.ligands.keys()])
|
||||
self.cluster_to_ligands = {k:[k2 for k2 in self.ligands.keys() if k2[:6] == k] for k in rec_keys}
|
||||
self.split_clusters = list(rec_keys)
|
||||
else:
|
||||
for c in self.cluster_to_ligands.keys():
|
||||
self.cluster_to_ligands[c] = [v for v in self.cluster_to_ligands[c] if v in self.ligands]
|
||||
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_ligands[c])>0]
|
||||
|
||||
print_statistics(self)
|
||||
list_names = [name for cluster in self.split_clusters for name in self.cluster_to_ligands[cluster]]
|
||||
with open(os.path.join(self.prot_cache_path, f'moad_{self.split}_names.txt'), 'w') as f:
|
||||
f.write('\n'.join(list_names))
|
||||
|
||||
def len(self):
|
||||
return len(self.split_clusters) * self.multiplicity if self.total_dataset_size is None else self.total_dataset_size
|
||||
|
||||
def get_by_name(self, ligand_name, cluster):
|
||||
ligand_graph = copy.deepcopy(self.ligands[ligand_name])
|
||||
complex_graph = copy.deepcopy(self.receptors[ligand_name[:6]])
|
||||
|
||||
if False and self.keep_original and hasattr(ligand_graph['ligand'], 'orig_pos'):
|
||||
lig_path = os.path.join(self.moad_dir, 'pdb_superligand', ligand_name + '.pdb')
|
||||
lig = Chem.MolFromPDBFile(lig_path)
|
||||
formula = np.asarray([atom.GetSymbol() for atom in lig.GetAtoms()])
|
||||
|
||||
# check for same receptor/ligand pair with a different binding position
|
||||
for ligand_comp in self.cluster_to_ligands[cluster]:
|
||||
if ligand_comp == ligand_name or ligand_comp[:6] != ligand_name[:6]:
|
||||
continue
|
||||
|
||||
lig_path_comp = os.path.join(self.moad_dir, 'pdb_superligand', ligand_comp + '.pdb')
|
||||
if not os.path.exists(lig_path_comp):
|
||||
continue
|
||||
|
||||
lig_comp = Chem.MolFromPDBFile(lig_path_comp)
|
||||
formula_comp = np.asarray([atom.GetSymbol() for atom in lig_comp.GetAtoms()])
|
||||
|
||||
if formula.shape == formula_comp.shape and np.all(formula == formula_comp) and hasattr(
|
||||
self.ligands[ligand_comp], 'orig_pos'):
|
||||
print(f'Found complex {ligand_comp} to have the same complex/ligand pair, adding it into orig_pos')
|
||||
# add the orig_pos of the binding position
|
||||
if not isinstance(ligand_graph['ligand'].orig_pos, list):
|
||||
ligand_graph['ligand'].orig_pos = [ligand_graph['ligand'].orig_pos]
|
||||
ligand_graph['ligand'].orig_pos.append(self.ligands[ligand_comp].orig_pos)
|
||||
|
||||
for type in ligand_graph.node_types + ligand_graph.edge_types:
|
||||
for key, value in ligand_graph[type].items():
|
||||
complex_graph[type][key] = value
|
||||
complex_graph.name = ligand_graph.name
|
||||
if isinstance(complex_graph['ligand'].pos, list):
|
||||
for i in range(len(complex_graph['ligand'].pos)):
|
||||
complex_graph['ligand'].pos[i] -= complex_graph.original_center
|
||||
else:
|
||||
complex_graph['ligand'].pos -= complex_graph.original_center
|
||||
if self.require_ligand:
|
||||
complex_graph.mol = copy.deepcopy(self.rdkit_ligands[ligand_name])
|
||||
|
||||
if self.chain_cutoff:
|
||||
distances = torch.norm(
|
||||
(torch.from_numpy(complex_graph['ligand'].orig_pos[0]) - complex_graph.original_center).unsqueeze(1) - complex_graph['receptor'].pos.unsqueeze(0), dim=2)
|
||||
distances = distances.min(dim=0)[0]
|
||||
if torch.min(distances) >= self.chain_cutoff:
|
||||
print('minimum distance', torch.min(distances), 'too large', ligand_name,
|
||||
'skipping and returning random. Number of chains',
|
||||
torch.max(complex_graph['receptor'].chain_ids) + 1)
|
||||
return self.get(random.randint(0, self.len()))
|
||||
|
||||
within_cutoff = distances < self.chain_cutoff
|
||||
chains_within_cutoff = torch.zeros(torch.max(complex_graph['receptor'].chain_ids) + 1)
|
||||
chains_within_cutoff.index_add_(0, complex_graph['receptor'].chain_ids, within_cutoff.float())
|
||||
chains_within_cutoff_bool = chains_within_cutoff > 0
|
||||
residues_to_keep = chains_within_cutoff_bool[complex_graph['receptor'].chain_ids]
|
||||
|
||||
if self.all_atoms:
|
||||
atom_to_res_mapping = complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index[1]
|
||||
atoms_to_keep = residues_to_keep[atom_to_res_mapping]
|
||||
rec_remapper = (torch.cumsum(residues_to_keep.long(), dim=0) - 1)
|
||||
atom_to_res_new_mapping = rec_remapper[atom_to_res_mapping][atoms_to_keep]
|
||||
atom_res_edge_index = torch.stack([torch.arange(len(atom_to_res_new_mapping)), atom_to_res_new_mapping])
|
||||
|
||||
complex_graph['atom'].x = complex_graph['atom'].x[atoms_to_keep]
|
||||
complex_graph['atom'].pos = complex_graph['atom'].pos[atoms_to_keep]
|
||||
complex_graph['atom', 'atom_contact', 'atom'].edge_index = \
|
||||
subgraph(atoms_to_keep, complex_graph['atom', 'atom_contact', 'atom'].edge_index,
|
||||
relabel_nodes=True)[0]
|
||||
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
|
||||
|
||||
complex_graph['receptor'].pos = complex_graph['receptor'].pos[residues_to_keep]
|
||||
complex_graph['receptor'].x = complex_graph['receptor'].x[residues_to_keep]
|
||||
complex_graph['receptor'].side_chain_vecs = complex_graph['receptor'].side_chain_vecs[residues_to_keep]
|
||||
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
|
||||
subgraph(residues_to_keep, complex_graph['receptor', 'rec_contact', 'receptor'].edge_index,
|
||||
relabel_nodes=True)[0]
|
||||
|
||||
extra_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
complex_graph['receptor'].pos -= extra_center
|
||||
if isinstance(complex_graph['ligand'].pos, list):
|
||||
for i in range(len(complex_graph['ligand'].pos)):
|
||||
complex_graph['ligand'].pos[i] -= extra_center
|
||||
else:
|
||||
complex_graph['ligand'].pos -= extra_center
|
||||
complex_graph.original_center += extra_center
|
||||
|
||||
complex_graph['receptor'].pop('chain_ids')
|
||||
|
||||
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq',
|
||||
'to_keep', 'chain_ids']:
|
||||
if hasattr(complex_graph, a):
|
||||
delattr(complex_graph, a)
|
||||
if hasattr(complex_graph['receptor'], a):
|
||||
delattr(complex_graph['receptor'], a)
|
||||
|
||||
return complex_graph
|
||||
|
||||
def get(self, idx):
|
||||
if self.total_dataset_size is not None:
|
||||
idx = random.randint(0, len(self.split_clusters) - 1)
|
||||
|
||||
idx = idx % len(self.split_clusters)
|
||||
cluster = self.split_clusters[idx]
|
||||
|
||||
if self.no_randomness:
|
||||
ligand_name = sorted(self.cluster_to_ligands[cluster])[0]
|
||||
else:
|
||||
ligand_name = random.choice(self.cluster_to_ligands[cluster])
|
||||
|
||||
complex_graph = self.get_by_name(ligand_name, cluster)
|
||||
|
||||
if self.total_dataset_size is not None:
|
||||
complex_graph = Batch.from_data_list([complex_graph])
|
||||
|
||||
return complex_graph
|
||||
|
||||
def get_all_complexes(self):
|
||||
complexes = {}
|
||||
for cluster in self.split_clusters:
|
||||
for ligand_name in self.cluster_to_ligands[cluster]:
|
||||
complexes[ligand_name] = self.get_by_name(ligand_name, cluster)
|
||||
return complexes
|
||||
|
||||
def preprocessing_receptors(self):
|
||||
print(f'Processing receptors from [{self.split}] and saving it to [{self.prot_cache_path}]')
|
||||
|
||||
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
|
||||
receptor_names_all = [l[:6] for l in complex_names_all]
|
||||
receptor_names_all = sorted(list(dict.fromkeys(receptor_names_all)))
|
||||
print(f'Loading {len(receptor_names_all)} receptors.')
|
||||
|
||||
if self.esm_embeddings_path is not None:
|
||||
id_to_embeddings = torch.load(self.esm_embeddings_path)
|
||||
sequences_list = read_strings_from_txt(self.esm_embeddings_sequences_path)
|
||||
sequences_to_embeddings = {}
|
||||
for i, seq in enumerate(sequences_list):
|
||||
sequences_to_embeddings[seq] = id_to_embeddings[str(i)]
|
||||
|
||||
else:
|
||||
sequences_to_embeddings = None
|
||||
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
list_indices = list(range(len(receptor_names_all)//1000+1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
if os.path.exists(os.path.join(self.prot_cache_path, f"receptors{i}.pkl")):
|
||||
continue
|
||||
receptor_names = receptor_names_all[1000*i:1000*(i+1)]
|
||||
receptor_graphs = []
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(receptor_names), desc=f'loading receptors {i}/{len(receptor_names_all)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_receptor, zip(receptor_names, [sequences_to_embeddings]*len(receptor_names))):
|
||||
if t is not None:
|
||||
print(len(receptor_graphs))
|
||||
receptor_graphs.append(t)
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
print('Number of receptors: ', len(receptor_graphs))
|
||||
with open(os.path.join(self.prot_cache_path, f"receptors{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((receptor_graphs), f)
|
||||
return receptor_names_all
|
||||
|
||||
def check_all_receptors(self):
|
||||
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
receptor_names_all = [l[:6] for l in complex_names_all]
|
||||
receptor_names_all = list(dict.fromkeys(receptor_names_all))
|
||||
for i in range(len(receptor_names_all)//1000+1):
|
||||
if not os.path.exists(os.path.join(self.prot_cache_path, f"receptors{i}.pkl")):
|
||||
return False
|
||||
return True
|
||||
|
||||
def collect_receptors(self, receptors_to_keep=None, max_receptor_size=None, remove_promiscuous_targets=None):
|
||||
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
receptor_names_all = [l[:6] for l in complex_names_all]
|
||||
receptor_names_all = sorted(list(dict.fromkeys(receptor_names_all)))
|
||||
|
||||
receptor_graphs_all = []
|
||||
total_recovered = 0
|
||||
print(f'Loading {len(receptor_names_all)} receptors to keep {len(receptors_to_keep)}.')
|
||||
for i in range(len(receptor_names_all)//1000+1):
|
||||
print(f'prot path: {os.path.join(self.prot_cache_path, f"receptors{i}.pkl")}')
|
||||
with open(os.path.join(self.prot_cache_path, f"receptors{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
total_recovered += len(l)
|
||||
if receptors_to_keep is not None:
|
||||
l = [t for t in l if t['receptor_name'] in receptors_to_keep]
|
||||
receptor_graphs_all.extend(l)
|
||||
|
||||
cur_len = len(receptor_graphs_all)
|
||||
print(f"Kept {len(receptor_graphs_all)} receptors out of {len(receptor_names_all)} total and recovered {total_recovered}")
|
||||
|
||||
if max_receptor_size is not None:
|
||||
receptor_graphs_all = [rec for rec in receptor_graphs_all if rec["receptor"].pos.shape[0] <= max_receptor_size]
|
||||
print(f"Kept {len(receptor_graphs_all)} receptors out of {cur_len} after filtering by size")
|
||||
cur_len = len(receptor_graphs_all)
|
||||
|
||||
if remove_promiscuous_targets is not None:
|
||||
promiscuous_targets = set()
|
||||
for name in complex_names_all:
|
||||
l = name.split('_')
|
||||
if int(l[3]) > remove_promiscuous_targets:
|
||||
promiscuous_targets.add(name[:6])
|
||||
receptor_graphs_all = [rec for rec in receptor_graphs_all if rec["receptor_name"] not in promiscuous_targets]
|
||||
print(f"Kept {len(receptor_graphs_all)} receptors out of {cur_len} after removing promiscuous targets")
|
||||
|
||||
self.receptors = {}
|
||||
for r in receptor_graphs_all:
|
||||
self.receptors[r['receptor_name']] = r
|
||||
return
|
||||
|
||||
def get_receptor(self, par):
|
||||
name, sequences_to_embeddings = par
|
||||
rec_path = os.path.join(self.moad_dir, 'pdb_protein', name + '_protein.pdb')
|
||||
if not os.path.exists(rec_path):
|
||||
print("Receptor not found", name, rec_path)
|
||||
return None
|
||||
|
||||
complex_graph = HeteroData()
|
||||
complex_graph['receptor_name'] = name
|
||||
try:
|
||||
moad_extract_receptor_structure(path=rec_path, complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius,
|
||||
max_neighbors=self.c_alpha_max_neighbors, sequences_to_embeddings=sequences_to_embeddings,
|
||||
knn_only_graph=self.knn_only_graph, all_atoms=self.all_atoms, atom_cutoff=self.atom_radius,
|
||||
atom_max_neighbors=self.atom_max_neighbors)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
print(e)
|
||||
return None
|
||||
|
||||
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
complex_graph['receptor'].pos -= protein_center
|
||||
if self.all_atoms:
|
||||
complex_graph['atom'].pos -= protein_center
|
||||
complex_graph.original_center = protein_center
|
||||
return complex_graph
|
||||
|
||||
|
||||
def preprocessing_ligands(self):
|
||||
print(f'Processing complexes from [{self.split}] and saving it to [{self.lig_cache_path}]')
|
||||
|
||||
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
print(f'Loading {len(complex_names_all)} ligands.')
|
||||
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
list_indices = list(range(len(complex_names_all)//1000+1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
if os.path.exists(os.path.join(self.lig_cache_path, f"ligands{i}.pkl")):
|
||||
continue
|
||||
complex_names = complex_names_all[1000*i:1000*(i+1)]
|
||||
ligand_graphs, rdkit_ligands = [], []
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_ligand, complex_names):
|
||||
if t is not None:
|
||||
ligand_graphs.append(t[0])
|
||||
rdkit_ligands.append(t[1])
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.lig_cache_path, f"ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((ligand_graphs), f)
|
||||
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands), f)
|
||||
|
||||
ligand_graphs_all = []
|
||||
for i in range(len(complex_names_all)//1000+1):
|
||||
with open(os.path.join(self.lig_cache_path, f"ligands{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
ligand_graphs_all.extend(l)
|
||||
with open(os.path.join(self.lig_cache_path, f"ligands.pkl"), 'wb') as f:
|
||||
pickle.dump((ligand_graphs_all), f)
|
||||
|
||||
rdkit_ligands_all = []
|
||||
for i in range(len(complex_names_all) // 1000 + 1):
|
||||
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
rdkit_ligands_all.extend(l)
|
||||
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands_all), f)
|
||||
|
||||
def get_ligand(self, name):
|
||||
if self.split == 'train':
|
||||
lig_path = os.path.join(self.moad_dir, 'pdb_superligand', name + '.pdb')
|
||||
else:
|
||||
lig_path = os.path.join(self.moad_dir, 'pdb_ligand', name + '.pdb')
|
||||
|
||||
if not os.path.exists(lig_path):
|
||||
print("Ligand not found", name, lig_path)
|
||||
return None
|
||||
|
||||
# read pickle
|
||||
lig = Chem.MolFromPDBFile(lig_path)
|
||||
|
||||
if self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size:
|
||||
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
|
||||
return None
|
||||
|
||||
try:
|
||||
if self.matching:
|
||||
smile = Chem.MolToSmiles(lig)
|
||||
if '.' in smile:
|
||||
print(f'Ligand {name} has multiple fragments and we are doing matching. Not including {name} in preprocessed data.')
|
||||
return None
|
||||
|
||||
complex_graph = HeteroData()
|
||||
complex_graph['name'] = name
|
||||
|
||||
Chem.SanitizeMol(lig)
|
||||
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
|
||||
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries, skip_matching=self.skip_matching)
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if self.split != 'train':
|
||||
other_positions = [complex_graph['ligand'].orig_pos]
|
||||
nsplit = name.split('_')
|
||||
for i in range(100):
|
||||
new_file = os.path.join(self.moad_dir, 'pdb_ligand', f'{nsplit[0]}_{nsplit[1]}_{nsplit[2]}_{i}.pdb')
|
||||
if os.path.exists(new_file):
|
||||
if i != int(nsplit[3]):
|
||||
lig = Chem.MolFromPDBFile(new_file)
|
||||
lig = RemoveHs(lig, sanitize=True)
|
||||
other_positions.append(lig.GetConformer().GetPositions())
|
||||
else:
|
||||
break
|
||||
complex_graph['ligand'].orig_pos = np.asarray(other_positions)
|
||||
|
||||
return complex_graph, lig
|
||||
|
||||
|
||||
def print_statistics(dataset):
|
||||
statistics = ([], [], [], [], [], [])
|
||||
receptor_sizes = []
|
||||
|
||||
for i in range(len(dataset)):
|
||||
complex_graph = dataset[i]
|
||||
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
|
||||
receptor_sizes.append(complex_graph['receptor'].pos.shape[0])
|
||||
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
|
||||
molecule_center = torch.mean(lig_pos, dim=0)
|
||||
radius_molecule = torch.max(
|
||||
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1))
|
||||
distance_center = torch.linalg.vector_norm(molecule_center)
|
||||
statistics[0].append(radius_protein)
|
||||
statistics[1].append(radius_molecule)
|
||||
statistics[2].append(distance_center)
|
||||
if "rmsd_matching" in complex_graph:
|
||||
statistics[3].append(complex_graph.rmsd_matching)
|
||||
else:
|
||||
statistics[3].append(0)
|
||||
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1)
|
||||
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph:
|
||||
statistics[5].append(complex_graph.rmsd_matching)
|
||||
|
||||
if len(statistics[5]) == 0:
|
||||
statistics[5].append(-1)
|
||||
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching']
|
||||
print('Number of complexes: ', len(dataset))
|
||||
for i in range(len(name)):
|
||||
array = np.asarray(statistics[i])
|
||||
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
|
||||
|
||||
return
|
||||
|
||||
146
datasets/parse_chi.py
Normal file
146
datasets/parse_chi.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# From Nick Polizzi
|
||||
import numpy as np
|
||||
from collections import defaultdict
|
||||
import prody as pr
|
||||
import os
|
||||
|
||||
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short
|
||||
|
||||
|
||||
def get_dihedral_indices(resname, chi_num):
|
||||
"""Return the atom indices for the specified dihedral angle.
|
||||
"""
|
||||
if resname not in chi:
|
||||
return np.array([np.nan]*4)
|
||||
if chi_num not in chi[resname]:
|
||||
return np.array([np.nan]*4)
|
||||
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]])
|
||||
|
||||
|
||||
dihedral_indices = defaultdict(list)
|
||||
for aa in atom_order.keys():
|
||||
for i in range(1, 5):
|
||||
inds = get_dihedral_indices(aa, i)
|
||||
dihedral_indices[aa].append(inds)
|
||||
dihedral_indices[aa] = np.array(dihedral_indices[aa])
|
||||
|
||||
|
||||
def vector_batch(a, b):
|
||||
return a - b
|
||||
|
||||
|
||||
def unit_vector_batch(v):
|
||||
return v / np.linalg.norm(v, axis=1, keepdims=True)
|
||||
|
||||
|
||||
def dihedral_angle_batch(p):
|
||||
b0 = vector_batch(p[:, 0], p[:, 1])
|
||||
b1 = vector_batch(p[:, 1], p[:, 2])
|
||||
b2 = vector_batch(p[:, 2], p[:, 3])
|
||||
|
||||
n1 = np.cross(b0, b1)
|
||||
n2 = np.cross(b1, b2)
|
||||
|
||||
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True))
|
||||
|
||||
x = np.sum(n1 * n2, axis=1)
|
||||
y = np.sum(m1 * n2, axis=1)
|
||||
|
||||
deg = np.degrees(np.arctan2(y, x))
|
||||
|
||||
deg[deg < 0] += 360
|
||||
|
||||
return deg
|
||||
|
||||
|
||||
def batch_compute_dihedral_angles(sidechains):
|
||||
sidechains_np = np.array(sidechains)
|
||||
dihedral_angles = dihedral_angle_batch(sidechains_np)
|
||||
return dihedral_angles
|
||||
|
||||
|
||||
def get_coords(prody_pdb):
|
||||
resindices = sorted(set(prody_pdb.ca.getResindices()))
|
||||
coords = np.zeros((len(resindices), 14, 3))
|
||||
for i, resind in enumerate(resindices):
|
||||
sel = prody_pdb.select(f'resindex {resind}')
|
||||
resname = sel.getResnames()[0]
|
||||
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']):
|
||||
sel_resnum_name = sel.select(f'name {name}')
|
||||
if sel_resnum_name is not None:
|
||||
coords[i, j, :] = sel_resnum_name.getCoords()[0]
|
||||
else:
|
||||
coords[i, j, :] = [np.nan, np.nan, np.nan]
|
||||
return coords
|
||||
|
||||
|
||||
def get_onehot_sequence(seq):
|
||||
onehot = np.zeros((len(seq), 20))
|
||||
for i, aa in enumerate(seq):
|
||||
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 # 7 is the index for GLY
|
||||
onehot[i, idx] = 1
|
||||
return onehot
|
||||
|
||||
|
||||
def get_dihedral_indices(onehot_sequence):
|
||||
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]])
|
||||
|
||||
|
||||
def _get_chi_angles(coords, indices):
|
||||
X = coords
|
||||
Y = indices.astype(int)
|
||||
N = coords.shape[0]
|
||||
mask = np.isnan(indices)
|
||||
Y[mask] = 0
|
||||
Z = X[np.arange(N)[:, None, None], Y, :]
|
||||
Z[mask] = np.nan
|
||||
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4)
|
||||
return chi_angles
|
||||
|
||||
|
||||
def get_chi_angles(coords, seq, return_onehot=False):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
prody_pdb : prody.AtomGroup
|
||||
prody pdb object or selection
|
||||
return_coords : bool, optional
|
||||
return coordinates of prody_pdb in (N, 14, 3) array format, by default False
|
||||
return_onehot : bool, optional
|
||||
return one-hot sequence of prody_pdb, by default False
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy array of shape (N, 4)
|
||||
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb.
|
||||
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan.
|
||||
"""
|
||||
onehot = get_onehot_sequence(seq)
|
||||
dihedral_indices = get_dihedral_indices(onehot)
|
||||
if return_onehot:
|
||||
return _get_chi_angles(coords, dihedral_indices), onehot
|
||||
return _get_chi_angles(coords, dihedral_indices)
|
||||
|
||||
|
||||
def test_get_chi_angles(print_chi_angles=False):
|
||||
# need internet connection of '6w70.pdb' in working directory
|
||||
pdb = pr.parsePDB('6w70')
|
||||
prody_pdb = pdb.select('chain A')
|
||||
chi_angles = get_chi_angles(prody_pdb)
|
||||
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4)
|
||||
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0
|
||||
print('test_get_chi_angles passed')
|
||||
try:
|
||||
os.remove('6w70.pdb.gz')
|
||||
except:
|
||||
pass
|
||||
if print_chi_angles:
|
||||
print(chi_angles)
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_get_chi_angles(print_chi_angles=True)
|
||||
|
||||
|
||||
536
datasets/pdb.py
Normal file
536
datasets/pdb.py
Normal file
@@ -0,0 +1,536 @@
|
||||
# Significant contribution from Ben Fry
|
||||
|
||||
import copy
|
||||
import os.path
|
||||
import pickle
|
||||
import random
|
||||
from multiprocessing import Pool
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import AllChem, MolFromSmiles
|
||||
from scipy.spatial.distance import pdist, squareform
|
||||
from torch_geometric.data import Dataset, HeteroData
|
||||
from torch_geometric.utils import subgraph
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.constants import aa_to_cg_indices, amino_acid_smiles, cg_rdkit_indices
|
||||
from datasets.parse_chi import aa_long2short, atom_order
|
||||
from datasets.process_mols import new_extract_receptor_structure, get_lig_graph, generate_conformer
|
||||
from utils.torsion import get_transformation_mask
|
||||
|
||||
|
||||
def read_strings_from_txt(path):
|
||||
# every line will be one element of the returned list
|
||||
with open(path) as file:
|
||||
lines = file.readlines()
|
||||
return [line.rstrip() for line in lines]
|
||||
|
||||
|
||||
def compute_num_ca_neighbors(coords, cg_coords, idx, is_valid_bb_node, max_dist=5, buffer_residue_num=7):
|
||||
"""
|
||||
Counts number of residues with heavy atoms within max_dist (Angstroms) of this sidechain that are not
|
||||
residues within +/- buffer_residue_num in primary sequence.
|
||||
From Ben's code
|
||||
Note: Gabriele removed the chain_index
|
||||
"""
|
||||
|
||||
# Extract coordinates of all residues in the protein.
|
||||
bb_coords = coords
|
||||
|
||||
# Compute the indices that we should not consider interactions.
|
||||
excluded_neighbors = [idx - x for x in reversed(range(0, buffer_residue_num+1)) if (idx - x) >= 0]
|
||||
excluded_neighbors.extend([idx + x for x in range(1, buffer_residue_num+1)])
|
||||
|
||||
# Create indices of an N x M distance matrix where N is num BB nodes and M is num CG nodes.
|
||||
e_idx = torch.stack([
|
||||
torch.arange(bb_coords.shape[0]).unsqueeze(-1).expand((-1, cg_coords.shape[0])).flatten(),
|
||||
torch.arange(cg_coords.shape[0]).unsqueeze(0).expand((bb_coords.shape[0], -1)).flatten()
|
||||
])
|
||||
|
||||
# Expand bb_coords and cg_coords into the same dimensionality.
|
||||
bb_coords_exp = bb_coords[e_idx[0]]
|
||||
cg_coords_exp = cg_coords[e_idx[1]].unsqueeze(1)
|
||||
|
||||
# Every row is distance of chemical group to each atom in backbone coordinate frame.
|
||||
bb_exp_idces, _ = (torch.cdist(bb_coords_exp, cg_coords_exp).squeeze(-1) < max_dist).nonzero(as_tuple=True)
|
||||
bb_idces_within_thresh = torch.unique(e_idx[0][bb_exp_idces])
|
||||
|
||||
# Only count residues that are not adjacent or origin in primary sequence and are valid backbone residues (fully resolved coordinate frame).
|
||||
bb_idces_within_thresh = bb_idces_within_thresh[~torch.isin(bb_idces_within_thresh, torch.tensor(excluded_neighbors)) & is_valid_bb_node[bb_idces_within_thresh]]
|
||||
|
||||
return len(bb_idces_within_thresh)
|
||||
|
||||
|
||||
def identify_valid_vandermers(args):
|
||||
"""
|
||||
Constructs a tensor containing all the number of contacts for each residue that can be sampled from for chemical groups.
|
||||
By using every sidechain as a chemical group, we will load the actual chemical groups at training time.
|
||||
These can be used to sample as probabilities once divided by the sum.
|
||||
"""
|
||||
complex_graph, max_dist, buffer_residue_num = args
|
||||
|
||||
# Constructs a mask tracking whether index is a valid coordinate frame / residue label to train over.
|
||||
#is_in_residue_vocabulary = torch.tensor([x in aa_short2long for x in data['seq']]).bool()
|
||||
coords, seq = complex_graph.coords, complex_graph.seq
|
||||
is_valid_bb_node = (coords[:, :4].isnan().sum(dim=(1,2)) == 0).bool() #* is_in_residue_vocabulary
|
||||
|
||||
valid_cg_idces = []
|
||||
for idx, aa in enumerate(seq):
|
||||
|
||||
if aa not in aa_to_cg_indices:
|
||||
valid_cg_idces.append(0)
|
||||
else:
|
||||
indices = aa_to_cg_indices[aa]
|
||||
cg_coordinates = coords[idx][indices]
|
||||
|
||||
# remove chemical group residues that aren't fully resolved.
|
||||
if torch.any(cg_coordinates.isnan()).item():
|
||||
valid_cg_idces.append(0)
|
||||
continue
|
||||
|
||||
nbr_count = compute_num_ca_neighbors(coords, cg_coordinates, idx, is_valid_bb_node,
|
||||
max_dist=max_dist, buffer_residue_num=buffer_residue_num)
|
||||
valid_cg_idces.append(nbr_count)
|
||||
|
||||
return complex_graph.name, torch.tensor(valid_cg_idces)
|
||||
|
||||
|
||||
def fast_identify_valid_vandermers(coords, seq, max_dist=5, buffer_residue_num=7):
|
||||
|
||||
offset = 10000 + max_dist
|
||||
R = coords.shape[0]
|
||||
|
||||
coords = coords.numpy().reshape(-1, 3)
|
||||
pdist_mat = squareform(pdist(coords))
|
||||
pdist_mat = pdist_mat.reshape((R, 14, R, 14))
|
||||
pdist_mat = np.nan_to_num(pdist_mat, nan=offset)
|
||||
pdist_mat = np.min(pdist_mat, axis=(1, 3))
|
||||
|
||||
# compute pairwise distances
|
||||
pdist_mat = pdist_mat + np.diag(np.ones(len(seq)) * offset)
|
||||
for i in range(1, buffer_residue_num+1):
|
||||
pdist_mat += np.diag(np.ones(len(seq)-i) * offset, k=i) + np.diag(np.ones(len(seq)-i) * offset, k=-i)
|
||||
|
||||
# get number of residues that are within max_dist of each other
|
||||
nbr_count = np.sum(pdist_mat < max_dist, axis=1)
|
||||
return torch.tensor(nbr_count)
|
||||
|
||||
|
||||
def compute_cg_features(aa, aa_smile):
|
||||
"""
|
||||
Given an amino acid and a smiles string returns the stacked tensor of chemical group atom encodings.
|
||||
The order of the output tensor rows corresponds to the index the atoms appear in aa_to_cg_indices from constants.
|
||||
"""
|
||||
|
||||
# Handle any residues that we don't have chemical groups for (ex: GLY if not using bb_cnh and bb_cco)
|
||||
aa_short = aa_long2short[aa]
|
||||
if aa_short not in aa_to_cg_indices:
|
||||
return None
|
||||
|
||||
# Create rdkit molecule from smiles string.
|
||||
mol = Chem.MolFromSmiles(aa_smile)
|
||||
|
||||
complex_graph = HeteroData()
|
||||
get_lig_graph(mol, complex_graph)
|
||||
|
||||
atoms_to_keep = torch.tensor([i for i, _ in cg_rdkit_indices[aa].items()]).long()
|
||||
complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr = \
|
||||
subgraph(atoms_to_keep, complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr, relabel_nodes=True)
|
||||
complex_graph['ligand'].x = complex_graph['ligand'].x[atoms_to_keep]
|
||||
|
||||
edge_mask, mask_rotate = get_transformation_mask(complex_graph)
|
||||
complex_graph['ligand'].edge_mask = torch.tensor(edge_mask)
|
||||
complex_graph['ligand'].mask_rotate = mask_rotate
|
||||
return complex_graph
|
||||
|
||||
|
||||
class PDBSidechain(Dataset):
|
||||
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0,
|
||||
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, remove_hs=True, all_atoms=False,
|
||||
atom_radius=5, atom_max_neighbors=None, sequences_to_embeddings=None,
|
||||
knn_only_graph=True, multiplicity=1, vandermers_max_dist=5, vandermers_buffer_residue_num=7,
|
||||
vandermers_min_contacts=5, remove_second_segment=False, merge_clusters=1, vandermers_extraction=True,
|
||||
add_random_ligand=False):
|
||||
|
||||
super(PDBSidechain, self).__init__(root, transform)
|
||||
assert remove_hs == True, "not implemented yet"
|
||||
self.root = root
|
||||
self.split = split
|
||||
self.limit_complexes = limit_complexes
|
||||
self.receptor_radius = receptor_radius
|
||||
self.knn_only_graph = knn_only_graph
|
||||
self.multiplicity = multiplicity
|
||||
self.c_alpha_max_neighbors = c_alpha_max_neighbors
|
||||
self.num_workers = num_workers
|
||||
self.sequences_to_embeddings = sequences_to_embeddings
|
||||
self.remove_second_segment = remove_second_segment
|
||||
self.merge_clusters = merge_clusters
|
||||
self.vandermers_extraction = vandermers_extraction
|
||||
self.add_random_ligand = add_random_ligand
|
||||
self.all_atoms = all_atoms
|
||||
self.atom_radius = atom_radius
|
||||
self.atom_max_neighbors = atom_max_neighbors
|
||||
|
||||
if vandermers_extraction:
|
||||
self.cg_node_feature_lookup_dict = {aa_long2short[aa]: compute_cg_features(aa, aa_smile) for aa, aa_smile in
|
||||
amino_acid_smiles.items()}
|
||||
|
||||
self.cache_path = os.path.join(cache_path, f'PDB3_limit{self.limit_complexes}_INDEX{self.split}'
|
||||
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
|
||||
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
|
||||
+ ('' if not self.knn_only_graph else '_knnOnly'))
|
||||
self.read_split()
|
||||
|
||||
if not self.check_all_proteins():
|
||||
os.makedirs(self.cache_path, exist_ok=True)
|
||||
self.preprocess()
|
||||
|
||||
self.vandermers_max_dist = vandermers_max_dist
|
||||
self.vandermers_buffer_residue_num = vandermers_buffer_residue_num
|
||||
self.vandermers_min_contacts = vandermers_min_contacts
|
||||
self.collect_proteins()
|
||||
|
||||
filtered_proteins = []
|
||||
if vandermers_extraction:
|
||||
for complex_graph in tqdm(self.protein_graphs):
|
||||
if complex_graph.name in self.vandermers and torch.any(self.vandermers[complex_graph.name] >= 10):
|
||||
filtered_proteins.append(complex_graph)
|
||||
print(f"Computed vandermers and kept {len(filtered_proteins)} proteins out of {len(self.protein_graphs)}")
|
||||
else:
|
||||
filtered_proteins = self.protein_graphs
|
||||
|
||||
second_filter = []
|
||||
for complex_graph in tqdm(filtered_proteins):
|
||||
if sequences_to_embeddings is None or complex_graph.orig_seq in sequences_to_embeddings:
|
||||
second_filter.append(complex_graph)
|
||||
print(f"Checked embeddings available and kept {len(second_filter)} proteins out of {len(filtered_proteins)}")
|
||||
|
||||
self.protein_graphs = second_filter
|
||||
|
||||
# filter clusters that have no protein graphs
|
||||
self.split_clusters = list(set([g.cluster for g in self.protein_graphs]))
|
||||
self.cluster_to_complexes = {c: [] for c in self.split_clusters}
|
||||
for p in self.protein_graphs:
|
||||
self.cluster_to_complexes[p['cluster']].append(p)
|
||||
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_complexes[c]) > 0]
|
||||
print("Total elements in set", len(self.split_clusters) * self.multiplicity // self.merge_clusters)
|
||||
|
||||
self.name_to_complex = {p.name: p for p in self.protein_graphs}
|
||||
self.define_probabilities()
|
||||
|
||||
if self.add_random_ligand:
|
||||
# read csv with all smiles
|
||||
with open('data/smiles_list.csv', 'r') as f:
|
||||
self.smiles_list = f.readlines()
|
||||
self.smiles_list = [s.split(',')[0] for s in self.smiles_list]
|
||||
|
||||
def define_probabilities(self):
|
||||
if not self.vandermers_extraction:
|
||||
return
|
||||
|
||||
if self.vandermers_min_contacts is not None:
|
||||
self.probabilities = torch.arange(1000) - self.vandermers_min_contacts + 1
|
||||
self.probabilities[:self.vandermers_min_contacts] = 0
|
||||
else:
|
||||
with open('data/pdbbind_counts.pkl', 'rb') as f:
|
||||
pdbbind_counts = pickle.load(f)
|
||||
|
||||
pdb_counts = torch.ones(1000)
|
||||
for contacts in self.vandermers.values():
|
||||
pdb_counts.index_add_(0, contacts, torch.ones(contacts.shape))
|
||||
print(pdbbind_counts[:30])
|
||||
print(pdb_counts[:30])
|
||||
|
||||
self.probabilities = pdbbind_counts / pdb_counts
|
||||
self.probabilities[:7] = 0
|
||||
|
||||
def len(self):
|
||||
return len(self.split_clusters) * self.multiplicity // self.merge_clusters
|
||||
|
||||
def get(self, idx=None, protein=None, smiles=None):
|
||||
assert idx is not None or (protein is not None and smiles is not None), "provide idx or protein or smile"
|
||||
|
||||
if protein is None or smiles is None:
|
||||
idx = idx % len(self.split_clusters)
|
||||
if self.merge_clusters > 1:
|
||||
idx = idx * self.merge_clusters
|
||||
idx = idx + random.randint(0, self.merge_clusters - 1)
|
||||
idx = min(idx, len(self.split_clusters) - 1)
|
||||
cluster = self.split_clusters[idx]
|
||||
protein_graph = copy.deepcopy(random.choice(self.cluster_to_complexes[cluster]))
|
||||
else:
|
||||
protein_graph = copy.deepcopy(self.name_to_complex[protein])
|
||||
|
||||
if self.sequences_to_embeddings is not None:
|
||||
#print(self.sequences_to_embeddings[protein_graph.orig_seq].shape, len(protein_graph.orig_seq), protein_graph.to_keep.shape)
|
||||
if len(protein_graph.orig_seq) != len(self.sequences_to_embeddings[protein_graph.orig_seq]):
|
||||
print('problem with ESM embeddings')
|
||||
return self.get(random.randint(0, self.len()))
|
||||
|
||||
lm_embeddings = self.sequences_to_embeddings[protein_graph.orig_seq][protein_graph.to_keep]
|
||||
protein_graph['receptor'].x = torch.cat([protein_graph['receptor'].x, lm_embeddings], dim=1)
|
||||
|
||||
if self.vandermers_extraction:
|
||||
# select sidechain to remove
|
||||
vandermers_contacts = self.vandermers[protein_graph.name]
|
||||
vandermers_probs = self.probabilities[vandermers_contacts].numpy()
|
||||
|
||||
if not np.any(vandermers_contacts.numpy() >= 10):
|
||||
print('no vandarmers >= 10 retrying with new one')
|
||||
return self.get(random.randint(0, self.len()))
|
||||
|
||||
sidechain_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
|
||||
|
||||
# remove part of the sequence
|
||||
residues_to_keep = np.ones(len(protein_graph.seq), dtype=bool)
|
||||
residues_to_keep[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
|
||||
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
|
||||
|
||||
if self.remove_second_segment:
|
||||
pos_idx = protein_graph['receptor'].pos[sidechain_idx]
|
||||
limit_closeness = 10
|
||||
far_enough = torch.sum((protein_graph['receptor'].pos - pos_idx[None, :]) ** 2, dim=-1) > limit_closeness ** 2
|
||||
vandermers_probs = vandermers_probs * far_enough.float().numpy()
|
||||
vandermers_probs[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
|
||||
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = 0
|
||||
if np.all(vandermers_probs<=0):
|
||||
print('no second vandermer available retrying with new one')
|
||||
return self.get(random.randint(0, self.len()))
|
||||
sc2_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
|
||||
|
||||
residues_to_keep[max(0, sc2_idx - self.vandermers_buffer_residue_num):
|
||||
min(sc2_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
|
||||
|
||||
residues_to_keep = torch.from_numpy(residues_to_keep)
|
||||
protein_graph['receptor'].pos = protein_graph['receptor'].pos[residues_to_keep]
|
||||
protein_graph['receptor'].x = protein_graph['receptor'].x[residues_to_keep]
|
||||
protein_graph['receptor'].side_chain_vecs = protein_graph['receptor'].side_chain_vecs[residues_to_keep]
|
||||
protein_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
|
||||
subgraph(residues_to_keep, protein_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
|
||||
|
||||
# create the sidechain ligand
|
||||
sidechain_aa = protein_graph.seq[sidechain_idx]
|
||||
ligand_graph = self.cg_node_feature_lookup_dict[sidechain_aa]
|
||||
ligand_graph['ligand'].pos = protein_graph.coords[sidechain_idx][protein_graph.mask[sidechain_idx]]
|
||||
|
||||
for type in ligand_graph.node_types + ligand_graph.edge_types:
|
||||
for key, value in ligand_graph[type].items():
|
||||
protein_graph[type][key] = value
|
||||
|
||||
protein_graph['ligand'].orig_pos = protein_graph['ligand'].pos.numpy()
|
||||
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
|
||||
protein_graph['ligand'].pos = protein_graph['ligand'].pos - protein_center
|
||||
protein_graph.original_center = protein_center
|
||||
protein_graph['receptor_name'] = protein_graph.name
|
||||
else:
|
||||
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
|
||||
protein_graph.original_center = protein_center
|
||||
protein_graph['receptor_name'] = protein_graph.name
|
||||
|
||||
if self.add_random_ligand:
|
||||
if smiles is not None:
|
||||
mol = MolFromSmiles(smiles)
|
||||
try:
|
||||
generate_conformer(mol)
|
||||
except Exception as e:
|
||||
print("failed to generate the given ligand returning None", e)
|
||||
return None
|
||||
else:
|
||||
success = False
|
||||
while not success:
|
||||
smiles = random.choice(self.smiles_list)
|
||||
mol = MolFromSmiles(smiles)
|
||||
try:
|
||||
success = not generate_conformer(mol)
|
||||
except Exception as e:
|
||||
print(e, "changing ligand")
|
||||
|
||||
lig_graph = HeteroData()
|
||||
get_lig_graph(mol, lig_graph)
|
||||
|
||||
edge_mask, mask_rotate = get_transformation_mask(lig_graph)
|
||||
lig_graph['ligand'].edge_mask = torch.tensor(edge_mask)
|
||||
lig_graph['ligand'].mask_rotate = mask_rotate
|
||||
lig_graph['ligand'].smiles = smiles
|
||||
lig_graph['ligand'].pos = lig_graph['ligand'].pos - torch.mean(lig_graph['ligand'].pos, dim=0, keepdim=True)
|
||||
|
||||
for type in lig_graph.node_types + lig_graph.edge_types:
|
||||
for key, value in lig_graph[type].items():
|
||||
protein_graph[type][key] = value
|
||||
|
||||
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
|
||||
if hasattr(protein_graph, a):
|
||||
delattr(protein_graph, a)
|
||||
if hasattr(protein_graph['receptor'], a):
|
||||
delattr(protein_graph['receptor'], a)
|
||||
|
||||
return protein_graph
|
||||
|
||||
def read_split(self):
|
||||
# read CSV file
|
||||
df = pd.read_csv(self.root + "/list.csv")
|
||||
print("Loaded list CSV file")
|
||||
|
||||
# get clusters and filter by split
|
||||
if self.split == "train":
|
||||
val_clusters = set(read_strings_from_txt(self.root + "/valid_clusters.txt"))
|
||||
test_clusters = set(read_strings_from_txt(self.root + "/test_clusters.txt"))
|
||||
clusters = df["CLUSTER"].unique()
|
||||
clusters = [int(c) for c in clusters if c not in val_clusters and c not in test_clusters]
|
||||
elif self.split == "val":
|
||||
clusters = [int(s) for s in read_strings_from_txt(self.root + "/valid_clusters.txt")]
|
||||
elif self.split == "test":
|
||||
clusters = [int(s) for s in read_strings_from_txt(self.root + "/test_clusters.txt")]
|
||||
else:
|
||||
raise ValueError("Split must be train, val or test")
|
||||
print(self.split, "clusters", len(clusters))
|
||||
clusters = set(clusters)
|
||||
|
||||
self.chains_in_cluster = []
|
||||
complexes_in_cluster = set()
|
||||
for chain, cluster in zip(df["CHAINID"], df["CLUSTER"]):
|
||||
if cluster not in clusters:
|
||||
continue
|
||||
# limit to one chain per complex
|
||||
if chain[:4] not in complexes_in_cluster:
|
||||
self.chains_in_cluster.append((chain, cluster))
|
||||
complexes_in_cluster.add(chain[:4])
|
||||
print("Filtered chains in cluster", len(self.chains_in_cluster))
|
||||
|
||||
if self.limit_complexes > 0:
|
||||
self.chains_in_cluster = self.chains_in_cluster[:self.limit_complexes]
|
||||
|
||||
def check_all_proteins(self):
|
||||
for i in range(len(self.chains_in_cluster)//10000+1):
|
||||
if not os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
|
||||
return False
|
||||
return True
|
||||
|
||||
def collect_proteins(self):
|
||||
self.protein_graphs = []
|
||||
self.vandermers = {}
|
||||
total_recovered = 0
|
||||
print(f'Loading {len(self.chains_in_cluster)} protein graphs.')
|
||||
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'rb') as f:
|
||||
print(i)
|
||||
l = pickle.load(f)
|
||||
total_recovered += len(l)
|
||||
self.protein_graphs.extend(l)
|
||||
|
||||
if not self.vandermers_extraction:
|
||||
continue
|
||||
|
||||
if os.path.exists(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl')):
|
||||
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'rb') as f:
|
||||
vandermers = pickle.load(f)
|
||||
self.vandermers.update(vandermers)
|
||||
continue
|
||||
|
||||
vandermers = {}
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(l), desc=f'computing vandermers {i}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
arguments = zip(l, [self.vandermers_max_dist] * len(l),
|
||||
[self.vandermers_buffer_residue_num] * len(l))
|
||||
for t in map_fn(identify_valid_vandermers, arguments):
|
||||
if t is not None:
|
||||
vandermers[t[0]] = t[1]
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'wb') as f:
|
||||
pickle.dump(vandermers, f)
|
||||
self.vandermers.update(vandermers)
|
||||
|
||||
print(f"Kept {len(self.protein_graphs)} proteins out of {len(self.chains_in_cluster)} total")
|
||||
return
|
||||
|
||||
def preprocess(self):
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 10000 proteins
|
||||
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
if os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
|
||||
continue
|
||||
chains_names = self.chains_in_cluster[10000 * i:10000 * (i + 1)]
|
||||
protein_graphs = []
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(chains_names),
|
||||
desc=f'loading protein batch {i}/{len(self.chains_in_cluster) // 10000 + 1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.load_chain, chains_names):
|
||||
if t is not None:
|
||||
protein_graphs.append(t)
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'wb') as f:
|
||||
pickle.dump(protein_graphs, f)
|
||||
|
||||
print("Finished preprocessing and saving protein graphs")
|
||||
|
||||
def load_chain(self, c):
|
||||
chain, cluster = c
|
||||
if not os.path.exists(self.root + f"/pdb/{chain[1:3]}/{chain}.pt"):
|
||||
print("File not found", chain)
|
||||
return None
|
||||
|
||||
data = torch.load(self.root + f"/pdb/{chain[1:3]}/{chain}.pt")
|
||||
complex_graph = HeteroData()
|
||||
complex_graph['name'] = chain
|
||||
orig_seq = data["seq"]
|
||||
coords = data["xyz"]
|
||||
mask = data["mask"].bool()
|
||||
|
||||
# remove residues with NaN backbone coordinates
|
||||
to_keep = torch.logical_not(torch.any(torch.isnan(coords[:, :4, 0]), dim=1))
|
||||
coords = coords[to_keep]
|
||||
seq = ''.join(np.asarray(list(orig_seq))[to_keep.numpy()].tolist())
|
||||
mask = mask[to_keep]
|
||||
|
||||
if len(coords) == 0:
|
||||
print("All coords were NaN", chain)
|
||||
return None
|
||||
|
||||
try:
|
||||
new_extract_receptor_structure(seq, coords.numpy(), complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius,
|
||||
max_neighbors=self.c_alpha_max_neighbors, knn_only_graph=self.knn_only_graph,
|
||||
all_atoms=self.all_atoms, atom_cutoff=self.atom_radius,
|
||||
atom_max_neighbors=self.atom_max_neighbors)
|
||||
except Exception as e:
|
||||
print("Error in extracting receptor", chain)
|
||||
print(e)
|
||||
return None
|
||||
|
||||
if torch.any(torch.isnan(complex_graph['receptor'].pos)):
|
||||
print("NaN in pos receptor", chain)
|
||||
return None
|
||||
|
||||
complex_graph.coords = coords
|
||||
complex_graph.seq = seq
|
||||
complex_graph.mask = mask
|
||||
complex_graph.cluster = cluster
|
||||
complex_graph.orig_seq = orig_seq
|
||||
complex_graph.to_keep = to_keep
|
||||
return complex_graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataset = PDBSidechain(root="data/pdb_2021aug02_sample", split="train", multiplicity=1, limit_complexes=150)
|
||||
print(len(dataset))
|
||||
print(dataset[0])
|
||||
for p in dataset:
|
||||
print(p)
|
||||
pass
|
||||
@@ -1,125 +1,208 @@
|
||||
import binascii
|
||||
import glob
|
||||
import hashlib
|
||||
import os
|
||||
import pickle
|
||||
from collections import defaultdict
|
||||
from multiprocessing import Pool
|
||||
import random
|
||||
import copy
|
||||
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torch
|
||||
from rdkit.Chem import MolToSmiles, MolFromSmiles, AddHs
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import MolFromSmiles, AddHs
|
||||
from torch_geometric.data import Dataset, HeteroData
|
||||
from torch_geometric.loader import DataLoader, DataListLoader
|
||||
from torch_geometric.transforms import BaseTransform
|
||||
from tqdm import tqdm
|
||||
from rdkit.Chem import RemoveAllHs
|
||||
|
||||
from datasets.process_mols import read_molecule, get_rec_graph, generate_conformer, \
|
||||
get_lig_graph_with_matching, extract_receptor_structure, parse_receptor, parse_pdb_from_path
|
||||
from datasets.process_mols import read_molecule, get_lig_graph_with_matching, generate_conformer, moad_extract_receptor_structure
|
||||
from utils.diffusion_utils import modify_conformer, set_time
|
||||
from utils.utils import read_strings_from_txt
|
||||
from utils.utils import read_strings_from_txt, crop_beyond
|
||||
from utils import so3, torus
|
||||
|
||||
|
||||
class NoiseTransform(BaseTransform):
|
||||
def __init__(self, t_to_sigma, no_torsion, all_atom):
|
||||
def __init__(self, t_to_sigma, no_torsion, all_atom, alpha=1, beta=1,
|
||||
include_miscellaneous_atoms=False, crop_beyond_cutoff=None, time_independent=False, rmsd_cutoff=0,
|
||||
minimum_t=0, sampling_mixing_coeff=0):
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.no_torsion = no_torsion
|
||||
self.all_atom = all_atom
|
||||
self.include_miscellaneous_atoms = include_miscellaneous_atoms
|
||||
self.minimum_t = minimum_t
|
||||
self.mixing_coeff = sampling_mixing_coeff
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.crop_beyond_cutoff = crop_beyond_cutoff
|
||||
self.rmsd_cutoff = rmsd_cutoff
|
||||
self.time_independent = time_independent
|
||||
|
||||
def __call__(self, data):
|
||||
t = np.random.uniform()
|
||||
t_tr, t_rot, t_tor = t, t, t
|
||||
return self.apply_noise(data, t_tr, t_rot, t_tor)
|
||||
t_tr, t_rot, t_tor, t = self.get_time()
|
||||
return self.apply_noise(data, t_tr, t_rot, t_tor, t)
|
||||
|
||||
def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update = None, rot_update=None, torsion_updates=None):
|
||||
def get_time(self):
|
||||
if self.time_independent:
|
||||
t = np.random.beta(self.alpha, self.beta)
|
||||
t_tr, t_rot, t_tor = t,t,t
|
||||
else:
|
||||
t = None
|
||||
if self.mixing_coeff == 0:
|
||||
t = np.random.beta(self.alpha, self.beta)
|
||||
t = self.minimum_t + t * (1 - self.minimum_t)
|
||||
else:
|
||||
choice = np.random.binomial(1, self.mixing_coeff)
|
||||
t1 = np.random.beta(self.alpha, self.beta)
|
||||
t1 = t1 * self.minimum_t
|
||||
t2 = np.random.beta(self.alpha, self.beta)
|
||||
t2 = self.minimum_t + t2 * (1 - self.minimum_t)
|
||||
t = choice * t1 + (1 - choice) * t2
|
||||
|
||||
t_tr, t_rot, t_tor = t,t,t
|
||||
return t_tr, t_rot, t_tor, t
|
||||
|
||||
def apply_noise(self, data, t_tr, t_rot, t_tor, t, tr_update = None, rot_update=None, torsion_updates=None):
|
||||
if not torch.is_tensor(data['ligand'].pos):
|
||||
data['ligand'].pos = random.choice(data['ligand'].pos)
|
||||
|
||||
if self.time_independent:
|
||||
orig_complex_graph = copy.deepcopy(data)
|
||||
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
|
||||
set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
|
||||
|
||||
if self.time_independent:
|
||||
set_time(data, 0, 0, 0, 0, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
|
||||
else:
|
||||
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
|
||||
|
||||
tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update
|
||||
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
|
||||
torsion_updates = np.random.normal(loc=0.0, scale=tor_sigma, size=data['ligand'].edge_mask.sum()) if torsion_updates is None else torsion_updates
|
||||
torsion_updates = None if self.no_torsion else torsion_updates
|
||||
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
|
||||
try:
|
||||
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
|
||||
except Exception as e:
|
||||
print("failed modify conformer")
|
||||
print(e)
|
||||
|
||||
if self.time_independent:
|
||||
if self.no_torsion:
|
||||
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
|
||||
|
||||
filterHs = torch.not_equal(data['ligand'].x[:, 0], 0).cpu().numpy()
|
||||
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
|
||||
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
|
||||
ligand_pos = data['ligand'].pos.cpu().numpy()[filterHs]
|
||||
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy()
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=1).mean(axis=0))
|
||||
data.y = torch.tensor(rmsd < self.rmsd_cutoff).float().unsqueeze(0)
|
||||
data.atom_y = data.y
|
||||
return data
|
||||
|
||||
data.tr_score = -tr_update / tr_sigma ** 2
|
||||
data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0)
|
||||
data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
|
||||
data.tor_sigma_edge = None if self.no_torsion else np.ones(data['ligand'].edge_mask.sum()) * tor_sigma
|
||||
|
||||
if data['ligand'].pos.shape[0] == 1:
|
||||
# if the ligand is a single atom, the rotational score is always 0
|
||||
data.rot_score = data.rot_score * 0
|
||||
|
||||
if self.crop_beyond_cutoff is not None:
|
||||
crop_beyond(data, tr_sigma * 3 + self.crop_beyond_cutoff, self.all_atom)
|
||||
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
|
||||
return data
|
||||
|
||||
|
||||
class PDBBind(Dataset):
|
||||
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0,
|
||||
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0, chain_cutoff=10,
|
||||
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
|
||||
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
|
||||
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False,
|
||||
ligands_list=None, protein_path_list=None, ligand_descriptions=None, keep_local_structures=False):
|
||||
include_miscellaneous_atoms=False,
|
||||
protein_path_list=None, ligand_descriptions=None, keep_local_structures=False,
|
||||
protein_file="protein_processed", ligand_file="ligand",
|
||||
knn_only_graph=False, matching_tries=1, dataset='PDBBind'):
|
||||
|
||||
super(PDBBind, self).__init__(root, transform)
|
||||
self.pdbbind_dir = root
|
||||
self.include_miscellaneous_atoms = include_miscellaneous_atoms
|
||||
self.max_lig_size = max_lig_size
|
||||
self.split_path = split_path
|
||||
self.limit_complexes = limit_complexes
|
||||
self.chain_cutoff = chain_cutoff
|
||||
self.receptor_radius = receptor_radius
|
||||
self.num_workers = num_workers
|
||||
self.c_alpha_max_neighbors = c_alpha_max_neighbors
|
||||
self.remove_hs = remove_hs
|
||||
self.esm_embeddings_path = esm_embeddings_path
|
||||
self.use_old_wrong_embedding_order = False
|
||||
self.require_ligand = require_ligand
|
||||
self.protein_path_list = protein_path_list
|
||||
self.ligand_descriptions = ligand_descriptions
|
||||
self.keep_local_structures = keep_local_structures
|
||||
self.protein_file = protein_file
|
||||
self.fixed_knn_radius_graph = True
|
||||
self.knn_only_graph = knn_only_graph
|
||||
self.matching_tries = matching_tries
|
||||
self.ligand_file = ligand_file
|
||||
self.dataset = dataset
|
||||
assert knn_only_graph or (not all_atoms)
|
||||
self.all_atoms = all_atoms
|
||||
if matching or protein_path_list is not None and ligand_descriptions is not None:
|
||||
cache_path += '_torsion'
|
||||
if all_atoms:
|
||||
cache_path += '_allatoms'
|
||||
self.full_cache_path = os.path.join(cache_path, f'limit{self.limit_complexes}'
|
||||
self.full_cache_path = os.path.join(cache_path, f'{dataset}3_limit{self.limit_complexes}'
|
||||
f'_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}'
|
||||
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
|
||||
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
|
||||
+ ('' if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
|
||||
+ ('' if not matching or num_conformers == 1 else f'_confs{num_conformers}')
|
||||
f'_chainCutoff{self.chain_cutoff if self.chain_cutoff is None else int(self.chain_cutoff)}'
|
||||
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
|
||||
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}')
|
||||
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
|
||||
+ '_full'
|
||||
+ ('' if not keep_local_structures else f'_keptLocalStruct')
|
||||
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode()))))
|
||||
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode())))
|
||||
+ ('' if protein_file == "protein_processed" else '_' + protein_file)
|
||||
+ ('' if not self.fixed_knn_radius_graph else (f'_fixedKNN' if not self.knn_only_graph else '_fixedKNNonly'))
|
||||
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms')
|
||||
+ ('' if self.use_old_wrong_embedding_order else '_chainOrd')
|
||||
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}'))
|
||||
self.popsize, self.maxiter = popsize, maxiter
|
||||
self.matching, self.keep_original = matching, keep_original
|
||||
self.num_conformers = num_conformers
|
||||
self.all_atoms = all_atoms
|
||||
|
||||
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
||||
if not os.path.exists(os.path.join(self.full_cache_path, "heterographs.pkl"))\
|
||||
or (require_ligand and not os.path.exists(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"))):
|
||||
if not self.check_all_complexes():
|
||||
os.makedirs(self.full_cache_path, exist_ok=True)
|
||||
if protein_path_list is None or ligand_descriptions is None:
|
||||
self.preprocessing()
|
||||
else:
|
||||
self.inference_preprocessing()
|
||||
|
||||
print('loading data from memory: ', os.path.join(self.full_cache_path, "heterographs.pkl"))
|
||||
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
|
||||
self.complex_graphs = pickle.load(f)
|
||||
if require_ligand:
|
||||
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
|
||||
self.rdkit_ligands = pickle.load(f)
|
||||
|
||||
self.complex_graphs, self.rdkit_ligands = self.collect_all_complexes()
|
||||
print_statistics(self.complex_graphs)
|
||||
list_names = [complex['name'] for complex in self.complex_graphs]
|
||||
with open(os.path.join(self.full_cache_path, f'pdbbind_{os.path.splitext(os.path.basename(self.split_path))[0][:3]}_names.txt'), 'w') as f:
|
||||
f.write('\n'.join(list_names))
|
||||
|
||||
def len(self):
|
||||
return len(self.complex_graphs)
|
||||
|
||||
def get(self, idx):
|
||||
complex_graph = copy.deepcopy(self.complex_graphs[idx])
|
||||
if self.require_ligand:
|
||||
complex_graph = copy.deepcopy(self.complex_graphs[idx])
|
||||
complex_graph.mol = copy.deepcopy(self.rdkit_ligands[idx])
|
||||
return complex_graph
|
||||
else:
|
||||
return copy.deepcopy(self.complex_graphs[idx])
|
||||
complex_graph.mol = RemoveAllHs(copy.deepcopy(self.rdkit_ligands[idx]))
|
||||
|
||||
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
|
||||
if hasattr(complex_graph, a):
|
||||
delattr(complex_graph, a)
|
||||
if hasattr(complex_graph['receptor'], a):
|
||||
delattr(complex_graph['receptor'], a)
|
||||
|
||||
return complex_graph
|
||||
|
||||
def preprocessing(self):
|
||||
print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]')
|
||||
@@ -132,94 +215,63 @@ class PDBBind(Dataset):
|
||||
if self.esm_embeddings_path is not None:
|
||||
id_to_embeddings = torch.load(self.esm_embeddings_path)
|
||||
chain_embeddings_dictlist = defaultdict(list)
|
||||
chain_indices_dictlist = defaultdict(list)
|
||||
for key, embedding in id_to_embeddings.items():
|
||||
key_name = key.split('_')[0]
|
||||
key_name = key.split('_chain_')[0]
|
||||
if key_name in complex_names_all:
|
||||
chain_embeddings_dictlist[key_name].append(embedding)
|
||||
chain_indices_dictlist[key_name].append(int(key.split('_chain_')[1]))
|
||||
lm_embeddings_chains_all = []
|
||||
for name in complex_names_all:
|
||||
lm_embeddings_chains_all.append(chain_embeddings_dictlist[name])
|
||||
complex_chains_embeddings = chain_embeddings_dictlist[name]
|
||||
complex_chains_indices = chain_indices_dictlist[name]
|
||||
chain_reorder_idx = np.argsort(complex_chains_indices)
|
||||
reordered_chains = [complex_chains_embeddings[i] for i in chain_reorder_idx]
|
||||
lm_embeddings_chains_all.append(reordered_chains)
|
||||
else:
|
||||
lm_embeddings_chains_all = [None] * len(complex_names_all)
|
||||
|
||||
if self.num_workers > 1:
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
for i in range(len(complex_names_all)//1000+1):
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
|
||||
continue
|
||||
complex_names = complex_names_all[1000*i:1000*(i+1)]
|
||||
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
|
||||
complex_graphs, rdkit_ligands = [], []
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
|
||||
complex_graphs.extend(t[0])
|
||||
rdkit_ligands.extend(t[1])
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs), f)
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands), f)
|
||||
|
||||
complex_graphs_all = []
|
||||
for i in range(len(complex_names_all)//1000+1):
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
complex_graphs_all.extend(l)
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs_all), f)
|
||||
|
||||
rdkit_ligands_all = []
|
||||
for i in range(len(complex_names_all) // 1000 + 1):
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
rdkit_ligands_all.extend(l)
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands_all), f)
|
||||
else:
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
list_indices = list(range(len(complex_names_all)//1000+1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
|
||||
continue
|
||||
complex_names = complex_names_all[1000*i:1000*(i+1)]
|
||||
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
|
||||
complex_graphs, rdkit_ligands = [], []
|
||||
with tqdm(total=len(complex_names_all), desc='loading complexes') as pbar:
|
||||
for t in map(self.get_complex, zip(complex_names_all, lm_embeddings_chains_all, [None] * len(complex_names_all), [None] * len(complex_names_all))):
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
|
||||
complex_graphs.extend(t[0])
|
||||
rdkit_ligands.extend(t[1])
|
||||
pbar.update()
|
||||
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs), f)
|
||||
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands), f)
|
||||
|
||||
def inference_preprocessing(self):
|
||||
ligands_list = []
|
||||
print('Reading molecules and generating local structures with RDKit (unless --keep_local_structures is turned on).')
|
||||
failed_ligand_indices = []
|
||||
for idx, ligand_description in tqdm(enumerate(self.ligand_descriptions)):
|
||||
try:
|
||||
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
|
||||
if mol is not None:
|
||||
print('Reading molecules and generating local structures with RDKit')
|
||||
for ligand_description in tqdm(self.ligand_descriptions):
|
||||
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
|
||||
if mol is not None:
|
||||
mol = AddHs(mol)
|
||||
generate_conformer(mol)
|
||||
ligands_list.append(mol)
|
||||
else:
|
||||
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True)
|
||||
if not self.keep_local_structures:
|
||||
mol.RemoveAllConformers()
|
||||
mol = AddHs(mol)
|
||||
generate_conformer(mol)
|
||||
ligands_list.append(mol)
|
||||
else:
|
||||
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True)
|
||||
if mol is None:
|
||||
raise Exception('RDKit could not read the molecule ', ligand_description)
|
||||
if not self.keep_local_structures:
|
||||
mol.RemoveAllConformers()
|
||||
mol = AddHs(mol)
|
||||
generate_conformer(mol)
|
||||
ligands_list.append(mol)
|
||||
except Exception as e:
|
||||
|
||||
print('Failed to read molecule ', ligand_description, ' We are skipping it. The reason is the exception: ', e)
|
||||
failed_ligand_indices.append(idx)
|
||||
for index in sorted(failed_ligand_indices, reverse=True):
|
||||
del self.protein_path_list[index]
|
||||
del self.ligand_descriptions[index]
|
||||
ligands_list.append(mol)
|
||||
|
||||
if self.esm_embeddings_path is not None:
|
||||
print('Reading language model embeddings.')
|
||||
@@ -235,129 +287,144 @@ class PDBBind(Dataset):
|
||||
lm_embeddings_chains_all = [None] * len(self.protein_path_list)
|
||||
|
||||
print('Generating graphs for ligands and proteins')
|
||||
if self.num_workers > 1:
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
for i in range(len(self.protein_path_list)//1000+1):
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
|
||||
continue
|
||||
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
|
||||
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
|
||||
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
|
||||
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
|
||||
complex_graphs, rdkit_ligands = [], []
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
|
||||
complex_graphs.extend(t[0])
|
||||
rdkit_ligands.extend(t[1])
|
||||
pbar.update()
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs), f)
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands), f)
|
||||
|
||||
complex_graphs_all = []
|
||||
for i in range(len(self.protein_path_list)//1000+1):
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
complex_graphs_all.extend(l)
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs_all), f)
|
||||
|
||||
rdkit_ligands_all = []
|
||||
for i in range(len(self.protein_path_list) // 1000 + 1):
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
rdkit_ligands_all.extend(l)
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands_all), f)
|
||||
else:
|
||||
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
|
||||
list_indices = list(range(len(self.protein_path_list)//1000+1))
|
||||
random.shuffle(list_indices)
|
||||
for i in list_indices:
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
|
||||
continue
|
||||
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
|
||||
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
|
||||
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
|
||||
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
|
||||
complex_graphs, rdkit_ligands = [], []
|
||||
with tqdm(total=len(self.protein_path_list), desc='loading complexes') as pbar:
|
||||
for t in map(self.get_complex, zip(self.protein_path_list, lm_embeddings_chains_all, ligands_list, self.ligand_descriptions)):
|
||||
if self.num_workers > 1:
|
||||
p = Pool(self.num_workers, maxtasksperchild=1)
|
||||
p.__enter__()
|
||||
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
|
||||
map_fn = p.imap_unordered if self.num_workers > 1 else map
|
||||
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
|
||||
complex_graphs.extend(t[0])
|
||||
rdkit_ligands.extend(t[1])
|
||||
pbar.update()
|
||||
if complex_graphs == []: raise Exception('Preprocessing did not succeed for any complex')
|
||||
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
|
||||
if self.num_workers > 1: p.__exit__(None, None, None)
|
||||
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((complex_graphs), f)
|
||||
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
|
||||
pickle.dump((rdkit_ligands), f)
|
||||
|
||||
def check_all_complexes(self):
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
|
||||
return True
|
||||
|
||||
complex_names_all = read_strings_from_txt(self.split_path)
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
for i in range(len(complex_names_all) // 1000 + 1):
|
||||
if not os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
|
||||
return False
|
||||
return True
|
||||
|
||||
def collect_all_complexes(self):
|
||||
print('Collecting all complexes from cache', self.full_cache_path)
|
||||
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
|
||||
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
|
||||
complex_graphs = pickle.load(f)
|
||||
if self.require_ligand:
|
||||
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
|
||||
rdkit_ligands = pickle.load(f)
|
||||
else:
|
||||
rdkit_ligands = None
|
||||
return complex_graphs, rdkit_ligands
|
||||
|
||||
complex_names_all = read_strings_from_txt(self.split_path)
|
||||
if self.limit_complexes is not None and self.limit_complexes != 0:
|
||||
complex_names_all = complex_names_all[:self.limit_complexes]
|
||||
complex_graphs_all = []
|
||||
for i in range(len(complex_names_all) // 1000 + 1):
|
||||
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
|
||||
print(i)
|
||||
l = pickle.load(f)
|
||||
complex_graphs_all.extend(l)
|
||||
|
||||
rdkit_ligands_all = []
|
||||
for i in range(len(complex_names_all) // 1000 + 1):
|
||||
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
|
||||
l = pickle.load(f)
|
||||
rdkit_ligands_all.extend(l)
|
||||
|
||||
return complex_graphs_all, rdkit_ligands_all
|
||||
|
||||
def get_complex(self, par):
|
||||
name, lm_embedding_chains, ligand, ligand_description = par
|
||||
if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None:
|
||||
print("Folder not found", name)
|
||||
return [], []
|
||||
|
||||
if ligand is not None:
|
||||
rec_model = parse_pdb_from_path(name)
|
||||
name = f'{name}____{ligand_description}'
|
||||
ligs = [ligand]
|
||||
else:
|
||||
try:
|
||||
rec_model = parse_receptor(name, self.pdbbind_dir)
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
print(e)
|
||||
try:
|
||||
|
||||
lig = read_mol(self.pdbbind_dir, name, suffix=self.ligand_file, remove_hs=False)
|
||||
if self.max_lig_size != None and lig.GetNumHeavyAtoms() > self.max_lig_size:
|
||||
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
|
||||
return [], []
|
||||
|
||||
ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
|
||||
complex_graphs = []
|
||||
failed_indices = []
|
||||
for i, lig in enumerate(ligs):
|
||||
if self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size:
|
||||
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
|
||||
continue
|
||||
complex_graph = HeteroData()
|
||||
complex_graph['name'] = name
|
||||
try:
|
||||
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
|
||||
self.num_conformers, remove_hs=self.remove_hs)
|
||||
rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings = extract_receptor_structure(copy.deepcopy(rec_model), lig, lm_embedding_chains=lm_embedding_chains)
|
||||
if lm_embeddings is not None and len(c_alpha_coords) != len(lm_embeddings):
|
||||
print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
|
||||
failed_indices.append(i)
|
||||
continue
|
||||
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
|
||||
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries)
|
||||
|
||||
get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius,
|
||||
c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms,
|
||||
atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings)
|
||||
moad_extract_receptor_structure(path=os.path.join(self.pdbbind_dir, name, f'{name}_{self.protein_file}.pdb'),
|
||||
complex_graph=complex_graph,
|
||||
neighbor_cutoff=self.receptor_radius,
|
||||
max_neighbors=self.c_alpha_max_neighbors,
|
||||
lm_embeddings=lm_embedding_chains,
|
||||
knn_only_graph=self.knn_only_graph,
|
||||
all_atoms=self.all_atoms,
|
||||
atom_cutoff=self.atom_radius,
|
||||
atom_max_neighbors=self.atom_max_neighbors)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
print(e)
|
||||
failed_indices.append(i)
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
print(e)
|
||||
return [], []
|
||||
|
||||
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
complex_graph['receptor'].pos -= protein_center
|
||||
if self.all_atoms:
|
||||
complex_graph['atom'].pos -= protein_center
|
||||
if self.dataset == 'posebusters':
|
||||
other_positions = []
|
||||
all_mol_file = os.path.join(self.pdbbind_dir, name, f'{name}_ligands.sdf')
|
||||
supplier = Chem.SDMolSupplier(all_mol_file, sanitize=False, removeHs=False)
|
||||
for mol in supplier:
|
||||
Chem.SanitizeMol(mol)
|
||||
all_mol = RemoveAllHs(mol)
|
||||
for conf in all_mol.GetConformers():
|
||||
other_positions.append(conf.GetPositions())
|
||||
|
||||
if (not self.matching) or self.num_conformers == 1:
|
||||
complex_graph['ligand'].pos -= protein_center
|
||||
else:
|
||||
for p in complex_graph['ligand'].pos:
|
||||
p -= protein_center
|
||||
print(f'Found {len(other_positions)} alternative poses for {name}')
|
||||
complex_graph['ligand'].orig_pos = np.asarray(other_positions)
|
||||
|
||||
complex_graph.original_center = protein_center
|
||||
complex_graphs.append(complex_graph)
|
||||
for idx_to_delete in sorted(failed_indices, reverse=True):
|
||||
del ligs[idx_to_delete]
|
||||
return complex_graphs, ligs
|
||||
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
|
||||
complex_graph['receptor'].pos -= protein_center
|
||||
if self.all_atoms:
|
||||
complex_graph['atom'].pos -= protein_center
|
||||
|
||||
if (not self.matching) or self.num_conformers == 1:
|
||||
complex_graph['ligand'].pos -= protein_center
|
||||
else:
|
||||
for p in complex_graph['ligand'].pos:
|
||||
p -= protein_center
|
||||
|
||||
complex_graph.original_center = protein_center
|
||||
complex_graph['receptor_name'] = name
|
||||
return [complex_graph], [lig]
|
||||
|
||||
|
||||
def print_statistics(complex_graphs):
|
||||
statistics = ([], [], [], [])
|
||||
statistics = ([], [], [], [], [], [])
|
||||
receptor_sizes = []
|
||||
|
||||
for complex_graph in complex_graphs:
|
||||
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
|
||||
receptor_sizes.append(complex_graph['receptor'].pos.shape[0])
|
||||
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
|
||||
molecule_center = torch.mean(lig_pos, dim=0)
|
||||
radius_molecule = torch.max(
|
||||
@@ -370,43 +437,25 @@ def print_statistics(complex_graphs):
|
||||
statistics[3].append(complex_graph.rmsd_matching)
|
||||
else:
|
||||
statistics[3].append(0)
|
||||
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1)
|
||||
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph:
|
||||
statistics[5].append(complex_graph.rmsd_matching)
|
||||
|
||||
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching']
|
||||
if len(statistics[5]) == 0:
|
||||
statistics[5].append(-1)
|
||||
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching']
|
||||
print('Number of complexes: ', len(complex_graphs))
|
||||
for i in range(4):
|
||||
for i in range(len(name)):
|
||||
array = np.asarray(statistics[i])
|
||||
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
|
||||
|
||||
|
||||
def construct_loader(args, t_to_sigma):
|
||||
transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
|
||||
all_atom=args.all_atoms)
|
||||
|
||||
common_args = {'transform': transform, 'root': args.data_dir, 'limit_complexes': args.limit_complexes,
|
||||
'receptor_radius': args.receptor_radius,
|
||||
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
|
||||
'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
|
||||
'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
|
||||
'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
|
||||
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
|
||||
'esm_embeddings_path': args.esm_embeddings_path}
|
||||
|
||||
train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
|
||||
num_conformers=args.num_conformers, **common_args)
|
||||
val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True, **common_args)
|
||||
|
||||
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
|
||||
train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
|
||||
val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
|
||||
|
||||
return train_loader, val_loader
|
||||
return
|
||||
|
||||
|
||||
def read_mol(pdbbind_dir, name, remove_hs=False):
|
||||
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.sdf'), remove_hs=remove_hs, sanitize=True)
|
||||
def read_mol(pdbbind_dir, name, suffix='ligand', remove_hs=False):
|
||||
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.sdf'), remove_hs=remove_hs, sanitize=True)
|
||||
if lig is None: # read mol2 file if sdf file cannot be sanitized
|
||||
print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
|
||||
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.mol2'), remove_hs=remove_hs, sanitize=True)
|
||||
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.mol2'), remove_hs=remove_hs, sanitize=True)
|
||||
return lig
|
||||
|
||||
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
import os
|
||||
from argparse import FileType, ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
from Bio.PDB import PDBParser
|
||||
from Bio.Seq import Seq
|
||||
from Bio.SeqRecord import SeqRecord
|
||||
from tqdm import tqdm
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--chain_cutoff', type=int, default=10, help='')
|
||||
parser.add_argument('--out_file', type=str, default="data/pdbbind_sequences.fasta")
|
||||
args = parser.parse_args()
|
||||
|
||||
cutoff = args.chain_cutoff
|
||||
data_dir = args.data_dir
|
||||
names = os.listdir(data_dir)
|
||||
#%%
|
||||
from Bio import SeqIO
|
||||
biopython_parser = PDBParser()
|
||||
|
||||
three_to_one = {'ALA': 'A',
|
||||
'ARG': 'R',
|
||||
'ASN': 'N',
|
||||
'ASP': 'D',
|
||||
'CYS': 'C',
|
||||
'GLN': 'Q',
|
||||
'GLU': 'E',
|
||||
'GLY': 'G',
|
||||
'HIS': 'H',
|
||||
'ILE': 'I',
|
||||
'LEU': 'L',
|
||||
'LYS': 'K',
|
||||
'MET': 'M',
|
||||
'MSE': 'M', # this is almost the same AA as MET. The sulfur is just replaced by Selen
|
||||
'PHE': 'F',
|
||||
'PRO': 'P',
|
||||
'PYL': 'O',
|
||||
'SER': 'S',
|
||||
'SEC': 'U',
|
||||
'THR': 'T',
|
||||
'TRP': 'W',
|
||||
'TYR': 'Y',
|
||||
'VAL': 'V',
|
||||
'ASX': 'B',
|
||||
'GLX': 'Z',
|
||||
'XAA': 'X',
|
||||
'XLE': 'J'}
|
||||
|
||||
sequences = []
|
||||
ids = []
|
||||
for name in tqdm(names):
|
||||
if name == '.DS_Store': continue
|
||||
if os.path.exists(os.path.join(data_dir, name, f'{name}_protein_processed.pdb')):
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein_processed.pdb')
|
||||
elif os.path.exists(os.path.join(data_dir, name, f'{name}_protein.pdb')):
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
|
||||
else:
|
||||
continue
|
||||
if cutoff > 10:
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein_obabel_reduce.pdb')
|
||||
if not os.path.exists(rec_path):
|
||||
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
|
||||
structure = biopython_parser.get_structure('random_id', rec_path)
|
||||
structure = structure[0]
|
||||
for i, chain in enumerate(structure):
|
||||
seq = ''
|
||||
for res_idx, residue in enumerate(chain):
|
||||
if residue.get_resname() == 'HOH':
|
||||
continue
|
||||
residue_coords = []
|
||||
c_alpha, n, c = None, None, None
|
||||
for atom in residue:
|
||||
if atom.name == 'CA':
|
||||
c_alpha = list(atom.get_vector())
|
||||
if atom.name == 'N':
|
||||
n = list(atom.get_vector())
|
||||
if atom.name == 'C':
|
||||
c = list(atom.get_vector())
|
||||
if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid and not
|
||||
try:
|
||||
seq += three_to_one[residue.get_resname()]
|
||||
except Exception as e:
|
||||
seq += '-'
|
||||
print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', name, '. Replacing it with a dash - .')
|
||||
sequences.append(seq)
|
||||
ids.append(f'{name}_chain_{i}')
|
||||
records = []
|
||||
for (index, seq) in zip(ids,sequences):
|
||||
record = SeqRecord(Seq(seq), str(index))
|
||||
record.description = ''
|
||||
records.append(record)
|
||||
SeqIO.write(records, args.out_file, "fasta")
|
||||
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
import copy
|
||||
import os
|
||||
import warnings
|
||||
|
||||
import numpy as np
|
||||
import scipy.spatial as spa
|
||||
import torch
|
||||
from Bio.PDB import PDBParser
|
||||
from Bio.PDB.PDBExceptions import PDBConstructionWarning
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem.rdchem import BondType as BT
|
||||
from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs
|
||||
from rdkit.Geometry import Point3D
|
||||
from scipy import spatial
|
||||
from scipy.special import softmax
|
||||
from torch_cluster import radius_graph
|
||||
|
||||
from torch import cdist
|
||||
from torch_cluster import knn_graph
|
||||
import prody as pr
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from datasets.conformer_matching import get_torsion_angles, optimize_rotatable_bonds
|
||||
from datasets.constants import aa_short2long, atom_order, three_to_one
|
||||
from datasets.parse_chi import get_chi_angles, get_coords, aa_idx2aa_short, get_onehot_sequence
|
||||
from utils.torsion import get_transformation_mask
|
||||
|
||||
|
||||
biopython_parser = PDBParser()
|
||||
periodic_table = GetPeriodicTable()
|
||||
allowable_features = {
|
||||
'possible_atomic_num_list': list(range(1, 119)) + ['misc'],
|
||||
@@ -94,9 +90,13 @@ def lig_atom_featurizer(mol):
|
||||
ringinfo = mol.GetRingInfo()
|
||||
atom_features_list = []
|
||||
for idx, atom in enumerate(mol.GetAtoms()):
|
||||
chiral_tag = str(atom.GetChiralTag())
|
||||
if chiral_tag in ['CHI_SQUAREPLANAR', 'CHI_TRIGONALBIPYRAMIDAL', 'CHI_OCTAHEDRAL']:
|
||||
chiral_tag = 'CHI_OTHER'
|
||||
|
||||
atom_features_list.append([
|
||||
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
|
||||
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
|
||||
allowable_features['possible_chirality_list'].index(str(chiral_tag)),
|
||||
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
|
||||
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
|
||||
safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()),
|
||||
@@ -111,18 +111,11 @@ def lig_atom_featurizer(mol):
|
||||
allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)),
|
||||
allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)),
|
||||
allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)),
|
||||
#g_charge if not np.isnan(g_charge) and not np.isinf(g_charge) else 0.
|
||||
])
|
||||
|
||||
return torch.tensor(atom_features_list)
|
||||
|
||||
|
||||
def rec_residue_featurizer(rec):
|
||||
feature_list = []
|
||||
for residue in rec.get_residues():
|
||||
feature_list.append([safe_index(allowable_features['possible_amino_acids'], residue.get_resname())])
|
||||
return torch.tensor(feature_list, dtype=torch.float32) # (N_res, 1)
|
||||
|
||||
|
||||
def safe_index(l, e):
|
||||
""" Return index of element e in list l. If e is not present, return the last index """
|
||||
try:
|
||||
@@ -131,122 +124,158 @@ def safe_index(l, e):
|
||||
return len(l) - 1
|
||||
|
||||
|
||||
def moad_extract_receptor_structure(path, complex_graph, neighbor_cutoff=20, max_neighbors=None, sequences_to_embeddings=None,
|
||||
knn_only_graph=False, lm_embeddings=None, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None):
|
||||
# load the entire pdb file
|
||||
pdb = pr.parsePDB(path)
|
||||
seq = pdb.ca.getSequence()
|
||||
coords = get_coords(pdb)
|
||||
one_hot = get_onehot_sequence(seq)
|
||||
|
||||
def parse_receptor(pdbid, pdbbind_dir):
|
||||
rec = parsePDB(pdbid, pdbbind_dir)
|
||||
return rec
|
||||
chain_ids = np.zeros(len(one_hot))
|
||||
res_chain_ids = pdb.ca.getChids()
|
||||
res_seg_ids = pdb.ca.getSegnames()
|
||||
res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
|
||||
ids = np.unique(res_chain_ids)
|
||||
sequences = []
|
||||
lm_embeddings = lm_embeddings if sequences_to_embeddings is None else []
|
||||
|
||||
for i, id in enumerate(ids):
|
||||
chain_ids[res_chain_ids == id] = i
|
||||
|
||||
s = np.argmax(one_hot[res_chain_ids == id], axis=1)
|
||||
s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s])
|
||||
sequences.append(s)
|
||||
if sequences_to_embeddings is not None:
|
||||
lm_embeddings.append(sequences_to_embeddings[s])
|
||||
|
||||
complex_graph['receptor'].sequence = sequences
|
||||
complex_graph['receptor'].chain_ids = torch.from_numpy(np.asarray(chain_ids)).long()
|
||||
|
||||
new_extract_receptor_structure(seq, coords, complex_graph, neighbor_cutoff=neighbor_cutoff, max_neighbors=max_neighbors,
|
||||
lm_embeddings=lm_embeddings, knn_only_graph=knn_only_graph, all_atoms=all_atoms,
|
||||
atom_cutoff=atom_cutoff, atom_max_neighbors=atom_max_neighbors)
|
||||
|
||||
|
||||
def parsePDB(pdbid, pdbbind_dir):
|
||||
rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')
|
||||
return parse_pdb_from_path(rec_path)
|
||||
def new_extract_receptor_structure(seq, all_coords, complex_graph, neighbor_cutoff=20, max_neighbors=None, lm_embeddings=None,
|
||||
knn_only_graph=False, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None):
|
||||
chi_angles, one_hot = get_chi_angles(all_coords, seq, return_onehot=True)
|
||||
n_rel_pos, c_rel_pos = all_coords[:, 0, :] - all_coords[:, 1, :], all_coords[:, 2, :] - all_coords[:, 1, :]
|
||||
side_chain_vecs = torch.from_numpy(np.concatenate([chi_angles / 360, n_rel_pos, c_rel_pos], axis=1))
|
||||
|
||||
def parse_pdb_from_path(path):
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=PDBConstructionWarning)
|
||||
structure = biopython_parser.get_structure('random_id', path)
|
||||
rec = structure[0]
|
||||
return rec
|
||||
# Build the k-NN graph
|
||||
coords = torch.tensor(all_coords[:, 1, :], dtype=torch.float)
|
||||
if len(coords) > 3000:
|
||||
raise ValueError(f'The receptor is too large {len(coords)}')
|
||||
if knn_only_graph:
|
||||
edge_index = knn_graph(coords, k=max_neighbors if max_neighbors else 32)
|
||||
else:
|
||||
distances = cdist(coords, coords)
|
||||
src_list = []
|
||||
dst_list = []
|
||||
for i in range(len(coords)):
|
||||
dst = list(np.where(distances[i, :] < neighbor_cutoff)[0])
|
||||
dst.remove(i)
|
||||
max_neighbors = max_neighbors if max_neighbors else 1000
|
||||
if max_neighbors != None and len(dst) > max_neighbors:
|
||||
dst = list(np.argsort(distances[i, :]))[1: max_neighbors + 1]
|
||||
if len(dst) == 0:
|
||||
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
|
||||
print(
|
||||
f'The cutoff {neighbor_cutoff} was too small for one atom such that it had no neighbors. '
|
||||
f'So we connected it to the closest other atom')
|
||||
assert i not in dst
|
||||
src = [i] * len(dst)
|
||||
src_list.extend(src)
|
||||
dst_list.extend(dst)
|
||||
edge_index = torch.from_numpy(np.asarray([dst_list, src_list]))
|
||||
|
||||
res_names_list = [aa_short2long[seq[i]] if seq[i] in aa_short2long else 'misc' for i in range(len(seq))]
|
||||
feature_list = [[safe_index(allowable_features['possible_amino_acids'], res)] for res in res_names_list]
|
||||
node_feat = torch.tensor(feature_list, dtype=torch.float32)
|
||||
|
||||
lm_embeddings = torch.tensor(np.concatenate(lm_embeddings, axis=0)) if lm_embeddings is not None else None
|
||||
complex_graph['receptor'].x = torch.cat([node_feat, lm_embeddings], axis=1) if lm_embeddings is not None else node_feat
|
||||
complex_graph['receptor'].pos = coords
|
||||
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
|
||||
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = edge_index
|
||||
if all_atoms:
|
||||
atom_coords = all_coords.reshape(-1, 3)
|
||||
atom_coords = torch.from_numpy(atom_coords[~np.any(np.isnan(atom_coords), axis=1)]).float()
|
||||
|
||||
if knn_only_graph:
|
||||
atoms_edge_index = knn_graph(atom_coords, k=atom_max_neighbors if atom_max_neighbors else 1000)
|
||||
else:
|
||||
atoms_distances = cdist(atom_coords, atom_coords)
|
||||
atom_src_list = []
|
||||
atom_dst_list = []
|
||||
for i in range(len(atom_coords)):
|
||||
dst = list(np.where(atoms_distances[i, :] < atom_cutoff)[0])
|
||||
dst.remove(i)
|
||||
max_neighbors = atom_max_neighbors if atom_max_neighbors else 1000
|
||||
if max_neighbors != None and len(dst) > max_neighbors:
|
||||
dst = list(np.argsort(atoms_distances[i, :]))[1: max_neighbors + 1]
|
||||
if len(dst) == 0:
|
||||
dst = list(np.argsort(atoms_distances[i, :]))[1:2] # choose second because first is i itself
|
||||
print(
|
||||
f'The atom_cutoff {atom_cutoff} was too small for one atom such that it had no neighbors. '
|
||||
f'So we connected it to the closest other atom')
|
||||
assert i not in dst
|
||||
src = [i] * len(dst)
|
||||
atom_src_list.extend(src)
|
||||
atom_dst_list.extend(dst)
|
||||
atoms_edge_index = torch.from_numpy(np.asarray([atom_dst_list, atom_src_list]))
|
||||
|
||||
feats = [get_moad_atom_feats(res, all_coords[i]) for i, res in enumerate(seq)]
|
||||
atom_feat = torch.from_numpy(np.concatenate(feats, axis=0)).float()
|
||||
c_alpha_idx = np.concatenate([np.zeros(len(f)) + i for i, f in enumerate(feats)])
|
||||
np_array = np.stack([np.arange(len(atom_feat)), c_alpha_idx])
|
||||
atom_res_edge_index = torch.from_numpy(np_array).long()
|
||||
complex_graph['atom'].x = atom_feat
|
||||
complex_graph['atom'].pos = atom_coords
|
||||
assert len(complex_graph['atom'].x) == len(complex_graph['atom'].pos)
|
||||
complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index
|
||||
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
|
||||
|
||||
return
|
||||
|
||||
|
||||
def extract_receptor_structure(rec, lig, lm_embedding_chains=None):
|
||||
conf = lig.GetConformer()
|
||||
lig_coords = conf.GetPositions()
|
||||
min_distances = []
|
||||
coords = []
|
||||
c_alpha_coords = []
|
||||
n_coords = []
|
||||
c_coords = []
|
||||
valid_chain_ids = []
|
||||
lengths = []
|
||||
for i, chain in enumerate(rec):
|
||||
chain_coords = [] # num_residues, num_atoms, 3
|
||||
chain_c_alpha_coords = []
|
||||
chain_n_coords = []
|
||||
chain_c_coords = []
|
||||
count = 0
|
||||
invalid_res_ids = []
|
||||
for res_idx, residue in enumerate(chain):
|
||||
if residue.get_resname() == 'HOH':
|
||||
invalid_res_ids.append(residue.get_id())
|
||||
continue
|
||||
residue_coords = []
|
||||
c_alpha, n, c = None, None, None
|
||||
for atom in residue:
|
||||
if atom.name == 'CA':
|
||||
c_alpha = list(atom.get_vector())
|
||||
if atom.name == 'N':
|
||||
n = list(atom.get_vector())
|
||||
if atom.name == 'C':
|
||||
c = list(atom.get_vector())
|
||||
residue_coords.append(list(atom.get_vector()))
|
||||
|
||||
if c_alpha != None and n != None and c != None:
|
||||
# only append residue if it is an amino acid and not some weird molecule that is part of the complex
|
||||
chain_c_alpha_coords.append(c_alpha)
|
||||
chain_n_coords.append(n)
|
||||
chain_c_coords.append(c)
|
||||
chain_coords.append(np.array(residue_coords))
|
||||
count += 1
|
||||
def get_moad_atom_feats(res, coords):
|
||||
feats = []
|
||||
res_long = aa_short2long[res]
|
||||
res_order = atom_order[res]
|
||||
for i, c in enumerate(coords):
|
||||
if np.any(np.isnan(c)):
|
||||
continue
|
||||
atom_feats = []
|
||||
if res == '-':
|
||||
atom_feats = [safe_index(allowable_features['possible_amino_acids'], 'misc'),
|
||||
safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_3'], 'misc')]
|
||||
else:
|
||||
atom_feats.append(safe_index(allowable_features['possible_amino_acids'], res_long))
|
||||
if i >= len(res_order):
|
||||
atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_3'], 'misc')])
|
||||
else:
|
||||
invalid_res_ids.append(residue.get_id())
|
||||
for res_id in invalid_res_ids:
|
||||
chain.detach_child(res_id)
|
||||
if len(chain_coords) > 0:
|
||||
all_chain_coords = np.concatenate(chain_coords, axis=0)
|
||||
distances = spatial.distance.cdist(lig_coords, all_chain_coords)
|
||||
min_distance = distances.min()
|
||||
else:
|
||||
min_distance = np.inf
|
||||
atom_name = res_order[i]
|
||||
try:
|
||||
atomic_num = periodic_table.GetAtomicNumber(atom_name[:1])
|
||||
except:
|
||||
print("element", res_order[i][:1], 'not found')
|
||||
atomic_num = -1
|
||||
|
||||
min_distances.append(min_distance)
|
||||
lengths.append(count)
|
||||
coords.append(chain_coords)
|
||||
c_alpha_coords.append(np.array(chain_c_alpha_coords))
|
||||
n_coords.append(np.array(chain_n_coords))
|
||||
c_coords.append(np.array(chain_c_coords))
|
||||
if not count == 0: valid_chain_ids.append(chain.get_id())
|
||||
|
||||
min_distances = np.array(min_distances)
|
||||
if len(valid_chain_ids) == 0:
|
||||
valid_chain_ids.append(np.argmin(min_distances))
|
||||
valid_coords = []
|
||||
valid_c_alpha_coords = []
|
||||
valid_n_coords = []
|
||||
valid_c_coords = []
|
||||
valid_lengths = []
|
||||
invalid_chain_ids = []
|
||||
valid_lm_embeddings = []
|
||||
for i, chain in enumerate(rec):
|
||||
if chain.get_id() in valid_chain_ids:
|
||||
valid_coords.append(coords[i])
|
||||
valid_c_alpha_coords.append(c_alpha_coords[i])
|
||||
if lm_embedding_chains is not None:
|
||||
if i >= len(lm_embedding_chains):
|
||||
raise ValueError('Encountered valid chain id that was not present in the LM embeddings')
|
||||
valid_lm_embeddings.append(lm_embedding_chains[i])
|
||||
valid_n_coords.append(n_coords[i])
|
||||
valid_c_coords.append(c_coords[i])
|
||||
valid_lengths.append(lengths[i])
|
||||
else:
|
||||
invalid_chain_ids.append(chain.get_id())
|
||||
coords = [item for sublist in valid_coords for item in sublist] # list with n_residues arrays: [n_atoms, 3]
|
||||
|
||||
c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) # [n_residues, 3]
|
||||
n_coords = np.concatenate(valid_n_coords, axis=0) # [n_residues, 3]
|
||||
c_coords = np.concatenate(valid_c_coords, axis=0) # [n_residues, 3]
|
||||
lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None
|
||||
for invalid_id in invalid_chain_ids:
|
||||
rec.detach_child(invalid_id)
|
||||
|
||||
assert len(c_alpha_coords) == len(n_coords)
|
||||
assert len(c_alpha_coords) == len(c_coords)
|
||||
assert sum(valid_lengths) == len(c_alpha_coords)
|
||||
return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings
|
||||
atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
|
||||
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
|
||||
safe_index(allowable_features['possible_atom_type_3'], atom_name)])
|
||||
feats.append(atom_feats)
|
||||
feats = np.asarray(feats)
|
||||
return feats
|
||||
|
||||
|
||||
def get_lig_graph(mol, complex_graph):
|
||||
lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float()
|
||||
atom_feats = lig_atom_featurizer(mol)
|
||||
|
||||
row, col, edge_type = [], [], []
|
||||
@@ -261,52 +290,77 @@ def get_lig_graph(mol, complex_graph):
|
||||
edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
|
||||
|
||||
complex_graph['ligand'].x = atom_feats
|
||||
complex_graph['ligand'].pos = lig_coords
|
||||
complex_graph['ligand', 'lig_bond', 'ligand'].edge_index = edge_index
|
||||
complex_graph['ligand', 'lig_bond', 'ligand'].edge_attr = edge_attr
|
||||
|
||||
if mol.GetNumConformers() > 0:
|
||||
lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float()
|
||||
complex_graph['ligand'].pos = lig_coords
|
||||
|
||||
return
|
||||
|
||||
|
||||
def generate_conformer(mol):
|
||||
ps = AllChem.ETKDGv2()
|
||||
id = AllChem.EmbedMolecule(mol, ps)
|
||||
failures, id = 0, -1
|
||||
while failures < 3 and id == -1:
|
||||
if failures > 0:
|
||||
print(f'rdkit coords could not be generated. trying again {failures}.')
|
||||
id = AllChem.EmbedMolecule(mol, ps)
|
||||
failures += 1
|
||||
if id == -1:
|
||||
print('rdkit coords could not be generated without using random coords. using random coords now.')
|
||||
ps.useRandomCoords = True
|
||||
AllChem.EmbedMolecule(mol, ps)
|
||||
AllChem.MMFFOptimizeMolecule(mol, confId=0)
|
||||
# else:
|
||||
# AllChem.MMFFOptimizeMolecule(mol_rdkit, confId=0)
|
||||
return True
|
||||
#else:
|
||||
# AllChem.MMFFOptimizeMolecule(mol, confId=0)
|
||||
return False
|
||||
|
||||
def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs):
|
||||
|
||||
def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs, tries=10, skip_matching=False):
|
||||
if matching:
|
||||
mol_maybe_noh = copy.deepcopy(mol_)
|
||||
if remove_hs:
|
||||
mol_maybe_noh = RemoveHs(mol_maybe_noh, sanitize=True)
|
||||
mol_maybe_noh = AllChem.RemoveAllHs(mol_maybe_noh)
|
||||
if keep_original:
|
||||
complex_graph['ligand'].orig_pos = mol_maybe_noh.GetConformer().GetPositions()
|
||||
positions = []
|
||||
for conf in mol_maybe_noh.GetConformers():
|
||||
positions.append(conf.GetPositions())
|
||||
complex_graph['ligand'].orig_pos = np.asarray(positions) if len(positions) > 1 else positions[0]
|
||||
|
||||
rotable_bonds = get_torsion_angles(mol_maybe_noh)
|
||||
if not rotable_bonds: print("no_rotable_bonds but still using it")
|
||||
#if not rotable_bonds: print("no_rotable_bonds but still using it")
|
||||
|
||||
for i in range(num_conformers):
|
||||
mol_rdkit = copy.deepcopy(mol_)
|
||||
mols, rmsds = [], []
|
||||
for _ in range(tries):
|
||||
mol_rdkit = copy.deepcopy(mol_)
|
||||
|
||||
mol_rdkit.RemoveAllConformers()
|
||||
mol_rdkit = AllChem.AddHs(mol_rdkit)
|
||||
generate_conformer(mol_rdkit)
|
||||
if remove_hs:
|
||||
mol_rdkit = RemoveHs(mol_rdkit, sanitize=True)
|
||||
mol = copy.deepcopy(mol_maybe_noh)
|
||||
if rotable_bonds:
|
||||
optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter)
|
||||
mol.AddConformer(mol_rdkit.GetConformer())
|
||||
rms_list = []
|
||||
AllChem.AlignMolConformers(mol, RMSlist=rms_list)
|
||||
mol_rdkit.RemoveAllConformers()
|
||||
mol_rdkit.AddConformer(mol.GetConformers()[1])
|
||||
mol_rdkit.RemoveAllConformers()
|
||||
mol_rdkit = AllChem.AddHs(mol_rdkit)
|
||||
generate_conformer(mol_rdkit)
|
||||
if remove_hs:
|
||||
mol_rdkit = RemoveHs(mol_rdkit, sanitize=True)
|
||||
mol_rdkit = AllChem.RemoveAllHs(mol_rdkit)
|
||||
mol = AllChem.RemoveAllHs(copy.deepcopy(mol_maybe_noh))
|
||||
if rotable_bonds and not skip_matching:
|
||||
optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter)
|
||||
mol.AddConformer(mol_rdkit.GetConformer())
|
||||
rms_list = []
|
||||
AllChem.AlignMolConformers(mol, RMSlist=rms_list)
|
||||
mol_rdkit.RemoveAllConformers()
|
||||
mol_rdkit.AddConformer(mol.GetConformers()[1])
|
||||
mols.append(mol_rdkit)
|
||||
rmsds.append(rms_list[0])
|
||||
|
||||
# select molecule with lowest rmsd
|
||||
#print("mean std min max", np.mean(rmsds), np.std(rmsds), np.min(rmsds), np.max(rmsds))
|
||||
mol_rdkit = mols[np.argmin(rmsds)]
|
||||
if i == 0:
|
||||
complex_graph.rmsd_matching = rms_list[0]
|
||||
complex_graph.rmsd_matching = min(rmsds)
|
||||
get_lig_graph(mol_rdkit, complex_graph)
|
||||
else:
|
||||
if torch.is_tensor(complex_graph['ligand'].pos):
|
||||
@@ -325,157 +379,34 @@ def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching,
|
||||
return
|
||||
|
||||
|
||||
def get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, cutoff=20, max_neighbor=None, lm_embeddings=None):
|
||||
n_rel_pos = n_coords - c_alpha_coords
|
||||
c_rel_pos = c_coords - c_alpha_coords
|
||||
num_residues = len(c_alpha_coords)
|
||||
if num_residues <= 1:
|
||||
raise ValueError(f"rec contains only 1 residue!")
|
||||
|
||||
# Build the k-NN graph
|
||||
distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords)
|
||||
src_list = []
|
||||
dst_list = []
|
||||
mean_norm_list = []
|
||||
for i in range(num_residues):
|
||||
dst = list(np.where(distances[i, :] < cutoff)[0])
|
||||
dst.remove(i)
|
||||
if max_neighbor != None and len(dst) > max_neighbor:
|
||||
dst = list(np.argsort(distances[i, :]))[1: max_neighbor + 1]
|
||||
if len(dst) == 0:
|
||||
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
|
||||
print(f'The c_alpha_cutoff {cutoff} was too small for one c_alpha such that it had no neighbors. '
|
||||
f'So we connected it to the closest other c_alpha')
|
||||
assert i not in dst
|
||||
src = [i] * len(dst)
|
||||
src_list.extend(src)
|
||||
dst_list.extend(dst)
|
||||
valid_dist = list(distances[i, dst])
|
||||
valid_dist_np = distances[i, dst]
|
||||
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
|
||||
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
|
||||
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01
|
||||
diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] # (neigh_num, 3)
|
||||
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
|
||||
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
|
||||
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
|
||||
mean_norm_list.append(mean_vec_ratio_norm)
|
||||
assert len(src_list) == len(dst_list)
|
||||
|
||||
node_feat = rec_residue_featurizer(rec)
|
||||
mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
|
||||
side_chain_vecs = torch.from_numpy(
|
||||
np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1))
|
||||
|
||||
complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat
|
||||
complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float()
|
||||
complex_graph['receptor'].mu_r_norm = mu_r_norm
|
||||
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
|
||||
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list]))
|
||||
|
||||
return
|
||||
|
||||
|
||||
def rec_atom_featurizer(rec):
|
||||
atom_feats = []
|
||||
for i, atom in enumerate(rec.get_atoms()):
|
||||
atom_name, element = atom.name, atom.element
|
||||
if element == 'CD':
|
||||
element = 'C'
|
||||
assert not element == ''
|
||||
try:
|
||||
atomic_num = periodic_table.GetAtomicNumber(element)
|
||||
except:
|
||||
atomic_num = -1
|
||||
atom_feat = [safe_index(allowable_features['possible_amino_acids'], atom.get_parent().get_resname()),
|
||||
safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
|
||||
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
|
||||
safe_index(allowable_features['possible_atom_type_3'], atom_name)]
|
||||
atom_feats.append(atom_feat)
|
||||
|
||||
return atom_feats
|
||||
|
||||
|
||||
def get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors=None, all_atoms=False,
|
||||
atom_radius=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None):
|
||||
if all_atoms:
|
||||
return get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph,
|
||||
c_alpha_cutoff=rec_radius, c_alpha_max_neighbors=c_alpha_max_neighbors,
|
||||
atom_cutoff=atom_radius, atom_max_neighbors=atom_max_neighbors, remove_hs=remove_hs,lm_embeddings=lm_embeddings)
|
||||
def get_rec_misc_atom_feat(bio_atom=None, atom_name=None, element=None, get_misc_features=False):
|
||||
if get_misc_features:
|
||||
return [safe_index(allowable_features['possible_amino_acids'], 'misc'),
|
||||
safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
|
||||
safe_index(allowable_features['possible_atom_type_3'], 'misc')]
|
||||
if atom_name is not None:
|
||||
atom_name = atom_name
|
||||
else:
|
||||
return get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors,lm_embeddings=lm_embeddings)
|
||||
atom_name = bio_atom.name
|
||||
if element is not None:
|
||||
element = element
|
||||
else:
|
||||
element = bio_atom.element
|
||||
if element == 'CD':
|
||||
element = 'C'
|
||||
assert not element == ''
|
||||
try:
|
||||
atomic_num = periodic_table.GetAtomicNumber(element.lower().capitalize())
|
||||
except:
|
||||
atomic_num = -1
|
||||
|
||||
atom_feat = [safe_index(allowable_features['possible_amino_acids'], bio_atom.get_parent().get_resname()),
|
||||
safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
|
||||
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
|
||||
safe_index(allowable_features['possible_atom_type_3'], atom_name)]
|
||||
return atom_feat
|
||||
|
||||
def get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, c_alpha_cutoff=20,
|
||||
c_alpha_max_neighbors=None, atom_cutoff=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None):
|
||||
# builds the receptor graph with both residues and atoms
|
||||
|
||||
n_rel_pos = n_coords - c_alpha_coords
|
||||
c_rel_pos = c_coords - c_alpha_coords
|
||||
num_residues = len(c_alpha_coords)
|
||||
if num_residues <= 1:
|
||||
raise ValueError(f"rec contains only 1 residue!")
|
||||
|
||||
# Build the k-NN graph of residues
|
||||
distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords)
|
||||
src_list = []
|
||||
dst_list = []
|
||||
mean_norm_list = []
|
||||
for i in range(num_residues):
|
||||
dst = list(np.where(distances[i, :] < c_alpha_cutoff)[0])
|
||||
dst.remove(i)
|
||||
if c_alpha_max_neighbors != None and len(dst) > c_alpha_max_neighbors:
|
||||
dst = list(np.argsort(distances[i, :]))[1: c_alpha_max_neighbors + 1]
|
||||
if len(dst) == 0:
|
||||
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
|
||||
print(f'The c_alpha_cutoff {c_alpha_cutoff} was too small for one c_alpha such that it had no neighbors. '
|
||||
f'So we connected it to the closest other c_alpha')
|
||||
assert i not in dst
|
||||
src = [i] * len(dst)
|
||||
src_list.extend(src)
|
||||
dst_list.extend(dst)
|
||||
valid_dist = list(distances[i, dst])
|
||||
valid_dist_np = distances[i, dst]
|
||||
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
|
||||
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
|
||||
assert 1 - 1e-2 < weights[0].sum() < 1.01
|
||||
diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] # (neigh_num, 3)
|
||||
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
|
||||
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
|
||||
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
|
||||
mean_norm_list.append(mean_vec_ratio_norm)
|
||||
assert len(src_list) == len(dst_list)
|
||||
|
||||
node_feat = rec_residue_featurizer(rec)
|
||||
mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
|
||||
side_chain_vecs = torch.from_numpy(
|
||||
np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1))
|
||||
|
||||
complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat
|
||||
complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float()
|
||||
complex_graph['receptor'].mu_r_norm = mu_r_norm
|
||||
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
|
||||
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list]))
|
||||
|
||||
src_c_alpha_idx = np.concatenate([np.asarray([i]*len(l)) for i, l in enumerate(rec_coords)])
|
||||
atom_feat = torch.from_numpy(np.asarray(rec_atom_featurizer(rec)))
|
||||
atom_coords = torch.from_numpy(np.concatenate(rec_coords, axis=0)).float()
|
||||
|
||||
if remove_hs:
|
||||
not_hs = (atom_feat[:, 1] != 0)
|
||||
src_c_alpha_idx = src_c_alpha_idx[not_hs]
|
||||
atom_feat = atom_feat[not_hs]
|
||||
atom_coords = atom_coords[not_hs]
|
||||
|
||||
atoms_edge_index = radius_graph(atom_coords, atom_cutoff, max_num_neighbors=atom_max_neighbors if atom_max_neighbors else 1000)
|
||||
atom_res_edge_index = torch.from_numpy(np.asarray([np.arange(len(atom_feat)), src_c_alpha_idx])).long()
|
||||
|
||||
complex_graph['atom'].x = atom_feat
|
||||
complex_graph['atom'].pos = atom_coords
|
||||
complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index
|
||||
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
|
||||
|
||||
return
|
||||
|
||||
def write_mol_with_coords(mol, new_coords, path):
|
||||
w = Chem.SDWriter(path)
|
||||
@@ -502,8 +433,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
|
||||
elif molecule_file.endswith('.pdb'):
|
||||
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
|
||||
else:
|
||||
raise ValueError('Expect the format of the molecule_file to be '
|
||||
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
|
||||
return ValueError('Expect the format of the molecule_file to be '
|
||||
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
|
||||
|
||||
try:
|
||||
if sanitize or calc_charges:
|
||||
@@ -518,30 +449,7 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
|
||||
|
||||
if remove_hs:
|
||||
mol = Chem.RemoveHs(mol, sanitize=sanitize)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
print("RDKit was unable to read the molecule.")
|
||||
except:
|
||||
return None
|
||||
|
||||
return mol
|
||||
|
||||
|
||||
def read_sdf_or_mol2(sdf_fileName, mol2_fileName):
|
||||
|
||||
mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False)
|
||||
problem = False
|
||||
try:
|
||||
Chem.SanitizeMol(mol)
|
||||
mol = Chem.RemoveHs(mol)
|
||||
except Exception as e:
|
||||
problem = True
|
||||
if problem:
|
||||
mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False)
|
||||
try:
|
||||
Chem.SanitizeMol(mol)
|
||||
mol = Chem.RemoveHs(mol)
|
||||
problem = False
|
||||
except Exception as e:
|
||||
problem = True
|
||||
|
||||
return mol, problem
|
||||
|
||||
39
datasets/sidechain_esm_embeddings_to_pt.py
Normal file
39
datasets/sidechain_esm_embeddings_to_pt.py
Normal file
@@ -0,0 +1,39 @@
|
||||
|
||||
import os
|
||||
import pickle
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--esm_embeddings_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new', help='')
|
||||
parser.add_argument('--output_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new.pt', help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
dic = {}
|
||||
|
||||
# read text file with all sequences
|
||||
with open('data/pdb_2021aug02/sequences_to_id.fasta') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# read sequences
|
||||
with open('data/pdb_2021aug02/useful_sequences.pkl', 'rb') as f:
|
||||
sequences = pickle.load(f)
|
||||
|
||||
ids = set()
|
||||
dict_seq_id = {seq[:-1]: str(id) for id, seq in enumerate(lines)}
|
||||
|
||||
for i, seq in tqdm(enumerate(sequences)):
|
||||
ids.add(dict_seq_id[seq])
|
||||
if i == 20000: break
|
||||
|
||||
print("total", len(ids), "out of", len(os.listdir(args.esm_embeddings_path)))
|
||||
|
||||
available = set([filename.split('.')[0] for filename in os.listdir(args.esm_embeddings_path)])
|
||||
final = available.intersection(ids)
|
||||
|
||||
for idp in tqdm(final):
|
||||
dic[idp] = torch.load(os.path.join(args.esm_embeddings_path, idp+'.pt'))['representations'][33]
|
||||
torch.save(dic,args.output_path)
|
||||
225
environment.yml
225
environment.yml
@@ -1,102 +1,201 @@
|
||||
name: diffdock
|
||||
channels:
|
||||
- pytorch
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- blas=1.0
|
||||
- brotlipy=0.7.0
|
||||
- boost=1.74.0
|
||||
- boost-cpp=1.74.0
|
||||
- brotli=1.0.9
|
||||
- brotli-bin=1.0.9
|
||||
- bzip2=1.0.8
|
||||
- ca-certificates=2022.07.19
|
||||
- certifi=2022.9.14
|
||||
- cffi=1.15.1
|
||||
- charset-normalizer=2.0.4
|
||||
- cryptography=37.0.1
|
||||
- ffmpeg=4.3
|
||||
- freetype=2.11.0
|
||||
- gettext=0.21.0
|
||||
- ca-certificates=2022.6.15
|
||||
- cairo=1.16.0
|
||||
- certifi=2022.6.15
|
||||
- cycler=0.11.0
|
||||
- expat=2.4.8
|
||||
- font-ttf-dejavu-sans-mono=2.37
|
||||
- font-ttf-inconsolata=3.000
|
||||
- font-ttf-source-code-pro=2.038
|
||||
- font-ttf-ubuntu=0.83
|
||||
- fontconfig=2.14.0
|
||||
- fonts-conda-ecosystem=1
|
||||
- fonts-conda-forge=1
|
||||
- fonttools=4.33.3
|
||||
- freetype=2.10.4
|
||||
- gettext=0.19.8.1
|
||||
- giflib=5.2.1
|
||||
- gmp=6.2.1
|
||||
- gnutls=3.6.15
|
||||
- icu=58.2
|
||||
- idna=3.3
|
||||
- intel-openmp=2021.4.0
|
||||
- greenlet=1.1.2
|
||||
- icu=70.1
|
||||
- jpeg=9e
|
||||
- lame=3.100
|
||||
- kiwisolver=1.4.3
|
||||
- lcms2=2.12
|
||||
- lerc=3.0
|
||||
- libblas=3.9.0
|
||||
- libbrotlicommon=1.0.9
|
||||
- libbrotlidec=1.0.9
|
||||
- libbrotlienc=1.0.9
|
||||
- libcblas=3.9.0
|
||||
- libcxx=14.0.6
|
||||
- libdeflate=1.8
|
||||
- libffi=3.3
|
||||
- libdeflate=1.12
|
||||
- libffi=3.4.2
|
||||
- libgfortran=5.0.0
|
||||
- libgfortran5=9.3.0
|
||||
- libglib=2.70.2
|
||||
- libiconv=1.16
|
||||
- libidn2=2.3.2
|
||||
- liblapack=3.9.0
|
||||
- libopenblas=0.3.20
|
||||
- libpng=1.6.37
|
||||
- libtasn1=4.16.0
|
||||
- libtiff=4.4.0
|
||||
- libunistring=0.9.10
|
||||
- libwebp=1.2.2
|
||||
- libwebp-base=1.2.2
|
||||
- libxcb=1.13
|
||||
- libxml2=2.9.14
|
||||
- llvm-openmp=14.0.6
|
||||
- libzlib=1.2.12
|
||||
- llvm-openmp=14.0.4
|
||||
- lz4-c=1.9.3
|
||||
- mkl=2021.4.0
|
||||
- mkl-service=2.4.0
|
||||
- mkl_fft=1.3.1
|
||||
- mkl_random=1.2.2
|
||||
- matplotlib=3.5.2
|
||||
- matplotlib-base=3.5.2
|
||||
- munkres=1.1.4
|
||||
- ncurses=6.3
|
||||
- nettle=3.7.3
|
||||
- numpy=1.23.1
|
||||
- numpy-base=1.23.1
|
||||
- openh264=2.1.1
|
||||
- openssl=1.1.1q
|
||||
- pillow=9.2.0
|
||||
- pip=22.2.2
|
||||
- pycparser=2.21
|
||||
- pyopenssl=22.0.0
|
||||
- pysocks=1.7.1
|
||||
- nomkl=3.0
|
||||
- numpy=1.23.0
|
||||
- openbabel=3.1.1
|
||||
- openjpeg=2.4.0
|
||||
- openssl=3.0.5
|
||||
- packaging=21.3
|
||||
- pandas=1.4.3
|
||||
- pcre=8.45
|
||||
- pillow=9.1.1
|
||||
- pip=22.1.2
|
||||
- pixman=0.40.0
|
||||
- pthread-stubs=0.4
|
||||
- pyaml=21.10.1
|
||||
- pycairo=1.21.0
|
||||
- pyparsing=3.0.9
|
||||
- python=3.9.13
|
||||
- pytorch=1.12.1
|
||||
- python-dateutil=2.8.2
|
||||
- python_abi=3.9
|
||||
- pytz=2022.1
|
||||
- pyyaml=6.0
|
||||
- rdkit=2022.03.3
|
||||
- readline=8.1.2
|
||||
- requests=2.28.1
|
||||
- setuptools=63.4.1
|
||||
- setuptools=62.6.0
|
||||
- six=1.16.0
|
||||
- sqlite=3.39.3
|
||||
- spyrmsd=0.5.2
|
||||
- sqlalchemy=1.4.39
|
||||
- sqlite=3.39.0
|
||||
- tk=8.6.12
|
||||
- torchaudio=0.12.1
|
||||
- torchvision=0.13.1
|
||||
- typing_extensions=4.3.0
|
||||
- tzdata=2022c
|
||||
- urllib3=1.26.11
|
||||
- tornado=6.1
|
||||
- tzdata=2022a
|
||||
- unicodedata2=14.0.0
|
||||
- wheel=0.37.1
|
||||
- xz=5.2.6
|
||||
- xorg-libxau=1.0.9
|
||||
- xorg-libxdmcp=1.1.3
|
||||
- xz=5.2.5
|
||||
- yaml=0.2.5
|
||||
- zlib=1.2.12
|
||||
- zstd=1.5.2
|
||||
- pip:
|
||||
- appnope==0.1.3
|
||||
- argon2-cffi==21.3.0
|
||||
- argon2-cffi-bindings==21.2.0
|
||||
- asttokens==2.0.5
|
||||
- attrs==21.4.0
|
||||
- backcall==0.2.0
|
||||
- beautifulsoup4==4.11.1
|
||||
- biopandas==0.4.1
|
||||
- biopython==1.79
|
||||
- bleach==5.0.1
|
||||
- cffi==1.15.1
|
||||
- charset-normalizer==2.1.0
|
||||
- click==8.1.3
|
||||
- debugpy==1.6.2
|
||||
- decorator==5.1.1
|
||||
- defusedxml==0.7.1
|
||||
- docker-pycreds==0.4.0
|
||||
- e3nn==0.5.0
|
||||
- entrypoints==0.4
|
||||
- executing==0.8.3
|
||||
- fastjsonschema==2.15.3
|
||||
- gitdb==4.0.9
|
||||
- gitpython==3.1.27
|
||||
- h5py==3.7.0
|
||||
- idna==3.3
|
||||
- ipykernel==6.15.1
|
||||
- ipython==8.4.0
|
||||
- ipython-genutils==0.2.0
|
||||
- ipywidgets==7.7.1
|
||||
- jedi==0.18.1
|
||||
- jinja2==3.1.2
|
||||
- joblib==1.2.0
|
||||
- joblib==1.1.0
|
||||
- jsonschema==4.7.1
|
||||
- jupyter==1.0.0
|
||||
- jupyter-client==7.3.4
|
||||
- jupyter-console==6.4.4
|
||||
- jupyter-core==4.11.1
|
||||
- jupyterlab-pygments==0.2.2
|
||||
- jupyterlab-widgets==1.1.1
|
||||
- kaleido==0.2.1
|
||||
- markupsafe==2.1.1
|
||||
- matplotlib-inline==0.1.3
|
||||
- mistune==0.8.4
|
||||
- mpmath==1.2.1
|
||||
- networkx==2.8.7
|
||||
- nbclient==0.6.6
|
||||
- nbconvert==6.5.0
|
||||
- nbformat==5.4.0
|
||||
- nest-asyncio==1.5.5
|
||||
- networkx==2.8.4
|
||||
- notebook==6.4.12
|
||||
- opt-einsum==3.3.0
|
||||
- opt-einsum-fx==0.1.4
|
||||
- packaging==21.3
|
||||
- pandas==1.5.0
|
||||
- pyaml==21.10.1
|
||||
- pyparsing==3.0.9
|
||||
- python-dateutil==2.8.2
|
||||
- pytz==2022.4
|
||||
- pyyaml==6.0
|
||||
- rdkit-pypi==2022.3.5
|
||||
- scikit-learn==1.1.2
|
||||
- scipy==1.9.1
|
||||
- spyrmsd==0.5.2
|
||||
- sympy==1.11.1
|
||||
- pandocfilters==1.5.0
|
||||
- parso==0.8.3
|
||||
- pathtools==0.1.2
|
||||
- pexpect==4.8.0
|
||||
- pickleshare==0.7.5
|
||||
- plotly==5.9.0
|
||||
- prometheus-client==0.14.1
|
||||
- promise==2.3
|
||||
- prompt-toolkit==3.0.30
|
||||
- protobuf==3.20.1
|
||||
- psutil==5.9.1
|
||||
- ptyprocess==0.7.0
|
||||
- pure-eval==0.2.2
|
||||
- pycparser==2.21
|
||||
- pygments==2.12.0
|
||||
- pyrsistent==0.18.1
|
||||
- pyzmq==23.2.0
|
||||
- qtconsole==5.3.1
|
||||
- qtpy==2.1.0
|
||||
- requests==2.28.1
|
||||
- scikit-learn==1.1.1
|
||||
- scipy==1.8.1
|
||||
- send2trash==1.8.0
|
||||
- sentry-sdk==1.6.0
|
||||
- setproctitle==1.2.3
|
||||
- shortuuid==1.0.9
|
||||
- smmap==5.0.0
|
||||
- soupsieve==2.3.2.post1
|
||||
- stack-data==0.3.0
|
||||
- sympy==1.10.1
|
||||
- tenacity==8.0.1
|
||||
- terminado==0.15.0
|
||||
- threadpoolctl==3.1.0
|
||||
- tinycss2==1.1.1
|
||||
- torch==1.11.0
|
||||
- torch-cluster==1.6.0
|
||||
- torch-geometric==2.1.0.post1
|
||||
- torch-geometric==2.0.4
|
||||
- torch-scatter==2.0.9
|
||||
- torch-sparse==0.6.15
|
||||
- torch-sparse==0.6.14
|
||||
- torch-spline-conv==1.2.1
|
||||
- tqdm==4.64.1
|
||||
- torchaudio==0.11.0
|
||||
- torchvision==0.12.0
|
||||
- tqdm==4.64.0
|
||||
- traitlets==5.3.0
|
||||
- typing-extensions==4.3.0
|
||||
- urllib3==1.26.9
|
||||
- wandb==0.12.20
|
||||
- wcwidth==0.2.5
|
||||
- webencodings==0.5.1
|
||||
- widgetsnbextension==3.6.1
|
||||
|
||||
1203
evaluate.py
1203
evaluate.py
File diff suppressed because it is too large
Load Diff
@@ -1,361 +0,0 @@
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import pandas as pd
|
||||
import plotly.express as px
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
from utils.utils import read_strings_from_txt
|
||||
|
||||
parser = ArgumentParser()
|
||||
|
||||
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--results_path', type=str, default='inference_out_dir_not_specified/TEST_top40_epoch75_FILTER_restart_cacheNewRestart_big_ema_ESM2emb_tr34_WITH_fixedSamples28_id1_FILTERFROM_temp_restart_ema_ESM2emb_tr34', help='')
|
||||
parser.add_argument('--gnina_results_path', type=str, default='results/gnina_rosetta13', help='')
|
||||
parser.add_argument('--smina_results_path', type=str, default='results/smina_rosetta13', help='')
|
||||
parser.add_argument('--glide_results_path', type=str, default='results/glide', help='')
|
||||
parser.add_argument('--qvinaw_results_path', type=str, default='results/qvinaw', help='')
|
||||
parser.add_argument('--tankbind_results_path', type=str, default='results/tankbind_top5', help='')
|
||||
parser.add_argument('--equibind_results_path', type=str, default='results/equibind_paper', help='')
|
||||
parser.add_argument('--no_rec_overlap', action='store_true', default=False, help='')
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
|
||||
min_cross_distances = np.load(f'{args.results_path}/min_cross_distances.npy')
|
||||
#min_self_distances = np.load(f'{args.results_path}/min_self_distances.npy')
|
||||
base_min_cross_distances = np.load(f'{args.results_path}/base_min_cross_distances.npy')
|
||||
rmsds = np.load(f'{args.results_path}/rmsds.npy')
|
||||
centroid_distances = np.load(f'{args.results_path}/centroid_distances.npy')
|
||||
confidences = np.load(f'{args.results_path}/confidences.npy')
|
||||
#complex_names = np.load(f'{args.results_path}/complex_names.npy')
|
||||
complex_names = read_strings_from_txt('data/splits/timesplit_test')
|
||||
if args.no_rec_overlap:
|
||||
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
|
||||
without_rec_overlap_list = []
|
||||
for name in complex_names:
|
||||
if name in names_no_rec_overlap:
|
||||
without_rec_overlap_list.append(1)
|
||||
else:
|
||||
without_rec_overlap_list.append(0)
|
||||
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
|
||||
rmsds = np.array(rmsds)[without_rec_overlap]
|
||||
#min_self_distances = np.array(min_self_distances)[without_rec_overlap]
|
||||
centroid_distances = np.array(centroid_distances)[without_rec_overlap]
|
||||
confidences = np.array(confidences)[without_rec_overlap]
|
||||
min_cross_distances = np.array(min_cross_distances)[without_rec_overlap]
|
||||
base_min_cross_distances = np.array(base_min_cross_distances)[without_rec_overlap]
|
||||
complex_names = names_no_rec_overlap
|
||||
|
||||
|
||||
|
||||
|
||||
N = rmsds.shape[1]
|
||||
performance_metrics = {
|
||||
'steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / N).__round__(2),
|
||||
'mean_rmsd': rmsds.mean(),
|
||||
'rmsds_below_2': (100 * (rmsds < 2).sum() / len(rmsds) / N),
|
||||
'rmsds_below_5': (100 * (rmsds < 5).sum() / len(rmsds) / N),
|
||||
'rmsds_percentile_25': np.percentile(rmsds, 25).round(2),
|
||||
'rmsds_percentile_50': np.percentile(rmsds, 50).round(2),
|
||||
'rmsds_percentile_75': np.percentile(rmsds, 75).round(2),
|
||||
|
||||
'mean_centroid': centroid_distances.mean().__round__(2),
|
||||
'centroid_below_2': (100 * (centroid_distances < 2).sum() / len(centroid_distances) / N).__round__(2),
|
||||
'centroid_below_5': (100 * (centroid_distances < 5).sum() / len(centroid_distances) / N).__round__(2),
|
||||
'centroid_percentile_25': np.percentile(centroid_distances, 25).round(2),
|
||||
'centroid_percentile_50': np.percentile(centroid_distances, 50).round(2),
|
||||
'centroid_percentile_75': np.percentile(centroid_distances, 75).round(2),
|
||||
}
|
||||
|
||||
if N >= 5:
|
||||
top5_rmsds = np.min(rmsds[:, :5], axis=1)
|
||||
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :5], axis=1)][ :, 0]
|
||||
top5_min_cross_distances = min_cross_distances[ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :5], axis=1)][:, 0]
|
||||
performance_metrics.update({
|
||||
'top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
|
||||
'top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
|
||||
'top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
|
||||
'top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
|
||||
'top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
|
||||
'top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
|
||||
|
||||
'top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
'top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
'top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
|
||||
'top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
|
||||
'top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
if N >= 10:
|
||||
top10_rmsds = np.min(rmsds[:, :10], axis=1)
|
||||
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :10], axis=1)][:, 0]
|
||||
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :10], axis=1)][:, 0]
|
||||
performance_metrics.update({
|
||||
'top10_steric_clash_fraction': (100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
|
||||
'top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
|
||||
'top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
|
||||
'top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
|
||||
'top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
|
||||
'top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
|
||||
|
||||
'top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
'top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
'top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
|
||||
'top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
|
||||
'top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
|
||||
confidence_ordering = np.argsort(confidences,axis=1)[:,::-1]
|
||||
filtered_rmsds = rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,0]
|
||||
filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,0]
|
||||
filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, 0]
|
||||
performance_metrics.update({
|
||||
'filtered_steric_clash_fraction': (100 * (filtered_min_cross_distances < 0.4).sum() / len(filtered_min_cross_distances)).__round__(2),
|
||||
'filtered_rmsds_below_2': (100 * (filtered_rmsds < 2).sum() / len(filtered_rmsds)).__round__(2),
|
||||
'filtered_rmsds_below_5': (100 * (filtered_rmsds < 5).sum() / len(filtered_rmsds)).__round__(2),
|
||||
'filtered_rmsds_percentile_25': np.percentile(filtered_rmsds, 25).round(2),
|
||||
'filtered_rmsds_percentile_50': np.percentile(filtered_rmsds, 50).round(2),
|
||||
'filtered_rmsds_percentile_75': np.percentile(filtered_rmsds, 75).round(2),
|
||||
|
||||
'filtered_centroid_below_2': (100 * (filtered_centroid_distances < 2).sum() / len(filtered_centroid_distances)).__round__(2),
|
||||
'filtered_centroid_below_5': (100 * (filtered_centroid_distances < 5).sum() / len(filtered_centroid_distances)).__round__(2),
|
||||
'filtered_centroid_percentile_25': np.percentile(filtered_centroid_distances, 25).round(2),
|
||||
'filtered_centroid_percentile_50': np.percentile(filtered_centroid_distances, 50).round(2),
|
||||
'filtered_centroid_percentile_75': np.percentile(filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
if N >= 5:
|
||||
top5_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:5], axis=1)
|
||||
top5_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:5][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :5], axis=1)][:, 0]
|
||||
top5_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, :5][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :5], axis=1)][:, 0]
|
||||
performance_metrics.update({
|
||||
'top5_filtered_steric_clash_fraction': (100 * (top5_filtered_min_cross_distances < 0.4).sum() / len(top5_filtered_min_cross_distances)).__round__(2),
|
||||
'top5_filtered_rmsds_below_2': (100 * (top5_filtered_rmsds < 2).sum() / len(top5_filtered_rmsds)).__round__(2),
|
||||
'top5_filtered_rmsds_below_5': (100 * (top5_filtered_rmsds < 5).sum() / len(top5_filtered_rmsds)).__round__(2),
|
||||
'top5_filtered_rmsds_percentile_25': np.percentile(top5_filtered_rmsds, 25).round(2),
|
||||
'top5_filtered_rmsds_percentile_50': np.percentile(top5_filtered_rmsds, 50).round(2),
|
||||
'top5_filtered_rmsds_percentile_75': np.percentile(top5_filtered_rmsds, 75).round(2),
|
||||
|
||||
'top5_filtered_centroid_below_2': (100 * (top5_filtered_centroid_distances < 2).sum() / len(top5_filtered_centroid_distances)).__round__(2),
|
||||
'top5_filtered_centroid_below_5': (100 * (top5_filtered_centroid_distances < 5).sum() / len(top5_filtered_centroid_distances)).__round__(2),
|
||||
'top5_filtered_centroid_percentile_25': np.percentile(top5_filtered_centroid_distances, 25).round(2),
|
||||
'top5_filtered_centroid_percentile_50': np.percentile(top5_filtered_centroid_distances, 50).round(2),
|
||||
'top5_filtered_centroid_percentile_75': np.percentile(top5_filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
if N >= 10:
|
||||
top10_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:10], axis=1)
|
||||
top10_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:10][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :10], axis=1)][:, 0]
|
||||
top10_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, :10][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :10], axis=1)][:, 0]
|
||||
performance_metrics.update({
|
||||
'top10_filtered_steric_clash_fraction': (100 * (top10_filtered_min_cross_distances < 0.4).sum() / len(top10_filtered_min_cross_distances)).__round__(2),
|
||||
'top10_filtered_rmsds_below_2': (100 * (top10_filtered_rmsds < 2).sum() / len(top10_filtered_rmsds)).__round__(2),
|
||||
'top10_filtered_rmsds_below_5': (100 * (top10_filtered_rmsds < 5).sum() / len(top10_filtered_rmsds)).__round__(2),
|
||||
'top10_filtered_rmsds_percentile_25': np.percentile(top10_filtered_rmsds, 25).round(2),
|
||||
'top10_filtered_rmsds_percentile_50': np.percentile(top10_filtered_rmsds, 50).round(2),
|
||||
'top10_filtered_rmsds_percentile_75': np.percentile(top10_filtered_rmsds, 75).round(2),
|
||||
|
||||
'top10_filtered_centroid_below_2': (100 * (top10_filtered_centroid_distances < 2).sum() / len(top10_filtered_centroid_distances)).__round__(2),
|
||||
'top10_filtered_centroid_below_5': (100 * (top10_filtered_centroid_distances < 5).sum() / len(top10_filtered_centroid_distances)).__round__(2),
|
||||
'top10_filtered_centroid_percentile_25': np.percentile(top10_filtered_centroid_distances, 25).round(2),
|
||||
'top10_filtered_centroid_percentile_50': np.percentile(top10_filtered_centroid_distances, 50).round(2),
|
||||
'top10_filtered_centroid_percentile_75': np.percentile(top10_filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
reverse_confidence_ordering = np.argsort(confidences,axis=1)
|
||||
reverse_filtered_rmsds = rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
|
||||
reverse_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
|
||||
reverse_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
|
||||
performance_metrics.update({
|
||||
'reversefiltered_steric_clash_fraction': (100 * (reverse_filtered_min_cross_distances < 0.4).sum() / len(reverse_filtered_min_cross_distances)).__round__(2),
|
||||
'reversefiltered_rmsds_below_2': (100 * (reverse_filtered_rmsds < 2).sum() / len(reverse_filtered_rmsds)).__round__(2),
|
||||
'reversefiltered_rmsds_below_5': (100 * (reverse_filtered_rmsds < 5).sum() / len(reverse_filtered_rmsds)).__round__(2),
|
||||
'reversefiltered_rmsds_percentile_25': np.percentile(reverse_filtered_rmsds, 25).round(2),
|
||||
'reversefiltered_rmsds_percentile_50': np.percentile(reverse_filtered_rmsds, 50).round(2),
|
||||
'reversefiltered_rmsds_percentile_75': np.percentile(reverse_filtered_rmsds, 75).round(2),
|
||||
|
||||
'reversefiltered_centroid_below_2': (100 * (reverse_filtered_centroid_distances < 2).sum() / len(reverse_filtered_centroid_distances)).__round__(2),
|
||||
'reversefiltered_centroid_below_5': (100 * (reverse_filtered_centroid_distances < 5).sum() / len(reverse_filtered_centroid_distances)).__round__(2),
|
||||
'reversefiltered_centroid_percentile_25': np.percentile(reverse_filtered_centroid_distances, 25).round(2),
|
||||
'reversefiltered_centroid_percentile_50': np.percentile(reverse_filtered_centroid_distances, 50).round(2),
|
||||
'reversefiltered_centroid_percentile_75': np.percentile(reverse_filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
if N >= 5:
|
||||
top5_reverse_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
|
||||
top5_reverse_filtered_centroid_distances = np.min(centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
|
||||
top5_reverse_filtered_min_cross_distances = np.max(min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
|
||||
performance_metrics.update({
|
||||
'top5_reverse_filtered_steric_clash_fraction': (100 * (top5_reverse_filtered_min_cross_distances < 0.4).sum() / len(top5_reverse_filtered_min_cross_distances)).__round__(2),
|
||||
'top5_reversefiltered_rmsds_below_2': (100 * (top5_reverse_filtered_rmsds < 2).sum() / len(top5_reverse_filtered_rmsds)).__round__(2),
|
||||
'top5_reversefiltered_rmsds_below_5': (100 * (top5_reverse_filtered_rmsds < 5).sum() / len(top5_reverse_filtered_rmsds)).__round__(2),
|
||||
'top5_reversefiltered_rmsds_percentile_25': np.percentile(top5_reverse_filtered_rmsds, 25).round(2),
|
||||
'top5_reversefiltered_rmsds_percentile_50': np.percentile(top5_reverse_filtered_rmsds, 50).round(2),
|
||||
'top5_reversefiltered_rmsds_percentile_75': np.percentile(top5_reverse_filtered_rmsds, 75).round(2),
|
||||
|
||||
'top5_reversefiltered_centroid_below_2': (100 * (top5_reverse_filtered_centroid_distances < 2).sum() / len(top5_reverse_filtered_centroid_distances)).__round__(2),
|
||||
'top5_reversefiltered_centroid_below_5': (100 * (top5_reverse_filtered_centroid_distances < 5).sum() / len(top5_reverse_filtered_centroid_distances)).__round__(2),
|
||||
'top5_reversefiltered_centroid_percentile_25': np.percentile(top5_reverse_filtered_centroid_distances, 25).round(2),
|
||||
'top5_reversefiltered_centroid_percentile_50': np.percentile(top5_reverse_filtered_centroid_distances, 50).round(2),
|
||||
'top5_reversefiltered_centroid_percentile_75': np.percentile(top5_reverse_filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
if N >= 10:
|
||||
top10_reverse_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
|
||||
top10_reverse_filtered_centroid_distances = np.min(centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
|
||||
top10_reverse_filtered_min_cross_distances = np.max(min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
|
||||
performance_metrics.update({
|
||||
'top10_reverse_filtered_steric_clash_fraction': (100 * (top10_reverse_filtered_min_cross_distances < 0.4).sum() / len(top10_reverse_filtered_min_cross_distances)).__round__(2),
|
||||
'top10_reversefiltered_rmsds_below_2': (100 * (top10_reverse_filtered_rmsds < 2).sum() / len(top10_reverse_filtered_rmsds)).__round__(2),
|
||||
'top10_reversefiltered_rmsds_below_5': (100 * (top10_reverse_filtered_rmsds < 5).sum() / len(top10_reverse_filtered_rmsds)).__round__(2),
|
||||
'top10_reversefiltered_rmsds_percentile_25': np.percentile(top10_reverse_filtered_rmsds, 25).round(2),
|
||||
'top10_reversefiltered_rmsds_percentile_50': np.percentile(top10_reverse_filtered_rmsds, 50).round(2),
|
||||
'top10_reversefiltered_rmsds_percentile_75': np.percentile(top10_reverse_filtered_rmsds, 75).round(2),
|
||||
|
||||
'top10_reversefiltered_centroid_below_2': (100 * (top10_reverse_filtered_centroid_distances < 2).sum() / len(top10_reverse_filtered_centroid_distances)).__round__(2),
|
||||
'top10_reversefiltered_centroid_below_5': (100 * (top10_reverse_filtered_centroid_distances < 5).sum() / len(top10_reverse_filtered_centroid_distances)).__round__(2),
|
||||
'top10_reversefiltered_centroid_percentile_25': np.percentile(top10_reverse_filtered_centroid_distances, 25).round(2),
|
||||
'top10_reversefiltered_centroid_percentile_50': np.percentile(top10_reverse_filtered_centroid_distances, 50).round(2),
|
||||
'top10_reversefiltered_centroid_percentile_75': np.percentile(top10_reverse_filtered_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
filtered_confidences = confidences[np.arange(confidences.shape[0])[:,None],confidence_ordering][:,0]
|
||||
|
||||
confident_mask = filtered_confidences > 0
|
||||
confident_rmsds = filtered_rmsds[confident_mask]
|
||||
confident_centroid_distances = filtered_centroid_distances[confident_mask]
|
||||
confident_min_cross_distances = filtered_min_cross_distances[confident_mask]
|
||||
|
||||
performance_metrics.update({
|
||||
'fraction_confident_predictions': (100 * len(confident_rmsds) / len(rmsds)).__round__(2),
|
||||
'confident_steric_clash_fraction': (100 * (confident_min_cross_distances < 0.4).sum() / len(confident_min_cross_distances)).__round__(2),
|
||||
'confident_rmsds_below_2': (100 * (confident_rmsds < 2).sum() / len(confident_rmsds)).__round__(2),
|
||||
'confident_rmsds_below_5': (100 * (confident_rmsds < 5).sum() / len(confident_rmsds)).__round__(2),
|
||||
'confident_rmsds_percentile_25': np.percentile(confident_rmsds, 25).round(2),
|
||||
'confident_rmsds_percentile_50': np.percentile(confident_rmsds, 50).round(2),
|
||||
'confident_rmsds_percentile_75': np.percentile(confident_rmsds, 75).round(2),
|
||||
|
||||
'confident_centroid_below_2': (100 * (confident_centroid_distances < 2).sum() / len(confident_centroid_distances)).__round__(2),
|
||||
'confident_centroid_below_5': (100 * (confident_centroid_distances < 5).sum() / len(confident_centroid_distances)).__round__(2),
|
||||
'confident_centroid_percentile_25': np.percentile(confident_centroid_distances, 25).round(2),
|
||||
'confident_centroid_percentile_50': np.percentile(confident_centroid_distances, 50).round(2),
|
||||
'confident_centroid_percentile_75': np.percentile(confident_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
for k in performance_metrics:
|
||||
print(k, performance_metrics[k])
|
||||
|
||||
fraction_dataset_rmsds_below_2 = []
|
||||
perfect_calibration = []
|
||||
no_calibration = []
|
||||
for dataset_percentage in range(100):
|
||||
dataset_percentage += 1
|
||||
dataset_fraction = (dataset_percentage)/100
|
||||
num_samples = round(len(rmsds)*dataset_fraction)
|
||||
per_complex_confidence_ordering = np.argsort(filtered_confidences)[::-1]
|
||||
confident_complexes_rmsds = filtered_rmsds[per_complex_confidence_ordering][:num_samples]
|
||||
confident_complexes_centroid_distances = filtered_centroid_distances[per_complex_confidence_ordering][:num_samples]
|
||||
confident_complexes_min_cross_distances = filtered_min_cross_distances[per_complex_confidence_ordering][:num_samples]
|
||||
confident_complexes_metrics = {
|
||||
'fraction_confident_complexes_predictions': (100 * len(confident_complexes_rmsds) / len(rmsds)).__round__(2),
|
||||
'confident_complexes_steric_clash_fraction': (100 * (confident_complexes_min_cross_distances < 0.4).sum() / len(confident_complexes_min_cross_distances)).__round__(2),
|
||||
'confident_complexes_rmsds_below_2': (100 * (confident_complexes_rmsds < 2).sum() / len(confident_complexes_rmsds)).__round__(2),
|
||||
'confident_complexes_rmsds_below_5': (100 * (confident_complexes_rmsds < 5).sum() / len(confident_complexes_rmsds)).__round__(2),
|
||||
'confident_complexes_rmsds_percentile_25': np.percentile(confident_complexes_rmsds, 25).round(2),
|
||||
'confident_complexes_rmsds_percentile_50': np.percentile(confident_complexes_rmsds, 50).round(2),
|
||||
'confident_complexes_rmsds_percentile_75': np.percentile(confident_complexes_rmsds, 75).round(2),
|
||||
|
||||
'confident_complexes_centroid_below_2': (100 * (confident_complexes_centroid_distances < 2).sum() / len(confident_complexes_centroid_distances)).__round__(2),
|
||||
'confident_complexes_centroid_below_5': (100 * (confident_complexes_centroid_distances < 5).sum() / len(confident_complexes_centroid_distances)).__round__(2),
|
||||
'confident_complexes_centroid_percentile_25': np.percentile(confident_complexes_centroid_distances, 25).round(2),
|
||||
'confident_complexes_centroid_percentile_50': np.percentile(confident_complexes_centroid_distances, 50).round(2),
|
||||
'confident_complexes_centroid_percentile_75': np.percentile(confident_complexes_centroid_distances, 75).round(2),
|
||||
}
|
||||
fraction_dataset_rmsds_below_2.append(confident_complexes_metrics['confident_complexes_rmsds_below_2'])
|
||||
perfect_calibration.append((100 * (np.sort(filtered_rmsds)[:num_samples] < 2).sum() / len(confident_complexes_rmsds)).__round__(2))
|
||||
no_calibration.append(performance_metrics['filtered_rmsds_below_2'])
|
||||
#print('percentage: ',dataset_percentage)
|
||||
#print(confident_complexes_metrics['confident_complexes_rmsds_below_2'])
|
||||
|
||||
print(scipy.stats.spearmanr(filtered_rmsds, filtered_confidences))
|
||||
df = {'conf': filtered_confidences, 'rmsd': filtered_rmsds}
|
||||
fig = px.scatter(df, x='rmsd',y='conf').update_layout(
|
||||
xaxis_title="Percentage of datapoints that may be abstained", yaxis_title="Percentage of predictions with RMSD < 2A"
|
||||
)
|
||||
fig.update_layout(margin={'l': 0, 'r': 0, 't': 20, 'b': 100}, plot_bgcolor='white',
|
||||
paper_bgcolor='white', legend_title_text='', legend_title_font_size=1,
|
||||
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ),
|
||||
)
|
||||
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
|
||||
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
|
||||
fig.show()
|
||||
|
||||
df = {'Confidence Model': reversed(fraction_dataset_rmsds_below_2),'No Calibration': reversed(no_calibration),'Perfect Calibration': reversed(perfect_calibration),}
|
||||
fig = px.line(df, y=list(df.keys())).update_layout(
|
||||
xaxis_title="Percentage of datapoints that may be abstained", yaxis_title="Percentage of predictions with RMSD < 2A"
|
||||
)
|
||||
fig.update_yaxes(range = [0,103])
|
||||
fig.update_layout(margin={'l': 0, 'r': 0, 't': 20, 'b': 100}, plot_bgcolor='white',
|
||||
paper_bgcolor='white', legend_title_text='', legend_title_font_size=1,
|
||||
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ),
|
||||
)
|
||||
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
|
||||
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
|
||||
fig.write_image('results/confidence_calibration.pdf')
|
||||
fig.show()
|
||||
|
||||
def filter_by_names(method_names, method_array, names_to_keep):
|
||||
output_array = []
|
||||
output_names = []
|
||||
for method_name, array_element in zip(method_names,method_array):
|
||||
if method_name in names_to_keep:
|
||||
output_array.append(array_element)
|
||||
output_names.append(method_name)
|
||||
return np.array(output_array), np.array(output_names)
|
||||
|
||||
qvinaw_rmsds = np.load(os.path.join(args.qvinaw_results_path, 'rmsds.npy'))
|
||||
qvinaw_names = np.load(os.path.join(args.qvinaw_results_path, 'names.npy'))
|
||||
qvinaw_rmsds, qvinaw_names = filter_by_names(qvinaw_names, qvinaw_rmsds, complex_names)
|
||||
qvinaw_rmsds = np.concatenate([qvinaw_rmsds, np.random.choice(qvinaw_rmsds, size=len(complex_names) - len(qvinaw_rmsds))])
|
||||
|
||||
glide_rmsds = np.load(os.path.join(args.glide_results_path, 'rmsds.npy'))
|
||||
glide_names = np.load(os.path.join(args.glide_results_path, 'names.npy')).tolist()
|
||||
glide_rmsds, glide_names = filter_by_names(glide_names, glide_rmsds, complex_names)
|
||||
glide_rmsds = np.concatenate([glide_rmsds, np.random.choice(glide_rmsds, size=len(complex_names) - len(glide_rmsds))])
|
||||
|
||||
smina_rmsds = np.load(os.path.join(args.smina_results_path, 'rmsds.npy'))[:,0]
|
||||
smina_names = np.load(os.path.join(args.smina_results_path, 'names.npy'))
|
||||
smina_rmsds, smina_names = filter_by_names(smina_names, smina_rmsds, complex_names)
|
||||
smina_rmsds = np.concatenate([smina_rmsds, np.random.choice(smina_rmsds, size=len(complex_names) - len(smina_rmsds))])
|
||||
|
||||
gnina_rmsds = np.load(os.path.join(args.gnina_results_path, 'rmsds.npy'))[:,0]
|
||||
gnina_names = np.load(os.path.join(args.gnina_results_path, 'names.npy'))
|
||||
gnina_rmsds, gnina_names = filter_by_names(gnina_names, gnina_rmsds, complex_names)
|
||||
gnina_rmsds = np.concatenate([gnina_rmsds, np.random.choice(gnina_rmsds, size=len(complex_names) - len(gnina_rmsds))])
|
||||
|
||||
tankbind_rmsds = np.load(os.path.join(args.tankbind_results_path, 'rmsds.npy'))[:,0]
|
||||
tankbind_names = np.load(os.path.join(args.tankbind_results_path, 'names.npy'))
|
||||
tankbind_rmsds, tankbind_names = filter_by_names(tankbind_names, tankbind_rmsds, complex_names)
|
||||
|
||||
equibind_rmsds = np.load(os.path.join(args.equibind_results_path, 'rmsds.npy'))
|
||||
equibind_names = np.load(os.path.join(args.equibind_results_path, 'names.npy'))
|
||||
equibind_rmsds, equibind_names = filter_by_names(equibind_names, equibind_rmsds, complex_names)
|
||||
|
||||
|
||||
df = {'DiffDock': filtered_rmsds, 'GLIDE': glide_rmsds, 'GNINA': gnina_rmsds, 'SMINA': smina_rmsds, 'QVinaW':qvinaw_rmsds, 'TANKBind': tankbind_rmsds, 'EquiBind': equibind_rmsds}
|
||||
fig = px.ecdf(df, range_x=[0, 5], range_y=[0.001, 0.75], width=600, height=400)
|
||||
fig.add_vline(x=2, annotation_text='', annotation_font_size=20, annotation_position="top right",
|
||||
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
|
||||
fig.update_xaxes(title=f'RMSD (Å)')
|
||||
fig.update_yaxes(title=f'Fraction with lower RMSD')
|
||||
fig.update_layout(autosize=False, margin={'l': 65, 'r': 5, 't': 5, 'b': 60}, plot_bgcolor='white',
|
||||
paper_bgcolor='white', legend_title_text='', legend_title_font_size=18,
|
||||
legend=dict(yanchor="top", y=0.995, xanchor="left", x=0.02, font=dict(size=18, color='black'), ), )
|
||||
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=23, color='black'),mirror=True,ticks='outside',showline=True, linewidth=1, linecolor='black', tickfont = dict(size = 18, color='black'))
|
||||
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=23, color='black'),mirror=True,ticks='outside',showline=True, linewidth=1, linecolor='black', tickfont = dict(size = 18, color='black'))
|
||||
fig.update_traces(line=dict(width=3))
|
||||
fig.write_image('results/rmsds_nooverlap.pdf')
|
||||
fig.show()
|
||||
@@ -1,180 +0,0 @@
|
||||
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as initial pose
|
||||
import os
|
||||
import time
|
||||
from argparse import FileType, ArgumentParser
|
||||
|
||||
import numpy as np
|
||||
from biopandas.pdb import PandasPdb
|
||||
from rdkit import Chem
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from datasets.pdbbind import read_mol
|
||||
from datasets.process_mols import read_molecule
|
||||
from utils.utils import read_strings_from_txt, get_symmetry_rmsd
|
||||
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--config', type=FileType(mode='r'), default=None)
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
|
||||
parser.add_argument('--results_path', type=str, default='results/user_predictions_testset', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand.pdb', help='Path to folder with trained model and hyperparameters')
|
||||
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
|
||||
parser.add_argument('--file_to_exclude', type=str, default=None, help='')
|
||||
parser.add_argument('--all_dirs_in_results', action='store_true', default=True, help='Evaluate all directories in the results path instead of using directly looking for the names')
|
||||
parser.add_argument('--num_predictions', type=int, default=10, help='')
|
||||
parser.add_argument('--no_id_in_filename', action='store_true', default=False, help='')
|
||||
parser.add_argument('--test_names_path', type=str, default='data/splits/timesplit_test', help='Path to text file with the folder names in the test set')
|
||||
parser.add_argument('--no_overlap_names_path', type=str, default='data/splits/timesplit_test_no_rec_overlap', help='Path text file with the folder names in the test set that have no receptor overlap with the train set')
|
||||
args = parser.parse_args()
|
||||
|
||||
print('Reading paths and names.')
|
||||
names = read_strings_from_txt(args.test_names_path)
|
||||
names_no_rec_overlap = read_strings_from_txt(args.no_overlap_names_path)
|
||||
results_path_containments = os.listdir(args.results_path)
|
||||
|
||||
all_times = []
|
||||
successful_names_list = []
|
||||
rmsds_list = []
|
||||
centroid_distances_list = []
|
||||
min_cross_distances_list = []
|
||||
min_self_distances_list = []
|
||||
without_rec_overlap_list = []
|
||||
start_time = time.time()
|
||||
for i, name in enumerate(tqdm(names)):
|
||||
mol = read_mol(args.data_dir, name, remove_hs=True)
|
||||
mol = Chem.RemoveAllHs(mol)
|
||||
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
|
||||
|
||||
if args.all_dirs_in_results:
|
||||
directory_with_name_list = [directory for directory in results_path_containments if name in directory]
|
||||
if directory_with_name_list == []:
|
||||
print('Did not find a directory for ', name, '. We are skipping that complex')
|
||||
continue
|
||||
else:
|
||||
directory_with_name = directory_with_name_list[0]
|
||||
ligand_pos = []
|
||||
debug_paths = []
|
||||
for i in range(args.num_predictions):
|
||||
file_paths = sorted(os.listdir(os.path.join(args.results_path, directory_with_name)))
|
||||
if args.file_to_exclude is not None:
|
||||
file_paths = [path for path in file_paths if not args.file_to_exclude in path]
|
||||
file_path = [path for path in file_paths if f'rank{i+1}_' in path][0]
|
||||
mol_pred = read_molecule(os.path.join(args.results_path, directory_with_name, file_path),remove_hs=True, sanitize=True)
|
||||
mol_pred = Chem.RemoveAllHs(mol_pred)
|
||||
ligand_pos.append(mol_pred.GetConformer().GetPositions())
|
||||
debug_paths.append(file_path)
|
||||
ligand_pos = np.asarray(ligand_pos)
|
||||
else:
|
||||
if not os.path.exists(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}')): raise Exception('path did not exists:', os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'))
|
||||
mol_pred = read_molecule(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'), remove_hs=True, sanitize=True)
|
||||
if mol_pred == None:
|
||||
print("Skipping ", name, ' because RDKIT could not read it.')
|
||||
continue
|
||||
mol_pred = Chem.RemoveAllHs(mol_pred)
|
||||
ligand_pos = np.asarray([np.array(mol_pred.GetConformer(i).GetPositions()) for i in range(args.num_predictions)])
|
||||
try:
|
||||
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred)
|
||||
except Exception as e:
|
||||
print("Using non corrected RMSD because of the error:", e)
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
|
||||
|
||||
rmsds_list.append(rmsd)
|
||||
centroid_distances_list.append(np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1))
|
||||
|
||||
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
|
||||
if not os.path.exists(rec_path):
|
||||
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
|
||||
rec = PandasPdb().read_pdb(rec_path)
|
||||
rec_df = rec.df['ATOM']
|
||||
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
|
||||
receptor_pos = np.tile(receptor_pos, (args.num_predictions, 1, 1))
|
||||
|
||||
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
|
||||
self_distances = np.linalg.norm(ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
|
||||
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
|
||||
min_cross_distances_list.append(np.min(cross_distances, axis=(1,2)))
|
||||
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
|
||||
successful_names_list.append(name)
|
||||
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
|
||||
performance_metrics = {}
|
||||
for overlap in ['', 'no_overlap_']:
|
||||
if 'no_overlap_' == overlap:
|
||||
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
|
||||
rmsds = np.array(rmsds_list)[without_rec_overlap]
|
||||
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
|
||||
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
|
||||
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
|
||||
successful_names = np.array(successful_names_list)[without_rec_overlap]
|
||||
else:
|
||||
rmsds = np.array(rmsds_list)
|
||||
centroid_distances = np.array(centroid_distances_list)
|
||||
min_cross_distances = np.array(min_cross_distances_list)
|
||||
min_self_distances = np.array(min_self_distances_list)
|
||||
successful_names = np.array(successful_names_list)
|
||||
|
||||
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), successful_names)
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), np.array(min_cross_distances))
|
||||
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), np.array(min_self_distances))
|
||||
|
||||
performance_metrics.update({
|
||||
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / args.num_predictions).__round__(2),
|
||||
f'{overlap}self_intersect_fraction': (100 * (min_self_distances < 0.4).sum() / len(min_self_distances) / args.num_predictions).__round__(2),
|
||||
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
|
||||
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
|
||||
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
|
||||
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
|
||||
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
|
||||
|
||||
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
|
||||
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
|
||||
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
|
||||
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
|
||||
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
|
||||
})
|
||||
|
||||
top5_rmsds = np.min(rmsds[:, :5], axis=1)
|
||||
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
|
||||
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
|
||||
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
|
||||
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
|
||||
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
|
||||
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
|
||||
})
|
||||
|
||||
|
||||
top10_rmsds = np.min(rmsds[:, :10], axis=1)
|
||||
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
|
||||
performance_metrics.update({
|
||||
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
|
||||
f'{overlap}top10_steric_clash_fraction': ( 100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
|
||||
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
|
||||
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
|
||||
|
||||
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
|
||||
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
|
||||
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
|
||||
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
|
||||
})
|
||||
for k in performance_metrics:
|
||||
print(k, performance_metrics[k])
|
||||
|
||||
87
inference.py
87
inference.py
@@ -1,13 +1,13 @@
|
||||
import copy
|
||||
import os
|
||||
import torch
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from rdkit.Chem import RemoveHs
|
||||
from argparse import ArgumentParser, Namespace, FileType
|
||||
from functools import partial
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from rdkit import RDLogger
|
||||
from torch_geometric.loader import DataLoader
|
||||
from rdkit.Chem import RemoveAllHs
|
||||
|
||||
from datasets.process_mols import write_mol_with_coords
|
||||
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
|
||||
@@ -20,7 +20,8 @@ from tqdm import tqdm
|
||||
RDLogger.DisableLog('rdApp.*')
|
||||
import yaml
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--protein_ligand_csv', type=str, default=None, help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters')
|
||||
parser.add_argument('--config', type=FileType(mode='r'), default='inference_args.yaml')
|
||||
parser.add_argument('--protein_ligand_csv', type=str, default="data/protein_ligand_example_csv.csv", help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters')
|
||||
parser.add_argument('--complex_name', type=str, default='1a0q', help='Name that the complex will be saved with')
|
||||
parser.add_argument('--protein_path', type=str, default=None, help='Path to the protein file')
|
||||
parser.add_argument('--protein_sequence', type=str, default=None, help='Sequence of the protein for ESMFold, this is ignored if --protein_path is not None')
|
||||
@@ -30,17 +31,51 @@ parser.add_argument('--out_dir', type=str, default='results/user_inference', hel
|
||||
parser.add_argument('--save_visualisation', action='store_true', default=False, help='Save a pdb file with all of the steps of the reverse diffusion')
|
||||
parser.add_argument('--samples_per_complex', type=int, default=10, help='Number of samples to generate')
|
||||
|
||||
parser.add_argument('--model_dir', type=str, default='workdir/paper_score_model', help='Path to folder with trained score model and hyperparameters')
|
||||
parser.add_argument('--model_dir', type=str, default=None, help='Path to folder with trained score model and hyperparameters')
|
||||
parser.add_argument('--ckpt', type=str, default='best_ema_inference_epoch_model.pt', help='Checkpoint to use for the score model')
|
||||
parser.add_argument('--confidence_model_dir', type=str, default='workdir/paper_confidence_model', help='Path to folder with trained confidence model and hyperparameters')
|
||||
parser.add_argument('--confidence_ckpt', type=str, default='best_model_epoch75.pt', help='Checkpoint to use for the confidence model')
|
||||
parser.add_argument('--confidence_model_dir', type=str, default=None, help='Path to folder with trained confidence model and hyperparameters')
|
||||
parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', help='Checkpoint to use for the confidence model')
|
||||
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='')
|
||||
parser.add_argument('--no_final_step_noise', action='store_true', default=False, help='Use no noise in the final step of the reverse diffusion')
|
||||
parser.add_argument('--batch_size', type=int, default=10, help='')
|
||||
parser.add_argument('--no_final_step_noise', action='store_true', default=True, help='Use no noise in the final step of the reverse diffusion')
|
||||
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps')
|
||||
parser.add_argument('--actual_steps', type=int, default=None, help='Number of denoising steps that are actually performed')
|
||||
|
||||
parser.add_argument('--old_score_model', action='store_true', default=False, help='')
|
||||
parser.add_argument('--old_confidence_model', action='store_true', default=True, help='')
|
||||
parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, help='Initial noise std proportion')
|
||||
parser.add_argument('--choose_residue', action='store_true', default=False, help='')
|
||||
|
||||
parser.add_argument('--temp_sampling_tr', type=float, default=1.0)
|
||||
parser.add_argument('--temp_psi_tr', type=float, default=0.0)
|
||||
parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5)
|
||||
parser.add_argument('--temp_sampling_rot', type=float, default=1.0)
|
||||
parser.add_argument('--temp_psi_rot', type=float, default=0.0)
|
||||
parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5)
|
||||
parser.add_argument('--temp_sampling_tor', type=float, default=1.0)
|
||||
parser.add_argument('--temp_psi_tor', type=float, default=0.0)
|
||||
parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5)
|
||||
|
||||
parser.add_argument('--gnina_minimize', action='store_true', default=False, help='')
|
||||
parser.add_argument('--gnina_path', type=str, default='gnina', help='')
|
||||
parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', help='') # To redirect gnina subprocesses stdouts from the terminal window
|
||||
parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='')
|
||||
parser.add_argument('--gnina_autobox_add', type=float, default=4.0)
|
||||
parser.add_argument('--gnina_poses_to_optimize', type=int, default=1)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.config:
|
||||
config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
|
||||
arg_dict = args.__dict__
|
||||
for key, value in config_dict.items():
|
||||
if isinstance(value, list):
|
||||
for v in value:
|
||||
arg_dict[key].append(v)
|
||||
else:
|
||||
arg_dict[key] = value
|
||||
# TODO check that the args are actually updated
|
||||
|
||||
os.makedirs(args.out_dir, exist_ok=True)
|
||||
with open(f'{args.model_dir}/model_parameters.yml') as f:
|
||||
score_model_args = Namespace(**yaml.full_load(f))
|
||||
@@ -70,11 +105,12 @@ for name in complex_name_list:
|
||||
# preprocessing of complexes into geometric graphs
|
||||
test_dataset = InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list,
|
||||
ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list,
|
||||
lm_embeddings=score_model_args.esm_embeddings_path is not None,
|
||||
lm_embeddings=True,
|
||||
receptor_radius=score_model_args.receptor_radius, remove_hs=score_model_args.remove_hs,
|
||||
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
|
||||
all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius,
|
||||
atom_max_neighbors=score_model_args.atom_max_neighbors)
|
||||
atom_max_neighbors=score_model_args.atom_max_neighbors,
|
||||
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph)
|
||||
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
|
||||
|
||||
if args.confidence_model_dir is not None and not confidence_args.use_original_model_cache:
|
||||
@@ -83,25 +119,27 @@ if args.confidence_model_dir is not None and not confidence_args.use_original_mo
|
||||
confidence_test_dataset = \
|
||||
InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list,
|
||||
ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list,
|
||||
lm_embeddings=confidence_args.esm_embeddings_path is not None,
|
||||
lm_embeddings=True,
|
||||
receptor_radius=confidence_args.receptor_radius, remove_hs=confidence_args.remove_hs,
|
||||
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
|
||||
all_atoms=confidence_args.all_atoms, atom_radius=confidence_args.atom_radius,
|
||||
atom_max_neighbors=confidence_args.atom_max_neighbors,
|
||||
precomputed_lm_embeddings=test_dataset.lm_embeddings)
|
||||
precomputed_lm_embeddings=test_dataset.lm_embeddings,
|
||||
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph)
|
||||
else:
|
||||
confidence_test_dataset = None
|
||||
|
||||
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
|
||||
|
||||
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True)
|
||||
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model)
|
||||
state_dict = torch.load(f'{args.model_dir}/{args.ckpt}', map_location=torch.device('cpu'))
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
if args.confidence_model_dir is not None:
|
||||
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, confidence_mode=True)
|
||||
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True,
|
||||
confidence_mode=True, old=args.old_confidence_model)
|
||||
state_dict = torch.load(f'{args.confidence_model_dir}/{args.confidence_ckpt}', map_location=torch.device('cpu'))
|
||||
confidence_model.load_state_dict(state_dict, strict=True)
|
||||
confidence_model = confidence_model.to(device)
|
||||
@@ -110,7 +148,7 @@ else:
|
||||
confidence_model = None
|
||||
confidence_args = None
|
||||
|
||||
tr_schedule = get_t_schedule(inference_steps=args.inference_steps)
|
||||
tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta')
|
||||
|
||||
failures, skipped = 0, 0
|
||||
N = args.samples_per_complex
|
||||
@@ -131,7 +169,10 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
|
||||
else:
|
||||
confidence_data_list = None
|
||||
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)]
|
||||
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max)
|
||||
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max,
|
||||
initial_noise_std_proportion=args.initial_noise_std_proportion,
|
||||
choose_residue=args.choose_residue)
|
||||
|
||||
lig = orig_complex_graph.mol[0]
|
||||
|
||||
# initialize visualisation
|
||||
@@ -154,7 +195,13 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
|
||||
device=device, t_to_sigma=t_to_sigma, model_args=score_model_args,
|
||||
visualization_list=visualization_list, confidence_model=confidence_model,
|
||||
confidence_data_list=confidence_data_list, confidence_model_args=confidence_args,
|
||||
batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise)
|
||||
batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise,
|
||||
temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot,
|
||||
args.temp_sampling_tor],
|
||||
temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor],
|
||||
temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot,
|
||||
args.temp_sigma_data_tor])
|
||||
|
||||
ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for complex_graph in data_list])
|
||||
|
||||
# reorder predictions based on confidence output
|
||||
@@ -170,7 +217,7 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
|
||||
write_dir = f'{args.out_dir}/{complex_name_list[idx]}'
|
||||
for rank, pos in enumerate(ligand_pos):
|
||||
mol_pred = copy.deepcopy(lig)
|
||||
if score_model_args.remove_hs: mol_pred = RemoveHs(mol_pred)
|
||||
if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred)
|
||||
if rank == 0: write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}.sdf'))
|
||||
write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}_confidence{confidence[rank]:.2f}.sdf'))
|
||||
|
||||
@@ -189,6 +236,4 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
|
||||
|
||||
print(f'Failed for {failures} complexes')
|
||||
print(f'Skipped {skipped} complexes')
|
||||
print(f'Results are in {args.out_dir}')
|
||||
|
||||
|
||||
print(f'Results are in {args.out_dir}')
|
||||
30
inference_args.yaml
Normal file
30
inference_args.yaml
Normal file
@@ -0,0 +1,30 @@
|
||||
actual_steps: 19
|
||||
ckpt: best_ema_inference_epoch_model.pt
|
||||
confidence_ckpt: best_model_epoch75.pt
|
||||
confidence_model_dir: workdir/paper_confidence_model
|
||||
different_schedules: false
|
||||
inf_sched_alpha: 1
|
||||
inf_sched_beta: 1
|
||||
inference_steps: 20
|
||||
initial_noise_std_proportion: 1.4601642460337794
|
||||
limit_failures: 5
|
||||
model_dir: workdir/diffdockL
|
||||
no_final_step_noise: true
|
||||
no_model: false
|
||||
no_random: false
|
||||
no_random_pocket: false
|
||||
ode: false
|
||||
old_filtering_model: true
|
||||
old_score_model: false
|
||||
resample_rdkit: false
|
||||
samples_per_complex: 10
|
||||
sigma_schedule: expbeta
|
||||
temp_psi_rot: 0.9022615585677628
|
||||
temp_psi_tor: 0.5946212391366862
|
||||
temp_psi_tr: 0.727287304570729
|
||||
temp_sampling_rot: 2.06391612594481
|
||||
temp_sampling_tor: 7.044261621607846
|
||||
temp_sampling_tr: 1.170050527854316
|
||||
temp_sigma_data_rot: 0.7464326999906034
|
||||
temp_sigma_data_tor: 0.6943254174849822
|
||||
temp_sigma_data_tr: 0.9299802531572672
|
||||
667
models/aa_model.py
Normal file
667
models/aa_model.py
Normal file
@@ -0,0 +1,667 @@
|
||||
from e3nn import o3
|
||||
import torch
|
||||
from esm.pretrained import load_model_and_alphabet
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch_cluster import radius, radius_graph
|
||||
from torch_geometric.utils import subgraph
|
||||
from torch_scatter import scatter_mean
|
||||
import numpy as np
|
||||
|
||||
from models.layers import GaussianSmearing, AtomEncoder
|
||||
from models.tensor_layers import get_irrep_seq, TensorProductConvLayer
|
||||
from utils import so3, torus
|
||||
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
|
||||
|
||||
AGGREGATORS = {"mean": lambda x: torch.mean(x, dim=1),
|
||||
"max": lambda x: torch.max(x, dim=1)[0],
|
||||
"min": lambda x: torch.min(x, dim=1)[0],
|
||||
"std": lambda x: torch.std(x, dim=1)}
|
||||
|
||||
|
||||
class AAModel(torch.nn.Module):
|
||||
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
|
||||
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
|
||||
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
|
||||
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
|
||||
separate_noise_schedule=False, lm_embedding_type=False, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm = False,
|
||||
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
|
||||
parallel_aggregators="mean max min std", num_confidence_outputs=1, atom_num_confidence_outputs=1, fixed_center_conv=False,
|
||||
no_aminoacid_identities=False, include_miscellaneous_atoms=False,
|
||||
differentiate_convolutions=True, tp_weights_layers=2, num_prot_emb_layers=0,
|
||||
reduce_pseudoscalars=False, embed_also_ligand=False, atom_confidence=False, sidechain_pred=False,
|
||||
depthwise_convolution=False, crop_beyond=None):
|
||||
super(AAModel, self).__init__()
|
||||
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
|
||||
assert not sidechain_pred, "sidechain prediction not implemented/makes sense for all atom model"
|
||||
assert not depthwise_convolution, "depthwise convolution not implemented for all atom model"
|
||||
if parallel > 1: assert affinity_prediction
|
||||
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.in_lig_edge_features = in_lig_edge_features
|
||||
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
|
||||
self.sigma_embed_dim = sigma_embed_dim
|
||||
self.lig_max_radius = lig_max_radius
|
||||
self.rec_max_radius = rec_max_radius
|
||||
self.cross_max_distance = cross_max_distance
|
||||
self.dynamic_max_cross = dynamic_max_cross
|
||||
self.center_max_distance = center_max_distance
|
||||
self.distance_embed_dim = distance_embed_dim
|
||||
self.cross_distance_embed_dim = cross_distance_embed_dim
|
||||
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
|
||||
self.ns, self.nv = ns, nv
|
||||
self.scale_by_sigma = scale_by_sigma
|
||||
self.norm_by_sigma = norm_by_sigma
|
||||
self.device = device
|
||||
self.no_torsion = no_torsion
|
||||
self.smooth_edges = smooth_edges
|
||||
self.odd_parity = odd_parity
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.timestep_emb_func = timestep_emb_func
|
||||
self.separate_noise_schedule = separate_noise_schedule
|
||||
self.confidence_mode = confidence_mode
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.num_prot_emb_layers = num_prot_emb_layers
|
||||
self.asyncronous_noise_schedule = asyncronous_noise_schedule
|
||||
self.affinity_prediction = affinity_prediction
|
||||
self.parallel, self.parallel_aggregators = parallel, parallel_aggregators.split(' ')
|
||||
self.fixed_center_conv = fixed_center_conv
|
||||
self.no_aminoacid_identities = no_aminoacid_identities
|
||||
self.differentiate_convolutions = differentiate_convolutions
|
||||
self.reduce_pseudoscalars = reduce_pseudoscalars
|
||||
self.atom_confidence = atom_confidence
|
||||
self.atom_num_confidence_outputs = atom_num_confidence_outputs
|
||||
self.crop_beyond = crop_beyond
|
||||
|
||||
self.lm_embedding_type = lm_embedding_type
|
||||
if lm_embedding_type is None:
|
||||
lm_embedding_dim = 0
|
||||
elif lm_embedding_type == "precomputed":
|
||||
lm_embedding_dim=1280
|
||||
else:
|
||||
lm, alphabet = load_model_and_alphabet(lm_embedding_type)
|
||||
self.batch_converter = alphabet.get_batch_converter()
|
||||
lm.lm_head = torch.nn.Identity()
|
||||
lm.contact_head = torch.nn.Identity()
|
||||
lm_embedding_dim = lm.embed_dim
|
||||
self.lm = lm
|
||||
|
||||
# embedding layers
|
||||
atom_encoder_class = AtomEncoder
|
||||
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.rec_sigma_embedding = nn.Sequential(nn.Linear(sigma_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=0, lm_embedding_dim=lm_embedding_dim)
|
||||
self.rec_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=0)
|
||||
self.atom_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
|
||||
self.lr_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
self.ar_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
|
||||
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
|
||||
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
|
||||
|
||||
irrep_seq = get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars)
|
||||
assert not include_miscellaneous_atoms, "currently not supported"
|
||||
|
||||
rec_emb_layers = []
|
||||
for i in range(num_prot_emb_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1 if not differentiate_convolutions else 4,
|
||||
)
|
||||
rec_emb_layers.append(layer)
|
||||
self.rec_emb_layers = nn.ModuleList(rec_emb_layers)
|
||||
|
||||
self.embed_also_ligand = embed_also_ligand
|
||||
if embed_also_ligand:
|
||||
lig_emb_layers = []
|
||||
for i in range(num_prot_emb_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1,
|
||||
)
|
||||
lig_emb_layers.append(layer)
|
||||
self.lig_emb_layers = nn.ModuleList(lig_emb_layers)
|
||||
|
||||
# convolutional layers
|
||||
conv_layers = []
|
||||
for i in range(num_prot_emb_layers, num_prot_emb_layers + num_conv_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1 if not differentiate_convolutions else (3 if i == num_prot_emb_layers + num_conv_layers - 1 else 9),
|
||||
)
|
||||
conv_layers.append(layer)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
|
||||
# confidence and affinity prediction layers
|
||||
if self.confidence_mode:
|
||||
if self.affinity_prediction:
|
||||
if self.parallel > 1:
|
||||
output_confidence_dim = 1 + ns
|
||||
else:
|
||||
output_confidence_dim = num_confidence_outputs + 1
|
||||
else:
|
||||
output_confidence_dim = num_confidence_outputs
|
||||
|
||||
input_size = ns + (nv if reduce_pseudoscalars else ns) if num_conv_layers + num_prot_emb_layers >= 3 else ns
|
||||
|
||||
if self.atom_confidence:
|
||||
self.atom_confidence_predictor = nn.Sequential(
|
||||
nn.Linear(input_size, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, atom_num_confidence_outputs + ns)
|
||||
)
|
||||
input_size = ns
|
||||
|
||||
self.confidence_predictor = nn.Sequential(
|
||||
nn.Linear(input_size, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, output_confidence_dim)
|
||||
)
|
||||
|
||||
if self.parallel > 1:
|
||||
self.affinity_predictor = nn.Sequential(
|
||||
nn.Linear(len(self.parallel_aggregators) * ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, 1)
|
||||
)
|
||||
|
||||
else:
|
||||
# convolution for translational and rotational scores
|
||||
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
|
||||
self.center_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
|
||||
self.final_conv = TensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
|
||||
n_edge_features=2 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
|
||||
if not no_torsion:
|
||||
# convolution for torsional score
|
||||
self.final_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
|
||||
self.tor_bond_conv = TensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.final_tp_tor.irreps_out,
|
||||
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
|
||||
n_edge_features=3 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tor_final_layer = nn.Sequential(
|
||||
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, 1, bias=False)
|
||||
)
|
||||
|
||||
def embedding(self, data):
|
||||
if not hasattr(data['receptor'], "rec_node_attr"):
|
||||
if self.lm_embedding_type not in [None, 'precomputed']:
|
||||
sequences = [s for l in data['receptor'].sequence for s in l]
|
||||
if isinstance(sequences[0], list):
|
||||
sequences = [s for l in sequences for s in l]
|
||||
sequences = [(i, s) for i, s in enumerate(sequences)]
|
||||
batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
|
||||
out = self.lm(batch_tokens.to(data['receptor'].x.device), repr_layers=[self.lm.num_layers], return_contacts=False)
|
||||
rec_lm_emb = torch.cat([t[:len(sequences[i][1])] for i, t in enumerate(out['representations'][self.lm.num_layers])], dim=0)
|
||||
data['receptor'].x = torch.cat([data['receptor'].x, rec_lm_emb], dim=-1)
|
||||
|
||||
rec_node_attr, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
|
||||
rec_node_attr = self.rec_node_embedding(rec_node_attr)
|
||||
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
|
||||
|
||||
atom_node_attr, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_atom_conv_graph(data)
|
||||
atom_node_attr = self.atom_node_embedding(atom_node_attr)
|
||||
atom_edge_attr = self.atom_edge_embedding(atom_edge_attr)
|
||||
|
||||
ar_edge_attr, ar_edge_sh, ar_edge_weight = self.build_cross_rec_conv_graph(data)
|
||||
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
|
||||
|
||||
rec_edge_index = data['receptor', 'receptor'].edge_index.clone()
|
||||
atom_edge_index = data['atom', 'atom'].edge_index.clone()
|
||||
ar_edge_index = data['atom', 'receptor'].edge_index.clone()
|
||||
|
||||
node_attr = torch.cat([rec_node_attr, atom_node_attr], dim=0)
|
||||
ar_edge_index[0] = ar_edge_index[0] + len(rec_node_attr)
|
||||
edge_index = torch.cat([rec_edge_index, ar_edge_index, atom_edge_index + len(rec_node_attr), torch.flip(ar_edge_index, dims=[0])], dim=1)
|
||||
edge_attr = torch.cat([rec_edge_attr, ar_edge_attr, atom_edge_attr, ar_edge_attr], dim=0)
|
||||
edge_sh = torch.cat([rec_edge_sh, ar_edge_sh, atom_edge_sh, ar_edge_sh], dim=0)
|
||||
edge_weight = torch.cat([rec_edge_weight, ar_edge_weight, atom_edge_weight, ar_edge_weight], dim=0) \
|
||||
if torch.is_tensor(rec_edge_weight) else torch.ones((len(edge_index[0]), 1), device=edge_index.device)
|
||||
s1, s2, s3 = len(rec_edge_index[0]), len(rec_edge_index[0]) + len(ar_edge_index[0]), len(rec_edge_index[0]) + len(ar_edge_index[0]) + len(atom_edge_index[0])
|
||||
|
||||
for l in range(len(self.rec_emb_layers)):
|
||||
edge_attr_ = torch.cat(
|
||||
[edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
|
||||
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:]]
|
||||
node_attr = self.rec_emb_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
|
||||
|
||||
|
||||
data['receptor'].rec_node_attr = node_attr[:len(rec_node_attr)]
|
||||
data['receptor', 'receptor'].rec_edge_attr = rec_edge_attr
|
||||
data['receptor', 'receptor'].edge_sh = rec_edge_sh
|
||||
data['receptor', 'receptor'].edge_weight = rec_edge_weight
|
||||
|
||||
data['atom'].atom_node_attr = node_attr[len(rec_node_attr):]
|
||||
data['atom', 'atom'].atom_edge_attr = atom_edge_attr
|
||||
data['atom', 'atom'].edge_sh = atom_edge_sh
|
||||
data['atom', 'atom'].edge_weight = atom_edge_weight
|
||||
|
||||
data['atom', 'receptor'].edge_attr = ar_edge_attr
|
||||
data['atom', 'receptor'].edge_sh = ar_edge_sh
|
||||
data['atom', 'receptor'].edge_weight = ar_edge_weight
|
||||
|
||||
# receptor embedding
|
||||
rec_sigma_emb = self.rec_sigma_embedding(self.timestep_emb_func(data.complex_t['tr']))
|
||||
rec_node_attr = data['receptor'].rec_node_attr + 0
|
||||
rec_node_attr[:, :self.ns] = rec_node_attr[:, :self.ns] + rec_sigma_emb[data['receptor'].batch]
|
||||
rec_edge_attr = data['receptor', 'receptor'].rec_edge_attr + rec_sigma_emb[data['receptor'].batch[data['receptor', 'receptor'].edge_index[0]]]
|
||||
|
||||
# atom embedding
|
||||
atom_node_attr = data['atom'].atom_node_attr + 0
|
||||
atom_node_attr[:, :self.ns] = atom_node_attr[:, :self.ns] + rec_sigma_emb[data['atom'].batch]
|
||||
atom_edge_attr = data['atom', 'atom'].atom_edge_attr + rec_sigma_emb[data['atom'].batch[data['atom', 'atom'].edge_index[0]]]
|
||||
|
||||
# atom-receptor embedding
|
||||
ar_edge_attr = data['atom', 'receptor'].edge_attr + rec_sigma_emb[data['atom'].batch[data['atom', 'receptor'].edge_index[0]]]
|
||||
|
||||
# ligand embedding
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
|
||||
lig_node_attr = self.lig_node_embedding(lig_node_attr)
|
||||
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
|
||||
|
||||
if self.embed_also_ligand:
|
||||
for l in range(len(self.lig_emb_layers)):
|
||||
edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns], lig_node_attr[lig_edge_index[1], :self.ns]], -1)
|
||||
lig_node_attr = self.lig_emb_layers[l](lig_node_attr, lig_edge_index, edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
|
||||
|
||||
else:
|
||||
lig_node_attr = F.pad(lig_node_attr, (0, rec_node_attr.shape[-1] - lig_node_attr.shape[-1]))
|
||||
|
||||
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, \
|
||||
rec_node_attr, data['receptor', 'receptor'].edge_index, rec_edge_attr, data['receptor', 'receptor'].edge_sh, data['receptor', 'receptor'].edge_weight, \
|
||||
atom_node_attr, data['atom', 'atom'].edge_index, atom_edge_attr, data['atom', 'atom'].edge_sh, data['atom', 'atom'].edge_weight, \
|
||||
data['atom', 'receptor'].edge_index, ar_edge_attr, data['atom', 'receptor'].edge_sh, data['atom', 'receptor'].edge_weight
|
||||
|
||||
def forward(self, data):
|
||||
if self.crop_beyond is not None:
|
||||
# TODO missing filtering atoms
|
||||
raise NotImplementedError
|
||||
ligand_pos = data['ligand'].pos
|
||||
receptor_pos = data['receptor'].pos
|
||||
residues_to_keep = torch.any(torch.sum((ligand_pos.unsqueeze(0) - receptor_pos.unsqueeze(1)) ** 2, -1) < self.crop_beyond ** 2, dim=1)
|
||||
|
||||
data['receptor'].pos = data['receptor'].pos[residues_to_keep]
|
||||
data['receptor'].x = data['receptor'].x[residues_to_keep]
|
||||
data['receptor'].side_chain_vecs = data['receptor'].side_chain_vecs[residues_to_keep]
|
||||
data['receptor', 'rec_contact', 'receptor'].edge_index = subgraph(residues_to_keep, data['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
|
||||
|
||||
if self.no_aminoacid_identities:
|
||||
data['receptor'].x = data['receptor'].x * 0
|
||||
|
||||
if not self.confidence_mode:
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
|
||||
else:
|
||||
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
|
||||
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, rec_node_attr, \
|
||||
rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight,\
|
||||
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight, \
|
||||
ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = self.embedding(data)
|
||||
|
||||
# build lig cross graph
|
||||
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1) if self.dynamic_max_cross else self.cross_max_distance
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight = self.build_cross_lig_conv_graph(data, cross_cutoff)
|
||||
lr_edge_attr= self.lr_edge_embedding(lr_edge_attr)
|
||||
la_edge_attr = self.la_edge_embedding(la_edge_attr)
|
||||
|
||||
n_lig, n_rec = len(lig_node_attr), len(rec_node_attr)
|
||||
|
||||
node_attr = torch.cat([lig_node_attr, rec_node_attr, atom_node_attr], dim=0)
|
||||
rec_edge_index, atom_edge_index, lr_edge_index, la_edge_index, ar_edge_index = rec_edge_index.clone(), atom_edge_index.clone(), lr_edge_index.clone(), la_edge_index.clone(), ar_edge_index.clone()
|
||||
rec_edge_index[0], rec_edge_index[1] = rec_edge_index[0] + n_lig, rec_edge_index[1] + n_lig
|
||||
atom_edge_index[0], atom_edge_index[1] = atom_edge_index[0] + n_lig + n_rec, atom_edge_index[1] + n_lig + n_rec
|
||||
lr_edge_index[1] = lr_edge_index[1] + n_lig
|
||||
la_edge_index[1] = la_edge_index[1] + n_lig + n_rec
|
||||
ar_edge_index[0], ar_edge_index[1] = ar_edge_index[0] + n_lig + n_rec, ar_edge_index[1] + n_lig
|
||||
|
||||
edge_index = torch.cat([lig_edge_index, lr_edge_index, la_edge_index, rec_edge_index,
|
||||
torch.flip(lr_edge_index, dims=[0]), torch.flip(ar_edge_index, dims=[0]),
|
||||
atom_edge_index, torch.flip(la_edge_index, dims=[0]), ar_edge_index], dim=1)
|
||||
edge_attr = torch.cat([lig_edge_attr, lr_edge_attr, la_edge_attr, rec_edge_attr, lr_edge_attr,
|
||||
ar_edge_attr, atom_edge_attr, la_edge_attr, ar_edge_attr], dim=0)
|
||||
edge_sh = torch.cat([lig_edge_sh, lr_edge_sh, la_edge_sh, rec_edge_sh, lr_edge_sh, ar_edge_sh,
|
||||
atom_edge_sh, la_edge_sh, ar_edge_sh], dim=0)
|
||||
edge_weight = torch.cat([lig_edge_weight, lr_edge_weight, la_edge_weight, rec_edge_weight, lr_edge_weight,
|
||||
ar_edge_weight, atom_edge_weight, la_edge_weight, ar_edge_weight],
|
||||
dim=0) if torch.is_tensor(lig_edge_weight) else torch.ones((len(edge_index[0]), 1),
|
||||
device=edge_index.device)
|
||||
s1, s2, s3, s4, s5, s6, s7, s8, _ = tuple(np.cumsum(list(map(len, [lig_edge_attr, lr_edge_attr, la_edge_attr,
|
||||
rec_edge_attr, lr_edge_attr, ar_edge_attr, atom_edge_attr, la_edge_attr, ar_edge_attr]))).tolist())
|
||||
|
||||
for l in range(len(self.conv_layers)):
|
||||
if l < len(self.conv_layers) - 1:
|
||||
edge_attr_ = torch.cat([edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
|
||||
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:s4],
|
||||
edge_attr_[s4:s5], edge_attr_[s5:s6], edge_attr_[s6:s7], edge_attr_[s7:s8], edge_attr_[s8:]]
|
||||
node_attr = self.conv_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
|
||||
else:
|
||||
edge_attr_ = torch.cat([edge_attr[:s3], node_attr[edge_index[0, :s3], :self.ns], node_attr[edge_index[1, :s3], :self.ns]], -1)
|
||||
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3]]
|
||||
node_attr = self.conv_layers[l](node_attr, edge_index[:, :s3], edge_attr_, edge_sh[:s3], edge_weight=edge_weight[:s3])
|
||||
|
||||
lig_node_attr = node_attr[:len(lig_node_attr)]
|
||||
|
||||
# confidence and affinity prediction
|
||||
if self.confidence_mode:
|
||||
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns], lig_node_attr[:,-(self.nv if self.reduce_pseudoscalars else self.ns):] ], dim=1) \
|
||||
if self.num_conv_layers + self.num_prot_emb_layers >= 3 else lig_node_attr[:,:self.ns]
|
||||
|
||||
if self.atom_confidence:
|
||||
scalar_lig_attr = self.atom_confidence_predictor(scalar_lig_attr)
|
||||
atom_confidence = scalar_lig_attr[:, :self.atom_num_confidence_outputs]
|
||||
scalar_lig_attr = scalar_lig_attr[:, self.atom_num_confidence_outputs:]
|
||||
else:
|
||||
atom_confidence = torch.zeros((len(lig_node_attr),), device=lig_node_attr.device)
|
||||
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
|
||||
|
||||
if self.parallel > 1:
|
||||
confidence, affinity = confidence[:, 0], confidence[:, 1:]
|
||||
confidence = confidence.reshape(data.num_graphs, self.parallel)
|
||||
affinity = affinity.reshape(data.num_graphs, self.parallel, -1)
|
||||
affinity = torch.cat([AGGREGATORS[agg](affinity) for agg in self.parallel_aggregators], dim=-1)
|
||||
affinity = self.affinity_predictor(affinity).squeeze(dim=-1)
|
||||
confidence = confidence, affinity
|
||||
return confidence, atom_confidence
|
||||
assert self.parallel == 1
|
||||
|
||||
# compute translational and rotational score vectors
|
||||
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
|
||||
center_edge_attr = self.center_edge_embedding(center_edge_attr)
|
||||
if self.fixed_center_conv:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
else:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
|
||||
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
|
||||
|
||||
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
|
||||
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
|
||||
|
||||
if self.separate_noise_schedule:
|
||||
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
|
||||
else: # tr rot and tor noise is all the same in this case
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# adjust the magniture of the score vectors
|
||||
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
|
||||
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
|
||||
|
||||
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
|
||||
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
|
||||
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
|
||||
|
||||
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0,device=self.device), None
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
|
||||
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
|
||||
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
|
||||
.to(data['ligand'].x.device))
|
||||
return tr_pred, rot_pred, tor_pred, None
|
||||
|
||||
def get_edge_weight(self, edge_vec, max_norm):
|
||||
if self.smooth_edges:
|
||||
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
|
||||
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
|
||||
return 1.0
|
||||
|
||||
def build_lig_conv_graph(self, data):
|
||||
# build the graph between ligand atoms
|
||||
if self.separate_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = torch.cat(
|
||||
[self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
|
||||
dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
|
||||
else:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(
|
||||
data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
|
||||
if self.parallel == 1:
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
else:
|
||||
batches = torch.zeros(data.num_graphs, device=data['ligand'].x.device).long()
|
||||
batches = batches.index_add(0, data['ligand'].batch, torch.ones(len(data['ligand'].batch), device=data['ligand'].x.device).long())
|
||||
outer_batches = data.num_graphs
|
||||
b = [torch.ones(batches[i].item()//self.parallel, device=data['ligand'].x.device).long() * (self.parallel * i + j)
|
||||
for i in range(outer_batches) for j in range(self.parallel)]
|
||||
data['ligand'].batch_parallel = torch.cat(b)
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch_parallel)
|
||||
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
|
||||
edge_attr = torch.cat([
|
||||
data['ligand', 'ligand'].edge_attr,
|
||||
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
|
||||
], 0)
|
||||
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_rec_conv_graph(self, data):
|
||||
# build the graph between receptor residues
|
||||
node_attr = data['receptor'].x
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['receptor', 'receptor'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
|
||||
|
||||
edge_attr = self.rec_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
|
||||
|
||||
return node_attr, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_atom_conv_graph(self, data):
|
||||
# build the graph between receptor atoms
|
||||
node_attr = data['atom'].x
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['atom', 'atom'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['atom'].pos[dst.long()] - data['atom'].pos[src.long()]
|
||||
|
||||
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_cross_lig_conv_graph(self, data, lr_cross_distance_cutoff):
|
||||
# build the cross edges between ligand atoms and receptor residues + atoms
|
||||
|
||||
# LIGAND to RECEPTOR
|
||||
if torch.is_tensor(lr_cross_distance_cutoff):
|
||||
# different cutoff for every graph
|
||||
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
|
||||
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
|
||||
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
|
||||
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
|
||||
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
|
||||
|
||||
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
|
||||
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
|
||||
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
|
||||
|
||||
# LIGAND to ATOM
|
||||
la_edge_index = radius(data['atom'].pos, data['ligand'].pos, self.lig_max_radius,
|
||||
data['atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
la_edge_vec = data['atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
|
||||
la_edge_length_emb = self.lig_distance_expansion(la_edge_vec.norm(dim=-1))
|
||||
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
|
||||
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
|
||||
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
|
||||
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
|
||||
|
||||
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight
|
||||
|
||||
def build_cross_rec_conv_graph(self, data):
|
||||
# build the cross edges between ligan atoms, receptor residues and receptor atoms
|
||||
|
||||
# ATOM to RECEPTOR
|
||||
ar_edge_index = data['atom', 'receptor'].edge_index
|
||||
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['atom'].pos[ar_edge_index[0].long()]
|
||||
ar_edge_attr = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
|
||||
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
|
||||
ar_edge_weight = 1
|
||||
|
||||
return ar_edge_attr, ar_edge_sh, ar_edge_weight
|
||||
|
||||
def build_center_conv_graph(self, data):
|
||||
# build the filter for the convolution of the center with the ligand atoms
|
||||
# for translational and rotational score
|
||||
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
|
||||
|
||||
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
|
||||
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
|
||||
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
|
||||
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
return edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_bond_conv_graph(self, data):
|
||||
# build graph for the pseudotorque layer
|
||||
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
|
||||
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
|
||||
bond_batch = data['ligand'].batch[bonds[0]]
|
||||
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
|
||||
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = self.final_edge_embedding(edge_attr)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return bonds, edge_index, edge_attr, edge_sh, edge_weight
|
||||
640
models/cg_model.py
Normal file
640
models/cg_model.py
Normal file
@@ -0,0 +1,640 @@
|
||||
import math
|
||||
|
||||
from e3nn import o3
|
||||
import torch
|
||||
from e3nn.o3 import Linear
|
||||
from esm.pretrained import load_model_and_alphabet
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch_cluster import radius, radius_graph
|
||||
from torch_scatter import scatter, scatter_mean
|
||||
import numpy as np
|
||||
|
||||
from models.layers import GaussianSmearing, AtomEncoder
|
||||
from models.tensor_layers import TensorProductConvLayer, get_irrep_seq
|
||||
from utils import so3, torus
|
||||
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
|
||||
|
||||
|
||||
class CGModel(torch.nn.Module):
|
||||
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
|
||||
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
|
||||
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
|
||||
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
|
||||
separate_noise_schedule=False, lm_embedding_type=None, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm=False,
|
||||
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
|
||||
parallel_aggregators="mean max min std", num_confidence_outputs=1, atom_num_confidence_outputs=1, fixed_center_conv=False,
|
||||
no_aminoacid_identities=False, include_miscellaneous_atoms=False,
|
||||
differentiate_convolutions=True, tp_weights_layers=2, num_prot_emb_layers=0, reduce_pseudoscalars=False,
|
||||
embed_also_ligand=False, atom_confidence=False, sidechain_pred=False, depthwise_convolution=False):
|
||||
super(CGModel, self).__init__()
|
||||
assert parallel == 1, "not implemented"
|
||||
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.in_lig_edge_features = in_lig_edge_features
|
||||
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
|
||||
self.sigma_embed_dim = sigma_embed_dim
|
||||
self.lig_max_radius = lig_max_radius
|
||||
self.rec_max_radius = rec_max_radius
|
||||
self.include_miscellaneous_atoms = include_miscellaneous_atoms
|
||||
self.cross_max_distance = cross_max_distance
|
||||
self.dynamic_max_cross = dynamic_max_cross
|
||||
self.center_max_distance = center_max_distance
|
||||
self.distance_embed_dim = distance_embed_dim
|
||||
self.cross_distance_embed_dim = cross_distance_embed_dim
|
||||
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
|
||||
self.ns, self.nv = ns, nv
|
||||
self.scale_by_sigma = scale_by_sigma
|
||||
self.norm_by_sigma = norm_by_sigma
|
||||
self.device = device
|
||||
self.no_torsion = no_torsion
|
||||
self.smooth_edges = smooth_edges
|
||||
self.odd_parity = odd_parity
|
||||
self.timestep_emb_func = timestep_emb_func
|
||||
self.separate_noise_schedule = separate_noise_schedule
|
||||
self.confidence_mode = confidence_mode
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.num_prot_emb_layers = num_prot_emb_layers
|
||||
self.asyncronous_noise_schedule = asyncronous_noise_schedule
|
||||
self.affinity_prediction = affinity_prediction
|
||||
self.fixed_center_conv = fixed_center_conv
|
||||
self.no_aminoacid_identities = no_aminoacid_identities
|
||||
self.differentiate_convolutions = differentiate_convolutions
|
||||
self.reduce_pseudoscalars = reduce_pseudoscalars
|
||||
self.atom_confidence = atom_confidence
|
||||
self.atom_num_confidence_outputs = atom_num_confidence_outputs
|
||||
self.sidechain_pred = sidechain_pred
|
||||
|
||||
self.lm_embedding_type = lm_embedding_type
|
||||
if lm_embedding_type is None:
|
||||
lm_embedding_dim = 0
|
||||
elif lm_embedding_type == "precomputed":
|
||||
lm_embedding_dim=1280
|
||||
else:
|
||||
lm, alphabet = load_model_and_alphabet(lm_embedding_type)
|
||||
self.batch_converter = alphabet.get_batch_converter()
|
||||
lm.lm_head = torch.nn.Identity()
|
||||
lm.contact_head = torch.nn.Identity()
|
||||
lm_embedding_dim = lm.embed_dim
|
||||
self.lm = lm
|
||||
|
||||
atom_encoder_class = AtomEncoder
|
||||
|
||||
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=0, lm_embedding_dim=lm_embedding_dim)
|
||||
self.rec_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.rec_sigma_embedding = nn.Sequential(nn.Linear(sigma_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
|
||||
if self.include_miscellaneous_atoms:
|
||||
self.misc_atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.misc_atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.ar_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
|
||||
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
|
||||
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
|
||||
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
|
||||
|
||||
irrep_seq = get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars)
|
||||
|
||||
assert not self.include_miscellaneous_atoms, "currently not supported"
|
||||
|
||||
rec_emb_layers = []
|
||||
for i in range(num_prot_emb_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1,
|
||||
depthwise=depthwise_convolution
|
||||
)
|
||||
rec_emb_layers.append(layer)
|
||||
self.rec_emb_layers = nn.ModuleList(rec_emb_layers)
|
||||
|
||||
self.embed_also_ligand = embed_also_ligand
|
||||
if embed_also_ligand:
|
||||
lig_emb_layers = []
|
||||
for i in range(num_prot_emb_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1,
|
||||
depthwise=depthwise_convolution
|
||||
)
|
||||
lig_emb_layers.append(layer)
|
||||
self.lig_emb_layers = nn.ModuleList(lig_emb_layers)
|
||||
|
||||
conv_layers = []
|
||||
for i in range(num_prot_emb_layers, num_prot_emb_layers + num_conv_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
layer = TensorProductConvLayer(
|
||||
in_irreps=in_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=out_irreps,
|
||||
n_edge_features=3 * ns,
|
||||
hidden_features=3 * ns,
|
||||
residual=True,
|
||||
batch_norm=batch_norm,
|
||||
dropout=dropout,
|
||||
faster=sh_lmax == 1 and not use_second_order_repr,
|
||||
tp_weights_layers=tp_weights_layers,
|
||||
edge_groups=1 if not differentiate_convolutions else (2 if i == num_prot_emb_layers + num_conv_layers - 1 else 4),
|
||||
depthwise=depthwise_convolution
|
||||
)
|
||||
conv_layers.append(layer)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
|
||||
if sidechain_pred:
|
||||
self.sidechain_predictor = Linear(
|
||||
irreps_in=irrep_seq[min(num_prot_emb_layers + num_conv_layers, len(irrep_seq) - 1)],
|
||||
irreps_out='4x0e + 2x1e + 4x0o + 2x1o',
|
||||
internal_weights=True,
|
||||
shared_weights=True,
|
||||
)
|
||||
|
||||
if self.confidence_mode:
|
||||
input_size = ns + (nv if reduce_pseudoscalars else ns) if num_conv_layers + num_prot_emb_layers >= 3 else ns
|
||||
|
||||
if self.atom_confidence:
|
||||
self.atom_confidence_predictor = nn.Sequential(
|
||||
nn.Linear(input_size, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, atom_num_confidence_outputs + ns)
|
||||
)
|
||||
input_size = ns
|
||||
|
||||
self.confidence_predictor = nn.Sequential(
|
||||
nn.Linear(input_size, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, num_confidence_outputs + (1 if self.affinity_prediction else 0))
|
||||
)
|
||||
|
||||
else:
|
||||
# center of mass translation and rotation components
|
||||
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
|
||||
self.center_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
|
||||
self.final_conv = TensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
|
||||
n_edge_features=2 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
|
||||
if not no_torsion:
|
||||
# torsion angles components
|
||||
self.final_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
|
||||
self.tor_bond_conv = TensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.final_tp_tor.irreps_out,
|
||||
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
|
||||
n_edge_features=3 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tor_final_layer = nn.Sequential(
|
||||
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, 1, bias=False)
|
||||
)
|
||||
|
||||
def ligand_embedding(self, data):
|
||||
# ligand embedding
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
|
||||
lig_node_attr = self.lig_node_embedding(lig_node_attr)
|
||||
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
|
||||
|
||||
assert self.embed_also_ligand, "otherwise reimplement padding"
|
||||
for l in range(len(self.lig_emb_layers)):
|
||||
edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns],
|
||||
lig_node_attr[lig_edge_index[1], :self.ns]], -1)
|
||||
lig_node_attr = self.lig_emb_layers[l](lig_node_attr, lig_edge_index, edge_attr_, lig_edge_sh,
|
||||
edge_weight=lig_edge_weight)
|
||||
|
||||
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight
|
||||
|
||||
def embedding(self, data):
|
||||
if not hasattr(data['receptor'], "rec_node_attr"):
|
||||
if self.lm_embedding_type not in [None, 'precomputed']:
|
||||
sequences = [s for l in data['receptor'].sequence for s in l]
|
||||
if isinstance(sequences[0], list):
|
||||
sequences = [s for l in sequences for s in l]
|
||||
sequences = [(i, s) for i, s in enumerate(sequences)]
|
||||
batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
|
||||
out = self.lm(batch_tokens.to(data['receptor'].x.device), repr_layers=[self.lm.num_layers], return_contacts=False)
|
||||
rec_lm_emb = torch.cat([t[:len(sequences[i][1])] for i, t in enumerate(out['representations'][self.lm.num_layers])], dim=0)
|
||||
data['receptor'].x = torch.cat([data['receptor'].x, rec_lm_emb], dim=-1)
|
||||
|
||||
rec_node_attr, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
|
||||
rec_node_attr = self.rec_node_embedding(rec_node_attr)
|
||||
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
|
||||
|
||||
for l in range(len(self.rec_emb_layers)):
|
||||
edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[data['receptor', 'receptor'].edge_index[0], :self.ns], rec_node_attr[data['receptor', 'receptor'].edge_index[1], :self.ns]], -1)
|
||||
rec_node_attr = self.rec_emb_layers[l](rec_node_attr, data['receptor', 'receptor'].edge_index, edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
|
||||
|
||||
data['receptor'].rec_node_attr = rec_node_attr
|
||||
data['receptor', 'receptor'].rec_edge_attr = rec_edge_attr
|
||||
data['receptor', 'receptor'].edge_sh = rec_edge_sh
|
||||
data['receptor', 'receptor'].edge_weight = rec_edge_weight
|
||||
|
||||
# receptor embedding
|
||||
rec_sigma_emb = self.rec_sigma_embedding(self.timestep_emb_func(data.complex_t['tr']))
|
||||
rec_node_attr = data['receptor'].rec_node_attr + 0
|
||||
rec_node_attr[:, :self.ns] = rec_node_attr[:, :self.ns] + rec_sigma_emb[data['receptor'].batch]
|
||||
rec_edge_attr = data['receptor', 'receptor'].rec_edge_attr + rec_sigma_emb[data['receptor'].batch[data['receptor', 'receptor'].edge_index[0]]]
|
||||
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.ligand_embedding(data)
|
||||
|
||||
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, \
|
||||
rec_node_attr, data['receptor', 'receptor'].edge_index, rec_edge_attr, data['receptor', 'receptor'].edge_sh, data['receptor', 'receptor'].edge_weight
|
||||
|
||||
def forward(self, data):
|
||||
if self.no_aminoacid_identities:
|
||||
data['receptor'].x = data['receptor'].x * 0
|
||||
|
||||
if not self.confidence_mode:
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
|
||||
else:
|
||||
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
|
||||
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, rec_node_attr, \
|
||||
rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.embedding(data)
|
||||
|
||||
# build cross graph
|
||||
if self.dynamic_max_cross:
|
||||
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
|
||||
else:
|
||||
cross_cutoff = self.cross_max_distance
|
||||
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, rev_lr_edge_sh, lr_edge_weight = self.build_cross_conv_graph(data, cross_cutoff)
|
||||
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
|
||||
|
||||
node_attr = torch.cat([lig_node_attr, rec_node_attr], dim=0)
|
||||
lr_edge_index[1] = lr_edge_index[1] + len(lig_node_attr)
|
||||
edge_index = torch.cat([lig_edge_index, lr_edge_index, rec_edge_index + len(lig_node_attr),
|
||||
torch.flip(lr_edge_index, dims=[0])], dim=1)
|
||||
edge_attr = torch.cat([lig_edge_attr, lr_edge_attr, rec_edge_attr, lr_edge_attr], dim=0)
|
||||
edge_sh = torch.cat([lig_edge_sh, lr_edge_sh, rec_edge_sh, rev_lr_edge_sh], dim=0)
|
||||
edge_weight = torch.cat([lig_edge_weight, lr_edge_weight, rec_edge_weight, lr_edge_weight],
|
||||
dim=0) if torch.is_tensor(lig_edge_weight) else torch.ones((len(edge_index[0]), 1),
|
||||
device=edge_index.device)
|
||||
s1, s2, s3 = len(lig_edge_index[0]), len(lig_edge_index[0]) + len(lr_edge_index[0]), len(lig_edge_index[0]) + len(lr_edge_index[0]) + len(rec_edge_index[0])
|
||||
|
||||
for l in range(len(self.conv_layers)):
|
||||
if l < len(self.conv_layers) - 1:
|
||||
edge_attr_ = torch.cat(
|
||||
[edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
|
||||
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:]]
|
||||
node_attr = self.conv_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
|
||||
else:
|
||||
edge_attr_ = torch.cat([edge_attr[:s2], node_attr[edge_index[0, :s2], :self.ns], node_attr[edge_index[1, :s2], :self.ns]], -1)
|
||||
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2]]
|
||||
node_attr = self.conv_layers[l](node_attr, edge_index[:, :s2], edge_attr_, edge_sh[:s2], edge_weight=edge_weight[:s2])
|
||||
|
||||
lig_node_attr = node_attr[:len(lig_node_attr)]
|
||||
|
||||
# compute confidence score
|
||||
if self.confidence_mode:
|
||||
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns], lig_node_attr[:,-(self.nv if self.reduce_pseudoscalars else self.ns):] ], dim=1) \
|
||||
if self.num_conv_layers + self.num_prot_emb_layers >= 3 else lig_node_attr[:,:self.ns]
|
||||
|
||||
if self.atom_confidence:
|
||||
scalar_lig_attr = self.atom_confidence_predictor(scalar_lig_attr)
|
||||
atom_confidence = scalar_lig_attr[:, :self.atom_num_confidence_outputs]
|
||||
scalar_lig_attr = scalar_lig_attr[:, self.atom_num_confidence_outputs:]
|
||||
else:
|
||||
atom_confidence = torch.zeros((len(lig_node_attr),), device=lig_node_attr.device)
|
||||
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
|
||||
return confidence, atom_confidence
|
||||
|
||||
# compute translational and rotational score vectors
|
||||
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
|
||||
center_edge_attr = self.center_edge_embedding(center_edge_attr)
|
||||
if self.fixed_center_conv:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
else:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
|
||||
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
|
||||
|
||||
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
|
||||
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
|
||||
|
||||
if self.separate_noise_schedule:
|
||||
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
|
||||
else: # tr rot and tor noise is all the same in this case
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# fix the magnitude of translational and rotational score vectors
|
||||
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
|
||||
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
|
||||
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
|
||||
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
|
||||
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
|
||||
|
||||
# predict sidechain orientation
|
||||
sidechain_pred = None
|
||||
if self.sidechain_pred:
|
||||
rec_node_attr = node_attr[len(lig_node_attr):]
|
||||
sidechain_pred = self.sidechain_predictor(rec_node_attr)
|
||||
sidechain_pred = sidechain_pred[:, :10] + sidechain_pred[:, 10:] # sum even and odd components
|
||||
|
||||
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device), sidechain_pred
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
|
||||
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
|
||||
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
|
||||
.to(data['ligand'].x.device))
|
||||
return tr_pred, rot_pred, tor_pred, sidechain_pred
|
||||
|
||||
def torsional_forward(self, data):
|
||||
tor_sigma = self.t_to_sigma(data.complex_t['tor'])
|
||||
|
||||
# build ligand graph
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.ligand_embedding(data)
|
||||
|
||||
if self.separate_noise_schedule:
|
||||
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
|
||||
else: # tr rot and tor noise is all the same in this case
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
|
||||
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
|
||||
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
|
||||
.to(data['ligand'].x.device))
|
||||
return 0, 0, tor_pred, 0
|
||||
|
||||
def get_edge_weight(self, edge_vec, max_norm):
|
||||
# computes weights for edges that are decreasing with the distance
|
||||
# it has an effect only if smooth edges is true
|
||||
if self.smooth_edges:
|
||||
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
|
||||
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
|
||||
return 1.0
|
||||
|
||||
def build_lig_conv_graph(self, data):
|
||||
# builds the ligand graph edges and initial node and edge features
|
||||
if self.separate_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
|
||||
else:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
|
||||
# compute edges
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
|
||||
edge_attr = torch.cat([
|
||||
data['ligand', 'ligand'].edge_attr,
|
||||
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
|
||||
], 0)
|
||||
|
||||
# compute initial features
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_rec_conv_graph(self, data):
|
||||
# builds the receptor initial node and edge embeddings
|
||||
assert not self.separate_noise_schedule or self.asyncronous_noise_schedule, "removed support in this function"
|
||||
node_attr = data['receptor'].x
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['receptor', 'receptor'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_attr = edge_length_emb
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
|
||||
|
||||
return node_attr, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_misc_atom_conv_graph(self, data):
|
||||
# build the graph between receptor misc_atoms
|
||||
if self.separate_noise_schedule:
|
||||
data['misc_atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['misc_atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['t'])
|
||||
else:
|
||||
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['misc_atom'].x, data['misc_atom'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['misc_atom', 'misc_atom'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['misc_atom'].pos[dst.long()] - data['misc_atom'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['misc_atom'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_cross_conv_graph(self, data, cross_distance_cutoff):
|
||||
# builds the cross edges between ligand and receptor
|
||||
if torch.is_tensor(cross_distance_cutoff):
|
||||
# different cutoff for every graph (depends on the diffusion time)
|
||||
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
rev_edge_sh = o3.spherical_harmonics(self.sh_irreps, -edge_vec, normalize=True, normalization='component')
|
||||
|
||||
cutoff_d = cross_distance_cutoff[data['ligand'].batch[src]].squeeze() if torch.is_tensor(cross_distance_cutoff) else cross_distance_cutoff
|
||||
edge_weight = self.get_edge_weight(edge_vec, cutoff_d)
|
||||
|
||||
return edge_index, edge_attr, edge_sh, rev_edge_sh, edge_weight
|
||||
|
||||
def build_misc_cross_conv_graph(self, data, lr_cross_distance_cutoff):
|
||||
# build the cross edges between ligan atoms, receptor residues and receptor atoms
|
||||
|
||||
# LIGAND to RECEPTOR
|
||||
if torch.is_tensor(lr_cross_distance_cutoff):
|
||||
# different cutoff for every graph
|
||||
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
|
||||
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
|
||||
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
|
||||
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
|
||||
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
|
||||
|
||||
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
|
||||
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
|
||||
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
|
||||
|
||||
# LIGAND to ATOM
|
||||
la_edge_index = radius(data['misc_atom'].pos, data['ligand'].pos, self.lig_max_radius,
|
||||
data['misc_atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
la_edge_vec = data['misc_atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
|
||||
la_edge_length_emb = self.cross_distance_expansion(la_edge_vec.norm(dim=-1))
|
||||
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
|
||||
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
|
||||
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
|
||||
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
|
||||
|
||||
# ATOM to RECEPTOR
|
||||
ar_edge_index = data['misc_atom', 'receptor'].edge_index
|
||||
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['misc_atom'].pos[ar_edge_index[0].long()]
|
||||
ar_edge_length_emb = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
|
||||
ar_edge_sigma_emb = data['misc_atom'].node_sigma_emb[ar_edge_index[0].long()]
|
||||
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
|
||||
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
|
||||
ar_edge_weight = 1
|
||||
|
||||
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
|
||||
|
||||
def build_center_conv_graph(self, data):
|
||||
# builds the filter and edges for the convolution generating translational and rotational scores
|
||||
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
|
||||
|
||||
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
|
||||
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
|
||||
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
|
||||
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
return edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_bond_conv_graph(self, data):
|
||||
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
|
||||
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
|
||||
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
|
||||
bond_batch = data['ligand'].batch[bonds[0]]
|
||||
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
|
||||
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = self.final_edge_embedding(edge_attr)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return bonds, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
94
models/layers.py
Normal file
94
models/layers.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
ACTIVATIONS = {
|
||||
'relu': nn.ReLU,
|
||||
'silu': nn.SiLU
|
||||
}
|
||||
|
||||
def FCBlock(in_dim, hidden_dim, out_dim, layers, dropout, activation='relu'):
|
||||
activation = ACTIVATIONS[activation]
|
||||
assert layers >= 2
|
||||
sequential = [nn.Linear(in_dim, hidden_dim), activation(), nn.Dropout(dropout)]
|
||||
for i in range(layers - 2):
|
||||
sequential += [nn.Linear(hidden_dim, hidden_dim), activation(), nn.Dropout(dropout)]
|
||||
sequential += [nn.Linear(hidden_dim, out_dim)]
|
||||
return nn.Sequential(*sequential)
|
||||
|
||||
|
||||
class GaussianSmearing(torch.nn.Module):
|
||||
# used to embed the edge distances
|
||||
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
|
||||
super().__init__()
|
||||
offset = torch.linspace(start, stop, num_gaussians)
|
||||
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
||||
self.register_buffer('offset', offset)
|
||||
|
||||
def forward(self, dist):
|
||||
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
||||
return torch.exp(self.coeff * torch.pow(dist, 2))
|
||||
|
||||
|
||||
|
||||
class AtomEncoder(torch.nn.Module):
|
||||
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_dim=0):
|
||||
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
|
||||
super(AtomEncoder, self).__init__()
|
||||
self.atom_embedding_list = torch.nn.ModuleList()
|
||||
self.num_categorical_features = len(feature_dims[0])
|
||||
self.additional_features_dim = feature_dims[1] + sigma_embed_dim + lm_embedding_dim
|
||||
for i, dim in enumerate(feature_dims[0]):
|
||||
emb = torch.nn.Embedding(dim, emb_dim)
|
||||
torch.nn.init.xavier_uniform_(emb.weight.data)
|
||||
self.atom_embedding_list.append(emb)
|
||||
|
||||
if self.additional_features_dim > 0:
|
||||
self.additional_features_embedder = torch.nn.Linear(self.additional_features_dim + emb_dim, emb_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x_embedding = 0
|
||||
assert x.shape[1] == self.num_categorical_features + self.additional_features_dim
|
||||
for i in range(self.num_categorical_features):
|
||||
x_embedding += self.atom_embedding_list[i](x[:, i].long())
|
||||
|
||||
if self.additional_features_dim > 0:
|
||||
x_embedding = self.additional_features_embedder(torch.cat([x_embedding, x[:, self.num_categorical_features:]], axis=1))
|
||||
return x_embedding
|
||||
|
||||
|
||||
class OldAtomEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type= None):
|
||||
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
|
||||
super(OldAtomEncoder, self).__init__()
|
||||
self.atom_embedding_list = torch.nn.ModuleList()
|
||||
self.num_categorical_features = len(feature_dims[0])
|
||||
self.num_scalar_features = feature_dims[1] + sigma_embed_dim
|
||||
self.lm_embedding_type = lm_embedding_type
|
||||
for i, dim in enumerate(feature_dims[0]):
|
||||
emb = torch.nn.Embedding(dim, emb_dim)
|
||||
torch.nn.init.xavier_uniform_(emb.weight.data)
|
||||
self.atom_embedding_list.append(emb)
|
||||
|
||||
if self.num_scalar_features > 0:
|
||||
self.linear = torch.nn.Linear(self.num_scalar_features, emb_dim)
|
||||
if self.lm_embedding_type is not None:
|
||||
if self.lm_embedding_type == 'esm':
|
||||
self.lm_embedding_dim = 1280
|
||||
else: raise ValueError('LM Embedding type was not correctly determined. LM embedding type: ', self.lm_embedding_type)
|
||||
self.lm_embedding_layer = torch.nn.Linear(self.lm_embedding_dim + emb_dim, emb_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x_embedding = 0
|
||||
if self.lm_embedding_type is not None:
|
||||
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features + self.lm_embedding_dim
|
||||
else:
|
||||
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features
|
||||
for i in range(self.num_categorical_features):
|
||||
x_embedding += self.atom_embedding_list[i](x[:, i].long())
|
||||
|
||||
if self.num_scalar_features > 0:
|
||||
x_embedding += self.linear(x[:, self.num_categorical_features:self.num_categorical_features + self.num_scalar_features])
|
||||
if self.lm_embedding_type is not None:
|
||||
x_embedding = self.lm_embedding_layer(torch.cat([x_embedding, x[:, -self.lm_embedding_dim:]], axis=1))
|
||||
return x_embedding
|
||||
@@ -3,24 +3,38 @@ import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch_cluster import radius, radius_graph
|
||||
from torch_scatter import scatter_mean
|
||||
from torch_scatter import scatter, scatter_mean
|
||||
import numpy as np
|
||||
from e3nn.nn import BatchNorm
|
||||
|
||||
from models.score_model import AtomEncoder, TensorProductConvLayer, GaussianSmearing
|
||||
from models.layers import GaussianSmearing, OldAtomEncoder, AtomEncoder
|
||||
from models.tensor_layers import OldTensorProductConvLayer
|
||||
from utils import so3, torus
|
||||
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
|
||||
|
||||
AGGREGATORS = {"mean": lambda x: torch.mean(x, dim=1),
|
||||
"max": lambda x: torch.max(x, dim=1)[0],
|
||||
"min": lambda x: torch.min(x, dim=1)[0],
|
||||
"std": lambda x: torch.std(x, dim=1)}
|
||||
|
||||
class TensorProductScoreModel(torch.nn.Module):
|
||||
|
||||
class AAOldModel(torch.nn.Module):
|
||||
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
|
||||
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
|
||||
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
|
||||
scale_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, lm_embedding_type=False, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm=False, num_confidence_outputs=1):
|
||||
super(TensorProductScoreModel, self).__init__()
|
||||
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
|
||||
separate_noise_schedule=False, lm_embedding_type=False, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm = False,
|
||||
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
|
||||
parallel_aggregators="mean max min std", num_confidence_outputs=1, fixed_center_conv=False,
|
||||
no_aminoacid_identities=False, include_miscellaneous_atoms=False, use_old_atom_encoder=False):
|
||||
super(AAOldModel, self).__init__()
|
||||
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
|
||||
if parallel > 1: assert affinity_prediction
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.in_lig_edge_features = in_lig_edge_features
|
||||
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
|
||||
self.sigma_embed_dim = sigma_embed_dim
|
||||
self.lig_max_radius = lig_max_radius
|
||||
self.rec_max_radius = rec_max_radius
|
||||
@@ -32,21 +46,31 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
|
||||
self.ns, self.nv = ns, nv
|
||||
self.scale_by_sigma = scale_by_sigma
|
||||
self.norm_by_sigma = norm_by_sigma
|
||||
self.device = device
|
||||
self.no_torsion = no_torsion
|
||||
self.smooth_edges = smooth_edges
|
||||
self.odd_parity = odd_parity
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.timestep_emb_func = timestep_emb_func
|
||||
self.separate_noise_schedule = separate_noise_schedule
|
||||
self.confidence_mode = confidence_mode
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.asyncronous_noise_schedule = asyncronous_noise_schedule
|
||||
self.affinity_prediction = affinity_prediction
|
||||
self.parallel, self.parallel_aggregators = parallel, parallel_aggregators.split(' ')
|
||||
self.fixed_center_conv = fixed_center_conv
|
||||
self.no_aminoacid_identities = no_aminoacid_identities
|
||||
|
||||
# embedding layers
|
||||
self.lig_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
atom_encoder_class = OldAtomEncoder if use_old_atom_encoder else AtomEncoder
|
||||
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.rec_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
|
||||
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
|
||||
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.atom_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.lr_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
@@ -88,13 +112,19 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
}
|
||||
|
||||
for _ in range(9): # 3 intra & 6 inter per each layer
|
||||
conv_layers.append(TensorProductConvLayer(**parameters))
|
||||
conv_layers.append(OldTensorProductConvLayer(**parameters))
|
||||
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
|
||||
# confidence and affinity prediction layers
|
||||
if self.confidence_mode:
|
||||
output_confidence_dim = num_confidence_outputs
|
||||
if self.affinity_prediction:
|
||||
if self.parallel > 1:
|
||||
output_confidence_dim = 1 + ns
|
||||
else:
|
||||
output_confidence_dim = num_confidence_outputs +1
|
||||
else:
|
||||
output_confidence_dim = num_confidence_outputs
|
||||
|
||||
self.confidence_predictor = nn.Sequential(
|
||||
nn.Linear(2 * self.ns if num_conv_layers >= 3 else self.ns, ns),
|
||||
@@ -108,6 +138,19 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
nn.Linear(ns, output_confidence_dim)
|
||||
)
|
||||
|
||||
if self.parallel > 1:
|
||||
self.affinity_predictor = nn.Sequential(
|
||||
nn.Linear(len(self.parallel_aggregators) * ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, 1)
|
||||
)
|
||||
|
||||
else:
|
||||
# convolution for translational and rotational scores
|
||||
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
|
||||
@@ -118,18 +161,18 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
|
||||
self.final_conv = TensorProductConvLayer(
|
||||
self.final_conv = OldTensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=f'2x1o + 2x1e',
|
||||
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
|
||||
n_edge_features=2 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns), nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns), nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
|
||||
if not no_torsion:
|
||||
# convolution for torsional score
|
||||
@@ -140,47 +183,51 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
|
||||
self.tor_bond_conv = TensorProductConvLayer(
|
||||
self.tor_bond_conv = OldTensorProductConvLayer(
|
||||
in_irreps=self.conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.final_tp_tor.irreps_out,
|
||||
out_irreps=f'{ns}x0o + {ns}x0e',
|
||||
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
|
||||
n_edge_features=3 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tor_final_layer = nn.Sequential(
|
||||
nn.Linear(2 * ns, ns, bias=False),
|
||||
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
if self.no_aminoacid_identities:
|
||||
data['receptor'].x = data['receptor'].x * 0
|
||||
|
||||
if not self.confidence_mode:
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
|
||||
else:
|
||||
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
|
||||
|
||||
# build ligand graph
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh = self.build_lig_conv_graph(data)
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
|
||||
lig_node_attr = self.lig_node_embedding(lig_node_attr)
|
||||
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
|
||||
|
||||
# build receptor graph
|
||||
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh = self.build_rec_conv_graph(data)
|
||||
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
|
||||
rec_node_attr = self.rec_node_embedding(rec_node_attr)
|
||||
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
|
||||
|
||||
# build atom graph
|
||||
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh = self.build_atom_conv_graph(data)
|
||||
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_atom_conv_graph(data)
|
||||
atom_node_attr = self.atom_node_embedding(atom_node_attr)
|
||||
atom_edge_attr = self.atom_edge_embedding(atom_edge_attr)
|
||||
|
||||
# build cross graph
|
||||
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1) if self.dynamic_max_cross else self.cross_max_distance
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, ar_edge_index, ar_edge_attr, ar_edge_sh = self.build_cross_conv_graph(data, cross_cutoff)
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = \
|
||||
self.build_cross_conv_graph(data, cross_cutoff)
|
||||
lr_edge_attr= self.lr_edge_embedding(lr_edge_attr)
|
||||
la_edge_attr = self.la_edge_embedding(la_edge_attr)
|
||||
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
|
||||
@@ -188,47 +235,47 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
for l in range(self.num_conv_layers):
|
||||
# LIGAND updates
|
||||
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns], lig_node_attr[lig_edge_index[1], :self.ns]], -1)
|
||||
lig_update = self.conv_layers[9*l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh)
|
||||
lig_update = self.conv_layers[9*l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
|
||||
|
||||
lr_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[lr_edge_index[0], :self.ns], rec_node_attr[lr_edge_index[1], :self.ns]], -1)
|
||||
lr_update = self.conv_layers[9*l+1](rec_node_attr, lr_edge_index, lr_edge_attr_, lr_edge_sh,
|
||||
out_nodes=lig_node_attr.shape[0])
|
||||
out_nodes=lig_node_attr.shape[0], edge_weight=lr_edge_weight)
|
||||
|
||||
la_edge_attr_ = torch.cat([la_edge_attr, lig_node_attr[la_edge_index[0], :self.ns], atom_node_attr[la_edge_index[1], :self.ns]], -1)
|
||||
la_update = self.conv_layers[9*l+2](atom_node_attr, la_edge_index, la_edge_attr_, la_edge_sh,
|
||||
out_nodes=lig_node_attr.shape[0])
|
||||
out_nodes=lig_node_attr.shape[0], edge_weight=la_edge_weight)
|
||||
|
||||
if l != self.num_conv_layers-1: # last layer optimisation
|
||||
|
||||
# ATOM UPDATES
|
||||
atom_edge_attr_ = torch.cat([atom_edge_attr, atom_node_attr[atom_edge_index[0], :self.ns], atom_node_attr[atom_edge_index[1], :self.ns]], -1)
|
||||
atom_update = self.conv_layers[9*l+3](atom_node_attr, atom_edge_index, atom_edge_attr_, atom_edge_sh)
|
||||
atom_update = self.conv_layers[9*l+3](atom_node_attr, atom_edge_index, atom_edge_attr_, atom_edge_sh, edge_weight=atom_edge_weight)
|
||||
|
||||
al_edge_attr_ = torch.cat([la_edge_attr, atom_node_attr[la_edge_index[1], :self.ns], lig_node_attr[la_edge_index[0], :self.ns]], -1)
|
||||
al_update = self.conv_layers[9*l+4](lig_node_attr, torch.flip(la_edge_index, dims=[0]), al_edge_attr_,
|
||||
la_edge_sh, out_nodes=atom_node_attr.shape[0])
|
||||
la_edge_sh, out_nodes=atom_node_attr.shape[0], edge_weight=la_edge_weight)
|
||||
|
||||
ar_edge_attr_ = torch.cat([ar_edge_attr, atom_node_attr[ar_edge_index[0], :self.ns], rec_node_attr[ar_edge_index[1], :self.ns]],-1)
|
||||
ar_update = self.conv_layers[9*l+5](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh, out_nodes=atom_node_attr.shape[0])
|
||||
ar_update = self.conv_layers[9*l+5](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh, out_nodes=atom_node_attr.shape[0], edge_weight=ar_edge_weight)
|
||||
|
||||
# RECEPTOR updates
|
||||
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_edge_index[0], :self.ns], rec_node_attr[rec_edge_index[1], :self.ns]], -1)
|
||||
rec_update = self.conv_layers[9*l+6](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh)
|
||||
rec_update = self.conv_layers[9*l+6](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
|
||||
|
||||
rl_edge_attr_ = torch.cat([lr_edge_attr, rec_node_attr[lr_edge_index[1], :self.ns], lig_node_attr[lr_edge_index[0], :self.ns]], -1)
|
||||
rl_update = self.conv_layers[9*l+7](lig_node_attr, torch.flip(lr_edge_index, dims=[0]), rl_edge_attr_,
|
||||
lr_edge_sh, out_nodes=rec_node_attr.shape[0])
|
||||
lr_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=lr_edge_weight)
|
||||
|
||||
ra_edge_attr_ = torch.cat([ar_edge_attr, rec_node_attr[ar_edge_index[1], :self.ns], atom_node_attr[ar_edge_index[0], :self.ns]], -1)
|
||||
ra_update = self.conv_layers[9*l+8](atom_node_attr, torch.flip(ar_edge_index, dims=[0]), ra_edge_attr_,
|
||||
ar_edge_sh, out_nodes=rec_node_attr.shape[0])
|
||||
ar_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=ar_edge_weight)
|
||||
|
||||
# padding original features and update features with residual updates
|
||||
lig_node_attr = F.pad(lig_node_attr, (0, lig_update.shape[-1] - lig_node_attr.shape[-1]))
|
||||
lig_node_attr = lig_node_attr + lig_update + la_update + lr_update
|
||||
|
||||
if l != self.num_conv_layers - 1: # last layer optimisation
|
||||
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - rec_node_attr.shape[-1]))
|
||||
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - atom_node_attr.shape[-1]))
|
||||
atom_node_attr = atom_node_attr + atom_update + al_update + ar_update
|
||||
rec_node_attr = F.pad(rec_node_attr, (0, rec_update.shape[-1] - rec_node_attr.shape[-1]))
|
||||
rec_node_attr = rec_node_attr + rec_update + ra_update + rl_update
|
||||
@@ -236,18 +283,36 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
# confidence and affinity prediction
|
||||
if self.confidence_mode:
|
||||
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:]], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch if self.parallel == 1 else data['ligand'].batch_parallel, dim=0)).squeeze(dim=-1)
|
||||
|
||||
if self.parallel > 1:
|
||||
confidence, affinity = confidence[:, 0], confidence[:, 1:]
|
||||
confidence = confidence.reshape(data.num_graphs, self.parallel)
|
||||
affinity = affinity.reshape(data.num_graphs, self.parallel, -1)
|
||||
affinity = torch.cat([AGGREGATORS[agg](affinity) for agg in self.parallel_aggregators], dim=-1)
|
||||
affinity = self.affinity_predictor(affinity).squeeze(dim=-1)
|
||||
confidence = confidence, affinity
|
||||
return confidence
|
||||
assert self.parallel == 1
|
||||
|
||||
# compute translational and rotational score vectors
|
||||
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
|
||||
center_edge_attr = self.center_edge_embedding(center_edge_attr)
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
if self.fixed_center_conv:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
else:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
|
||||
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
|
||||
|
||||
tr_pred = global_pred[:, :3] + global_pred[:, 6:9]
|
||||
rot_pred = global_pred[:, 3:6] + global_pred[:, 9:]
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
|
||||
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
|
||||
|
||||
if self.separate_noise_schedule:
|
||||
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
|
||||
else: # tr rot and tor noise is all the same in this case
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# adjust the magniture of the score vectors
|
||||
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
|
||||
@@ -263,7 +328,7 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0,device=self.device)
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh = self.build_bond_conv_graph(data)
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
@@ -273,7 +338,7 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean')
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
@@ -282,11 +347,34 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
.to(data['ligand'].x.device))
|
||||
return tr_pred, rot_pred, tor_pred
|
||||
|
||||
def get_edge_weight(self, edge_vec, max_norm):
|
||||
if self.smooth_edges:
|
||||
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
|
||||
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
|
||||
return 1.0
|
||||
|
||||
def build_lig_conv_graph(self, data):
|
||||
# build the graph between ligand atoms
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr'])
|
||||
if self.separate_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = torch.cat(
|
||||
[self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
|
||||
dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
|
||||
else:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(
|
||||
data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
if self.parallel == 1:
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
else:
|
||||
batches = torch.zeros(data.num_graphs, device=data['ligand'].x.device).long()
|
||||
batches = batches.index_add(0, data['ligand'].batch, torch.ones(len(data['ligand'].batch), device=data['ligand'].x.device).long())
|
||||
outer_batches = data.num_graphs
|
||||
b = [torch.ones(batches[i].item()//self.parallel, device=data['ligand'].x.device).long() * (self.parallel * i + j)
|
||||
for i in range(outer_batches) for j in range(self.parallel)]
|
||||
data['ligand'].batch_parallel = torch.cat(b)
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch_parallel)
|
||||
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
|
||||
edge_attr = torch.cat([
|
||||
data['ligand', 'ligand'].edge_attr,
|
||||
@@ -303,29 +391,45 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
|
||||
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_rec_conv_graph(self, data):
|
||||
# build the graph between receptor residues
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr'])
|
||||
if self.separate_noise_schedule:
|
||||
data['receptor'].node_sigma_emb = torch.cat(
|
||||
[self.timestep_emb_func(data['receptor'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
|
||||
dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['t'])
|
||||
else:
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(
|
||||
data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['receptor', 'receptor'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
|
||||
#assert torch.all(edge_vec.norm(dim=-1) < self.rec_max_radius)
|
||||
|
||||
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_atom_conv_graph(self, data):
|
||||
# build the graph between receptor atoms
|
||||
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['tr'])
|
||||
if self.separate_noise_schedule:
|
||||
data['atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['t'])
|
||||
else:
|
||||
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['atom'].x, data['atom'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
@@ -337,8 +441,9 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
edge_sigma_emb = data['atom'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_cross_conv_graph(self, data, lr_cross_distance_cutoff):
|
||||
# build the cross edges between ligan atoms, receptor residues and receptor atoms
|
||||
@@ -361,6 +466,7 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
|
||||
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
|
||||
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
|
||||
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
|
||||
|
||||
# LIGAND to ATOM
|
||||
la_edge_index = radius(data['atom'].pos, data['ligand'].pos, self.lig_max_radius,
|
||||
@@ -371,6 +477,7 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
|
||||
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
|
||||
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
|
||||
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
|
||||
|
||||
# ATOM to RECEPTOR
|
||||
ar_edge_index = data['atom', 'receptor'].edge_index
|
||||
@@ -379,9 +486,10 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
ar_edge_sigma_emb = data['atom'].node_sigma_emb[ar_edge_index[0].long()]
|
||||
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
|
||||
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
|
||||
ar_edge_weight = 1
|
||||
|
||||
return lr_edge_index, lr_edge_attr, lr_edge_sh, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, ar_edge_index, ar_edge_attr, ar_edge_sh
|
||||
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
|
||||
|
||||
def build_center_conv_graph(self, data):
|
||||
# build the filter for the convolution of the center with the ligand atoms
|
||||
@@ -411,5 +519,6 @@ class TensorProductScoreModel(torch.nn.Module):
|
||||
|
||||
edge_attr = self.final_edge_embedding(edge_attr)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return bonds, edge_index, edge_attr, edge_sh
|
||||
return bonds, edge_index, edge_attr, edge_sh, edge_weight
|
||||
538
models/old_cg_model.py
Normal file
538
models/old_cg_model.py
Normal file
@@ -0,0 +1,538 @@
|
||||
import math
|
||||
|
||||
from e3nn import o3
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch_cluster import radius, radius_graph
|
||||
from torch_scatter import scatter, scatter_mean
|
||||
import numpy as np
|
||||
from e3nn.nn import BatchNorm
|
||||
|
||||
from models.layers import OldAtomEncoder, AtomEncoder, GaussianSmearing
|
||||
from models.tensor_layers import OldTensorProductConvLayer
|
||||
from utils import so3, torus
|
||||
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
|
||||
|
||||
|
||||
class CGOldModel(torch.nn.Module):
|
||||
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
|
||||
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
|
||||
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
|
||||
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
|
||||
separate_noise_schedule=False, lm_embedding_type=None, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm=False,
|
||||
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
|
||||
parallel_aggregators="mean max min std", num_confidence_outputs=1, fixed_center_conv=False,
|
||||
no_aminoacid_identities=False, include_miscellaneous_atoms=False, use_old_atom_encoder=False):
|
||||
super(CGOldModel, self).__init__()
|
||||
assert parallel == 1, "not implemented"
|
||||
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.in_lig_edge_features = in_lig_edge_features
|
||||
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
|
||||
self.sigma_embed_dim = sigma_embed_dim
|
||||
self.lig_max_radius = lig_max_radius
|
||||
self.rec_max_radius = rec_max_radius
|
||||
self.include_miscellaneous_atoms = include_miscellaneous_atoms
|
||||
self.cross_max_distance = cross_max_distance
|
||||
self.dynamic_max_cross = dynamic_max_cross
|
||||
self.center_max_distance = center_max_distance
|
||||
self.distance_embed_dim = distance_embed_dim
|
||||
self.cross_distance_embed_dim = cross_distance_embed_dim
|
||||
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
|
||||
self.ns, self.nv = ns, nv
|
||||
self.scale_by_sigma = scale_by_sigma
|
||||
self.norm_by_sigma = norm_by_sigma
|
||||
self.device = device
|
||||
self.no_torsion = no_torsion
|
||||
self.smooth_edges = smooth_edges
|
||||
self.odd_parity = odd_parity
|
||||
self.timestep_emb_func = timestep_emb_func
|
||||
self.separate_noise_schedule = separate_noise_schedule
|
||||
self.confidence_mode = confidence_mode
|
||||
self.num_conv_layers = num_conv_layers
|
||||
self.asyncronous_noise_schedule = asyncronous_noise_schedule
|
||||
self.affinity_prediction = affinity_prediction
|
||||
self.fixed_center_conv = fixed_center_conv
|
||||
self.no_aminoacid_identities = no_aminoacid_identities
|
||||
|
||||
atom_encoder_class = OldAtomEncoder if use_old_atom_encoder else AtomEncoder
|
||||
|
||||
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
|
||||
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
if self.include_miscellaneous_atoms:
|
||||
self.misc_atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.misc_atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.ar_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
|
||||
|
||||
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
|
||||
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
|
||||
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
|
||||
|
||||
if use_second_order_repr:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {ns}x0o'
|
||||
]
|
||||
else:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e + {ns}x0o'
|
||||
]
|
||||
|
||||
lig_conv_layers, rec_conv_layers, lig_to_rec_conv_layers, rec_to_lig_conv_layers = [], [], [], []
|
||||
if self.include_miscellaneous_atoms:
|
||||
misc_conv_layers, la_conv_layers, ra_conv_layers, al_conv_layers, ar_conv_layers = [], [], [], [], []
|
||||
for i in range(num_conv_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
parameters = {
|
||||
'in_irreps': in_irreps,
|
||||
'sh_irreps': self.sh_irreps,
|
||||
'out_irreps': out_irreps,
|
||||
'n_edge_features': 3 * ns,
|
||||
'hidden_features': 3 * ns,
|
||||
'residual': False,
|
||||
'batch_norm': batch_norm,
|
||||
'dropout': dropout
|
||||
}
|
||||
|
||||
lig_layer = OldTensorProductConvLayer(**parameters)
|
||||
lig_conv_layers.append(lig_layer)
|
||||
rec_layer = OldTensorProductConvLayer(**parameters)
|
||||
rec_conv_layers.append(rec_layer)
|
||||
lig_to_rec_layer = OldTensorProductConvLayer(**parameters)
|
||||
lig_to_rec_conv_layers.append(lig_to_rec_layer)
|
||||
rec_to_lig_layer = OldTensorProductConvLayer(**parameters)
|
||||
rec_to_lig_conv_layers.append(rec_to_lig_layer)
|
||||
if self.include_miscellaneous_atoms:
|
||||
misc_conv_layer = OldTensorProductConvLayer(**parameters)
|
||||
la_conv_layer = OldTensorProductConvLayer(**parameters)
|
||||
ra_conv_layer = OldTensorProductConvLayer(**parameters)
|
||||
al_conv_layer = OldTensorProductConvLayer(**parameters)
|
||||
ar_conv_layer = OldTensorProductConvLayer(**parameters)
|
||||
misc_conv_layers.append(misc_conv_layer)
|
||||
la_conv_layers.append(la_conv_layer)
|
||||
ra_conv_layers.append(ra_conv_layer)
|
||||
al_conv_layers.append(al_conv_layer)
|
||||
ar_conv_layers.append(ar_conv_layer)
|
||||
|
||||
self.lig_conv_layers = nn.ModuleList(lig_conv_layers)
|
||||
self.rec_conv_layers = nn.ModuleList(rec_conv_layers)
|
||||
self.lig_to_rec_conv_layers = nn.ModuleList(lig_to_rec_conv_layers)
|
||||
self.rec_to_lig_conv_layers = nn.ModuleList(rec_to_lig_conv_layers)
|
||||
if self.include_miscellaneous_atoms:
|
||||
self.misc_conv_layers = nn.ModuleList(misc_conv_layers)
|
||||
self.la_conv_layers = nn.ModuleList(la_conv_layers)
|
||||
self.ra_conv_layers = nn.ModuleList(ra_conv_layers)
|
||||
self.al_conv_layers = nn.ModuleList(al_conv_layers)
|
||||
self.ar_conv_layers = nn.ModuleList(ar_conv_layers)
|
||||
|
||||
if self.confidence_mode:
|
||||
self.confidence_predictor = nn.Sequential(
|
||||
nn.Linear(2*self.ns if num_conv_layers >= 3 else self.ns,ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, 2 if self.affinity_prediction else 1)
|
||||
)
|
||||
else:
|
||||
# center of mass translation and rotation components
|
||||
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
|
||||
self.center_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
|
||||
self.final_conv = OldTensorProductConvLayer(
|
||||
in_irreps=self.lig_conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
|
||||
n_edge_features=2 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
|
||||
if not no_torsion:
|
||||
# torsion angles components
|
||||
self.final_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
|
||||
self.tor_bond_conv = OldTensorProductConvLayer(
|
||||
in_irreps=self.lig_conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.final_tp_tor.irreps_out,
|
||||
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
|
||||
n_edge_features=3 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tor_final_layer = nn.Sequential(
|
||||
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
if self.no_aminoacid_identities:
|
||||
data['receptor'].x = data['receptor'].x * 0
|
||||
|
||||
if not self.confidence_mode:
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
|
||||
else:
|
||||
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
|
||||
|
||||
# build ligand graph
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
|
||||
lig_src, lig_dst = lig_edge_index
|
||||
lig_node_attr = self.lig_node_embedding(lig_node_attr)
|
||||
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
|
||||
|
||||
# build receptor graph
|
||||
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
|
||||
rec_src, rec_dst = rec_edge_index
|
||||
rec_node_attr = self.rec_node_embedding(rec_node_attr)
|
||||
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
|
||||
if self.include_miscellaneous_atoms:
|
||||
# build misc_atom graph
|
||||
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_misc_atom_conv_graph(data)
|
||||
atom_node_attr = self.misc_atom_node_embedding(atom_node_attr)
|
||||
atom_edge_attr = self.misc_atom_edge_embedding(atom_edge_attr)
|
||||
|
||||
# build cross graph
|
||||
if self.dynamic_max_cross:
|
||||
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
|
||||
else:
|
||||
cross_cutoff = self.cross_max_distance
|
||||
if self.include_miscellaneous_atoms:
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = \
|
||||
self.build_misc_cross_conv_graph(data, cross_cutoff)
|
||||
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
|
||||
la_edge_attr = self.la_edge_embedding(la_edge_attr)
|
||||
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
|
||||
cross_lig, cross_rec = lr_edge_index
|
||||
else:
|
||||
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight = self.build_cross_conv_graph(data, cross_cutoff)
|
||||
cross_lig, cross_rec = lr_edge_index
|
||||
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
|
||||
|
||||
for l in range(len(self.lig_conv_layers)):
|
||||
# intra graph message passing
|
||||
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_src, :self.ns], lig_node_attr[lig_dst, :self.ns]], -1)
|
||||
lig_intra_update = self.lig_conv_layers[l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
|
||||
|
||||
# inter graph message passing
|
||||
rec_to_lig_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
|
||||
lig_inter_update = self.rec_to_lig_conv_layers[l](rec_node_attr, lr_edge_index, rec_to_lig_edge_attr_, lr_edge_sh,
|
||||
out_nodes=lig_node_attr.shape[0], edge_weight=lr_edge_weight)
|
||||
if self.include_miscellaneous_atoms:
|
||||
la_edge_attr_ = torch.cat([la_edge_attr, lig_node_attr[la_edge_index[0], :self.ns],atom_node_attr[la_edge_index[1], :self.ns]], -1)
|
||||
la_update = self.la_conv_layers[l](atom_node_attr, la_edge_index, la_edge_attr_, la_edge_sh,out_nodes=lig_node_attr.shape[0], edge_weight=la_edge_weight)
|
||||
|
||||
if l != len(self.lig_conv_layers) - 1:
|
||||
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_src, :self.ns], rec_node_attr[rec_dst, :self.ns]], -1)
|
||||
rec_intra_update = self.rec_conv_layers[l](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
|
||||
|
||||
lig_to_rec_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
|
||||
rl_update = self.lig_to_rec_conv_layers[l](lig_node_attr, torch.flip(lr_edge_index, dims=[0]),lig_to_rec_edge_attr_,lr_edge_sh, out_nodes=rec_node_attr.shape[0],edge_weight=lr_edge_weight)
|
||||
if self.include_miscellaneous_atoms:
|
||||
# ATOM UPDATES
|
||||
atom_edge_attr_ = torch.cat([atom_edge_attr, atom_node_attr[atom_edge_index[0], :self.ns],atom_node_attr[atom_edge_index[1], :self.ns]], -1)
|
||||
atom_update = self.misc_conv_layers[l](atom_node_attr, atom_edge_index, atom_edge_attr_,atom_edge_sh, edge_weight=atom_edge_weight)
|
||||
|
||||
al_edge_attr_ = torch.cat([la_edge_attr, atom_node_attr[la_edge_index[1], :self.ns],lig_node_attr[la_edge_index[0], :self.ns]], -1)
|
||||
al_update = self.al_conv_layers[l](lig_node_attr, torch.flip(la_edge_index, dims=[0]),al_edge_attr_,la_edge_sh, out_nodes=atom_node_attr.shape[0],edge_weight=la_edge_weight)
|
||||
|
||||
ar_edge_attr_ = torch.cat([ar_edge_attr, atom_node_attr[ar_edge_index[0], :self.ns],rec_node_attr[ar_edge_index[1], :self.ns]], -1)
|
||||
ar_update = self.ar_conv_layers[l](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh,out_nodes=atom_node_attr.shape[0],edge_weight=ar_edge_weight)
|
||||
|
||||
ra_edge_attr_ = torch.cat([ar_edge_attr, rec_node_attr[ar_edge_index[1], :self.ns],atom_node_attr[ar_edge_index[0], :self.ns]], -1)
|
||||
ra_update = self.ra_conv_layers[l](atom_node_attr, torch.flip(ar_edge_index, dims=[0]), ra_edge_attr_, ar_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=ar_edge_weight)
|
||||
|
||||
# padding original features
|
||||
lig_node_attr = F.pad(lig_node_attr, (0, lig_intra_update.shape[-1] - lig_node_attr.shape[-1]))
|
||||
|
||||
# update features with residual updates
|
||||
lig_node_attr = lig_node_attr + lig_intra_update + lig_inter_update
|
||||
if self.include_miscellaneous_atoms:
|
||||
lig_node_attr += la_update
|
||||
|
||||
if l != len(self.lig_conv_layers) - 1:
|
||||
rec_node_attr = F.pad(rec_node_attr, (0, rec_intra_update.shape[-1] - rec_node_attr.shape[-1]))
|
||||
rec_node_attr = rec_node_attr + rec_intra_update + rl_update
|
||||
if self.include_miscellaneous_atoms:
|
||||
rec_node_attr += ra_update
|
||||
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - atom_node_attr.shape[-1]))
|
||||
atom_node_attr = atom_node_attr + atom_update + al_update + ar_update
|
||||
|
||||
# compute confidence score
|
||||
if self.confidence_mode:
|
||||
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:] ], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
|
||||
return confidence
|
||||
|
||||
# compute translational and rotational score vectors
|
||||
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
|
||||
center_edge_attr = self.center_edge_embedding(center_edge_attr)
|
||||
if self.fixed_center_conv:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
else:
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
|
||||
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
|
||||
|
||||
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
|
||||
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
|
||||
|
||||
if self.separate_noise_schedule:
|
||||
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
|
||||
else: # tr rot and tor noise is all the same in this case
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# fix the magnitude of translational and rotational score vectors
|
||||
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
|
||||
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
|
||||
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
|
||||
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
|
||||
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
|
||||
|
||||
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device)
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
|
||||
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
|
||||
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
|
||||
.to(data['ligand'].x.device))
|
||||
return tr_pred, rot_pred, tor_pred
|
||||
|
||||
def get_edge_weight(self, edge_vec, max_norm):
|
||||
# computes weights for edges that are decreasing with the distance
|
||||
# it has an effect only if smooth edges is true
|
||||
if self.smooth_edges:
|
||||
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
|
||||
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
|
||||
return 1.0
|
||||
|
||||
def build_lig_conv_graph(self, data):
|
||||
# builds the ligand graph edges and initial node and edge features
|
||||
if self.separate_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
|
||||
else:
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
|
||||
# compute edges
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
|
||||
edge_attr = torch.cat([
|
||||
data['ligand', 'ligand'].edge_attr,
|
||||
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
|
||||
], 0)
|
||||
|
||||
# compute initial features
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_rec_conv_graph(self, data):
|
||||
# builds the receptor initial node and edge embeddings
|
||||
if self.separate_noise_schedule:
|
||||
data['receptor'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['receptor'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['t'])
|
||||
else:
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['receptor', 'receptor'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_misc_atom_conv_graph(self, data):
|
||||
# build the graph between receptor misc_atoms
|
||||
if self.separate_noise_schedule:
|
||||
data['misc_atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['misc_atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
|
||||
elif self.asyncronous_noise_schedule:
|
||||
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['t'])
|
||||
else:
|
||||
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['misc_atom'].x, data['misc_atom'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['misc_atom', 'misc_atom'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['misc_atom'].pos[dst.long()] - data['misc_atom'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['misc_atom'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_cross_conv_graph(self, data, cross_distance_cutoff):
|
||||
# builds the cross edges between ligand and receptor
|
||||
if torch.is_tensor(cross_distance_cutoff):
|
||||
# different cutoff for every graph (depends on the diffusion time)
|
||||
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
|
||||
cutoff_d = cross_distance_cutoff[data['ligand'].batch[src]].squeeze() if torch.is_tensor(cross_distance_cutoff) else cross_distance_cutoff
|
||||
edge_weight = self.get_edge_weight(edge_vec, cutoff_d)
|
||||
|
||||
return edge_index, edge_attr, edge_sh, edge_weight
|
||||
|
||||
def build_misc_cross_conv_graph(self, data, lr_cross_distance_cutoff):
|
||||
# build the cross edges between ligan atoms, receptor residues and receptor atoms
|
||||
|
||||
# LIGAND to RECEPTOR
|
||||
if torch.is_tensor(lr_cross_distance_cutoff):
|
||||
# different cutoff for every graph
|
||||
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
|
||||
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
|
||||
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
|
||||
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
|
||||
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
|
||||
|
||||
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
|
||||
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
|
||||
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
|
||||
|
||||
# LIGAND to ATOM
|
||||
la_edge_index = radius(data['misc_atom'].pos, data['ligand'].pos, self.lig_max_radius,
|
||||
data['misc_atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
la_edge_vec = data['misc_atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
|
||||
la_edge_length_emb = self.cross_distance_expansion(la_edge_vec.norm(dim=-1))
|
||||
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
|
||||
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
|
||||
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
|
||||
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
|
||||
|
||||
# ATOM to RECEPTOR
|
||||
ar_edge_index = data['misc_atom', 'receptor'].edge_index
|
||||
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['misc_atom'].pos[ar_edge_index[0].long()]
|
||||
ar_edge_length_emb = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
|
||||
ar_edge_sigma_emb = data['misc_atom'].node_sigma_emb[ar_edge_index[0].long()]
|
||||
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
|
||||
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
|
||||
ar_edge_weight = 1
|
||||
|
||||
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
|
||||
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
|
||||
|
||||
def build_center_conv_graph(self, data):
|
||||
# builds the filter and edges for the convolution generating translational and rotational scores
|
||||
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
|
||||
|
||||
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
|
||||
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
|
||||
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
|
||||
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
return edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_bond_conv_graph(self, data):
|
||||
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
|
||||
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
|
||||
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
|
||||
bond_batch = data['ligand'].batch[bonds[0]]
|
||||
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
|
||||
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = self.final_edge_embedding(edge_attr)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
|
||||
|
||||
return bonds, edge_index, edge_attr, edge_sh, edge_weight
|
||||
@@ -1,442 +0,0 @@
|
||||
import math
|
||||
|
||||
from e3nn import o3
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch_cluster import radius, radius_graph
|
||||
from torch_scatter import scatter, scatter_mean
|
||||
import numpy as np
|
||||
from e3nn.nn import BatchNorm
|
||||
|
||||
from utils import so3, torus
|
||||
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims
|
||||
|
||||
|
||||
class AtomEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type= None):
|
||||
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
|
||||
super(AtomEncoder, self).__init__()
|
||||
self.atom_embedding_list = torch.nn.ModuleList()
|
||||
self.num_categorical_features = len(feature_dims[0])
|
||||
self.num_scalar_features = feature_dims[1] + sigma_embed_dim
|
||||
self.lm_embedding_type = lm_embedding_type
|
||||
for i, dim in enumerate(feature_dims[0]):
|
||||
emb = torch.nn.Embedding(dim, emb_dim)
|
||||
torch.nn.init.xavier_uniform_(emb.weight.data)
|
||||
self.atom_embedding_list.append(emb)
|
||||
|
||||
if self.num_scalar_features > 0:
|
||||
self.linear = torch.nn.Linear(self.num_scalar_features, emb_dim)
|
||||
if self.lm_embedding_type is not None:
|
||||
if self.lm_embedding_type == 'esm':
|
||||
self.lm_embedding_dim = 1280
|
||||
else: raise ValueError('LM Embedding type was not correctly determined. LM embedding type: ', self.lm_embedding_type)
|
||||
self.lm_embedding_layer = torch.nn.Linear(self.lm_embedding_dim + emb_dim, emb_dim)
|
||||
|
||||
def forward(self, x):
|
||||
x_embedding = 0
|
||||
if self.lm_embedding_type is not None:
|
||||
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features + self.lm_embedding_dim
|
||||
else:
|
||||
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features
|
||||
for i in range(self.num_categorical_features):
|
||||
x_embedding += self.atom_embedding_list[i](x[:, i].long())
|
||||
|
||||
if self.num_scalar_features > 0:
|
||||
x_embedding += self.linear(x[:, self.num_categorical_features:self.num_categorical_features + self.num_scalar_features])
|
||||
if self.lm_embedding_type is not None:
|
||||
x_embedding = self.lm_embedding_layer(torch.cat([x_embedding, x[:, -self.lm_embedding_dim:]], axis=1))
|
||||
return x_embedding
|
||||
|
||||
|
||||
class TensorProductConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
|
||||
hidden_features=None):
|
||||
super(TensorProductConvLayer, self).__init__()
|
||||
self.in_irreps = in_irreps
|
||||
self.out_irreps = out_irreps
|
||||
self.sh_irreps = sh_irreps
|
||||
self.residual = residual
|
||||
if hidden_features is None:
|
||||
hidden_features = n_edge_features
|
||||
|
||||
self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(n_edge_features, hidden_features),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_features, tp.weight_numel)
|
||||
)
|
||||
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
|
||||
|
||||
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean'):
|
||||
|
||||
edge_src, edge_dst = edge_index
|
||||
tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr))
|
||||
|
||||
out_nodes = out_nodes or node_attr.shape[0]
|
||||
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
|
||||
|
||||
if self.residual:
|
||||
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
|
||||
out = out + padded
|
||||
|
||||
if self.batch_norm:
|
||||
out = self.batch_norm(out)
|
||||
return out
|
||||
|
||||
|
||||
class TensorProductScoreModel(torch.nn.Module):
|
||||
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
|
||||
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
|
||||
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
|
||||
scale_by_sigma=True, use_second_order_repr=False, batch_norm=True,
|
||||
dynamic_max_cross=False, dropout=0.0, lm_embedding_type=None, confidence_mode=False,
|
||||
confidence_dropout=0, confidence_no_batchnorm=False, num_confidence_outputs=1):
|
||||
super(TensorProductScoreModel, self).__init__()
|
||||
self.t_to_sigma = t_to_sigma
|
||||
self.in_lig_edge_features = in_lig_edge_features
|
||||
self.sigma_embed_dim = sigma_embed_dim
|
||||
self.lig_max_radius = lig_max_radius
|
||||
self.rec_max_radius = rec_max_radius
|
||||
self.cross_max_distance = cross_max_distance
|
||||
self.dynamic_max_cross = dynamic_max_cross
|
||||
self.center_max_distance = center_max_distance
|
||||
self.distance_embed_dim = distance_embed_dim
|
||||
self.cross_distance_embed_dim = cross_distance_embed_dim
|
||||
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
|
||||
self.ns, self.nv = ns, nv
|
||||
self.scale_by_sigma = scale_by_sigma
|
||||
self.device = device
|
||||
self.no_torsion = no_torsion
|
||||
self.timestep_emb_func = timestep_emb_func
|
||||
self.confidence_mode = confidence_mode
|
||||
self.num_conv_layers = num_conv_layers
|
||||
|
||||
self.lig_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
|
||||
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.rec_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
|
||||
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
|
||||
|
||||
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
|
||||
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
|
||||
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
|
||||
|
||||
if use_second_order_repr:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {ns}x0o'
|
||||
]
|
||||
else:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e + {ns}x0o'
|
||||
]
|
||||
|
||||
lig_conv_layers, rec_conv_layers, lig_to_rec_conv_layers, rec_to_lig_conv_layers = [], [], [], []
|
||||
for i in range(num_conv_layers):
|
||||
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
|
||||
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
|
||||
parameters = {
|
||||
'in_irreps': in_irreps,
|
||||
'sh_irreps': self.sh_irreps,
|
||||
'out_irreps': out_irreps,
|
||||
'n_edge_features': 3 * ns,
|
||||
'hidden_features': 3 * ns,
|
||||
'residual': False,
|
||||
'batch_norm': batch_norm,
|
||||
'dropout': dropout
|
||||
}
|
||||
|
||||
lig_layer = TensorProductConvLayer(**parameters)
|
||||
lig_conv_layers.append(lig_layer)
|
||||
rec_layer = TensorProductConvLayer(**parameters)
|
||||
rec_conv_layers.append(rec_layer)
|
||||
lig_to_rec_layer = TensorProductConvLayer(**parameters)
|
||||
lig_to_rec_conv_layers.append(lig_to_rec_layer)
|
||||
rec_to_lig_layer = TensorProductConvLayer(**parameters)
|
||||
rec_to_lig_conv_layers.append(rec_to_lig_layer)
|
||||
|
||||
self.lig_conv_layers = nn.ModuleList(lig_conv_layers)
|
||||
self.rec_conv_layers = nn.ModuleList(rec_conv_layers)
|
||||
self.lig_to_rec_conv_layers = nn.ModuleList(lig_to_rec_conv_layers)
|
||||
self.rec_to_lig_conv_layers = nn.ModuleList(rec_to_lig_conv_layers)
|
||||
|
||||
if self.confidence_mode:
|
||||
self.confidence_predictor = nn.Sequential(
|
||||
nn.Linear(2*self.ns if num_conv_layers >= 3 else self.ns,ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, ns),
|
||||
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(confidence_dropout),
|
||||
nn.Linear(ns, num_confidence_outputs)
|
||||
)
|
||||
else:
|
||||
# center of mass translation and rotation components
|
||||
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
|
||||
self.center_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
|
||||
self.final_conv = TensorProductConvLayer(
|
||||
in_irreps=self.lig_conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.sh_irreps,
|
||||
out_irreps=f'2x1o + 2x1e',
|
||||
n_edge_features=2 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
|
||||
|
||||
if not no_torsion:
|
||||
# torsion angles components
|
||||
self.final_edge_embedding = nn.Sequential(
|
||||
nn.Linear(distance_embed_dim, ns),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, ns)
|
||||
)
|
||||
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
|
||||
self.tor_bond_conv = TensorProductConvLayer(
|
||||
in_irreps=self.lig_conv_layers[-1].out_irreps,
|
||||
sh_irreps=self.final_tp_tor.irreps_out,
|
||||
out_irreps=f'{ns}x0o + {ns}x0e',
|
||||
n_edge_features=3 * ns,
|
||||
residual=False,
|
||||
dropout=dropout,
|
||||
batch_norm=batch_norm
|
||||
)
|
||||
self.tor_final_layer = nn.Sequential(
|
||||
nn.Linear(2 * ns, ns, bias=False),
|
||||
nn.Tanh(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ns, 1, bias=False)
|
||||
)
|
||||
|
||||
def forward(self, data):
|
||||
if not self.confidence_mode:
|
||||
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
|
||||
else:
|
||||
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
|
||||
|
||||
# build ligand graph
|
||||
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh = self.build_lig_conv_graph(data)
|
||||
lig_src, lig_dst = lig_edge_index
|
||||
lig_node_attr = self.lig_node_embedding(lig_node_attr)
|
||||
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
|
||||
|
||||
# build receptor graph
|
||||
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh = self.build_rec_conv_graph(data)
|
||||
rec_src, rec_dst = rec_edge_index
|
||||
rec_node_attr = self.rec_node_embedding(rec_node_attr)
|
||||
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
|
||||
|
||||
# build cross graph
|
||||
if self.dynamic_max_cross:
|
||||
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
|
||||
else:
|
||||
cross_cutoff = self.cross_max_distance
|
||||
cross_edge_index, cross_edge_attr, cross_edge_sh = self.build_cross_conv_graph(data, cross_cutoff)
|
||||
cross_lig, cross_rec = cross_edge_index
|
||||
cross_edge_attr = self.cross_edge_embedding(cross_edge_attr)
|
||||
|
||||
for l in range(len(self.lig_conv_layers)):
|
||||
# intra graph message passing
|
||||
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_src, :self.ns], lig_node_attr[lig_dst, :self.ns]], -1)
|
||||
lig_intra_update = self.lig_conv_layers[l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh)
|
||||
|
||||
# inter graph message passing
|
||||
rec_to_lig_edge_attr_ = torch.cat([cross_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
|
||||
lig_inter_update = self.rec_to_lig_conv_layers[l](rec_node_attr, cross_edge_index, rec_to_lig_edge_attr_, cross_edge_sh,
|
||||
out_nodes=lig_node_attr.shape[0])
|
||||
|
||||
if l != len(self.lig_conv_layers) - 1:
|
||||
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_src, :self.ns], rec_node_attr[rec_dst, :self.ns]], -1)
|
||||
rec_intra_update = self.rec_conv_layers[l](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh)
|
||||
|
||||
lig_to_rec_edge_attr_ = torch.cat([cross_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
|
||||
rec_inter_update = self.lig_to_rec_conv_layers[l](lig_node_attr, torch.flip(cross_edge_index, dims=[0]), lig_to_rec_edge_attr_,
|
||||
cross_edge_sh, out_nodes=rec_node_attr.shape[0])
|
||||
|
||||
# padding original features
|
||||
lig_node_attr = F.pad(lig_node_attr, (0, lig_intra_update.shape[-1] - lig_node_attr.shape[-1]))
|
||||
|
||||
# update features with residual updates
|
||||
lig_node_attr = lig_node_attr + lig_intra_update + lig_inter_update
|
||||
|
||||
if l != len(self.lig_conv_layers) - 1:
|
||||
rec_node_attr = F.pad(rec_node_attr, (0, rec_intra_update.shape[-1] - rec_node_attr.shape[-1]))
|
||||
rec_node_attr = rec_node_attr + rec_intra_update + rec_inter_update
|
||||
|
||||
# compute confidence score
|
||||
if self.confidence_mode:
|
||||
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:] ], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
|
||||
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
|
||||
return confidence
|
||||
|
||||
# compute translational and rotational score vectors
|
||||
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
|
||||
center_edge_attr = self.center_edge_embedding(center_edge_attr)
|
||||
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
|
||||
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
|
||||
|
||||
tr_pred = global_pred[:, :3] + global_pred[:, 6:9]
|
||||
rot_pred = global_pred[:, 3:6] + global_pred[:, 9:]
|
||||
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
|
||||
|
||||
# fix the magnitude of translational and rotational score vectors
|
||||
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
|
||||
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
|
||||
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
|
||||
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
|
||||
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
|
||||
|
||||
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device)
|
||||
|
||||
# torsional components
|
||||
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh = self.build_bond_conv_graph(data)
|
||||
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
|
||||
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
|
||||
|
||||
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
|
||||
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
|
||||
|
||||
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
|
||||
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
|
||||
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
|
||||
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean')
|
||||
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
|
||||
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
|
||||
|
||||
if self.scale_by_sigma:
|
||||
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
|
||||
.to(data['ligand'].x.device))
|
||||
return tr_pred, rot_pred, tor_pred
|
||||
|
||||
def build_lig_conv_graph(self, data):
|
||||
# builds the ligand graph edges and initial node and edge features
|
||||
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr'])
|
||||
|
||||
# compute edges
|
||||
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
|
||||
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
|
||||
edge_attr = torch.cat([
|
||||
data['ligand', 'ligand'].edge_attr,
|
||||
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
|
||||
], 0)
|
||||
|
||||
# compute initial features
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_rec_conv_graph(self, data):
|
||||
# builds the receptor initial node and edge embeddings
|
||||
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
|
||||
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
|
||||
|
||||
# this assumes the edges were already created in preprocessing since protein's structure is fixed
|
||||
edge_index = data['receptor', 'receptor'].edge_index
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
|
||||
return node_attr, edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_cross_conv_graph(self, data, cross_distance_cutoff):
|
||||
# builds the cross edges between ligand and receptor
|
||||
if torch.is_tensor(cross_distance_cutoff):
|
||||
# different cutoff for every graph (depends on the diffusion time)
|
||||
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
|
||||
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
else:
|
||||
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
|
||||
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
|
||||
|
||||
src, dst = edge_index
|
||||
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
|
||||
|
||||
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
|
||||
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
|
||||
return edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_center_conv_graph(self, data):
|
||||
# builds the filter and edges for the convolution generating translational and rotational scores
|
||||
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
|
||||
|
||||
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
|
||||
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
|
||||
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
|
||||
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
|
||||
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
|
||||
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
return edge_index, edge_attr, edge_sh
|
||||
|
||||
def build_bond_conv_graph(self, data):
|
||||
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
|
||||
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
|
||||
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
|
||||
bond_batch = data['ligand'].batch[bonds[0]]
|
||||
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
|
||||
|
||||
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
|
||||
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
|
||||
|
||||
edge_attr = self.final_edge_embedding(edge_attr)
|
||||
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
|
||||
|
||||
return bonds, edge_index, edge_attr, edge_sh
|
||||
|
||||
|
||||
class GaussianSmearing(torch.nn.Module):
|
||||
# used to embed the edge distances
|
||||
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
|
||||
super().__init__()
|
||||
offset = torch.linspace(start, stop, num_gaussians)
|
||||
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
||||
self.register_buffer('offset', offset)
|
||||
|
||||
def forward(self, dist):
|
||||
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
||||
return torch.exp(self.coeff * torch.pow(dist, 2))
|
||||
255
models/tensor_layers.py
Normal file
255
models/tensor_layers.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from e3nn import o3
|
||||
from e3nn.nn import BatchNorm
|
||||
from e3nn.o3 import TensorProduct, Linear
|
||||
from torch_scatter import scatter, scatter_mean
|
||||
|
||||
from models.layers import FCBlock
|
||||
|
||||
def get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars):
|
||||
if use_second_order_repr:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {nv if reduce_pseudoscalars else ns}x0o'
|
||||
]
|
||||
else:
|
||||
irrep_seq = [
|
||||
f'{ns}x0e',
|
||||
f'{ns}x0e + {nv}x1o',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e',
|
||||
f'{ns}x0e + {nv}x1o + {nv}x1e + {nv if reduce_pseudoscalars else ns}x0o'
|
||||
]
|
||||
return irrep_seq
|
||||
|
||||
|
||||
def irrep_to_size(irrep):
|
||||
irreps = irrep.split(' + ')
|
||||
size = 0
|
||||
for ir in irreps:
|
||||
m, (l, p) = ir.split('x')
|
||||
size += int(m) * (2 * int(l) + 1)
|
||||
return size
|
||||
|
||||
|
||||
class FasterTensorProduct(torch.nn.Module):
|
||||
# Implemented by Bowen Jing
|
||||
def __init__(self, in_irreps, sh_irreps, out_irreps, **kwargs):
|
||||
super().__init__()
|
||||
#for ir in in_irreps:
|
||||
# m, (l, p) = ir
|
||||
# assert l in [0, 1], "Higher order in irreps are not supported"
|
||||
#for ir in out_irreps:
|
||||
# m, (l, p) = ir
|
||||
# assert l in [0, 1], "Higher order out irreps are not supported"
|
||||
assert o3.Irreps(sh_irreps) == o3.Irreps('1x0e+1x1o'), "sh_irreps don't look like 1st order spherical harmonics"
|
||||
self.in_irreps = o3.Irreps(in_irreps)
|
||||
self.out_irreps = o3.Irreps(out_irreps)
|
||||
|
||||
in_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
|
||||
out_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
|
||||
for (m, ir) in self.in_irreps: in_muls[str(ir)] = m
|
||||
for (m, ir) in self.out_irreps: out_muls[str(ir)] = m
|
||||
|
||||
self.weight_shapes = {
|
||||
'0e': (in_muls['0e'] + in_muls['1o'], out_muls['0e']),
|
||||
'1o': (in_muls['0e'] + in_muls['1o'] + in_muls['1e'], out_muls['1o']),
|
||||
'1e': (in_muls['1o'] + in_muls['1e'] + in_muls['0o'], out_muls['1e']),
|
||||
'0o': (in_muls['1e'] + in_muls['0o'], out_muls['0o'])
|
||||
}
|
||||
self.weight_numel = sum(a * b for (a, b) in self.weight_shapes.values())
|
||||
|
||||
def forward(self, in_, sh, weight):
|
||||
in_dict, out_dict = {}, {'0e': [], '1o': [], '1e': [], '0o': []}
|
||||
for (m, ir), sl in zip(self.in_irreps, self.in_irreps.slices()):
|
||||
in_dict[str(ir)] = in_[..., sl]
|
||||
if ir[0] == 1: in_dict[str(ir)] = in_dict[str(ir)].reshape(list(in_dict[str(ir)].shape)[:-1] + [-1, 3])
|
||||
sh_0e, sh_1o = sh[..., 0], sh[..., 1:]
|
||||
if '0e' in in_dict:
|
||||
out_dict['0e'].append(in_dict['0e'] * sh_0e.unsqueeze(-1))
|
||||
out_dict['1o'].append(in_dict['0e'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
|
||||
if '1o' in in_dict:
|
||||
out_dict['0e'].append((in_dict['1o'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
|
||||
out_dict['1o'].append(in_dict['1o'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
|
||||
out_dict['1e'].append(torch.linalg.cross(in_dict['1o'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
|
||||
if '1e' in in_dict:
|
||||
out_dict['1o'].append(torch.linalg.cross(in_dict['1e'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
|
||||
out_dict['1e'].append(in_dict['1e'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
|
||||
out_dict['0o'].append((in_dict['1e'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
|
||||
if '0o' in in_dict:
|
||||
out_dict['1e'].append(in_dict['0o'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
|
||||
out_dict['0o'].append(in_dict['0o'] * sh_0e.unsqueeze(-1))
|
||||
|
||||
weight_dict = {}
|
||||
start = 0
|
||||
for key in self.weight_shapes:
|
||||
in_, out = self.weight_shapes[key]
|
||||
weight_dict[key] = weight[..., start:start + in_ * out].reshape(
|
||||
list(weight.shape)[:-1] + [in_, out]) / np.sqrt(in_)
|
||||
start += in_ * out
|
||||
|
||||
if out_dict['0e']:
|
||||
out_dict['0e'] = torch.cat(out_dict['0e'], dim=-1)
|
||||
out_dict['0e'] = torch.matmul(out_dict['0e'].unsqueeze(-2), weight_dict['0e']).squeeze(-2)
|
||||
|
||||
if out_dict['1o']:
|
||||
out_dict['1o'] = torch.cat(out_dict['1o'], dim=-2)
|
||||
out_dict['1o'] = (out_dict['1o'].unsqueeze(-2) * weight_dict['1o'].unsqueeze(-1)).sum(-3)
|
||||
out_dict['1o'] = out_dict['1o'].reshape(list(out_dict['1o'].shape)[:-2] + [-1])
|
||||
|
||||
if out_dict['1e']:
|
||||
out_dict['1e'] = torch.cat(out_dict['1e'], dim=-2)
|
||||
out_dict['1e'] = (out_dict['1e'].unsqueeze(-2) * weight_dict['1e'].unsqueeze(-1)).sum(-3)
|
||||
out_dict['1e'] = out_dict['1e'].reshape(list(out_dict['1e'].shape)[:-2] + [-1])
|
||||
|
||||
if out_dict['0o']:
|
||||
out_dict['0o'] = torch.cat(out_dict['0o'], dim=-1)
|
||||
# out_dict['0o'] = (out_dict['0o'].unsqueeze(-1) * weight_dict['0o']).sum(-2)
|
||||
out_dict['0o'] = torch.matmul(out_dict['0o'].unsqueeze(-2), weight_dict['0o']).squeeze(-2)
|
||||
|
||||
out = []
|
||||
for _, ir in self.out_irreps:
|
||||
out.append(out_dict[str(ir)])
|
||||
return torch.cat(out, dim=-1)
|
||||
|
||||
|
||||
class TensorProductConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
|
||||
hidden_features=None, faster=False, edge_groups=1, tp_weights_layers=2, activation='relu', depthwise=False):
|
||||
super(TensorProductConvLayer, self).__init__()
|
||||
self.in_irreps = in_irreps
|
||||
self.out_irreps = out_irreps
|
||||
self.sh_irreps = sh_irreps
|
||||
self.residual = residual
|
||||
self.edge_groups = edge_groups
|
||||
self.out_size = irrep_to_size(out_irreps)
|
||||
self.depthwise = depthwise
|
||||
if hidden_features is None:
|
||||
hidden_features = n_edge_features
|
||||
|
||||
if depthwise:
|
||||
in_irreps = o3.Irreps(in_irreps)
|
||||
sh_irreps = o3.Irreps(sh_irreps)
|
||||
out_irreps = o3.Irreps(out_irreps)
|
||||
|
||||
irreps_mid = []
|
||||
instructions = []
|
||||
for i, (mul, ir_in) in enumerate(in_irreps):
|
||||
for j, (_, ir_edge) in enumerate(sh_irreps):
|
||||
for ir_out in ir_in * ir_edge:
|
||||
if ir_out in out_irreps:
|
||||
k = len(irreps_mid)
|
||||
irreps_mid.append((mul, ir_out))
|
||||
instructions.append((i, j, k, "uvu", True))
|
||||
|
||||
# We sort the output irreps of the tensor product so that we can simplify them
|
||||
# when they are provided to the second o3.Linear
|
||||
irreps_mid = o3.Irreps(irreps_mid)
|
||||
irreps_mid, p, _ = irreps_mid.sort()
|
||||
|
||||
# Permute the output indexes of the instructions to match the sorted irreps:
|
||||
instructions = [
|
||||
(i_in1, i_in2, p[i_out], mode, train)
|
||||
for i_in1, i_in2, i_out, mode, train in instructions
|
||||
]
|
||||
|
||||
self.tp = TensorProduct(
|
||||
in_irreps,
|
||||
sh_irreps,
|
||||
irreps_mid,
|
||||
instructions,
|
||||
shared_weights=False,
|
||||
internal_weights=False,
|
||||
)
|
||||
|
||||
self.linear_2 = Linear(
|
||||
# irreps_mid has uncoallesed irreps because of the uvu instructions,
|
||||
# but there's no reason to treat them seperately for the Linear
|
||||
# Note that normalization of o3.Linear changes if irreps are coallesed
|
||||
# (likely for the better)
|
||||
irreps_in=irreps_mid.simplify(),
|
||||
irreps_out=out_irreps,
|
||||
internal_weights=True,
|
||||
shared_weights=True,
|
||||
)
|
||||
|
||||
else:
|
||||
if faster:
|
||||
print("Faster Tensor Product")
|
||||
self.tp = FasterTensorProduct(in_irreps, sh_irreps, out_irreps)
|
||||
else:
|
||||
self.tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
|
||||
|
||||
if edge_groups == 1:
|
||||
self.fc = FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation)
|
||||
else:
|
||||
self.fc = [FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation) for _ in range(edge_groups)]
|
||||
self.fc = nn.ModuleList(self.fc)
|
||||
|
||||
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
|
||||
|
||||
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
|
||||
|
||||
if edge_index.shape[1] == 0:
|
||||
out = torch.zeros((node_attr.shape[0], self.out_size), dtype=node_attr.dtype, device=node_attr.device)
|
||||
else:
|
||||
edge_src, edge_dst = edge_index
|
||||
edge_attr_ = self.fc(edge_attr) if self.edge_groups == 1 else torch.cat(
|
||||
[self.fc[i](edge_attr[i]) for i in range(self.edge_groups)], dim=0).to(node_attr.device)
|
||||
tp = self.tp(node_attr[edge_dst], edge_sh, edge_attr_ * edge_weight)
|
||||
|
||||
out_nodes = out_nodes or node_attr.shape[0]
|
||||
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
|
||||
|
||||
if self.depthwise:
|
||||
out = self.linear_2(out)
|
||||
|
||||
if self.batch_norm:
|
||||
out = self.batch_norm(out)
|
||||
|
||||
if self.residual:
|
||||
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
|
||||
out = out + padded
|
||||
return out
|
||||
|
||||
|
||||
class OldTensorProductConvLayer(torch.nn.Module):
|
||||
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
|
||||
hidden_features=None):
|
||||
super(OldTensorProductConvLayer, self).__init__()
|
||||
self.in_irreps = in_irreps
|
||||
self.out_irreps = out_irreps
|
||||
self.sh_irreps = sh_irreps
|
||||
self.residual = residual
|
||||
if hidden_features is None:
|
||||
hidden_features = n_edge_features
|
||||
|
||||
self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(n_edge_features, hidden_features),
|
||||
nn.ReLU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(hidden_features, tp.weight_numel)
|
||||
)
|
||||
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
|
||||
|
||||
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
|
||||
|
||||
edge_src, edge_dst = edge_index
|
||||
tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr) * edge_weight)
|
||||
|
||||
out_nodes = out_nodes or node_attr.shape[0]
|
||||
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
|
||||
|
||||
if self.residual:
|
||||
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
|
||||
out = out + padded
|
||||
|
||||
if self.batch_norm:
|
||||
out = self.batch_norm(out)
|
||||
return out
|
||||
|
Before Width: | Height: | Size: 334 KiB After Width: | Height: | Size: 334 KiB |
22
spyrmsd/LICENSE
Normal file
22
spyrmsd/LICENSE
Normal file
@@ -0,0 +1,22 @@
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019-2021 Rocco Meli
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
24
spyrmsd/__init__.py
Normal file
24
spyrmsd/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Python RMSD tool with symmetry correction.
|
||||
"""
|
||||
|
||||
from ._version import get_versions
|
||||
from .due import Doi, due
|
||||
|
||||
versions = get_versions()
|
||||
__version__ = versions["version"]
|
||||
__git_revision__ = versions["full-revisionid"]
|
||||
del get_versions, versions
|
||||
|
||||
# This will print latest Zenodo version
|
||||
due.cite(
|
||||
Doi("10.5281/zenodo.3631876"),
|
||||
path="spyrmsd",
|
||||
description="spyrmsd",
|
||||
)
|
||||
|
||||
due.cite(
|
||||
Doi("10.1186/s13321-020-00455-2"),
|
||||
path="spyrmsd",
|
||||
description="spyrmsd",
|
||||
)
|
||||
63
spyrmsd/__main__.py
Normal file
63
spyrmsd/__main__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Symmetry-corrected RMSD calculations in Python
|
||||
"""
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse as ap
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
from spyrmsd import io
|
||||
from spyrmsd.rmsd import rmsdwrapper
|
||||
|
||||
parser = ap.ArgumentParser(
|
||||
prog="python -m spyrmsd",
|
||||
description="Symmetry-corrected RMSD calculations in Python.",
|
||||
)
|
||||
|
||||
parser.add_argument("reference", type=str, help="Reference file")
|
||||
parser.add_argument("molecules", type=str, nargs="+", help="Input file(s)")
|
||||
parser.add_argument("-m", "--minimize", action="store_true", help="Minimize (fit)")
|
||||
parser.add_argument(
|
||||
"-c", "--center", action="store_true", help="Center molecules at origin"
|
||||
)
|
||||
parser.add_argument("--hydrogens", action="store_true", help="Keep hydrogen atoms")
|
||||
parser.add_argument(
|
||||
"-n", "--nosymm", action="store_false", help="No graph isomorphism"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if (
|
||||
importlib.util.find_spec("openbabel") is None
|
||||
and importlib.util.find_spec("rdkit") is None
|
||||
):
|
||||
raise ImportError(
|
||||
"OpenBabel or RDKit not found. Please install OpenBabel or RDKit to use sPyRMSD as a standalone tool."
|
||||
)
|
||||
|
||||
try:
|
||||
ref = io.loadmol(args.reference)
|
||||
except OSError:
|
||||
print("ERROR: Reference file not found.", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
# Load all molecules
|
||||
try:
|
||||
mols = [mol for molfile in args.molecules for mol in io.loadallmols(molfile)]
|
||||
except OSError:
|
||||
print("ERROR: Molecule file(s) not found.", file=sys.stderr)
|
||||
exit(-1)
|
||||
|
||||
# Loop over molecules within fil
|
||||
RMSDlist = rmsdwrapper(
|
||||
ref,
|
||||
mols,
|
||||
symmetry=args.nosymm, # args.nosymm store False
|
||||
center=args.center,
|
||||
minimize=args.minimize,
|
||||
strip=not args.hydrogens,
|
||||
)
|
||||
|
||||
for RMSD in RMSDlist:
|
||||
print(f"{RMSD:.5f}")
|
||||
693
spyrmsd/_version.py
Normal file
693
spyrmsd/_version.py
Normal file
@@ -0,0 +1,693 @@
|
||||
# This file helps to compute a version number in source trees obtained from
|
||||
# git-archive tarball (such as those provided by githubs download-from-tag
|
||||
# feature). Distribution tarballs (built by setup.py sdist) and build
|
||||
# directories (produced by setup.py build) will contain a much shorter file
|
||||
# that just contains the computed version number.
|
||||
|
||||
# This file is released into the public domain.
|
||||
# Generated by versioneer-0.28
|
||||
# https://github.com/python-versioneer/python-versioneer
|
||||
|
||||
"""Git implementation of _version.py."""
|
||||
|
||||
import errno
|
||||
import functools
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import Callable, Dict
|
||||
|
||||
|
||||
def get_keywords():
|
||||
"""Get the keywords needed to look up the version information."""
|
||||
# these strings will be replaced by git during git-archive.
|
||||
# setup.py/versioneer.py will grep for the variable names, so they must
|
||||
# each be defined on a line of their own. _version.py will just call
|
||||
# get_keywords().
|
||||
git_refnames = " (HEAD -> develop, refs/pull/86/head)"
|
||||
git_full = "b5532dd1b677a686c9d31c7bf39446e0a66e0af8"
|
||||
git_date = "2023-09-08 19:57:37 +0200"
|
||||
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
|
||||
return keywords
|
||||
|
||||
|
||||
class VersioneerConfig:
|
||||
"""Container for Versioneer configuration parameters."""
|
||||
|
||||
|
||||
def get_config():
|
||||
"""Create, populate and return the VersioneerConfig() object."""
|
||||
# these strings are filled in when 'setup.py versioneer' creates
|
||||
# _version.py
|
||||
cfg = VersioneerConfig()
|
||||
cfg.VCS = "git"
|
||||
cfg.style = "pep440"
|
||||
cfg.tag_prefix = ""
|
||||
cfg.parentdir_prefix = "None"
|
||||
cfg.versionfile_source = "spyrmsd/_version.py"
|
||||
cfg.verbose = False
|
||||
return cfg
|
||||
|
||||
|
||||
class NotThisMethod(Exception):
|
||||
"""Exception raised if a method is not valid for the current scenario."""
|
||||
|
||||
|
||||
LONG_VERSION_PY: Dict[str, str] = {}
|
||||
HANDLERS: Dict[str, Dict[str, Callable]] = {}
|
||||
|
||||
|
||||
def register_vcs_handler(vcs, method): # decorator
|
||||
"""Create decorator to mark a method as the handler of a VCS."""
|
||||
|
||||
def decorate(f):
|
||||
"""Store f in HANDLERS[vcs][method]."""
|
||||
if vcs not in HANDLERS:
|
||||
HANDLERS[vcs] = {}
|
||||
HANDLERS[vcs][method] = f
|
||||
return f
|
||||
|
||||
return decorate
|
||||
|
||||
|
||||
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
|
||||
"""Call the given command(s)."""
|
||||
assert isinstance(commands, list)
|
||||
process = None
|
||||
|
||||
popen_kwargs = {}
|
||||
if sys.platform == "win32":
|
||||
# This hides the console window if pythonw.exe is used
|
||||
startupinfo = subprocess.STARTUPINFO()
|
||||
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
|
||||
popen_kwargs["startupinfo"] = startupinfo
|
||||
|
||||
for command in commands:
|
||||
try:
|
||||
dispcmd = str([command] + args)
|
||||
# remember shell=False, so use git.cmd on windows, not just git
|
||||
process = subprocess.Popen(
|
||||
[command] + args,
|
||||
cwd=cwd,
|
||||
env=env,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=(subprocess.PIPE if hide_stderr else None),
|
||||
**popen_kwargs,
|
||||
)
|
||||
break
|
||||
except OSError:
|
||||
e = sys.exc_info()[1]
|
||||
if e.errno == errno.ENOENT:
|
||||
continue
|
||||
if verbose:
|
||||
print("unable to run %s" % dispcmd)
|
||||
print(e)
|
||||
return None, None
|
||||
else:
|
||||
if verbose:
|
||||
print("unable to find command, tried %s" % (commands,))
|
||||
return None, None
|
||||
stdout = process.communicate()[0].strip().decode()
|
||||
if process.returncode != 0:
|
||||
if verbose:
|
||||
print("unable to run %s (error)" % dispcmd)
|
||||
print("stdout was %s" % stdout)
|
||||
return None, process.returncode
|
||||
return stdout, process.returncode
|
||||
|
||||
|
||||
def versions_from_parentdir(parentdir_prefix, root, verbose):
|
||||
"""Try to determine the version from the parent directory name.
|
||||
|
||||
Source tarballs conventionally unpack into a directory that includes both
|
||||
the project name and a version string. We will also support searching up
|
||||
two directory levels for an appropriately named parent directory
|
||||
"""
|
||||
rootdirs = []
|
||||
|
||||
for _ in range(3):
|
||||
dirname = os.path.basename(root)
|
||||
if dirname.startswith(parentdir_prefix):
|
||||
return {
|
||||
"version": dirname[len(parentdir_prefix) :],
|
||||
"full-revisionid": None,
|
||||
"dirty": False,
|
||||
"error": None,
|
||||
"date": None,
|
||||
}
|
||||
rootdirs.append(root)
|
||||
root = os.path.dirname(root) # up a level
|
||||
|
||||
if verbose:
|
||||
print(
|
||||
"Tried directories %s but none started with prefix %s"
|
||||
% (str(rootdirs), parentdir_prefix)
|
||||
)
|
||||
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
|
||||
|
||||
|
||||
@register_vcs_handler("git", "get_keywords")
|
||||
def git_get_keywords(versionfile_abs):
|
||||
"""Extract version information from the given file."""
|
||||
# the code embedded in _version.py can just fetch the value of these
|
||||
# keywords. When used from setup.py, we don't want to import _version.py,
|
||||
# so we do it with a regexp instead. This function is not used from
|
||||
# _version.py.
|
||||
keywords = {}
|
||||
try:
|
||||
with open(versionfile_abs, "r") as fobj:
|
||||
for line in fobj:
|
||||
if line.strip().startswith("git_refnames ="):
|
||||
mo = re.search(r'=\s*"(.*)"', line)
|
||||
if mo:
|
||||
keywords["refnames"] = mo.group(1)
|
||||
if line.strip().startswith("git_full ="):
|
||||
mo = re.search(r'=\s*"(.*)"', line)
|
||||
if mo:
|
||||
keywords["full"] = mo.group(1)
|
||||
if line.strip().startswith("git_date ="):
|
||||
mo = re.search(r'=\s*"(.*)"', line)
|
||||
if mo:
|
||||
keywords["date"] = mo.group(1)
|
||||
except OSError:
|
||||
pass
|
||||
return keywords
|
||||
|
||||
|
||||
@register_vcs_handler("git", "keywords")
|
||||
def git_versions_from_keywords(keywords, tag_prefix, verbose):
|
||||
"""Get version information from git keywords."""
|
||||
if "refnames" not in keywords:
|
||||
raise NotThisMethod("Short version file found")
|
||||
date = keywords.get("date")
|
||||
if date is not None:
|
||||
# Use only the last line. Previous lines may contain GPG signature
|
||||
# information.
|
||||
date = date.splitlines()[-1]
|
||||
|
||||
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
|
||||
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
|
||||
# -like" string, which we must then edit to make compliant), because
|
||||
# it's been around since git-1.5.3, and it's too difficult to
|
||||
# discover which version we're using, or to work around using an
|
||||
# older one.
|
||||
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
|
||||
refnames = keywords["refnames"].strip()
|
||||
if refnames.startswith("$Format"):
|
||||
if verbose:
|
||||
print("keywords are unexpanded, not using")
|
||||
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
|
||||
refs = {r.strip() for r in refnames.strip("()").split(",")}
|
||||
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
|
||||
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
|
||||
TAG = "tag: "
|
||||
tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
|
||||
if not tags:
|
||||
# Either we're using git < 1.8.3, or there really are no tags. We use
|
||||
# a heuristic: assume all version tags have a digit. The old git %d
|
||||
# expansion behaves like git log --decorate=short and strips out the
|
||||
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
|
||||
# between branches and tags. By ignoring refnames without digits, we
|
||||
# filter out many common branch names like "release" and
|
||||
# "stabilization", as well as "HEAD" and "master".
|
||||
tags = {r for r in refs if re.search(r"\d", r)}
|
||||
if verbose:
|
||||
print("discarding '%s', no digits" % ",".join(refs - tags))
|
||||
if verbose:
|
||||
print("likely tags: %s" % ",".join(sorted(tags)))
|
||||
for ref in sorted(tags):
|
||||
# sorting will prefer e.g. "2.0" over "2.0rc1"
|
||||
if ref.startswith(tag_prefix):
|
||||
r = ref[len(tag_prefix) :]
|
||||
# Filter out refs that exactly match prefix or that don't start
|
||||
# with a number once the prefix is stripped (mostly a concern
|
||||
# when prefix is '')
|
||||
if not re.match(r"\d", r):
|
||||
continue
|
||||
if verbose:
|
||||
print("picking %s" % r)
|
||||
return {
|
||||
"version": r,
|
||||
"full-revisionid": keywords["full"].strip(),
|
||||
"dirty": False,
|
||||
"error": None,
|
||||
"date": date,
|
||||
}
|
||||
# no suitable tags, so version is "0+unknown", but full hex is still there
|
||||
if verbose:
|
||||
print("no suitable tags, using unknown + full revision id")
|
||||
return {
|
||||
"version": "0+unknown",
|
||||
"full-revisionid": keywords["full"].strip(),
|
||||
"dirty": False,
|
||||
"error": "no suitable tags",
|
||||
"date": None,
|
||||
}
|
||||
|
||||
|
||||
@register_vcs_handler("git", "pieces_from_vcs")
|
||||
def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
|
||||
"""Get version from 'git describe' in the root of the source tree.
|
||||
|
||||
This only gets called if the git-archive 'subst' keywords were *not*
|
||||
expanded, and _version.py hasn't already been rewritten with a short
|
||||
version string, meaning we're inside a checked out source tree.
|
||||
"""
|
||||
GITS = ["git"]
|
||||
if sys.platform == "win32":
|
||||
GITS = ["git.cmd", "git.exe"]
|
||||
|
||||
# GIT_DIR can interfere with correct operation of Versioneer.
|
||||
# It may be intended to be passed to the Versioneer-versioned project,
|
||||
# but that should not change where we get our version from.
|
||||
env = os.environ.copy()
|
||||
env.pop("GIT_DIR", None)
|
||||
runner = functools.partial(runner, env=env)
|
||||
|
||||
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose)
|
||||
if rc != 0:
|
||||
if verbose:
|
||||
print("Directory %s not under git control" % root)
|
||||
raise NotThisMethod("'git rev-parse --git-dir' returned error")
|
||||
|
||||
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
|
||||
# if there isn't one, this yields HEX[-dirty] (no NUM)
|
||||
describe_out, rc = runner(
|
||||
GITS,
|
||||
[
|
||||
"describe",
|
||||
"--tags",
|
||||
"--dirty",
|
||||
"--always",
|
||||
"--long",
|
||||
"--match",
|
||||
f"{tag_prefix}[[:digit:]]*",
|
||||
],
|
||||
cwd=root,
|
||||
)
|
||||
# --long was added in git-1.5.5
|
||||
if describe_out is None:
|
||||
raise NotThisMethod("'git describe' failed")
|
||||
describe_out = describe_out.strip()
|
||||
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
|
||||
if full_out is None:
|
||||
raise NotThisMethod("'git rev-parse' failed")
|
||||
full_out = full_out.strip()
|
||||
|
||||
pieces = {}
|
||||
pieces["long"] = full_out
|
||||
pieces["short"] = full_out[:7] # maybe improved later
|
||||
pieces["error"] = None
|
||||
|
||||
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
|
||||
# --abbrev-ref was added in git-1.6.3
|
||||
if rc != 0 or branch_name is None:
|
||||
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
|
||||
branch_name = branch_name.strip()
|
||||
|
||||
if branch_name == "HEAD":
|
||||
# If we aren't exactly on a branch, pick a branch which represents
|
||||
# the current commit. If all else fails, we are on a branchless
|
||||
# commit.
|
||||
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
|
||||
# --contains was added in git-1.5.4
|
||||
if rc != 0 or branches is None:
|
||||
raise NotThisMethod("'git branch --contains' returned error")
|
||||
branches = branches.split("\n")
|
||||
|
||||
# Remove the first line if we're running detached
|
||||
if "(" in branches[0]:
|
||||
branches.pop(0)
|
||||
|
||||
# Strip off the leading "* " from the list of branches.
|
||||
branches = [branch[2:] for branch in branches]
|
||||
if "master" in branches:
|
||||
branch_name = "master"
|
||||
elif not branches:
|
||||
branch_name = None
|
||||
else:
|
||||
# Pick the first branch that is returned. Good or bad.
|
||||
branch_name = branches[0]
|
||||
|
||||
pieces["branch"] = branch_name
|
||||
|
||||
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
|
||||
# TAG might have hyphens.
|
||||
git_describe = describe_out
|
||||
|
||||
# look for -dirty suffix
|
||||
dirty = git_describe.endswith("-dirty")
|
||||
pieces["dirty"] = dirty
|
||||
if dirty:
|
||||
git_describe = git_describe[: git_describe.rindex("-dirty")]
|
||||
|
||||
# now we have TAG-NUM-gHEX or HEX
|
||||
|
||||
if "-" in git_describe:
|
||||
# TAG-NUM-gHEX
|
||||
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
|
||||
if not mo:
|
||||
# unparsable. Maybe git-describe is misbehaving?
|
||||
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
|
||||
return pieces
|
||||
|
||||
# tag
|
||||
full_tag = mo.group(1)
|
||||
if not full_tag.startswith(tag_prefix):
|
||||
if verbose:
|
||||
fmt = "tag '%s' doesn't start with prefix '%s'"
|
||||
print(fmt % (full_tag, tag_prefix))
|
||||
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
|
||||
full_tag,
|
||||
tag_prefix,
|
||||
)
|
||||
return pieces
|
||||
pieces["closest-tag"] = full_tag[len(tag_prefix) :]
|
||||
|
||||
# distance: number of commits since tag
|
||||
pieces["distance"] = int(mo.group(2))
|
||||
|
||||
# commit: short hex revision ID
|
||||
pieces["short"] = mo.group(3)
|
||||
|
||||
else:
|
||||
# HEX: no tags
|
||||
pieces["closest-tag"] = None
|
||||
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
|
||||
pieces["distance"] = len(out.split()) # total number of commits
|
||||
|
||||
# commit date: see ISO-8601 comment in git_versions_from_keywords()
|
||||
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
|
||||
# Use only the last line. Previous lines may contain GPG signature
|
||||
# information.
|
||||
date = date.splitlines()[-1]
|
||||
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
|
||||
|
||||
return pieces
|
||||
|
||||
|
||||
def plus_or_dot(pieces):
|
||||
"""Return a + if we don't already have one, else return a ."""
|
||||
if "+" in pieces.get("closest-tag", ""):
|
||||
return "."
|
||||
return "+"
|
||||
|
||||
|
||||
def render_pep440(pieces):
|
||||
"""Build up version string, with post-release "local version identifier".
|
||||
|
||||
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
|
||||
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
|
||||
|
||||
Exceptions:
|
||||
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"] or pieces["dirty"]:
|
||||
rendered += plus_or_dot(pieces)
|
||||
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
return rendered
|
||||
|
||||
|
||||
def render_pep440_branch(pieces):
|
||||
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
|
||||
|
||||
The ".dev0" means not master branch. Note that .dev0 sorts backwards
|
||||
(a feature branch will appear "older" than the master branch).
|
||||
|
||||
Exceptions:
|
||||
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"] or pieces["dirty"]:
|
||||
if pieces["branch"] != "master":
|
||||
rendered += ".dev0"
|
||||
rendered += plus_or_dot(pieces)
|
||||
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0"
|
||||
if pieces["branch"] != "master":
|
||||
rendered += ".dev0"
|
||||
rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
return rendered
|
||||
|
||||
|
||||
def pep440_split_post(ver):
|
||||
"""Split pep440 version string at the post-release segment.
|
||||
|
||||
Returns the release segments before the post-release and the
|
||||
post-release version number (or -1 if no post-release segment is present).
|
||||
"""
|
||||
vc = str.split(ver, ".post")
|
||||
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
|
||||
|
||||
|
||||
def render_pep440_pre(pieces):
|
||||
"""TAG[.postN.devDISTANCE] -- No -dirty.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. 0.post0.devDISTANCE
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
if pieces["distance"]:
|
||||
# update the post release segment
|
||||
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
|
||||
rendered = tag_version
|
||||
if post_version is not None:
|
||||
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
|
||||
else:
|
||||
rendered += ".post0.dev%d" % (pieces["distance"])
|
||||
else:
|
||||
# no commits, use the tag as the version
|
||||
rendered = pieces["closest-tag"]
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0.post0.dev%d" % pieces["distance"]
|
||||
return rendered
|
||||
|
||||
|
||||
def render_pep440_post(pieces):
|
||||
"""TAG[.postDISTANCE[.dev0]+gHEX] .
|
||||
|
||||
The ".dev0" means dirty. Note that .dev0 sorts backwards
|
||||
(a dirty tree will appear "older" than the corresponding clean one),
|
||||
but you shouldn't be releasing software with -dirty anyways.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. 0.postDISTANCE[.dev0]
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"] or pieces["dirty"]:
|
||||
rendered += ".post%d" % pieces["distance"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dev0"
|
||||
rendered += plus_or_dot(pieces)
|
||||
rendered += "g%s" % pieces["short"]
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0.post%d" % pieces["distance"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dev0"
|
||||
rendered += "+g%s" % pieces["short"]
|
||||
return rendered
|
||||
|
||||
|
||||
def render_pep440_post_branch(pieces):
|
||||
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
|
||||
|
||||
The ".dev0" means not master branch.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"] or pieces["dirty"]:
|
||||
rendered += ".post%d" % pieces["distance"]
|
||||
if pieces["branch"] != "master":
|
||||
rendered += ".dev0"
|
||||
rendered += plus_or_dot(pieces)
|
||||
rendered += "g%s" % pieces["short"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0.post%d" % pieces["distance"]
|
||||
if pieces["branch"] != "master":
|
||||
rendered += ".dev0"
|
||||
rendered += "+g%s" % pieces["short"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dirty"
|
||||
return rendered
|
||||
|
||||
|
||||
def render_pep440_old(pieces):
|
||||
"""TAG[.postDISTANCE[.dev0]] .
|
||||
|
||||
The ".dev0" means dirty.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. 0.postDISTANCE[.dev0]
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"] or pieces["dirty"]:
|
||||
rendered += ".post%d" % pieces["distance"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dev0"
|
||||
else:
|
||||
# exception #1
|
||||
rendered = "0.post%d" % pieces["distance"]
|
||||
if pieces["dirty"]:
|
||||
rendered += ".dev0"
|
||||
return rendered
|
||||
|
||||
|
||||
def render_git_describe(pieces):
|
||||
"""TAG[-DISTANCE-gHEX][-dirty].
|
||||
|
||||
Like 'git describe --tags --dirty --always'.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. HEX[-dirty] (note: no 'g' prefix)
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
if pieces["distance"]:
|
||||
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
|
||||
else:
|
||||
# exception #1
|
||||
rendered = pieces["short"]
|
||||
if pieces["dirty"]:
|
||||
rendered += "-dirty"
|
||||
return rendered
|
||||
|
||||
|
||||
def render_git_describe_long(pieces):
|
||||
"""TAG-DISTANCE-gHEX[-dirty].
|
||||
|
||||
Like 'git describe --tags --dirty --always -long'.
|
||||
The distance/hash is unconditional.
|
||||
|
||||
Exceptions:
|
||||
1: no tags. HEX[-dirty] (note: no 'g' prefix)
|
||||
"""
|
||||
if pieces["closest-tag"]:
|
||||
rendered = pieces["closest-tag"]
|
||||
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
|
||||
else:
|
||||
# exception #1
|
||||
rendered = pieces["short"]
|
||||
if pieces["dirty"]:
|
||||
rendered += "-dirty"
|
||||
return rendered
|
||||
|
||||
|
||||
def render(pieces, style):
|
||||
"""Render the given version pieces into the requested style."""
|
||||
if pieces["error"]:
|
||||
return {
|
||||
"version": "unknown",
|
||||
"full-revisionid": pieces.get("long"),
|
||||
"dirty": None,
|
||||
"error": pieces["error"],
|
||||
"date": None,
|
||||
}
|
||||
|
||||
if not style or style == "default":
|
||||
style = "pep440" # the default
|
||||
|
||||
if style == "pep440":
|
||||
rendered = render_pep440(pieces)
|
||||
elif style == "pep440-branch":
|
||||
rendered = render_pep440_branch(pieces)
|
||||
elif style == "pep440-pre":
|
||||
rendered = render_pep440_pre(pieces)
|
||||
elif style == "pep440-post":
|
||||
rendered = render_pep440_post(pieces)
|
||||
elif style == "pep440-post-branch":
|
||||
rendered = render_pep440_post_branch(pieces)
|
||||
elif style == "pep440-old":
|
||||
rendered = render_pep440_old(pieces)
|
||||
elif style == "git-describe":
|
||||
rendered = render_git_describe(pieces)
|
||||
elif style == "git-describe-long":
|
||||
rendered = render_git_describe_long(pieces)
|
||||
else:
|
||||
raise ValueError("unknown style '%s'" % style)
|
||||
|
||||
return {
|
||||
"version": rendered,
|
||||
"full-revisionid": pieces["long"],
|
||||
"dirty": pieces["dirty"],
|
||||
"error": None,
|
||||
"date": pieces.get("date"),
|
||||
}
|
||||
|
||||
|
||||
def get_versions():
|
||||
"""Get version information or return default if unable to do so."""
|
||||
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
|
||||
# __file__, we can work backwards from there to the root. Some
|
||||
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
|
||||
# case we can only use expanded keywords.
|
||||
|
||||
cfg = get_config()
|
||||
verbose = cfg.verbose
|
||||
|
||||
try:
|
||||
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
|
||||
except NotThisMethod:
|
||||
pass
|
||||
|
||||
try:
|
||||
root = os.path.realpath(__file__)
|
||||
# versionfile_source is the relative path from the top of the source
|
||||
# tree (where the .git directory might live) to this file. Invert
|
||||
# this to find the root from __file__.
|
||||
for _ in cfg.versionfile_source.split("/"):
|
||||
root = os.path.dirname(root)
|
||||
except NameError:
|
||||
return {
|
||||
"version": "0+unknown",
|
||||
"full-revisionid": None,
|
||||
"dirty": None,
|
||||
"error": "unable to find root of source tree",
|
||||
"date": None,
|
||||
}
|
||||
|
||||
try:
|
||||
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
|
||||
return render(pieces, cfg.style)
|
||||
except NotThisMethod:
|
||||
pass
|
||||
|
||||
try:
|
||||
if cfg.parentdir_prefix:
|
||||
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
|
||||
except NotThisMethod:
|
||||
pass
|
||||
|
||||
return {
|
||||
"version": "0+unknown",
|
||||
"full-revisionid": None,
|
||||
"dirty": None,
|
||||
"error": "unable to compute version",
|
||||
"date": None,
|
||||
}
|
||||
235
spyrmsd/constants.py
Normal file
235
spyrmsd/constants.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""
|
||||
Useful constants.
|
||||
|
||||
Notes
|
||||
-----
|
||||
Periodic table data (atomic masses and covalent radii) are extracted from
|
||||
QCElemental_
|
||||
|
||||
.. _QCElemental: http://docs.qcarchive.molssi.org/projects/qcelemental/en/latest/
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
connectivity_tolerance: float = 0.4
|
||||
|
||||
# Values from QCElemental
|
||||
anum_to_mass: Dict = {
|
||||
1: 1.00782503223,
|
||||
2: 4.00260325413,
|
||||
3: 7.0160034366,
|
||||
4: 9.012183065,
|
||||
5: 11.00930536,
|
||||
6: 12.0,
|
||||
7: 14.00307400443,
|
||||
8: 15.99491461957,
|
||||
9: 18.99840316273,
|
||||
10: 19.9924401762,
|
||||
11: 22.989769282,
|
||||
12: 23.985041697,
|
||||
13: 26.98153853,
|
||||
14: 27.97692653465,
|
||||
15: 30.97376199842,
|
||||
16: 31.9720711744,
|
||||
17: 34.968852682,
|
||||
18: 39.9623831237,
|
||||
19: 38.9637064864,
|
||||
20: 39.962590863,
|
||||
21: 44.95590828,
|
||||
22: 47.94794198,
|
||||
23: 50.94395704,
|
||||
24: 51.94050623,
|
||||
25: 54.93804391,
|
||||
26: 55.93493633,
|
||||
27: 58.93319429,
|
||||
28: 57.93534241,
|
||||
29: 62.92959772,
|
||||
30: 63.92914201,
|
||||
31: 68.9255735,
|
||||
32: 73.921177761,
|
||||
33: 74.92159457,
|
||||
34: 79.9165218,
|
||||
35: 78.9183376,
|
||||
36: 83.9114977282,
|
||||
37: 84.9117897379,
|
||||
38: 87.9056125,
|
||||
39: 88.9058403,
|
||||
40: 89.9046977,
|
||||
41: 92.906373,
|
||||
42: 97.90540482,
|
||||
43: 97.9072124,
|
||||
44: 101.9043441,
|
||||
45: 102.905498,
|
||||
46: 105.9034804,
|
||||
47: 106.9050916,
|
||||
48: 113.90336509,
|
||||
49: 114.903878776,
|
||||
50: 119.90220163,
|
||||
51: 120.903812,
|
||||
52: 129.906222748,
|
||||
53: 126.9044719,
|
||||
54: 131.9041550856,
|
||||
55: 132.905451961,
|
||||
56: 137.905247,
|
||||
57: 138.9063563,
|
||||
58: 139.9054431,
|
||||
59: 140.9076576,
|
||||
60: 141.907729,
|
||||
61: 144.9127559,
|
||||
62: 151.9197397,
|
||||
63: 152.921238,
|
||||
64: 157.9241123,
|
||||
65: 158.9253547,
|
||||
66: 163.9291819,
|
||||
67: 164.9303288,
|
||||
68: 165.9302995,
|
||||
69: 168.9342179,
|
||||
70: 173.9388664,
|
||||
71: 174.9407752,
|
||||
72: 179.946557,
|
||||
73: 180.9479958,
|
||||
74: 183.95093092,
|
||||
75: 186.9557501,
|
||||
76: 191.961477,
|
||||
77: 192.9629216,
|
||||
78: 194.9647917,
|
||||
79: 196.96656879,
|
||||
80: 201.9706434,
|
||||
81: 204.9744278,
|
||||
82: 207.9766525,
|
||||
83: 208.9803991,
|
||||
84: 208.9824308,
|
||||
85: 209.9871479,
|
||||
86: 222.0175782,
|
||||
87: 223.019736,
|
||||
88: 226.0254103,
|
||||
89: 227.0277523,
|
||||
90: 232.0380558,
|
||||
91: 231.0358842,
|
||||
92: 238.0507884,
|
||||
93: 237.0481736,
|
||||
94: 244.0642053,
|
||||
95: 243.0613813,
|
||||
96: 247.0703541,
|
||||
97: 247.0703073,
|
||||
98: 251.0795886,
|
||||
99: 252.08298,
|
||||
100: 257.0951061,
|
||||
101: 258.0984315,
|
||||
102: 259.10103,
|
||||
103: 266.11983,
|
||||
104: 267.12179,
|
||||
105: 268.12567,
|
||||
106: 271.13393,
|
||||
107: 270.13336,
|
||||
108: 269.13375,
|
||||
109: 278.15631,
|
||||
110: 281.16451,
|
||||
111: 282.16912,
|
||||
112: 285.17712,
|
||||
113: 286.18221,
|
||||
114: 289.19042,
|
||||
115: 289.19363,
|
||||
116: 293.20449,
|
||||
117: 294.21046,
|
||||
}
|
||||
|
||||
# Data from QCElemental (in angstroms)
|
||||
anum_to_covalentradius: Dict = {
|
||||
1: 0.31,
|
||||
2: 0.28,
|
||||
3: 1.28,
|
||||
4: 0.96,
|
||||
5: 0.84,
|
||||
6: 0.76,
|
||||
7: 0.71,
|
||||
8: 0.66,
|
||||
9: 0.57,
|
||||
10: 0.58,
|
||||
11: 1.66,
|
||||
12: 1.41,
|
||||
13: 1.21,
|
||||
14: 1.11,
|
||||
15: 1.07,
|
||||
16: 1.05,
|
||||
17: 1.02,
|
||||
18: 1.06,
|
||||
19: 2.03,
|
||||
20: 1.76,
|
||||
21: 1.7,
|
||||
22: 1.6,
|
||||
23: 1.53,
|
||||
24: 1.39,
|
||||
25: 1.61,
|
||||
26: 1.52,
|
||||
27: 1.5,
|
||||
28: 1.24,
|
||||
29: 1.32,
|
||||
30: 1.22,
|
||||
31: 1.22,
|
||||
32: 1.2,
|
||||
33: 1.19,
|
||||
34: 1.2,
|
||||
35: 1.2,
|
||||
36: 1.16,
|
||||
37: 2.2,
|
||||
38: 1.95,
|
||||
39: 1.9,
|
||||
40: 1.75,
|
||||
41: 1.64,
|
||||
42: 1.54,
|
||||
43: 1.47,
|
||||
44: 1.46,
|
||||
45: 1.42,
|
||||
46: 1.39,
|
||||
47: 1.45,
|
||||
48: 1.44,
|
||||
49: 1.42,
|
||||
50: 1.39,
|
||||
51: 1.39,
|
||||
52: 1.38,
|
||||
53: 1.39,
|
||||
54: 1.4,
|
||||
55: 2.44,
|
||||
56: 2.15,
|
||||
57: 2.07,
|
||||
58: 2.04,
|
||||
59: 2.03,
|
||||
60: 2.01,
|
||||
61: 1.99,
|
||||
62: 1.98,
|
||||
63: 1.98,
|
||||
64: 1.96,
|
||||
65: 1.94,
|
||||
66: 1.92,
|
||||
67: 1.92,
|
||||
68: 1.89,
|
||||
69: 1.9,
|
||||
70: 1.87,
|
||||
71: 1.87,
|
||||
72: 1.75,
|
||||
73: 1.7,
|
||||
74: 1.62,
|
||||
75: 1.51,
|
||||
76: 1.44,
|
||||
77: 1.41,
|
||||
78: 1.36,
|
||||
79: 1.36,
|
||||
80: 1.32,
|
||||
81: 1.45,
|
||||
82: 1.46,
|
||||
83: 1.48,
|
||||
84: 1.4,
|
||||
85: 1.5,
|
||||
86: 1.5,
|
||||
87: 2.6,
|
||||
88: 2.21,
|
||||
89: 2.15,
|
||||
90: 2.06,
|
||||
91: 2.0,
|
||||
92: 1.96,
|
||||
93: 1.9,
|
||||
94: 1.87,
|
||||
95: 1.8,
|
||||
96: 1.69,
|
||||
}
|
||||
79
spyrmsd/due.py
Normal file
79
spyrmsd/due.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# emacs: at the end of the file
|
||||
# ex: set sts=4 ts=4 sw=4 et:
|
||||
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### #
|
||||
"""
|
||||
|
||||
Stub file for a guaranteed safe import of duecredit constructs: if duecredit
|
||||
is not available.
|
||||
|
||||
To use it, place it into your project codebase to be imported, e.g. copy as
|
||||
|
||||
cp stub.py /path/tomodule/module/due.py
|
||||
|
||||
Note that it might be better to avoid naming it duecredit.py to avoid shadowing
|
||||
installed duecredit.
|
||||
|
||||
Then use in your code as
|
||||
|
||||
from .due import due, Doi, BibTeX, Text
|
||||
|
||||
See https://github.com/duecredit/duecredit/blob/master/README.md for examples.
|
||||
|
||||
Origin: Originally a part of the duecredit
|
||||
Copyright: 2015-2019 DueCredit developers
|
||||
License: BSD-2
|
||||
"""
|
||||
|
||||
__version__ = "0.0.8"
|
||||
|
||||
|
||||
class InactiveDueCreditCollector(object):
|
||||
"""Just a stub at the Collector which would not do anything"""
|
||||
|
||||
def _donothing(self, *args, **kwargs):
|
||||
"""Perform no good and no bad"""
|
||||
pass
|
||||
|
||||
def dcite(self, *args, **kwargs):
|
||||
"""If I could cite I would"""
|
||||
|
||||
def nondecorating_decorator(func):
|
||||
return func
|
||||
|
||||
return nondecorating_decorator
|
||||
|
||||
active = False
|
||||
activate = add = cite = dump = load = _donothing
|
||||
|
||||
def __repr__(self):
|
||||
return self.__class__.__name__ + "()"
|
||||
|
||||
|
||||
def _donothing_func(*args, **kwargs):
|
||||
"""Perform no good and no bad"""
|
||||
pass
|
||||
|
||||
|
||||
try:
|
||||
from duecredit import BibTeX, Doi, Text, Url, due
|
||||
|
||||
if "due" in locals() and not hasattr(due, "cite"):
|
||||
raise RuntimeError("Imported due lacks .cite. DueCredit is now disabled")
|
||||
except Exception as e:
|
||||
if not isinstance(e, ImportError):
|
||||
import logging
|
||||
|
||||
logging.getLogger("duecredit").error(
|
||||
"Failed to import duecredit due to %s" % str(e)
|
||||
)
|
||||
# Initiate due stub
|
||||
due = InactiveDueCreditCollector()
|
||||
BibTeX = Doi = Url = Text = _donothing_func
|
||||
|
||||
# Emacs mode definitions
|
||||
# Local Variables:
|
||||
# mode: python
|
||||
# py-indent-offset: 4
|
||||
# tab-width: 4
|
||||
# indent-tabs-mode: nil
|
||||
# End:
|
||||
6
spyrmsd/exceptions.py
Normal file
6
spyrmsd/exceptions.py
Normal file
@@ -0,0 +1,6 @@
|
||||
class NonIsomorphicGraphs(ValueError):
|
||||
"""
|
||||
Raised when graphs are not isomorphic
|
||||
"""
|
||||
|
||||
pass
|
||||
94
spyrmsd/graph.py
Normal file
94
spyrmsd/graph.py
Normal file
@@ -0,0 +1,94 @@
|
||||
try:
|
||||
from spyrmsd.graphs.gt import (
|
||||
cycle,
|
||||
graph_from_adjacency_matrix,
|
||||
lattice,
|
||||
match_graphs,
|
||||
num_edges,
|
||||
num_vertices,
|
||||
vertex_property,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
try:
|
||||
from spyrmsd.graphs.nx import (
|
||||
cycle,
|
||||
graph_from_adjacency_matrix,
|
||||
lattice,
|
||||
match_graphs,
|
||||
num_edges,
|
||||
num_vertices,
|
||||
vertex_property,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError("graph_tool or NetworkX libraries not found.")
|
||||
|
||||
__all__ = [
|
||||
"graph_from_adjacency_matrix",
|
||||
"match_graphs",
|
||||
"vertex_property",
|
||||
"num_vertices",
|
||||
"num_edges",
|
||||
"lattice",
|
||||
"cycle",
|
||||
"adjacency_matrix_from_atomic_coordinates",
|
||||
]
|
||||
|
||||
import numpy as np
|
||||
|
||||
from spyrmsd import constants
|
||||
|
||||
|
||||
def adjacency_matrix_from_atomic_coordinates(
|
||||
aprops: np.ndarray, coordinates: np.ndarray
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Compute adjacency matrix from atomic coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
aprops: numpy.ndarray
|
||||
Atomic properties
|
||||
coordinates: numpy.ndarray
|
||||
Atomic coordinates
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
Adjacency matrix
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
This function is based on an automatic bond perception algorithm: two
|
||||
atoms are considered to be bonded when their distance is smaller than
|
||||
the sum of their covalent radii plus a tolerance value. [3]_
|
||||
|
||||
.. warning::
|
||||
The automatic bond perceptron rule implemented in this functions
|
||||
is very simple and only depends on atomic coordinates. Use
|
||||
with care!
|
||||
|
||||
.. [3] E. C. Meng and R. A. Lewis, *Determination of molecular topology and atomic
|
||||
hybridization states from heavy atom coordinates*, J. Comp. Chem. **12**, 891-898
|
||||
(1991).
|
||||
"""
|
||||
|
||||
n = len(aprops)
|
||||
|
||||
assert coordinates.shape == (n, 3)
|
||||
|
||||
A = np.zeros((n, n))
|
||||
|
||||
for i in range(n):
|
||||
r_i = constants.anum_to_covalentradius[aprops[i]]
|
||||
|
||||
for j in range(i + 1, n):
|
||||
r_j = constants.anum_to_covalentradius[aprops[j]]
|
||||
|
||||
distance = np.sqrt(np.sum((coordinates[i] - coordinates[j]) ** 2))
|
||||
|
||||
if distance < (r_i + r_j + constants.connectivity_tolerance):
|
||||
A[i, j] = A[j, i] = 1
|
||||
|
||||
return A
|
||||
0
spyrmsd/graphs/__init__.py
Normal file
0
spyrmsd/graphs/__init__.py
Normal file
8
spyrmsd/graphs/_common.py
Normal file
8
spyrmsd/graphs/_common.py
Normal file
@@ -0,0 +1,8 @@
|
||||
warn_disconnected_graph: str = "Disconnected graph detected. Is this expected?"
|
||||
warn_no_atomic_properties: str = (
|
||||
"No atomic property information stored on nodes. Node matching is not performed..."
|
||||
)
|
||||
|
||||
error_non_isomorphic_graphs: str = (
|
||||
"Graphs are not isomorphic.\nMake sure graphs have the same connectivity."
|
||||
)
|
||||
239
spyrmsd/graphs/gt.py
Normal file
239
spyrmsd/graphs/gt.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import graph_tool as gt
|
||||
import numpy as np
|
||||
from graph_tool import generation, topology
|
||||
|
||||
from spyrmsd.exceptions import NonIsomorphicGraphs
|
||||
from spyrmsd.graphs._common import (
|
||||
error_non_isomorphic_graphs,
|
||||
warn_disconnected_graph,
|
||||
warn_no_atomic_properties,
|
||||
)
|
||||
|
||||
|
||||
# TODO: Implement all graph-tool supported types
|
||||
def _c_type(numpy_dtype):
|
||||
"""
|
||||
Get C type compatible with graph-tool from numpy dtype
|
||||
|
||||
Parameters
|
||||
----------
|
||||
numpy_dtype: np.dtype
|
||||
Numpy dtype
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
C type
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the data type is not supported
|
||||
|
||||
Notes
|
||||
-----
|
||||
https://graph-tool.skewed.de/static/doc/quickstart.html#sec-property-maps
|
||||
"""
|
||||
name: str = numpy_dtype.name
|
||||
|
||||
if "int" in name:
|
||||
return "int"
|
||||
elif "float" in name:
|
||||
return "double"
|
||||
elif "str" in name:
|
||||
return "string"
|
||||
else:
|
||||
raise ValueError(f"Unsupported property type: {name}")
|
||||
|
||||
|
||||
def graph_from_adjacency_matrix(
|
||||
adjacency_matrix: Union[np.ndarray, List[List[int]]],
|
||||
aprops: Optional[Union[np.ndarray, List[Any]]] = None,
|
||||
):
|
||||
"""
|
||||
Graph from adjacency matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
adjacency_matrix: Union[np.ndarray, List[List[int]]]
|
||||
Adjacency matrix
|
||||
aprops: Union[np.ndarray, List[Any]], optional
|
||||
Atomic properties
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Molecular graph
|
||||
|
||||
Notes
|
||||
-----
|
||||
It the atomic numbers are passed, they are used as node attributes.
|
||||
"""
|
||||
|
||||
# Get upper triangular adjacency matrix
|
||||
adj = np.triu(adjacency_matrix)
|
||||
|
||||
assert adj.shape[0] == adj.shape[1]
|
||||
num_vertices = adj.shape[0]
|
||||
|
||||
G = gt.Graph(directed=False)
|
||||
G.add_vertex(n=num_vertices)
|
||||
G.add_edge_list(np.transpose(adj.nonzero()))
|
||||
|
||||
# Check if graph is connected, for warning
|
||||
cc, _ = topology.label_components(G)
|
||||
if set(cc.a) != {0}:
|
||||
warnings.warn(warn_disconnected_graph)
|
||||
|
||||
if aprops is not None:
|
||||
if not isinstance(aprops, np.ndarray):
|
||||
aprops = np.array(aprops)
|
||||
|
||||
assert aprops.shape[0] == num_vertices
|
||||
|
||||
ptype: str = _c_type(aprops.dtype) # Get C type
|
||||
vprop = G.new_vertex_property(ptype, vals=aprops) # Create property map
|
||||
G.vertex_properties["aprops"] = vprop # Set property map
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def match_graphs(G1, G2) -> List[Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Compute graph isomorphisms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G1:
|
||||
Graph 1
|
||||
G2:
|
||||
Graph 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Tuple[List[int],List[int]]]
|
||||
All possible mappings between nodes of graph 1 and graph 2 (isomorphisms)
|
||||
|
||||
Raises
|
||||
------
|
||||
NonIsomorphicGraphs
|
||||
If the graphs `G1` and `G2` are not isomorphic
|
||||
"""
|
||||
|
||||
try:
|
||||
maps = topology.subgraph_isomorphism(
|
||||
G1,
|
||||
G2,
|
||||
vertex_label=(
|
||||
G1.vertex_properties["aprops"],
|
||||
G2.vertex_properties["aprops"],
|
||||
),
|
||||
subgraph=False,
|
||||
)
|
||||
except KeyError:
|
||||
warnings.warn(warn_no_atomic_properties)
|
||||
|
||||
maps = topology.subgraph_isomorphism(G1, G2, subgraph=False)
|
||||
|
||||
# Check if graphs are actually isomorphic
|
||||
if len(maps) == 0:
|
||||
raise NonIsomorphicGraphs(error_non_isomorphic_graphs)
|
||||
|
||||
n = num_vertices(G1)
|
||||
|
||||
# Extract all isomorphisms in a list
|
||||
return [(np.arange(0, n, dtype=int), m.a) for m in maps]
|
||||
|
||||
|
||||
def vertex_property(G, vproperty: str, idx: int) -> Any:
|
||||
"""
|
||||
Get vertex (node) property from graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
vproperty: str
|
||||
Vertex property name
|
||||
idx: int
|
||||
Vertex index
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Vertex property value
|
||||
"""
|
||||
return G.vertex_properties[vproperty][idx]
|
||||
|
||||
|
||||
def num_vertices(G) -> int:
|
||||
"""
|
||||
Number of vertices
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of vertices (nodes)
|
||||
"""
|
||||
return G.num_vertices()
|
||||
|
||||
|
||||
def num_edges(G) -> int:
|
||||
"""
|
||||
Number of edges
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of edges
|
||||
"""
|
||||
return G.num_edges()
|
||||
|
||||
|
||||
def lattice(n1: int, n2: int):
|
||||
"""
|
||||
Build 2D lattice graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n1: int
|
||||
Number of nodes in dimension 1
|
||||
n2: int
|
||||
Number of nodes in dimension 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Lattice graph
|
||||
"""
|
||||
return generation.lattice((n1, n2))
|
||||
|
||||
|
||||
def cycle(n):
|
||||
"""
|
||||
Build cycle graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: int
|
||||
Number of nodes
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Cycle graph
|
||||
"""
|
||||
return generation.circular_graph(n)
|
||||
192
spyrmsd/graphs/nx.py
Normal file
192
spyrmsd/graphs/nx.py
Normal file
@@ -0,0 +1,192 @@
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
||||
from spyrmsd.exceptions import NonIsomorphicGraphs
|
||||
from spyrmsd.graphs._common import (
|
||||
error_non_isomorphic_graphs,
|
||||
warn_disconnected_graph,
|
||||
warn_no_atomic_properties,
|
||||
)
|
||||
|
||||
|
||||
def graph_from_adjacency_matrix(
|
||||
adjacency_matrix: Union[np.ndarray, List[List[int]]],
|
||||
aprops: Optional[Union[np.ndarray, List[Any]]] = None,
|
||||
) -> nx.Graph:
|
||||
"""
|
||||
Graph from adjacency matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
adjacency_matrix: Union[np.ndarray, List[List[int]]]
|
||||
Adjacency matrix
|
||||
aprops: Union[np.ndarray, List[Any]], optional
|
||||
Atomic properties
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Molecular graph
|
||||
|
||||
Notes
|
||||
-----
|
||||
It the atomic numbers are passed, they are used as node attributes.
|
||||
"""
|
||||
|
||||
G = nx.Graph(adjacency_matrix)
|
||||
|
||||
if not nx.is_connected(G):
|
||||
warnings.warn(warn_disconnected_graph)
|
||||
|
||||
if aprops is not None:
|
||||
attributes = {idx: aprops for idx, aprops in enumerate(aprops)}
|
||||
nx.set_node_attributes(G, attributes, "aprops")
|
||||
|
||||
return G
|
||||
|
||||
|
||||
def match_graphs(G1, G2) -> List[Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Compute graph isomorphisms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G1:
|
||||
Graph 1
|
||||
G2:
|
||||
Graph 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Tuple[List[int],List[int]]]
|
||||
All possible mappings between nodes of graph 1 and graph 2 (isomorphisms)
|
||||
|
||||
Raises
|
||||
------
|
||||
NonIsomorphicGraphs
|
||||
If the graphs `G1` and `G2` are not isomorphic
|
||||
"""
|
||||
|
||||
def match_aprops(node1, node2):
|
||||
"""
|
||||
Check if atomic properties for two nodes match.
|
||||
"""
|
||||
return node1["aprops"] == node2["aprops"]
|
||||
|
||||
if (
|
||||
nx.get_node_attributes(G1, "aprops") == {}
|
||||
or nx.get_node_attributes(G2, "aprops") == {}
|
||||
):
|
||||
# Nodes without atomic number information
|
||||
# No node-matching check
|
||||
node_match = None
|
||||
|
||||
warnings.warn(warn_no_atomic_properties)
|
||||
|
||||
else:
|
||||
node_match = match_aprops
|
||||
|
||||
GM = nx.algorithms.isomorphism.GraphMatcher(G1, G2, node_match)
|
||||
|
||||
# Check if graphs are actually isomorphic
|
||||
if not GM.is_isomorphic():
|
||||
raise NonIsomorphicGraphs(error_non_isomorphic_graphs)
|
||||
|
||||
return [
|
||||
(list(isomorphism.keys()), list(isomorphism.values()))
|
||||
for isomorphism in GM.isomorphisms_iter()
|
||||
]
|
||||
|
||||
|
||||
def vertex_property(G, vproperty: str, idx: int) -> Any:
|
||||
"""
|
||||
Get vertex (node) property from graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
vproperty: str
|
||||
Vertex property name
|
||||
idx: int
|
||||
Vertex index
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Vertex property value
|
||||
"""
|
||||
return G.nodes[idx][vproperty]
|
||||
|
||||
|
||||
def num_vertices(G) -> int:
|
||||
"""
|
||||
Number of vertices
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of vertices (nodes)
|
||||
"""
|
||||
return G.number_of_nodes()
|
||||
|
||||
|
||||
def num_edges(G) -> int:
|
||||
"""
|
||||
Number of edges
|
||||
|
||||
Parameters
|
||||
----------
|
||||
G:
|
||||
Graph
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of edges
|
||||
"""
|
||||
return G.number_of_edges()
|
||||
|
||||
|
||||
def lattice(n1, n2):
|
||||
"""
|
||||
Build 2D lattice graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n1: int
|
||||
Number of nodes in dimension 1
|
||||
n2: int
|
||||
Number of nodes in dimension 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Lattice graph
|
||||
"""
|
||||
return nx.generators.lattice.grid_2d_graph(n1, n2)
|
||||
|
||||
|
||||
def cycle(n):
|
||||
"""
|
||||
Build cycle graph
|
||||
|
||||
Parameters
|
||||
----------
|
||||
n: int
|
||||
Number of nodes
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Cycle graph
|
||||
"""
|
||||
return nx.cycle_graph(n)
|
||||
120
spyrmsd/hungarian.py
Normal file
120
spyrmsd/hungarian.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import numpy as np
|
||||
import scipy
|
||||
|
||||
from .due import Doi, due
|
||||
|
||||
due.cite(
|
||||
Doi("10.1021/ci400534h"),
|
||||
path="spyrmsd.hungarian",
|
||||
description="Hungarian method",
|
||||
)
|
||||
|
||||
|
||||
def cost_mtx(A: np.ndarray, B: np.ndarray):
|
||||
"""
|
||||
Compute the cost matrix for atom-atom assignment.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: numpy.ndarray
|
||||
Atomic coordinates of molecule A
|
||||
B: numpy.ndarray
|
||||
Atomic coordinates of molecule B
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Cost matrix of squared atomic distances between atoms of
|
||||
molecules A and B
|
||||
"""
|
||||
|
||||
return scipy.spatial.distance.cdist(A, B, metric="sqeuclidean")
|
||||
|
||||
|
||||
def optimal_assignment(A: np.ndarray, B: np.ndarray):
|
||||
"""
|
||||
Solve the optimal assignment problems between atomic coordinates of
|
||||
molecules A and B.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: numpy.ndarray
|
||||
Atomic coordinates of molecule A
|
||||
B: numpy.ndarray
|
||||
Atomic coordinates of molecule B
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, nd.array, nd.array]
|
||||
Cost of the optimal assignment, together with the row and column
|
||||
indices of said assignment
|
||||
"""
|
||||
|
||||
C = cost_mtx(A, B)
|
||||
|
||||
row_idx, col_idx = scipy.optimize.linear_sum_assignment(C)
|
||||
|
||||
# Compute assignment cost
|
||||
cost = C[row_idx, col_idx].sum()
|
||||
|
||||
return cost, row_idx, col_idx
|
||||
|
||||
|
||||
def hungarian_rmsd(
|
||||
A: np.ndarray, B: np.ndarray, apropsA: np.ndarray, apropsB: np.ndarray
|
||||
) -> float:
|
||||
"""
|
||||
Solve the optimal assignment problems between atomic coordinates of
|
||||
molecules A and B.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: numpy.ndarray
|
||||
Atomic coordinates of molecule A
|
||||
B: numpy.ndarray
|
||||
Atomic coordinates of molecule B
|
||||
apropsA: numpy.ndarray
|
||||
Atomic properties of molecule A
|
||||
apropsB: numpy.ndarray
|
||||
Atomic properties of molecule B
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
RMSD computed with the Hungarian method
|
||||
|
||||
Notes
|
||||
-----
|
||||
The Hungarian algorithm is used to solve the linear assignment problem, which is
|
||||
a minimum weight matching of the molecular graphs (bipartite).
|
||||
|
||||
The linear assignment problem is solved for every element separately. [1]_
|
||||
|
||||
.. [1] W. J. Allen and R. C. Rizzo, *Implementation of the Hungarian Algorithm to
|
||||
Account for Ligand Symmetry and Similarity in Structure-Based Design*,
|
||||
J. Chem. Inf. Model. **54**, 518-529 (2014)
|
||||
"""
|
||||
|
||||
assert A.shape == B.shape
|
||||
assert apropsA.shape == apropsB.shape
|
||||
|
||||
elements = set(apropsA)
|
||||
|
||||
total_cost: float = 0.0
|
||||
for t in elements:
|
||||
apropsA_idx = apropsA == t
|
||||
apropsB_idx = apropsB == t
|
||||
|
||||
assert apropsA_idx.shape == apropsB_idx.shape
|
||||
|
||||
cost, row_idx, col_idx = optimal_assignment(
|
||||
A[apropsA_idx, :], B[apropsB_idx, :]
|
||||
)
|
||||
|
||||
total_cost += cost
|
||||
|
||||
N = A.shape[0]
|
||||
|
||||
rmsd = np.sqrt(total_cost / N)
|
||||
|
||||
return rmsd
|
||||
87
spyrmsd/io.py
Normal file
87
spyrmsd/io.py
Normal file
@@ -0,0 +1,87 @@
|
||||
try:
|
||||
from spyrmsd.optional.obabel import (
|
||||
adjacency_matrix,
|
||||
bonds,
|
||||
load,
|
||||
loadall,
|
||||
numatoms,
|
||||
numbonds,
|
||||
to_molecule,
|
||||
)
|
||||
|
||||
except ImportError:
|
||||
try:
|
||||
from spyrmsd.optional.rdkit import (
|
||||
adjacency_matrix,
|
||||
bonds,
|
||||
load,
|
||||
loadall,
|
||||
numatoms,
|
||||
numbonds,
|
||||
to_molecule,
|
||||
)
|
||||
except ImportError:
|
||||
# Use sPyRMSD as standalone library
|
||||
__all__ = []
|
||||
else:
|
||||
# Avoid flake8 complaint "imported but unused"
|
||||
__all__ = [
|
||||
"load",
|
||||
"loadall",
|
||||
"adjacency_matrix",
|
||||
"to_molecule",
|
||||
"numatoms",
|
||||
"numbonds",
|
||||
"bonds",
|
||||
]
|
||||
else:
|
||||
# Avoid flake8 complaint "imported but unused"
|
||||
__all__ = [
|
||||
"load",
|
||||
"loadall",
|
||||
"adjacency_matrix",
|
||||
"to_molecule",
|
||||
"numatoms",
|
||||
"numbonds",
|
||||
"bonds",
|
||||
]
|
||||
|
||||
|
||||
def loadmol(fname: str, adjacency: bool = True):
|
||||
"""
|
||||
Load molecule from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
molecule.Molecule
|
||||
`spyrmsd` molecule
|
||||
"""
|
||||
|
||||
mol = load(fname)
|
||||
|
||||
return to_molecule(mol, adjacency=adjacency)
|
||||
|
||||
|
||||
def loadallmols(fname: str, adjacency: bool = True):
|
||||
"""
|
||||
Load molecules from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[molecule.Molecule]
|
||||
`spyrmsd` molecules
|
||||
"""
|
||||
|
||||
mols = loadall(fname)
|
||||
|
||||
return [to_molecule(mol, adjacency=adjacency) for mol in mols]
|
||||
266
spyrmsd/molecule.py
Normal file
266
spyrmsd/molecule.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from spyrmsd import constants, graph, utils
|
||||
|
||||
|
||||
class Molecule:
|
||||
def __init__(
|
||||
self,
|
||||
atomicnums: Union[np.ndarray, List[int]],
|
||||
coordinates: Union[np.ndarray, List[List[float]]],
|
||||
adjacency_matrix: Optional[Union[np.ndarray, List[List[int]]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Molecule initialisation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
atomicnums: Union[np.ndarray, List[int]]
|
||||
Atomic numbers
|
||||
coordinates: Union[np.ndarray, List[List[float]]]
|
||||
Atomic coordinates
|
||||
adjacency_matrix: Union[np.ndarray, List[List[int]]], optional
|
||||
Molecular graph adjacency matrix
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
A molecule is built from atomic numbers and atomic coordinates only.
|
||||
Optionally, a good representation of the molecular graph (obtained with
|
||||
OpenBabel or RDKit) can be stored as adjacency matrix.
|
||||
"""
|
||||
|
||||
atomicnums = np.asarray(atomicnums, dtype=int)
|
||||
coordinates = np.asarray(coordinates, dtype=float)
|
||||
|
||||
self.natoms: int = len(atomicnums)
|
||||
|
||||
assert atomicnums.shape == (self.natoms,)
|
||||
assert coordinates.shape == (self.natoms, 3)
|
||||
|
||||
self.atomicnums = atomicnums
|
||||
self.coordinates = coordinates
|
||||
|
||||
self.stripped: bool = bool(np.all(atomicnums != 1))
|
||||
|
||||
if adjacency_matrix is not None:
|
||||
self.adjacency_matrix: np.ndarray = np.asarray(adjacency_matrix, dtype=int)
|
||||
|
||||
# Molecular graph
|
||||
self.G = None
|
||||
|
||||
self.masses: Optional[List[float]] = None
|
||||
|
||||
@classmethod
|
||||
def from_obabel(cls, obmol, adjacency: bool = True):
|
||||
"""
|
||||
Constructor from OpenBabel molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
obmol:
|
||||
OpenBabel molecule
|
||||
adjacency:
|
||||
Flag to compute the adjacency matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
spyrmsd.molecule.Molecule
|
||||
:code:`spyrmsd` Molecule
|
||||
"""
|
||||
# TODO: Check if OpenBabel is available?
|
||||
from spyrmsd.optional import obabel as ob
|
||||
|
||||
return ob.to_molecule(obmol, adjacency=adjacency)
|
||||
|
||||
@classmethod
|
||||
def from_rdkit(cls, rdmol, adjacency: bool = True):
|
||||
"""
|
||||
Constructor from RDKit molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
rdmol:
|
||||
RDKit molecule
|
||||
adjacency:
|
||||
Flag to compute the adjacency matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
spyrmsd.molecule.Molecule
|
||||
:code:`spyrmsd` Molecule
|
||||
"""
|
||||
|
||||
# TODO: Check if RDKit is available?
|
||||
from spyrmsd.optional import rdkit as rd
|
||||
|
||||
return rd.to_molecule(rdmol, adjacency=adjacency)
|
||||
|
||||
def translate(self, vector: Union[np.ndarray, List[float]]) -> None:
|
||||
"""
|
||||
Translate molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
vector: np.ndarray
|
||||
Translation vector (in 3D)
|
||||
"""
|
||||
assert len(vector) == 3
|
||||
vector = np.asarray(vector)
|
||||
self.coordinates += vector
|
||||
|
||||
def rotate(
|
||||
self, angle: float, axis: Union[np.ndarray, List[float]], units: str = "rad"
|
||||
) -> None:
|
||||
"""
|
||||
Rotate molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angle: float
|
||||
Rotation angle
|
||||
axis: np.ndarray
|
||||
Axis of rotation (in 3D)
|
||||
units: {"rad", "deg"}
|
||||
Units of the angle (radians `rad` or degrees `deg`)
|
||||
"""
|
||||
axis = np.asarray(axis)
|
||||
self.coordinates = utils.rotate(self.coordinates, angle, axis, units)
|
||||
|
||||
def center_of_mass(self) -> np.ndarray:
|
||||
"""
|
||||
Center of mass.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Center of mass
|
||||
|
||||
Notes
|
||||
-----
|
||||
Atomic masses are cached.
|
||||
"""
|
||||
|
||||
# Get masses and cache them
|
||||
if self.masses is None:
|
||||
self.masses = [constants.anum_to_mass[anum] for anum in self.atomicnums]
|
||||
|
||||
return np.average(self.coordinates, axis=0, weights=self.masses)
|
||||
|
||||
def center_of_geometry(self) -> np.ndarray:
|
||||
"""
|
||||
Center of geometry.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Center of geometry
|
||||
"""
|
||||
return utils.center_of_geometry(self.coordinates)
|
||||
|
||||
# TODO: Change name (to stripH)
|
||||
def strip(self) -> None:
|
||||
"""
|
||||
Strip hydrogen atoms.
|
||||
"""
|
||||
|
||||
if not self.stripped:
|
||||
idx = self.atomicnums != 1 # Non-hydrogen atoms
|
||||
|
||||
# Strip
|
||||
self.atomicnums = self.atomicnums[idx]
|
||||
self.coordinates = self.coordinates[idx, :]
|
||||
|
||||
# Update number of atoms
|
||||
self.natoms = len(self.atomicnums)
|
||||
|
||||
# Update adjacency matrix
|
||||
if self.adjacency_matrix is not None:
|
||||
self.adjacency_matrix = self.adjacency_matrix[np.ix_(idx, idx)]
|
||||
|
||||
# Reset molecular graph when stripping
|
||||
self.G = None
|
||||
|
||||
self.stripped = True
|
||||
|
||||
def to_graph(self):
|
||||
"""
|
||||
Convert molecule to graph.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Graph
|
||||
Molecular graph.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the molecule does not have an associated adjacency matrix, a simple
|
||||
bond perception is used.
|
||||
|
||||
The molecular graph is cached.
|
||||
"""
|
||||
if self.G is None:
|
||||
try:
|
||||
self.G = graph.graph_from_adjacency_matrix(
|
||||
self.adjacency_matrix, self.atomicnums
|
||||
)
|
||||
except AttributeError:
|
||||
warnings.warn(
|
||||
"Molecule was not initialized with an adjacency matrix. "
|
||||
+ "Using bond perception..."
|
||||
)
|
||||
|
||||
# Automatic bond perception (with very simple rule)
|
||||
self.adjacency_matrix = graph.adjacency_matrix_from_atomic_coordinates(
|
||||
self.atomicnums, self.coordinates
|
||||
)
|
||||
|
||||
self.G = graph.graph_from_adjacency_matrix(
|
||||
self.adjacency_matrix, self.atomicnums
|
||||
)
|
||||
|
||||
return self.G
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""
|
||||
Molecule size.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of atoms within the molecule
|
||||
"""
|
||||
return self.natoms
|
||||
|
||||
|
||||
def coords_from_molecule(mol: Molecule, center: bool = False) -> np.ndarray:
|
||||
"""
|
||||
Atomic coordinates from molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol: molecule.Molecule
|
||||
Molecule
|
||||
center: bool
|
||||
Center flag
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Atomic coordinates (possibly centred)
|
||||
|
||||
Notes
|
||||
-----
|
||||
Atomic coordinates are centred according to the center of geometry, not the center
|
||||
of mass.
|
||||
"""
|
||||
|
||||
if center:
|
||||
coords = mol.coordinates - mol.center_of_geometry()
|
||||
else:
|
||||
coords = mol.coordinates
|
||||
|
||||
return coords
|
||||
7
spyrmsd/optional/__init__.py
Normal file
7
spyrmsd/optional/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Handle versioneer
|
||||
from spyrmsd._version import get_versions
|
||||
|
||||
versions = get_versions()
|
||||
__version__ = versions["version"]
|
||||
__git_revision__ = versions["full-revisionid"]
|
||||
del get_versions, versions
|
||||
181
spyrmsd/optional/obabel.py
Normal file
181
spyrmsd/optional/obabel.py
Normal file
@@ -0,0 +1,181 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from openbabel import openbabel as ob
|
||||
from openbabel import pybel
|
||||
|
||||
from spyrmsd import molecule, utils
|
||||
|
||||
|
||||
def load(fname: str):
|
||||
"""
|
||||
Load molecule from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
Molecule
|
||||
"""
|
||||
|
||||
fmt = utils.molformat(fname)
|
||||
|
||||
obmol = next(pybel.readfile(fmt, fname))
|
||||
|
||||
return obmol
|
||||
|
||||
|
||||
def loadall(fname: str):
|
||||
"""
|
||||
Load molecules from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
List of molecules
|
||||
"""
|
||||
|
||||
fmt = utils.molformat(fname)
|
||||
|
||||
obmols = [obmol for obmol in pybel.readfile(fmt, fname)]
|
||||
|
||||
# FIXME: Special handling for multi-model PDB files
|
||||
# See OpenBabel Issue #2097
|
||||
if fmt == "pdb":
|
||||
if len(obmols) > 1: # Multi-model PDB file
|
||||
obmols = obmols[:-1]
|
||||
|
||||
return obmols
|
||||
|
||||
|
||||
def adjacency_matrix(mol) -> np.ndarray:
|
||||
"""
|
||||
Adjacency matrix from OpenBabel molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Adjacency matrix of the molecule
|
||||
"""
|
||||
|
||||
n = len(mol.atoms)
|
||||
|
||||
# Pre-allocate memory for the adjacency matrix
|
||||
A = np.zeros((n, n), dtype=int)
|
||||
|
||||
# Loop over molecular bonds
|
||||
for bond in ob.OBMolBondIter(mol.OBMol):
|
||||
# Bonds are 1-indexed
|
||||
i: int = bond.GetBeginAtomIdx() - 1
|
||||
j: int = bond.GetEndAtomIdx() - 1
|
||||
|
||||
# A molecular graph is undirected
|
||||
A[i, j] = A[j, i] = 1
|
||||
|
||||
return A
|
||||
|
||||
|
||||
def to_molecule(mol, adjacency: bool = True):
|
||||
"""
|
||||
Transform molecule to `pyrmsd` molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
adjacency: boolean, optional
|
||||
Flag to decide wether to build the adjacency matrix from molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
pyrmsd.molecule.Molecule
|
||||
`pyrmsd` molecule
|
||||
"""
|
||||
|
||||
n = len(mol.atoms)
|
||||
|
||||
atomicnums = np.zeros(n, dtype=int)
|
||||
coordinates = np.zeros((n, 3))
|
||||
|
||||
for i, atom in enumerate(mol.atoms):
|
||||
atomicnums[i] = atom.atomicnum
|
||||
coordinates[i] = atom.coords
|
||||
|
||||
A: Optional[np.ndarray] = adjacency_matrix(mol) if adjacency else None
|
||||
|
||||
return molecule.Molecule(atomicnums, coordinates, A)
|
||||
|
||||
|
||||
def numatoms(mol) -> int:
|
||||
"""
|
||||
Number of atoms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of atoms
|
||||
"""
|
||||
return mol.OBMol.NumAtoms()
|
||||
|
||||
|
||||
def numbonds(mol) -> int:
|
||||
"""
|
||||
Number of bonds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of bonds
|
||||
"""
|
||||
return mol.OBMol.NumBonds()
|
||||
|
||||
|
||||
def bonds(mol) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
List of bonds
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Tuple[int, int]]
|
||||
List of bonds
|
||||
|
||||
Notes
|
||||
-----
|
||||
A bond is defined by a tuple of (0-based) indices of two atoms.
|
||||
"""
|
||||
b = []
|
||||
|
||||
for bond in ob.OBMolBondIter(mol.OBMol):
|
||||
i = bond.GetBeginAtomIdx() - 1
|
||||
j = bond.GetEndAtomIdx() - 1
|
||||
|
||||
b.append((i, j))
|
||||
|
||||
return b
|
||||
227
spyrmsd/optional/rdkit.py
Normal file
227
spyrmsd/optional/rdkit.py
Normal file
@@ -0,0 +1,227 @@
|
||||
import gzip
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import rdkit.Chem as Chem
|
||||
|
||||
from spyrmsd import molecule, utils
|
||||
|
||||
|
||||
def _load_block_gzipped(loader, fname: str):
|
||||
"""
|
||||
Load gzipped files using MolBlocks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
loader:
|
||||
RDKit MolBlock loader (MolFromMol2Block, MolFromPDBBlock, ...)
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
Molecule
|
||||
"""
|
||||
with gzip.open(fname, "r") as fgz:
|
||||
content = fgz.read()
|
||||
rdmol = loader(content, removeHs=False)
|
||||
|
||||
return rdmol
|
||||
|
||||
|
||||
def load(fname: str):
|
||||
"""
|
||||
Load molecule from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
Molecule
|
||||
"""
|
||||
|
||||
gzipped = os.path.splitext(fname)[-1] == ".gz"
|
||||
fmt = utils.molformat(fname)
|
||||
|
||||
if fmt == "mol2":
|
||||
if not gzipped:
|
||||
rdmol = Chem.MolFromMol2File(fname, removeHs=False)
|
||||
else:
|
||||
rdmol = _load_block_gzipped(Chem.MolFromMol2Block, fname)
|
||||
elif fmt == "sdf":
|
||||
if not gzipped:
|
||||
rdmol = next(Chem.SDMolSupplier(fname, removeHs=False))
|
||||
else:
|
||||
with gzip.open(fname, "r") as fgz:
|
||||
rdmol = next(Chem.ForwardSDMolSupplier(fgz, removeHs=False))
|
||||
elif fmt == "pdb":
|
||||
if not gzipped:
|
||||
rdmol = Chem.MolFromPDBFile(fname, removeHs=False)
|
||||
else:
|
||||
rdmol = _load_block_gzipped(Chem.MolFromPDBBlock, fname)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return rdmol
|
||||
|
||||
|
||||
def loadall(fname: str):
|
||||
"""
|
||||
Load molecules from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname: str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
List of molecules
|
||||
"""
|
||||
|
||||
gzipped = os.path.splitext(fname)[-1] == ".gz"
|
||||
fmt = utils.molformat(fname)
|
||||
|
||||
if fmt == "mol2":
|
||||
raise NotImplementedError # See RDKit Issue #415
|
||||
elif fmt == "sdf":
|
||||
if not gzipped:
|
||||
rdmols = Chem.SDMolSupplier(fname, removeHs=False)
|
||||
mols = [rdmol for rdmol in rdmols]
|
||||
else:
|
||||
with gzip.open(fname, "r") as fgz:
|
||||
rdmols = Chem.ForwardSDMolSupplier(fgz, removeHs=False)
|
||||
|
||||
# Load all molecules before closing file
|
||||
mols = [rdmol for rdmol in rdmols]
|
||||
elif fmt == "pdb":
|
||||
# TODO: Implement
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return mols
|
||||
|
||||
|
||||
def adjacency_matrix(mol) -> np.ndarray:
|
||||
"""
|
||||
Adjacency matrix from OpenBabel molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Adjacency matrix of the molecule
|
||||
"""
|
||||
|
||||
return Chem.rdmolops.GetAdjacencyMatrix(mol)
|
||||
|
||||
|
||||
def to_molecule(mol, adjacency: bool = True):
|
||||
"""
|
||||
Transform molecule to `pyrmsd` molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
adjacency: boolean, optional
|
||||
Flag to decide wether to build the adjacency matrix from molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
spyrmsd.molecule.Molecule
|
||||
`spyrmsd` molecule
|
||||
"""
|
||||
|
||||
if mol is None:
|
||||
# Propagate RDKit parsing failure
|
||||
return None
|
||||
|
||||
atoms = mol.GetAtoms()
|
||||
|
||||
n = len(atoms)
|
||||
|
||||
atomicnums = np.zeros(n, dtype=int)
|
||||
coordinates = np.zeros((n, 3))
|
||||
|
||||
conformer = mol.GetConformer()
|
||||
|
||||
for i, atom in enumerate(atoms):
|
||||
atomicnums[i] = atom.GetAtomicNum()
|
||||
|
||||
pos = conformer.GetAtomPosition(i)
|
||||
|
||||
coordinates[i] = np.array([pos.x, pos.y, pos.z])
|
||||
|
||||
A: Optional[np.ndarray] = adjacency_matrix(mol) if adjacency else None
|
||||
|
||||
return molecule.Molecule(atomicnums, coordinates, A)
|
||||
|
||||
|
||||
def numatoms(mol) -> int:
|
||||
"""
|
||||
Number of atoms.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of atoms
|
||||
"""
|
||||
return mol.GetNumAtoms()
|
||||
|
||||
|
||||
def numbonds(mol) -> int:
|
||||
"""
|
||||
Number of bonds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
Number of bonds
|
||||
"""
|
||||
return mol.GetNumBonds()
|
||||
|
||||
|
||||
def bonds(mol) -> List[Tuple[int, int]]:
|
||||
"""
|
||||
List of bonds.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
mol:
|
||||
Molecule
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[Tuple[int, int]]
|
||||
List of bonds
|
||||
|
||||
Notes
|
||||
-----
|
||||
A bond is defined by a tuple of (0-based) indices of two atoms.
|
||||
"""
|
||||
b = []
|
||||
|
||||
for bond in mol.GetBonds():
|
||||
b.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
|
||||
|
||||
return b
|
||||
288
spyrmsd/qcp.py
Normal file
288
spyrmsd/qcp.py
Normal file
@@ -0,0 +1,288 @@
|
||||
from typing import Tuple
|
||||
|
||||
import numpy as np
|
||||
from scipy import optimize
|
||||
|
||||
from .due import Doi, due
|
||||
|
||||
due.cite(
|
||||
Doi("10.1107/S0108767305015266"),
|
||||
path="spyrmsd.qcp",
|
||||
description="QCP method",
|
||||
)
|
||||
|
||||
|
||||
def M_mtx(A: np.ndarray, B: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Compute inner product between coordinate matrices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: numpy.ndarray
|
||||
Coordinates `A`
|
||||
B: numpy.ndarray
|
||||
Coordinates `B`
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
Inner product of the coordinate matrices `A` and `B`
|
||||
|
||||
Notes
|
||||
-----
|
||||
The inner product of the coordinate matrices `A` and `B` corresponds to the matrix
|
||||
:math:`\\mathbf{M}`. [1]_
|
||||
|
||||
If :math:`S_{xy}` is defined as
|
||||
|
||||
.. math:: S_{xy} = \\sum_i^N x_{B,i} y_{A,i}
|
||||
|
||||
then :math:`\\mathbf{M}` is the :math:`3\\times 3` matrix given by
|
||||
|
||||
.. math::
|
||||
\\begin{pmatrix}
|
||||
S_{xx} & S_{xy} & S_{xz} \\\\
|
||||
S_{yx} & S_{yy} & S_{yz} \\\\
|
||||
S_{zx} & S_{zy} & S_{zz} \\\\
|
||||
\\end{pmatrix}
|
||||
|
||||
.. [1] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
|
||||
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
|
||||
"""
|
||||
return B.T @ A
|
||||
|
||||
|
||||
def K_mtx(M):
|
||||
"""
|
||||
Compute symmetric key matrix.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
M : numpy.ndarray
|
||||
Inner product between coordinate matrices
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.ndarray
|
||||
Symmetric key matrix
|
||||
|
||||
Notes
|
||||
-----
|
||||
The symmetric key matrix corresponds to the matrix :math:`\\mathbf{K}`. [2]_
|
||||
|
||||
If :math:`S_{xy}` is defined as
|
||||
|
||||
.. math:: S_{xy} = \\sum_i^N x_{B,i} y_{A,i}
|
||||
|
||||
then :math:`\\mathbf{K}` is the :math:`4\\times 4` symmetric matrix given by
|
||||
|
||||
.. math::
|
||||
\\begin{pmatrix}
|
||||
S_{xx} + S_{yy} + S_{zz} & S_{yz} - S_{zy} & S_{zx} - S_{xz} & S_{xy} - S_{yx} \\\\
|
||||
& S_{xx} - S_{yy} - S_{zz} & S_{xy} + S_{yx} & S_{zx} + S_{xz}\\\\
|
||||
& & -S_{xx} + S_{yy} - S_{zz} & S_{yz} - S_{zy} \\\\
|
||||
& & & -S_{xx} - S_{yy} + S_{zz} \\\\
|
||||
\\end{pmatrix}
|
||||
|
||||
.. [2] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
|
||||
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
|
||||
"""
|
||||
|
||||
assert M.shape == (3, 3)
|
||||
|
||||
S_xx = M[0, 0]
|
||||
S_xy = M[0, 1]
|
||||
S_xz = M[0, 2]
|
||||
S_yx = M[1, 0]
|
||||
S_yy = M[1, 1]
|
||||
S_yz = M[1, 2]
|
||||
S_zx = M[2, 0]
|
||||
S_zy = M[2, 1]
|
||||
S_zz = M[2, 2]
|
||||
|
||||
# p = plus, m = minus
|
||||
S_xx_yy_zz_ppp = S_xx + S_yy + S_zz
|
||||
S_yz_zy_pm = S_yz - S_zy
|
||||
S_zx_xz_pm = S_zx - S_xz
|
||||
S_xy_yx_pm = S_xy - S_yx
|
||||
S_xx_yy_zz_pmm = S_xx - S_yy - S_zz
|
||||
S_xy_yx_pp = S_xy + S_yx
|
||||
S_zx_xz_pp = S_zx + S_xz
|
||||
S_xx_yy_zz_mpm = -S_xx + S_yy - S_zz
|
||||
S_yz_zy_pp = S_yz + S_zy
|
||||
S_xx_yy_zz_mmp = -S_xx - S_yy + S_zz
|
||||
|
||||
return np.array(
|
||||
[
|
||||
[S_xx_yy_zz_ppp, S_yz_zy_pm, S_zx_xz_pm, S_xy_yx_pm],
|
||||
[S_yz_zy_pm, S_xx_yy_zz_pmm, S_xy_yx_pp, S_zx_xz_pp],
|
||||
[S_zx_xz_pm, S_xy_yx_pp, S_xx_yy_zz_mpm, S_yz_zy_pp],
|
||||
[S_xy_yx_pm, S_zx_xz_pp, S_yz_zy_pp, S_xx_yy_zz_mmp],
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def coefficients(M: np.ndarray, K: np.ndarray) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Compute quaternion polynomial coefficients.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
M : numpy.ndarray
|
||||
Inner product between coordinate matrices
|
||||
K: numpy.ndarray
|
||||
Symmetric key matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float, float]
|
||||
Quaternion polynomial coefficients
|
||||
|
||||
Notes
|
||||
_____
|
||||
Returns only :math:`\\mathbf{M}`- and :math:`\\mathbf{K}`-dependent coefficients
|
||||
are returned. :math:`c_4=1` and :math:`c_3=0` are not returned.
|
||||
|
||||
The :math:`\\mathbf{M}`- and :math:`\\mathbf{K}`-dependent quaternion polynomial
|
||||
coefficients are given by
|
||||
|
||||
.. math:: c_2 = -2 \\text{ tr}\\left(\\mathbf{M}^T\\mathbf{M}\\right)
|
||||
|
||||
.. math:: c_1 = -8 \\text{ det}(\\mathbf{M})
|
||||
|
||||
.. math:: c_0 = \\text{ det}(\\mathbf{K})
|
||||
|
||||
"""
|
||||
|
||||
c2 = -2 * np.trace(M.T @ M)
|
||||
c1 = -8 * np.linalg.det(M) # TODO: Slow?
|
||||
c0 = np.linalg.det(K) # TODO: Slow?
|
||||
|
||||
return c2, c1, c0
|
||||
|
||||
|
||||
def lambda_max(Ga: float, Gb: float, c2: float, c1: float, c0: float) -> float:
|
||||
"""
|
||||
Find largest root of the quaternion polynomial.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
Ga: float
|
||||
Inner product of structure A
|
||||
Gb:
|
||||
Inner product of structure B
|
||||
c2:
|
||||
Coefficient :math:`c_2` of the quaternion polynomial
|
||||
c1:
|
||||
Coefficient :math:`c_1` of the quaternion polynomial
|
||||
c0:
|
||||
Coefficient :math:`c_0` of the quaternion polynomial
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Largest root of the quaternion polynomial (:math:`\\lambda_\\text{max}`)
|
||||
"""
|
||||
|
||||
def P(x):
|
||||
"""
|
||||
Quaternion polynomial
|
||||
"""
|
||||
return x**4 + c2 * x**2 + c1 * x + c0
|
||||
|
||||
def dP(x):
|
||||
"""
|
||||
Fist derivative of the quaternion polynomial
|
||||
"""
|
||||
return 4 * x**3 + 2 * c2 * x + c1
|
||||
|
||||
x0 = (Ga + Gb) * 0.5
|
||||
|
||||
lmax = optimize.newton(P, x0, fprime=dP)
|
||||
|
||||
return lmax
|
||||
|
||||
|
||||
def _lambda_max_eig(K: np.ndarray) -> float:
|
||||
"""
|
||||
Find largest eigenvalue of :math:`K`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
K: np.ndarray
|
||||
Symmetric key matrix
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Largest eigenvalue of :math:`K`, :math:`\\lambda_\\text{max}`
|
||||
"""
|
||||
e, _ = np.linalg.eig(K)
|
||||
|
||||
return max(e)
|
||||
|
||||
|
||||
def qcp_rmsd(A: np.ndarray, B: np.ndarray, atol: float = 1e-9) -> float:
|
||||
"""
|
||||
Compute RMSD using the quaternion polynomial method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
A: numpy.ndarray
|
||||
Coordinates of structure A
|
||||
B: numpy.ndarray
|
||||
Coordinates of structure B
|
||||
atol: float
|
||||
Absolute tolerance parameter (see notes)
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
RMSD between structures `A` and `B`
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If the shape of structures `A` and `B` is different
|
||||
|
||||
Notes
|
||||
-----
|
||||
If the structures `A` and `B` can be superimposed exactly (i.e. they differ only
|
||||
by center-of-mass translations and rotations), we have
|
||||
|
||||
.. math:: G_a + G_b = 2 \\lambda_\\text{max}
|
||||
|
||||
This means that :math:`s = G_a + G_bb - 2 * \\lambda_\\text{max}` can become
|
||||
negative because of numerical errors and therefore :math:`\\sqrt{s}` fails.
|
||||
In order to avoid this problem, the final RMSD is set to :math:`0`
|
||||
if :math:`|s| < atol`.
|
||||
"""
|
||||
|
||||
assert A.shape == B.shape
|
||||
|
||||
N = A.shape[0]
|
||||
|
||||
Ga = np.trace(A.T @ A)
|
||||
Gb = np.trace(B.T @ B)
|
||||
|
||||
M = M_mtx(A, B)
|
||||
K = K_mtx(M)
|
||||
|
||||
c2, c1, c0 = coefficients(M, K)
|
||||
|
||||
try:
|
||||
# Fast calculation of the largest eigenvalue of K as root of the characteristic
|
||||
# polynomial.
|
||||
l_max = lambda_max(Ga, Gb, c2, c1, c0)
|
||||
except RuntimeError: # Newton method fails to converge; see GitHub Issue #35
|
||||
# Fallback to (slower) explicit calculation of the largest eigenvalue of K
|
||||
l_max = _lambda_max_eig(K)
|
||||
|
||||
s = Ga + Gb - 2 * l_max
|
||||
|
||||
if abs(s) < atol: # Avoid numerical errors when Ga + Gb = 2 * l_max
|
||||
rmsd = 0.0
|
||||
else:
|
||||
rmsd = np.sqrt(s / N)
|
||||
|
||||
return rmsd
|
||||
382
spyrmsd/rmsd.py
Normal file
382
spyrmsd/rmsd.py
Normal file
@@ -0,0 +1,382 @@
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from spyrmsd import graph, hungarian, molecule, qcp, utils
|
||||
|
||||
|
||||
def rmsd(
|
||||
coords1: np.ndarray,
|
||||
coords2: np.ndarray,
|
||||
atomicn1: np.ndarray,
|
||||
atomicn2: np.ndarray,
|
||||
center: bool = False,
|
||||
minimize: bool = False,
|
||||
atol: float = 1e-9,
|
||||
) -> float:
|
||||
"""
|
||||
Compute RMSD
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coords1: np.ndarray
|
||||
Coordinate of molecule 1
|
||||
coords2: np.ndarray
|
||||
Coordinates of molecule 2
|
||||
atomicn1: np.ndarray
|
||||
Atomic numbers for molecule 1
|
||||
atomicn2: np.ndarray
|
||||
Atomic numbers for molecule 2
|
||||
center: bool
|
||||
Center molecules at origin
|
||||
minimize: bool
|
||||
Compute minimum RMSD (with QCP method)
|
||||
atol: float
|
||||
Absolute tolerance parameter for QCP method (see :func:`qcp_rmsd`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
RMSD
|
||||
|
||||
Notes
|
||||
-----
|
||||
When `minimize=True`, the QCP method is used. [1]_ The molecules are
|
||||
centred at the origin according to the center of geometry and superimposed
|
||||
in order to minimize the RMSD.
|
||||
|
||||
.. [1] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
|
||||
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
|
||||
"""
|
||||
|
||||
assert np.all(atomicn1 == atomicn2)
|
||||
assert coords1.shape == coords2.shape
|
||||
|
||||
# Center coordinates if required
|
||||
c1 = utils.center(coords1) if center or minimize else coords1
|
||||
c2 = utils.center(coords2) if center or minimize else coords2
|
||||
|
||||
if minimize:
|
||||
rmsd = qcp.qcp_rmsd(c1, c2, atol)
|
||||
else:
|
||||
n = coords1.shape[0]
|
||||
|
||||
rmsd = np.sqrt(np.sum((c1 - c2) ** 2) / n)
|
||||
|
||||
return rmsd
|
||||
|
||||
|
||||
def hrmsd(
|
||||
coords1: np.ndarray,
|
||||
coords2: np.ndarray,
|
||||
atomicn1: np.ndarray,
|
||||
atomicn2: np.ndarray,
|
||||
center=False,
|
||||
):
|
||||
"""
|
||||
Compute minimum RMSD using the Hungarian method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coords1: np.ndarray
|
||||
Coordinate of molecule 1
|
||||
coords2: np.ndarray
|
||||
Coordinates of molecule 2
|
||||
atomicn1: np.ndarray
|
||||
Atomic numbers for molecule 1
|
||||
atomicn2: np.ndarray
|
||||
Atomic numbers for molecule 2
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Minimum RMSD (after assignment)
|
||||
|
||||
Notes
|
||||
-----
|
||||
The Hungarian algorithm is used to solve the linear assignment problem, which is
|
||||
a minimum weight matching of the molecular graphs (bipartite). [2]_
|
||||
|
||||
The linear assignment problem is solved for every element separately.
|
||||
|
||||
.. [2] W. J. Allen and R. C. Rizzo, *Implementation of the Hungarian Algorithm to
|
||||
Account for Ligand Symmetry and Similarity in Structure-Based Design*,
|
||||
J. Chem. Inf. Model. **54**, 518-529 (2014)
|
||||
"""
|
||||
|
||||
assert atomicn1.shape == atomicn2.shape
|
||||
assert coords1.shape == coords2.shape
|
||||
|
||||
# Center coordinates if required
|
||||
c1 = utils.center(coords1) if center else coords1
|
||||
c2 = utils.center(coords2) if center else coords2
|
||||
|
||||
return hungarian.hungarian_rmsd(c1, c2, atomicn1, atomicn2)
|
||||
|
||||
|
||||
def _rmsd_isomorphic_core(
|
||||
coords1: np.ndarray,
|
||||
coords2: np.ndarray,
|
||||
aprops1: np.ndarray,
|
||||
aprops2: np.ndarray,
|
||||
am1: np.ndarray,
|
||||
am2: np.ndarray,
|
||||
center: bool = False,
|
||||
minimize: bool = False,
|
||||
isomorphisms: Optional[List[Tuple[List[int], List[int]]]] = None,
|
||||
atol: float = 1e-9,
|
||||
) -> Tuple[float, List[Tuple[List[int], List[int]]], Tuple[List[int], List[int]]]:
|
||||
"""
|
||||
Compute RMSD using graph isomorphism.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coords1: np.ndarray
|
||||
Coordinate of molecule 1
|
||||
coords2: np.ndarray
|
||||
Coordinates of molecule 2
|
||||
aprops1: np.ndarray
|
||||
Atomic properties for molecule 1
|
||||
aprops2: np.ndarray
|
||||
Atomic properties for molecule 2
|
||||
am1: np.ndarray
|
||||
Adjacency matrix for molecule 1
|
||||
am2: np.ndarray
|
||||
Adjacency matrix for molecule 2
|
||||
center: bool
|
||||
Centering flag
|
||||
minimize: bool
|
||||
Compute minized RMSD
|
||||
isomorphisms: Optional[List[Dict[int,int]]]
|
||||
Previously computed graph isomorphism
|
||||
atol: float
|
||||
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, List[Dict[int, int]]]
|
||||
RMSD (after graph matching) and graph isomorphisms
|
||||
"""
|
||||
|
||||
assert coords1.shape == coords2.shape
|
||||
|
||||
n = coords1.shape[0]
|
||||
|
||||
# Center coordinates if required
|
||||
c1 = utils.center(coords1) if center or minimize else coords1
|
||||
c2 = utils.center(coords2) if center or minimize else coords2
|
||||
|
||||
# No cached isomorphisms
|
||||
if isomorphisms is None:
|
||||
# Convert molecules to graphs
|
||||
G1 = graph.graph_from_adjacency_matrix(am1, aprops1)
|
||||
G2 = graph.graph_from_adjacency_matrix(am2, aprops2)
|
||||
|
||||
# Get all the possible graph isomorphisms
|
||||
isomorphisms = graph.match_graphs(G1, G2)
|
||||
|
||||
# Minimum result
|
||||
# Squared displacement (not minimize) or RMSD (minimize)
|
||||
min_result = np.inf
|
||||
min_isomorphisms = None
|
||||
|
||||
# Loop over all graph isomorphisms to find the lowest RMSD
|
||||
for idx1, idx2 in isomorphisms:
|
||||
# Use the isomorphism to shuffle coordinates around (from original order)
|
||||
c1i = c1[idx1, :]
|
||||
c2i = c2[idx2, :]
|
||||
|
||||
if not minimize:
|
||||
# Compute square displacement
|
||||
# Avoid dividing by n and an expensive sqrt() operation
|
||||
result = np.sum((c1i - c2i) ** 2)
|
||||
else:
|
||||
# Compute minimized RMSD using QCP
|
||||
result = qcp.qcp_rmsd(c1i, c2i, atol)
|
||||
|
||||
if result < min_result:
|
||||
min_result = result
|
||||
min_isomorphisms = (idx1, idx2)
|
||||
|
||||
if not minimize:
|
||||
# Compute actual RMSD from square displacement
|
||||
min_result = np.sqrt(min_result / n)
|
||||
|
||||
# Return the actual RMSD
|
||||
return min_result, isomorphisms, min_isomorphisms
|
||||
|
||||
|
||||
def symmrmsd(
|
||||
coordsref: np.ndarray,
|
||||
coords: Union[np.ndarray, List[np.ndarray]],
|
||||
apropsref: np.ndarray,
|
||||
aprops: np.ndarray,
|
||||
amref: np.ndarray,
|
||||
am: np.ndarray,
|
||||
center: bool = False,
|
||||
minimize: bool = False,
|
||||
cache: bool = True,
|
||||
atol: float = 1e-9,
|
||||
return_permutation: bool = False,
|
||||
) -> Any:
|
||||
"""
|
||||
Compute RMSD using graph isomorphism for multiple coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coordsref: np.ndarray
|
||||
Coordinate of reference molecule
|
||||
coords: List[np.ndarray]
|
||||
Coordinates of other molecule
|
||||
apropsref: np.ndarray
|
||||
Atomic properties for reference
|
||||
aprops: np.ndarray
|
||||
Atomic properties for other molecule
|
||||
amref: np.ndarray
|
||||
Adjacency matrix for reference molecule
|
||||
am: np.ndarray
|
||||
Adjacency matrix for other molecule
|
||||
center: bool
|
||||
Centering flag
|
||||
minimize: bool
|
||||
Minimum RMSD
|
||||
cache: bool
|
||||
Cache graph isomorphisms
|
||||
atol: float
|
||||
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
float: Union[float, List[float]]
|
||||
Symmetry-corrected RMSD(s) and graph isomorphisms
|
||||
|
||||
Notes
|
||||
-----
|
||||
|
||||
Graph isomorphism is introduced for symmetry corrections. However, it is also
|
||||
useful when two molecules do not have the atoms in the same order since atom
|
||||
matching according to atomic numbers and the molecular connectivity is
|
||||
performed. If atoms are in the same order and there is no symmetry, use the
|
||||
`rmsd` function.
|
||||
"""
|
||||
|
||||
if isinstance(coords, list): # Multiple RMSD calculations
|
||||
RMSD: Any = []
|
||||
isomorphism = None
|
||||
min_iso = []
|
||||
|
||||
for c in coords:
|
||||
if not cache:
|
||||
# Reset isomorphism
|
||||
isomorphism = None
|
||||
|
||||
srmsd, isomorphism, min_i = _rmsd_isomorphic_core(
|
||||
coordsref,
|
||||
c,
|
||||
apropsref,
|
||||
aprops,
|
||||
amref,
|
||||
am,
|
||||
center=center,
|
||||
minimize=minimize,
|
||||
isomorphisms=isomorphism,
|
||||
atol=atol,
|
||||
)
|
||||
min_iso.append(min_i)
|
||||
RMSD.append(srmsd)
|
||||
|
||||
else: # Single RMSD calculation
|
||||
RMSD, isomorphism, min_iso = _rmsd_isomorphic_core(
|
||||
coordsref,
|
||||
coords,
|
||||
apropsref,
|
||||
aprops,
|
||||
amref,
|
||||
am,
|
||||
center=center,
|
||||
minimize=minimize,
|
||||
isomorphisms=None,
|
||||
atol=atol,
|
||||
)
|
||||
|
||||
if return_permutation:
|
||||
return RMSD, min_iso
|
||||
return RMSD
|
||||
|
||||
|
||||
def rmsdwrapper(
|
||||
molref: molecule.Molecule,
|
||||
mols: Union[molecule.Molecule, List[molecule.Molecule]],
|
||||
symmetry: bool = True,
|
||||
center: bool = False,
|
||||
minimize: bool = False,
|
||||
strip: bool = True,
|
||||
cache: bool = True,
|
||||
) -> Any:
|
||||
"""
|
||||
Compute RMSD between two molecule.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
molref: molecule.Molecule
|
||||
Reference molecule
|
||||
mols: Union[molecule.Molecule, List[molecule.Molecule]]
|
||||
Molecules to compare to reference molecule
|
||||
symmetry: bool, optional
|
||||
Symmetry-corrected RMSD (using graph isomorphism)
|
||||
center: bool, optional
|
||||
Center molecules at origin
|
||||
minimize: bool, optional
|
||||
Minimised RMSD (using the quaternion polynomial method)
|
||||
strip: bool, optional
|
||||
Strip hydrogen atoms
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[float]
|
||||
RMSDs
|
||||
"""
|
||||
|
||||
if not isinstance(mols, list):
|
||||
mols = [mols]
|
||||
|
||||
if strip:
|
||||
molref.strip()
|
||||
|
||||
for mol in mols:
|
||||
mol.strip()
|
||||
|
||||
if minimize:
|
||||
center = True
|
||||
|
||||
cref = molecule.coords_from_molecule(molref, center)
|
||||
cmols = [molecule.coords_from_molecule(mol, center) for mol in mols]
|
||||
|
||||
RMSDlist = []
|
||||
|
||||
if symmetry:
|
||||
RMSDlist = symmrmsd(
|
||||
cref,
|
||||
cmols,
|
||||
molref.atomicnums,
|
||||
mols[0].atomicnums,
|
||||
molref.adjacency_matrix,
|
||||
mols[0].adjacency_matrix,
|
||||
center=center,
|
||||
minimize=minimize,
|
||||
cache=cache,
|
||||
)
|
||||
else: # No symmetry
|
||||
for c in cmols:
|
||||
RMSDlist.append(
|
||||
rmsd(
|
||||
cref,
|
||||
c,
|
||||
molref.atomicnums,
|
||||
mols[0].atomicnums,
|
||||
center=center,
|
||||
minimize=minimize,
|
||||
)
|
||||
)
|
||||
|
||||
return RMSDlist
|
||||
176
spyrmsd/utils.py
Normal file
176
spyrmsd/utils.py
Normal file
@@ -0,0 +1,176 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def format(fname: str) -> str:
|
||||
"""
|
||||
Extract format extension from file name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File extension
|
||||
|
||||
Notes
|
||||
-----
|
||||
The file extension is returned without the `.` character, i.e. for the file
|
||||
`path/filename.ext` the string `ext` is returned.
|
||||
|
||||
If a file is compressed, the `.gz` extension is ignored.
|
||||
"""
|
||||
name, ext = os.path.splitext(fname)
|
||||
|
||||
if ext == ".gz":
|
||||
_, ext = os.path.splitext(name)
|
||||
|
||||
return ext[1:] # Remove "."
|
||||
|
||||
|
||||
def molformat(fname: str) -> str:
|
||||
"""
|
||||
Extract an OpenBabel-friendly format from file name.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
fname : str
|
||||
File name
|
||||
|
||||
Returns
|
||||
-------
|
||||
str
|
||||
File extension in an OpenBabel-friendly format
|
||||
|
||||
Notes
|
||||
-----
|
||||
File types in OpenBabel do not always correspond to the file extension. This
|
||||
function converts the file extension to an OpenBabel file type.
|
||||
|
||||
The following table shows the different conversions performed by this function:
|
||||
|
||||
========= =========
|
||||
Extension File Type
|
||||
--------- ---------
|
||||
xyz XYZ
|
||||
========= =========
|
||||
"""
|
||||
|
||||
ext = format(fname)
|
||||
|
||||
if ext == "xyz":
|
||||
# xyz files in OpenBabel are called XYZ
|
||||
ext = "XYZ"
|
||||
|
||||
return ext
|
||||
|
||||
|
||||
def deg_to_rad(angle: float) -> float:
|
||||
"""
|
||||
Convert angle in degrees to angle in radians.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
angle : float
|
||||
Angle (in degrees)
|
||||
|
||||
Returns
|
||||
-------
|
||||
float
|
||||
Angle (in radians)
|
||||
"""
|
||||
|
||||
return angle * np.pi / 180.0
|
||||
|
||||
|
||||
def rotate(
|
||||
v: np.ndarray, angle: float, axis: np.ndarray, units: str = "rad"
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Rotate vector.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
v: numpy.array
|
||||
3D vector to be rotated
|
||||
angle : float
|
||||
Angle of rotation (in `units`)
|
||||
axis : numpy.array
|
||||
3D axis of rotation
|
||||
units: {"rad", "deg"}
|
||||
Units of `angle` (in radians `rad` or degrees `deg`)
|
||||
|
||||
Returns
|
||||
-------
|
||||
numpy.array
|
||||
Rotated vector
|
||||
|
||||
Raises
|
||||
------
|
||||
AssertionError
|
||||
If the axis of rotation is not a 3D vector
|
||||
ValueError
|
||||
If `units` is not `rad` or `deg`
|
||||
"""
|
||||
|
||||
assert len(axis) == 3
|
||||
|
||||
# Ensure rotation axis is normalised
|
||||
axis = axis / np.linalg.norm(axis)
|
||||
|
||||
if units.lower() == "rad":
|
||||
pass
|
||||
elif units.lower() == "deg":
|
||||
angle = deg_to_rad(angle)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Units {units} for angle is not supported. Use 'deg' or 'rad' instead."
|
||||
)
|
||||
|
||||
t1 = np.outer(axis, np.inner(axis, v)).T
|
||||
t2 = np.cos(angle) * np.cross(np.cross(axis, v), axis)
|
||||
t3 = np.sin(angle) * np.cross(axis, v)
|
||||
|
||||
return t1 + t2 + t3
|
||||
|
||||
|
||||
def center_of_geometry(coordinates: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Center of geometry.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coordinates: np.ndarray
|
||||
Coordinates
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Center of geometry
|
||||
"""
|
||||
|
||||
assert coordinates.shape[1] == 3
|
||||
|
||||
return np.mean(coordinates, axis=0)
|
||||
|
||||
|
||||
def center(coordinates: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Center coordinates.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
coordinates: np.ndarray
|
||||
Coordinates
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
Centred coordinates
|
||||
"""
|
||||
|
||||
return coordinates - center_of_geometry(coordinates)
|
||||
139
train.py
139
train.py
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
from functools import partial
|
||||
|
||||
import wandb
|
||||
@@ -12,46 +13,88 @@ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
|
||||
resource.setrlimit(resource.RLIMIT_NOFILE, (64000, rlimit[1]))
|
||||
|
||||
import yaml
|
||||
|
||||
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl
|
||||
from datasets.pdbbind import construct_loader
|
||||
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, t_to_sigma_individual
|
||||
from datasets.loader import construct_loader
|
||||
from utils.parsing import parse_train_args
|
||||
from utils.training import train_epoch, test_epoch, loss_function, inference_epoch
|
||||
from utils.training import train_epoch, test_epoch, loss_function, inference_epoch_fix
|
||||
from utils.utils import save_yaml_file, get_optimizer_and_scheduler, get_model, ExponentialMovingAverage
|
||||
|
||||
|
||||
def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir):
|
||||
def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir, val_dataset2):
|
||||
|
||||
loss_fn = partial(loss_function, tr_weight=args.tr_weight, rot_weight=args.rot_weight,
|
||||
tor_weight=args.tor_weight, no_torsion=args.no_torsion, backbone_weight=args.backbone_loss_weight,
|
||||
sidechain_weight=args.sidechain_loss_weight)
|
||||
|
||||
best_val_loss = math.inf
|
||||
best_val_inference_value = math.inf if args.inference_earlystop_goal == 'min' else 0
|
||||
best_val_secondary_value = math.inf if args.inference_earlystop_goal == 'min' else 0
|
||||
best_epoch = 0
|
||||
best_val_inference_epoch = 0
|
||||
loss_fn = partial(loss_function, tr_weight=args.tr_weight, rot_weight=args.rot_weight,
|
||||
tor_weight=args.tor_weight, no_torsion=args.no_torsion)
|
||||
|
||||
freeze_params = 0
|
||||
scheduler_mode = args.inference_earlystop_goal if args.val_inference_freq is not None else 'min'
|
||||
if args.scheduler == 'layer_linear_warmup':
|
||||
freeze_params = args.warmup_dur * (args.num_conv_layers + 2) - 1
|
||||
print("Freezing some parameters until epoch {}".format(freeze_params))
|
||||
|
||||
print("Starting training...")
|
||||
for epoch in range(args.n_epochs):
|
||||
if epoch % 5 == 0: print("Run name: ", args.run_name)
|
||||
logs = {}
|
||||
train_losses = train_epoch(model, train_loader, optimizer, device, t_to_sigma, loss_fn, ema_weights)
|
||||
print("Epoch {}: Training loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f}"
|
||||
.format(epoch, train_losses['loss'], train_losses['tr_loss'], train_losses['rot_loss'],
|
||||
train_losses['tor_loss']))
|
||||
|
||||
ema_weights.store(model.parameters())
|
||||
if args.use_ema: ema_weights.copy_to(model.parameters()) # load ema parameters into model for running validation and inference
|
||||
if args.scheduler == 'layer_linear_warmup' and (epoch+1) % args.warmup_dur == 0:
|
||||
step = (epoch+1) // args.warmup_dur
|
||||
if step < args.num_conv_layers + 2:
|
||||
print("New unfreezing step")
|
||||
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=step, scheduler_mode=scheduler_mode)
|
||||
elif step == args.num_conv_layers + 2:
|
||||
print("Unfreezing all parameters")
|
||||
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=step, scheduler_mode=scheduler_mode)
|
||||
ema_weights = ExponentialMovingAverage(model.parameters(), decay=args.ema_rate)
|
||||
elif args.scheduler == 'linear_warmup' and epoch == args.warmup_dur:
|
||||
print("Moving to plateu scheduler")
|
||||
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=1, scheduler_mode=scheduler_mode,
|
||||
optimizer=optimizer)
|
||||
|
||||
logs = {}
|
||||
train_losses = train_epoch(model, train_loader, optimizer, device, t_to_sigma, loss_fn, ema_weights if epoch > freeze_params else None)
|
||||
print("Epoch {}: Training loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f} sc {:.4f} lr {:.4f}"
|
||||
.format(epoch, train_losses['loss'], train_losses['tr_loss'], train_losses['rot_loss'],
|
||||
train_losses['tor_loss'], train_losses['sidechain_loss'], optimizer.param_groups[0]['lr']))
|
||||
|
||||
if epoch > freeze_params:
|
||||
ema_weights.store(model.parameters())
|
||||
if args.use_ema: ema_weights.copy_to(model.parameters()) # load ema parameters into model for running validation and inference
|
||||
val_losses = test_epoch(model, val_loader, device, t_to_sigma, loss_fn, args.test_sigma_intervals)
|
||||
print("Epoch {}: Validation loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f}"
|
||||
.format(epoch, val_losses['loss'], val_losses['tr_loss'], val_losses['rot_loss'], val_losses['tor_loss']))
|
||||
print("Epoch {}: Validation loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f} sc {:.4f}"
|
||||
.format(epoch, val_losses['loss'], val_losses['tr_loss'], val_losses['rot_loss'], val_losses['tor_loss'], val_losses['sidechain_loss']))
|
||||
|
||||
if args.val_inference_freq != None and (epoch + 1) % args.val_inference_freq == 0:
|
||||
inf_metrics = inference_epoch(model, val_loader.dataset.complex_graphs[:args.num_inference_complexes], device, t_to_sigma, args)
|
||||
print("Epoch {}: Val inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f}"
|
||||
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5']))
|
||||
inf_dataset = [val_loader.dataset.get(i) for i in range(min(args.num_inference_complexes, val_loader.dataset.__len__()))]
|
||||
inf_metrics = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
|
||||
print("Epoch {}: Val inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
|
||||
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5'], inf_metrics['min_rmsds_lt2'], inf_metrics['min_rmsds_lt5']))
|
||||
logs.update({'valinf_' + k: v for k, v in inf_metrics.items()}, step=epoch + 1)
|
||||
|
||||
if not args.use_ema: ema_weights.copy_to(model.parameters())
|
||||
ema_state_dict = copy.deepcopy(model.module.state_dict() if device.type == 'cuda' else model.state_dict())
|
||||
ema_weights.restore(model.parameters())
|
||||
if args.double_val and args.val_inference_freq != None and (epoch + 1) % args.val_inference_freq == 0:
|
||||
inf_dataset = [val_dataset2.get(i) for i in range(min(args.num_inference_complexes, val_dataset2.__len__()))]
|
||||
inf_metrics2 = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
|
||||
print("Epoch {}: Val inference on second validation rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
|
||||
.format(epoch, inf_metrics2['rmsds_lt2'], inf_metrics2['rmsds_lt5'], inf_metrics2['min_rmsds_lt2'], inf_metrics2['min_rmsds_lt5']))
|
||||
logs.update({'valinf2_' + k: v for k, v in inf_metrics2.items()}, step=epoch + 1)
|
||||
logs.update({'valinfcomb_' + k: (v + inf_metrics[k])/2 for k, v in inf_metrics2.items()}, step=epoch + 1)
|
||||
|
||||
if args.train_inference_freq != None and (epoch + 1) % args.train_inference_freq == 0:
|
||||
inf_dataset = [train_loader.dataset.get(i) for i in range(min(min(args.num_inference_complexes, 300), train_loader.dataset.__len__()))]
|
||||
inf_metrics = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
|
||||
print("Epoch {}: Train inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
|
||||
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5'], inf_metrics['min_rmsds_lt2'], inf_metrics['min_rmsds_lt5']))
|
||||
logs.update({'traininf_' + k: v for k, v in inf_metrics.items()}, step=epoch + 1)
|
||||
|
||||
if epoch > freeze_params:
|
||||
if not args.use_ema: ema_weights.copy_to(model.parameters())
|
||||
ema_state_dict = copy.deepcopy(model.module.state_dict() if device.type == 'cuda' else model.state_dict())
|
||||
ema_weights.restore(model.parameters())
|
||||
|
||||
if args.wandb:
|
||||
logs.update({'train_' + k: v for k, v in train_losses.items()})
|
||||
@@ -66,15 +109,31 @@ def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_load
|
||||
best_val_inference_value = logs[args.inference_earlystop_metric]
|
||||
best_val_inference_epoch = epoch
|
||||
torch.save(state_dict, os.path.join(run_dir, 'best_inference_epoch_model.pt'))
|
||||
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_inference_epoch_model.pt'))
|
||||
if epoch > freeze_params:
|
||||
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_inference_epoch_model.pt'))
|
||||
|
||||
if args.inference_secondary_metric is not None and args.inference_secondary_metric in logs.keys() and \
|
||||
(args.inference_earlystop_goal == 'min' and logs[args.inference_secondary_metric] <= best_val_secondary_value or
|
||||
args.inference_earlystop_goal == 'max' and logs[args.inference_secondary_metric] >= best_val_secondary_value):
|
||||
best_val_secondary_value = logs[args.inference_secondary_metric]
|
||||
if epoch > freeze_params:
|
||||
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_secondary_epoch_model.pt'))
|
||||
|
||||
if val_losses['loss'] <= best_val_loss:
|
||||
best_val_loss = val_losses['loss']
|
||||
best_epoch = epoch
|
||||
torch.save(state_dict, os.path.join(run_dir, 'best_model.pt'))
|
||||
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_model.pt'))
|
||||
if epoch > freeze_params:
|
||||
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_model.pt'))
|
||||
|
||||
if args.save_model_freq is not None and (epoch + 1) % args.save_model_freq == 0:
|
||||
shutil.copyfile(os.path.join(run_dir, 'best_model.pt'),
|
||||
os.path.join(run_dir, f'epoch{epoch+1}_best_model.pt'))
|
||||
|
||||
if scheduler:
|
||||
if args.val_inference_freq is not None:
|
||||
if epoch < freeze_params or (args.scheduler == 'linear_warmup' and epoch < args.warmup_dur):
|
||||
scheduler.step()
|
||||
elif args.val_inference_freq is not None:
|
||||
scheduler.step(best_val_inference_value)
|
||||
else:
|
||||
scheduler.step(val_losses['loss'])
|
||||
@@ -108,17 +167,26 @@ def main_function():
|
||||
if args.cudnn_benchmark:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
if args.wandb:
|
||||
wandb.init(
|
||||
entity='',
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
project=args.project,
|
||||
name=args.run_name,
|
||||
config=args
|
||||
)
|
||||
|
||||
# construct loader
|
||||
t_to_sigma = partial(t_to_sigma_compl, args=args)
|
||||
train_loader, val_loader = construct_loader(args, t_to_sigma)
|
||||
|
||||
train_loader, val_loader, val_dataset2 = construct_loader(args, t_to_sigma, device)
|
||||
|
||||
model = get_model(args, device, t_to_sigma=t_to_sigma)
|
||||
optimizer, scheduler = get_optimizer_and_scheduler(args, model, scheduler_mode=args.inference_earlystop_goal if args.val_inference_freq is not None else 'min')
|
||||
ema_weights = ExponentialMovingAverage(model.parameters(),decay=args.ema_rate)
|
||||
|
||||
if args.restart_dir:
|
||||
try:
|
||||
dict = torch.load(f'{args.restart_dir}/last_model.pt', map_location=torch.device('cpu'))
|
||||
dict = torch.load(f'{args.restart_dir}/{args.restart_ckpt}.pt', map_location=torch.device('cpu'))
|
||||
if args.restart_lr is not None: dict['optimizer']['param_groups'][0]['lr'] = args.restart_lr
|
||||
optimizer.load_state_dict(dict['optimizer'])
|
||||
model.module.load_state_dict(dict['model'], strict=True)
|
||||
@@ -130,18 +198,15 @@ def main_function():
|
||||
dict = torch.load(f'{args.restart_dir}/best_model.pt', map_location=torch.device('cpu'))
|
||||
model.module.load_state_dict(dict, strict=True)
|
||||
print("Due to exception had to take the best epoch and no optimiser")
|
||||
elif args.pretrain_dir:
|
||||
dict = torch.load(f'{args.pretrain_dir}/{args.pretrain_ckpt}.pt', map_location=torch.device('cpu'))
|
||||
model.module.load_state_dict(dict, strict=True)
|
||||
print("Using pretrained model", f'{args.pretrain_dir}/{args.pretrain_ckpt}.pt')
|
||||
|
||||
numel = sum([p.numel() for p in model.parameters()])
|
||||
print('Model with', numel, 'parameters')
|
||||
|
||||
if args.wandb:
|
||||
wandb.init(
|
||||
entity='entity',
|
||||
settings=wandb.Settings(start_method="fork"),
|
||||
project=args.project,
|
||||
name=args.run_name,
|
||||
config=args
|
||||
)
|
||||
wandb.log({'numel': numel})
|
||||
|
||||
# record parameters
|
||||
@@ -150,9 +215,9 @@ def main_function():
|
||||
save_yaml_file(yaml_file_name, args.__dict__)
|
||||
args.device = device
|
||||
|
||||
train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir)
|
||||
train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir, val_dataset2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
||||
main_function()
|
||||
main_function()
|
||||
|
||||
@@ -5,8 +5,24 @@ import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from scipy.stats import beta
|
||||
|
||||
from utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch
|
||||
from utils.torsion import modify_conformer_torsion_angles
|
||||
from utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch, rigid_transform_Kabsch_3D_torch_batch
|
||||
from utils.torsion import modify_conformer_torsion_angles, modify_conformer_torsion_angles_batch
|
||||
|
||||
|
||||
def sigmoid(t):
|
||||
return 1 / (1 + np.e**(-t))
|
||||
|
||||
|
||||
def sigmoid_schedule(t, k=10, m=0.5):
|
||||
s = lambda t: sigmoid(k*(t-m))
|
||||
return (s(t)-s(0))/(s(1)-s(0))
|
||||
|
||||
|
||||
def t_to_sigma_individual(t, schedule_type, sigma_min, sigma_max, schedule_k=10, schedule_m=0.4):
|
||||
if schedule_type == "exponential":
|
||||
return sigma_min ** (1 - t) * sigma_max ** t
|
||||
elif schedule_type == 'sigmoid':
|
||||
return sigmoid_schedule(t, k=schedule_k, m=schedule_m) * (sigma_max - sigma_min) + sigma_min
|
||||
|
||||
|
||||
def t_to_sigma(t_tr, t_rot, t_tor, args):
|
||||
@@ -16,7 +32,7 @@ def t_to_sigma(t_tr, t_rot, t_tor, args):
|
||||
return tr_sigma, rot_sigma, tor_sigma
|
||||
|
||||
|
||||
def modify_conformer(data, tr_update, rot_update, torsion_updates):
|
||||
def modify_conformer(data, tr_update, rot_update, torsion_updates, pivot=None):
|
||||
lig_center = torch.mean(data['ligand'].pos, dim=0, keepdim=True)
|
||||
rot_mat = axis_angle_to_matrix(rot_update.squeeze())
|
||||
rigid_new_pos = (data['ligand'].pos - lig_center) @ rot_mat.T + tr_update + lig_center
|
||||
@@ -26,14 +42,60 @@ def modify_conformer(data, tr_update, rot_update, torsion_updates):
|
||||
data['ligand', 'ligand'].edge_index.T[data['ligand'].edge_mask],
|
||||
data['ligand'].mask_rotate if isinstance(data['ligand'].mask_rotate, np.ndarray) else data['ligand'].mask_rotate[0],
|
||||
torsion_updates).to(rigid_new_pos.device)
|
||||
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
|
||||
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
|
||||
if pivot is None:
|
||||
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
|
||||
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
|
||||
else:
|
||||
R1, t1 = rigid_transform_Kabsch_3D_torch(pivot.T, rigid_new_pos.T)
|
||||
R2, t2 = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, pivot.T)
|
||||
|
||||
aligned_flexible_pos = (flexible_new_pos @ R2.T + t2.T) @ R1.T + t1.T
|
||||
|
||||
data['ligand'].pos = aligned_flexible_pos
|
||||
else:
|
||||
data['ligand'].pos = rigid_new_pos
|
||||
return data
|
||||
|
||||
|
||||
def modify_conformer_batch(orig_pos, data, tr_update, rot_update, torsion_updates, mask_rotate):
|
||||
B = data.num_graphs
|
||||
N, M, R = data['ligand'].num_nodes // B, data['ligand', 'ligand'].num_edges // B, data['ligand'].edge_mask.sum().item() // B
|
||||
|
||||
pos, edge_index, edge_mask = orig_pos.reshape(B, N, 3) + 0, data['ligand', 'ligand'].edge_index[:, :M], data['ligand'].edge_mask[:M]
|
||||
torsion_updates = torsion_updates.reshape(B, -1) if torsion_updates is not None else None
|
||||
|
||||
lig_center = torch.mean(pos, dim=1, keepdim=True)
|
||||
rot_mat = axis_angle_to_matrix(rot_update)
|
||||
rigid_new_pos = torch.bmm(pos - lig_center, rot_mat.permute(0, 2, 1)) + tr_update.unsqueeze(1) + lig_center
|
||||
|
||||
if torsion_updates is not None:
|
||||
flexible_new_pos = modify_conformer_torsion_angles_batch(rigid_new_pos, edge_index.T[edge_mask], mask_rotate, torsion_updates)
|
||||
R, t = rigid_transform_Kabsch_3D_torch_batch(flexible_new_pos, rigid_new_pos)
|
||||
aligned_flexible_pos = torch.bmm(flexible_new_pos, R.transpose(1, 2)) + t.transpose(1, 2)
|
||||
final_pos = aligned_flexible_pos.reshape(-1, 3)
|
||||
else:
|
||||
final_pos = rigid_new_pos.reshape(-1, 3)
|
||||
return final_pos
|
||||
|
||||
|
||||
def modify_conformer_coordinates(pos, tr_update, rot_update, torsion_updates, edge_mask, mask_rotate, edge_index):
|
||||
# Made this function which does the same as modify_conformer because passing a graph would require
|
||||
# creating a new heterograph for reach graph when unbatching a batch of graphs
|
||||
lig_center = torch.mean(pos, dim=0, keepdim=True)
|
||||
rot_mat = axis_angle_to_matrix(rot_update.squeeze())
|
||||
rigid_new_pos = (pos - lig_center) @ rot_mat.T + tr_update + lig_center
|
||||
|
||||
if torsion_updates is not None:
|
||||
flexible_new_pos = modify_conformer_torsion_angles(rigid_new_pos,edge_index.T[edge_mask],mask_rotate \
|
||||
if isinstance(mask_rotate, np.ndarray) else mask_rotate[0], torsion_updates).to(rigid_new_pos.device)
|
||||
|
||||
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
|
||||
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
|
||||
return aligned_flexible_pos
|
||||
else:
|
||||
return rigid_new_pos
|
||||
|
||||
|
||||
def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000):
|
||||
""" from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """
|
||||
assert len(timesteps.shape) == 1
|
||||
@@ -73,11 +135,15 @@ def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000)
|
||||
return emb_func
|
||||
|
||||
|
||||
def get_t_schedule(inference_steps):
|
||||
return np.linspace(1, 0, inference_steps + 1)[:-1]
|
||||
def get_t_schedule(sigma_schedule, inference_steps, inf_sched_alpha=1, inf_sched_beta=1, t_max=1):
|
||||
if sigma_schedule == 'expbeta':
|
||||
lin_max = beta.cdf(t_max, a=inf_sched_alpha, b=inf_sched_beta)
|
||||
c = np.linspace(lin_max, 0, inference_steps + 1)[:-1]
|
||||
return beta.ppf(c, a=inf_sched_alpha, b=inf_sched_beta)
|
||||
raise Exception()
|
||||
|
||||
|
||||
def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device):
|
||||
def set_time(complex_graphs, t, t_tr, t_rot, t_tor, batchsize, all_atoms, device, include_miscellaneous_atoms=False):
|
||||
complex_graphs['ligand'].node_t = {
|
||||
'tr': t_tr * torch.ones(complex_graphs['ligand'].num_nodes).to(device),
|
||||
'rot': t_rot * torch.ones(complex_graphs['ligand'].num_nodes).to(device),
|
||||
@@ -93,4 +159,10 @@ def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device):
|
||||
complex_graphs['atom'].node_t = {
|
||||
'tr': t_tr * torch.ones(complex_graphs['atom'].num_nodes).to(device),
|
||||
'rot': t_rot * torch.ones(complex_graphs['atom'].num_nodes).to(device),
|
||||
'tor': t_tor * torch.ones(complex_graphs['atom'].num_nodes).to(device)}
|
||||
'tor': t_tor * torch.ones(complex_graphs['atom'].num_nodes).to(device)}
|
||||
|
||||
if include_miscellaneous_atoms and not all_atoms:
|
||||
complex_graphs['misc_atom'].node_t = {
|
||||
'tr': t_tr * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device),
|
||||
'rot': t_rot * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device),
|
||||
'tor': t_tor * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device)}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@@ -85,6 +86,126 @@ def axis_angle_to_matrix(axis_angle):
|
||||
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Returns torch.sqrt(torch.max(0, x))
|
||||
but with a zero subgradient where x is 0.
|
||||
"""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
return ret
|
||||
|
||||
|
||||
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to quaternions.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
quaternions with real part first, as tensor of shape (..., 4).
|
||||
"""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
# we produce the desired quaternion multiplied by each of r, i, j, k
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
|
||||
# `int`.
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
|
||||
# the candidate won't be picked.
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
|
||||
# forall i; we pick the best-conditioned one (with the largest denominator)
|
||||
|
||||
return quat_candidates[
|
||||
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
|
||||
].reshape(batch_dim + (4,))
|
||||
|
||||
|
||||
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as quaternions to axis/angle.
|
||||
|
||||
Args:
|
||||
quaternions: quaternions with real part first,
|
||||
as tensor of shape (..., 4).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
||||
half_angles = torch.atan2(norms, quaternions[..., :1])
|
||||
angles = 2 * half_angles
|
||||
eps = 1e-6
|
||||
small_angles = angles.abs() < eps
|
||||
sin_half_angles_over_angles = torch.empty_like(angles)
|
||||
sin_half_angles_over_angles[~small_angles] = (
|
||||
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
||||
)
|
||||
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
||||
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
||||
sin_half_angles_over_angles[small_angles] = (
|
||||
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
||||
)
|
||||
return quaternions[..., 1:] / sin_half_angles_over_angles
|
||||
|
||||
|
||||
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert rotations given as rotation matrices to axis/angle.
|
||||
|
||||
Args:
|
||||
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
||||
|
||||
Returns:
|
||||
Rotations given as a vector in axis angle form, as a tensor
|
||||
of shape (..., 3), where the magnitude is the angle
|
||||
turned anticlockwise in radians around the vector's
|
||||
direction.
|
||||
"""
|
||||
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
||||
|
||||
|
||||
def rigid_transform_Kabsch_3D_torch(A, B):
|
||||
# R = 3x3 rotation matrix, t = 3x1 column vector
|
||||
# This already takes residue identity into account.
|
||||
@@ -97,7 +218,6 @@ def rigid_transform_Kabsch_3D_torch(A, B):
|
||||
if num_rows != 3:
|
||||
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
|
||||
|
||||
|
||||
# find mean column wise: 3 x 1
|
||||
centroid_A = torch.mean(A, axis=1, keepdims=True)
|
||||
centroid_B = torch.mean(B, axis=1, keepdims=True)
|
||||
@@ -121,3 +241,78 @@ def rigid_transform_Kabsch_3D_torch(A, B):
|
||||
|
||||
t = -R @ centroid_A + centroid_B
|
||||
return R, t
|
||||
|
||||
|
||||
def rigid_transform_Kabsch_3D_torch_batch(A, B):
|
||||
# R = Bx3x3 rotation matrix, t = Bx3x1 column vector
|
||||
|
||||
assert A.shape == B.shape
|
||||
_, N, M = A.shape
|
||||
if M != 3:
|
||||
raise Exception(f"matrix A and B should be BxNx3")
|
||||
|
||||
A, B = A.permute(0, 2, 1), B.permute(0, 2, 1)
|
||||
|
||||
# find mean column wise: 3 x 1
|
||||
centroid_A = torch.mean(A, axis=2, keepdims=True)
|
||||
centroid_B = torch.mean(B, axis=2, keepdims=True)
|
||||
|
||||
# subtract mean
|
||||
Am = A - centroid_A
|
||||
Bm = B - centroid_B
|
||||
H = torch.bmm(Am, Bm.transpose(1, 2))
|
||||
|
||||
# find rotation
|
||||
U, S, Vt = torch.linalg.svd(H)
|
||||
R = torch.bmm(Vt.transpose(1, 2), U.transpose(1, 2))
|
||||
|
||||
# reflection case
|
||||
SS = torch.diag(torch.tensor([1., 1., -1.], device=A.device))
|
||||
Rm = torch.bmm(Vt.transpose(1,2) @ SS, U.transpose(1, 2))
|
||||
R = torch.where(torch.linalg.det(R)[:, None, None] < 0, Rm, R)
|
||||
assert torch.all(torch.abs(torch.linalg.det(R) - 1) < 3e-3) # note I had to change this error bound to be higher
|
||||
|
||||
t = torch.bmm(-R, centroid_A) + centroid_B
|
||||
return R, t
|
||||
|
||||
|
||||
def rigid_transform_Kabsch_independent_torch(A, B):
|
||||
# R = 3x3 rotation matrix, t = 3x1 column vector
|
||||
# This already takes residue identity into account.
|
||||
|
||||
assert A.shape[1] == B.shape[1]
|
||||
num_rows, num_cols = A.shape
|
||||
if num_rows != 3:
|
||||
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
|
||||
num_rows, num_cols = B.shape
|
||||
if num_rows != 3:
|
||||
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
|
||||
|
||||
# find mean column wise: 3 x 1
|
||||
centroid_A = torch.mean(A, axis=1, keepdims=True)
|
||||
centroid_B = torch.mean(B, axis=1, keepdims=True)
|
||||
|
||||
# subtract mean
|
||||
Am = A - centroid_A
|
||||
Bm = B - centroid_B
|
||||
|
||||
H = Am @ Bm.T
|
||||
|
||||
# find rotation
|
||||
U, S, Vt = torch.linalg.svd(H)
|
||||
|
||||
R = Vt.T @ U.T
|
||||
# special reflection case
|
||||
if torch.linalg.det(R) < 0:
|
||||
# print("det(R) < R, reflection detected!, correcting for it ...")
|
||||
SS = torch.diag(torch.tensor([1.,1.,-1.], device=A.device))
|
||||
R = (Vt.T @ SS) @ U.T
|
||||
assert math.fabs(torch.linalg.det(R) - 1) < 3e-3 # note I had to change this error bound to be higher
|
||||
|
||||
t = - centroid_A + centroid_B # note does not change rotation
|
||||
R_vec = matrix_to_axis_angle(R)
|
||||
return t, R_vec
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
89
utils/gnina_utils.py
Normal file
89
utils/gnina_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
import numpy as np
|
||||
from rdkit.Chem import AllChem, RemoveHs, RemoveAllHs
|
||||
|
||||
from datasets.process_mols import write_mol_with_coords, read_molecule
|
||||
import re
|
||||
|
||||
from utils.utils import remove_all_hs
|
||||
|
||||
|
||||
def read_gnina_metrics(gnina_sdf_path):
|
||||
with open(gnina_sdf_path, 'r') as f:
|
||||
pattern = re.compile(r'> <(.*?)>\n(.*?)\n')
|
||||
content = f.read()
|
||||
matches = pattern.findall(content)
|
||||
metrics = {k: v for k, v in matches}
|
||||
return metrics
|
||||
|
||||
|
||||
def read_gnina_score(gnina_sdf_path):
|
||||
with open(gnina_sdf_path, 'r') as f:
|
||||
pattern = re.compile(r'> <CNNscore>\n(.*?)\n')
|
||||
content = f.read()
|
||||
matches = pattern.findall(content)
|
||||
return float(matches[0])
|
||||
|
||||
|
||||
def invert_permutation(p):
|
||||
"""Return an array s with which np.array_equal(arr[p][s], arr) is True.
|
||||
The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
|
||||
"""
|
||||
p = np.asanyarray(p) # in case p is a tuple, etc.
|
||||
s = np.empty_like(p)
|
||||
s[p] = np.arange(p.size)
|
||||
return s
|
||||
|
||||
|
||||
def get_gnina_poses(args, mol, pos, orig_center, name, folder, gnina_path, thread_id=0):
|
||||
#folder = "data/MOAD_new_test_processed" if args.split == 'test' else "data/MOAD_new_val_processed"
|
||||
out_dir = args.out_dir if hasattr(args, 'out_dir') else args.inference_out_dir
|
||||
rec_path = os.path.join(folder, name[:6] + '_protein_chain_removed.pdb')
|
||||
pred_lig_path = os.path.join(out_dir, f'pred_{name}_tid{thread_id}_lig.sdf')
|
||||
|
||||
if not os.path.exists(os.path.dirname(pred_lig_path)):
|
||||
os.mkdir(os.path.dirname(pred_lig_path))
|
||||
|
||||
print(f'Ligand path {pred_lig_path}')
|
||||
write_mol_with_coords(mol, pos + orig_center, pred_lig_path)
|
||||
gnina_pred_path = os.path.join(out_dir, f'gnina_{name}_tid{thread_id}_lig.sdf')
|
||||
|
||||
gnina_logs_dir = os.path.join(out_dir, "gnina_logs")
|
||||
|
||||
with open(os.path.join(gnina_logs_dir, f'{name}'), "w+") as f:
|
||||
if args.gnina_full_dock:
|
||||
return_code = subprocess.run(
|
||||
f'{gnina_path} -r {rec_path} -l "{pred_lig_path}" --autobox_ligand "{pred_lig_path}" -o "{gnina_pred_path}" --no_gpu --autobox_add {args.gnina_autobox_add}',
|
||||
shell=True, stdout=f, stderr=f)
|
||||
else:
|
||||
return_code = subprocess.run(
|
||||
f'{gnina_path} --receptor {rec_path} --ligand "{pred_lig_path}" --minimize -o "{gnina_pred_path}"',
|
||||
shell=True, stdout=f, stderr=f)
|
||||
|
||||
# print(f'gnina return code: {return_code}')
|
||||
try:
|
||||
gnina_mol = RemoveAllHs(read_molecule(gnina_pred_path, remove_hs=True, sanitize=True))
|
||||
gnina_minimized_ligand_pos = np.array(gnina_mol.GetConformer(0).GetPositions())
|
||||
gnina_atoms = np.array([atom.GetSymbol() for atom in gnina_mol.GetAtoms()])
|
||||
gnina_filter_Hs = np.where(gnina_atoms != 'H')
|
||||
gnina_ligand_pos = gnina_minimized_ligand_pos[gnina_filter_Hs] - orig_center
|
||||
|
||||
try:
|
||||
gnina_score = read_gnina_score(gnina_pred_path)
|
||||
if gnina_score is None:
|
||||
gnina_score = 0
|
||||
except Exception as e:
|
||||
print(f'Error reading gnina score: {e}')
|
||||
gnina_score = 0
|
||||
|
||||
except Exception as e:
|
||||
print(f'Error when running gnina with {name} to minimize energy')
|
||||
print('Error:', e)
|
||||
print('Using score model output pos instead.')
|
||||
gnina_ligand_pos = pos
|
||||
gnina_mol = RemoveAllHs(mol)
|
||||
gnina_score = 0
|
||||
|
||||
return gnina_ligand_pos, gnina_mol, gnina_score
|
||||
@@ -1,4 +1,6 @@
|
||||
import copy
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
from Bio.PDB import PDBParser
|
||||
@@ -7,38 +9,10 @@ from rdkit.Chem import AddHs, MolFromSmiles
|
||||
from torch_geometric.data import Dataset, HeteroData
|
||||
import esm
|
||||
|
||||
from datasets.process_mols import parse_pdb_from_path, generate_conformer, read_molecule, get_lig_graph_with_matching, \
|
||||
extract_receptor_structure, get_rec_graph
|
||||
from datasets.constants import three_to_one
|
||||
from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure
|
||||
|
||||
|
||||
three_to_one = {'ALA': 'A',
|
||||
'ARG': 'R',
|
||||
'ASN': 'N',
|
||||
'ASP': 'D',
|
||||
'CYS': 'C',
|
||||
'GLN': 'Q',
|
||||
'GLU': 'E',
|
||||
'GLY': 'G',
|
||||
'HIS': 'H',
|
||||
'ILE': 'I',
|
||||
'LEU': 'L',
|
||||
'LYS': 'K',
|
||||
'MET': 'M',
|
||||
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
|
||||
'PHE': 'F',
|
||||
'PRO': 'P',
|
||||
'PYL': 'O',
|
||||
'SER': 'S',
|
||||
'SEC': 'U',
|
||||
'THR': 'T',
|
||||
'TRP': 'W',
|
||||
'TYR': 'Y',
|
||||
'VAL': 'V',
|
||||
'ASX': 'B',
|
||||
'GLX': 'Z',
|
||||
'XAA': 'X',
|
||||
'XLE': 'J'}
|
||||
|
||||
def get_sequences_from_pdbfile(file_path):
|
||||
biopython_parser = PDBParser()
|
||||
structure = biopython_parser.get_structure('random_id', file_path)
|
||||
@@ -153,7 +127,7 @@ def generate_ESM_structure(model, filename, sequence):
|
||||
class InferenceDataset(Dataset):
|
||||
def __init__(self, out_dir, complex_names, protein_files, ligand_descriptions, protein_sequences, lm_embeddings,
|
||||
receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None,
|
||||
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None):
|
||||
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False):
|
||||
|
||||
super(InferenceDataset, self).__init__()
|
||||
self.receptor_radius = receptor_radius
|
||||
@@ -161,6 +135,7 @@ class InferenceDataset(Dataset):
|
||||
self.remove_hs = remove_hs
|
||||
self.all_atoms = all_atoms
|
||||
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
|
||||
self.knn_only_graph = knn_only_graph
|
||||
|
||||
self.complex_names = complex_names
|
||||
self.protein_files = protein_files
|
||||
@@ -242,18 +217,19 @@ class InferenceDataset(Dataset):
|
||||
|
||||
try:
|
||||
# parse the receptor from the pdb file
|
||||
rec_model = parse_pdb_from_path(protein_file)
|
||||
get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False,
|
||||
num_conformers=1, remove_hs=self.remove_hs)
|
||||
rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings = extract_receptor_structure(rec_model, mol, lm_embedding_chains=lm_embedding)
|
||||
if lm_embeddings is not None and len(c_alpha_coords) != len(lm_embeddings):
|
||||
print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
|
||||
complex_graph['success'] = False
|
||||
return complex_graph
|
||||
|
||||
get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius,
|
||||
c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms,
|
||||
atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings)
|
||||
moad_extract_receptor_structure(
|
||||
path=os.path.join(protein_file),
|
||||
complex_graph=complex_graph,
|
||||
neighbor_cutoff=self.receptor_radius,
|
||||
max_neighbors=self.c_alpha_max_neighbors,
|
||||
lm_embeddings=lm_embedding,
|
||||
knn_only_graph=self.knn_only_graph,
|
||||
all_atoms=self.all_atoms,
|
||||
atom_cutoff=self.atom_radius,
|
||||
atom_max_neighbors=self.atom_max_neighbors)
|
||||
|
||||
except Exception as e:
|
||||
print(f'Skipping {name} because of the error:')
|
||||
|
||||
39
utils/molecules_utils.py
Normal file
39
utils/molecules_utils.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from spyrmsd import rmsd, molecule
|
||||
|
||||
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None, return_permutation=False):
|
||||
with time_limit(10):
|
||||
mol = molecule.Molecule.from_rdkit(mol)
|
||||
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
|
||||
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
|
||||
mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
|
||||
RMSD = rmsd.symmrmsd(
|
||||
coords1,
|
||||
coords2,
|
||||
mol.atomicnums,
|
||||
mol2_atomicnums,
|
||||
mol.adjacency_matrix,
|
||||
mol2_adjacency_matrix,
|
||||
return_permutation=return_permutation
|
||||
)
|
||||
return RMSD
|
||||
|
||||
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class TimeoutException(Exception): pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def time_limit(seconds):
|
||||
def signal_handler(signum, frame):
|
||||
raise TimeoutException("Timed out!")
|
||||
|
||||
signal.signal(signal.SIGALRM, signal_handler)
|
||||
signal.alarm(seconds)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
@@ -6,32 +6,46 @@ def parse_train_args():
|
||||
# General arguments
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument('--config', type=FileType(mode='r'), default=None)
|
||||
parser.add_argument('--log_dir', type=str, default='workdir', help='Folder in which to save model and logs')
|
||||
parser.add_argument('--log_dir', type=str, default='workdir/test_score', help='Folder in which to save model and logs')
|
||||
parser.add_argument('--restart_dir', type=str, help='Folder of previous training model from which to restart')
|
||||
parser.add_argument('--restart_ckpt', type=str, default='last_model', help='')
|
||||
parser.add_argument('--pretrain_dir', type=str, help='Folder of pretrained model from which to restart')
|
||||
parser.add_argument('--pretrain_ckpt', type=str, help='')
|
||||
parser.add_argument('--freeze_params', type=int, default=0, help='')
|
||||
parser.add_argument('--cache_path', type=str, default='data/cache', help='Folder from where to load/restore cached dataset')
|
||||
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures')
|
||||
parser.add_argument('--moad_dir', type=str, default='data/BindingMOAD_2020_processed/', help='Folder containing original structures')
|
||||
parser.add_argument('--pdbbind_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures')
|
||||
parser.add_argument('--dataset', type=str, default='pdbbind', help='Folder containing original structures')
|
||||
parser.add_argument('--split_train', type=str, default='data/splits/timesplit_no_lig_overlap_train', help='Path of file defining the split')
|
||||
parser.add_argument('--split_val', type=str, default='data/splits/timesplit_no_lig_overlap_val', help='Path of file defining the split')
|
||||
parser.add_argument('--split_test', type=str, default='data/splits/timesplit_test', help='Path of file defining the split')
|
||||
parser.add_argument('--test_sigma_intervals', action='store_true', default=False, help='Whether to log loss per noise interval')
|
||||
parser.add_argument('--val_inference_freq', type=int, default=5, help='Frequency of epochs for which to run expensive inference on val data')
|
||||
parser.add_argument('--save_model_freq', type=int, default=None, help='')
|
||||
parser.add_argument('--inference_samples', type=int, default=1, help='')
|
||||
parser.add_argument('--train_inference_freq', type=int, default=None, help='Frequency of epochs for which to run expensive inference on train data')
|
||||
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps for inference on val')
|
||||
parser.add_argument('--num_inference_complexes', type=int, default=100, help='Number of complexes for which inference is run every val/train_inference_freq epochs (None will run it on all)')
|
||||
parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None')
|
||||
parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_min_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None')
|
||||
parser.add_argument('--inference_secondary_metric', type=str, default=None, help='')
|
||||
parser.add_argument('--inference_earlystop_goal', type=str, default='max', help='Whether to maximize or minimize metric')
|
||||
parser.add_argument('--wandb', action='store_true', default=False, help='')
|
||||
parser.add_argument('--project', type=str, default='difdock_train', help='')
|
||||
parser.add_argument('--project', type=str, default='diffdock', help='')
|
||||
parser.add_argument('--run_name', type=str, default='', help='')
|
||||
parser.add_argument('--cudnn_benchmark', action='store_true', default=False, help='CUDA optimization parameter for faster training')
|
||||
parser.add_argument('--num_dataloader_workers', type=int, default=0, help='Number of workers for dataloader')
|
||||
parser.add_argument('--pin_memory', action='store_true', default=False, help='pin_memory arg of dataloader')
|
||||
parser.add_argument('--dataloader_drop_last', action='store_true', default=False, help='drop_last arg of dataloader')
|
||||
parser.add_argument('--double_val', action='store_true', default=False, help='')
|
||||
parser.add_argument('--combined_training', action='store_true', default=False, help='')
|
||||
|
||||
# Training arguments
|
||||
parser.add_argument('--n_epochs', type=int, default=400, help='Number of epochs for training')
|
||||
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
|
||||
parser.add_argument('--scheduler', type=str, default=None, help='LR scheduler')
|
||||
parser.add_argument('--scheduler_patience', type=int, default=20, help='Patience of the LR scheduler')
|
||||
parser.add_argument('--lr_start_factor', type=float, default=0.001, help='')
|
||||
parser.add_argument('--warmup_dur', type=int, default=4, help='')
|
||||
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate')
|
||||
parser.add_argument('--restart_lr', type=float, default=None, help='If this is not none, the lr of the optimizer will be overwritten with this value when restarting from a checkpoint.')
|
||||
parser.add_argument('--w_decay', type=float, default=0.0, help='Weight decay added to loss')
|
||||
@@ -40,23 +54,42 @@ def parse_train_args():
|
||||
parser.add_argument('--ema_rate', type=float, default=0.999, help='decay rate for the exponential moving average model parameters ')
|
||||
|
||||
# Dataset
|
||||
parser.add_argument('--limit_complexes', type=int, default=0, help='If positive, the number of training and validation complexes is capped')
|
||||
parser.add_argument('--limit_complexes', type=int, default=5, help='If positive, the number of training and validation complexes is capped') # TODO change
|
||||
parser.add_argument('--all_atoms', action='store_true', default=False, help='Whether to use the all atoms model')
|
||||
parser.add_argument('--chain_cutoff', type=float, default=None, help='Cutoff on whether to include non-interacting chains')
|
||||
parser.add_argument('--receptor_radius', type=float, default=30, help='Cutoff on distances for receptor edges')
|
||||
parser.add_argument('--c_alpha_max_neighbors', type=int, default=10, help='Maximum number of neighbors for each residue')
|
||||
parser.add_argument('--atom_radius', type=float, default=5, help='Cutoff on distances for atom connections')
|
||||
parser.add_argument('--atom_max_neighbors', type=int, default=8, help='Maximum number of atom neighbours for receptor')
|
||||
parser.add_argument('--matching_popsize', type=int, default=20, help='Differential evolution popsize parameter in matching')
|
||||
parser.add_argument('--matching_maxiter', type=int, default=20, help='Differential evolution maxiter parameter in matching')
|
||||
parser.add_argument('--matching_tries', type=int, default=1, help='')
|
||||
parser.add_argument('--max_lig_size', type=int, default=None, help='Maximum number of heavy atoms in ligand')
|
||||
parser.add_argument('--remove_hs', action='store_true', default=False, help='remove Hs')
|
||||
parser.add_argument('--remove_hs', action='store_true', default=True, help='remove Hs')
|
||||
parser.add_argument('--num_conformers', type=int, default=1, help='Number of conformers to match to each ligand')
|
||||
parser.add_argument('--esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
|
||||
parser.add_argument('--moad_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
|
||||
parser.add_argument('--pdbbind_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
|
||||
parser.add_argument('--moad_esm_embeddings_sequences_path', type=str, default=None, help='')
|
||||
parser.add_argument('--esm_embeddings_model', type=str, default=None, help='')
|
||||
parser.add_argument('--not_fixed_knn_radius_graph', action='store_true', default=False, help='Use knn graph and radius graph with closest neighbors instead of random ones as with radius_graph')
|
||||
parser.add_argument('--not_knn_only_graph', action='store_true', default=False, help='Use knn graph only and not restrict to a specific radius')
|
||||
parser.add_argument('--include_miscellaneous_atoms', action='store_true', default=False, help='include non amino acid atoms for the receptor')
|
||||
parser.add_argument('--train_multiplicity', type=int, default=1, help='')
|
||||
parser.add_argument('--val_multiplicity', type=int, default=1, help='')
|
||||
parser.add_argument('--max_receptor_size', type=int, default=None, help='')
|
||||
parser.add_argument('--remove_promiscuous_targets', type=int, default=None, help='')
|
||||
parser.add_argument('--min_ligand_size', type=int, default=0, help='')
|
||||
parser.add_argument('--unroll_clusters', action='store_true', default=False, help='')
|
||||
parser.add_argument('--enforce_timesplit', action='store_true', default=False, help='')
|
||||
parser.add_argument('--merge_clusters', type=int, default=1, help='')
|
||||
parser.add_argument('--triple_training', action='store_true', default=False, help='')
|
||||
parser.add_argument('--crop_beyond', type=float, default=20, help='')
|
||||
|
||||
# Diffusion
|
||||
parser.add_argument('--tr_weight', type=float, default=0.33, help='Weight of translation loss')
|
||||
parser.add_argument('--rot_weight', type=float, default=0.33, help='Weight of rotation loss')
|
||||
parser.add_argument('--tor_weight', type=float, default=0.33, help='Weight of torsional loss')
|
||||
parser.add_argument('--confidence_weight', type=float, default=0.33, help='Weight of confidence loss')
|
||||
parser.add_argument('--rot_sigma_min', type=float, default=0.1, help='Minimum sigma for rotational component')
|
||||
parser.add_argument('--rot_sigma_max', type=float, default=1.65, help='Maximum sigma for rotational component')
|
||||
parser.add_argument('--tr_sigma_min', type=float, default=0.1, help='Minimum sigma for translational component')
|
||||
@@ -64,11 +97,17 @@ def parse_train_args():
|
||||
parser.add_argument('--tor_sigma_min', type=float, default=0.0314, help='Minimum sigma for torsional component')
|
||||
parser.add_argument('--tor_sigma_max', type=float, default=3.14, help='Maximum sigma for torsional component')
|
||||
parser.add_argument('--no_torsion', action='store_true', default=False, help='If set only rigid matching')
|
||||
parser.add_argument('--sampling_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t')
|
||||
parser.add_argument('--sampling_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t')
|
||||
parser.add_argument('--bootstrap_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t in bootstrapping')
|
||||
parser.add_argument('--bootstrap_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t in bootstrapping')
|
||||
parser.add_argument('--bootstrap_tmin', type=float, default=0, help='')
|
||||
|
||||
# Model
|
||||
parser.add_argument('--num_conv_layers', type=int, default=2, help='Number of interaction layers')
|
||||
parser.add_argument('--max_radius', type=float, default=5.0, help='Radius cutoff for geometric graph')
|
||||
parser.add_argument('--scale_by_sigma', action='store_true', default=True, help='Whether to normalise the score')
|
||||
parser.add_argument('--norm_by_sigma', action='store_true', default=False, help='Whether to normalise the score')
|
||||
parser.add_argument('--ns', type=int, default=16, help='Number of hidden features per node of order 0')
|
||||
parser.add_argument('--nv', type=int, default=4, help='Number of hidden features per node of order >0')
|
||||
parser.add_argument('--distance_embed_dim', type=int, default=32, help='Embedding size for the distance')
|
||||
@@ -78,9 +117,36 @@ def parse_train_args():
|
||||
parser.add_argument('--cross_max_distance', type=float, default=80, help='Maximum cross distance in case not dynamic')
|
||||
parser.add_argument('--dynamic_max_cross', action='store_true', default=False, help='Whether to use the dynamic distance cutoff')
|
||||
parser.add_argument('--dropout', type=float, default=0.0, help='MLP dropout')
|
||||
parser.add_argument('--smooth_edges', action='store_true', default=False, help='Whether to apply additional smoothing weight to edges')
|
||||
parser.add_argument('--odd_parity', action='store_true', default=False, help='Whether to impose odd parity in output')
|
||||
parser.add_argument('--embedding_type', type=str, default="sinusoidal", help='Type of diffusion time embedding')
|
||||
parser.add_argument('--sigma_embed_dim', type=int, default=32, help='Size of the embedding of the diffusion time')
|
||||
parser.add_argument('--embedding_scale', type=int, default=1000, help='Parameter of the diffusion time embedding')
|
||||
parser.add_argument('--use_old_atom_encoder', action='store_true', default=False, help='option to use old atom encoder for backward compatibility')
|
||||
parser.add_argument('--depthwise_convolution', action='store_true', default=False, help='')
|
||||
|
||||
parser.add_argument('--protein_file', type=str, default='protein_processed', help='')
|
||||
parser.add_argument('--no_aminoacid_identities', action='store_true', default=False, help='')
|
||||
parser.add_argument('--sh_lmax', type=int, default=2, help='Size of the embedding of the diffusion time')
|
||||
parser.add_argument('--no_differentiate_convolutions', action='store_true', default=False, help='')
|
||||
parser.add_argument('--tp_weights_layers', type=int, default=2, help='')
|
||||
parser.add_argument('--num_prot_emb_layers', type=int, default=0, help='')
|
||||
parser.add_argument('--reduce_pseudoscalars', action='store_true', default=False, help='')
|
||||
parser.add_argument('--embed_also_ligand', action='store_true', default=True, help='')
|
||||
parser.add_argument('--sidechain_loss_weight', type=float, default=0, help='')
|
||||
parser.add_argument('--backbone_loss_weight', type=float, default=0, help='')
|
||||
|
||||
# pdb sidechain training
|
||||
parser.add_argument('--pdbsidechain_dir', type=str, default='data/pdb_2021aug02_sample', help='')
|
||||
parser.add_argument('--pdbsidechain_esm_embeddings_path', type=str, default=None, help='')
|
||||
parser.add_argument('--pdbsidechain_esm_embeddings_sequences_path', type=str, default=None, help='')
|
||||
parser.add_argument('--vandermers_max_dist', type=int, default=5, help='')
|
||||
parser.add_argument('--vandermers_buffer_residue_num', type=int, default=7, help='')
|
||||
parser.add_argument('--vandermers_min_contacts', type=int, default=None, help='')
|
||||
parser.add_argument('--remove_second_segment', action='store_true', default=False, help='')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
assert (not args.dynamic_max_cross) or (args.tr_sigma_max * 3 + 20 < args.cross_max_distance)
|
||||
assert args.esm_embeddings_model is None or args.esm_embeddings_path is None
|
||||
return args
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
import copy
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch_geometric.data import Batch
|
||||
from torch_geometric.loader import DataLoader
|
||||
|
||||
from utils.diffusion_utils import modify_conformer, set_time
|
||||
from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch
|
||||
from utils.torsion import modify_conformer_torsion_angles
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from utils.utils import crop_beyond
|
||||
|
||||
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max):
|
||||
|
||||
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7,
|
||||
initial_noise_std_proportion=-1.0, choose_residue=False):
|
||||
# in place modification of the list
|
||||
center_pocket = data_list[0]['receptor'].pos.mean(dim=0)
|
||||
if pocket_knowledge:
|
||||
complex = data_list[0]
|
||||
d = torch.cdist(complex['receptor'].pos, torch.from_numpy(complex['ligand'].orig_pos[0]).float() - complex.original_center)
|
||||
label = torch.any(d < pocket_cutoff, dim=1)
|
||||
|
||||
if torch.any(label):
|
||||
center_pocket = complex['receptor'].pos[label].mean(dim=0)
|
||||
else:
|
||||
print("No pocket residue below minimum distance ", pocket_cutoff, "taking closest at", torch.min(d))
|
||||
center_pocket = complex['receptor'].pos[torch.argmin(torch.min(d, dim=1)[0])]
|
||||
|
||||
if not no_torsion:
|
||||
# randomize torsion angles
|
||||
for complex_graph in data_list:
|
||||
@@ -23,92 +42,212 @@ def randomize_position(data_list, no_torsion, no_random, tr_sigma_max):
|
||||
# randomize position
|
||||
molecule_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True)
|
||||
random_rotation = torch.from_numpy(R.random().as_matrix()).float()
|
||||
complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T
|
||||
complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T + center_pocket
|
||||
# base_rmsd = np.sqrt(np.sum((complex_graph['ligand'].pos.cpu().numpy() - orig_complex_graph['ligand'].pos.numpy()) ** 2, axis=1).mean())
|
||||
|
||||
if not no_random: # note for now the torsion angles are still randomised
|
||||
tr_update = torch.normal(mean=0, std=tr_sigma_max, size=(1, 3))
|
||||
if choose_residue:
|
||||
idx = random.randint(0, len(complex_graph['receptor'].pos)-1)
|
||||
tr_update = torch.normal(mean=complex_graph['receptor'].pos[idx:idx+1], std=0.01)
|
||||
elif initial_noise_std_proportion >= 0.0:
|
||||
std_rec = torch.sqrt(torch.mean(torch.sum(complex_graph['receptor'].pos ** 2, dim=1)))
|
||||
tr_update = torch.normal(mean=0, std=std_rec * initial_noise_std_proportion / 1.73, size=(1, 3))
|
||||
else:
|
||||
# if initial_noise_std_proportion < 0.0, we use the tr_sigma_max multiplied by -initial_noise_std_proportion
|
||||
tr_update = torch.normal(mean=0, std=-initial_noise_std_proportion * tr_sigma_max, size=(1, 3))
|
||||
complex_graph['ligand'].pos += tr_update
|
||||
|
||||
|
||||
def is_iterable(arr):
|
||||
try:
|
||||
some_object_iterator = iter(arr)
|
||||
return True
|
||||
except TypeError as te:
|
||||
return False
|
||||
|
||||
|
||||
def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_schedule, device, t_to_sigma, model_args,
|
||||
no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None,
|
||||
confidence_model_args=None, batch_size=32, no_final_step_noise=False):
|
||||
no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None, confidence_model_args=None,
|
||||
t_schedule=None, batch_size=32, no_final_step_noise=False, pivot=None, return_full_trajectory=False,
|
||||
temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False):
|
||||
N = len(data_list)
|
||||
trajectory = []
|
||||
if return_features:
|
||||
lig_features, rec_features = [], []
|
||||
assert batch_size >= N, "Not implemented yet"
|
||||
|
||||
for t_idx in range(inference_steps):
|
||||
t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx]
|
||||
dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx]
|
||||
dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx]
|
||||
dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx]
|
||||
loader = DataLoader(data_list, batch_size=batch_size)
|
||||
assert not (return_full_trajectory or return_features or pivot), "Not implemented yet in new inference version"
|
||||
|
||||
loader = DataLoader(data_list, batch_size=batch_size)
|
||||
new_data_list = []
|
||||
mask_rotate = torch.from_numpy(data_list[0]['ligand'].mask_rotate[0]).to(device)
|
||||
|
||||
for complex_graph_batch in loader:
|
||||
b = complex_graph_batch.num_graphs
|
||||
complex_graph_batch = complex_graph_batch.to(device)
|
||||
|
||||
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor)
|
||||
set_time(complex_graph_batch, t_tr, t_rot, t_tor, b, model_args.all_atoms, device)
|
||||
|
||||
with torch.no_grad():
|
||||
tr_score, rot_score, tor_score = model(complex_graph_batch)
|
||||
|
||||
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
|
||||
rot_g = 2 * rot_sigma * torch.sqrt(torch.tensor(np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
|
||||
|
||||
if ode:
|
||||
tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score.cpu()).cpu()
|
||||
rot_perturb = (0.5 * rot_score.cpu() * dt_rot * rot_g ** 2).cpu()
|
||||
else:
|
||||
tr_z = torch.zeros((b, 3)) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=(b, 3))
|
||||
tr_perturb = (tr_g ** 2 * dt_tr * tr_score.cpu() + tr_g * np.sqrt(dt_tr) * tr_z).cpu()
|
||||
|
||||
rot_z = torch.zeros((b, 3)) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=(b, 3))
|
||||
rot_perturb = (rot_score.cpu() * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z).cpu()
|
||||
|
||||
if not model_args.no_torsion:
|
||||
tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min)))
|
||||
if ode:
|
||||
tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score.cpu()).numpy()
|
||||
else:
|
||||
tor_z = torch.zeros(tor_score.shape) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=tor_score.shape)
|
||||
tor_perturb = (tor_g ** 2 * dt_tor * tor_score.cpu() + tor_g * np.sqrt(dt_tor) * tor_z).numpy()
|
||||
torsions_per_molecule = tor_perturb.shape[0] // b
|
||||
else:
|
||||
tor_perturb = None
|
||||
|
||||
# Apply noise
|
||||
new_data_list.extend([modify_conformer(complex_graph, tr_perturb[i:i + 1], rot_perturb[i:i + 1].squeeze(0),
|
||||
tor_perturb[i * torsions_per_molecule:(i + 1) * torsions_per_molecule] if not model_args.no_torsion else None)
|
||||
for i, complex_graph in enumerate(complex_graph_batch.to('cpu').to_data_list())])
|
||||
data_list = new_data_list
|
||||
|
||||
if visualization_list is not None:
|
||||
for idx, visualization in enumerate(visualization_list):
|
||||
visualization.add((data_list[idx]['ligand'].pos + data_list[idx].original_center).detach().cpu(),
|
||||
part=1, order=t_idx + 2)
|
||||
confidence = None
|
||||
if confidence_model is not None:
|
||||
confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size))
|
||||
confidence = []
|
||||
|
||||
with torch.no_grad():
|
||||
if confidence_model is not None:
|
||||
loader = DataLoader(data_list, batch_size=batch_size)
|
||||
confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size))
|
||||
confidence = []
|
||||
for complex_graph_batch in loader:
|
||||
complex_graph_batch = complex_graph_batch.to(device)
|
||||
if confidence_data_list is not None:
|
||||
confidence_complex_graph_batch = next(confidence_loader).to(device)
|
||||
confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos
|
||||
set_time(confidence_complex_graph_batch, 0, 0, 0, N, confidence_model_args.all_atoms, device)
|
||||
confidence.append(confidence_model(confidence_complex_graph_batch))
|
||||
for batch_id, complex_graph_batch in enumerate(loader):
|
||||
b = complex_graph_batch.num_graphs
|
||||
n = len(complex_graph_batch['ligand'].pos) // b
|
||||
complex_graph_batch = complex_graph_batch.to(device)
|
||||
|
||||
for t_idx in range(inference_steps):
|
||||
t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx]
|
||||
dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx]
|
||||
dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx]
|
||||
dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx]
|
||||
|
||||
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor)
|
||||
|
||||
if hasattr(model_args, 'crop_beyond') and model_args.crop_beyond is not None:
|
||||
#print('Cropping beyond', tr_sigma * 3 + model_args.crop_beyond, 'for score model')
|
||||
mod_complex_graph_batch = copy.deepcopy(complex_graph_batch).to_data_list()
|
||||
for batch in mod_complex_graph_batch:
|
||||
crop_beyond(batch, tr_sigma * 3 + model_args.crop_beyond, model_args.all_atoms)
|
||||
mod_complex_graph_batch = Batch.from_data_list(mod_complex_graph_batch)
|
||||
else:
|
||||
confidence.append(confidence_model(complex_graph_batch))
|
||||
confidence = torch.cat(confidence, dim=0)
|
||||
else:
|
||||
confidence = None
|
||||
mod_complex_graph_batch = complex_graph_batch
|
||||
|
||||
set_time(mod_complex_graph_batch, t_schedule[t_idx] if t_schedule is not None else None, t_tr, t_rot, t_tor, b,
|
||||
'all_atoms' in model_args and model_args.all_atoms, device)
|
||||
|
||||
tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3]
|
||||
|
||||
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
|
||||
rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
|
||||
|
||||
if ode:
|
||||
tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score)
|
||||
rot_perturb = (0.5 * rot_score * dt_rot * rot_g ** 2)
|
||||
else:
|
||||
tr_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
|
||||
tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z)
|
||||
|
||||
rot_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
|
||||
rot_perturb = (rot_score * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z)
|
||||
|
||||
if not model_args.no_torsion:
|
||||
tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min)))
|
||||
if ode:
|
||||
tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score)
|
||||
else:
|
||||
tor_z = torch.zeros(tor_score.shape, device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
|
||||
else torch.normal(mean=0, std=1, size=tor_score.shape, device=device)
|
||||
tor_perturb = (tor_g ** 2 * dt_tor * tor_score + tor_g * np.sqrt(dt_tor) * tor_z)
|
||||
torsions_per_molecule = tor_perturb.shape[0] // b
|
||||
else:
|
||||
tor_perturb = None
|
||||
|
||||
if not is_iterable(temp_sampling):
|
||||
temp_sampling = [temp_sampling] * 3
|
||||
if not is_iterable(temp_psi):
|
||||
temp_psi = [temp_psi] * 3
|
||||
|
||||
if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3
|
||||
if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3
|
||||
if not is_iterable(temp_sigma_data): temp_sigma_data = [temp_sigma_data] * 3
|
||||
|
||||
assert len(temp_sampling) == 3
|
||||
assert len(temp_psi) == 3
|
||||
assert len(temp_sigma_data) == 3
|
||||
|
||||
if temp_sampling[0] != 1.0:
|
||||
tr_sigma_data = np.exp(temp_sigma_data[0] * np.log(model_args.tr_sigma_max) + (1 - temp_sigma_data[0]) * np.log(model_args.tr_sigma_min))
|
||||
lambda_tr = (tr_sigma_data + tr_sigma) / (tr_sigma_data + tr_sigma / temp_sampling[0])
|
||||
tr_perturb = (tr_g ** 2 * dt_tr * (lambda_tr + temp_sampling[0] * temp_psi[0] / 2) * tr_score + tr_g * np.sqrt(dt_tr * (1 + temp_psi[0])) * tr_z)
|
||||
|
||||
if temp_sampling[1] != 1.0:
|
||||
rot_sigma_data = np.exp(temp_sigma_data[1] * np.log(model_args.rot_sigma_max) + (1 - temp_sigma_data[1]) * np.log(model_args.rot_sigma_min))
|
||||
lambda_rot = (rot_sigma_data + rot_sigma) / (rot_sigma_data + rot_sigma / temp_sampling[1])
|
||||
rot_perturb = (rot_g ** 2 * dt_rot * (lambda_rot + temp_sampling[1] * temp_psi[1] / 2) * rot_score + rot_g * np.sqrt(dt_rot * (1 + temp_psi[1])) * rot_z)
|
||||
|
||||
if temp_sampling[2] != 1.0:
|
||||
tor_sigma_data = np.exp(temp_sigma_data[2] * np.log(model_args.tor_sigma_max) + (1 - temp_sigma_data[2]) * np.log(model_args.tor_sigma_min))
|
||||
lambda_tor = (tor_sigma_data + tor_sigma) / (tor_sigma_data + tor_sigma / temp_sampling[2])
|
||||
tor_perturb = (tor_g ** 2 * dt_tor * (lambda_tor + temp_sampling[2] * temp_psi[2] / 2) * tor_score + tor_g * np.sqrt(dt_tor * (1 + temp_psi[2])) * tor_z)
|
||||
|
||||
# Apply noise
|
||||
complex_graph_batch['ligand'].pos = \
|
||||
modify_conformer_batch(complex_graph_batch['ligand'].pos, complex_graph_batch, tr_perturb, rot_perturb,
|
||||
tor_perturb if not model_args.no_torsion else None, mask_rotate)
|
||||
|
||||
if visualization_list is not None:
|
||||
for idx_b in range(b):
|
||||
visualization_list[batch_id * batch_size + idx_b].add((
|
||||
complex_graph_batch['ligand'].pos[idx_b*n:n*(idx_b+1)].detach().cpu() +
|
||||
data_list[batch_id * batch_size + idx_b].original_center.detach().cpu()),
|
||||
part=1, order=t_idx + 2)
|
||||
|
||||
for i in range(b):
|
||||
data_list[batch_id * batch_size + i]['ligand'].pos = complex_graph_batch['ligand'].pos[i*n:n*(i+1)]
|
||||
|
||||
if visualization_list is not None:
|
||||
for idx, visualization in enumerate(visualization_list):
|
||||
visualization.add((data_list[idx]['ligand'].pos.detach().cpu() + data_list[idx].original_center.detach().cpu()),
|
||||
part=1, order=2)
|
||||
|
||||
if confidence_model is not None:
|
||||
if confidence_data_list is not None:
|
||||
confidence_complex_graph_batch = next(confidence_loader)
|
||||
confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos.cpu()
|
||||
|
||||
if hasattr(confidence_model_args, 'crop_beyond') and confidence_model_args.crop_beyond is not None:
|
||||
confidence_complex_graph_batch = confidence_complex_graph_batch.to_data_list()
|
||||
for batch in confidence_complex_graph_batch:
|
||||
crop_beyond(batch, confidence_model_args.crop_beyond, confidence_model_args.all_atoms)
|
||||
confidence_complex_graph_batch = Batch.from_data_list(confidence_complex_graph_batch)
|
||||
|
||||
confidence_complex_graph_batch = confidence_complex_graph_batch.to(device)
|
||||
set_time(confidence_complex_graph_batch, 0, 0, 0, 0, b, confidence_model_args.all_atoms, device)
|
||||
out = confidence_model(confidence_complex_graph_batch)
|
||||
else:
|
||||
out = confidence_model(complex_graph_batch)
|
||||
|
||||
if type(out) is tuple:
|
||||
out = out[0]
|
||||
confidence.append(out)
|
||||
|
||||
if confidence_model is not None:
|
||||
confidence = torch.cat(confidence, dim=0)
|
||||
confidence = torch.nan_to_num(confidence, nan=-1000)
|
||||
|
||||
if return_full_trajectory:
|
||||
return data_list, confidence, trajectory
|
||||
elif return_features:
|
||||
lig_features = torch.cat(lig_features, dim=0)
|
||||
rec_features = torch.cat(rec_features, dim=0)
|
||||
return data_list, confidence, lig_features, rec_features
|
||||
|
||||
return data_list, confidence
|
||||
|
||||
|
||||
def compute_affinity(data_list, affinity_model, affinity_data_list, device, parallel, all_atoms, include_miscellaneous_atoms):
|
||||
|
||||
with torch.no_grad():
|
||||
if affinity_model is not None:
|
||||
assert parallel <= len(data_list)
|
||||
loader = DataLoader(data_list, batch_size=parallel)
|
||||
complex_graph_batch = next(iter(loader)).to(device)
|
||||
positions = complex_graph_batch['ligand'].pos
|
||||
|
||||
assert affinity_data_list is not None
|
||||
complex_graph = affinity_data_list[0]
|
||||
N = complex_graph['ligand'].num_nodes
|
||||
complex_graph['ligand'].x = complex_graph['ligand'].x.repeat(parallel, 1)
|
||||
complex_graph['ligand'].edge_mask = complex_graph['ligand'].edge_mask.repeat(parallel)
|
||||
complex_graph['ligand', 'ligand'].edge_index = torch.cat(
|
||||
[N * i + complex_graph['ligand', 'ligand'].edge_index for i in range(parallel)], dim=1)
|
||||
complex_graph['ligand', 'ligand'].edge_attr = complex_graph['ligand', 'ligand'].edge_attr.repeat(parallel, 1)
|
||||
complex_graph['ligand'].pos = positions
|
||||
|
||||
affinity_loader = DataLoader([complex_graph], batch_size=1)
|
||||
affinity_batch = next(iter(affinity_loader)).to(device)
|
||||
set_time(affinity_batch, 0, 0, 0, 0, 1, all_atoms, device, include_miscellaneous_atoms=include_miscellaneous_atoms)
|
||||
_, affinity = affinity_model(affinity_batch)
|
||||
else:
|
||||
affinity = None
|
||||
|
||||
return affinity
|
||||
25
utils/so3.py
25
utils/so3.py
@@ -3,7 +3,7 @@ import numpy as np
|
||||
import torch
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000
|
||||
MIN_EPS, MAX_EPS, N_EPS = 0.0005, 4, 2000
|
||||
X_N = 2000
|
||||
|
||||
"""
|
||||
@@ -21,7 +21,7 @@ def _compose(r1, r2): # R1 @ R2 but for Euler vecs
|
||||
def _expansion(omega, eps, L=2000): # the summation term only
|
||||
p = 0
|
||||
for l in range(L):
|
||||
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
|
||||
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
|
||||
return p
|
||||
|
||||
|
||||
@@ -39,17 +39,16 @@ def _score(exp, omega, eps, L=2000): # score of density over SO(3)
|
||||
dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2))
|
||||
lo = np.sin(omega / 2)
|
||||
dlo = 1 / 2 * np.cos(omega / 2)
|
||||
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo ** 2
|
||||
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * (lo * dhi - hi * dlo) / lo ** 2
|
||||
return dSigma / exp
|
||||
|
||||
|
||||
if os.path.exists('.so3_omegas_array2.npy'):
|
||||
_omegas_array = np.load('.so3_omegas_array2.npy')
|
||||
_cdf_vals = np.load('.so3_cdf_vals2.npy')
|
||||
_score_norms = np.load('.so3_score_norms2.npy')
|
||||
_exp_score_norms = np.load('.so3_exp_score_norms2.npy')
|
||||
if os.path.exists('.so3_omegas_array4.npy'):
|
||||
_omegas_array = np.load('.so3_omegas_array4.npy')
|
||||
_cdf_vals = np.load('.so3_cdf_vals4.npy')
|
||||
_score_norms = np.load('.so3_score_norms4.npy')
|
||||
_exp_score_norms = np.load('.so3_exp_score_norms4.npy')
|
||||
else:
|
||||
print("Precomputing and saving to cache SO(3) distribution table")
|
||||
_eps_array = 10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS)
|
||||
_omegas_array = np.linspace(0, np.pi, X_N + 1)[1:]
|
||||
|
||||
@@ -60,10 +59,10 @@ else:
|
||||
|
||||
_exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi)
|
||||
|
||||
np.save('.so3_omegas_array2.npy', _omegas_array)
|
||||
np.save('.so3_cdf_vals2.npy', _cdf_vals)
|
||||
np.save('.so3_score_norms2.npy', _score_norms)
|
||||
np.save('.so3_exp_score_norms2.npy', _exp_score_norms)
|
||||
np.save('.so3_omegas_array4.npy', _omegas_array)
|
||||
np.save('.so3_cdf_vals4.npy', _cdf_vals)
|
||||
np.save('.so3_score_norms4.npy', _score_norms)
|
||||
np.save('.so3_exp_score_norms4.npy', _exp_score_norms)
|
||||
|
||||
|
||||
def sample(eps):
|
||||
|
||||
@@ -5,6 +5,8 @@ from scipy.spatial.transform import Rotation as R
|
||||
from torch_geometric.utils import to_networkx
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from utils.geometry import rigid_transform_Kabsch_independent_torch, axis_angle_to_matrix
|
||||
|
||||
"""
|
||||
Preprocessing and computation for torsional updates to conformers
|
||||
"""
|
||||
@@ -35,7 +37,7 @@ def get_transformation_mask(pyg_data):
|
||||
mask_edges = np.asarray([0 if len(l) == 0 else 1 for l in to_rotate], dtype=bool)
|
||||
mask_rotate = np.zeros((np.sum(mask_edges), len(G.nodes())), dtype=bool)
|
||||
idx = 0
|
||||
for i in range(len(G.edges())):
|
||||
for i in range(min(edges.shape[0], len(G.edges()))):
|
||||
if mask_edges[i]:
|
||||
mask_rotate[idx][np.asarray(to_rotate[i], dtype=int)] = True
|
||||
idx += 1
|
||||
@@ -46,15 +48,19 @@ def get_transformation_mask(pyg_data):
|
||||
def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False):
|
||||
pos = copy.deepcopy(pos)
|
||||
if type(pos) != np.ndarray: pos = pos.cpu().numpy()
|
||||
|
||||
|
||||
if type(mask_rotate) == list: mask_rotate = mask_rotate[0]
|
||||
|
||||
for idx_edge, e in enumerate(edge_index.cpu().numpy()):
|
||||
if torsion_updates[idx_edge] == 0:
|
||||
continue
|
||||
u, v = e[0], e[1]
|
||||
|
||||
# check if need to reverse the edge, v should be connected to the part that gets rotated
|
||||
assert not mask_rotate[idx_edge, u]
|
||||
assert mask_rotate[idx_edge, v]
|
||||
if mask_rotate[idx_edge, u] or (not mask_rotate[idx_edge, v]):
|
||||
print("mask rotate exception")
|
||||
#assert not mask_rotate[idx_edge, u]
|
||||
#assert mask_rotate[idx_edge, v]
|
||||
|
||||
rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards
|
||||
rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge!
|
||||
@@ -66,6 +72,24 @@ def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_update
|
||||
return pos
|
||||
|
||||
|
||||
def modify_conformer_torsion_angles_batch(pos, edge_index, mask_rotate, torsion_updates):
|
||||
pos = pos + 0
|
||||
for idx_edge, e in enumerate(edge_index):
|
||||
u, v = e[0], e[1]
|
||||
|
||||
# check if need to reverse the edge, v should be connected to the part that gets rotated
|
||||
assert not mask_rotate[idx_edge, u]
|
||||
assert mask_rotate[idx_edge, v]
|
||||
|
||||
rot_vec = pos[:, u] - pos[:, v] # convention: positive rotation if pointing inwards
|
||||
rot_mat = axis_angle_to_matrix(
|
||||
rot_vec / torch.linalg.norm(rot_vec, dim=-1, keepdims=True) * torsion_updates[:, idx_edge:idx_edge + 1])
|
||||
|
||||
pos[:, mask_rotate[idx_edge]] = torch.bmm(pos[:, mask_rotate[idx_edge]] - pos[:, v:v + 1], torch.transpose(rot_mat, 1, 2)) + pos[:, v:v + 1]
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def perturb_batch(data, torsion_updates, split=False, return_updates=False):
|
||||
if type(data) is Data:
|
||||
return modify_conformer_torsion_angles(data.pos,
|
||||
@@ -91,4 +115,24 @@ def perturb_batch(data, torsion_updates, split=False, return_updates=False):
|
||||
idx_edges += mask_rotate.shape[0]
|
||||
if return_updates:
|
||||
return pos_new, torsion_update_list
|
||||
return pos_new
|
||||
return pos_new
|
||||
|
||||
|
||||
def get_dihedrals(data_list):
|
||||
edge_index, edge_mask = data_list[0]['ligand', 'ligand'].edge_index, data_list[0]['ligand'].edge_mask
|
||||
edge_list = [[] for _ in range(torch.max(edge_index) + 1)]
|
||||
|
||||
for p in edge_index.T:
|
||||
edge_list[p[0]].append(p[1])
|
||||
|
||||
rot_bonds = [(p[0], p[1]) for i, p in enumerate(edge_index.T) if edge_mask[i]]
|
||||
|
||||
dihedral = []
|
||||
for a, b in rot_bonds:
|
||||
c = edge_list[a][0] if edge_list[a][0] != b else edge_list[a][1]
|
||||
d = edge_list[b][0] if edge_list[b][0] != a else edge_list[b][1]
|
||||
dihedral.append((c.item(), a.item(), b.item(), d.item()))
|
||||
# dihedral_numpy = np.asarray(dihedral)
|
||||
# print(dihedral_numpy.shape)
|
||||
dihedral = torch.tensor(dihedral)
|
||||
return dihedral
|
||||
|
||||
@@ -32,7 +32,6 @@ if os.path.exists('.p.npy'):
|
||||
p_ = np.load('.p.npy')
|
||||
score_ = np.load('.score.npy')
|
||||
else:
|
||||
print("Precomputing and saving to cache torus distribution table")
|
||||
p_ = p(x, sigma[:, None], N=100)
|
||||
np.save('.p.npy', p_)
|
||||
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import copy
|
||||
|
||||
import numpy as np
|
||||
from rdkit.Chem import RemoveAllHs
|
||||
from torch_geometric.loader import DataLoader
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
from confidence.dataset import ListDataset
|
||||
from utils import so3, torus
|
||||
from utils.molecules_utils import get_symmetry_rmsd
|
||||
from utils.sampling import randomize_position, sampling
|
||||
import torch
|
||||
from utils.diffusion_utils import get_t_schedule
|
||||
|
||||
|
||||
def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weight=1, rot_weight=1,
|
||||
tor_weight=1, apply_mean=True, no_torsion=False):
|
||||
def loss_function(tr_pred, rot_pred, tor_pred, sidechain_pred, data, t_to_sigma, device, tr_weight=1, rot_weight=1,
|
||||
tor_weight=1, backbone_weight=0, sidechain_weight=0, apply_mean=True, no_torsion=False):
|
||||
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(
|
||||
*[torch.cat([d.complex_t[noise_type] for d in data]) if device.type == 'cuda' else data.complex_t[noise_type]
|
||||
for noise_type in ['tr', 'rot', 'tor']])
|
||||
@@ -21,14 +22,14 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
|
||||
# translation component
|
||||
tr_score = torch.cat([d.tr_score for d in data], dim=0) if device.type == 'cuda' else data.tr_score
|
||||
tr_sigma = tr_sigma.unsqueeze(-1)
|
||||
tr_loss = ((tr_pred.cpu() - tr_score) ** 2 * tr_sigma ** 2).mean(dim=mean_dims)
|
||||
tr_loss = ((tr_pred.cpu() - tr_score.cpu()) ** 2 * tr_sigma.cpu() ** 2).mean(dim=mean_dims)
|
||||
tr_base_loss = (tr_score ** 2 * tr_sigma ** 2).mean(dim=mean_dims).detach()
|
||||
|
||||
# rotation component
|
||||
rot_score = torch.cat([d.rot_score for d in data], dim=0) if device.type == 'cuda' else data.rot_score
|
||||
rot_score_norm = so3.score_norm(rot_sigma.cpu()).unsqueeze(-1)
|
||||
rot_loss = (((rot_pred.cpu() - rot_score) / rot_score_norm) ** 2).mean(dim=mean_dims)
|
||||
rot_base_loss = ((rot_score / rot_score_norm) ** 2).mean(dim=mean_dims).detach()
|
||||
rot_loss = (((rot_pred.cpu() - rot_score.cpu()) / rot_score_norm) ** 2).mean(dim=mean_dims)
|
||||
rot_base_loss = ((rot_score.cpu() / rot_score_norm) ** 2).mean(dim=mean_dims).detach()
|
||||
|
||||
# torsion component
|
||||
if not no_torsion:
|
||||
@@ -36,8 +37,8 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
|
||||
np.concatenate([d.tor_sigma_edge for d in data] if device.type == 'cuda' else data.tor_sigma_edge))
|
||||
tor_score = torch.cat([d.tor_score for d in data], dim=0) if device.type == 'cuda' else data.tor_score
|
||||
tor_score_norm2 = torch.tensor(torus.score_norm(edge_tor_sigma.cpu().numpy())).float()
|
||||
tor_loss = ((tor_pred.cpu() - tor_score) ** 2 / tor_score_norm2)
|
||||
tor_base_loss = ((tor_score ** 2 / tor_score_norm2)).detach()
|
||||
tor_loss = ((tor_pred.cpu() - tor_score.cpu()) ** 2 / tor_score_norm2)
|
||||
tor_base_loss = ((tor_score.cpu() ** 2 / tor_score_norm2)).detach()
|
||||
if apply_mean:
|
||||
tor_loss, tor_base_loss = tor_loss.mean() * torch.ones(1, dtype=torch.float), tor_base_loss.mean() * torch.ones(1, dtype=torch.float)
|
||||
else:
|
||||
@@ -57,8 +58,70 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
|
||||
else:
|
||||
tor_loss, tor_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(len(rot_loss), dtype=torch.float)
|
||||
|
||||
loss = tr_loss * tr_weight + rot_loss * rot_weight + tor_loss * tor_weight
|
||||
return loss, tr_loss.detach(), rot_loss.detach(), tor_loss.detach(), tr_base_loss, rot_base_loss, tor_base_loss
|
||||
if backbone_weight > 0:
|
||||
backbone_vecs = torch.cat([d['receptor'].side_chain_vecs.cpu() for d in data], dim=0) if device.type == 'cuda' else data['receptor'].side_chain_vecs
|
||||
backbone_vecs = backbone_vecs[:, 4:]
|
||||
backbone_pred = sidechain_pred[:, 4:]
|
||||
|
||||
backbone_base_loss = (backbone_vecs ** 2).detach().mean(dim=1) + 0.0001
|
||||
backbone_loss = ((backbone_pred.cpu() - backbone_vecs) ** 2).mean(dim=1) / backbone_base_loss.mean()
|
||||
backbone_base_loss = backbone_base_loss / backbone_base_loss.mean()
|
||||
if apply_mean:
|
||||
backbone_loss, backbone_base_loss = backbone_loss.mean() * torch.ones(1, dtype=torch.float), backbone_base_loss.mean() * torch.ones(1, dtype=torch.float)
|
||||
else:
|
||||
index = torch.cat([torch.ones((d['receptor'].pos.shape[0])) * i for i, d in enumerate(data)], dim=0).long() if device.type == 'cuda' else data['receptor'].batch
|
||||
num_graphs = len(data) if device.type == 'cuda' else data.num_graphs
|
||||
s_l, s_b_l, c = torch.zeros(num_graphs), torch.zeros(num_graphs), torch.zeros(num_graphs)
|
||||
c.index_add_(0, index, torch.ones(backbone_loss.shape[0]))
|
||||
c = c + 0.0001
|
||||
s_l.index_add_(0, index, backbone_loss)
|
||||
s_b_l.index_add_(0, index, backbone_base_loss)
|
||||
backbone_loss, backbone_base_loss = s_l / c, s_b_l / c
|
||||
else:
|
||||
if apply_mean:
|
||||
backbone_loss, backbone_base_loss = torch.zeros(1, dtype=torch.float), torch.zeros(1, dtype=torch.float)
|
||||
else:
|
||||
backbone_loss, backbone_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(len(rot_loss), dtype=torch.float)
|
||||
|
||||
if sidechain_weight > 0:
|
||||
sidechain_vecs = torch.cat([d['receptor'].side_chain_vecs.cpu() for d in data],
|
||||
dim=0) if device.type == 'cuda' else data['receptor'].side_chain_vecs
|
||||
chi_angles = sidechain_vecs[:, :4].to(device)
|
||||
chi_pred = sidechain_pred[:, :4].to(device)
|
||||
|
||||
chi_pred = torch.where(torch.isnan(chi_angles), torch.zeros_like(chi_angles, device=device), chi_pred)
|
||||
chi_angles = torch.where(torch.isnan(chi_angles), torch.zeros_like(chi_angles, device=device), chi_angles)
|
||||
|
||||
difference = torch.abs(chi_pred - chi_angles)
|
||||
difference = torch.min(difference, 1 - difference) # angles are circular and 360 degrees = 1
|
||||
|
||||
sidechain_base_loss = (chi_angles ** 2).detach().mean(dim=1) + 0.0001
|
||||
sidechain_loss = (difference ** 2).mean(dim=1) / sidechain_base_loss.mean()
|
||||
sidechain_base_loss = sidechain_base_loss / sidechain_base_loss.mean()
|
||||
if apply_mean:
|
||||
sidechain_loss, sidechain_base_loss = \
|
||||
sidechain_loss.mean().cpu() * torch.ones(1, dtype=torch.float), \
|
||||
sidechain_base_loss.mean().cpu() * torch.ones(1, dtype=torch.float)
|
||||
else:
|
||||
index = torch.cat([torch.ones((d['receptor'].pos.shape[0])) * i for i, d in enumerate(data)],
|
||||
dim=0).long() if device.type == 'cuda' else data['receptor'].batch
|
||||
num_graphs = len(data) if device.type == 'cuda' else data.num_graphs
|
||||
s_l, s_b_l, c = torch.zeros(num_graphs), torch.zeros(num_graphs), torch.zeros(num_graphs)
|
||||
c.index_add_(0, index, torch.ones(sidechain_loss.shape[0]))
|
||||
c = c + 0.0001
|
||||
s_l.index_add_(0, index, sidechain_loss.cpu())
|
||||
s_b_l.index_add_(0, index, sidechain_base_loss.cpu())
|
||||
sidechain_loss, sidechain_base_loss = s_l / c, s_b_l / c
|
||||
else:
|
||||
if apply_mean:
|
||||
sidechain_loss, sidechain_base_loss = torch.zeros(1, dtype=torch.float), torch.zeros(1, dtype=torch.float)
|
||||
else:
|
||||
sidechain_loss, sidechain_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(
|
||||
len(rot_loss), dtype=torch.float)
|
||||
|
||||
loss = tr_loss * tr_weight + rot_loss * rot_weight + tor_loss * tor_weight + sidechain_loss * sidechain_weight + backbone_loss * backbone_weight
|
||||
return loss, tr_loss.detach(), rot_loss.detach(), tor_loss.detach(), backbone_loss.detach(), sidechain_loss.detach(), \
|
||||
tr_base_loss, rot_base_loss, tor_base_loss, backbone_base_loss, sidechain_base_loss
|
||||
|
||||
|
||||
class AverageMeter():
|
||||
@@ -73,7 +136,7 @@ class AverageMeter():
|
||||
if self.intervals == 1:
|
||||
self.count += 1 if vals[0].dim() == 0 else len(vals[0])
|
||||
for type_idx, v in enumerate(vals):
|
||||
self.acc[self.types[type_idx]] += v.sum() if self.unpooled_metrics else v
|
||||
self.acc[self.types[type_idx]] += v.sum().cpu() if self.unpooled_metrics else v.cpu()
|
||||
else:
|
||||
for type_idx, v in enumerate(vals):
|
||||
self.count[type_idx].index_add_(0, interval_idx[type_idx], torch.ones(len(v)))
|
||||
@@ -93,22 +156,34 @@ class AverageMeter():
|
||||
return out
|
||||
|
||||
|
||||
def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weigths):
|
||||
def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weights):
|
||||
model.train()
|
||||
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'])
|
||||
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
|
||||
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'])
|
||||
|
||||
for data in tqdm(loader, total=len(loader)):
|
||||
if device.type == 'cuda' and len(data) == 1 or device.type == 'cpu' and data.num_graphs == 1:
|
||||
print("Skipping batch of size 1 since otherwise batchnorm would not work.")
|
||||
continue
|
||||
optimizer.zero_grad()
|
||||
data = [d.to(device) for d in data] if device.type == 'cuda' else data
|
||||
try:
|
||||
tr_pred, rot_pred, tor_pred = model(data)
|
||||
loss, tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss = \
|
||||
loss_fn(tr_pred, rot_pred, tor_pred, data=data, t_to_sigma=t_to_sigma, device=device)
|
||||
tr_pred, rot_pred, tor_pred, sidechain_pred = model(data)
|
||||
loss_tuple = loss_fn(tr_pred, rot_pred, tor_pred, sidechain_pred, data=data, t_to_sigma=t_to_sigma, device=device)
|
||||
if loss_tuple is None:
|
||||
print("None loss tuple, skipping")
|
||||
continue
|
||||
loss = loss_tuple[0]
|
||||
|
||||
if torch.any(torch.isnan(loss)):
|
||||
names = data.name if device.type == 'cpu' else [d.name for d in data]
|
||||
print("Nan loss, skipping batch with complexes", names)
|
||||
continue
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
ema_weigths.update(model.parameters())
|
||||
meter.add([loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss])
|
||||
if ema_weights is not None: ema_weights.update(model.parameters())
|
||||
meter.add([loss.cpu().detach(), *loss_tuple[1:]])
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
print('| WARNING: ran out of memory, skipping batch')
|
||||
@@ -125,40 +200,42 @@ def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weigt
|
||||
torch.cuda.empty_cache()
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
#raise e
|
||||
print(e)
|
||||
continue
|
||||
|
||||
return meter.summary()
|
||||
|
||||
|
||||
def test_epoch(model, loader, device, t_to_sigma, loss_fn, test_sigma_intervals=False):
|
||||
model.eval()
|
||||
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'],
|
||||
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
|
||||
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'],
|
||||
unpooled_metrics=True)
|
||||
|
||||
if test_sigma_intervals:
|
||||
meter_all = AverageMeter(
|
||||
['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'],
|
||||
['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
|
||||
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'],
|
||||
unpooled_metrics=True, intervals=10)
|
||||
|
||||
for data in tqdm(loader, total=len(loader)):
|
||||
try:
|
||||
with torch.no_grad():
|
||||
tr_pred, rot_pred, tor_pred = model(data)
|
||||
|
||||
loss, tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss = \
|
||||
loss_fn(tr_pred, rot_pred, tor_pred, data=data, t_to_sigma=t_to_sigma, apply_mean=False, device=device)
|
||||
meter.add([loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss])
|
||||
tr_pred, rot_pred, tor_pred, sidechain_pred = model(data)
|
||||
loss_tuple = loss_fn(tr_pred, rot_pred, tor_pred, sidechain_pred, data=data, t_to_sigma=t_to_sigma, apply_mean=False, device=device)
|
||||
if loss_tuple is None: continue
|
||||
meter.add([loss_tuple[0].cpu().detach(), *loss_tuple[1:]])
|
||||
|
||||
if test_sigma_intervals > 0:
|
||||
complex_t_tr, complex_t_rot, complex_t_tor = [torch.cat([d.complex_t[noise_type] for d in data]) for
|
||||
complex_t_tr, complex_t_rot, complex_t_tor = [torch.cat([data[i].complex_t[noise_type] for i in range(len(data))]) for
|
||||
noise_type in ['tr', 'rot', 'tor']]
|
||||
sigma_index_tr = torch.round(complex_t_tr.cpu() * (10 - 1)).long()
|
||||
sigma_index_rot = torch.round(complex_t_rot.cpu() * (10 - 1)).long()
|
||||
sigma_index_tor = torch.round(complex_t_tor.cpu() * (10 - 1)).long()
|
||||
meter_all.add(
|
||||
[loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss],
|
||||
[sigma_index_tr, sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_rot,
|
||||
sigma_index_tor, sigma_index_tr])
|
||||
meter_all.add([loss_tuple[0].cpu().detach(), *loss_tuple[1:]],
|
||||
[sigma_index_tr, sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_tr,
|
||||
sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_tr])
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'out of memory' in str(e):
|
||||
@@ -177,60 +254,87 @@ def test_epoch(model, loader, device, t_to_sigma, loss_fn, test_sigma_intervals=
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
print(e)
|
||||
continue
|
||||
|
||||
out = meter.summary()
|
||||
if test_sigma_intervals > 0: out.update(meter_all.summary())
|
||||
return out
|
||||
|
||||
|
||||
def inference_epoch(model, complex_graphs, device, t_to_sigma, args):
|
||||
t_schedule = get_t_schedule(inference_steps=args.inference_steps)
|
||||
def inference_epoch_fix(model, complex_graphs, device, t_to_sigma, args):
|
||||
t_schedule = get_t_schedule(sigma_schedule='expbeta', inference_steps=args.inference_steps,
|
||||
inf_sched_alpha=1, inf_sched_beta=1)
|
||||
tr_schedule, rot_schedule, tor_schedule = t_schedule, t_schedule, t_schedule
|
||||
|
||||
dataset = ListDataset(complex_graphs)
|
||||
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
|
||||
rmsds = []
|
||||
rmsds, min_rmsds = [], []
|
||||
|
||||
for orig_complex_graph in tqdm(loader):
|
||||
data_list = [copy.deepcopy(orig_complex_graph)]
|
||||
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(args.inference_samples)]
|
||||
randomize_position(data_list, args.no_torsion, False, args.tr_sigma_max)
|
||||
|
||||
predictions_list = None
|
||||
failed_convergence_counter = 0
|
||||
while predictions_list == None:
|
||||
try:
|
||||
predictions_list, confidences = sampling(data_list=data_list, model=model.module if device.type=='cuda' else model,
|
||||
predictions_list, confidences = sampling(data_list=data_list, model=model.module if device.type == 'cuda' else model,
|
||||
inference_steps=args.inference_steps,
|
||||
tr_schedule=tr_schedule, rot_schedule=rot_schedule,
|
||||
tor_schedule=tor_schedule,
|
||||
device=device, t_to_sigma=t_to_sigma, model_args=args)
|
||||
device=device, t_to_sigma=t_to_sigma, model_args=args,
|
||||
t_schedule=t_schedule)
|
||||
except Exception as e:
|
||||
if 'failed to converge' in str(e):
|
||||
failed_convergence_counter += 1
|
||||
if failed_convergence_counter > 5:
|
||||
print('| WARNING: SVD failed to converge 5 times - skipping the complex')
|
||||
break
|
||||
print('| WARNING: SVD failed to converge - trying again with a new sample')
|
||||
else:
|
||||
raise e
|
||||
if failed_convergence_counter > 5: continue
|
||||
failed_convergence_counter += 1
|
||||
if failed_convergence_counter > 5:
|
||||
print('failed 5 times - skipping the complex')
|
||||
break
|
||||
print("Exception while running inference on complex:", e)
|
||||
if failed_convergence_counter > 5:
|
||||
rmsds.extend([100] * args.inference_samples)
|
||||
min_rmsds.append(100)
|
||||
continue
|
||||
|
||||
if args.no_torsion:
|
||||
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() +
|
||||
orig_complex_graph.original_center.cpu().numpy())
|
||||
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph[
|
||||
'ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
|
||||
|
||||
filterHs = torch.not_equal(predictions_list[0]['ligand'].x[:, 0], 0).cpu().numpy()
|
||||
|
||||
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
|
||||
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
|
||||
# if len(orig_complex_graph['ligand'].orig_pos.shape) == 3:
|
||||
# orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
|
||||
|
||||
ligand_pos = np.asarray(
|
||||
[complex_graph['ligand'].pos.cpu().numpy()[filterHs] for complex_graph in predictions_list])
|
||||
orig_ligand_pos = np.expand_dims(
|
||||
orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy(), axis=0)
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
|
||||
rmsds.append(rmsd)
|
||||
if len(orig_complex_graph['ligand'].orig_pos.shape) == 2:
|
||||
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[None, :, :]
|
||||
try:
|
||||
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[:, filterHs] - orig_complex_graph.original_center.cpu().numpy()
|
||||
except Exception as e:
|
||||
print("problem with orig_pos which is of shape:", orig_complex_graph['ligand'].orig_pos.shape, e)
|
||||
continue
|
||||
mol = RemoveAllHs(orig_complex_graph.mol[0])
|
||||
complex_rmsds = []
|
||||
for i in range(len(orig_ligand_pos)):
|
||||
try:
|
||||
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos[i], [l for l in ligand_pos])
|
||||
except Exception as e:
|
||||
print("Using non corrected RMSD because of the error:", e)
|
||||
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos[i]) ** 2).sum(axis=2).mean(axis=1))
|
||||
complex_rmsds.append(rmsd)
|
||||
complex_rmsds = np.asarray(complex_rmsds)
|
||||
rmsd = np.min(complex_rmsds, axis=0)
|
||||
|
||||
rmsds.extend([r for r in rmsd])
|
||||
min_rmsds.append(rmsd.min(axis=0))
|
||||
|
||||
rmsds = np.array(rmsds)
|
||||
min_rmsds = np.array(min_rmsds)
|
||||
losses = {'rmsds_lt2': (100 * (rmsds < 2).sum() / len(rmsds)),
|
||||
'rmsds_lt5': (100 * (rmsds < 5).sum() / len(rmsds))}
|
||||
'rmsds_lt5': (100 * (rmsds < 5).sum() / len(rmsds)),
|
||||
'min_rmsds_lt2': (100 * (min_rmsds < 2).sum() / len(min_rmsds)),
|
||||
'min_rmsds_lt5': (100 * (min_rmsds < 5).sum() / len(min_rmsds)),}
|
||||
return losses
|
||||
|
||||
282
utils/utils.py
282
utils/utils.py
@@ -2,19 +2,23 @@ import os
|
||||
import subprocess
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from rdkit import Chem
|
||||
from rdkit.Chem import RemoveHs, MolToPDBFile
|
||||
from torch import nn, Tensor
|
||||
from torch_geometric.nn.data_parallel import DataParallel
|
||||
from torch_geometric.utils import degree, subgraph
|
||||
|
||||
from models.all_atom_score_model import TensorProductScoreModel as AAScoreModel
|
||||
from models.score_model import TensorProductScoreModel as CGScoreModel
|
||||
from models.aa_model import AAModel
|
||||
from models.cg_model import CGModel
|
||||
from models.old_aa_model import AAOldModel
|
||||
from models.old_cg_model import CGOldModel
|
||||
from utils.diffusion_utils import get_timestep_embedding
|
||||
from spyrmsd import rmsd, molecule
|
||||
|
||||
|
||||
def get_obrmsd(mol1_path, mol2_path, cache_name=None):
|
||||
@@ -61,6 +65,53 @@ def read_strings_from_txt(path):
|
||||
return [line.rstrip() for line in lines]
|
||||
|
||||
|
||||
def unbatch(src, batch: Tensor, dim: int = 0) -> List[Tensor]:
|
||||
r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension
|
||||
:obj:`dim`.
|
||||
|
||||
Args:
|
||||
src (Tensor): The source tensor.
|
||||
batch (LongTensor): The batch vector
|
||||
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
|
||||
entry in :obj:`src` to a specific example. Must be ordered.
|
||||
dim (int, optional): The dimension along which to split the :obj:`src`
|
||||
tensor. (default: :obj:`0`)
|
||||
|
||||
:rtype: :class:`List[Tensor]`
|
||||
"""
|
||||
sizes = degree(batch, dtype=torch.long).tolist()
|
||||
if isinstance(src, numpy.ndarray):
|
||||
return np.split(src, np.array(sizes).cumsum()[:-1], axis=dim)
|
||||
else:
|
||||
return src.split(sizes, dim)
|
||||
|
||||
|
||||
def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
|
||||
r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.
|
||||
|
||||
Args:
|
||||
edge_index (Tensor): The edge_index tensor. Must be ordered.
|
||||
batch (LongTensor): The batch vector
|
||||
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
|
||||
node to a specific example. Must be ordered.
|
||||
|
||||
:rtype: :class:`List[Tensor]`
|
||||
"""
|
||||
deg = degree(batch, dtype=torch.int64)
|
||||
ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)
|
||||
|
||||
edge_batch = batch[edge_index[0]]
|
||||
edge_index = edge_index - ptr[edge_batch]
|
||||
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
|
||||
return edge_index.split(sizes, dim=1)
|
||||
|
||||
|
||||
def unbatch_edge_attributes(edge_attributes, edge_index: Tensor, batch: Tensor) -> List[Tensor]:
|
||||
edge_batch = batch[edge_index[0]]
|
||||
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
|
||||
return edge_attributes.split(sizes, dim=0)
|
||||
|
||||
|
||||
def save_yaml_file(path, content):
|
||||
assert isinstance(path, str), f'path must be a string, got {path} which is a {type(path)}'
|
||||
content = yaml.dump(data=content)
|
||||
@@ -70,12 +121,47 @@ def save_yaml_file(path, content):
|
||||
f.write(content)
|
||||
|
||||
|
||||
def get_optimizer_and_scheduler(args, model, scheduler_mode='min'):
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay)
|
||||
def unfreeze_layer(model):
|
||||
for name, child in (model.named_children()):
|
||||
#print(name, child.parameters())
|
||||
for param in child.parameters():
|
||||
param.requires_grad = True
|
||||
|
||||
|
||||
def get_optimizer_and_scheduler(args, model, scheduler_mode='min', step=0, optimizer=None):
|
||||
if args.scheduler == 'layer_linear_warmup':
|
||||
if step == 0:
|
||||
for name, child in (model.named_children()):
|
||||
if name.find('batch_norm') == -1:
|
||||
for name, param in child.named_parameters():
|
||||
if name.find('batch_norm') == -1:
|
||||
param.requires_grad = False
|
||||
|
||||
for l in [model.center_edge_embedding, model.final_conv, model.tr_final_layer, model.rot_final_layer,
|
||||
model.final_edge_embedding, model.final_tp_tor, model.tor_bond_conv, model.tor_final_layer]:
|
||||
unfreeze_layer(l)
|
||||
|
||||
elif 0 < step <= args.num_conv_layers:
|
||||
unfreeze_layer(model.conv_layers[-step])
|
||||
|
||||
elif step == args.num_conv_layers + 1:
|
||||
for l in [model.lig_node_embedding, model.lig_edge_embedding, model.rec_node_embedding, model.rec_edge_embedding,
|
||||
model.rec_sigma_embedding, model.cross_edge_embedding, model.rec_emb_layers, model.lig_emb_layers]:
|
||||
unfreeze_layer(l)
|
||||
|
||||
if step == 0 or args.scheduler == 'layer_linear_warmup':
|
||||
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay)
|
||||
|
||||
scheduler_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7, patience=args.scheduler_patience, min_lr=args.lr / 100)
|
||||
if args.scheduler == 'plateau':
|
||||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7,
|
||||
patience=args.scheduler_patience, min_lr=args.lr / 100)
|
||||
scheduler = scheduler_plateau
|
||||
elif args.scheduler == 'linear_warmup' or args.scheduler == 'layer_linear_warmup':
|
||||
if (args.scheduler == 'linear_warmup' and step < 1) or \
|
||||
(args.scheduler == 'layer_linear_warmup' and step <= args.num_conv_layers + 1):
|
||||
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_start_factor, end_factor=1.0,
|
||||
total_iters=args.warmup_dur)
|
||||
else:
|
||||
scheduler = scheduler_plateau
|
||||
else:
|
||||
print('No scheduler')
|
||||
scheduler = None
|
||||
@@ -83,63 +169,119 @@ def get_optimizer_and_scheduler(args, model, scheduler_mode='min'):
|
||||
return optimizer, scheduler
|
||||
|
||||
|
||||
def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False):
|
||||
if 'all_atoms' in args and args.all_atoms:
|
||||
model_class = AAScoreModel
|
||||
else:
|
||||
model_class = CGScoreModel
|
||||
def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False, old=False):
|
||||
|
||||
timestep_emb_func = get_timestep_embedding(
|
||||
embedding_type=args.embedding_type,
|
||||
embedding_type=args.embedding_type if 'embedding_type' in args else 'sinusoidal',
|
||||
embedding_dim=args.sigma_embed_dim,
|
||||
embedding_scale=args.embedding_scale)
|
||||
embedding_scale=args.embedding_scale if 'embedding_type' in args else 10000)
|
||||
|
||||
lm_embedding_type = None
|
||||
if args.esm_embeddings_path is not None: lm_embedding_type = 'esm'
|
||||
if old:
|
||||
if 'all_atoms' in args and args.all_atoms:
|
||||
model_class = AAOldModel
|
||||
else:
|
||||
model_class = CGOldModel
|
||||
|
||||
model = model_class(t_to_sigma=t_to_sigma,
|
||||
device=device,
|
||||
no_torsion=args.no_torsion,
|
||||
timestep_emb_func=timestep_emb_func,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
lig_max_radius=args.max_radius,
|
||||
scale_by_sigma=args.scale_by_sigma,
|
||||
sigma_embed_dim=args.sigma_embed_dim,
|
||||
ns=args.ns, nv=args.nv,
|
||||
distance_embed_dim=args.distance_embed_dim,
|
||||
cross_distance_embed_dim=args.cross_distance_embed_dim,
|
||||
batch_norm=not args.no_batch_norm,
|
||||
dropout=args.dropout,
|
||||
use_second_order_repr=args.use_second_order_repr,
|
||||
cross_max_distance=args.cross_max_distance,
|
||||
dynamic_max_cross=args.dynamic_max_cross,
|
||||
lm_embedding_type=lm_embedding_type,
|
||||
confidence_mode=confidence_mode,
|
||||
num_confidence_outputs=len(
|
||||
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
|
||||
args.rmsd_classification_cutoff, list) else 1)
|
||||
lm_embedding_type = None
|
||||
if args.esm_embeddings_path is not None: lm_embedding_type = 'esm'
|
||||
|
||||
if device.type == 'cuda' and not no_parallel:
|
||||
model = model_class(t_to_sigma=t_to_sigma,
|
||||
device=device,
|
||||
no_torsion=args.no_torsion,
|
||||
timestep_emb_func=timestep_emb_func,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
lig_max_radius=args.max_radius,
|
||||
scale_by_sigma=args.scale_by_sigma,
|
||||
sigma_embed_dim=args.sigma_embed_dim,
|
||||
norm_by_sigma='norm_by_sigma' in args and args.norm_by_sigma,
|
||||
ns=args.ns, nv=args.nv,
|
||||
distance_embed_dim=args.distance_embed_dim,
|
||||
cross_distance_embed_dim=args.cross_distance_embed_dim,
|
||||
batch_norm=not args.no_batch_norm,
|
||||
dropout=args.dropout,
|
||||
use_second_order_repr=args.use_second_order_repr,
|
||||
cross_max_distance=args.cross_max_distance,
|
||||
dynamic_max_cross=args.dynamic_max_cross,
|
||||
smooth_edges=args.smooth_edges if "smooth_edges" in args else False,
|
||||
odd_parity=args.odd_parity if "odd_parity" in args else False,
|
||||
lm_embedding_type=lm_embedding_type,
|
||||
confidence_mode=confidence_mode,
|
||||
affinity_prediction=args.affinity_prediction if 'affinity_prediction' in args else False,
|
||||
parallel=args.parallel if "parallel" in args else 1,
|
||||
num_confidence_outputs=len(
|
||||
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
|
||||
args.rmsd_classification_cutoff, list) else 1,
|
||||
parallel_aggregators=args.parallel_aggregators if "parallel_aggregators" in args else "",
|
||||
fixed_center_conv=not args.not_fixed_center_conv if "not_fixed_center_conv" in args else False,
|
||||
no_aminoacid_identities=args.no_aminoacid_identities if "no_aminoacid_identities" in args else False,
|
||||
include_miscellaneous_atoms=args.include_miscellaneous_atoms if hasattr(args, 'include_miscellaneous_atoms') else False,
|
||||
use_old_atom_encoder=args.use_old_atom_encoder if hasattr(args, 'use_old_atom_encoder') else True)
|
||||
|
||||
else:
|
||||
if 'all_atoms' in args and args.all_atoms:
|
||||
model_class = AAModel
|
||||
else:
|
||||
model_class = CGModel
|
||||
|
||||
lm_embedding_type = None
|
||||
if ('moad_esm_embeddings_path' in args and args.moad_esm_embeddings_path is not None) or \
|
||||
('pdbbind_esm_embeddings_path' in args and args.pdbbind_esm_embeddings_path is not None) or \
|
||||
('pdbsidechain_esm_embeddings_path' in args and args.pdbsidechain_esm_embeddings_path is not None) or \
|
||||
('esm_embeddings_path' in args and args.esm_embeddings_path is not None):
|
||||
lm_embedding_type = 'precomputed'
|
||||
if 'esm_embeddings_model' in args and args.esm_embeddings_model is not None: lm_embedding_type = args.esm_embeddings_model
|
||||
|
||||
model = model_class(t_to_sigma=t_to_sigma,
|
||||
device=device,
|
||||
no_torsion=args.no_torsion,
|
||||
timestep_emb_func=timestep_emb_func,
|
||||
num_conv_layers=args.num_conv_layers,
|
||||
lig_max_radius=args.max_radius,
|
||||
scale_by_sigma=args.scale_by_sigma,
|
||||
sigma_embed_dim=args.sigma_embed_dim,
|
||||
norm_by_sigma='norm_by_sigma' in args and args.norm_by_sigma,
|
||||
ns=args.ns, nv=args.nv,
|
||||
distance_embed_dim=args.distance_embed_dim,
|
||||
cross_distance_embed_dim=args.cross_distance_embed_dim,
|
||||
batch_norm=not args.no_batch_norm,
|
||||
dropout=args.dropout,
|
||||
use_second_order_repr=args.use_second_order_repr,
|
||||
cross_max_distance=args.cross_max_distance,
|
||||
dynamic_max_cross=args.dynamic_max_cross,
|
||||
smooth_edges=args.smooth_edges if "smooth_edges" in args else False,
|
||||
odd_parity=args.odd_parity if "odd_parity" in args else False,
|
||||
lm_embedding_type=lm_embedding_type,
|
||||
confidence_mode=confidence_mode,
|
||||
affinity_prediction=args.affinity_prediction if 'affinity_prediction' in args else False,
|
||||
parallel=args.parallel if "parallel" in args else 1,
|
||||
num_confidence_outputs=len(
|
||||
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
|
||||
args.rmsd_classification_cutoff, list) else 1,
|
||||
atom_num_confidence_outputs=len(
|
||||
args.atom_rmsd_classification_cutoff) + 1 if 'atom_rmsd_classification_cutoff' in args and isinstance(
|
||||
args.atom_rmsd_classification_cutoff, list) else 1,
|
||||
parallel_aggregators=args.parallel_aggregators if "parallel_aggregators" in args else "",
|
||||
fixed_center_conv=not args.not_fixed_center_conv if "not_fixed_center_conv" in args else False,
|
||||
no_aminoacid_identities=args.no_aminoacid_identities if "no_aminoacid_identities" in args else False,
|
||||
include_miscellaneous_atoms=args.include_miscellaneous_atoms if hasattr(args, 'include_miscellaneous_atoms') else False,
|
||||
sh_lmax=args.sh_lmax if 'sh_lmax' in args else 2,
|
||||
differentiate_convolutions=not args.no_differentiate_convolutions if "no_differentiate_convolutions" in args else True,
|
||||
tp_weights_layers=args.tp_weights_layers if "tp_weights_layers" in args else 2,
|
||||
num_prot_emb_layers=args.num_prot_emb_layers if "num_prot_emb_layers" in args else 0,
|
||||
reduce_pseudoscalars=args.reduce_pseudoscalars if "reduce_pseudoscalars" in args else False,
|
||||
embed_also_ligand=args.embed_also_ligand if "embed_also_ligand" in args else False,
|
||||
atom_confidence=args.atom_confidence_loss_weight > 0.0 if "atom_confidence_loss_weight" in args else False,
|
||||
sidechain_pred=(hasattr(args, 'sidechain_loss_weight') and args.sidechain_loss_weight > 0) or
|
||||
(hasattr(args, 'backbone_loss_weight') and args.backbone_loss_weight > 0),
|
||||
depthwise_convolution=args.depthwise_convolution if hasattr(args, 'depthwise_convolution') else False)
|
||||
|
||||
if device.type == 'cuda' and not no_parallel and ('dataset' not in args or not args.dataset == 'torsional'):
|
||||
model = DataParallel(model)
|
||||
model.to(device)
|
||||
return model
|
||||
|
||||
|
||||
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None):
|
||||
with time_limit(10):
|
||||
mol = molecule.Molecule.from_rdkit(mol)
|
||||
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
|
||||
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
|
||||
mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
|
||||
RMSD = rmsd.symmrmsd(
|
||||
coords1,
|
||||
coords2,
|
||||
mol.atomicnums,
|
||||
mol2_atomicnums,
|
||||
mol.adjacency_matrix,
|
||||
mol2_adjacency_matrix,
|
||||
)
|
||||
return RMSD
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
class TimeoutException(Exception): pass
|
||||
@@ -241,3 +383,31 @@ class ExponentialMovingAverage:
|
||||
self.decay = state_dict['decay']
|
||||
self.num_updates = state_dict['num_updates']
|
||||
self.shadow_params = [tensor.to(device) for tensor in state_dict['shadow_params']]
|
||||
|
||||
|
||||
def crop_beyond(complex_graph, cutoff, all_atoms):
|
||||
ligand_pos = complex_graph['ligand'].pos
|
||||
receptor_pos = complex_graph['receptor'].pos
|
||||
residues_to_keep = torch.any(torch.sum((ligand_pos.unsqueeze(0) - receptor_pos.unsqueeze(1)) ** 2, -1) < cutoff ** 2, dim=1)
|
||||
|
||||
if all_atoms:
|
||||
#print(complex_graph['atom'].x.shape, complex_graph['atom'].pos.shape, complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index.shape)
|
||||
atom_to_res_mapping = complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index[1]
|
||||
atoms_to_keep = residues_to_keep[atom_to_res_mapping]
|
||||
rec_remapper = (torch.cumsum(residues_to_keep.long(), dim=0) - 1)
|
||||
atom_to_res_new_mapping = rec_remapper[atom_to_res_mapping][atoms_to_keep]
|
||||
atom_res_edge_index = torch.stack([torch.arange(len(atom_to_res_new_mapping), device=atom_to_res_new_mapping.device), atom_to_res_new_mapping])
|
||||
|
||||
complex_graph['receptor'].pos = complex_graph['receptor'].pos[residues_to_keep]
|
||||
complex_graph['receptor'].x = complex_graph['receptor'].x[residues_to_keep]
|
||||
complex_graph['receptor'].side_chain_vecs = complex_graph['receptor'].side_chain_vecs[residues_to_keep]
|
||||
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
|
||||
subgraph(residues_to_keep, complex_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
|
||||
|
||||
if all_atoms:
|
||||
complex_graph['atom'].x = complex_graph['atom'].x[atoms_to_keep]
|
||||
complex_graph['atom'].pos = complex_graph['atom'].pos[atoms_to_keep]
|
||||
complex_graph['atom', 'atom_contact', 'atom'].edge_index = subgraph(atoms_to_keep, complex_graph['atom', 'atom_contact', 'atom'].edge_index, relabel_nodes=True)[0]
|
||||
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
|
||||
|
||||
#print("cropped", 1-torch.mean(residues_to_keep.float()), 'residues', 1-torch.mean(atoms_to_keep.float()), 'atoms')
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
## Visualizations of complexes that were unseen during training. EquiBind (cyan), DockDiff highest confidence sample (red), all other DockDiff samples (orange), and the crystal structure (green).
|
||||
|
||||
Complex 6agt:
|
||||

|
||||
|
||||
Complex 6dz3:
|
||||

|
||||
|
||||
Complex 6gdy:
|
||||

|
||||
|
||||
Complex 6ckl:
|
||||

|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 16 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 14 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 23 MiB |
Binary file not shown.
|
Before Width: | Height: | Size: 12 MiB |
Binary file not shown.
@@ -1,85 +0,0 @@
|
||||
all_atoms: true
|
||||
atom_max_neighbors: 8
|
||||
atom_radius: 5
|
||||
balance: false
|
||||
batch_size: 16
|
||||
best_model_save_frequency: 5
|
||||
c_alpha_max_neighbors: 24
|
||||
cache_creation_id: 1
|
||||
cache_ids_to_combine:
|
||||
- '1'
|
||||
- '2'
|
||||
- '3'
|
||||
- '4'
|
||||
cache_path: data/cache
|
||||
ckpt: best_model.pt
|
||||
confidence_dropout: 0.0
|
||||
confidence_loss_weigth: 1
|
||||
confidence_no_batchnorm: false
|
||||
confidence_weight: 0.33
|
||||
config: null
|
||||
cross_distance_embed_dim: 32
|
||||
cross_max_distance: 80
|
||||
data_dir: data/PDBBind_processed/
|
||||
distance_embed_dim: 32
|
||||
dropout: 0.1
|
||||
dynamic_max_cross: true
|
||||
embedding_scale: 10000
|
||||
embedding_type: sinusoidal
|
||||
esm_embeddings_path: data/esm2_3billion_embeddings.pt
|
||||
high_confidence_threshold: 5.0
|
||||
include_confidence_prediction: false
|
||||
inference_steps: 20
|
||||
limit_complexes: 0
|
||||
lm_embeddings_path: null
|
||||
log_dir: workdir
|
||||
lr: 0.0003
|
||||
main_metric: loss
|
||||
main_metric_goal: min
|
||||
matching_maxiter: 20
|
||||
matching_popsize: 20
|
||||
max_lig_size: null
|
||||
max_radius: 5.0
|
||||
model_save_frequency: 0
|
||||
n_epochs: 100
|
||||
no_batch_norm: false
|
||||
no_torsion: false
|
||||
ns: 24
|
||||
num_conformers: 1
|
||||
num_conv_layers: 5
|
||||
num_workers: 1
|
||||
nv: 6
|
||||
original_model_dir: workdir/temp_restart_ema_ESM2emb_tr34
|
||||
project: diffdock_confidence
|
||||
receptor_radius: 15.0
|
||||
remove_hs: true
|
||||
restart_dir: null
|
||||
rmsd_classification_cutoff:
|
||||
- 2.0
|
||||
rmsd_prediction: false
|
||||
rot_sigma_max: 1.55
|
||||
rot_sigma_min: 0.03
|
||||
rot_weight: 0.33
|
||||
run_name: confidencetrain_samples28_FILTERFROM_ema_ESM2emb_tr34
|
||||
samples_per_complex: 7
|
||||
scale_by_sigma: true
|
||||
scheduler: plateau
|
||||
scheduler_patience: 50
|
||||
sigma_embed_dim: 32
|
||||
split_test: data/splits/timesplit_test
|
||||
split_train: data/splits/timesplit_no_lig_overlap_train
|
||||
split_val: data/splits/timesplit_no_lig_overlap_val
|
||||
tor_sigma_max: 3.14
|
||||
tor_sigma_min: 0.0314
|
||||
tor_sigma_schedule: expbeta
|
||||
tor_weight: 0.33
|
||||
tr_only_confidence: true
|
||||
tr_sigma_max: 34.0
|
||||
tr_sigma_min: 0.1
|
||||
tr_weight: 0.33
|
||||
train_sampling: linear
|
||||
transfer_weights: false
|
||||
use_original_model_cache: false
|
||||
use_second_order_repr: false
|
||||
w_decay: 0.0
|
||||
wandb: true
|
||||
Binary file not shown.
@@ -1,83 +0,0 @@
|
||||
all_atoms: false
|
||||
atom_max_neighbors: 8
|
||||
atom_radius: 5
|
||||
batch_size: 16
|
||||
c_alpha_max_neighbors: 24
|
||||
cache_path: data/cacheNew
|
||||
confidence_dropout: 0.0
|
||||
confidence_no_batchnorm: false
|
||||
config: null
|
||||
cross_distance_embed_dim: 64
|
||||
cross_max_distance: 80
|
||||
cudnn_benchmark: true
|
||||
data_dir: data/PDBBind_processed/
|
||||
dataset: pdbbind
|
||||
distance_embed_dim: 64
|
||||
dropout: 0.1
|
||||
dynamic_max_cross: true
|
||||
ema_rate: 0.999
|
||||
embedding_scale: 10000
|
||||
embedding_type: sinusoidal
|
||||
esm_embeddings_path: data/esm2_3billion_embeddings.pt
|
||||
high_confidence_threshold: 5.0
|
||||
include_confidence_prediction: false
|
||||
inf_pocket_cutoff: 5
|
||||
inf_pocket_knowledge: false
|
||||
inference_earlystop_goal: max
|
||||
inference_earlystop_metric: valinf_rmsds_lt2
|
||||
inference_steps: 20
|
||||
limit_complexes: 0
|
||||
lm_embeddings_path: null
|
||||
log_dir: workdir
|
||||
lr: 0.001
|
||||
matching_maxiter: 20
|
||||
matching_popsize: 20
|
||||
max_lig_size: null
|
||||
max_radius: 5.0
|
||||
multiplicity: 1
|
||||
n_epochs: 850
|
||||
no_batch_norm: false
|
||||
no_torsion: false
|
||||
norm_by_sigma: false
|
||||
not_full_dataset: false
|
||||
ns: 48
|
||||
num_conformers: 1
|
||||
num_conv_layers: 6
|
||||
num_dataloader_workers: 1
|
||||
num_gpus: 1
|
||||
num_inference_complexes: 500
|
||||
num_workers: 1
|
||||
nv: 10
|
||||
odd_parity: false
|
||||
pin_memory: true
|
||||
pretrained_model: null
|
||||
project: diffdock_train
|
||||
receptor_radius: 15.0
|
||||
remove_hs: true
|
||||
restart_dir: null
|
||||
rot_sigma_max: 1.55
|
||||
rot_sigma_min: 0.03
|
||||
rot_weight: 0.33
|
||||
run_name: big_ema_ESM2emb
|
||||
scale_by_sigma: true
|
||||
scheduler: plateau
|
||||
scheduler_patience: 30
|
||||
sigma_embed_dim: 64
|
||||
split_test: data/splits/timesplit_test
|
||||
split_train: data/splits/timesplit_no_lig_overlap_train
|
||||
split_val: data/splits/timesplit_no_lig_overlap_val
|
||||
test_sigma_intervals: true
|
||||
tor_sigma_max: 3.14
|
||||
tor_sigma_min: 0.0314
|
||||
tor_weight: 0.33
|
||||
tr_only_confidence: true
|
||||
tr_sigma_max: 19.0
|
||||
tr_sigma_min: 0.1
|
||||
tr_weight: 0.33
|
||||
train_inference_freq: null
|
||||
train_sampling: linear
|
||||
use_ema: true
|
||||
use_second_order_repr: false
|
||||
val_inference_freq: 5
|
||||
w_decay: 0.0
|
||||
wandb: true
|
||||
Reference in New Issue
Block a user