first commit v1.1

This commit is contained in:
Gabriele Corso
2024-02-28 11:21:46 -05:00
parent bc6b515145
commit 001c4fa46e
83 changed files with 236826 additions and 4668 deletions

218
README.md
View File

@@ -1,52 +1,75 @@
# DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/diffdock-diffusion-steps-twists-and-turns-for/blind-docking-on-pdbbind)](https://paperswithcode.com/sota/blind-docking-on-pdbbind?p=diffdock-diffusion-steps-twists-and-turns-for)
[![Open in HuggingFace](https://huggingface.co/datasets/huggingface/badges/raw/main/open-in-hf-spaces-sm.svg)](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web)
### [Paper on arXiv](https://arxiv.org/abs/2210.01776)
![Alt Text](overview.png)
### [Original paper on arXiv](https://arxiv.org/abs/2210.01776)
Implementation of DiffDock, state-of-the-art method for molecular docking, by Gabriele Corso*, Hannes Stark*, Bowen Jing*, Regina Barzilay and Tommi Jaakkola.
This repository contains all code, instructions and model weights necessary to run the method or to retrain a model.
This repository contains code and instructions to run the method.
If you have any question, feel free to open an issue or reach out to us: [gcorso@mit.edu](gcorso@mit.edu), [hstark@mit.edu](hstark@mit.edu), [bjing@mit.edu](bjing@mit.edu).
![Alt Text](visualizations/overview.png)
**Update February 2024:** We have released DiffDock-L, a new version of DiffDock that provides a significant improvement in performance and generalization capacity (see the description of the new version in [our new paper]()). By default the repository now runs the new model, please use GitHub commit history to run the original DiffDock model. Further we now provide instructions for Docker and to set up your own local UI interface.
The repository also contains all the scripts to run the baselines and generate the figures.
Additionally, there are visualization videos in `visualizations`.
You might also be interested in this [Google Colab notebook](https://colab.research.google.com/drive/1CTtUGg05-2MtlWmfJhqzLTtkDDaxCDOQ#scrollTo=zlPOKLIBsiPU) to run DiffDock by Brian Naughton.
# Dataset
The files in `data` contain the names for the time-based data split.
If you want to train one of our models with the data then:
1. download it from [zenodo](https://zenodo.org/record/6408497)
2. unzip the directory and place it into `data` such that you have the path `data/PDBBind_processed`
You can also try out the model on [Hugging Face Spaces](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web).
## Setup Environment
<details><summary><b>Citation</b></summary>
We will set up the environment using [Anaconda](https://docs.anaconda.com/anaconda/install/index.html). Clone the
current repo
If you use this code or the models in your research, please cite the following paper:
git clone https://github.com/gcorso/DiffDock.git
```bibtex
@inproceedings{corso2023diffdock,
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
booktitle={International Conference on Learning Representations (ICLR)},
year={2023}
}
```
This is an example for how to set up a working conda environment to run the code (but make sure to use the correct pytorch, pytorch-geometric, cuda versions or cpu only versions):
If you use the latest version, DiffDock-L, please also cite the following paper:
conda create --name diffdock python=3.9
conda activate diffdock
conda install pytorch==1.11.0 pytorch-cuda=11.7 -c pytorch -c nvidia
pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric==2.0.4 -f https://data.pyg.org/whl/torch-1.11.0+cu117.html
python -m pip install PyYAML scipy "networkx[default]" biopython rdkit-pypi e3nn spyrmsd pandas biopandas
```bibtex
@inproceedings{corso2024discovery,
title={Deep Confident Steps to New Pockets: Strategies for Docking Generalization},
author={Corso, Gabriele and Deng, Arthur and Polizzi, Nicholas and Barzilay, Regina and Jaakkola, Tommi},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
Then you need to install ESM that we use both for protein sequence embeddings and for the protein structure prediction in case you only have the sequence of your target. Note that OpenFold (and so ESMFold) requires a GPU. If you don't have a GPU, you can still use DiffDock with existing protein structures.
```
</details>
pip install "fair-esm[esmfold]"
pip install 'dllogger @ git+https://github.com/NVIDIA/dllogger.git'
pip install 'openfold @ git+https://github.com/aqlaboratory/openfold.git@4b41059694619831a7db195b7e0988fc4ff3a307'
<details open><summary><b>Table of contents</b></summary>
- [Usage](#usage)
- [Quick Start](#quickstart)
- [Setup Environment](#environment)
- [Docking Prediction](#inference)
- [FAQ](#faq)
- [Datasets](#datasets)
- [Replicate results](#replicate)
- [Citations](#citations)
- [License](#license)
- [Acknowledgements](#acknowledgements)
</details>
# Running DiffDock on your own complexes
## Usage <a name="usage"></a>
### Quick Start <a name="quickstart"></a>
You can directly try out the model without the need of installing anything through [Hugging Face Spaces](https://huggingface.co/spaces/reginabarzilaygroup/DiffDock-Web). Credit for the current HF interface goes to Jacob Silterra and for the previous version to Simon Duerr.
### Setup Environment <a name="environment"></a>
### Docking Prediction <a name="inference"></a>
We support multiple input formats depending on whether you only want to make predictions for a single complex or for many at once.\
The protein inputs need to be `.pdb` files or sequences that will be folded with ESMFold. The ligand input can either be a SMILES string or a filetype that RDKit can read like `.sdf` or `.mol2`.
@@ -57,18 +80,60 @@ An example .csv is at `data/protein_ligand_example_csv.csv` and you would use it
And you are ready to run inference:
python -m inference --protein_ligand_csv data/protein_ligand_example_csv.csv --out_dir results/user_predictions_small --inference_steps 20 --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
python -m inference --config inference_args.yaml --protein_ligand_csv data/protein_ligand_example_csv.csv --out_dir results/user_predictions_small
When providing the `.pdb` files you can run DiffDock also on CPU, however, if possible, we recommend using a GPU as the model runs significantly faster. Note that the first time you run DiffDock on a device the program will precompute and store in cache look-up tables for SO(2) and SO(3) distributions (typically takes a couple of minutes), this won't be repeated in following runs.
# Retraining DiffDock
Download the data and place it as described in the "Dataset" section above.
## FAQ <a name="faq"></a>
<details>
<summary><b>How to interpret the DiffDock output confidence score?</b> </summary>
It can be hard to interpret and compare confidence score of different complexes or different protein conformations, however, here a rough guideline that we typically use (c is the confidence score of the top pose):
- c > 0 high confidence
- -1.5 < c < 0 moderate confidence
- c < -1.5 low confidence
This is assuming the complex is similar to what DiffDock saw in the training set i.e. a not too large drug-like molecule bound to medium size protein (1 or 2 chains) in a conformation that is similar to the bound one (e.g. if it comes from an homologue crystal structure). If you are dealing with a large ligand, a large protein complex and/or an app/unbound protein conformation you should shift these intervals down.
</details>
<details>
<summary><b>Does DiffDock predict the binding affinity of the ligand to the protein?</b> </summary>
No, DiffDock does not predict the binding affinity of the ligand to the protein. It predicts the 3D structure of the complex and it outputs a confidence score. This latter is a measure of the quality of the prediction, i.e. the model's confidence in its prediction of the binding structure. Several of our collaborators have seen this to have some correlation with binding affinity (intuitively if a ligand does not bind there will be no good pose), but it is not a direct measure of it.
We are working on better affinity prediction models, but in the meantime we recommend combining DiffDock's prediction with other tools such as docking function (e.g. GNINA), MM/GBSA or absolute binding free energy calculations. For this we recommend to first relax the DiffDock's structure predictions with the tool/force field used for the affinity prediction.
</details>
<details>
<summary><b>Can I use DiffDock for protein-protein or protein-nucleic acid interactions?</b> </summary>
While the program might not throw and error when fed with a large biomolecules as input, the model has only been designed, trained and tested for small molecule docking to proteins. Therefore, DiffDock is only likely to be able to deal with small peptides and nucleic acids as ligands, we do not recommend using DiffDock for the interactions of larger biomolecules. For other interactions we recommend looking at [DiffDock-PP](https://github.com/ketatam/DiffDock-PP) (rigid protein-protein interactions), [AlphaFold-Multimer](https://github.com/google-deepmind/alphafold) (flexible protein-protein interactions) or [RoseTTAFold2NA](https://github.com/uw-ipd/RoseTTAFold2NA) (protein-nucleic acid interactions).
</details>
## Datasets <a name="datasets"></a>
The files in `data` contain the splits used for the various datasets. Below instructions for how to download each of the different datasets used for training and evaluation:
- **PDBBind:** download the processed complexes from [zenodo](https://zenodo.org/record/6408497), unzip the directory and place it into `data` such that you have the path `data/PDBBind_processed`.
- **BindingMOAD:** download the processed complexes from [zenodo](https://zenodo.org/records/10656052) under `BindingMOAD_2020_processed.tar`, unzip the directory and place it into `data` such that you have the path `data/BindingMOAD_2020_processed`.
- **DockGen:** to evaluate the performance of `DiffDock-L` with this repository you should use directly the data from BindingMOAD above. For other purposes you can download exclusively the complexes of the DockGen benchmark already processed (e.g. chain cutoff) from [zenodo](https://zenodo.org/records/10656052) downloading the `DockGen.tar` file.
- **PoseBusters:** download the processed complexes from [zenodo](https://zenodo.org/records/8278563).
- **van der Mers:** the protein structures used for the van der Mers data augmentation strategy were downloaded [here](https://files.ipd.uw.edu/pub/training_sets/pdb_2021aug02.tar.gz).
## Replicate results <a name="replicate"></a>
If you are interested in replicating the results of the original DiffDock paper please checkout to the following commit:
git checkout v1.0
Otherwise download the data and place it as described in the "Dataset" section above.
### Generate the ESM2 embeddings for the proteins
First run:
To avoid having to compute ESM embeddings every time we evaluate on a dataset we first cache them and then run the evaluation script. Here the instructions for generating these for PDBBind but it also applies similarly to the other benchmarks. First run the following command to save the list of ESM embeddings:
python datasets/pdbbind_lm_embedding_preparation.py
python datasets/esm_embedding_preparation.py
Use the generated file `data/pdbbind_sequences.fasta` to generate the ESM2 language model embeddings using the library https://github.com/facebookresearch/esm by installing their repository and executing the following in their repository:
@@ -79,65 +144,54 @@ Then run the command:
python datasets/esm_embeddings_to_pt.py
### Using the provided model weights for evaluation
We first generate the language model embeddings for the testset, then run inference with DiffDock, and then evaluate the files that DiffDock produced:
### Run DiffDock-L
python datasets/esm_embedding_preparation.py --protein_ligand_csv data/testset_csv.csv --out_file data/prepared_for_esm_testset.fasta
git clone https://github.com/facebookresearch/esm
cd esm
pip install -e .
cd ..
HOME=esm/model_weights python esm/scripts/extract.py esm2_t33_650M_UR50D data/prepared_for_esm_testset.fasta data/esm2_output --repr_layers 33 --include per_tok
python -m inference --protein_ligand_csv data/testset_csv.csv --out_dir results/user_predictions_testset --inference_steps 20 --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
python evaluate_files.py --results_path results/user_predictions_testset --file_to_exclude rank1.sdf --num_predictions 40
For PDBBind:
<!--
To predict binding structures using the provided model weights run:
python -m evaluate --config inference_args.yaml --split_path data/splits/timesplit_test --split_path data/splits/timesplit_test --batch_size 10 --esm_embeddings_path data/esm2_embeddings.pt --data_dir data/PDBBind_processed/ --tqdm --split test --chain_cutoff 10 --dataset pdbbind
python -m evaluate --model_dir workdir/paper_score_model --ckpt best_ema_inference_epoch_model.pt --confidence_ckpt best_model_epoch75.pt --confidence_model_dir workdir/paper_confidence_model --run_name DiffDockInference --inference_steps 20 --split_path data/splits/timesplit_test --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
For DockGen:
To additionally save the .sdf files of the generated molecules, add the flag `--save_visualisation`
-->
### Training a model yourself and using those weights
Train the large score model:
python -m evaluate --config inference_args.yaml --dataset moad --data_dir data/BindingMOAD_2020_processed --unroll_clusters --tqdm --split test --esm_embeddings_path data/moad_esm2_embeddings.pt --min_ligand_size 2 --moad_esm_embeddings_sequences_path data/moad_sequences_to_id.fasta --chain_cutoff 10 --batch_size 10
python -m train --run_name big_score_model --test_sigma_intervals --esm_embeddings_path data/esm2_3billion_embeddings.pt --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 19 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --batch_size 16 --ns 48 --nv 10 --num_conv_layers 6 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --val_inference_freq 5 --num_inference_complexes 500 --use_ema --distance_embed_dim 64 --cross_distance_embed_dim 64 --sigma_embed_dim 64 --scheduler_patience 30 --n_epochs 850
For PoseBusters:
The model weights are saved in the `workdir` directory.
python -m evaluate --config inference_args.yaml --data_dir data/posebusters_benchmark_set --tqdm --dataset posebusters --split_path data/splits/posebusters_benchmark_set_ids.txt --esm_embeddings_path data/posebusters_ESM.pt --chain_cutoff 10 --batch_size 10 --protein_file protein --ligand_file ligands
Train a small score model with higher maximum translation sigma that will be used to generate the samples for training the confidence model:
To additionally save the .sdf files of the generated molecules, add the flag `--save_visualisation`.
python -m train --run_name small_score_model --test_sigma_intervals --esm_embeddings_path data/esm2_3billion_embeddings.pt --log_dir workdir --lr 1e-3 --tr_sigma_min 0.1 --tr_sigma_max 34 --rot_sigma_min 0.03 --rot_sigma_max 1.55 --batch_size 16 --ns 24 --nv 6 --num_conv_layers 5 --dynamic_max_cross --scheduler plateau --scale_by_sigma --dropout 0.1 --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --num_dataloader_workers 1 --cudnn_benchmark --val_inference_freq 5 --num_inference_complexes 500 --use_ema --scheduler_patience 30 --n_epochs 300
Note: the notebook `data/apo_alignment.ipynb` contains the code used to align the ESMFold-generated apo-structures to the holo-structures.
In practice, you could also likely achieve the same or better results by using the first score model for creating the samples to train the confidence model, but this is what we did in the paper.
The score model used to generate the samples to train the confidence model does not have to be the same as the score model that is used with that confidence model during inference.
## Citations <a name="citations"></a>
If you use this code or the models in your research, please cite the following paper:
Train the confidence model by running the following:
```bibtex
@inproceedings{corso2023diffdock,
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
booktitle={International Conference on Learning Representations (ICLR)},
year={2023}
}
```
python -m confidence.confidence_train --original_model_dir workdir/small_score_model --run_name confidence_model --inference_steps 20 --samples_per_complex 7 --batch_size 16 --n_epochs 100 --lr 3e-4 --scheduler_patience 50 --ns 24 --nv 6 --num_conv_layers 5 --dynamic_max_cross --scale_by_sigma --dropout 0.1 --all_atoms --remove_hs --c_alpha_max_neighbors 24 --receptor_radius 15 --esm_embeddings_path data/esm2_3billion_embeddings.pt --main_metric loss --main_metric_goal min --best_model_save_frequency 5 --rmsd_classification_cutoff 2 --cache_creation_id 1 --cache_ids_to_combine 1 2 3 4
If you use the latest version of our model, DiffDock-L, please also cite the following paper:
first with `--cache_creation_id 1` then `--cache_creation_id 2` etc. up to 4
```bibtex
@inproceedings{corso2024discovery,
title={Deep Confident Steps to New Pockets: Strategies for Docking Generalization},
author={Corso, Gabriele and Deng, Arthur and Polizzi, Nicholas and Barzilay, Regina and Jaakkola, Tommi},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}
```
Now everything is trained and you can run inference with:
## License <a name="license"></a>
The code and model weights are released under MIT license. See the [LICENSE](LICENSE) file for details.
python -m evaluate --model_dir workdir/big_score_model --ckpt best_ema_inference_epoch_model.pt --confidence_ckpt best_model_epoch75.pt --confidence_model_dir workdir/confidence_model --run_name DiffDockInference --inference_steps 20 --split_path data/splits/timesplit_test --samples_per_complex 40 --batch_size 10 --actual_steps 18 --no_final_step_noise
Components of the code of the [spyrmsd](spyrmsd) package by Rocco Meli (also MIT license) were integrated in the repo.
Note: the notebook `data/apo_alignment.ipynb` contains the code used to align the ESMFold-generated apo-structures to the holo-structures.
## Citation
@article{corso2023diffdock,
title={DiffDock: Diffusion Steps, Twists, and Turns for Molecular Docking},
author = {Corso, Gabriele and Stärk, Hannes and Jing, Bowen and Barzilay, Regina and Jaakkola, Tommi},
journal={International Conference on Learning Representations (ICLR)},
year={2023}
}
## License
MIT
## Acknowledgements
We thank Wei Lu and Rachel Wu for pointing out some issues with the code.
![Alt Text](visualizations/example_6agt_symmetric.gif)
## Acknowledgements <a name="acknowledgements"></a>
We sincerely thank:
* Jacob Silterra for his help with the publishing and deployment of the code.
* Arthur Deng, Nicholas Polizzi and Ben Fry for their critical contributions to part of the code in this repository.
* Wei Lu and Rachel Wu for pointing out some issues with the code.

View File

@@ -1,219 +0,0 @@
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as initial pose
import os
import plotly.express as px
import time
from argparse import FileType, ArgumentParser
import numpy as np
import pandas as pd
import wandb
from biopandas.pdb import PandasPdb
from rdkit import Chem
from tqdm import tqdm
from datasets.pdbbind import read_mol
from datasets.process_mols import read_molecule
from utils.utils import read_strings_from_txt, get_symmetry_rmsd
parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--run_name', type=str, default='gnina_results', help='')
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--results_path', type=str, default='results/user_inference', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand.pdb', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
parser.add_argument('--wandb', action='store_true', default=False, help='')
parser.add_argument('--file_to_exclude', type=str, default=None, help='')
parser.add_argument('--all_dirs_in_results', action='store_true', default=True, help='Evaluate all directories in the results path instead of using directly looking for the names')
parser.add_argument('--num_predictions', type=int, default=10, help='')
parser.add_argument('--no_id_in_filename', action='store_true', default=False, help='')
args = parser.parse_args()
print('Reading paths and names.')
names = read_strings_from_txt(f'data/splits/timesplit_test')
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
results_path_containments = os.listdir(args.results_path)
if args.wandb:
wandb.init(
entity='coarse-graining-mit',
settings=wandb.Settings(start_method="fork"),
project=args.project,
name=args.run_name,
config=args
)
all_times = []
successful_names_list = []
rmsds_list = []
centroid_distances_list = []
min_cross_distances_list = []
min_self_distances_list = []
without_rec_overlap_list = []
start_time = time.time()
for i, name in enumerate(tqdm(names)):
mol = read_mol(args.data_dir, name, remove_hs=True)
mol = Chem.RemoveAllHs(mol)
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
if args.all_dirs_in_results:
directory_with_name = [directory for directory in results_path_containments if name in directory][0]
ligand_pos = []
for i in range(args.num_predictions):
file_paths = os.listdir(os.path.join(args.results_path, directory_with_name))
file_path = [path for path in file_paths if f'rank{i+1}' in path][0]
if args.file_to_exclude is not None and args.file_to_exclude in file_path: continue
mol_pred = read_molecule(os.path.join(args.results_path, directory_with_name, file_path),remove_hs=True, sanitize=True)
mol_pred = Chem.RemoveAllHs(mol_pred)
ligand_pos.append(mol_pred.GetConformer().GetPositions())
ligand_pos = np.asarray(ligand_pos)
else:
if not os.path.exists(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}')): raise Exception('path did not exists:', os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'))
mol_pred = read_molecule(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'), remove_hs=True, sanitize=True)
if mol_pred == None:
print("Skipping ", name, ' because RDKIT could not read it.')
continue
mol_pred = Chem.RemoveAllHs(mol_pred)
ligand_pos = np.asarray([np.array(mol_pred.GetConformer(i).GetPositions()) for i in range(args.num_predictions)])
try:
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred)
except Exception as e:
print("Using non corrected RMSD because of the error:", e)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsds_list.append(rmsd)
centroid_distances_list.append(np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1))
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
if not os.path.exists(rec_path):
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
rec = PandasPdb().read_pdb(rec_path)
rec_df = rec.df['ATOM']
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
receptor_pos = np.tile(receptor_pos, (args.num_predictions, 1, 1))
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
self_distances = np.linalg.norm(ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
min_cross_distances_list.append(np.min(cross_distances, axis=(1,2)))
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
successful_names_list.append(name)
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
performance_metrics = {}
for overlap in ['', 'no_overlap_']:
if 'no_overlap_' == overlap:
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
rmsds = np.array(rmsds_list)[without_rec_overlap]
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
successful_names = np.array(successful_names_list)[without_rec_overlap]
else:
rmsds = np.array(rmsds_list)
centroid_distances = np.array(centroid_distances_list)
min_cross_distances = np.array(min_cross_distances_list)
min_self_distances = np.array(min_self_distances_list)
successful_names = np.array(successful_names_list)
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), successful_names)
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), np.array(min_cross_distances))
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), np.array(min_self_distances))
performance_metrics.update({
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / args.num_predictions).__round__(2),
f'{overlap}self_intersect_fraction': (100 * (min_self_distances < 0.4).sum() / len(min_self_distances) / args.num_predictions).__round__(2),
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
})
top5_rmsds = np.min(rmsds[:, :5], axis=1)
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
})
top10_rmsds = np.min(rmsds[:, :10], axis=1)
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
f'{overlap}top10_steric_clash_fraction': ( 100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
})
for k in performance_metrics:
print(k, performance_metrics[k])
if args.wandb:
wandb.log(performance_metrics)
histogram_metrics_list = [('rmsd', rmsds[:,0]),
('centroid_distance', centroid_distances[:,0]),
('mean_rmsd', rmsds[:,0]),
('mean_centroid_distance', centroid_distances[:,0])]
histogram_metrics_list.append(('top5_rmsds', top5_rmsds))
histogram_metrics_list.append(('top5_centroid_distances', top5_centroid_distances))
histogram_metrics_list.append(('top10_rmsds', top10_rmsds))
histogram_metrics_list.append(('top10_centroid_distances', top10_centroid_distances))
os.makedirs(f'.plotly_cache/baseline_cache', exist_ok=True)
images = []
for metric_name, metric in histogram_metrics_list:
d = {args.results_path: metric}
df = pd.DataFrame(data=d)
fig = px.ecdf(df, width=900, height=600, range_x=[0, 40])
fig.add_vline(x=2, annotation_text='2 A;', annotation_font_size=20, annotation_position="top right",
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
fig.add_vline(x=5, annotation_text='5 A;', annotation_font_size=20, annotation_position="top right",
line_dash='dash', line_color='green', annotation_font_color='green')
fig.update_xaxes(title=f'{metric_name} in Angstrom', title_font={"size": 20}, tickfont={"size": 20})
fig.update_yaxes(title=f'Fraction of predictions with lower error', title_font={"size": 20},
tickfont={"size": 20})
fig.update_layout(autosize=False, margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='white',
paper_bgcolor='white', legend_title_text='Method', legend_title_font_size=17,
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ), )
fig.update_xaxes(showgrid=True, gridcolor='lightgrey')
fig.update_yaxes(showgrid=True, gridcolor='lightgrey')
fig.write_image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'))
wandb.log({metric_name: wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}")})
images.append(wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}"))
wandb.log({'images': images})

View File

@@ -1,175 +0,0 @@
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as
# initial pose
import os
import shutil
import subprocess
import sys
import time
from argparse import ArgumentParser, FileType
from datetime import datetime
import numpy as np
import pandas as pd
from biopandas.pdb import PandasPdb
from rdkit import Chem
from rdkit.Chem import AllChem, MolToPDBFile
from scipy.spatial.distance import cdist
from datasets.pdbbind import read_mol
from utils.utils import read_strings_from_txt
parser = ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--results_path', type=str, default='results/gnina_predictions', help='')
parser.add_argument('--complex_names_path', type=str, default='data/splits/timesplit_test', help='')
parser.add_argument('--seed_molecules_path', type=str, default=None, help='Use the molecules at seed molecule path as initialization and only search around them')
parser.add_argument('--seed_molecule_filename', type=str, default='equibind_corrected.sdf', help='Use the molecules at seed molecule path as initialization and only search around them')
parser.add_argument('--smina', action='store_true', default=False, help='')
parser.add_argument('--no_gpu', action='store_true', default=False, help='')
parser.add_argument('--exhaustiveness', type=int, default=8, help='')
parser.add_argument('--num_cpu', type=int, default=16, help='')
parser.add_argument('--pocket_mode', action='store_true', default=False, help='')
parser.add_argument('--pocket_cutoff', type=int, default=5, help='')
parser.add_argument('--num_modes', type=int, default=10, help='')
parser.add_argument('--autobox_add', type=int, default=4, help='')
parser.add_argument('--use_p2rank_pocket', action='store_true', default=False, help='')
parser.add_argument('--skip_p2rank', action='store_true', default=False, help='')
parser.add_argument('--prank_path', type=str, default='/Users/hstark/projects/p2rank_2.3/prank', help='')
parser.add_argument('--skip_existing', action='store_true', default=False, help='')
args = parser.parse_args()
class Logger(object):
def __init__(self, logpath, syspart=sys.stdout):
self.terminal = syspart
self.log = open(logpath, "a")
def write(self, message):
self.terminal.write(message)
self.log.write(message)
self.log.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
def log(*args):
print(f'[{datetime.now()}]', *args)
# parameters
names = read_strings_from_txt(args.complex_names_path)
if os.path.exists(args.results_path) and not args.skip_existing:
shutil.rmtree(args.results_path)
os.makedirs(args.results_path, exist_ok=True)
sys.stdout = Logger(logpath=f'{args.results_path}/gnina.log', syspart=sys.stdout)
sys.stderr = Logger(logpath=f'{args.results_path}/error.log', syspart=sys.stderr)
p2rank_cache_path = "results/.p2rank_cache"
if args.use_p2rank_pocket and not args.skip_p2rank:
os.makedirs(p2rank_cache_path, exist_ok=True)
pdb_files_cache = os.path.join(p2rank_cache_path,'pdb_files')
os.makedirs(pdb_files_cache, exist_ok=True)
with open(f"{p2rank_cache_path}/pdb_list_p2rank.txt", "w") as out:
for name in names:
shutil.copy(os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb'), f'{pdb_files_cache}/{name}_protein_processed.pdb')
out.write(os.path.join('pdb_files', f'{name}_protein_processed.pdb\n'))
cmd = f"bash {args.prank_path} predict {p2rank_cache_path}/pdb_list_p2rank.txt -o {p2rank_cache_path}/p2rank_output -threads 4"
os.system(cmd)
all_times = []
start_time = time.time()
for i, name in enumerate(names):
os.makedirs(os.path.join(args.results_path, name), exist_ok=True)
log('\n')
log(f'complex {i} of {len(names)}')
# call gnina to find binding pose
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
prediction_output_name = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.pdb')
log_path = os.path.join(args.results_path, name, f'{name}{args.file_suffix}.log')
if args.seed_molecules_path is not None: seed_mol_path = os.path.join(args.seed_molecules_path, name, f'{args.seed_molecule_filename}')
if args.skip_existing and os.path.exists(prediction_output_name): continue
if args.pocket_mode:
mol = read_mol(args.data_dir, name, remove_hs=False)
rec = PandasPdb().read_pdb(rec_path)
rec_df = rec.get(s='c-alpha')
rec_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
lig_pos = mol.GetConformer().GetPositions()
d = cdist(rec_pos, lig_pos)
label = np.any(d < args.pocket_cutoff, axis=1)
if np.any(label):
center_pocket = rec_pos[label].mean(axis=0)
else:
print("No pocket residue below minimum distance ", args.pocket_cutoff, "taking closest at", np.min(d))
center_pocket = rec_pos[np.argmin(np.min(d, axis=1)[0])]
radius_pocket = np.max(np.linalg.norm(lig_pos - center_pocket[None, :], axis=1))
diameter_pocket = radius_pocket * 2
center_x = center_pocket[0]
size_x = diameter_pocket + 8
center_y = center_pocket[1]
size_y = diameter_pocket + 8
center_z = center_pocket[2]
size_z = diameter_pocket + 8
mol_rdkit = read_mol(args.data_dir, name, remove_hs=False)
single_time = time.time()
mol_rdkit.RemoveAllConformers()
ps = AllChem.ETKDGv2()
id = AllChem.EmbedMolecule(mol_rdkit, ps)
if id == -1:
print('rdkit pos could not be generated without using random pos. using random pos now.')
ps.useRandomCoords = True
AllChem.EmbedMolecule(mol_rdkit, ps)
AllChem.MMFFOptimizeMolecule(mol_rdkit, confId=0)
rdkit_mol_path = os.path.join(args.data_dir, name, f'{name}_rdkit_ligand.pdb')
MolToPDBFile(mol_rdkit, rdkit_mol_path)
fallback_without_p2rank = False
if args.use_p2rank_pocket:
df = pd.read_csv(f'{p2rank_cache_path}/p2rank_output/{name}_protein_processed.pdb_predictions.csv')
rdkit_lig_pos = mol_rdkit.GetConformer().GetPositions()
diameter_pocket = np.max(cdist(rdkit_lig_pos, rdkit_lig_pos))
size_x = diameter_pocket + args.autobox_add * 2
size_y = diameter_pocket + args.autobox_add * 2
size_z = diameter_pocket + args.autobox_add * 2
if df.empty:
fallback_without_p2rank = True
else:
center_x = df.iloc[0][' center_x']
center_y = df.iloc[0][' center_y']
center_z = df.iloc[0][' center_z']
log(f'processing {rec_path}')
if not args.pocket_mode and not args.use_p2rank_pocket or fallback_without_p2rank:
return_code = subprocess.run(
f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --autobox_ligand {rec_path if args.seed_molecules_path is None else seed_mol_path} --autobox_add {args.autobox_add} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''}",
shell=True)
else:
return_code = subprocess.run(
f"gnina --receptor {rec_path} --ligand {rdkit_mol_path} --num_modes {args.num_modes} -o {prediction_output_name} {'--no_gpu' if args.no_gpu else ''} --log {log_path} --exhaustiveness {args.exhaustiveness} --cpu {args.num_cpu} {'--cnn_scoring none' if args.smina else ''} --center_x {center_x} --center_y {center_y} --center_z {center_z} --size_x {size_x} --size_y {size_y} --size_z {size_z}",
shell=True)
log(return_code)
all_times.append(time.time() - single_time)
log("single time: --- %s seconds ---" % (time.time() - single_time))
log("time so far: --- %s seconds ---" % (time.time() - start_time))
log('\n')
log(all_times)
log("--- %s seconds ---" % (time.time() - start_time))

View File

@@ -1,5 +0,0 @@
for i in $(seq 0 15); do
python baseline_tankbind_runtime.py --parallel_id $i --parallel_tot 16 --prank_path /data/rsg/nlp/hstark/TankBind/packages/p2rank_2.3/prank --data_dir /data/rsg/nlp/hstark/ligbind/data/PDBBind_processed --split_path /data/rsg/nlp/hstark/ligbind/data/splits/timesplit_test --results_path /data/rsg/nlp/hstark/ligbind/results/tankbind_16_worker_runtime --device cpu --skip_p2rank --num_workers 1 --skip_multiple_pocket_outputs &
done
wait

View File

@@ -1,239 +0,0 @@
import copy
import os
import plotly.express as px
import time
from argparse import FileType, ArgumentParser
import numpy as np
import pandas as pd
import wandb
from biopandas.pdb import PandasPdb
from rdkit import Chem
from rdkit.Chem import RemoveHs
from tqdm import tqdm
from datasets.pdbbind import read_mol
from datasets.process_mols import read_molecule, read_sdf_or_mol2
from utils.utils import read_strings_from_txt, get_symmetry_rmsd, remove_all_hs
parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--run_name', type=str, default='tankbind', help='')
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--renumbered_atoms_dir', type=str, default='../TankBind/examples/tankbind_pdb/renumber_atom_index_same_as_smiles', help='')
parser.add_argument('--results_path', type=str, default='results/tankbind_top5', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
parser.add_argument('--wandb', action='store_true', default=True, help='')
parser.add_argument('--num_predictions', type=int, default=5, help='')
args = parser.parse_args()
names = read_strings_from_txt(f'data/splits/timesplit_test')
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
if args.wandb:
wandb.init(
entity='coarse-graining-mit',
settings=wandb.Settings(start_method="fork"),
project=args.project,
name=args.run_name,
config=args
)
all_times = []
rmsds_list = []
unsym_rmsds_list = []
centroid_distances_list = []
min_cross_distances_list = []
min_self_distances_list = []
made_prediction_list = []
steric_clash_list = []
without_rec_overlap_list = []
start_time = time.time()
successful_names_list = []
for i, name in enumerate(tqdm(names)):
mol, _ = read_sdf_or_mol2(f"{args.renumbered_atoms_dir}/{name}.sdf", None)
sm = Chem.MolToSmiles(mol)
m_order = list(mol.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
mol = Chem.RenumberAtoms(mol, m_order)
mol = Chem.RemoveHs(mol)
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
assert(os.path.exists(os.path.join(args.results_path, name, f'{name}_tankbind_0.sdf')))
ligand_pos = []
for i in range(args.num_predictions):
if not os.path.exists(os.path.join(args.results_path, name, f'{name}_tankbind_{i}.sdf')): break
mol_pred, _ = read_sdf_or_mol2(os.path.join(args.results_path, name, f'{name}_tankbind_{i}.sdf'),None)
sm = Chem.MolToSmiles(mol_pred)
m_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
mol_pred = Chem.RenumberAtoms(mol_pred, m_order)
mol_pred = RemoveHs(mol_pred)
ligand_pos.append(np.array(mol_pred.GetConformer().GetPositions()))
ligand_pos = np.asarray(ligand_pos)
try:
unsym_rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsd = np.array(get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred))
except Exception as e:
print("Using non corrected RMSD because of the error:", e)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
num_pockets = len(ligand_pos)
unsym_rmsds_list.append(np.lib.pad(unsym_rmsd, (0,10-len(unsym_rmsd)), 'constant', constant_values=(0)) )
rmsds_list.append(np.lib.pad(rmsd, (0,10-len(rmsd)), 'constant', constant_values=(0)) )
centroid_distance = np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1)
centroid_distances_list.append(np.lib.pad(centroid_distance, (0,10-len(rmsd)), 'constant', constant_values=(0)) )
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
if not os.path.exists(rec_path):
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
rec = PandasPdb().read_pdb(rec_path)
rec_df = rec.df['ATOM']
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
receptor_pos = np.tile(receptor_pos, (10, 1, 1))
ligand_pos_padded = np.lib.pad(ligand_pos, ((0,10-len(ligand_pos)), (0,0), (0,0)), 'constant', constant_values=(np.inf))
ligand_pos_padded_zero = np.lib.pad(ligand_pos, ((0, 10 - len(ligand_pos)), (0, 0), (0, 0)), 'constant',constant_values=0)
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos_padded[:, None, :, :], axis=-1)
self_distances = np.linalg.norm(ligand_pos_padded_zero[:, :, None, :] - ligand_pos_padded_zero[:, None, :, :], axis=-1)
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
min_cross_distance = np.min(cross_distances, axis=(1, 2))
individual_made_prediction = np.lib.pad(np.ones(num_pockets), (0,10-len(rmsd)), 'constant', constant_values=(0))
made_prediction_list.append(individual_made_prediction)
min_cross_distances_list.append(min_cross_distance)
successful_names_list.append(name)
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
performance_metrics = {}
for overlap in ['', 'no_overlap_']:
if 'no_overlap_' == overlap:
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
unsym_rmsds = np.array(unsym_rmsds_list)[without_rec_overlap]
rmsds = np.array(rmsds_list)[without_rec_overlap]
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
made_prediction = np.array(made_prediction_list)[without_rec_overlap]
successful_names = np.array(successful_names_list)[without_rec_overlap]
else:
unsym_rmsds = np.array(unsym_rmsds_list)
rmsds = np.array(rmsds_list)
centroid_distances = np.array(centroid_distances_list)
min_cross_distances = np.array(min_cross_distances_list)
min_self_distances = np.array(min_self_distances_list)
made_prediction = np.array(made_prediction_list)
successful_names = np.array(successful_names_list)
inf_rmsds = copy.deepcopy(rmsds)
inf_rmsds[~made_prediction.astype(bool)] = np.inf
inf_centroid_distances = copy.deepcopy(centroid_distances)
inf_centroid_distances[~made_prediction.astype(bool)] = np.inf
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), np.array(successful_names))
np.save(os.path.join(args.results_path, f'{overlap}centroid_distances.npy'), centroid_distances)
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), min_cross_distances)
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), min_self_distances)
performance_metrics.update({
f'{overlap}self_intersect_fraction': (100 * (min_self_distances[:, 0] < 0.4).sum() / len(min_self_distances[:, 0])),
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances[:,0] < 0.4).sum() / len(min_cross_distances[:,0])),
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
f'{overlap}unsym_rmsds_below_2': (100 * (unsym_rmsds[:,0] < 2).sum() / len(unsym_rmsds[:,0])),
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
})
top5_rmsds = np.min(inf_rmsds[:, :5], axis=1)
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :5], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
})
top10_rmsds = np.min(inf_rmsds[:, :10], axis=1)
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(inf_rmsds[:, :10], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top10_steric_clash_fraction': (100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
})
for k in performance_metrics:
print(k, performance_metrics[k])
if args.wandb:
wandb.log(performance_metrics)
histogram_metrics_list = [('rmsd', rmsds[:,0]),
('centroid_distance', centroid_distances[:,0]),
('mean_rmsd', rmsds[:,0]),
('mean_centroid_distance', centroid_distances[:,0])]
histogram_metrics_list.append(('top5_rmsds', top5_rmsds))
histogram_metrics_list.append(('top5_centroid_distances', top5_centroid_distances))
histogram_metrics_list.append(('top10_rmsds', top10_rmsds))
histogram_metrics_list.append(('top10_centroid_distances', top10_centroid_distances))
os.makedirs(f'.plotly_cache/baseline_cache', exist_ok=True)
images = []
for metric_name, metric in histogram_metrics_list:
d = {args.results_path: metric}
df = pd.DataFrame(data=d)
fig = px.ecdf(df, width=900, height=600, range_x=[0, 40])
fig.add_vline(x=2, annotation_text='2 A;', annotation_font_size=20, annotation_position="top right",
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
fig.add_vline(x=5, annotation_text='5 A;', annotation_font_size=20, annotation_position="top right",
line_dash='dash', line_color='green', annotation_font_color='green')
fig.update_xaxes(title=f'{metric_name} in Angstrom', title_font={"size": 20}, tickfont={"size": 20})
fig.update_yaxes(title=f'Fraction of predictions with lower error', title_font={"size": 20},
tickfont={"size": 20})
fig.update_layout(autosize=False, margin={'l': 0, 'r': 0, 't': 0, 'b': 0}, plot_bgcolor='white',
paper_bgcolor='white', legend_title_text='Method', legend_title_font_size=17,
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ), )
fig.update_xaxes(showgrid=True, gridcolor='lightgrey')
fig.update_yaxes(showgrid=True, gridcolor='lightgrey')
fig.write_image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'))
wandb.log({metric_name: wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}")})
images.append(wandb.Image(os.path.join(f'.plotly_cache/baseline_cache', f'{metric_name}.png'), caption=f"{metric_name}"))
wandb.log({'images': images})

View File

@@ -1,342 +0,0 @@
# This file needs to be ran in the TANKBind repository together with baseline_run_tankbind_parallel.sh
import sys
import time
from multiprocessing import Pool
import copy
import warnings
from argparse import ArgumentParser
from rdkit.Chem import AllChem, RemoveHs
from feature_utils import save_cleaned_protein, read_mol
from generation_utils import get_LAS_distance_constraint_mask, get_info_pred_distance, write_with_new_coords
import logging
from torch_geometric.loader import DataLoader
from tqdm import tqdm # pip install tqdm if fails.
from model import get_model
# from utils import *
import torch
from data import TankBind_prediction
import os
import numpy as np
import pandas as pd
import rdkit.Chem as Chem
from feature_utils import generate_sdf_from_smiles_using_rdkit
from feature_utils import get_protein_feature
from Bio.PDB import PDBParser
from feature_utils import extract_torchdrug_feature_from_mol
def read_strings_from_txt(path):
# every line will be one element of the returned list
with open(path) as file:
lines = file.readlines()
return [line.rstrip() for line in lines]
def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):
if molecule_file.endswith('.mol2'):
mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)
elif molecule_file.endswith('.sdf'):
supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)
mol = supplier[0]
elif molecule_file.endswith('.pdbqt'):
with open(molecule_file) as file:
pdbqt_data = file.readlines()
pdb_block = ''
for line in pdbqt_data:
pdb_block += '{}\n'.format(line[:66])
mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if sanitize or calc_charges:
Chem.SanitizeMol(mol)
if calc_charges:
# Compute Gasteiger charges on the molecule.
try:
AllChem.ComputeGasteigerCharges(mol)
except:
warnings.warn('Unable to compute charges for the molecule.')
if remove_hs:
mol = Chem.RemoveHs(mol, sanitize=sanitize)
except:
return None
return mol
def parallel_save_prediction(arguments):
dataset, y_pred_list, chosen,rdkit_mol_path, result_folder, name = arguments
for idx, line in chosen.iterrows():
pocket_name = line['pocket_name']
compound_name = line['compound_name']
ligandName = compound_name.split("_")[1]
dataset_index = line['dataset_index']
coords = dataset[dataset_index].coords.to('cpu')
protein_nodes_xyz = dataset[dataset_index].node_xyz.to('cpu')
n_compound = coords.shape[0]
n_protein = protein_nodes_xyz.shape[0]
y_pred = y_pred_list[dataset_index].reshape(n_protein, n_compound).to('cpu')
compound_pair_dis_constraint = torch.cdist(coords, coords)
mol = Chem.MolFromMolFile(rdkit_mol_path)
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
LAS_distance_constraint_mask=LAS_distance_constraint_mask,
n_repeat=1, show_progress=False)
toFile = f'{result_folder}/{name}_tankbind_chosen.sdf'
new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
write_with_new_coords(mol, new_coords, toFile)
if __name__ == '__main__':
tankbind_src_folder = "../tankbind"
sys.path.insert(0, tankbind_src_folder)
torch.set_num_threads(16)
parser = ArgumentParser()
parser.add_argument('--data_dir', type=str, default='/Users/hstark/projects/ligbind/data/PDBBind_processed', help='')
parser.add_argument('--split_path', type=str, default='/Users/hstark/projects/ligbind/data/splits/timesplit_test', help='')
parser.add_argument('--prank_path', type=str, default='/Users/hstark/projects/p2rank_2.3/prank', help='')
parser.add_argument('--results_path', type=str, default='results/tankbind_results', help='')
parser.add_argument('--skip_existing', action='store_true', default=False, help='')
parser.add_argument('--skip_p2rank', action='store_true', default=False, help='')
parser.add_argument('--skip_multiple_pocket_outputs', action='store_true', default=False, help='')
parser.add_argument('--device', type=str, default='cpu', help='')
parser.add_argument('--num_workers', type=int, default=1, help='')
parser.add_argument('--parallel_id', type=int, default=0, help='')
parser.add_argument('--parallel_tot', type=int, default=1, help='')
args = parser.parse_args()
device = args.device
cache_path = "tankbind_cache"
os.makedirs(cache_path, exist_ok=True)
os.makedirs(args.results_path, exist_ok=True)
logging.basicConfig(level=logging.INFO)
model = get_model(0, logging, device)
# re-dock model
# modelFile = "../saved_models/re_dock.pt"
# self-dock model
modelFile = f"{tankbind_src_folder}/../saved_models/self_dock.pt"
model.load_state_dict(torch.load(modelFile, map_location=device))
_ = model.eval()
batch_size = 5
names = read_strings_from_txt(args.split_path)
if args.parallel_tot > 1:
size = len(names) // args.parallel_tot + 1
names = names[args.parallel_id*size:(args.parallel_id+1)*size]
rmsds = []
forward_pass_time = []
times_preprocess = []
times_inference = []
top_10_generation_time = []
top_1_generation_time = []
start_time = time.time()
if not args.skip_p2rank:
for name in names:
if args.skip_existing and os.path.exists(f'{args.results_path}/{name}/{name}_tankbind_1.sdf'): continue
print("Now processing: ", name)
protein_path = f'{args.data_dir}/{name}/{name}_protein_processed.pdb'
cleaned_protein_path = f"{cache_path}/{name}_protein_tankbind_cleaned.pdb" # if you change this you also need to change below
parser = PDBParser(QUIET=True)
s = parser.get_structure(name, protein_path)
c = s[0]
clean_res_list, ligand_list = save_cleaned_protein(c, cleaned_protein_path)
with open(f"{cache_path}/pdb_list_p2rank.txt", "w") as out:
for name in names:
out.write(f"{name}_protein_tankbind_cleaned.pdb\n")
cmd = f"bash {args.prank_path} predict {cache_path}/pdb_list_p2rank.txt -o {cache_path}/p2rank -threads 4"
os.system(cmd)
times_preprocess.append(time.time() - start_time)
p2_rank_time = time.time() - start_time
list_to_parallelize = []
for name in tqdm(names):
single_preprocess_time = time.time()
if args.skip_existing and os.path.exists(f'{args.results_path}/{name}/{name}_tankbind_1.sdf'): continue
print("Now processing: ", name)
protein_path = f'{args.data_dir}/{name}/{name}_protein_processed.pdb'
ligand_path = f"{args.data_dir}/{name}/{name}_ligand.sdf"
cleaned_protein_path = f"{cache_path}/{name}_protein_tankbind_cleaned.pdb" # if you change this you also need to change below
rdkit_mol_path = f"{cache_path}/{name}_rdkit_ligand.sdf"
parser = PDBParser(QUIET=True)
s = parser.get_structure(name, protein_path)
c = s[0]
clean_res_list, ligand_list = save_cleaned_protein(c, cleaned_protein_path)
lig, _ = read_mol(f"{args.data_dir}/{name}/{name}_ligand.sdf", f"{args.data_dir}/{name}/{name}_ligand.mol2")
lig = RemoveHs(lig)
smiles = Chem.MolToSmiles(lig)
generate_sdf_from_smiles_using_rdkit(smiles, rdkit_mol_path, shift_dis=0)
parser = PDBParser(QUIET=True)
s = parser.get_structure("x", cleaned_protein_path)
res_list = list(s.get_residues())
protein_dict = {}
protein_dict[name] = get_protein_feature(res_list)
compound_dict = {}
mol = Chem.MolFromMolFile(rdkit_mol_path)
compound_dict[name + f"_{name}" + "_rdkit"] = extract_torchdrug_feature_from_mol(mol, has_LAS_mask=True)
info = []
for compound_name in list(compound_dict.keys()):
# use protein center as the block center.
com = ",".join([str(a.round(3)) for a in protein_dict[name][0].mean(axis=0).numpy()])
info.append([name, compound_name, "protein_center", com])
p2rankFile = f"{cache_path}/p2rank/{name}_protein_tankbind_cleaned.pdb_predictions.csv"
pocket = pd.read_csv(p2rankFile)
pocket.columns = pocket.columns.str.strip()
pocket_coms = pocket[['center_x', 'center_y', 'center_z']].values
for ith_pocket, com in enumerate(pocket_coms):
com = ",".join([str(a.round(3)) for a in com])
info.append([name, compound_name, f"pocket_{ith_pocket + 1}", com])
info = pd.DataFrame(info, columns=['protein_name', 'compound_name', 'pocket_name', 'pocket_com'])
dataset_path = f"{cache_path}/{name}_dataset/"
os.system(f"rm -r {dataset_path}")
os.system(f"mkdir -p {dataset_path}")
dataset = TankBind_prediction(dataset_path, data=info, protein_dict=protein_dict, compound_dict=compound_dict)
# dataset = TankBind_prediction(dataset_path)
times_preprocess.append(time.time() - single_preprocess_time)
single_forward_pass_time = time.time()
data_loader = DataLoader(dataset, batch_size=batch_size, follow_batch=['x', 'y', 'compound_pair'], shuffle=False,
num_workers=0)
affinity_pred_list = []
y_pred_list = []
for data in tqdm(data_loader):
data = data.to(device)
y_pred, affinity_pred = model(data)
affinity_pred_list.append(affinity_pred.detach().cpu())
for i in range(data.y_batch.max() + 1):
y_pred_list.append((y_pred[data['y_batch'] == i]).detach().cpu())
affinity_pred_list = torch.cat(affinity_pred_list)
forward_pass_time.append(time.time() - single_forward_pass_time)
output_info = copy.deepcopy(dataset.data)
output_info['affinity'] = affinity_pred_list
output_info['dataset_index'] = range(len(output_info))
output_info_sorted = output_info.sort_values('affinity', ascending=False)
result_folder = f'{args.results_path}/{name}'
os.makedirs(result_folder, exist_ok=True)
output_info_sorted.to_csv(f"{result_folder}/output_info_sorted_by_affinity.csv")
if not args.skip_multiple_pocket_outputs:
for idx, (dataframe_idx, line) in enumerate(copy.deepcopy(output_info_sorted).iterrows()):
single_top10_generation_time = time.time()
pocket_name = line['pocket_name']
compound_name = line['compound_name']
ligandName = compound_name.split("_")[1]
coords = dataset[dataframe_idx].coords.to('cpu')
protein_nodes_xyz = dataset[dataframe_idx].node_xyz.to('cpu')
n_compound = coords.shape[0]
n_protein = protein_nodes_xyz.shape[0]
y_pred = y_pred_list[dataframe_idx].reshape(n_protein, n_compound).to('cpu')
y = dataset[dataframe_idx].dis_map.reshape(n_protein, n_compound).to('cpu')
compound_pair_dis_constraint = torch.cdist(coords, coords)
mol = Chem.MolFromMolFile(rdkit_mol_path)
LAS_distance_constraint_mask = get_LAS_distance_constraint_mask(mol).bool()
pred_dist_info = get_info_pred_distance(coords, y_pred, protein_nodes_xyz, compound_pair_dis_constraint,
LAS_distance_constraint_mask=LAS_distance_constraint_mask,
n_repeat=1, show_progress=False)
toFile = f'{result_folder}/{name}_tankbind_{idx}.sdf'
new_coords = pred_dist_info.sort_values("loss")['coords'].iloc[0].astype(np.double)
write_with_new_coords(mol, new_coords, toFile)
if idx < 10:
top_10_generation_time.append(time.time() - single_top10_generation_time)
if idx == 0:
top_1_generation_time.append(time.time() - single_top10_generation_time)
output_info_chosen = copy.deepcopy(dataset.data)
output_info_chosen['affinity'] = affinity_pred_list
output_info_chosen['dataset_index'] = range(len(output_info_chosen))
chosen = output_info_chosen.loc[
output_info_chosen.groupby(['protein_name', 'compound_name'], sort=False)['affinity'].agg(
'idxmax')].reset_index()
list_to_parallelize.append((dataset, y_pred_list, chosen, rdkit_mol_path, result_folder, name))
chosen_generation_start_time = time.time()
if args.num_workers > 1:
p = Pool(args.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(list_to_parallelize), desc=f'running optimization {i}/{len(list_to_parallelize)}') as pbar:
map_fn = p.imap_unordered if args.num_workers > 1 else map
for t in map_fn(parallel_save_prediction, list_to_parallelize):
pbar.update()
if args.num_workers > 1: p.__exit__(None, None, None)
chosen_generation_time = time.time() - chosen_generation_start_time
"""
lig, _ = read_mol(f"{args.data_dir}/{name}/{name}_ligand.sdf", f"{args.data_dir}/{name}/{name}_ligand.mol2")
sm = Chem.MolToSmiles(lig)
m_order = list(lig.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
lig = Chem.RenumberAtoms(lig, m_order)
lig = Chem.RemoveAllHs(lig)
lig = RemoveHs(lig)
true_ligand_pos = np.array(lig.GetConformer().GetPositions())
toFile = f'{result_folder}/{name}_tankbind_chosen.sdf'
mol_pred, _ = read_mol(toFile, None)
sm = Chem.MolToSmiles(mol_pred)
m_order = list(mol_pred.GetPropsAsDict(includePrivate=True, includeComputed=True)['_smilesAtomOutputOrder'])
mol_pred = Chem.RenumberAtoms(mol_pred, m_order)
mol_pred = RemoveHs(mol_pred)
mol_pred_pos = np.array(mol_pred.GetConformer().GetPositions())
rmsds.append(np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0)))
print(np.sqrt(((true_ligand_pos - mol_pred_pos) ** 2).sum(axis=1).mean(axis=0)))
"""
forward_pass_time = np.array(forward_pass_time).sum()
times_preprocess = np.array(times_preprocess).sum()
times_inference = np.array(times_inference).sum()
top_10_generation_time = np.array(top_10_generation_time).sum()
top_1_generation_time = np.array(top_1_generation_time).sum()
rmsds = np.array(rmsds)
print(f'forward_pass_time: {forward_pass_time}')
print(f'times_preprocess: {times_preprocess}')
print(f'times_inference: {times_inference}')
print(f'top_10_generation_time: {top_10_generation_time}')
print(f'top_1_generation_time: {top_1_generation_time}')
print(f'chosen_generation_time: {chosen_generation_time}')
print(f'rmsds_below_2: {(100 * (rmsds < 2).sum() / len(rmsds))}')
print(f'p2rank Time: {p2_rank_time}')
print(
f'total_time: '
f'{forward_pass_time + times_preprocess + times_inference + top_10_generation_time + top_1_generation_time + p2_rank_time}')
with open(os.path.join(args.results_path, 'tankbind_log.log'), 'w') as file:
file.write(f'forward_pass_time: {forward_pass_time}')
file.write(f'times_preprocess: {times_preprocess}')
file.write(f'times_inference: {times_inference}')
file.write(f'top_10_generation_time: {top_10_generation_time}')
file.write(f'top_1_generation_time: {top_1_generation_time}')
file.write(f'rmsds_below_2: {(100 * (rmsds < 2).sum() / len(rmsds))}')
file.write(f'p2rank Time: {p2_rank_time}')
file.write(f'total_time: {forward_pass_time + times_preprocess + times_inference + top_10_generation_time + top_1_generation_time + p2_rank_time}')

View File

@@ -1,363 +0,0 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "477e3cc1-4143-47c7-8278-80b0cabca0ab",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from tqdm import tqdm\n",
"import shutil\n",
"import sys\n",
"import copy\n",
"import warnings\n",
"import numpy as np\n",
"import pandas as pd\n",
"from scipy import spatial as spa\n",
"from scipy.optimize import minimize, Bounds\n",
"from rdkit import Chem\n",
"import Bio\n",
"import scipy.spatial as spa\n",
"from Bio.PDB import PDBParser\n",
"from Bio.PDB.PDBExceptions import PDBConstructionWarning\n",
"from rdkit.Chem.rdchem import BondType as BT\n",
"from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs\n",
"from rdkit.Geometry import Point3D\n",
"from scipy import spatial\n",
"from scipy.special import softmax"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "ccfaa84e-d6e6-4b24-a0a8-95136256127b",
"metadata": {},
"outputs": [],
"source": [
"biopython_parser = PDBParser()\n",
"\n",
"\n",
"def safe_index(l, e):\n",
" \"\"\" Return index of element e in list l. If e is not present, return the last index \"\"\"\n",
" try:\n",
" return l.index(e)\n",
" except:\n",
" return len(l) - 1\n",
"\n",
"\n",
"def parse_receptor(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):\n",
" rec = parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file)\n",
" return rec\n",
"\n",
"\n",
"def parsePDB(pdbid, pdbbind_dir,use_full_size_file, use_original_protein_file):\n",
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')\n",
" if not os.path.exists(rec_path) or use_full_size_file or use_original_protein_file:\n",
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_obabel_reduce.pdb')\n",
" if not os.path.exists(rec_path) or use_original_protein_file:\n",
" rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein.pdb')\n",
"\n",
"\n",
" return parse_pdb_from_path(rec_path)\n",
"\n",
"def parse_pdb_from_path(path):\n",
" with warnings.catch_warnings():\n",
" warnings.filterwarnings(\"ignore\", category=PDBConstructionWarning)\n",
" structure = biopython_parser.get_structure('random_id', path)\n",
" rec = structure[0]\n",
" return rec\n",
"\n",
"\n",
"\n",
"def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=False):\n",
" if molecule_file.endswith('.mol2'):\n",
" mol = Chem.MolFromMol2File(molecule_file, sanitize=False, removeHs=False)\n",
" elif molecule_file.endswith('.sdf'):\n",
" supplier = Chem.SDMolSupplier(molecule_file, sanitize=False, removeHs=False)\n",
" mol = supplier[0]\n",
" elif molecule_file.endswith('.pdbqt'):\n",
" with open(molecule_file) as file:\n",
" pdbqt_data = file.readlines()\n",
" pdb_block = ''\n",
" for line in pdbqt_data:\n",
" pdb_block += '{}\\n'.format(line[:66])\n",
" mol = Chem.MolFromPDBBlock(pdb_block, sanitize=False, removeHs=False)\n",
" elif molecule_file.endswith('.pdb'):\n",
" mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)\n",
" else:\n",
" return ValueError('Expect the format of the molecule_file to be '\n",
" 'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))\n",
"\n",
" try:\n",
" if sanitize or calc_charges:\n",
" Chem.SanitizeMol(mol)\n",
"\n",
" if calc_charges:\n",
" # Compute Gasteiger charges on the molecule.\n",
" try:\n",
" AllChem.ComputeGasteigerCharges(mol)\n",
" except:\n",
" warnings.warn('Unable to compute charges for the molecule.')\n",
"\n",
" if remove_hs:\n",
" mol = Chem.RemoveHs(mol, sanitize=sanitize)\n",
" except:\n",
" return None\n",
"\n",
" return mol\n",
"\n",
"\n",
"def read_sdf_or_mol2(sdf_fileName):\n",
"\n",
" try:\n",
" mol = read_molecule(sdf_fileName)\n",
" except Exception as e:\n",
" mol2_fileName = sdf_fileName[:-3] + \"mol2\"\n",
" mol = read_molecule(mol2_fileName)\n",
" return mol\n",
"\n",
"def read_mols(pdbbind_dir, name, remove_hs=False):\n",
" ligs = []\n",
" for file in os.listdir(os.path.join(pdbbind_dir, name)):\n",
" if file.endswith(\".sdf\") and 'rdkit' not in file:\n",
" lig = read_molecule(os.path.join(pdbbind_dir, name, file), remove_hs=remove_hs, sanitize=True)\n",
" if lig is None and os.path.exists(os.path.join(pdbbind_dir, name, file[:-4] + \".mol2\")): # read mol2 file if sdf file cannot be sanitized\n",
" #print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')\n",
" lig = read_molecule(os.path.join(pdbbind_dir, name, file[:-4] + \".mol2\"), remove_hs=remove_hs, sanitize=True)\n",
" if lig is not None:\n",
" ligs.append(lig)\n",
" return ligs\n",
"\n",
"def extract_receptor_structure(rec, lig, cutoff=10000, lm_embedding_chains=None):\n",
" conf = lig.GetConformer()\n",
" lig_coords = conf.GetPositions()\n",
" min_distances = []\n",
" coords = []\n",
" c_alpha_coords = []\n",
" n_coords = []\n",
" c_coords = []\n",
" valid_chain_ids = []\n",
" lengths = []\n",
" for i, chain in enumerate(rec):\n",
" chain_coords = [] # num_residues, num_atoms, 3\n",
" chain_c_alpha_coords = []\n",
" chain_n_coords = []\n",
" chain_c_coords = []\n",
" count = 0\n",
" invalid_res_ids = []\n",
" for res_idx, residue in enumerate(chain):\n",
" if residue.get_resname() == 'HOH':\n",
" invalid_res_ids.append(residue.get_id())\n",
" continue\n",
" residue_coords = []\n",
" c_alpha, n, c = None, None, None\n",
" for atom in residue:\n",
" if atom.name == 'CA':\n",
" c_alpha = list(atom.get_vector())\n",
" if atom.name == 'N':\n",
" n = list(atom.get_vector())\n",
" if atom.name == 'C':\n",
" c = list(atom.get_vector())\n",
" residue_coords.append(list(atom.get_vector()))\n",
"\n",
" if c_alpha != None and n != None and c != None:\n",
" # only append residue if it is an amino acid and not some weird molecule that is part of the complex\n",
" chain_c_alpha_coords.append(c_alpha)\n",
" chain_n_coords.append(n)\n",
" chain_c_coords.append(c)\n",
" chain_coords.append(np.array(residue_coords))\n",
" count += 1\n",
" else:\n",
" invalid_res_ids.append(residue.get_id())\n",
" for res_id in invalid_res_ids:\n",
" chain.detach_child(res_id)\n",
" if len(chain_coords) > 0:\n",
" all_chain_coords = np.concatenate(chain_coords, axis=0)\n",
" distances = spatial.distance.cdist(lig_coords, all_chain_coords)\n",
" min_distance = distances.min()\n",
" else:\n",
" min_distance = np.inf\n",
"\n",
" # this removes chains if they are not close enough to the ligand\n",
" min_distances.append(min_distance)\n",
" lengths.append(count)\n",
" coords.append(chain_coords)\n",
" c_alpha_coords.append(np.array(chain_c_alpha_coords))\n",
" n_coords.append(np.array(chain_n_coords))\n",
" c_coords.append(np.array(chain_c_coords))\n",
" if min_distance < cutoff:\n",
" valid_chain_ids.append(chain.get_id())\n",
" min_distances = np.array(min_distances)\n",
" if len(valid_chain_ids) == 0:\n",
" valid_chain_ids.append(np.argmin(min_distances))\n",
" valid_coords = []\n",
" valid_c_alpha_coords = []\n",
" valid_n_coords = []\n",
" valid_c_coords = []\n",
" valid_lengths = []\n",
" invalid_chain_ids = []\n",
" valid_lm_embeddings = []\n",
" for i, chain in enumerate(rec):\n",
" if chain.get_id() in valid_chain_ids:\n",
" valid_coords.append(coords[i])\n",
" valid_c_alpha_coords.append(c_alpha_coords[i])\n",
" if lm_embedding_chains is not None:\n",
" if i >= len(lm_embedding_chains):\n",
" raise ValueError('Encountered valid chain id that was not present in the LM embeddings')\n",
" valid_lm_embeddings.append(lm_embedding_chains[i])\n",
" valid_n_coords.append(n_coords[i])\n",
" valid_c_coords.append(c_coords[i])\n",
" valid_lengths.append(lengths[i])\n",
" else:\n",
" invalid_chain_ids.append(chain.get_id())\n",
" coords = [item for sublist in valid_coords for item in sublist] # list with n_residues arrays: [n_atoms, 3]\n",
"\n",
" c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) # [n_residues, 3]\n",
" n_coords = np.concatenate(valid_n_coords, axis=0) # [n_residues, 3]\n",
" c_coords = np.concatenate(valid_c_coords, axis=0) # [n_residues, 3]\n",
" lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None\n",
" for invalid_id in invalid_chain_ids:\n",
" rec.detach_child(invalid_id)\n",
"\n",
" assert len(c_alpha_coords) == len(n_coords)\n",
" assert len(c_alpha_coords) == len(c_coords)\n",
" assert sum(valid_lengths) == len(c_alpha_coords)\n",
" return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ce502378-1912-4a65-940e-61eb545a4264",
"metadata": {},
"outputs": [],
"source": [
"def align_prediction(smoothing_factor, pdbbind_calpha_coords, omegafold_calpha_coords, pdbbind_ligand_coords, return_rotation=False):\n",
" pdbbind_dists = spa.distance.cdist(pdbbind_calpha_coords, pdbbind_ligand_coords)\n",
" weights = np.exp(-1 * smoothing_factor * np.amin(pdbbind_dists, axis=1))\n",
" \n",
" pdbbind_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * pdbbind_calpha_coords, axis=0) / np.sum(weights)\n",
" omegafold_calpha_centroid = np.sum(np.expand_dims(weights, axis=1) * omegafold_calpha_coords, axis=0) / np.sum(weights)\n",
" centered_pdbbind_calpha_coords = pdbbind_calpha_coords - pdbbind_calpha_centroid\n",
" centered_omegafold_calpha_coords = omegafold_calpha_coords - omegafold_calpha_centroid\n",
" centered_pdbbind_ligand_coords = pdbbind_ligand_coords - pdbbind_calpha_centroid\n",
" \n",
" rotation, rec_weighted_rmsd = spa.transform.Rotation.align_vectors(centered_pdbbind_calpha_coords, centered_omegafold_calpha_coords, weights)\n",
" if return_rotation:\n",
" return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid\n",
" \n",
" aligned_omegafold_calpha_coords = rotation.apply(centered_omegafold_calpha_coords)\n",
" aligned_omegafold_pdbbind_dists = spa.distance.cdist(aligned_omegafold_calpha_coords, centered_pdbbind_ligand_coords)\n",
" inv_r_rmse = np.sqrt(np.mean(((1 / pdbbind_dists) - (1 / aligned_omegafold_pdbbind_dists)) ** 2))\n",
" return inv_r_rmse\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7a20fbfe-a026-4815-9ed9-42c9bbd9cd08",
"metadata": {},
"outputs": [],
"source": [
"def get_alignment_rotation(pdb_id, pdbbind_protein_path, omegafold_protein_path, pdbbind_path):\n",
" pdbbind_rec = parse_pdb_from_path(pdbbind_protein_path)\n",
" omegafold_rec = parse_pdb_from_path(omegafold_protein_path)\n",
" pdbbind_ligand = read_mols(pdbbind_path, pdb_id, remove_hs=True)[0]\n",
" \n",
" pdbbind_calpha_coords = extract_receptor_structure(pdbbind_rec, pdbbind_ligand)[2]\n",
" omegafold_calpha_coords = extract_receptor_structure(omegafold_rec, pdbbind_ligand)[2]\n",
" pdbbind_ligand_coords = pdbbind_ligand.GetConformer().GetPositions()\n",
"\n",
" if pdbbind_calpha_coords.shape != omegafold_calpha_coords.shape:\n",
" print(f'Receptor structures differ for PDB ID {pdb_id} - Skipping', pdbbind_calpha_coords.shape, omegafold_calpha_coords.shape)\n",
" return None, None, None\n",
"\n",
" res = minimize(\n",
" align_prediction,\n",
" [0.1],\n",
" bounds=Bounds([0.0],[1.0]),\n",
" args=(\n",
" pdbbind_calpha_coords,\n",
" omegafold_calpha_coords,\n",
" pdbbind_ligand_coords\n",
" ),\n",
" tol=1e-8\n",
" )\n",
"\n",
" smoothing_factor = res.x\n",
" inv_r_rmse = res.fun\n",
" rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = align_prediction(\n",
" smoothing_factor,\n",
" pdbbind_calpha_coords,\n",
" omegafold_calpha_coords,\n",
" pdbbind_ligand_coords,\n",
" True\n",
" )\n",
"\n",
" return rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "befbf9d6-fb08-4e85-a326-ff903547e2f9",
"metadata": {},
"outputs": [],
"source": [
"from biopandas.pdb import PandasPdb\n",
"\n",
"for f in tqdm(os.listdir(\"data/esmfold_structures\")):\n",
" pdb_id = f.split(\"_\")[0]\n",
" \n",
" omega_protein_filename = f\"data/esmfold_structures/{pdb_id}_protein_esmfold.pdb\"\n",
" omega_protein_output_filename = f\"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_esmfold_aligned_tr.pdb\"\n",
" \n",
" rotation, pdbbind_calpha_centroid, omegafold_calpha_centroid = get_alignment_rotation(pdb_id, f\"data/PDBBind_processed/{pdb_id}/{pdb_id}_protein_processed.pdb\", \n",
" omega_protein_filename, \"data/PDBBind_processed/\")\n",
" \n",
" if rotation is None:\n",
" continue\n",
" \n",
" ppdb_omegafold = PandasPdb().read_pdb(omega_protein_filename)\n",
" ppdb_omegafold_pre_rot = ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)\n",
" ppdb_omegafold_aligned = rotation.apply(ppdb_omegafold_pre_rot - omegafold_calpha_centroid) + pdbbind_calpha_centroid\n",
" \n",
" \n",
" ppdb_omegafold.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ppdb_omegafold_aligned\n",
" ppdb_omegafold.to_pdb(path=omega_protein_output_filename, records=['ATOM'], gz=False)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "445b0f84-302f-4023-bafb-05780d8967c0",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

Binary file not shown.

225095
data/splits/pdbids_2019 Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,428 @@
5S8I_2LY
5SAK_ZRY
5SB2_1K2
5SD5_HWI
5SIS_JSM
6M2B_EZO
6M73_FNR
6T88_MWQ
6TW5_9M2
6TW7_NZB
6VS3_R6V
6VTA_AKN
6W59_SZD
6WTN_RXT
6X8D_ARA
6XAF_GDP
6XBO_5MC
6XCT_478
6XG5_TOP
6XHT_V2V
6XM9_V55
6XUM_30L
6Y7L_QMG
6YDY_K73
6YJA_2BA
6YMS_OZH
6YQV_8K2
6YQW_82I
6YR2_T1C
6YRV_PJ8
6YSP_PAL
6YT6_PKE
6YYO_Q1K
6Z0R_Q4H
6Z14_Q4Z
6Z1C_7EY
6Z2C_Q5E
6Z4N_Q7B
6Z5Z_BDF
6ZAE_ACV
6ZC3_JOR
6ZCY_QF8
6ZK5_IMH
6ZPB_3D1
6ZR8_QOZ
6ZT2_QPK
6ZX3_QRZ
6ZXQ_IMO
7A1P_QW2
7A9E_R4W
7A9H_TPP
7AA0_R6B
7AFX_R9K
7AKL_RK5
7AMC_73B
7AN5_RDH
7AS1_21G
7AVI_S2Q
7B0E_C2E
7B2C_TP7
7B94_ANP
7BA0_T5H
7BCP_GCO
7BHX_TO5
7BJ6_TVK
7BJJ_TVW
7BKA_4JC
7BLA_WCS
7BLG_GAL
7BMI_U4B
7BNH_BEZ
7BTT_F8R
7C0U_FGO
7C3U_AZG
7C6P_SQH
7C8Q_DSG
7CD9_FVR
7CIJ_G0C
7CL8_TES
7CNQ_G8X
7CNS_PMV
7CTM_BDP
7CUO_PHB
7D0P_1VU
7D5C_GV6
7D6O_MTE
7D8Q_GZF
7D9L_GSF
7DIN_MPO
7DKT_GLF
7DQL_4CL
7DUA_HJ0
7E2S_BLA
7E4L_MDN
7EBG_J0L
7ECR_SIN
7ED2_A3P
7ELT_TYM
7EN7_J79
7EPV_FDA
7ES1_UDP
7F51_BA7
7F5D_EUO
7F8T_FAD
7FB7_8NF
7FHA_ADX
7FRX_O88
7FT9_4MB
7JG0_GAR
7JGW_V9S
7JHQ_VAJ
7JMV_4NC
7JNB_A2G
7JR8_VH7
7JUD_MMA
7JXX_VP7
7JY3_VUD
7K0V_VQP
7K41_VUA
7KB1_WBJ
7KC5_BJZ
7KFO_IAC
7KLX_WOV
7KM8_WPD
7KP6_WTP
7KQU_YOF
7KRU_ATP
7KZ9_XN7
7L00_XCJ
7L03_F9F
7L5F_XNG
7L6D_BMF
7L7C_XQ1
7L81_UD4
7LB3_XXS
7LCU_XTA
7LEV_0JO
7LJN_GTP
7LMO_NYO
7LOE_Y84
7LOU_IFM
7LT0_ONJ
7LZD_YHY
7LZQ_YJV
7M31_TDR
7M3H_YPV
7M41_YQG
7M6K_YRJ
7MAE_XUS
7MEU_MGP
7MFP_Z7P
7MGT_ZD4
7MGY_ZD1
7MMH_ZJY
7MOI_HPS
7MRH_ZMJ
7MS7_ZQ1
7MSR_DCA
7MWN_WI5
7MWU_ZPM
7MY1_IPE
7MYU_ZR7
7MZS_GLA
7N03_ZRP
7N4N_0BK
7N4W_P4V
7N6F_0I1
7N7B_T3F
7N7H_CTP
7NA4_1I9
7NB4_U6Q
7NF0_BYN
7NF3_4LU
7NFB_GEN
7NGW_UAW
7NLK_UHK
7NLV_UJE
7NML_I7B
7NP6_UK8
7NPL_UKZ
7NR6_UO8
7NR8_UOE
7NSW_HC4
7NTG_F6R
7NU0_DCL
7NUT_GLP
7NXO_UU8
7O0N_CDP
7O1T_5X8
7OCB_V88
7ODX_DGP
7ODY_DGI
7OEO_V9Z
7OFF_VCB
7OFK_VCH
7OKC_VFE
7OKF_VH5
7OLI_8HG
7OLT_58J
7OMJ_GCP
7OMX_CNA
7OP9_06K
7OPG_06N
7ORW_7WA
7OSO_0V1
7OU8_1XI
7OZ9_NGK
7OZC_G6S
7P1F_KFN
7P1M_4IU
7P2I_MFU
7P2W_4QR
7P4C_5OV
7P4J_5JK
7P4V_DAT
7P5T_5YG
7P85_5ZG
7PA4_C
7PGX_FMN
7PIH_7QW
7PJQ_OWH
7PK0_BYC
7PL1_SFG
7POM_7VZ
7PRI_7TI
7PRM_81I
7PT3_3KK
7PUV_84Z
7Q19_DSM
7Q25_8J9
7Q27_8KC
7Q2B_M6H
7Q5I_I0F
7QE4_NGA
7QF4_RBF
7QFM_AY3
7QGP_DJ8
7QHG_T3B
7QHL_D5P
7QK0_EBL
7QPP_VDX
7QSW_CAP
7QTA_URI
7R3D_APR
7R59_I5F
7R6J_2I7
7R7R_AWJ
7R9N_F97
7RC3_SAH
7REE_4LY
7RH3_59O
7RH8_UTP
7RKW_5TV
7RNI_60I
7ROR_69X
7ROU_66I
7RPZ_6IC
7RSV_7IQ
7RUI_7QZ
7RWO_7WN
7RWS_4UR
7RZL_NPO
7S45_ACO
7S9H_7PP
7SCW_GSP
7SDD_4IP
7SED_8VD
7SFO_98L
7SGV_L30
7SIU_9ID
7SNE_9XR
7SSM_B7L
7SUC_COM
7SZA_DUI
7T0D_FPP
7T0U_E3I
7T1D_E7K
7T2I_E9F
7T3E_SLB
7T3F_EM0
7T9O_GEI
7TB0_UD1
7TBU_S3P
7TE8_P0T
7TH4_FFO
7THI_PGA
7TM6_GPJ
7TOM_5AD
7TS6_KMI
7TSF_H4B
7TUO_KL9
7TWC_CXS
7TXK_LW8
7TXP_0FX
7TYP_KUR
7U0U_FK5
7U3J_L6U
7UAS_MBU
7UAW_MF6
7UEY_N0R
7UF2_5SP
7UJ4_OQ4
7UJ5_DGL
7UJF_R3V
7ULC_56B
7UMV_NUU
7UMW_NAD
7UP3_NZ0
7UQ3_O2U
7USH_82V
7UTW_NAI
7UXS_OJC
7UY4_SMI
7UYB_OK0
7V14_ORU
7V3N_AKG
7V3S_5I9
7V43_C4O
7V8Z_5YH
7VB8_STL
7VBU_6I4
7VC5_9SF
7VJT_7IJ
7VKZ_NOJ
7VQ9_ISY
7VWF_K55
7VYJ_CA0
7W05_GMP
7W06_ITN
7W6F_8I6
7WCF_ACP
7WDT_NGS
7WJB_BGC
7WKL_CAQ
7WL4_JFU
7WN5_JGL
7WPW_F15
7WQQ_5Z6
7WUX_6OI
7WUY_76N
7WY1_D0L
7X5N_5M5
7X9K_8OG
7XBV_APC
7XEK_9YX
7XFA_D9J
7XG5_PLP
7XI7_4RI
7XIJ_EJ3
7XJN_NSD
7XPO_UPG
7XQZ_FPF
7XRL_FWK
7YZU_DO7
7Z1Q_NIO
7Z2O_IAJ
7Z7F_IF3
7ZCC_OGA
7ZDY_6MJ
7ZF0_DHR
7ZHP_IQY
7ZL5_IWE
7ZOC_T8E
7ZTL_BCN
7ZU2_DHT
7ZXV_45D
7ZXZ_K9R
7ZYS_KNR
7ZZB_KGX
7ZZW_KKW
8A1H_DLZ
8A2D_KXY
8AAU_LH0
8ACL_LQL
8AEM_LVF
8AEU_M0L
8AIE_M7L
8AIJ_M9I
8AJX_FUM
8AP0_PRP
8AQL_PLG
8AUH_L9I
8AY3_OE3
8B8H_OJQ
8BN6_R53
8BOM_QU6
8BPL_CP
8BRO_R7E
8BTI_RFO
8C3N_ADP
8C5D_GTB
8C5M_MTA
8C7Y_TXV
8CGC_LMR
8CI0_8EL
8CNH_V6U
8CSD_C5P
8D19_GSH
8D39_QDB
8D5D_5DK
8DHG_T78
8DKO_TFB
8DP2_UMA
8DSC_NCA
8DW5_FQ7
8DZT_G4P
8E77_ULP
8EAB_VN2
8EAD_UY0
8ERS_WQO
8EX2_Q2Q
8EXL_799
8EYE_X4I
8F4J_PHO
8F8E_XJI
8FAV_4Y5
8FLN_Y7W
8FLV_ZB9
8FO5_Y4U
8FV9_80J
8G0V_YHT
8G43_ZU6
8G6P_API
8GFD_ZHR
8H0M_2EH
8HFN_XGC
8HO0_3ZI
8SLG_G5A

Binary file not shown.

View File

@@ -0,0 +1,585 @@
4f49
4z88
3q8d
5yvx
2w2i
4dcd
4nuf
4xxh
4wk2
3odl
1ui0
4ris
4d3h
3k3g
6drt
4b4q
2p59
4u3f
2o9a
3m94
2y3p
3rdv
1ols
2yiv
1jgl
6bj2
2ke1
5ah2
2xtk
6hly
4wk1
1fpy
3iw4
1olu
4xqu
1i3z
5zla
2cfd
5zku
4dj7
3evf
1iup
6e5x
5j5d
3n2c
3odi
4c94
2x7x
4rww
5x33
1olx
4mfe
1wvc
2wzm
4pxf
1v1m
1v11
6hm2
5uxf
3rde
5syn
2qta
2puy
2za5
4nrt
3rg2
2jc0
4xtt
1v16
1bq4
1uef
5o0j
6hlz
2p4s
2o4h
1njj
2cfg
3ucj
16pk
1u9l
4lkj
4p5z
1mpl
4u7t
1jr1
5wbf
3ess
2n7b
6hlx
5j5t
2qw1
4wn5
1xt3
5ks7
1ff1
4up5
4tx6
4tsz
1dpu
1oqp
5bml
6fdt
4ie6
3juo
2xvn
2xye
4olh
4udb
1my2
4n9e
4dht
3azb
6eqp
4eb9
4okg
4q18
5a3w
2mip
4yp1
2wnc
3skf
3zm9
1tsm
3wzp
3qqk
1pmv
3q4l
4x6p
5xij
4wkt
4g31
5y3n
1cw2
4tk0
2y34
4a4l
3dga
4igr
5fnt
2d06
6gjy
4j84
5icy
5oui
3rj7
6ccy
4xu3
5e0l
6fjm
2igx
4odl
4gr8
4bhf
3iux
5v82
4c6z
4yje
4rz1
5hv1
2vc7
1jyc
2lgf
4clj
5aom
1l5r
1nvq
5ukl
6b96
1xmu
2bcd
2i0a
4yrr
4w9g
3lxl
1sqo
3wk6
3t1l
3p1d
2eg7
5o0s
5jn8
2wap
5u9i
6min
3eyh
4p4e
6mr5
1sqp
1f8a
1d4w
5oa6
3kx1
5myr
2iwu
1ero
4xnw
4bch
5tob
4uxl
1jd6
3hfj
4de3
2brp
4lkg
2fme
2y5g
4glr
1k08
1hyv
3ske
4hv7
3fun
3gwu
4xsx
4ec4
5zmq
3zlw
3fcl
5f3t
3mfv
3bzf
4jbo
2ow2
5l2z
2qci
1hiy
6f6n
6fgg
1z2b
1szm
3n9r
5dhr
1jvu
1jd0
4a9r
2vgc
5y5t
4iax
5l4m
2vnp
3n3l
1qxz
1tok
4aqh
1wbg
5mkx
2oqs
6dar
5v49
3ft2
3o2m
3g4l
4bgy
5gsw
1fkw
2gg8
3tv5
4at5
4kn4
6cvw
1x7b
6fsd
3hig
5ewa
5vij
3d4f
5vyy
3v66
3tzd
4xkc
4y4g
2w0x
5mft
3p78
2xn3
4x3h
4d8z
4r17
6bik
1sqc
4jkw
4m7c
4kpx
1a52
4qew
3zlk
5np8
1qpe
4wmx
1z4u
5j7q
4xt9
4tpm
5wlg
5fot
4bcc
5l3e
3nti
1xr9
1zpa
5we9
5avi
4mbj
3cwj
6ck6
4ir5
5l8n
3bz3
1azm
5a82
3u15
5wbl
5mli
1ujj
3wp1
3obu
3vzd
4rj5
5d6p
2xkf
4ib5
3jsw
2zv9
1okz
3ff6
2gz8
6bg5
2wk2
4wet
4l53
3jzg
1g3d
3pka
5yg4
4few
1z9g
5evd
2a4q
5ehg
3max
6msy
1fcy
4nhy
1k4g
5lz8
4kqr
1t1r
2z78
3dz6
4ygf
4ht6
2v7a
3wax
4ivc
3u7l
3hvi
5o0e
2c97
4b6f
5mmn
6ee3
3fmz
5wzv
5znr
5ct2
4fai
5uit
2gkl
2vpg
1jtq
3veh
4jit
4bgg
2ce9
5aqu
6hth
4m5k
1aj6
3h22
5vnd
1q4l
2j7x
3acx
5l6p
5m0m
5oh9
2a5c
5npc
4qk4
1fzj
5bqi
5c4l
1fq4
2v5x
4k63
3uik
1me8
5uq9
4nrq
2r1y
3ldw
4bh4
4pl5
1mfg
3c5u
3iqh
6awo
5d1t
4urn
3dnj
4csy
1caq
4oem
4wke
3pb7
1qwu
5g3m
2buc
4i9h
5epr
5ivy
3krj
3f9w
3rwj
4bo4
4amx
4wvl
5x54
2a58
6ma1
1i41
2ew5
1y57
3ovn
5g1n
2bba
3iej
6eip
3iaf
2uy4
830c
1aqc
2wot
1bmq
4nal
3kbz
1gt1
4pp9
2aq9
1b7h
3lp1
3exf
6g9a
4gzw
1p06
4own
3qo2
5emk
1gvk
6chn
4r1v
1q95
4x5q
2x6x
4lh6
5mb1
4zzz
4ty9
2hkf
1n3i
4b6p
5vo2
2ovq
1vcj
3kjn
4zz1
1f73
6baw
6arv
2yxj
2xo8
3hqz
4li6
1uze
1ih0
4uwh
5l6h
3g0w
5v3r
6ccu
2g78
3eqr
5d10
2vte
4e3j
2a5u
5duc
5a0e
4zzx
6ezq
6e4w
6m8w
3mo0
3eyd
1pl0
4bhz
2wxn
3vb7
3daj
3ara
4b4m
2y59
3tcp
5hjd
4wym
4qo9
4yz9
2jew
5tq4
4qsh
4djp
3rqg
4ddh
6bj3
3czv
4jok
5n2d
2afw
4h2o
2vd7
2gvf
2yhd
3drg
1zc9
5khi
5aix
2y2i
4yax
5d75
4m3b
5hjc
4u0b
4n7j
4b6s
3lq4
1lhu
5ur5
5fcz
5fyx
3ml5
4mb9
4nzn
5wpb
4omd
5t2d
1bzh
5idn
4f9v
2m3o
6cw4
4i6f
3t2t
4ep2
5x4o
3vqh
4itj
1p93
4b5w
5ech
2ltw
2hrp
5lpm
5twh
2gbg
4jjq
4kil
2r7g
5dgu
1hbj
5vsj
6e05
1jq9
4gto
4gii
2wc4

View File

@@ -1,364 +1,364 @@
complex_name,protein_path,ligand_description,protein_sequence
0,data/PDBBind_processed/6qqw/6qqw_protein_processed.pdb,data/PDBBind_processed/6qqw/6qqw_ligand.mol2,
1,data/PDBBind_processed/6d08/6d08_protein_processed.pdb,data/PDBBind_processed/6d08/6d08_ligand.sdf,
2,data/PDBBind_processed/6jap/6jap_protein_processed.pdb,data/PDBBind_processed/6jap/6jap_ligand.sdf,
3,data/PDBBind_processed/6np2/6np2_protein_processed.pdb,data/PDBBind_processed/6np2/6np2_ligand.sdf,
4,data/PDBBind_processed/6uvp/6uvp_protein_processed.pdb,data/PDBBind_processed/6uvp/6uvp_ligand.sdf,
5,data/PDBBind_processed/6oxq/6oxq_protein_processed.pdb,data/PDBBind_processed/6oxq/6oxq_ligand.sdf,
6,data/PDBBind_processed/6jsn/6jsn_protein_processed.pdb,data/PDBBind_processed/6jsn/6jsn_ligand.sdf,
7,data/PDBBind_processed/6hzb/6hzb_protein_processed.pdb,data/PDBBind_processed/6hzb/6hzb_ligand.sdf,
8,data/PDBBind_processed/6qrc/6qrc_protein_processed.pdb,data/PDBBind_processed/6qrc/6qrc_ligand.mol2,
9,data/PDBBind_processed/6oio/6oio_protein_processed.pdb,data/PDBBind_processed/6oio/6oio_ligand.sdf,
10,data/PDBBind_processed/6jag/6jag_protein_processed.pdb,data/PDBBind_processed/6jag/6jag_ligand.sdf,
11,data/PDBBind_processed/6moa/6moa_protein_processed.pdb,data/PDBBind_processed/6moa/6moa_ligand.mol2,
12,data/PDBBind_processed/6hld/6hld_protein_processed.pdb,data/PDBBind_processed/6hld/6hld_ligand.sdf,
13,data/PDBBind_processed/6i9a/6i9a_protein_processed.pdb,data/PDBBind_processed/6i9a/6i9a_ligand.sdf,
14,data/PDBBind_processed/6e4c/6e4c_protein_processed.pdb,data/PDBBind_processed/6e4c/6e4c_ligand.sdf,
15,data/PDBBind_processed/6g24/6g24_protein_processed.pdb,data/PDBBind_processed/6g24/6g24_ligand.sdf,
16,data/PDBBind_processed/6jb4/6jb4_protein_processed.pdb,data/PDBBind_processed/6jb4/6jb4_ligand.sdf,
17,data/PDBBind_processed/6s55/6s55_protein_processed.pdb,data/PDBBind_processed/6s55/6s55_ligand.sdf,
18,data/PDBBind_processed/6seo/6seo_protein_processed.pdb,data/PDBBind_processed/6seo/6seo_ligand.sdf,
19,data/PDBBind_processed/6dyz/6dyz_protein_processed.pdb,data/PDBBind_processed/6dyz/6dyz_ligand.mol2,
20,data/PDBBind_processed/5zk5/5zk5_protein_processed.pdb,data/PDBBind_processed/5zk5/5zk5_ligand.sdf,
21,data/PDBBind_processed/6jid/6jid_protein_processed.pdb,data/PDBBind_processed/6jid/6jid_ligand.sdf,
22,data/PDBBind_processed/5ze6/5ze6_protein_processed.pdb,data/PDBBind_processed/5ze6/5ze6_ligand.sdf,
23,data/PDBBind_processed/6qlu/6qlu_protein_processed.pdb,data/PDBBind_processed/6qlu/6qlu_ligand.sdf,
24,data/PDBBind_processed/6a6k/6a6k_protein_processed.pdb,data/PDBBind_processed/6a6k/6a6k_ligand.sdf,
25,data/PDBBind_processed/6qgf/6qgf_protein_processed.pdb,data/PDBBind_processed/6qgf/6qgf_ligand.sdf,
26,data/PDBBind_processed/6e3z/6e3z_protein_processed.pdb,data/PDBBind_processed/6e3z/6e3z_ligand.sdf,
27,data/PDBBind_processed/6te6/6te6_protein_processed.pdb,data/PDBBind_processed/6te6/6te6_ligand.sdf,
28,data/PDBBind_processed/6pka/6pka_protein_processed.pdb,data/PDBBind_processed/6pka/6pka_ligand.sdf,
29,data/PDBBind_processed/6g2o/6g2o_protein_processed.pdb,data/PDBBind_processed/6g2o/6g2o_ligand.sdf,
30,data/PDBBind_processed/6jsf/6jsf_protein_processed.pdb,data/PDBBind_processed/6jsf/6jsf_ligand.sdf,
31,data/PDBBind_processed/5zxk/5zxk_protein_processed.pdb,data/PDBBind_processed/5zxk/5zxk_ligand.sdf,
32,data/PDBBind_processed/6qxd/6qxd_protein_processed.pdb,data/PDBBind_processed/6qxd/6qxd_ligand.sdf,
33,data/PDBBind_processed/6n97/6n97_protein_processed.pdb,data/PDBBind_processed/6n97/6n97_ligand.sdf,
34,data/PDBBind_processed/6jt3/6jt3_protein_processed.pdb,data/PDBBind_processed/6jt3/6jt3_ligand.sdf,
35,data/PDBBind_processed/6qtr/6qtr_protein_processed.pdb,data/PDBBind_processed/6qtr/6qtr_ligand.sdf,
36,data/PDBBind_processed/6oy1/6oy1_protein_processed.pdb,data/PDBBind_processed/6oy1/6oy1_ligand.sdf,
37,data/PDBBind_processed/6n96/6n96_protein_processed.pdb,data/PDBBind_processed/6n96/6n96_ligand.sdf,
38,data/PDBBind_processed/6qzh/6qzh_protein_processed.pdb,data/PDBBind_processed/6qzh/6qzh_ligand.sdf,
39,data/PDBBind_processed/6qqz/6qqz_protein_processed.pdb,data/PDBBind_processed/6qqz/6qqz_ligand.mol2,
40,data/PDBBind_processed/6qmt/6qmt_protein_processed.pdb,data/PDBBind_processed/6qmt/6qmt_ligand.sdf,
41,data/PDBBind_processed/6ibx/6ibx_protein_processed.pdb,data/PDBBind_processed/6ibx/6ibx_ligand.sdf,
42,data/PDBBind_processed/6hmt/6hmt_protein_processed.pdb,data/PDBBind_processed/6hmt/6hmt_ligand.sdf,
43,data/PDBBind_processed/5zk7/5zk7_protein_processed.pdb,data/PDBBind_processed/5zk7/5zk7_ligand.sdf,
44,data/PDBBind_processed/6k3l/6k3l_protein_processed.pdb,data/PDBBind_processed/6k3l/6k3l_ligand.sdf,
45,data/PDBBind_processed/6cjs/6cjs_protein_processed.pdb,data/PDBBind_processed/6cjs/6cjs_ligand.sdf,
46,data/PDBBind_processed/6n9l/6n9l_protein_processed.pdb,data/PDBBind_processed/6n9l/6n9l_ligand.sdf,
47,data/PDBBind_processed/6ibz/6ibz_protein_processed.pdb,data/PDBBind_processed/6ibz/6ibz_ligand.sdf,
48,data/PDBBind_processed/6ott/6ott_protein_processed.pdb,data/PDBBind_processed/6ott/6ott_ligand.sdf,
49,data/PDBBind_processed/6gge/6gge_protein_processed.pdb,data/PDBBind_processed/6gge/6gge_ligand.sdf,
50,data/PDBBind_processed/6hot/6hot_protein_processed.pdb,data/PDBBind_processed/6hot/6hot_ligand.sdf,
51,data/PDBBind_processed/6e3p/6e3p_protein_processed.pdb,data/PDBBind_processed/6e3p/6e3p_ligand.mol2,
52,data/PDBBind_processed/6md6/6md6_protein_processed.pdb,data/PDBBind_processed/6md6/6md6_ligand.sdf,
53,data/PDBBind_processed/6hlb/6hlb_protein_processed.pdb,data/PDBBind_processed/6hlb/6hlb_ligand.sdf,
54,data/PDBBind_processed/6fe5/6fe5_protein_processed.pdb,data/PDBBind_processed/6fe5/6fe5_ligand.sdf,
55,data/PDBBind_processed/6uwp/6uwp_protein_processed.pdb,data/PDBBind_processed/6uwp/6uwp_ligand.sdf,
56,data/PDBBind_processed/6npp/6npp_protein_processed.pdb,data/PDBBind_processed/6npp/6npp_ligand.sdf,
57,data/PDBBind_processed/6g2f/6g2f_protein_processed.pdb,data/PDBBind_processed/6g2f/6g2f_ligand.sdf,
58,data/PDBBind_processed/6mo7/6mo7_protein_processed.pdb,data/PDBBind_processed/6mo7/6mo7_ligand.sdf,
59,data/PDBBind_processed/6bqd/6bqd_protein_processed.pdb,data/PDBBind_processed/6bqd/6bqd_ligand.mol2,
60,data/PDBBind_processed/6nsv/6nsv_protein_processed.pdb,data/PDBBind_processed/6nsv/6nsv_ligand.mol2,
61,data/PDBBind_processed/6i76/6i76_protein_processed.pdb,data/PDBBind_processed/6i76/6i76_ligand.sdf,
62,data/PDBBind_processed/6n53/6n53_protein_processed.pdb,data/PDBBind_processed/6n53/6n53_ligand.sdf,
63,data/PDBBind_processed/6g2c/6g2c_protein_processed.pdb,data/PDBBind_processed/6g2c/6g2c_ligand.sdf,
64,data/PDBBind_processed/6eeb/6eeb_protein_processed.pdb,data/PDBBind_processed/6eeb/6eeb_ligand.mol2,
65,data/PDBBind_processed/6n0m/6n0m_protein_processed.pdb,data/PDBBind_processed/6n0m/6n0m_ligand.sdf,
66,data/PDBBind_processed/6uvy/6uvy_protein_processed.pdb,data/PDBBind_processed/6uvy/6uvy_ligand.sdf,
67,data/PDBBind_processed/6ovz/6ovz_protein_processed.pdb,data/PDBBind_processed/6ovz/6ovz_ligand.sdf,
68,data/PDBBind_processed/6olx/6olx_protein_processed.pdb,data/PDBBind_processed/6olx/6olx_ligand.sdf,
69,data/PDBBind_processed/6v5l/6v5l_protein_processed.pdb,data/PDBBind_processed/6v5l/6v5l_ligand.mol2,
70,data/PDBBind_processed/6hhg/6hhg_protein_processed.pdb,data/PDBBind_processed/6hhg/6hhg_ligand.sdf,
71,data/PDBBind_processed/5zcu/5zcu_protein_processed.pdb,data/PDBBind_processed/5zcu/5zcu_ligand.sdf,
72,data/PDBBind_processed/6dz2/6dz2_protein_processed.pdb,data/PDBBind_processed/6dz2/6dz2_ligand.mol2,
73,data/PDBBind_processed/6mjq/6mjq_protein_processed.pdb,data/PDBBind_processed/6mjq/6mjq_ligand.sdf,
74,data/PDBBind_processed/6efk/6efk_protein_processed.pdb,data/PDBBind_processed/6efk/6efk_ligand.sdf,
75,data/PDBBind_processed/6s9w/6s9w_protein_processed.pdb,data/PDBBind_processed/6s9w/6s9w_ligand.sdf,
76,data/PDBBind_processed/6gdy/6gdy_protein_processed.pdb,data/PDBBind_processed/6gdy/6gdy_ligand.sdf,
77,data/PDBBind_processed/6kqi/6kqi_protein_processed.pdb,data/PDBBind_processed/6kqi/6kqi_ligand.sdf,
78,data/PDBBind_processed/6ueg/6ueg_protein_processed.pdb,data/PDBBind_processed/6ueg/6ueg_ligand.sdf,
79,data/PDBBind_processed/6oxt/6oxt_protein_processed.pdb,data/PDBBind_processed/6oxt/6oxt_ligand.sdf,
80,data/PDBBind_processed/6oy0/6oy0_protein_processed.pdb,data/PDBBind_processed/6oy0/6oy0_ligand.sdf,
81,data/PDBBind_processed/6qr7/6qr7_protein_processed.pdb,data/PDBBind_processed/6qr7/6qr7_ligand.mol2,
82,data/PDBBind_processed/6i41/6i41_protein_processed.pdb,data/PDBBind_processed/6i41/6i41_ligand.sdf,
83,data/PDBBind_processed/6cyg/6cyg_protein_processed.pdb,data/PDBBind_processed/6cyg/6cyg_ligand.sdf,
84,data/PDBBind_processed/6qmr/6qmr_protein_processed.pdb,data/PDBBind_processed/6qmr/6qmr_ligand.sdf,
85,data/PDBBind_processed/6g27/6g27_protein_processed.pdb,data/PDBBind_processed/6g27/6g27_ligand.sdf,
86,data/PDBBind_processed/6ggb/6ggb_protein_processed.pdb,data/PDBBind_processed/6ggb/6ggb_ligand.sdf,
87,data/PDBBind_processed/6g3c/6g3c_protein_processed.pdb,data/PDBBind_processed/6g3c/6g3c_ligand.sdf,
88,data/PDBBind_processed/6n4e/6n4e_protein_processed.pdb,data/PDBBind_processed/6n4e/6n4e_ligand.sdf,
89,data/PDBBind_processed/6fcj/6fcj_protein_processed.pdb,data/PDBBind_processed/6fcj/6fcj_ligand.sdf,
90,data/PDBBind_processed/6quv/6quv_protein_processed.pdb,data/PDBBind_processed/6quv/6quv_ligand.sdf,
91,data/PDBBind_processed/6iql/6iql_protein_processed.pdb,data/PDBBind_processed/6iql/6iql_ligand.mol2,
92,data/PDBBind_processed/6i74/6i74_protein_processed.pdb,data/PDBBind_processed/6i74/6i74_ligand.sdf,
93,data/PDBBind_processed/6qr4/6qr4_protein_processed.pdb,data/PDBBind_processed/6qr4/6qr4_ligand.mol2,
94,data/PDBBind_processed/6rnu/6rnu_protein_processed.pdb,data/PDBBind_processed/6rnu/6rnu_ligand.sdf,
95,data/PDBBind_processed/6jib/6jib_protein_processed.pdb,data/PDBBind_processed/6jib/6jib_ligand.sdf,
96,data/PDBBind_processed/6izq/6izq_protein_processed.pdb,data/PDBBind_processed/6izq/6izq_ligand.sdf,
97,data/PDBBind_processed/6qw8/6qw8_protein_processed.pdb,data/PDBBind_processed/6qw8/6qw8_ligand.sdf,
98,data/PDBBind_processed/6qto/6qto_protein_processed.pdb,data/PDBBind_processed/6qto/6qto_ligand.sdf,
99,data/PDBBind_processed/6qrd/6qrd_protein_processed.pdb,data/PDBBind_processed/6qrd/6qrd_ligand.mol2,
100,data/PDBBind_processed/6hza/6hza_protein_processed.pdb,data/PDBBind_processed/6hza/6hza_ligand.sdf,
101,data/PDBBind_processed/6e5s/6e5s_protein_processed.pdb,data/PDBBind_processed/6e5s/6e5s_ligand.sdf,
102,data/PDBBind_processed/6dz3/6dz3_protein_processed.pdb,data/PDBBind_processed/6dz3/6dz3_ligand.mol2,
103,data/PDBBind_processed/6e6w/6e6w_protein_processed.pdb,data/PDBBind_processed/6e6w/6e6w_ligand.mol2,
104,data/PDBBind_processed/6cyh/6cyh_protein_processed.pdb,data/PDBBind_processed/6cyh/6cyh_ligand.sdf,
105,data/PDBBind_processed/5zlf/5zlf_protein_processed.pdb,data/PDBBind_processed/5zlf/5zlf_ligand.sdf,
106,data/PDBBind_processed/6om4/6om4_protein_processed.pdb,data/PDBBind_processed/6om4/6om4_ligand.sdf,
107,data/PDBBind_processed/6gga/6gga_protein_processed.pdb,data/PDBBind_processed/6gga/6gga_ligand.sdf,
108,data/PDBBind_processed/6pgp/6pgp_protein_processed.pdb,data/PDBBind_processed/6pgp/6pgp_ligand.sdf,
109,data/PDBBind_processed/6qqv/6qqv_protein_processed.pdb,data/PDBBind_processed/6qqv/6qqv_ligand.mol2,
110,data/PDBBind_processed/6qtq/6qtq_protein_processed.pdb,data/PDBBind_processed/6qtq/6qtq_ligand.sdf,
111,data/PDBBind_processed/6gj6/6gj6_protein_processed.pdb,data/PDBBind_processed/6gj6/6gj6_ligand.mol2,
112,data/PDBBind_processed/6os5/6os5_protein_processed.pdb,data/PDBBind_processed/6os5/6os5_ligand.mol2,
113,data/PDBBind_processed/6s07/6s07_protein_processed.pdb,data/PDBBind_processed/6s07/6s07_ligand.sdf,
114,data/PDBBind_processed/6i77/6i77_protein_processed.pdb,data/PDBBind_processed/6i77/6i77_ligand.sdf,
115,data/PDBBind_processed/6hhj/6hhj_protein_processed.pdb,data/PDBBind_processed/6hhj/6hhj_ligand.sdf,
116,data/PDBBind_processed/6ahs/6ahs_protein_processed.pdb,data/PDBBind_processed/6ahs/6ahs_ligand.sdf,
117,data/PDBBind_processed/6oxx/6oxx_protein_processed.pdb,data/PDBBind_processed/6oxx/6oxx_ligand.sdf,
118,data/PDBBind_processed/6mjj/6mjj_protein_processed.pdb,data/PDBBind_processed/6mjj/6mjj_ligand.sdf,
119,data/PDBBind_processed/6hor/6hor_protein_processed.pdb,data/PDBBind_processed/6hor/6hor_ligand.sdf,
120,data/PDBBind_processed/6jb0/6jb0_protein_processed.pdb,data/PDBBind_processed/6jb0/6jb0_ligand.sdf,
121,data/PDBBind_processed/6i68/6i68_protein_processed.pdb,data/PDBBind_processed/6i68/6i68_ligand.sdf,
122,data/PDBBind_processed/6pz4/6pz4_protein_processed.pdb,data/PDBBind_processed/6pz4/6pz4_ligand.sdf,
123,data/PDBBind_processed/6mhb/6mhb_protein_processed.pdb,data/PDBBind_processed/6mhb/6mhb_ligand.sdf,
124,data/PDBBind_processed/6uim/6uim_protein_processed.pdb,data/PDBBind_processed/6uim/6uim_ligand.sdf,
125,data/PDBBind_processed/6jsg/6jsg_protein_processed.pdb,data/PDBBind_processed/6jsg/6jsg_ligand.sdf,
126,data/PDBBind_processed/6i78/6i78_protein_processed.pdb,data/PDBBind_processed/6i78/6i78_ligand.sdf,
127,data/PDBBind_processed/6oxy/6oxy_protein_processed.pdb,data/PDBBind_processed/6oxy/6oxy_ligand.sdf,
128,data/PDBBind_processed/6gbw/6gbw_protein_processed.pdb,data/PDBBind_processed/6gbw/6gbw_ligand.sdf,
129,data/PDBBind_processed/6mo0/6mo0_protein_processed.pdb,data/PDBBind_processed/6mo0/6mo0_ligand.sdf,
130,data/PDBBind_processed/6ggf/6ggf_protein_processed.pdb,data/PDBBind_processed/6ggf/6ggf_ligand.sdf,
131,data/PDBBind_processed/6qge/6qge_protein_processed.pdb,data/PDBBind_processed/6qge/6qge_ligand.sdf,
132,data/PDBBind_processed/6cjr/6cjr_protein_processed.pdb,data/PDBBind_processed/6cjr/6cjr_ligand.sdf,
133,data/PDBBind_processed/6oxp/6oxp_protein_processed.pdb,data/PDBBind_processed/6oxp/6oxp_ligand.sdf,
134,data/PDBBind_processed/6d07/6d07_protein_processed.pdb,data/PDBBind_processed/6d07/6d07_ligand.sdf,
135,data/PDBBind_processed/6i63/6i63_protein_processed.pdb,data/PDBBind_processed/6i63/6i63_ligand.sdf,
136,data/PDBBind_processed/6ten/6ten_protein_processed.pdb,data/PDBBind_processed/6ten/6ten_ligand.sdf,
137,data/PDBBind_processed/6uii/6uii_protein_processed.pdb,data/PDBBind_processed/6uii/6uii_ligand.sdf,
138,data/PDBBind_processed/6qlr/6qlr_protein_processed.pdb,data/PDBBind_processed/6qlr/6qlr_ligand.sdf,
139,data/PDBBind_processed/6sen/6sen_protein_processed.pdb,data/PDBBind_processed/6sen/6sen_ligand.mol2,
140,data/PDBBind_processed/6oxv/6oxv_protein_processed.pdb,data/PDBBind_processed/6oxv/6oxv_ligand.sdf,
141,data/PDBBind_processed/6g2b/6g2b_protein_processed.pdb,data/PDBBind_processed/6g2b/6g2b_ligand.sdf,
142,data/PDBBind_processed/5zr3/5zr3_protein_processed.pdb,data/PDBBind_processed/5zr3/5zr3_ligand.sdf,
143,data/PDBBind_processed/6kjf/6kjf_protein_processed.pdb,data/PDBBind_processed/6kjf/6kjf_ligand.sdf,
144,data/PDBBind_processed/6qr9/6qr9_protein_processed.pdb,data/PDBBind_processed/6qr9/6qr9_ligand.mol2,
145,data/PDBBind_processed/6g9f/6g9f_protein_processed.pdb,data/PDBBind_processed/6g9f/6g9f_ligand.sdf,
146,data/PDBBind_processed/6e6v/6e6v_protein_processed.pdb,data/PDBBind_processed/6e6v/6e6v_ligand.sdf,
147,data/PDBBind_processed/5zk9/5zk9_protein_processed.pdb,data/PDBBind_processed/5zk9/5zk9_ligand.sdf,
148,data/PDBBind_processed/6pnn/6pnn_protein_processed.pdb,data/PDBBind_processed/6pnn/6pnn_ligand.sdf,
149,data/PDBBind_processed/6nri/6nri_protein_processed.pdb,data/PDBBind_processed/6nri/6nri_ligand.sdf,
150,data/PDBBind_processed/6uwv/6uwv_protein_processed.pdb,data/PDBBind_processed/6uwv/6uwv_ligand.sdf,
151,data/PDBBind_processed/6ooz/6ooz_protein_processed.pdb,data/PDBBind_processed/6ooz/6ooz_ligand.sdf,
152,data/PDBBind_processed/6npi/6npi_protein_processed.pdb,data/PDBBind_processed/6npi/6npi_ligand.sdf,
153,data/PDBBind_processed/6oip/6oip_protein_processed.pdb,data/PDBBind_processed/6oip/6oip_ligand.sdf,
154,data/PDBBind_processed/6miv/6miv_protein_processed.pdb,data/PDBBind_processed/6miv/6miv_ligand.sdf,
155,data/PDBBind_processed/6s57/6s57_protein_processed.pdb,data/PDBBind_processed/6s57/6s57_ligand.sdf,
156,data/PDBBind_processed/6p8x/6p8x_protein_processed.pdb,data/PDBBind_processed/6p8x/6p8x_ligand.sdf,
157,data/PDBBind_processed/6hoq/6hoq_protein_processed.pdb,data/PDBBind_processed/6hoq/6hoq_ligand.sdf,
158,data/PDBBind_processed/6qts/6qts_protein_processed.pdb,data/PDBBind_processed/6qts/6qts_ligand.sdf,
159,data/PDBBind_processed/6ggd/6ggd_protein_processed.pdb,data/PDBBind_processed/6ggd/6ggd_ligand.sdf,
160,data/PDBBind_processed/6pnm/6pnm_protein_processed.pdb,data/PDBBind_processed/6pnm/6pnm_ligand.sdf,
161,data/PDBBind_processed/6oy2/6oy2_protein_processed.pdb,data/PDBBind_processed/6oy2/6oy2_ligand.sdf,
162,data/PDBBind_processed/6oi8/6oi8_protein_processed.pdb,data/PDBBind_processed/6oi8/6oi8_ligand.sdf,
163,data/PDBBind_processed/6mhd/6mhd_protein_processed.pdb,data/PDBBind_processed/6mhd/6mhd_ligand.sdf,
164,data/PDBBind_processed/6agt/6agt_protein_processed.pdb,data/PDBBind_processed/6agt/6agt_ligand.sdf,
165,data/PDBBind_processed/6i5p/6i5p_protein_processed.pdb,data/PDBBind_processed/6i5p/6i5p_ligand.sdf,
166,data/PDBBind_processed/6hhr/6hhr_protein_processed.pdb,data/PDBBind_processed/6hhr/6hhr_ligand.sdf,
167,data/PDBBind_processed/6p8z/6p8z_protein_processed.pdb,data/PDBBind_processed/6p8z/6p8z_ligand.sdf,
168,data/PDBBind_processed/6c85/6c85_protein_processed.pdb,data/PDBBind_processed/6c85/6c85_ligand.sdf,
169,data/PDBBind_processed/6g5u/6g5u_protein_processed.pdb,data/PDBBind_processed/6g5u/6g5u_ligand.sdf,
170,data/PDBBind_processed/6j06/6j06_protein_processed.pdb,data/PDBBind_processed/6j06/6j06_ligand.sdf,
171,data/PDBBind_processed/6qsz/6qsz_protein_processed.pdb,data/PDBBind_processed/6qsz/6qsz_ligand.sdf,
172,data/PDBBind_processed/6jbb/6jbb_protein_processed.pdb,data/PDBBind_processed/6jbb/6jbb_ligand.sdf,
173,data/PDBBind_processed/6hhp/6hhp_protein_processed.pdb,data/PDBBind_processed/6hhp/6hhp_ligand.sdf,
174,data/PDBBind_processed/6np5/6np5_protein_processed.pdb,data/PDBBind_processed/6np5/6np5_ligand.sdf,
175,data/PDBBind_processed/6nlj/6nlj_protein_processed.pdb,data/PDBBind_processed/6nlj/6nlj_ligand.sdf,
176,data/PDBBind_processed/6qlp/6qlp_protein_processed.pdb,data/PDBBind_processed/6qlp/6qlp_ligand.sdf,
177,data/PDBBind_processed/6n94/6n94_protein_processed.pdb,data/PDBBind_processed/6n94/6n94_ligand.sdf,
178,data/PDBBind_processed/6e13/6e13_protein_processed.pdb,data/PDBBind_processed/6e13/6e13_ligand.sdf,
179,data/PDBBind_processed/6qls/6qls_protein_processed.pdb,data/PDBBind_processed/6qls/6qls_ligand.sdf,
180,data/PDBBind_processed/6uil/6uil_protein_processed.pdb,data/PDBBind_processed/6uil/6uil_ligand.sdf,
181,data/PDBBind_processed/6st3/6st3_protein_processed.pdb,data/PDBBind_processed/6st3/6st3_ligand.sdf,
182,data/PDBBind_processed/6n92/6n92_protein_processed.pdb,data/PDBBind_processed/6n92/6n92_ligand.sdf,
183,data/PDBBind_processed/6s56/6s56_protein_processed.pdb,data/PDBBind_processed/6s56/6s56_ligand.sdf,
184,data/PDBBind_processed/6hzd/6hzd_protein_processed.pdb,data/PDBBind_processed/6hzd/6hzd_ligand.sdf,
185,data/PDBBind_processed/6uhv/6uhv_protein_processed.pdb,data/PDBBind_processed/6uhv/6uhv_ligand.sdf,
186,data/PDBBind_processed/6k05/6k05_protein_processed.pdb,data/PDBBind_processed/6k05/6k05_ligand.sdf,
187,data/PDBBind_processed/6q36/6q36_protein_processed.pdb,data/PDBBind_processed/6q36/6q36_ligand.mol2,
188,data/PDBBind_processed/6ic0/6ic0_protein_processed.pdb,data/PDBBind_processed/6ic0/6ic0_ligand.sdf,
189,data/PDBBind_processed/6hhi/6hhi_protein_processed.pdb,data/PDBBind_processed/6hhi/6hhi_ligand.sdf,
190,data/PDBBind_processed/6e3m/6e3m_protein_processed.pdb,data/PDBBind_processed/6e3m/6e3m_ligand.sdf,
191,data/PDBBind_processed/6qtx/6qtx_protein_processed.pdb,data/PDBBind_processed/6qtx/6qtx_ligand.sdf,
192,data/PDBBind_processed/6jse/6jse_protein_processed.pdb,data/PDBBind_processed/6jse/6jse_ligand.sdf,
193,data/PDBBind_processed/5zjy/5zjy_protein_processed.pdb,data/PDBBind_processed/5zjy/5zjy_ligand.sdf,
194,data/PDBBind_processed/6o3y/6o3y_protein_processed.pdb,data/PDBBind_processed/6o3y/6o3y_ligand.sdf,
195,data/PDBBind_processed/6rpg/6rpg_protein_processed.pdb,data/PDBBind_processed/6rpg/6rpg_ligand.sdf,
196,data/PDBBind_processed/6rr0/6rr0_protein_processed.pdb,data/PDBBind_processed/6rr0/6rr0_ligand.sdf,
197,data/PDBBind_processed/6gzy/6gzy_protein_processed.pdb,data/PDBBind_processed/6gzy/6gzy_ligand.sdf,
198,data/PDBBind_processed/6qlt/6qlt_protein_processed.pdb,data/PDBBind_processed/6qlt/6qlt_ligand.sdf,
199,data/PDBBind_processed/6ufo/6ufo_protein_processed.pdb,data/PDBBind_processed/6ufo/6ufo_ligand.sdf,
200,data/PDBBind_processed/6o0h/6o0h_protein_processed.pdb,data/PDBBind_processed/6o0h/6o0h_ligand.sdf,
201,data/PDBBind_processed/6o3x/6o3x_protein_processed.pdb,data/PDBBind_processed/6o3x/6o3x_ligand.sdf,
202,data/PDBBind_processed/5zjz/5zjz_protein_processed.pdb,data/PDBBind_processed/5zjz/5zjz_ligand.mol2,
203,data/PDBBind_processed/6i8t/6i8t_protein_processed.pdb,data/PDBBind_processed/6i8t/6i8t_ligand.sdf,
204,data/PDBBind_processed/6ooy/6ooy_protein_processed.pdb,data/PDBBind_processed/6ooy/6ooy_ligand.sdf,
205,data/PDBBind_processed/6oiq/6oiq_protein_processed.pdb,data/PDBBind_processed/6oiq/6oiq_ligand.sdf,
206,data/PDBBind_processed/6od6/6od6_protein_processed.pdb,data/PDBBind_processed/6od6/6od6_ligand.sdf,
207,data/PDBBind_processed/6nrh/6nrh_protein_processed.pdb,data/PDBBind_processed/6nrh/6nrh_ligand.sdf,
208,data/PDBBind_processed/6qra/6qra_protein_processed.pdb,data/PDBBind_processed/6qra/6qra_ligand.mol2,
209,data/PDBBind_processed/6hhh/6hhh_protein_processed.pdb,data/PDBBind_processed/6hhh/6hhh_ligand.sdf,
210,data/PDBBind_processed/6m7h/6m7h_protein_processed.pdb,data/PDBBind_processed/6m7h/6m7h_ligand.sdf,
211,data/PDBBind_processed/6ufn/6ufn_protein_processed.pdb,data/PDBBind_processed/6ufn/6ufn_ligand.sdf,
212,data/PDBBind_processed/6qr0/6qr0_protein_processed.pdb,data/PDBBind_processed/6qr0/6qr0_ligand.mol2,
213,data/PDBBind_processed/6o5u/6o5u_protein_processed.pdb,data/PDBBind_processed/6o5u/6o5u_ligand.sdf,
214,data/PDBBind_processed/6h14/6h14_protein_processed.pdb,data/PDBBind_processed/6h14/6h14_ligand.sdf,
215,data/PDBBind_processed/6jwa/6jwa_protein_processed.pdb,data/PDBBind_processed/6jwa/6jwa_ligand.sdf,
216,data/PDBBind_processed/6ny0/6ny0_protein_processed.pdb,data/PDBBind_processed/6ny0/6ny0_ligand.sdf,
217,data/PDBBind_processed/6jan/6jan_protein_processed.pdb,data/PDBBind_processed/6jan/6jan_ligand.sdf,
218,data/PDBBind_processed/6ftf/6ftf_protein_processed.pdb,data/PDBBind_processed/6ftf/6ftf_ligand.sdf,
219,data/PDBBind_processed/6oxw/6oxw_protein_processed.pdb,data/PDBBind_processed/6oxw/6oxw_ligand.sdf,
220,data/PDBBind_processed/6jon/6jon_protein_processed.pdb,data/PDBBind_processed/6jon/6jon_ligand.sdf,
221,data/PDBBind_processed/6cf7/6cf7_protein_processed.pdb,data/PDBBind_processed/6cf7/6cf7_ligand.sdf,
222,data/PDBBind_processed/6rtn/6rtn_protein_processed.pdb,data/PDBBind_processed/6rtn/6rtn_ligand.mol2,
223,data/PDBBind_processed/6jsz/6jsz_protein_processed.pdb,data/PDBBind_processed/6jsz/6jsz_ligand.sdf,
224,data/PDBBind_processed/6o9c/6o9c_protein_processed.pdb,data/PDBBind_processed/6o9c/6o9c_ligand.sdf,
225,data/PDBBind_processed/6mo8/6mo8_protein_processed.pdb,data/PDBBind_processed/6mo8/6mo8_ligand.sdf,
226,data/PDBBind_processed/6qln/6qln_protein_processed.pdb,data/PDBBind_processed/6qln/6qln_ligand.sdf,
227,data/PDBBind_processed/6qqu/6qqu_protein_processed.pdb,data/PDBBind_processed/6qqu/6qqu_ligand.mol2,
228,data/PDBBind_processed/6i66/6i66_protein_processed.pdb,data/PDBBind_processed/6i66/6i66_ligand.sdf,
229,data/PDBBind_processed/6mja/6mja_protein_processed.pdb,data/PDBBind_processed/6mja/6mja_ligand.sdf,
230,data/PDBBind_processed/6gwe/6gwe_protein_processed.pdb,data/PDBBind_processed/6gwe/6gwe_ligand.mol2,
231,data/PDBBind_processed/6d3z/6d3z_protein_processed.pdb,data/PDBBind_processed/6d3z/6d3z_ligand.sdf,
232,data/PDBBind_processed/6oxr/6oxr_protein_processed.pdb,data/PDBBind_processed/6oxr/6oxr_ligand.sdf,
233,data/PDBBind_processed/6r4k/6r4k_protein_processed.pdb,data/PDBBind_processed/6r4k/6r4k_ligand.sdf,
234,data/PDBBind_processed/6hle/6hle_protein_processed.pdb,data/PDBBind_processed/6hle/6hle_ligand.sdf,
235,data/PDBBind_processed/6h9v/6h9v_protein_processed.pdb,data/PDBBind_processed/6h9v/6h9v_ligand.sdf,
236,data/PDBBind_processed/6hou/6hou_protein_processed.pdb,data/PDBBind_processed/6hou/6hou_ligand.sdf,
237,data/PDBBind_processed/6nv9/6nv9_protein_processed.pdb,data/PDBBind_processed/6nv9/6nv9_ligand.sdf,
238,data/PDBBind_processed/6py0/6py0_protein_processed.pdb,data/PDBBind_processed/6py0/6py0_ligand.sdf,
239,data/PDBBind_processed/6qlq/6qlq_protein_processed.pdb,data/PDBBind_processed/6qlq/6qlq_ligand.sdf,
240,data/PDBBind_processed/6nv7/6nv7_protein_processed.pdb,data/PDBBind_processed/6nv7/6nv7_ligand.sdf,
241,data/PDBBind_processed/6n4b/6n4b_protein_processed.pdb,data/PDBBind_processed/6n4b/6n4b_ligand.sdf,
242,data/PDBBind_processed/6jaq/6jaq_protein_processed.pdb,data/PDBBind_processed/6jaq/6jaq_ligand.sdf,
243,data/PDBBind_processed/6i8m/6i8m_protein_processed.pdb,data/PDBBind_processed/6i8m/6i8m_ligand.sdf,
244,data/PDBBind_processed/6dz0/6dz0_protein_processed.pdb,data/PDBBind_processed/6dz0/6dz0_ligand.mol2,
245,data/PDBBind_processed/6oxs/6oxs_protein_processed.pdb,data/PDBBind_processed/6oxs/6oxs_ligand.sdf,
246,data/PDBBind_processed/6k2n/6k2n_protein_processed.pdb,data/PDBBind_processed/6k2n/6k2n_ligand.sdf,
247,data/PDBBind_processed/6cjj/6cjj_protein_processed.pdb,data/PDBBind_processed/6cjj/6cjj_ligand.sdf,
248,data/PDBBind_processed/6ffg/6ffg_protein_processed.pdb,data/PDBBind_processed/6ffg/6ffg_ligand.sdf,
249,data/PDBBind_processed/6a73/6a73_protein_processed.pdb,data/PDBBind_processed/6a73/6a73_ligand.sdf,
250,data/PDBBind_processed/6qqt/6qqt_protein_processed.pdb,data/PDBBind_processed/6qqt/6qqt_ligand.mol2,
251,data/PDBBind_processed/6a1c/6a1c_protein_processed.pdb,data/PDBBind_processed/6a1c/6a1c_ligand.sdf,
252,data/PDBBind_processed/6oxu/6oxu_protein_processed.pdb,data/PDBBind_processed/6oxu/6oxu_ligand.sdf,
253,data/PDBBind_processed/6qre/6qre_protein_processed.pdb,data/PDBBind_processed/6qre/6qre_ligand.mol2,
254,data/PDBBind_processed/6qtw/6qtw_protein_processed.pdb,data/PDBBind_processed/6qtw/6qtw_ligand.sdf,
255,data/PDBBind_processed/6np4/6np4_protein_processed.pdb,data/PDBBind_processed/6np4/6np4_ligand.sdf,
256,data/PDBBind_processed/6hv2/6hv2_protein_processed.pdb,data/PDBBind_processed/6hv2/6hv2_ligand.sdf,
257,data/PDBBind_processed/6n55/6n55_protein_processed.pdb,data/PDBBind_processed/6n55/6n55_ligand.sdf,
258,data/PDBBind_processed/6e3o/6e3o_protein_processed.pdb,data/PDBBind_processed/6e3o/6e3o_ligand.sdf,
259,data/PDBBind_processed/6kjd/6kjd_protein_processed.pdb,data/PDBBind_processed/6kjd/6kjd_ligand.sdf,
260,data/PDBBind_processed/6sfc/6sfc_protein_processed.pdb,data/PDBBind_processed/6sfc/6sfc_ligand.sdf,
261,data/PDBBind_processed/6qi7/6qi7_protein_processed.pdb,data/PDBBind_processed/6qi7/6qi7_ligand.sdf,
262,data/PDBBind_processed/6hzc/6hzc_protein_processed.pdb,data/PDBBind_processed/6hzc/6hzc_ligand.sdf,
263,data/PDBBind_processed/6k04/6k04_protein_processed.pdb,data/PDBBind_processed/6k04/6k04_ligand.sdf,
264,data/PDBBind_processed/6op0/6op0_protein_processed.pdb,data/PDBBind_processed/6op0/6op0_ligand.sdf,
265,data/PDBBind_processed/6q38/6q38_protein_processed.pdb,data/PDBBind_processed/6q38/6q38_ligand.mol2,
266,data/PDBBind_processed/6n8x/6n8x_protein_processed.pdb,data/PDBBind_processed/6n8x/6n8x_ligand.sdf,
267,data/PDBBind_processed/6np3/6np3_protein_processed.pdb,data/PDBBind_processed/6np3/6np3_ligand.sdf,
268,data/PDBBind_processed/6uvv/6uvv_protein_processed.pdb,data/PDBBind_processed/6uvv/6uvv_ligand.sdf,
269,data/PDBBind_processed/6pgo/6pgo_protein_processed.pdb,data/PDBBind_processed/6pgo/6pgo_ligand.sdf,
270,data/PDBBind_processed/6jbe/6jbe_protein_processed.pdb,data/PDBBind_processed/6jbe/6jbe_ligand.sdf,
271,data/PDBBind_processed/6i75/6i75_protein_processed.pdb,data/PDBBind_processed/6i75/6i75_ligand.sdf,
272,data/PDBBind_processed/6qqq/6qqq_protein_processed.pdb,data/PDBBind_processed/6qqq/6qqq_ligand.mol2,
273,data/PDBBind_processed/6i62/6i62_protein_processed.pdb,data/PDBBind_processed/6i62/6i62_ligand.sdf,
274,data/PDBBind_processed/6j9y/6j9y_protein_processed.pdb,data/PDBBind_processed/6j9y/6j9y_ligand.sdf,
275,data/PDBBind_processed/6g29/6g29_protein_processed.pdb,data/PDBBind_processed/6g29/6g29_ligand.sdf,
276,data/PDBBind_processed/6h7d/6h7d_protein_processed.pdb,data/PDBBind_processed/6h7d/6h7d_ligand.sdf,
277,data/PDBBind_processed/6mo9/6mo9_protein_processed.pdb,data/PDBBind_processed/6mo9/6mo9_ligand.sdf,
278,data/PDBBind_processed/6jao/6jao_protein_processed.pdb,data/PDBBind_processed/6jao/6jao_ligand.sdf,
279,data/PDBBind_processed/6jmf/6jmf_protein_processed.pdb,data/PDBBind_processed/6jmf/6jmf_ligand.sdf,
280,data/PDBBind_processed/6hmy/6hmy_protein_processed.pdb,data/PDBBind_processed/6hmy/6hmy_ligand.sdf,
281,data/PDBBind_processed/6qfe/6qfe_protein_processed.pdb,data/PDBBind_processed/6qfe/6qfe_ligand.mol2,
282,data/PDBBind_processed/5zml/5zml_protein_processed.pdb,data/PDBBind_processed/5zml/5zml_ligand.sdf,
283,data/PDBBind_processed/6i65/6i65_protein_processed.pdb,data/PDBBind_processed/6i65/6i65_ligand.sdf,
284,data/PDBBind_processed/6e7m/6e7m_protein_processed.pdb,data/PDBBind_processed/6e7m/6e7m_ligand.sdf,
285,data/PDBBind_processed/6i61/6i61_protein_processed.pdb,data/PDBBind_processed/6i61/6i61_ligand.sdf,
286,data/PDBBind_processed/6rz6/6rz6_protein_processed.pdb,data/PDBBind_processed/6rz6/6rz6_ligand.sdf,
287,data/PDBBind_processed/6qtm/6qtm_protein_processed.pdb,data/PDBBind_processed/6qtm/6qtm_ligand.sdf,
288,data/PDBBind_processed/6qlo/6qlo_protein_processed.pdb,data/PDBBind_processed/6qlo/6qlo_ligand.sdf,
289,data/PDBBind_processed/6oie/6oie_protein_processed.pdb,data/PDBBind_processed/6oie/6oie_ligand.sdf,
290,data/PDBBind_processed/6miy/6miy_protein_processed.pdb,data/PDBBind_processed/6miy/6miy_ligand.sdf,
291,data/PDBBind_processed/6nrf/6nrf_protein_processed.pdb,data/PDBBind_processed/6nrf/6nrf_ligand.mol2,
292,data/PDBBind_processed/6gj5/6gj5_protein_processed.pdb,data/PDBBind_processed/6gj5/6gj5_ligand.mol2,
293,data/PDBBind_processed/6jad/6jad_protein_processed.pdb,data/PDBBind_processed/6jad/6jad_ligand.sdf,
294,data/PDBBind_processed/6mj4/6mj4_protein_processed.pdb,data/PDBBind_processed/6mj4/6mj4_ligand.sdf,
295,data/PDBBind_processed/6h12/6h12_protein_processed.pdb,data/PDBBind_processed/6h12/6h12_ligand.sdf,
296,data/PDBBind_processed/6d3y/6d3y_protein_processed.pdb,data/PDBBind_processed/6d3y/6d3y_ligand.sdf,
297,data/PDBBind_processed/6qr2/6qr2_protein_processed.pdb,data/PDBBind_processed/6qr2/6qr2_ligand.mol2,
298,data/PDBBind_processed/6qxa/6qxa_protein_processed.pdb,data/PDBBind_processed/6qxa/6qxa_ligand.mol2,
299,data/PDBBind_processed/6o9b/6o9b_protein_processed.pdb,data/PDBBind_processed/6o9b/6o9b_ligand.sdf,
300,data/PDBBind_processed/6ckl/6ckl_protein_processed.pdb,data/PDBBind_processed/6ckl/6ckl_ligand.sdf,
301,data/PDBBind_processed/6oir/6oir_protein_processed.pdb,data/PDBBind_processed/6oir/6oir_ligand.sdf,
302,data/PDBBind_processed/6d40/6d40_protein_processed.pdb,data/PDBBind_processed/6d40/6d40_ligand.sdf,
303,data/PDBBind_processed/6e6j/6e6j_protein_processed.pdb,data/PDBBind_processed/6e6j/6e6j_ligand.mol2,
304,data/PDBBind_processed/6i7a/6i7a_protein_processed.pdb,data/PDBBind_processed/6i7a/6i7a_ligand.sdf,
305,data/PDBBind_processed/6g25/6g25_protein_processed.pdb,data/PDBBind_processed/6g25/6g25_ligand.mol2,
306,data/PDBBind_processed/6oin/6oin_protein_processed.pdb,data/PDBBind_processed/6oin/6oin_ligand.sdf,
307,data/PDBBind_processed/6jam/6jam_protein_processed.pdb,data/PDBBind_processed/6jam/6jam_ligand.sdf,
308,data/PDBBind_processed/6oxz/6oxz_protein_processed.pdb,data/PDBBind_processed/6oxz/6oxz_ligand.sdf,
309,data/PDBBind_processed/6hop/6hop_protein_processed.pdb,data/PDBBind_processed/6hop/6hop_ligand.sdf,
310,data/PDBBind_processed/6rot/6rot_protein_processed.pdb,data/PDBBind_processed/6rot/6rot_ligand.sdf,
311,data/PDBBind_processed/6uhu/6uhu_protein_processed.pdb,data/PDBBind_processed/6uhu/6uhu_ligand.mol2,
312,data/PDBBind_processed/6mji/6mji_protein_processed.pdb,data/PDBBind_processed/6mji/6mji_ligand.sdf,
313,data/PDBBind_processed/6nrj/6nrj_protein_processed.pdb,data/PDBBind_processed/6nrj/6nrj_ligand.mol2,
314,data/PDBBind_processed/6nt2/6nt2_protein_processed.pdb,data/PDBBind_processed/6nt2/6nt2_ligand.mol2,
315,data/PDBBind_processed/6op9/6op9_protein_processed.pdb,data/PDBBind_processed/6op9/6op9_ligand.sdf,
316,data/PDBBind_processed/6pno/6pno_protein_processed.pdb,data/PDBBind_processed/6pno/6pno_ligand.sdf,
317,data/PDBBind_processed/6e4v/6e4v_protein_processed.pdb,data/PDBBind_processed/6e4v/6e4v_ligand.sdf,
318,data/PDBBind_processed/6k1s/6k1s_protein_processed.pdb,data/PDBBind_processed/6k1s/6k1s_ligand.sdf,
319,data/PDBBind_processed/6a87/6a87_protein_processed.pdb,data/PDBBind_processed/6a87/6a87_ligand.sdf,
320,data/PDBBind_processed/6oim/6oim_protein_processed.pdb,data/PDBBind_processed/6oim/6oim_ligand.sdf,
321,data/PDBBind_processed/6cjp/6cjp_protein_processed.pdb,data/PDBBind_processed/6cjp/6cjp_ligand.sdf,
322,data/PDBBind_processed/6pyb/6pyb_protein_processed.pdb,data/PDBBind_processed/6pyb/6pyb_ligand.sdf,
323,data/PDBBind_processed/6h13/6h13_protein_processed.pdb,data/PDBBind_processed/6h13/6h13_ligand.sdf,
324,data/PDBBind_processed/6qrf/6qrf_protein_processed.pdb,data/PDBBind_processed/6qrf/6qrf_ligand.mol2,
325,data/PDBBind_processed/6mhc/6mhc_protein_processed.pdb,data/PDBBind_processed/6mhc/6mhc_ligand.sdf,
326,data/PDBBind_processed/6j9w/6j9w_protein_processed.pdb,data/PDBBind_processed/6j9w/6j9w_ligand.sdf,
327,data/PDBBind_processed/6nrg/6nrg_protein_processed.pdb,data/PDBBind_processed/6nrg/6nrg_ligand.mol2,
328,data/PDBBind_processed/6fff/6fff_protein_processed.pdb,data/PDBBind_processed/6fff/6fff_ligand.sdf,
329,data/PDBBind_processed/6n93/6n93_protein_processed.pdb,data/PDBBind_processed/6n93/6n93_ligand.sdf,
330,data/PDBBind_processed/6jut/6jut_protein_processed.pdb,data/PDBBind_processed/6jut/6jut_ligand.mol2,
331,data/PDBBind_processed/6g2e/6g2e_protein_processed.pdb,data/PDBBind_processed/6g2e/6g2e_ligand.sdf,
332,data/PDBBind_processed/6nd3/6nd3_protein_processed.pdb,data/PDBBind_processed/6nd3/6nd3_ligand.sdf,
333,data/PDBBind_processed/6os6/6os6_protein_processed.pdb,data/PDBBind_processed/6os6/6os6_ligand.mol2,
334,data/PDBBind_processed/6dql/6dql_protein_processed.pdb,data/PDBBind_processed/6dql/6dql_ligand.mol2,
335,data/PDBBind_processed/6inz/6inz_protein_processed.pdb,data/PDBBind_processed/6inz/6inz_ligand.sdf,
336,data/PDBBind_processed/6i67/6i67_protein_processed.pdb,data/PDBBind_processed/6i67/6i67_ligand.sdf,
337,data/PDBBind_processed/6quw/6quw_protein_processed.pdb,data/PDBBind_processed/6quw/6quw_ligand.sdf,
338,data/PDBBind_processed/6qwi/6qwi_protein_processed.pdb,data/PDBBind_processed/6qwi/6qwi_ligand.sdf,
339,data/PDBBind_processed/6npm/6npm_protein_processed.pdb,data/PDBBind_processed/6npm/6npm_ligand.sdf,
340,data/PDBBind_processed/6i64/6i64_protein_processed.pdb,data/PDBBind_processed/6i64/6i64_ligand.sdf,
341,data/PDBBind_processed/6e3n/6e3n_protein_processed.pdb,data/PDBBind_processed/6e3n/6e3n_ligand.sdf,
342,data/PDBBind_processed/6qrg/6qrg_protein_processed.pdb,data/PDBBind_processed/6qrg/6qrg_ligand.mol2,
343,data/PDBBind_processed/6nxz/6nxz_protein_processed.pdb,data/PDBBind_processed/6nxz/6nxz_ligand.sdf,
344,data/PDBBind_processed/6iby/6iby_protein_processed.pdb,data/PDBBind_processed/6iby/6iby_ligand.sdf,
345,data/PDBBind_processed/6gj7/6gj7_protein_processed.pdb,data/PDBBind_processed/6gj7/6gj7_ligand.mol2,
346,data/PDBBind_processed/6qr3/6qr3_protein_processed.pdb,data/PDBBind_processed/6qr3/6qr3_ligand.mol2,
347,data/PDBBind_processed/6qr1/6qr1_protein_processed.pdb,data/PDBBind_processed/6qr1/6qr1_ligand.mol2,
348,data/PDBBind_processed/6s9x/6s9x_protein_processed.pdb,data/PDBBind_processed/6s9x/6s9x_ligand.sdf,
349,data/PDBBind_processed/6q4q/6q4q_protein_processed.pdb,data/PDBBind_processed/6q4q/6q4q_ligand.mol2,
350,data/PDBBind_processed/6hbn/6hbn_protein_processed.pdb,data/PDBBind_processed/6hbn/6hbn_ligand.sdf,
351,data/PDBBind_processed/6nw3/6nw3_protein_processed.pdb,data/PDBBind_processed/6nw3/6nw3_ligand.sdf,
352,data/PDBBind_processed/6tel/6tel_protein_processed.pdb,data/PDBBind_processed/6tel/6tel_ligand.sdf,
353,data/PDBBind_processed/6p8y/6p8y_protein_processed.pdb,data/PDBBind_processed/6p8y/6p8y_ligand.sdf,
354,data/PDBBind_processed/6d5w/6d5w_protein_processed.pdb,data/PDBBind_processed/6d5w/6d5w_ligand.sdf,
355,data/PDBBind_processed/6t6a/6t6a_protein_processed.pdb,data/PDBBind_processed/6t6a/6t6a_ligand.mol2,
356,data/PDBBind_processed/6o5g/6o5g_protein_processed.pdb,data/PDBBind_processed/6o5g/6o5g_ligand.mol2,
357,data/PDBBind_processed/6r7d/6r7d_protein_processed.pdb,data/PDBBind_processed/6r7d/6r7d_ligand.sdf,
358,data/PDBBind_processed/6pya/6pya_protein_processed.pdb,data/PDBBind_processed/6pya/6pya_ligand.mol2,
359,data/PDBBind_processed/6ffe/6ffe_protein_processed.pdb,data/PDBBind_processed/6ffe/6ffe_ligand.sdf,
360,data/PDBBind_processed/6d3x/6d3x_protein_processed.pdb,data/PDBBind_processed/6d3x/6d3x_ligand.sdf,
361,data/PDBBind_processed/6gj8/6gj8_protein_processed.pdb,data/PDBBind_processed/6gj8/6gj8_ligand.mol2,
362,data/PDBBind_processed/6mo2/6mo2_protein_processed.pdb,data/PDBBind_processed/6mo2/6mo2_ligand.mol2,
,protein_path,ligand
0,data/PDBBind_processed/6qqw/6qqw_protein_processed.pdb,data/PDBBind_processed/6qqw/6qqw_ligand.mol2
1,data/PDBBind_processed/6d08/6d08_protein_processed.pdb,data/PDBBind_processed/6d08/6d08_ligand.sdf
2,data/PDBBind_processed/6jap/6jap_protein_processed.pdb,data/PDBBind_processed/6jap/6jap_ligand.sdf
3,data/PDBBind_processed/6np2/6np2_protein_processed.pdb,data/PDBBind_processed/6np2/6np2_ligand.sdf
4,data/PDBBind_processed/6uvp/6uvp_protein_processed.pdb,data/PDBBind_processed/6uvp/6uvp_ligand.sdf
5,data/PDBBind_processed/6oxq/6oxq_protein_processed.pdb,data/PDBBind_processed/6oxq/6oxq_ligand.sdf
6,data/PDBBind_processed/6jsn/6jsn_protein_processed.pdb,data/PDBBind_processed/6jsn/6jsn_ligand.sdf
7,data/PDBBind_processed/6hzb/6hzb_protein_processed.pdb,data/PDBBind_processed/6hzb/6hzb_ligand.sdf
8,data/PDBBind_processed/6qrc/6qrc_protein_processed.pdb,data/PDBBind_processed/6qrc/6qrc_ligand.mol2
9,data/PDBBind_processed/6oio/6oio_protein_processed.pdb,data/PDBBind_processed/6oio/6oio_ligand.sdf
10,data/PDBBind_processed/6jag/6jag_protein_processed.pdb,data/PDBBind_processed/6jag/6jag_ligand.sdf
11,data/PDBBind_processed/6moa/6moa_protein_processed.pdb,data/PDBBind_processed/6moa/6moa_ligand.mol2
12,data/PDBBind_processed/6hld/6hld_protein_processed.pdb,data/PDBBind_processed/6hld/6hld_ligand.sdf
13,data/PDBBind_processed/6i9a/6i9a_protein_processed.pdb,data/PDBBind_processed/6i9a/6i9a_ligand.sdf
14,data/PDBBind_processed/6e4c/6e4c_protein_processed.pdb,data/PDBBind_processed/6e4c/6e4c_ligand.sdf
15,data/PDBBind_processed/6g24/6g24_protein_processed.pdb,data/PDBBind_processed/6g24/6g24_ligand.sdf
16,data/PDBBind_processed/6jb4/6jb4_protein_processed.pdb,data/PDBBind_processed/6jb4/6jb4_ligand.sdf
17,data/PDBBind_processed/6s55/6s55_protein_processed.pdb,data/PDBBind_processed/6s55/6s55_ligand.sdf
18,data/PDBBind_processed/6seo/6seo_protein_processed.pdb,data/PDBBind_processed/6seo/6seo_ligand.sdf
19,data/PDBBind_processed/6dyz/6dyz_protein_processed.pdb,data/PDBBind_processed/6dyz/6dyz_ligand.mol2
20,data/PDBBind_processed/5zk5/5zk5_protein_processed.pdb,data/PDBBind_processed/5zk5/5zk5_ligand.sdf
21,data/PDBBind_processed/6jid/6jid_protein_processed.pdb,data/PDBBind_processed/6jid/6jid_ligand.sdf
22,data/PDBBind_processed/5ze6/5ze6_protein_processed.pdb,data/PDBBind_processed/5ze6/5ze6_ligand.sdf
23,data/PDBBind_processed/6qlu/6qlu_protein_processed.pdb,data/PDBBind_processed/6qlu/6qlu_ligand.sdf
24,data/PDBBind_processed/6a6k/6a6k_protein_processed.pdb,data/PDBBind_processed/6a6k/6a6k_ligand.sdf
25,data/PDBBind_processed/6qgf/6qgf_protein_processed.pdb,data/PDBBind_processed/6qgf/6qgf_ligand.sdf
26,data/PDBBind_processed/6e3z/6e3z_protein_processed.pdb,data/PDBBind_processed/6e3z/6e3z_ligand.sdf
27,data/PDBBind_processed/6te6/6te6_protein_processed.pdb,data/PDBBind_processed/6te6/6te6_ligand.sdf
28,data/PDBBind_processed/6pka/6pka_protein_processed.pdb,data/PDBBind_processed/6pka/6pka_ligand.sdf
29,data/PDBBind_processed/6g2o/6g2o_protein_processed.pdb,data/PDBBind_processed/6g2o/6g2o_ligand.sdf
30,data/PDBBind_processed/6jsf/6jsf_protein_processed.pdb,data/PDBBind_processed/6jsf/6jsf_ligand.sdf
31,data/PDBBind_processed/5zxk/5zxk_protein_processed.pdb,data/PDBBind_processed/5zxk/5zxk_ligand.sdf
32,data/PDBBind_processed/6qxd/6qxd_protein_processed.pdb,data/PDBBind_processed/6qxd/6qxd_ligand.sdf
33,data/PDBBind_processed/6n97/6n97_protein_processed.pdb,data/PDBBind_processed/6n97/6n97_ligand.sdf
34,data/PDBBind_processed/6jt3/6jt3_protein_processed.pdb,data/PDBBind_processed/6jt3/6jt3_ligand.sdf
35,data/PDBBind_processed/6qtr/6qtr_protein_processed.pdb,data/PDBBind_processed/6qtr/6qtr_ligand.sdf
36,data/PDBBind_processed/6oy1/6oy1_protein_processed.pdb,data/PDBBind_processed/6oy1/6oy1_ligand.sdf
37,data/PDBBind_processed/6n96/6n96_protein_processed.pdb,data/PDBBind_processed/6n96/6n96_ligand.sdf
38,data/PDBBind_processed/6qzh/6qzh_protein_processed.pdb,data/PDBBind_processed/6qzh/6qzh_ligand.sdf
39,data/PDBBind_processed/6qqz/6qqz_protein_processed.pdb,data/PDBBind_processed/6qqz/6qqz_ligand.mol2
40,data/PDBBind_processed/6qmt/6qmt_protein_processed.pdb,data/PDBBind_processed/6qmt/6qmt_ligand.sdf
41,data/PDBBind_processed/6ibx/6ibx_protein_processed.pdb,data/PDBBind_processed/6ibx/6ibx_ligand.sdf
42,data/PDBBind_processed/6hmt/6hmt_protein_processed.pdb,data/PDBBind_processed/6hmt/6hmt_ligand.sdf
43,data/PDBBind_processed/5zk7/5zk7_protein_processed.pdb,data/PDBBind_processed/5zk7/5zk7_ligand.sdf
44,data/PDBBind_processed/6k3l/6k3l_protein_processed.pdb,data/PDBBind_processed/6k3l/6k3l_ligand.sdf
45,data/PDBBind_processed/6cjs/6cjs_protein_processed.pdb,data/PDBBind_processed/6cjs/6cjs_ligand.sdf
46,data/PDBBind_processed/6n9l/6n9l_protein_processed.pdb,data/PDBBind_processed/6n9l/6n9l_ligand.sdf
47,data/PDBBind_processed/6ibz/6ibz_protein_processed.pdb,data/PDBBind_processed/6ibz/6ibz_ligand.sdf
48,data/PDBBind_processed/6ott/6ott_protein_processed.pdb,data/PDBBind_processed/6ott/6ott_ligand.sdf
49,data/PDBBind_processed/6gge/6gge_protein_processed.pdb,data/PDBBind_processed/6gge/6gge_ligand.sdf
50,data/PDBBind_processed/6hot/6hot_protein_processed.pdb,data/PDBBind_processed/6hot/6hot_ligand.sdf
51,data/PDBBind_processed/6e3p/6e3p_protein_processed.pdb,data/PDBBind_processed/6e3p/6e3p_ligand.mol2
52,data/PDBBind_processed/6md6/6md6_protein_processed.pdb,data/PDBBind_processed/6md6/6md6_ligand.sdf
53,data/PDBBind_processed/6hlb/6hlb_protein_processed.pdb,data/PDBBind_processed/6hlb/6hlb_ligand.sdf
54,data/PDBBind_processed/6fe5/6fe5_protein_processed.pdb,data/PDBBind_processed/6fe5/6fe5_ligand.sdf
55,data/PDBBind_processed/6uwp/6uwp_protein_processed.pdb,data/PDBBind_processed/6uwp/6uwp_ligand.sdf
56,data/PDBBind_processed/6npp/6npp_protein_processed.pdb,data/PDBBind_processed/6npp/6npp_ligand.sdf
57,data/PDBBind_processed/6g2f/6g2f_protein_processed.pdb,data/PDBBind_processed/6g2f/6g2f_ligand.sdf
58,data/PDBBind_processed/6mo7/6mo7_protein_processed.pdb,data/PDBBind_processed/6mo7/6mo7_ligand.sdf
59,data/PDBBind_processed/6bqd/6bqd_protein_processed.pdb,data/PDBBind_processed/6bqd/6bqd_ligand.mol2
60,data/PDBBind_processed/6nsv/6nsv_protein_processed.pdb,data/PDBBind_processed/6nsv/6nsv_ligand.mol2
61,data/PDBBind_processed/6i76/6i76_protein_processed.pdb,data/PDBBind_processed/6i76/6i76_ligand.sdf
62,data/PDBBind_processed/6n53/6n53_protein_processed.pdb,data/PDBBind_processed/6n53/6n53_ligand.sdf
63,data/PDBBind_processed/6g2c/6g2c_protein_processed.pdb,data/PDBBind_processed/6g2c/6g2c_ligand.sdf
64,data/PDBBind_processed/6eeb/6eeb_protein_processed.pdb,data/PDBBind_processed/6eeb/6eeb_ligand.mol2
65,data/PDBBind_processed/6n0m/6n0m_protein_processed.pdb,data/PDBBind_processed/6n0m/6n0m_ligand.sdf
66,data/PDBBind_processed/6uvy/6uvy_protein_processed.pdb,data/PDBBind_processed/6uvy/6uvy_ligand.sdf
67,data/PDBBind_processed/6ovz/6ovz_protein_processed.pdb,data/PDBBind_processed/6ovz/6ovz_ligand.sdf
68,data/PDBBind_processed/6olx/6olx_protein_processed.pdb,data/PDBBind_processed/6olx/6olx_ligand.sdf
69,data/PDBBind_processed/6v5l/6v5l_protein_processed.pdb,data/PDBBind_processed/6v5l/6v5l_ligand.mol2
70,data/PDBBind_processed/6hhg/6hhg_protein_processed.pdb,data/PDBBind_processed/6hhg/6hhg_ligand.sdf
71,data/PDBBind_processed/5zcu/5zcu_protein_processed.pdb,data/PDBBind_processed/5zcu/5zcu_ligand.sdf
72,data/PDBBind_processed/6dz2/6dz2_protein_processed.pdb,data/PDBBind_processed/6dz2/6dz2_ligand.mol2
73,data/PDBBind_processed/6mjq/6mjq_protein_processed.pdb,data/PDBBind_processed/6mjq/6mjq_ligand.sdf
74,data/PDBBind_processed/6efk/6efk_protein_processed.pdb,data/PDBBind_processed/6efk/6efk_ligand.sdf
75,data/PDBBind_processed/6s9w/6s9w_protein_processed.pdb,data/PDBBind_processed/6s9w/6s9w_ligand.sdf
76,data/PDBBind_processed/6gdy/6gdy_protein_processed.pdb,data/PDBBind_processed/6gdy/6gdy_ligand.sdf
77,data/PDBBind_processed/6kqi/6kqi_protein_processed.pdb,data/PDBBind_processed/6kqi/6kqi_ligand.sdf
78,data/PDBBind_processed/6ueg/6ueg_protein_processed.pdb,data/PDBBind_processed/6ueg/6ueg_ligand.sdf
79,data/PDBBind_processed/6oxt/6oxt_protein_processed.pdb,data/PDBBind_processed/6oxt/6oxt_ligand.sdf
80,data/PDBBind_processed/6oy0/6oy0_protein_processed.pdb,data/PDBBind_processed/6oy0/6oy0_ligand.sdf
81,data/PDBBind_processed/6qr7/6qr7_protein_processed.pdb,data/PDBBind_processed/6qr7/6qr7_ligand.mol2
82,data/PDBBind_processed/6i41/6i41_protein_processed.pdb,data/PDBBind_processed/6i41/6i41_ligand.sdf
83,data/PDBBind_processed/6cyg/6cyg_protein_processed.pdb,data/PDBBind_processed/6cyg/6cyg_ligand.sdf
84,data/PDBBind_processed/6qmr/6qmr_protein_processed.pdb,data/PDBBind_processed/6qmr/6qmr_ligand.sdf
85,data/PDBBind_processed/6g27/6g27_protein_processed.pdb,data/PDBBind_processed/6g27/6g27_ligand.sdf
86,data/PDBBind_processed/6ggb/6ggb_protein_processed.pdb,data/PDBBind_processed/6ggb/6ggb_ligand.sdf
87,data/PDBBind_processed/6g3c/6g3c_protein_processed.pdb,data/PDBBind_processed/6g3c/6g3c_ligand.sdf
88,data/PDBBind_processed/6n4e/6n4e_protein_processed.pdb,data/PDBBind_processed/6n4e/6n4e_ligand.sdf
89,data/PDBBind_processed/6fcj/6fcj_protein_processed.pdb,data/PDBBind_processed/6fcj/6fcj_ligand.sdf
90,data/PDBBind_processed/6quv/6quv_protein_processed.pdb,data/PDBBind_processed/6quv/6quv_ligand.sdf
91,data/PDBBind_processed/6iql/6iql_protein_processed.pdb,data/PDBBind_processed/6iql/6iql_ligand.mol2
92,data/PDBBind_processed/6i74/6i74_protein_processed.pdb,data/PDBBind_processed/6i74/6i74_ligand.sdf
93,data/PDBBind_processed/6qr4/6qr4_protein_processed.pdb,data/PDBBind_processed/6qr4/6qr4_ligand.mol2
94,data/PDBBind_processed/6rnu/6rnu_protein_processed.pdb,data/PDBBind_processed/6rnu/6rnu_ligand.sdf
95,data/PDBBind_processed/6jib/6jib_protein_processed.pdb,data/PDBBind_processed/6jib/6jib_ligand.sdf
96,data/PDBBind_processed/6izq/6izq_protein_processed.pdb,data/PDBBind_processed/6izq/6izq_ligand.sdf
97,data/PDBBind_processed/6qw8/6qw8_protein_processed.pdb,data/PDBBind_processed/6qw8/6qw8_ligand.sdf
98,data/PDBBind_processed/6qto/6qto_protein_processed.pdb,data/PDBBind_processed/6qto/6qto_ligand.sdf
99,data/PDBBind_processed/6qrd/6qrd_protein_processed.pdb,data/PDBBind_processed/6qrd/6qrd_ligand.mol2
100,data/PDBBind_processed/6hza/6hza_protein_processed.pdb,data/PDBBind_processed/6hza/6hza_ligand.sdf
101,data/PDBBind_processed/6e5s/6e5s_protein_processed.pdb,data/PDBBind_processed/6e5s/6e5s_ligand.sdf
102,data/PDBBind_processed/6dz3/6dz3_protein_processed.pdb,data/PDBBind_processed/6dz3/6dz3_ligand.mol2
103,data/PDBBind_processed/6e6w/6e6w_protein_processed.pdb,data/PDBBind_processed/6e6w/6e6w_ligand.mol2
104,data/PDBBind_processed/6cyh/6cyh_protein_processed.pdb,data/PDBBind_processed/6cyh/6cyh_ligand.sdf
105,data/PDBBind_processed/5zlf/5zlf_protein_processed.pdb,data/PDBBind_processed/5zlf/5zlf_ligand.sdf
106,data/PDBBind_processed/6om4/6om4_protein_processed.pdb,data/PDBBind_processed/6om4/6om4_ligand.sdf
107,data/PDBBind_processed/6gga/6gga_protein_processed.pdb,data/PDBBind_processed/6gga/6gga_ligand.sdf
108,data/PDBBind_processed/6pgp/6pgp_protein_processed.pdb,data/PDBBind_processed/6pgp/6pgp_ligand.sdf
109,data/PDBBind_processed/6qqv/6qqv_protein_processed.pdb,data/PDBBind_processed/6qqv/6qqv_ligand.mol2
110,data/PDBBind_processed/6qtq/6qtq_protein_processed.pdb,data/PDBBind_processed/6qtq/6qtq_ligand.sdf
111,data/PDBBind_processed/6gj6/6gj6_protein_processed.pdb,data/PDBBind_processed/6gj6/6gj6_ligand.mol2
112,data/PDBBind_processed/6os5/6os5_protein_processed.pdb,data/PDBBind_processed/6os5/6os5_ligand.mol2
113,data/PDBBind_processed/6s07/6s07_protein_processed.pdb,data/PDBBind_processed/6s07/6s07_ligand.sdf
114,data/PDBBind_processed/6i77/6i77_protein_processed.pdb,data/PDBBind_processed/6i77/6i77_ligand.sdf
115,data/PDBBind_processed/6hhj/6hhj_protein_processed.pdb,data/PDBBind_processed/6hhj/6hhj_ligand.sdf
116,data/PDBBind_processed/6ahs/6ahs_protein_processed.pdb,data/PDBBind_processed/6ahs/6ahs_ligand.sdf
117,data/PDBBind_processed/6oxx/6oxx_protein_processed.pdb,data/PDBBind_processed/6oxx/6oxx_ligand.sdf
118,data/PDBBind_processed/6mjj/6mjj_protein_processed.pdb,data/PDBBind_processed/6mjj/6mjj_ligand.sdf
119,data/PDBBind_processed/6hor/6hor_protein_processed.pdb,data/PDBBind_processed/6hor/6hor_ligand.sdf
120,data/PDBBind_processed/6jb0/6jb0_protein_processed.pdb,data/PDBBind_processed/6jb0/6jb0_ligand.sdf
121,data/PDBBind_processed/6i68/6i68_protein_processed.pdb,data/PDBBind_processed/6i68/6i68_ligand.sdf
122,data/PDBBind_processed/6pz4/6pz4_protein_processed.pdb,data/PDBBind_processed/6pz4/6pz4_ligand.sdf
123,data/PDBBind_processed/6mhb/6mhb_protein_processed.pdb,data/PDBBind_processed/6mhb/6mhb_ligand.sdf
124,data/PDBBind_processed/6uim/6uim_protein_processed.pdb,data/PDBBind_processed/6uim/6uim_ligand.sdf
125,data/PDBBind_processed/6jsg/6jsg_protein_processed.pdb,data/PDBBind_processed/6jsg/6jsg_ligand.sdf
126,data/PDBBind_processed/6i78/6i78_protein_processed.pdb,data/PDBBind_processed/6i78/6i78_ligand.sdf
127,data/PDBBind_processed/6oxy/6oxy_protein_processed.pdb,data/PDBBind_processed/6oxy/6oxy_ligand.sdf
128,data/PDBBind_processed/6gbw/6gbw_protein_processed.pdb,data/PDBBind_processed/6gbw/6gbw_ligand.sdf
129,data/PDBBind_processed/6mo0/6mo0_protein_processed.pdb,data/PDBBind_processed/6mo0/6mo0_ligand.sdf
130,data/PDBBind_processed/6ggf/6ggf_protein_processed.pdb,data/PDBBind_processed/6ggf/6ggf_ligand.sdf
131,data/PDBBind_processed/6qge/6qge_protein_processed.pdb,data/PDBBind_processed/6qge/6qge_ligand.sdf
132,data/PDBBind_processed/6cjr/6cjr_protein_processed.pdb,data/PDBBind_processed/6cjr/6cjr_ligand.sdf
133,data/PDBBind_processed/6oxp/6oxp_protein_processed.pdb,data/PDBBind_processed/6oxp/6oxp_ligand.sdf
134,data/PDBBind_processed/6d07/6d07_protein_processed.pdb,data/PDBBind_processed/6d07/6d07_ligand.sdf
135,data/PDBBind_processed/6i63/6i63_protein_processed.pdb,data/PDBBind_processed/6i63/6i63_ligand.sdf
136,data/PDBBind_processed/6ten/6ten_protein_processed.pdb,data/PDBBind_processed/6ten/6ten_ligand.sdf
137,data/PDBBind_processed/6uii/6uii_protein_processed.pdb,data/PDBBind_processed/6uii/6uii_ligand.sdf
138,data/PDBBind_processed/6qlr/6qlr_protein_processed.pdb,data/PDBBind_processed/6qlr/6qlr_ligand.sdf
139,data/PDBBind_processed/6sen/6sen_protein_processed.pdb,data/PDBBind_processed/6sen/6sen_ligand.mol2
140,data/PDBBind_processed/6oxv/6oxv_protein_processed.pdb,data/PDBBind_processed/6oxv/6oxv_ligand.sdf
141,data/PDBBind_processed/6g2b/6g2b_protein_processed.pdb,data/PDBBind_processed/6g2b/6g2b_ligand.sdf
142,data/PDBBind_processed/5zr3/5zr3_protein_processed.pdb,data/PDBBind_processed/5zr3/5zr3_ligand.sdf
143,data/PDBBind_processed/6kjf/6kjf_protein_processed.pdb,data/PDBBind_processed/6kjf/6kjf_ligand.sdf
144,data/PDBBind_processed/6qr9/6qr9_protein_processed.pdb,data/PDBBind_processed/6qr9/6qr9_ligand.mol2
145,data/PDBBind_processed/6g9f/6g9f_protein_processed.pdb,data/PDBBind_processed/6g9f/6g9f_ligand.sdf
146,data/PDBBind_processed/6e6v/6e6v_protein_processed.pdb,data/PDBBind_processed/6e6v/6e6v_ligand.sdf
147,data/PDBBind_processed/5zk9/5zk9_protein_processed.pdb,data/PDBBind_processed/5zk9/5zk9_ligand.sdf
148,data/PDBBind_processed/6pnn/6pnn_protein_processed.pdb,data/PDBBind_processed/6pnn/6pnn_ligand.sdf
149,data/PDBBind_processed/6nri/6nri_protein_processed.pdb,data/PDBBind_processed/6nri/6nri_ligand.sdf
150,data/PDBBind_processed/6uwv/6uwv_protein_processed.pdb,data/PDBBind_processed/6uwv/6uwv_ligand.sdf
151,data/PDBBind_processed/6ooz/6ooz_protein_processed.pdb,data/PDBBind_processed/6ooz/6ooz_ligand.sdf
152,data/PDBBind_processed/6npi/6npi_protein_processed.pdb,data/PDBBind_processed/6npi/6npi_ligand.sdf
153,data/PDBBind_processed/6oip/6oip_protein_processed.pdb,data/PDBBind_processed/6oip/6oip_ligand.sdf
154,data/PDBBind_processed/6miv/6miv_protein_processed.pdb,data/PDBBind_processed/6miv/6miv_ligand.sdf
155,data/PDBBind_processed/6s57/6s57_protein_processed.pdb,data/PDBBind_processed/6s57/6s57_ligand.sdf
156,data/PDBBind_processed/6p8x/6p8x_protein_processed.pdb,data/PDBBind_processed/6p8x/6p8x_ligand.sdf
157,data/PDBBind_processed/6hoq/6hoq_protein_processed.pdb,data/PDBBind_processed/6hoq/6hoq_ligand.sdf
158,data/PDBBind_processed/6qts/6qts_protein_processed.pdb,data/PDBBind_processed/6qts/6qts_ligand.sdf
159,data/PDBBind_processed/6ggd/6ggd_protein_processed.pdb,data/PDBBind_processed/6ggd/6ggd_ligand.sdf
160,data/PDBBind_processed/6pnm/6pnm_protein_processed.pdb,data/PDBBind_processed/6pnm/6pnm_ligand.sdf
161,data/PDBBind_processed/6oy2/6oy2_protein_processed.pdb,data/PDBBind_processed/6oy2/6oy2_ligand.sdf
162,data/PDBBind_processed/6oi8/6oi8_protein_processed.pdb,data/PDBBind_processed/6oi8/6oi8_ligand.sdf
163,data/PDBBind_processed/6mhd/6mhd_protein_processed.pdb,data/PDBBind_processed/6mhd/6mhd_ligand.sdf
164,data/PDBBind_processed/6agt/6agt_protein_processed.pdb,data/PDBBind_processed/6agt/6agt_ligand.sdf
165,data/PDBBind_processed/6i5p/6i5p_protein_processed.pdb,data/PDBBind_processed/6i5p/6i5p_ligand.sdf
166,data/PDBBind_processed/6hhr/6hhr_protein_processed.pdb,data/PDBBind_processed/6hhr/6hhr_ligand.sdf
167,data/PDBBind_processed/6p8z/6p8z_protein_processed.pdb,data/PDBBind_processed/6p8z/6p8z_ligand.sdf
168,data/PDBBind_processed/6c85/6c85_protein_processed.pdb,data/PDBBind_processed/6c85/6c85_ligand.sdf
169,data/PDBBind_processed/6g5u/6g5u_protein_processed.pdb,data/PDBBind_processed/6g5u/6g5u_ligand.sdf
170,data/PDBBind_processed/6j06/6j06_protein_processed.pdb,data/PDBBind_processed/6j06/6j06_ligand.sdf
171,data/PDBBind_processed/6qsz/6qsz_protein_processed.pdb,data/PDBBind_processed/6qsz/6qsz_ligand.sdf
172,data/PDBBind_processed/6jbb/6jbb_protein_processed.pdb,data/PDBBind_processed/6jbb/6jbb_ligand.sdf
173,data/PDBBind_processed/6hhp/6hhp_protein_processed.pdb,data/PDBBind_processed/6hhp/6hhp_ligand.sdf
174,data/PDBBind_processed/6np5/6np5_protein_processed.pdb,data/PDBBind_processed/6np5/6np5_ligand.sdf
175,data/PDBBind_processed/6nlj/6nlj_protein_processed.pdb,data/PDBBind_processed/6nlj/6nlj_ligand.sdf
176,data/PDBBind_processed/6qlp/6qlp_protein_processed.pdb,data/PDBBind_processed/6qlp/6qlp_ligand.sdf
177,data/PDBBind_processed/6n94/6n94_protein_processed.pdb,data/PDBBind_processed/6n94/6n94_ligand.sdf
178,data/PDBBind_processed/6e13/6e13_protein_processed.pdb,data/PDBBind_processed/6e13/6e13_ligand.sdf
179,data/PDBBind_processed/6qls/6qls_protein_processed.pdb,data/PDBBind_processed/6qls/6qls_ligand.sdf
180,data/PDBBind_processed/6uil/6uil_protein_processed.pdb,data/PDBBind_processed/6uil/6uil_ligand.sdf
181,data/PDBBind_processed/6st3/6st3_protein_processed.pdb,data/PDBBind_processed/6st3/6st3_ligand.sdf
182,data/PDBBind_processed/6n92/6n92_protein_processed.pdb,data/PDBBind_processed/6n92/6n92_ligand.sdf
183,data/PDBBind_processed/6s56/6s56_protein_processed.pdb,data/PDBBind_processed/6s56/6s56_ligand.sdf
184,data/PDBBind_processed/6hzd/6hzd_protein_processed.pdb,data/PDBBind_processed/6hzd/6hzd_ligand.sdf
185,data/PDBBind_processed/6uhv/6uhv_protein_processed.pdb,data/PDBBind_processed/6uhv/6uhv_ligand.sdf
186,data/PDBBind_processed/6k05/6k05_protein_processed.pdb,data/PDBBind_processed/6k05/6k05_ligand.sdf
187,data/PDBBind_processed/6q36/6q36_protein_processed.pdb,data/PDBBind_processed/6q36/6q36_ligand.mol2
188,data/PDBBind_processed/6ic0/6ic0_protein_processed.pdb,data/PDBBind_processed/6ic0/6ic0_ligand.sdf
189,data/PDBBind_processed/6hhi/6hhi_protein_processed.pdb,data/PDBBind_processed/6hhi/6hhi_ligand.sdf
190,data/PDBBind_processed/6e3m/6e3m_protein_processed.pdb,data/PDBBind_processed/6e3m/6e3m_ligand.sdf
191,data/PDBBind_processed/6qtx/6qtx_protein_processed.pdb,data/PDBBind_processed/6qtx/6qtx_ligand.sdf
192,data/PDBBind_processed/6jse/6jse_protein_processed.pdb,data/PDBBind_processed/6jse/6jse_ligand.sdf
193,data/PDBBind_processed/5zjy/5zjy_protein_processed.pdb,data/PDBBind_processed/5zjy/5zjy_ligand.sdf
194,data/PDBBind_processed/6o3y/6o3y_protein_processed.pdb,data/PDBBind_processed/6o3y/6o3y_ligand.sdf
195,data/PDBBind_processed/6rpg/6rpg_protein_processed.pdb,data/PDBBind_processed/6rpg/6rpg_ligand.sdf
196,data/PDBBind_processed/6rr0/6rr0_protein_processed.pdb,data/PDBBind_processed/6rr0/6rr0_ligand.sdf
197,data/PDBBind_processed/6gzy/6gzy_protein_processed.pdb,data/PDBBind_processed/6gzy/6gzy_ligand.sdf
198,data/PDBBind_processed/6qlt/6qlt_protein_processed.pdb,data/PDBBind_processed/6qlt/6qlt_ligand.sdf
199,data/PDBBind_processed/6ufo/6ufo_protein_processed.pdb,data/PDBBind_processed/6ufo/6ufo_ligand.sdf
200,data/PDBBind_processed/6o0h/6o0h_protein_processed.pdb,data/PDBBind_processed/6o0h/6o0h_ligand.sdf
201,data/PDBBind_processed/6o3x/6o3x_protein_processed.pdb,data/PDBBind_processed/6o3x/6o3x_ligand.sdf
202,data/PDBBind_processed/5zjz/5zjz_protein_processed.pdb,data/PDBBind_processed/5zjz/5zjz_ligand.mol2
203,data/PDBBind_processed/6i8t/6i8t_protein_processed.pdb,data/PDBBind_processed/6i8t/6i8t_ligand.sdf
204,data/PDBBind_processed/6ooy/6ooy_protein_processed.pdb,data/PDBBind_processed/6ooy/6ooy_ligand.sdf
205,data/PDBBind_processed/6oiq/6oiq_protein_processed.pdb,data/PDBBind_processed/6oiq/6oiq_ligand.sdf
206,data/PDBBind_processed/6od6/6od6_protein_processed.pdb,data/PDBBind_processed/6od6/6od6_ligand.sdf
207,data/PDBBind_processed/6nrh/6nrh_protein_processed.pdb,data/PDBBind_processed/6nrh/6nrh_ligand.sdf
208,data/PDBBind_processed/6qra/6qra_protein_processed.pdb,data/PDBBind_processed/6qra/6qra_ligand.mol2
209,data/PDBBind_processed/6hhh/6hhh_protein_processed.pdb,data/PDBBind_processed/6hhh/6hhh_ligand.sdf
210,data/PDBBind_processed/6m7h/6m7h_protein_processed.pdb,data/PDBBind_processed/6m7h/6m7h_ligand.sdf
211,data/PDBBind_processed/6ufn/6ufn_protein_processed.pdb,data/PDBBind_processed/6ufn/6ufn_ligand.sdf
212,data/PDBBind_processed/6qr0/6qr0_protein_processed.pdb,data/PDBBind_processed/6qr0/6qr0_ligand.mol2
213,data/PDBBind_processed/6o5u/6o5u_protein_processed.pdb,data/PDBBind_processed/6o5u/6o5u_ligand.sdf
214,data/PDBBind_processed/6h14/6h14_protein_processed.pdb,data/PDBBind_processed/6h14/6h14_ligand.sdf
215,data/PDBBind_processed/6jwa/6jwa_protein_processed.pdb,data/PDBBind_processed/6jwa/6jwa_ligand.sdf
216,data/PDBBind_processed/6ny0/6ny0_protein_processed.pdb,data/PDBBind_processed/6ny0/6ny0_ligand.sdf
217,data/PDBBind_processed/6jan/6jan_protein_processed.pdb,data/PDBBind_processed/6jan/6jan_ligand.sdf
218,data/PDBBind_processed/6ftf/6ftf_protein_processed.pdb,data/PDBBind_processed/6ftf/6ftf_ligand.sdf
219,data/PDBBind_processed/6oxw/6oxw_protein_processed.pdb,data/PDBBind_processed/6oxw/6oxw_ligand.sdf
220,data/PDBBind_processed/6jon/6jon_protein_processed.pdb,data/PDBBind_processed/6jon/6jon_ligand.sdf
221,data/PDBBind_processed/6cf7/6cf7_protein_processed.pdb,data/PDBBind_processed/6cf7/6cf7_ligand.sdf
222,data/PDBBind_processed/6rtn/6rtn_protein_processed.pdb,data/PDBBind_processed/6rtn/6rtn_ligand.mol2
223,data/PDBBind_processed/6jsz/6jsz_protein_processed.pdb,data/PDBBind_processed/6jsz/6jsz_ligand.sdf
224,data/PDBBind_processed/6o9c/6o9c_protein_processed.pdb,data/PDBBind_processed/6o9c/6o9c_ligand.sdf
225,data/PDBBind_processed/6mo8/6mo8_protein_processed.pdb,data/PDBBind_processed/6mo8/6mo8_ligand.sdf
226,data/PDBBind_processed/6qln/6qln_protein_processed.pdb,data/PDBBind_processed/6qln/6qln_ligand.sdf
227,data/PDBBind_processed/6qqu/6qqu_protein_processed.pdb,data/PDBBind_processed/6qqu/6qqu_ligand.mol2
228,data/PDBBind_processed/6i66/6i66_protein_processed.pdb,data/PDBBind_processed/6i66/6i66_ligand.sdf
229,data/PDBBind_processed/6mja/6mja_protein_processed.pdb,data/PDBBind_processed/6mja/6mja_ligand.sdf
230,data/PDBBind_processed/6gwe/6gwe_protein_processed.pdb,data/PDBBind_processed/6gwe/6gwe_ligand.mol2
231,data/PDBBind_processed/6d3z/6d3z_protein_processed.pdb,data/PDBBind_processed/6d3z/6d3z_ligand.sdf
232,data/PDBBind_processed/6oxr/6oxr_protein_processed.pdb,data/PDBBind_processed/6oxr/6oxr_ligand.sdf
233,data/PDBBind_processed/6r4k/6r4k_protein_processed.pdb,data/PDBBind_processed/6r4k/6r4k_ligand.sdf
234,data/PDBBind_processed/6hle/6hle_protein_processed.pdb,data/PDBBind_processed/6hle/6hle_ligand.sdf
235,data/PDBBind_processed/6h9v/6h9v_protein_processed.pdb,data/PDBBind_processed/6h9v/6h9v_ligand.sdf
236,data/PDBBind_processed/6hou/6hou_protein_processed.pdb,data/PDBBind_processed/6hou/6hou_ligand.sdf
237,data/PDBBind_processed/6nv9/6nv9_protein_processed.pdb,data/PDBBind_processed/6nv9/6nv9_ligand.sdf
238,data/PDBBind_processed/6py0/6py0_protein_processed.pdb,data/PDBBind_processed/6py0/6py0_ligand.sdf
239,data/PDBBind_processed/6qlq/6qlq_protein_processed.pdb,data/PDBBind_processed/6qlq/6qlq_ligand.sdf
240,data/PDBBind_processed/6nv7/6nv7_protein_processed.pdb,data/PDBBind_processed/6nv7/6nv7_ligand.sdf
241,data/PDBBind_processed/6n4b/6n4b_protein_processed.pdb,data/PDBBind_processed/6n4b/6n4b_ligand.sdf
242,data/PDBBind_processed/6jaq/6jaq_protein_processed.pdb,data/PDBBind_processed/6jaq/6jaq_ligand.sdf
243,data/PDBBind_processed/6i8m/6i8m_protein_processed.pdb,data/PDBBind_processed/6i8m/6i8m_ligand.sdf
244,data/PDBBind_processed/6dz0/6dz0_protein_processed.pdb,data/PDBBind_processed/6dz0/6dz0_ligand.mol2
245,data/PDBBind_processed/6oxs/6oxs_protein_processed.pdb,data/PDBBind_processed/6oxs/6oxs_ligand.sdf
246,data/PDBBind_processed/6k2n/6k2n_protein_processed.pdb,data/PDBBind_processed/6k2n/6k2n_ligand.sdf
247,data/PDBBind_processed/6cjj/6cjj_protein_processed.pdb,data/PDBBind_processed/6cjj/6cjj_ligand.sdf
248,data/PDBBind_processed/6ffg/6ffg_protein_processed.pdb,data/PDBBind_processed/6ffg/6ffg_ligand.sdf
249,data/PDBBind_processed/6a73/6a73_protein_processed.pdb,data/PDBBind_processed/6a73/6a73_ligand.sdf
250,data/PDBBind_processed/6qqt/6qqt_protein_processed.pdb,data/PDBBind_processed/6qqt/6qqt_ligand.mol2
251,data/PDBBind_processed/6a1c/6a1c_protein_processed.pdb,data/PDBBind_processed/6a1c/6a1c_ligand.sdf
252,data/PDBBind_processed/6oxu/6oxu_protein_processed.pdb,data/PDBBind_processed/6oxu/6oxu_ligand.sdf
253,data/PDBBind_processed/6qre/6qre_protein_processed.pdb,data/PDBBind_processed/6qre/6qre_ligand.mol2
254,data/PDBBind_processed/6qtw/6qtw_protein_processed.pdb,data/PDBBind_processed/6qtw/6qtw_ligand.sdf
255,data/PDBBind_processed/6np4/6np4_protein_processed.pdb,data/PDBBind_processed/6np4/6np4_ligand.sdf
256,data/PDBBind_processed/6hv2/6hv2_protein_processed.pdb,data/PDBBind_processed/6hv2/6hv2_ligand.sdf
257,data/PDBBind_processed/6n55/6n55_protein_processed.pdb,data/PDBBind_processed/6n55/6n55_ligand.sdf
258,data/PDBBind_processed/6e3o/6e3o_protein_processed.pdb,data/PDBBind_processed/6e3o/6e3o_ligand.sdf
259,data/PDBBind_processed/6kjd/6kjd_protein_processed.pdb,data/PDBBind_processed/6kjd/6kjd_ligand.sdf
260,data/PDBBind_processed/6sfc/6sfc_protein_processed.pdb,data/PDBBind_processed/6sfc/6sfc_ligand.sdf
261,data/PDBBind_processed/6qi7/6qi7_protein_processed.pdb,data/PDBBind_processed/6qi7/6qi7_ligand.sdf
262,data/PDBBind_processed/6hzc/6hzc_protein_processed.pdb,data/PDBBind_processed/6hzc/6hzc_ligand.sdf
263,data/PDBBind_processed/6k04/6k04_protein_processed.pdb,data/PDBBind_processed/6k04/6k04_ligand.sdf
264,data/PDBBind_processed/6op0/6op0_protein_processed.pdb,data/PDBBind_processed/6op0/6op0_ligand.sdf
265,data/PDBBind_processed/6q38/6q38_protein_processed.pdb,data/PDBBind_processed/6q38/6q38_ligand.mol2
266,data/PDBBind_processed/6n8x/6n8x_protein_processed.pdb,data/PDBBind_processed/6n8x/6n8x_ligand.sdf
267,data/PDBBind_processed/6np3/6np3_protein_processed.pdb,data/PDBBind_processed/6np3/6np3_ligand.sdf
268,data/PDBBind_processed/6uvv/6uvv_protein_processed.pdb,data/PDBBind_processed/6uvv/6uvv_ligand.sdf
269,data/PDBBind_processed/6pgo/6pgo_protein_processed.pdb,data/PDBBind_processed/6pgo/6pgo_ligand.sdf
270,data/PDBBind_processed/6jbe/6jbe_protein_processed.pdb,data/PDBBind_processed/6jbe/6jbe_ligand.sdf
271,data/PDBBind_processed/6i75/6i75_protein_processed.pdb,data/PDBBind_processed/6i75/6i75_ligand.sdf
272,data/PDBBind_processed/6qqq/6qqq_protein_processed.pdb,data/PDBBind_processed/6qqq/6qqq_ligand.mol2
273,data/PDBBind_processed/6i62/6i62_protein_processed.pdb,data/PDBBind_processed/6i62/6i62_ligand.sdf
274,data/PDBBind_processed/6j9y/6j9y_protein_processed.pdb,data/PDBBind_processed/6j9y/6j9y_ligand.sdf
275,data/PDBBind_processed/6g29/6g29_protein_processed.pdb,data/PDBBind_processed/6g29/6g29_ligand.sdf
276,data/PDBBind_processed/6h7d/6h7d_protein_processed.pdb,data/PDBBind_processed/6h7d/6h7d_ligand.sdf
277,data/PDBBind_processed/6mo9/6mo9_protein_processed.pdb,data/PDBBind_processed/6mo9/6mo9_ligand.sdf
278,data/PDBBind_processed/6jao/6jao_protein_processed.pdb,data/PDBBind_processed/6jao/6jao_ligand.sdf
279,data/PDBBind_processed/6jmf/6jmf_protein_processed.pdb,data/PDBBind_processed/6jmf/6jmf_ligand.sdf
280,data/PDBBind_processed/6hmy/6hmy_protein_processed.pdb,data/PDBBind_processed/6hmy/6hmy_ligand.sdf
281,data/PDBBind_processed/6qfe/6qfe_protein_processed.pdb,data/PDBBind_processed/6qfe/6qfe_ligand.mol2
282,data/PDBBind_processed/5zml/5zml_protein_processed.pdb,data/PDBBind_processed/5zml/5zml_ligand.sdf
283,data/PDBBind_processed/6i65/6i65_protein_processed.pdb,data/PDBBind_processed/6i65/6i65_ligand.sdf
284,data/PDBBind_processed/6e7m/6e7m_protein_processed.pdb,data/PDBBind_processed/6e7m/6e7m_ligand.sdf
285,data/PDBBind_processed/6i61/6i61_protein_processed.pdb,data/PDBBind_processed/6i61/6i61_ligand.sdf
286,data/PDBBind_processed/6rz6/6rz6_protein_processed.pdb,data/PDBBind_processed/6rz6/6rz6_ligand.sdf
287,data/PDBBind_processed/6qtm/6qtm_protein_processed.pdb,data/PDBBind_processed/6qtm/6qtm_ligand.sdf
288,data/PDBBind_processed/6qlo/6qlo_protein_processed.pdb,data/PDBBind_processed/6qlo/6qlo_ligand.sdf
289,data/PDBBind_processed/6oie/6oie_protein_processed.pdb,data/PDBBind_processed/6oie/6oie_ligand.sdf
290,data/PDBBind_processed/6miy/6miy_protein_processed.pdb,data/PDBBind_processed/6miy/6miy_ligand.sdf
291,data/PDBBind_processed/6nrf/6nrf_protein_processed.pdb,data/PDBBind_processed/6nrf/6nrf_ligand.mol2
292,data/PDBBind_processed/6gj5/6gj5_protein_processed.pdb,data/PDBBind_processed/6gj5/6gj5_ligand.mol2
293,data/PDBBind_processed/6jad/6jad_protein_processed.pdb,data/PDBBind_processed/6jad/6jad_ligand.sdf
294,data/PDBBind_processed/6mj4/6mj4_protein_processed.pdb,data/PDBBind_processed/6mj4/6mj4_ligand.sdf
295,data/PDBBind_processed/6h12/6h12_protein_processed.pdb,data/PDBBind_processed/6h12/6h12_ligand.sdf
296,data/PDBBind_processed/6d3y/6d3y_protein_processed.pdb,data/PDBBind_processed/6d3y/6d3y_ligand.sdf
297,data/PDBBind_processed/6qr2/6qr2_protein_processed.pdb,data/PDBBind_processed/6qr2/6qr2_ligand.mol2
298,data/PDBBind_processed/6qxa/6qxa_protein_processed.pdb,data/PDBBind_processed/6qxa/6qxa_ligand.mol2
299,data/PDBBind_processed/6o9b/6o9b_protein_processed.pdb,data/PDBBind_processed/6o9b/6o9b_ligand.sdf
300,data/PDBBind_processed/6ckl/6ckl_protein_processed.pdb,data/PDBBind_processed/6ckl/6ckl_ligand.sdf
301,data/PDBBind_processed/6oir/6oir_protein_processed.pdb,data/PDBBind_processed/6oir/6oir_ligand.sdf
302,data/PDBBind_processed/6d40/6d40_protein_processed.pdb,data/PDBBind_processed/6d40/6d40_ligand.sdf
303,data/PDBBind_processed/6e6j/6e6j_protein_processed.pdb,data/PDBBind_processed/6e6j/6e6j_ligand.mol2
304,data/PDBBind_processed/6i7a/6i7a_protein_processed.pdb,data/PDBBind_processed/6i7a/6i7a_ligand.sdf
305,data/PDBBind_processed/6g25/6g25_protein_processed.pdb,data/PDBBind_processed/6g25/6g25_ligand.mol2
306,data/PDBBind_processed/6oin/6oin_protein_processed.pdb,data/PDBBind_processed/6oin/6oin_ligand.sdf
307,data/PDBBind_processed/6jam/6jam_protein_processed.pdb,data/PDBBind_processed/6jam/6jam_ligand.sdf
308,data/PDBBind_processed/6oxz/6oxz_protein_processed.pdb,data/PDBBind_processed/6oxz/6oxz_ligand.sdf
309,data/PDBBind_processed/6hop/6hop_protein_processed.pdb,data/PDBBind_processed/6hop/6hop_ligand.sdf
310,data/PDBBind_processed/6rot/6rot_protein_processed.pdb,data/PDBBind_processed/6rot/6rot_ligand.sdf
311,data/PDBBind_processed/6uhu/6uhu_protein_processed.pdb,data/PDBBind_processed/6uhu/6uhu_ligand.mol2
312,data/PDBBind_processed/6mji/6mji_protein_processed.pdb,data/PDBBind_processed/6mji/6mji_ligand.sdf
313,data/PDBBind_processed/6nrj/6nrj_protein_processed.pdb,data/PDBBind_processed/6nrj/6nrj_ligand.mol2
314,data/PDBBind_processed/6nt2/6nt2_protein_processed.pdb,data/PDBBind_processed/6nt2/6nt2_ligand.mol2
315,data/PDBBind_processed/6op9/6op9_protein_processed.pdb,data/PDBBind_processed/6op9/6op9_ligand.sdf
316,data/PDBBind_processed/6pno/6pno_protein_processed.pdb,data/PDBBind_processed/6pno/6pno_ligand.sdf
317,data/PDBBind_processed/6e4v/6e4v_protein_processed.pdb,data/PDBBind_processed/6e4v/6e4v_ligand.sdf
318,data/PDBBind_processed/6k1s/6k1s_protein_processed.pdb,data/PDBBind_processed/6k1s/6k1s_ligand.sdf
319,data/PDBBind_processed/6a87/6a87_protein_processed.pdb,data/PDBBind_processed/6a87/6a87_ligand.sdf
320,data/PDBBind_processed/6oim/6oim_protein_processed.pdb,data/PDBBind_processed/6oim/6oim_ligand.sdf
321,data/PDBBind_processed/6cjp/6cjp_protein_processed.pdb,data/PDBBind_processed/6cjp/6cjp_ligand.sdf
322,data/PDBBind_processed/6pyb/6pyb_protein_processed.pdb,data/PDBBind_processed/6pyb/6pyb_ligand.sdf
323,data/PDBBind_processed/6h13/6h13_protein_processed.pdb,data/PDBBind_processed/6h13/6h13_ligand.sdf
324,data/PDBBind_processed/6qrf/6qrf_protein_processed.pdb,data/PDBBind_processed/6qrf/6qrf_ligand.mol2
325,data/PDBBind_processed/6mhc/6mhc_protein_processed.pdb,data/PDBBind_processed/6mhc/6mhc_ligand.sdf
326,data/PDBBind_processed/6j9w/6j9w_protein_processed.pdb,data/PDBBind_processed/6j9w/6j9w_ligand.sdf
327,data/PDBBind_processed/6nrg/6nrg_protein_processed.pdb,data/PDBBind_processed/6nrg/6nrg_ligand.mol2
328,data/PDBBind_processed/6fff/6fff_protein_processed.pdb,data/PDBBind_processed/6fff/6fff_ligand.sdf
329,data/PDBBind_processed/6n93/6n93_protein_processed.pdb,data/PDBBind_processed/6n93/6n93_ligand.sdf
330,data/PDBBind_processed/6jut/6jut_protein_processed.pdb,data/PDBBind_processed/6jut/6jut_ligand.mol2
331,data/PDBBind_processed/6g2e/6g2e_protein_processed.pdb,data/PDBBind_processed/6g2e/6g2e_ligand.sdf
332,data/PDBBind_processed/6nd3/6nd3_protein_processed.pdb,data/PDBBind_processed/6nd3/6nd3_ligand.sdf
333,data/PDBBind_processed/6os6/6os6_protein_processed.pdb,data/PDBBind_processed/6os6/6os6_ligand.mol2
334,data/PDBBind_processed/6dql/6dql_protein_processed.pdb,data/PDBBind_processed/6dql/6dql_ligand.mol2
335,data/PDBBind_processed/6inz/6inz_protein_processed.pdb,data/PDBBind_processed/6inz/6inz_ligand.sdf
336,data/PDBBind_processed/6i67/6i67_protein_processed.pdb,data/PDBBind_processed/6i67/6i67_ligand.sdf
337,data/PDBBind_processed/6quw/6quw_protein_processed.pdb,data/PDBBind_processed/6quw/6quw_ligand.sdf
338,data/PDBBind_processed/6qwi/6qwi_protein_processed.pdb,data/PDBBind_processed/6qwi/6qwi_ligand.sdf
339,data/PDBBind_processed/6npm/6npm_protein_processed.pdb,data/PDBBind_processed/6npm/6npm_ligand.sdf
340,data/PDBBind_processed/6i64/6i64_protein_processed.pdb,data/PDBBind_processed/6i64/6i64_ligand.sdf
341,data/PDBBind_processed/6e3n/6e3n_protein_processed.pdb,data/PDBBind_processed/6e3n/6e3n_ligand.sdf
342,data/PDBBind_processed/6qrg/6qrg_protein_processed.pdb,data/PDBBind_processed/6qrg/6qrg_ligand.mol2
343,data/PDBBind_processed/6nxz/6nxz_protein_processed.pdb,data/PDBBind_processed/6nxz/6nxz_ligand.sdf
344,data/PDBBind_processed/6iby/6iby_protein_processed.pdb,data/PDBBind_processed/6iby/6iby_ligand.sdf
345,data/PDBBind_processed/6gj7/6gj7_protein_processed.pdb,data/PDBBind_processed/6gj7/6gj7_ligand.mol2
346,data/PDBBind_processed/6qr3/6qr3_protein_processed.pdb,data/PDBBind_processed/6qr3/6qr3_ligand.mol2
347,data/PDBBind_processed/6qr1/6qr1_protein_processed.pdb,data/PDBBind_processed/6qr1/6qr1_ligand.mol2
348,data/PDBBind_processed/6s9x/6s9x_protein_processed.pdb,data/PDBBind_processed/6s9x/6s9x_ligand.sdf
349,data/PDBBind_processed/6q4q/6q4q_protein_processed.pdb,data/PDBBind_processed/6q4q/6q4q_ligand.mol2
350,data/PDBBind_processed/6hbn/6hbn_protein_processed.pdb,data/PDBBind_processed/6hbn/6hbn_ligand.sdf
351,data/PDBBind_processed/6nw3/6nw3_protein_processed.pdb,data/PDBBind_processed/6nw3/6nw3_ligand.sdf
352,data/PDBBind_processed/6tel/6tel_protein_processed.pdb,data/PDBBind_processed/6tel/6tel_ligand.sdf
353,data/PDBBind_processed/6p8y/6p8y_protein_processed.pdb,data/PDBBind_processed/6p8y/6p8y_ligand.sdf
354,data/PDBBind_processed/6d5w/6d5w_protein_processed.pdb,data/PDBBind_processed/6d5w/6d5w_ligand.sdf
355,data/PDBBind_processed/6t6a/6t6a_protein_processed.pdb,data/PDBBind_processed/6t6a/6t6a_ligand.mol2
356,data/PDBBind_processed/6o5g/6o5g_protein_processed.pdb,data/PDBBind_processed/6o5g/6o5g_ligand.mol2
357,data/PDBBind_processed/6r7d/6r7d_protein_processed.pdb,data/PDBBind_processed/6r7d/6r7d_ligand.sdf
358,data/PDBBind_processed/6pya/6pya_protein_processed.pdb,data/PDBBind_processed/6pya/6pya_ligand.mol2
359,data/PDBBind_processed/6ffe/6ffe_protein_processed.pdb,data/PDBBind_processed/6ffe/6ffe_ligand.sdf
360,data/PDBBind_processed/6d3x/6d3x_protein_processed.pdb,data/PDBBind_processed/6d3x/6d3x_ligand.sdf
361,data/PDBBind_processed/6gj8/6gj8_protein_processed.pdb,data/PDBBind_processed/6gj8/6gj8_ligand.mol2
362,data/PDBBind_processed/6mo2/6mo2_protein_processed.pdb,data/PDBBind_processed/6mo2/6mo2_ligand.mol2
1 complex_name protein_path ligand_description ligand protein_sequence
2 0 data/PDBBind_processed/6qqw/6qqw_protein_processed.pdb data/PDBBind_processed/6qqw/6qqw_ligand.mol2
3 1 data/PDBBind_processed/6d08/6d08_protein_processed.pdb data/PDBBind_processed/6d08/6d08_ligand.sdf
4 2 data/PDBBind_processed/6jap/6jap_protein_processed.pdb data/PDBBind_processed/6jap/6jap_ligand.sdf
5 3 data/PDBBind_processed/6np2/6np2_protein_processed.pdb data/PDBBind_processed/6np2/6np2_ligand.sdf
6 4 data/PDBBind_processed/6uvp/6uvp_protein_processed.pdb data/PDBBind_processed/6uvp/6uvp_ligand.sdf
7 5 data/PDBBind_processed/6oxq/6oxq_protein_processed.pdb data/PDBBind_processed/6oxq/6oxq_ligand.sdf
8 6 data/PDBBind_processed/6jsn/6jsn_protein_processed.pdb data/PDBBind_processed/6jsn/6jsn_ligand.sdf
9 7 data/PDBBind_processed/6hzb/6hzb_protein_processed.pdb data/PDBBind_processed/6hzb/6hzb_ligand.sdf
10 8 data/PDBBind_processed/6qrc/6qrc_protein_processed.pdb data/PDBBind_processed/6qrc/6qrc_ligand.mol2
11 9 data/PDBBind_processed/6oio/6oio_protein_processed.pdb data/PDBBind_processed/6oio/6oio_ligand.sdf
12 10 data/PDBBind_processed/6jag/6jag_protein_processed.pdb data/PDBBind_processed/6jag/6jag_ligand.sdf
13 11 data/PDBBind_processed/6moa/6moa_protein_processed.pdb data/PDBBind_processed/6moa/6moa_ligand.mol2
14 12 data/PDBBind_processed/6hld/6hld_protein_processed.pdb data/PDBBind_processed/6hld/6hld_ligand.sdf
15 13 data/PDBBind_processed/6i9a/6i9a_protein_processed.pdb data/PDBBind_processed/6i9a/6i9a_ligand.sdf
16 14 data/PDBBind_processed/6e4c/6e4c_protein_processed.pdb data/PDBBind_processed/6e4c/6e4c_ligand.sdf
17 15 data/PDBBind_processed/6g24/6g24_protein_processed.pdb data/PDBBind_processed/6g24/6g24_ligand.sdf
18 16 data/PDBBind_processed/6jb4/6jb4_protein_processed.pdb data/PDBBind_processed/6jb4/6jb4_ligand.sdf
19 17 data/PDBBind_processed/6s55/6s55_protein_processed.pdb data/PDBBind_processed/6s55/6s55_ligand.sdf
20 18 data/PDBBind_processed/6seo/6seo_protein_processed.pdb data/PDBBind_processed/6seo/6seo_ligand.sdf
21 19 data/PDBBind_processed/6dyz/6dyz_protein_processed.pdb data/PDBBind_processed/6dyz/6dyz_ligand.mol2
22 20 data/PDBBind_processed/5zk5/5zk5_protein_processed.pdb data/PDBBind_processed/5zk5/5zk5_ligand.sdf
23 21 data/PDBBind_processed/6jid/6jid_protein_processed.pdb data/PDBBind_processed/6jid/6jid_ligand.sdf
24 22 data/PDBBind_processed/5ze6/5ze6_protein_processed.pdb data/PDBBind_processed/5ze6/5ze6_ligand.sdf
25 23 data/PDBBind_processed/6qlu/6qlu_protein_processed.pdb data/PDBBind_processed/6qlu/6qlu_ligand.sdf
26 24 data/PDBBind_processed/6a6k/6a6k_protein_processed.pdb data/PDBBind_processed/6a6k/6a6k_ligand.sdf
27 25 data/PDBBind_processed/6qgf/6qgf_protein_processed.pdb data/PDBBind_processed/6qgf/6qgf_ligand.sdf
28 26 data/PDBBind_processed/6e3z/6e3z_protein_processed.pdb data/PDBBind_processed/6e3z/6e3z_ligand.sdf
29 27 data/PDBBind_processed/6te6/6te6_protein_processed.pdb data/PDBBind_processed/6te6/6te6_ligand.sdf
30 28 data/PDBBind_processed/6pka/6pka_protein_processed.pdb data/PDBBind_processed/6pka/6pka_ligand.sdf
31 29 data/PDBBind_processed/6g2o/6g2o_protein_processed.pdb data/PDBBind_processed/6g2o/6g2o_ligand.sdf
32 30 data/PDBBind_processed/6jsf/6jsf_protein_processed.pdb data/PDBBind_processed/6jsf/6jsf_ligand.sdf
33 31 data/PDBBind_processed/5zxk/5zxk_protein_processed.pdb data/PDBBind_processed/5zxk/5zxk_ligand.sdf
34 32 data/PDBBind_processed/6qxd/6qxd_protein_processed.pdb data/PDBBind_processed/6qxd/6qxd_ligand.sdf
35 33 data/PDBBind_processed/6n97/6n97_protein_processed.pdb data/PDBBind_processed/6n97/6n97_ligand.sdf
36 34 data/PDBBind_processed/6jt3/6jt3_protein_processed.pdb data/PDBBind_processed/6jt3/6jt3_ligand.sdf
37 35 data/PDBBind_processed/6qtr/6qtr_protein_processed.pdb data/PDBBind_processed/6qtr/6qtr_ligand.sdf
38 36 data/PDBBind_processed/6oy1/6oy1_protein_processed.pdb data/PDBBind_processed/6oy1/6oy1_ligand.sdf
39 37 data/PDBBind_processed/6n96/6n96_protein_processed.pdb data/PDBBind_processed/6n96/6n96_ligand.sdf
40 38 data/PDBBind_processed/6qzh/6qzh_protein_processed.pdb data/PDBBind_processed/6qzh/6qzh_ligand.sdf
41 39 data/PDBBind_processed/6qqz/6qqz_protein_processed.pdb data/PDBBind_processed/6qqz/6qqz_ligand.mol2
42 40 data/PDBBind_processed/6qmt/6qmt_protein_processed.pdb data/PDBBind_processed/6qmt/6qmt_ligand.sdf
43 41 data/PDBBind_processed/6ibx/6ibx_protein_processed.pdb data/PDBBind_processed/6ibx/6ibx_ligand.sdf
44 42 data/PDBBind_processed/6hmt/6hmt_protein_processed.pdb data/PDBBind_processed/6hmt/6hmt_ligand.sdf
45 43 data/PDBBind_processed/5zk7/5zk7_protein_processed.pdb data/PDBBind_processed/5zk7/5zk7_ligand.sdf
46 44 data/PDBBind_processed/6k3l/6k3l_protein_processed.pdb data/PDBBind_processed/6k3l/6k3l_ligand.sdf
47 45 data/PDBBind_processed/6cjs/6cjs_protein_processed.pdb data/PDBBind_processed/6cjs/6cjs_ligand.sdf
48 46 data/PDBBind_processed/6n9l/6n9l_protein_processed.pdb data/PDBBind_processed/6n9l/6n9l_ligand.sdf
49 47 data/PDBBind_processed/6ibz/6ibz_protein_processed.pdb data/PDBBind_processed/6ibz/6ibz_ligand.sdf
50 48 data/PDBBind_processed/6ott/6ott_protein_processed.pdb data/PDBBind_processed/6ott/6ott_ligand.sdf
51 49 data/PDBBind_processed/6gge/6gge_protein_processed.pdb data/PDBBind_processed/6gge/6gge_ligand.sdf
52 50 data/PDBBind_processed/6hot/6hot_protein_processed.pdb data/PDBBind_processed/6hot/6hot_ligand.sdf
53 51 data/PDBBind_processed/6e3p/6e3p_protein_processed.pdb data/PDBBind_processed/6e3p/6e3p_ligand.mol2
54 52 data/PDBBind_processed/6md6/6md6_protein_processed.pdb data/PDBBind_processed/6md6/6md6_ligand.sdf
55 53 data/PDBBind_processed/6hlb/6hlb_protein_processed.pdb data/PDBBind_processed/6hlb/6hlb_ligand.sdf
56 54 data/PDBBind_processed/6fe5/6fe5_protein_processed.pdb data/PDBBind_processed/6fe5/6fe5_ligand.sdf
57 55 data/PDBBind_processed/6uwp/6uwp_protein_processed.pdb data/PDBBind_processed/6uwp/6uwp_ligand.sdf
58 56 data/PDBBind_processed/6npp/6npp_protein_processed.pdb data/PDBBind_processed/6npp/6npp_ligand.sdf
59 57 data/PDBBind_processed/6g2f/6g2f_protein_processed.pdb data/PDBBind_processed/6g2f/6g2f_ligand.sdf
60 58 data/PDBBind_processed/6mo7/6mo7_protein_processed.pdb data/PDBBind_processed/6mo7/6mo7_ligand.sdf
61 59 data/PDBBind_processed/6bqd/6bqd_protein_processed.pdb data/PDBBind_processed/6bqd/6bqd_ligand.mol2
62 60 data/PDBBind_processed/6nsv/6nsv_protein_processed.pdb data/PDBBind_processed/6nsv/6nsv_ligand.mol2
63 61 data/PDBBind_processed/6i76/6i76_protein_processed.pdb data/PDBBind_processed/6i76/6i76_ligand.sdf
64 62 data/PDBBind_processed/6n53/6n53_protein_processed.pdb data/PDBBind_processed/6n53/6n53_ligand.sdf
65 63 data/PDBBind_processed/6g2c/6g2c_protein_processed.pdb data/PDBBind_processed/6g2c/6g2c_ligand.sdf
66 64 data/PDBBind_processed/6eeb/6eeb_protein_processed.pdb data/PDBBind_processed/6eeb/6eeb_ligand.mol2
67 65 data/PDBBind_processed/6n0m/6n0m_protein_processed.pdb data/PDBBind_processed/6n0m/6n0m_ligand.sdf
68 66 data/PDBBind_processed/6uvy/6uvy_protein_processed.pdb data/PDBBind_processed/6uvy/6uvy_ligand.sdf
69 67 data/PDBBind_processed/6ovz/6ovz_protein_processed.pdb data/PDBBind_processed/6ovz/6ovz_ligand.sdf
70 68 data/PDBBind_processed/6olx/6olx_protein_processed.pdb data/PDBBind_processed/6olx/6olx_ligand.sdf
71 69 data/PDBBind_processed/6v5l/6v5l_protein_processed.pdb data/PDBBind_processed/6v5l/6v5l_ligand.mol2
72 70 data/PDBBind_processed/6hhg/6hhg_protein_processed.pdb data/PDBBind_processed/6hhg/6hhg_ligand.sdf
73 71 data/PDBBind_processed/5zcu/5zcu_protein_processed.pdb data/PDBBind_processed/5zcu/5zcu_ligand.sdf
74 72 data/PDBBind_processed/6dz2/6dz2_protein_processed.pdb data/PDBBind_processed/6dz2/6dz2_ligand.mol2
75 73 data/PDBBind_processed/6mjq/6mjq_protein_processed.pdb data/PDBBind_processed/6mjq/6mjq_ligand.sdf
76 74 data/PDBBind_processed/6efk/6efk_protein_processed.pdb data/PDBBind_processed/6efk/6efk_ligand.sdf
77 75 data/PDBBind_processed/6s9w/6s9w_protein_processed.pdb data/PDBBind_processed/6s9w/6s9w_ligand.sdf
78 76 data/PDBBind_processed/6gdy/6gdy_protein_processed.pdb data/PDBBind_processed/6gdy/6gdy_ligand.sdf
79 77 data/PDBBind_processed/6kqi/6kqi_protein_processed.pdb data/PDBBind_processed/6kqi/6kqi_ligand.sdf
80 78 data/PDBBind_processed/6ueg/6ueg_protein_processed.pdb data/PDBBind_processed/6ueg/6ueg_ligand.sdf
81 79 data/PDBBind_processed/6oxt/6oxt_protein_processed.pdb data/PDBBind_processed/6oxt/6oxt_ligand.sdf
82 80 data/PDBBind_processed/6oy0/6oy0_protein_processed.pdb data/PDBBind_processed/6oy0/6oy0_ligand.sdf
83 81 data/PDBBind_processed/6qr7/6qr7_protein_processed.pdb data/PDBBind_processed/6qr7/6qr7_ligand.mol2
84 82 data/PDBBind_processed/6i41/6i41_protein_processed.pdb data/PDBBind_processed/6i41/6i41_ligand.sdf
85 83 data/PDBBind_processed/6cyg/6cyg_protein_processed.pdb data/PDBBind_processed/6cyg/6cyg_ligand.sdf
86 84 data/PDBBind_processed/6qmr/6qmr_protein_processed.pdb data/PDBBind_processed/6qmr/6qmr_ligand.sdf
87 85 data/PDBBind_processed/6g27/6g27_protein_processed.pdb data/PDBBind_processed/6g27/6g27_ligand.sdf
88 86 data/PDBBind_processed/6ggb/6ggb_protein_processed.pdb data/PDBBind_processed/6ggb/6ggb_ligand.sdf
89 87 data/PDBBind_processed/6g3c/6g3c_protein_processed.pdb data/PDBBind_processed/6g3c/6g3c_ligand.sdf
90 88 data/PDBBind_processed/6n4e/6n4e_protein_processed.pdb data/PDBBind_processed/6n4e/6n4e_ligand.sdf
91 89 data/PDBBind_processed/6fcj/6fcj_protein_processed.pdb data/PDBBind_processed/6fcj/6fcj_ligand.sdf
92 90 data/PDBBind_processed/6quv/6quv_protein_processed.pdb data/PDBBind_processed/6quv/6quv_ligand.sdf
93 91 data/PDBBind_processed/6iql/6iql_protein_processed.pdb data/PDBBind_processed/6iql/6iql_ligand.mol2
94 92 data/PDBBind_processed/6i74/6i74_protein_processed.pdb data/PDBBind_processed/6i74/6i74_ligand.sdf
95 93 data/PDBBind_processed/6qr4/6qr4_protein_processed.pdb data/PDBBind_processed/6qr4/6qr4_ligand.mol2
96 94 data/PDBBind_processed/6rnu/6rnu_protein_processed.pdb data/PDBBind_processed/6rnu/6rnu_ligand.sdf
97 95 data/PDBBind_processed/6jib/6jib_protein_processed.pdb data/PDBBind_processed/6jib/6jib_ligand.sdf
98 96 data/PDBBind_processed/6izq/6izq_protein_processed.pdb data/PDBBind_processed/6izq/6izq_ligand.sdf
99 97 data/PDBBind_processed/6qw8/6qw8_protein_processed.pdb data/PDBBind_processed/6qw8/6qw8_ligand.sdf
100 98 data/PDBBind_processed/6qto/6qto_protein_processed.pdb data/PDBBind_processed/6qto/6qto_ligand.sdf
101 99 data/PDBBind_processed/6qrd/6qrd_protein_processed.pdb data/PDBBind_processed/6qrd/6qrd_ligand.mol2
102 100 data/PDBBind_processed/6hza/6hza_protein_processed.pdb data/PDBBind_processed/6hza/6hza_ligand.sdf
103 101 data/PDBBind_processed/6e5s/6e5s_protein_processed.pdb data/PDBBind_processed/6e5s/6e5s_ligand.sdf
104 102 data/PDBBind_processed/6dz3/6dz3_protein_processed.pdb data/PDBBind_processed/6dz3/6dz3_ligand.mol2
105 103 data/PDBBind_processed/6e6w/6e6w_protein_processed.pdb data/PDBBind_processed/6e6w/6e6w_ligand.mol2
106 104 data/PDBBind_processed/6cyh/6cyh_protein_processed.pdb data/PDBBind_processed/6cyh/6cyh_ligand.sdf
107 105 data/PDBBind_processed/5zlf/5zlf_protein_processed.pdb data/PDBBind_processed/5zlf/5zlf_ligand.sdf
108 106 data/PDBBind_processed/6om4/6om4_protein_processed.pdb data/PDBBind_processed/6om4/6om4_ligand.sdf
109 107 data/PDBBind_processed/6gga/6gga_protein_processed.pdb data/PDBBind_processed/6gga/6gga_ligand.sdf
110 108 data/PDBBind_processed/6pgp/6pgp_protein_processed.pdb data/PDBBind_processed/6pgp/6pgp_ligand.sdf
111 109 data/PDBBind_processed/6qqv/6qqv_protein_processed.pdb data/PDBBind_processed/6qqv/6qqv_ligand.mol2
112 110 data/PDBBind_processed/6qtq/6qtq_protein_processed.pdb data/PDBBind_processed/6qtq/6qtq_ligand.sdf
113 111 data/PDBBind_processed/6gj6/6gj6_protein_processed.pdb data/PDBBind_processed/6gj6/6gj6_ligand.mol2
114 112 data/PDBBind_processed/6os5/6os5_protein_processed.pdb data/PDBBind_processed/6os5/6os5_ligand.mol2
115 113 data/PDBBind_processed/6s07/6s07_protein_processed.pdb data/PDBBind_processed/6s07/6s07_ligand.sdf
116 114 data/PDBBind_processed/6i77/6i77_protein_processed.pdb data/PDBBind_processed/6i77/6i77_ligand.sdf
117 115 data/PDBBind_processed/6hhj/6hhj_protein_processed.pdb data/PDBBind_processed/6hhj/6hhj_ligand.sdf
118 116 data/PDBBind_processed/6ahs/6ahs_protein_processed.pdb data/PDBBind_processed/6ahs/6ahs_ligand.sdf
119 117 data/PDBBind_processed/6oxx/6oxx_protein_processed.pdb data/PDBBind_processed/6oxx/6oxx_ligand.sdf
120 118 data/PDBBind_processed/6mjj/6mjj_protein_processed.pdb data/PDBBind_processed/6mjj/6mjj_ligand.sdf
121 119 data/PDBBind_processed/6hor/6hor_protein_processed.pdb data/PDBBind_processed/6hor/6hor_ligand.sdf
122 120 data/PDBBind_processed/6jb0/6jb0_protein_processed.pdb data/PDBBind_processed/6jb0/6jb0_ligand.sdf
123 121 data/PDBBind_processed/6i68/6i68_protein_processed.pdb data/PDBBind_processed/6i68/6i68_ligand.sdf
124 122 data/PDBBind_processed/6pz4/6pz4_protein_processed.pdb data/PDBBind_processed/6pz4/6pz4_ligand.sdf
125 123 data/PDBBind_processed/6mhb/6mhb_protein_processed.pdb data/PDBBind_processed/6mhb/6mhb_ligand.sdf
126 124 data/PDBBind_processed/6uim/6uim_protein_processed.pdb data/PDBBind_processed/6uim/6uim_ligand.sdf
127 125 data/PDBBind_processed/6jsg/6jsg_protein_processed.pdb data/PDBBind_processed/6jsg/6jsg_ligand.sdf
128 126 data/PDBBind_processed/6i78/6i78_protein_processed.pdb data/PDBBind_processed/6i78/6i78_ligand.sdf
129 127 data/PDBBind_processed/6oxy/6oxy_protein_processed.pdb data/PDBBind_processed/6oxy/6oxy_ligand.sdf
130 128 data/PDBBind_processed/6gbw/6gbw_protein_processed.pdb data/PDBBind_processed/6gbw/6gbw_ligand.sdf
131 129 data/PDBBind_processed/6mo0/6mo0_protein_processed.pdb data/PDBBind_processed/6mo0/6mo0_ligand.sdf
132 130 data/PDBBind_processed/6ggf/6ggf_protein_processed.pdb data/PDBBind_processed/6ggf/6ggf_ligand.sdf
133 131 data/PDBBind_processed/6qge/6qge_protein_processed.pdb data/PDBBind_processed/6qge/6qge_ligand.sdf
134 132 data/PDBBind_processed/6cjr/6cjr_protein_processed.pdb data/PDBBind_processed/6cjr/6cjr_ligand.sdf
135 133 data/PDBBind_processed/6oxp/6oxp_protein_processed.pdb data/PDBBind_processed/6oxp/6oxp_ligand.sdf
136 134 data/PDBBind_processed/6d07/6d07_protein_processed.pdb data/PDBBind_processed/6d07/6d07_ligand.sdf
137 135 data/PDBBind_processed/6i63/6i63_protein_processed.pdb data/PDBBind_processed/6i63/6i63_ligand.sdf
138 136 data/PDBBind_processed/6ten/6ten_protein_processed.pdb data/PDBBind_processed/6ten/6ten_ligand.sdf
139 137 data/PDBBind_processed/6uii/6uii_protein_processed.pdb data/PDBBind_processed/6uii/6uii_ligand.sdf
140 138 data/PDBBind_processed/6qlr/6qlr_protein_processed.pdb data/PDBBind_processed/6qlr/6qlr_ligand.sdf
141 139 data/PDBBind_processed/6sen/6sen_protein_processed.pdb data/PDBBind_processed/6sen/6sen_ligand.mol2
142 140 data/PDBBind_processed/6oxv/6oxv_protein_processed.pdb data/PDBBind_processed/6oxv/6oxv_ligand.sdf
143 141 data/PDBBind_processed/6g2b/6g2b_protein_processed.pdb data/PDBBind_processed/6g2b/6g2b_ligand.sdf
144 142 data/PDBBind_processed/5zr3/5zr3_protein_processed.pdb data/PDBBind_processed/5zr3/5zr3_ligand.sdf
145 143 data/PDBBind_processed/6kjf/6kjf_protein_processed.pdb data/PDBBind_processed/6kjf/6kjf_ligand.sdf
146 144 data/PDBBind_processed/6qr9/6qr9_protein_processed.pdb data/PDBBind_processed/6qr9/6qr9_ligand.mol2
147 145 data/PDBBind_processed/6g9f/6g9f_protein_processed.pdb data/PDBBind_processed/6g9f/6g9f_ligand.sdf
148 146 data/PDBBind_processed/6e6v/6e6v_protein_processed.pdb data/PDBBind_processed/6e6v/6e6v_ligand.sdf
149 147 data/PDBBind_processed/5zk9/5zk9_protein_processed.pdb data/PDBBind_processed/5zk9/5zk9_ligand.sdf
150 148 data/PDBBind_processed/6pnn/6pnn_protein_processed.pdb data/PDBBind_processed/6pnn/6pnn_ligand.sdf
151 149 data/PDBBind_processed/6nri/6nri_protein_processed.pdb data/PDBBind_processed/6nri/6nri_ligand.sdf
152 150 data/PDBBind_processed/6uwv/6uwv_protein_processed.pdb data/PDBBind_processed/6uwv/6uwv_ligand.sdf
153 151 data/PDBBind_processed/6ooz/6ooz_protein_processed.pdb data/PDBBind_processed/6ooz/6ooz_ligand.sdf
154 152 data/PDBBind_processed/6npi/6npi_protein_processed.pdb data/PDBBind_processed/6npi/6npi_ligand.sdf
155 153 data/PDBBind_processed/6oip/6oip_protein_processed.pdb data/PDBBind_processed/6oip/6oip_ligand.sdf
156 154 data/PDBBind_processed/6miv/6miv_protein_processed.pdb data/PDBBind_processed/6miv/6miv_ligand.sdf
157 155 data/PDBBind_processed/6s57/6s57_protein_processed.pdb data/PDBBind_processed/6s57/6s57_ligand.sdf
158 156 data/PDBBind_processed/6p8x/6p8x_protein_processed.pdb data/PDBBind_processed/6p8x/6p8x_ligand.sdf
159 157 data/PDBBind_processed/6hoq/6hoq_protein_processed.pdb data/PDBBind_processed/6hoq/6hoq_ligand.sdf
160 158 data/PDBBind_processed/6qts/6qts_protein_processed.pdb data/PDBBind_processed/6qts/6qts_ligand.sdf
161 159 data/PDBBind_processed/6ggd/6ggd_protein_processed.pdb data/PDBBind_processed/6ggd/6ggd_ligand.sdf
162 160 data/PDBBind_processed/6pnm/6pnm_protein_processed.pdb data/PDBBind_processed/6pnm/6pnm_ligand.sdf
163 161 data/PDBBind_processed/6oy2/6oy2_protein_processed.pdb data/PDBBind_processed/6oy2/6oy2_ligand.sdf
164 162 data/PDBBind_processed/6oi8/6oi8_protein_processed.pdb data/PDBBind_processed/6oi8/6oi8_ligand.sdf
165 163 data/PDBBind_processed/6mhd/6mhd_protein_processed.pdb data/PDBBind_processed/6mhd/6mhd_ligand.sdf
166 164 data/PDBBind_processed/6agt/6agt_protein_processed.pdb data/PDBBind_processed/6agt/6agt_ligand.sdf
167 165 data/PDBBind_processed/6i5p/6i5p_protein_processed.pdb data/PDBBind_processed/6i5p/6i5p_ligand.sdf
168 166 data/PDBBind_processed/6hhr/6hhr_protein_processed.pdb data/PDBBind_processed/6hhr/6hhr_ligand.sdf
169 167 data/PDBBind_processed/6p8z/6p8z_protein_processed.pdb data/PDBBind_processed/6p8z/6p8z_ligand.sdf
170 168 data/PDBBind_processed/6c85/6c85_protein_processed.pdb data/PDBBind_processed/6c85/6c85_ligand.sdf
171 169 data/PDBBind_processed/6g5u/6g5u_protein_processed.pdb data/PDBBind_processed/6g5u/6g5u_ligand.sdf
172 170 data/PDBBind_processed/6j06/6j06_protein_processed.pdb data/PDBBind_processed/6j06/6j06_ligand.sdf
173 171 data/PDBBind_processed/6qsz/6qsz_protein_processed.pdb data/PDBBind_processed/6qsz/6qsz_ligand.sdf
174 172 data/PDBBind_processed/6jbb/6jbb_protein_processed.pdb data/PDBBind_processed/6jbb/6jbb_ligand.sdf
175 173 data/PDBBind_processed/6hhp/6hhp_protein_processed.pdb data/PDBBind_processed/6hhp/6hhp_ligand.sdf
176 174 data/PDBBind_processed/6np5/6np5_protein_processed.pdb data/PDBBind_processed/6np5/6np5_ligand.sdf
177 175 data/PDBBind_processed/6nlj/6nlj_protein_processed.pdb data/PDBBind_processed/6nlj/6nlj_ligand.sdf
178 176 data/PDBBind_processed/6qlp/6qlp_protein_processed.pdb data/PDBBind_processed/6qlp/6qlp_ligand.sdf
179 177 data/PDBBind_processed/6n94/6n94_protein_processed.pdb data/PDBBind_processed/6n94/6n94_ligand.sdf
180 178 data/PDBBind_processed/6e13/6e13_protein_processed.pdb data/PDBBind_processed/6e13/6e13_ligand.sdf
181 179 data/PDBBind_processed/6qls/6qls_protein_processed.pdb data/PDBBind_processed/6qls/6qls_ligand.sdf
182 180 data/PDBBind_processed/6uil/6uil_protein_processed.pdb data/PDBBind_processed/6uil/6uil_ligand.sdf
183 181 data/PDBBind_processed/6st3/6st3_protein_processed.pdb data/PDBBind_processed/6st3/6st3_ligand.sdf
184 182 data/PDBBind_processed/6n92/6n92_protein_processed.pdb data/PDBBind_processed/6n92/6n92_ligand.sdf
185 183 data/PDBBind_processed/6s56/6s56_protein_processed.pdb data/PDBBind_processed/6s56/6s56_ligand.sdf
186 184 data/PDBBind_processed/6hzd/6hzd_protein_processed.pdb data/PDBBind_processed/6hzd/6hzd_ligand.sdf
187 185 data/PDBBind_processed/6uhv/6uhv_protein_processed.pdb data/PDBBind_processed/6uhv/6uhv_ligand.sdf
188 186 data/PDBBind_processed/6k05/6k05_protein_processed.pdb data/PDBBind_processed/6k05/6k05_ligand.sdf
189 187 data/PDBBind_processed/6q36/6q36_protein_processed.pdb data/PDBBind_processed/6q36/6q36_ligand.mol2
190 188 data/PDBBind_processed/6ic0/6ic0_protein_processed.pdb data/PDBBind_processed/6ic0/6ic0_ligand.sdf
191 189 data/PDBBind_processed/6hhi/6hhi_protein_processed.pdb data/PDBBind_processed/6hhi/6hhi_ligand.sdf
192 190 data/PDBBind_processed/6e3m/6e3m_protein_processed.pdb data/PDBBind_processed/6e3m/6e3m_ligand.sdf
193 191 data/PDBBind_processed/6qtx/6qtx_protein_processed.pdb data/PDBBind_processed/6qtx/6qtx_ligand.sdf
194 192 data/PDBBind_processed/6jse/6jse_protein_processed.pdb data/PDBBind_processed/6jse/6jse_ligand.sdf
195 193 data/PDBBind_processed/5zjy/5zjy_protein_processed.pdb data/PDBBind_processed/5zjy/5zjy_ligand.sdf
196 194 data/PDBBind_processed/6o3y/6o3y_protein_processed.pdb data/PDBBind_processed/6o3y/6o3y_ligand.sdf
197 195 data/PDBBind_processed/6rpg/6rpg_protein_processed.pdb data/PDBBind_processed/6rpg/6rpg_ligand.sdf
198 196 data/PDBBind_processed/6rr0/6rr0_protein_processed.pdb data/PDBBind_processed/6rr0/6rr0_ligand.sdf
199 197 data/PDBBind_processed/6gzy/6gzy_protein_processed.pdb data/PDBBind_processed/6gzy/6gzy_ligand.sdf
200 198 data/PDBBind_processed/6qlt/6qlt_protein_processed.pdb data/PDBBind_processed/6qlt/6qlt_ligand.sdf
201 199 data/PDBBind_processed/6ufo/6ufo_protein_processed.pdb data/PDBBind_processed/6ufo/6ufo_ligand.sdf
202 200 data/PDBBind_processed/6o0h/6o0h_protein_processed.pdb data/PDBBind_processed/6o0h/6o0h_ligand.sdf
203 201 data/PDBBind_processed/6o3x/6o3x_protein_processed.pdb data/PDBBind_processed/6o3x/6o3x_ligand.sdf
204 202 data/PDBBind_processed/5zjz/5zjz_protein_processed.pdb data/PDBBind_processed/5zjz/5zjz_ligand.mol2
205 203 data/PDBBind_processed/6i8t/6i8t_protein_processed.pdb data/PDBBind_processed/6i8t/6i8t_ligand.sdf
206 204 data/PDBBind_processed/6ooy/6ooy_protein_processed.pdb data/PDBBind_processed/6ooy/6ooy_ligand.sdf
207 205 data/PDBBind_processed/6oiq/6oiq_protein_processed.pdb data/PDBBind_processed/6oiq/6oiq_ligand.sdf
208 206 data/PDBBind_processed/6od6/6od6_protein_processed.pdb data/PDBBind_processed/6od6/6od6_ligand.sdf
209 207 data/PDBBind_processed/6nrh/6nrh_protein_processed.pdb data/PDBBind_processed/6nrh/6nrh_ligand.sdf
210 208 data/PDBBind_processed/6qra/6qra_protein_processed.pdb data/PDBBind_processed/6qra/6qra_ligand.mol2
211 209 data/PDBBind_processed/6hhh/6hhh_protein_processed.pdb data/PDBBind_processed/6hhh/6hhh_ligand.sdf
212 210 data/PDBBind_processed/6m7h/6m7h_protein_processed.pdb data/PDBBind_processed/6m7h/6m7h_ligand.sdf
213 211 data/PDBBind_processed/6ufn/6ufn_protein_processed.pdb data/PDBBind_processed/6ufn/6ufn_ligand.sdf
214 212 data/PDBBind_processed/6qr0/6qr0_protein_processed.pdb data/PDBBind_processed/6qr0/6qr0_ligand.mol2
215 213 data/PDBBind_processed/6o5u/6o5u_protein_processed.pdb data/PDBBind_processed/6o5u/6o5u_ligand.sdf
216 214 data/PDBBind_processed/6h14/6h14_protein_processed.pdb data/PDBBind_processed/6h14/6h14_ligand.sdf
217 215 data/PDBBind_processed/6jwa/6jwa_protein_processed.pdb data/PDBBind_processed/6jwa/6jwa_ligand.sdf
218 216 data/PDBBind_processed/6ny0/6ny0_protein_processed.pdb data/PDBBind_processed/6ny0/6ny0_ligand.sdf
219 217 data/PDBBind_processed/6jan/6jan_protein_processed.pdb data/PDBBind_processed/6jan/6jan_ligand.sdf
220 218 data/PDBBind_processed/6ftf/6ftf_protein_processed.pdb data/PDBBind_processed/6ftf/6ftf_ligand.sdf
221 219 data/PDBBind_processed/6oxw/6oxw_protein_processed.pdb data/PDBBind_processed/6oxw/6oxw_ligand.sdf
222 220 data/PDBBind_processed/6jon/6jon_protein_processed.pdb data/PDBBind_processed/6jon/6jon_ligand.sdf
223 221 data/PDBBind_processed/6cf7/6cf7_protein_processed.pdb data/PDBBind_processed/6cf7/6cf7_ligand.sdf
224 222 data/PDBBind_processed/6rtn/6rtn_protein_processed.pdb data/PDBBind_processed/6rtn/6rtn_ligand.mol2
225 223 data/PDBBind_processed/6jsz/6jsz_protein_processed.pdb data/PDBBind_processed/6jsz/6jsz_ligand.sdf
226 224 data/PDBBind_processed/6o9c/6o9c_protein_processed.pdb data/PDBBind_processed/6o9c/6o9c_ligand.sdf
227 225 data/PDBBind_processed/6mo8/6mo8_protein_processed.pdb data/PDBBind_processed/6mo8/6mo8_ligand.sdf
228 226 data/PDBBind_processed/6qln/6qln_protein_processed.pdb data/PDBBind_processed/6qln/6qln_ligand.sdf
229 227 data/PDBBind_processed/6qqu/6qqu_protein_processed.pdb data/PDBBind_processed/6qqu/6qqu_ligand.mol2
230 228 data/PDBBind_processed/6i66/6i66_protein_processed.pdb data/PDBBind_processed/6i66/6i66_ligand.sdf
231 229 data/PDBBind_processed/6mja/6mja_protein_processed.pdb data/PDBBind_processed/6mja/6mja_ligand.sdf
232 230 data/PDBBind_processed/6gwe/6gwe_protein_processed.pdb data/PDBBind_processed/6gwe/6gwe_ligand.mol2
233 231 data/PDBBind_processed/6d3z/6d3z_protein_processed.pdb data/PDBBind_processed/6d3z/6d3z_ligand.sdf
234 232 data/PDBBind_processed/6oxr/6oxr_protein_processed.pdb data/PDBBind_processed/6oxr/6oxr_ligand.sdf
235 233 data/PDBBind_processed/6r4k/6r4k_protein_processed.pdb data/PDBBind_processed/6r4k/6r4k_ligand.sdf
236 234 data/PDBBind_processed/6hle/6hle_protein_processed.pdb data/PDBBind_processed/6hle/6hle_ligand.sdf
237 235 data/PDBBind_processed/6h9v/6h9v_protein_processed.pdb data/PDBBind_processed/6h9v/6h9v_ligand.sdf
238 236 data/PDBBind_processed/6hou/6hou_protein_processed.pdb data/PDBBind_processed/6hou/6hou_ligand.sdf
239 237 data/PDBBind_processed/6nv9/6nv9_protein_processed.pdb data/PDBBind_processed/6nv9/6nv9_ligand.sdf
240 238 data/PDBBind_processed/6py0/6py0_protein_processed.pdb data/PDBBind_processed/6py0/6py0_ligand.sdf
241 239 data/PDBBind_processed/6qlq/6qlq_protein_processed.pdb data/PDBBind_processed/6qlq/6qlq_ligand.sdf
242 240 data/PDBBind_processed/6nv7/6nv7_protein_processed.pdb data/PDBBind_processed/6nv7/6nv7_ligand.sdf
243 241 data/PDBBind_processed/6n4b/6n4b_protein_processed.pdb data/PDBBind_processed/6n4b/6n4b_ligand.sdf
244 242 data/PDBBind_processed/6jaq/6jaq_protein_processed.pdb data/PDBBind_processed/6jaq/6jaq_ligand.sdf
245 243 data/PDBBind_processed/6i8m/6i8m_protein_processed.pdb data/PDBBind_processed/6i8m/6i8m_ligand.sdf
246 244 data/PDBBind_processed/6dz0/6dz0_protein_processed.pdb data/PDBBind_processed/6dz0/6dz0_ligand.mol2
247 245 data/PDBBind_processed/6oxs/6oxs_protein_processed.pdb data/PDBBind_processed/6oxs/6oxs_ligand.sdf
248 246 data/PDBBind_processed/6k2n/6k2n_protein_processed.pdb data/PDBBind_processed/6k2n/6k2n_ligand.sdf
249 247 data/PDBBind_processed/6cjj/6cjj_protein_processed.pdb data/PDBBind_processed/6cjj/6cjj_ligand.sdf
250 248 data/PDBBind_processed/6ffg/6ffg_protein_processed.pdb data/PDBBind_processed/6ffg/6ffg_ligand.sdf
251 249 data/PDBBind_processed/6a73/6a73_protein_processed.pdb data/PDBBind_processed/6a73/6a73_ligand.sdf
252 250 data/PDBBind_processed/6qqt/6qqt_protein_processed.pdb data/PDBBind_processed/6qqt/6qqt_ligand.mol2
253 251 data/PDBBind_processed/6a1c/6a1c_protein_processed.pdb data/PDBBind_processed/6a1c/6a1c_ligand.sdf
254 252 data/PDBBind_processed/6oxu/6oxu_protein_processed.pdb data/PDBBind_processed/6oxu/6oxu_ligand.sdf
255 253 data/PDBBind_processed/6qre/6qre_protein_processed.pdb data/PDBBind_processed/6qre/6qre_ligand.mol2
256 254 data/PDBBind_processed/6qtw/6qtw_protein_processed.pdb data/PDBBind_processed/6qtw/6qtw_ligand.sdf
257 255 data/PDBBind_processed/6np4/6np4_protein_processed.pdb data/PDBBind_processed/6np4/6np4_ligand.sdf
258 256 data/PDBBind_processed/6hv2/6hv2_protein_processed.pdb data/PDBBind_processed/6hv2/6hv2_ligand.sdf
259 257 data/PDBBind_processed/6n55/6n55_protein_processed.pdb data/PDBBind_processed/6n55/6n55_ligand.sdf
260 258 data/PDBBind_processed/6e3o/6e3o_protein_processed.pdb data/PDBBind_processed/6e3o/6e3o_ligand.sdf
261 259 data/PDBBind_processed/6kjd/6kjd_protein_processed.pdb data/PDBBind_processed/6kjd/6kjd_ligand.sdf
262 260 data/PDBBind_processed/6sfc/6sfc_protein_processed.pdb data/PDBBind_processed/6sfc/6sfc_ligand.sdf
263 261 data/PDBBind_processed/6qi7/6qi7_protein_processed.pdb data/PDBBind_processed/6qi7/6qi7_ligand.sdf
264 262 data/PDBBind_processed/6hzc/6hzc_protein_processed.pdb data/PDBBind_processed/6hzc/6hzc_ligand.sdf
265 263 data/PDBBind_processed/6k04/6k04_protein_processed.pdb data/PDBBind_processed/6k04/6k04_ligand.sdf
266 264 data/PDBBind_processed/6op0/6op0_protein_processed.pdb data/PDBBind_processed/6op0/6op0_ligand.sdf
267 265 data/PDBBind_processed/6q38/6q38_protein_processed.pdb data/PDBBind_processed/6q38/6q38_ligand.mol2
268 266 data/PDBBind_processed/6n8x/6n8x_protein_processed.pdb data/PDBBind_processed/6n8x/6n8x_ligand.sdf
269 267 data/PDBBind_processed/6np3/6np3_protein_processed.pdb data/PDBBind_processed/6np3/6np3_ligand.sdf
270 268 data/PDBBind_processed/6uvv/6uvv_protein_processed.pdb data/PDBBind_processed/6uvv/6uvv_ligand.sdf
271 269 data/PDBBind_processed/6pgo/6pgo_protein_processed.pdb data/PDBBind_processed/6pgo/6pgo_ligand.sdf
272 270 data/PDBBind_processed/6jbe/6jbe_protein_processed.pdb data/PDBBind_processed/6jbe/6jbe_ligand.sdf
273 271 data/PDBBind_processed/6i75/6i75_protein_processed.pdb data/PDBBind_processed/6i75/6i75_ligand.sdf
274 272 data/PDBBind_processed/6qqq/6qqq_protein_processed.pdb data/PDBBind_processed/6qqq/6qqq_ligand.mol2
275 273 data/PDBBind_processed/6i62/6i62_protein_processed.pdb data/PDBBind_processed/6i62/6i62_ligand.sdf
276 274 data/PDBBind_processed/6j9y/6j9y_protein_processed.pdb data/PDBBind_processed/6j9y/6j9y_ligand.sdf
277 275 data/PDBBind_processed/6g29/6g29_protein_processed.pdb data/PDBBind_processed/6g29/6g29_ligand.sdf
278 276 data/PDBBind_processed/6h7d/6h7d_protein_processed.pdb data/PDBBind_processed/6h7d/6h7d_ligand.sdf
279 277 data/PDBBind_processed/6mo9/6mo9_protein_processed.pdb data/PDBBind_processed/6mo9/6mo9_ligand.sdf
280 278 data/PDBBind_processed/6jao/6jao_protein_processed.pdb data/PDBBind_processed/6jao/6jao_ligand.sdf
281 279 data/PDBBind_processed/6jmf/6jmf_protein_processed.pdb data/PDBBind_processed/6jmf/6jmf_ligand.sdf
282 280 data/PDBBind_processed/6hmy/6hmy_protein_processed.pdb data/PDBBind_processed/6hmy/6hmy_ligand.sdf
283 281 data/PDBBind_processed/6qfe/6qfe_protein_processed.pdb data/PDBBind_processed/6qfe/6qfe_ligand.mol2
284 282 data/PDBBind_processed/5zml/5zml_protein_processed.pdb data/PDBBind_processed/5zml/5zml_ligand.sdf
285 283 data/PDBBind_processed/6i65/6i65_protein_processed.pdb data/PDBBind_processed/6i65/6i65_ligand.sdf
286 284 data/PDBBind_processed/6e7m/6e7m_protein_processed.pdb data/PDBBind_processed/6e7m/6e7m_ligand.sdf
287 285 data/PDBBind_processed/6i61/6i61_protein_processed.pdb data/PDBBind_processed/6i61/6i61_ligand.sdf
288 286 data/PDBBind_processed/6rz6/6rz6_protein_processed.pdb data/PDBBind_processed/6rz6/6rz6_ligand.sdf
289 287 data/PDBBind_processed/6qtm/6qtm_protein_processed.pdb data/PDBBind_processed/6qtm/6qtm_ligand.sdf
290 288 data/PDBBind_processed/6qlo/6qlo_protein_processed.pdb data/PDBBind_processed/6qlo/6qlo_ligand.sdf
291 289 data/PDBBind_processed/6oie/6oie_protein_processed.pdb data/PDBBind_processed/6oie/6oie_ligand.sdf
292 290 data/PDBBind_processed/6miy/6miy_protein_processed.pdb data/PDBBind_processed/6miy/6miy_ligand.sdf
293 291 data/PDBBind_processed/6nrf/6nrf_protein_processed.pdb data/PDBBind_processed/6nrf/6nrf_ligand.mol2
294 292 data/PDBBind_processed/6gj5/6gj5_protein_processed.pdb data/PDBBind_processed/6gj5/6gj5_ligand.mol2
295 293 data/PDBBind_processed/6jad/6jad_protein_processed.pdb data/PDBBind_processed/6jad/6jad_ligand.sdf
296 294 data/PDBBind_processed/6mj4/6mj4_protein_processed.pdb data/PDBBind_processed/6mj4/6mj4_ligand.sdf
297 295 data/PDBBind_processed/6h12/6h12_protein_processed.pdb data/PDBBind_processed/6h12/6h12_ligand.sdf
298 296 data/PDBBind_processed/6d3y/6d3y_protein_processed.pdb data/PDBBind_processed/6d3y/6d3y_ligand.sdf
299 297 data/PDBBind_processed/6qr2/6qr2_protein_processed.pdb data/PDBBind_processed/6qr2/6qr2_ligand.mol2
300 298 data/PDBBind_processed/6qxa/6qxa_protein_processed.pdb data/PDBBind_processed/6qxa/6qxa_ligand.mol2
301 299 data/PDBBind_processed/6o9b/6o9b_protein_processed.pdb data/PDBBind_processed/6o9b/6o9b_ligand.sdf
302 300 data/PDBBind_processed/6ckl/6ckl_protein_processed.pdb data/PDBBind_processed/6ckl/6ckl_ligand.sdf
303 301 data/PDBBind_processed/6oir/6oir_protein_processed.pdb data/PDBBind_processed/6oir/6oir_ligand.sdf
304 302 data/PDBBind_processed/6d40/6d40_protein_processed.pdb data/PDBBind_processed/6d40/6d40_ligand.sdf
305 303 data/PDBBind_processed/6e6j/6e6j_protein_processed.pdb data/PDBBind_processed/6e6j/6e6j_ligand.mol2
306 304 data/PDBBind_processed/6i7a/6i7a_protein_processed.pdb data/PDBBind_processed/6i7a/6i7a_ligand.sdf
307 305 data/PDBBind_processed/6g25/6g25_protein_processed.pdb data/PDBBind_processed/6g25/6g25_ligand.mol2
308 306 data/PDBBind_processed/6oin/6oin_protein_processed.pdb data/PDBBind_processed/6oin/6oin_ligand.sdf
309 307 data/PDBBind_processed/6jam/6jam_protein_processed.pdb data/PDBBind_processed/6jam/6jam_ligand.sdf
310 308 data/PDBBind_processed/6oxz/6oxz_protein_processed.pdb data/PDBBind_processed/6oxz/6oxz_ligand.sdf
311 309 data/PDBBind_processed/6hop/6hop_protein_processed.pdb data/PDBBind_processed/6hop/6hop_ligand.sdf
312 310 data/PDBBind_processed/6rot/6rot_protein_processed.pdb data/PDBBind_processed/6rot/6rot_ligand.sdf
313 311 data/PDBBind_processed/6uhu/6uhu_protein_processed.pdb data/PDBBind_processed/6uhu/6uhu_ligand.mol2
314 312 data/PDBBind_processed/6mji/6mji_protein_processed.pdb data/PDBBind_processed/6mji/6mji_ligand.sdf
315 313 data/PDBBind_processed/6nrj/6nrj_protein_processed.pdb data/PDBBind_processed/6nrj/6nrj_ligand.mol2
316 314 data/PDBBind_processed/6nt2/6nt2_protein_processed.pdb data/PDBBind_processed/6nt2/6nt2_ligand.mol2
317 315 data/PDBBind_processed/6op9/6op9_protein_processed.pdb data/PDBBind_processed/6op9/6op9_ligand.sdf
318 316 data/PDBBind_processed/6pno/6pno_protein_processed.pdb data/PDBBind_processed/6pno/6pno_ligand.sdf
319 317 data/PDBBind_processed/6e4v/6e4v_protein_processed.pdb data/PDBBind_processed/6e4v/6e4v_ligand.sdf
320 318 data/PDBBind_processed/6k1s/6k1s_protein_processed.pdb data/PDBBind_processed/6k1s/6k1s_ligand.sdf
321 319 data/PDBBind_processed/6a87/6a87_protein_processed.pdb data/PDBBind_processed/6a87/6a87_ligand.sdf
322 320 data/PDBBind_processed/6oim/6oim_protein_processed.pdb data/PDBBind_processed/6oim/6oim_ligand.sdf
323 321 data/PDBBind_processed/6cjp/6cjp_protein_processed.pdb data/PDBBind_processed/6cjp/6cjp_ligand.sdf
324 322 data/PDBBind_processed/6pyb/6pyb_protein_processed.pdb data/PDBBind_processed/6pyb/6pyb_ligand.sdf
325 323 data/PDBBind_processed/6h13/6h13_protein_processed.pdb data/PDBBind_processed/6h13/6h13_ligand.sdf
326 324 data/PDBBind_processed/6qrf/6qrf_protein_processed.pdb data/PDBBind_processed/6qrf/6qrf_ligand.mol2
327 325 data/PDBBind_processed/6mhc/6mhc_protein_processed.pdb data/PDBBind_processed/6mhc/6mhc_ligand.sdf
328 326 data/PDBBind_processed/6j9w/6j9w_protein_processed.pdb data/PDBBind_processed/6j9w/6j9w_ligand.sdf
329 327 data/PDBBind_processed/6nrg/6nrg_protein_processed.pdb data/PDBBind_processed/6nrg/6nrg_ligand.mol2
330 328 data/PDBBind_processed/6fff/6fff_protein_processed.pdb data/PDBBind_processed/6fff/6fff_ligand.sdf
331 329 data/PDBBind_processed/6n93/6n93_protein_processed.pdb data/PDBBind_processed/6n93/6n93_ligand.sdf
332 330 data/PDBBind_processed/6jut/6jut_protein_processed.pdb data/PDBBind_processed/6jut/6jut_ligand.mol2
333 331 data/PDBBind_processed/6g2e/6g2e_protein_processed.pdb data/PDBBind_processed/6g2e/6g2e_ligand.sdf
334 332 data/PDBBind_processed/6nd3/6nd3_protein_processed.pdb data/PDBBind_processed/6nd3/6nd3_ligand.sdf
335 333 data/PDBBind_processed/6os6/6os6_protein_processed.pdb data/PDBBind_processed/6os6/6os6_ligand.mol2
336 334 data/PDBBind_processed/6dql/6dql_protein_processed.pdb data/PDBBind_processed/6dql/6dql_ligand.mol2
337 335 data/PDBBind_processed/6inz/6inz_protein_processed.pdb data/PDBBind_processed/6inz/6inz_ligand.sdf
338 336 data/PDBBind_processed/6i67/6i67_protein_processed.pdb data/PDBBind_processed/6i67/6i67_ligand.sdf
339 337 data/PDBBind_processed/6quw/6quw_protein_processed.pdb data/PDBBind_processed/6quw/6quw_ligand.sdf
340 338 data/PDBBind_processed/6qwi/6qwi_protein_processed.pdb data/PDBBind_processed/6qwi/6qwi_ligand.sdf
341 339 data/PDBBind_processed/6npm/6npm_protein_processed.pdb data/PDBBind_processed/6npm/6npm_ligand.sdf
342 340 data/PDBBind_processed/6i64/6i64_protein_processed.pdb data/PDBBind_processed/6i64/6i64_ligand.sdf
343 341 data/PDBBind_processed/6e3n/6e3n_protein_processed.pdb data/PDBBind_processed/6e3n/6e3n_ligand.sdf
344 342 data/PDBBind_processed/6qrg/6qrg_protein_processed.pdb data/PDBBind_processed/6qrg/6qrg_ligand.mol2
345 343 data/PDBBind_processed/6nxz/6nxz_protein_processed.pdb data/PDBBind_processed/6nxz/6nxz_ligand.sdf
346 344 data/PDBBind_processed/6iby/6iby_protein_processed.pdb data/PDBBind_processed/6iby/6iby_ligand.sdf
347 345 data/PDBBind_processed/6gj7/6gj7_protein_processed.pdb data/PDBBind_processed/6gj7/6gj7_ligand.mol2
348 346 data/PDBBind_processed/6qr3/6qr3_protein_processed.pdb data/PDBBind_processed/6qr3/6qr3_ligand.mol2
349 347 data/PDBBind_processed/6qr1/6qr1_protein_processed.pdb data/PDBBind_processed/6qr1/6qr1_ligand.mol2
350 348 data/PDBBind_processed/6s9x/6s9x_protein_processed.pdb data/PDBBind_processed/6s9x/6s9x_ligand.sdf
351 349 data/PDBBind_processed/6q4q/6q4q_protein_processed.pdb data/PDBBind_processed/6q4q/6q4q_ligand.mol2
352 350 data/PDBBind_processed/6hbn/6hbn_protein_processed.pdb data/PDBBind_processed/6hbn/6hbn_ligand.sdf
353 351 data/PDBBind_processed/6nw3/6nw3_protein_processed.pdb data/PDBBind_processed/6nw3/6nw3_ligand.sdf
354 352 data/PDBBind_processed/6tel/6tel_protein_processed.pdb data/PDBBind_processed/6tel/6tel_ligand.sdf
355 353 data/PDBBind_processed/6p8y/6p8y_protein_processed.pdb data/PDBBind_processed/6p8y/6p8y_ligand.sdf
356 354 data/PDBBind_processed/6d5w/6d5w_protein_processed.pdb data/PDBBind_processed/6d5w/6d5w_ligand.sdf
357 355 data/PDBBind_processed/6t6a/6t6a_protein_processed.pdb data/PDBBind_processed/6t6a/6t6a_ligand.mol2
358 356 data/PDBBind_processed/6o5g/6o5g_protein_processed.pdb data/PDBBind_processed/6o5g/6o5g_ligand.mol2
359 357 data/PDBBind_processed/6r7d/6r7d_protein_processed.pdb data/PDBBind_processed/6r7d/6r7d_ligand.sdf
360 358 data/PDBBind_processed/6pya/6pya_protein_processed.pdb data/PDBBind_processed/6pya/6pya_ligand.mol2
361 359 data/PDBBind_processed/6ffe/6ffe_protein_processed.pdb data/PDBBind_processed/6ffe/6ffe_ligand.sdf
362 360 data/PDBBind_processed/6d3x/6d3x_protein_processed.pdb data/PDBBind_processed/6d3x/6d3x_ligand.sdf
363 361 data/PDBBind_processed/6gj8/6gj8_protein_processed.pdb data/PDBBind_processed/6gj8/6gj8_ligand.mol2
364 362 data/PDBBind_processed/6mo2/6mo2_protein_processed.pdb data/PDBBind_processed/6mo2/6mo2_ligand.mol2

View File

@@ -58,7 +58,7 @@ class OptimizeConformer:
def score_conformation(self, values):
for i, r in enumerate(self.rotable_bonds):
SetDihedral(self.mol.GetConformer(self.probe_id), r, values[i])
return RMSD(self.mol, self.true_mol, self.probe_id, self.ref_id)
return AllChem.AlignMol(self.mol, self.true_mol, self.probe_id, self.ref_id)
def get_torsion_angles(mol):
@@ -83,114 +83,3 @@ def get_torsion_angles(mol):
)
return torsions_list
# GeoMol
def get_torsions(mol_list):
print('USING GEOMOL GET TORSIONS FUNCTION')
atom_counter = 0
torsionList = []
for m in mol_list:
torsionSmarts = '[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]'
torsionQuery = Chem.MolFromSmarts(torsionSmarts)
matches = m.GetSubstructMatches(torsionQuery)
for match in matches:
idx2 = match[0]
idx3 = match[1]
bond = m.GetBondBetweenAtoms(idx2, idx3)
jAtom = m.GetAtomWithIdx(idx2)
kAtom = m.GetAtomWithIdx(idx3)
for b1 in jAtom.GetBonds():
if (b1.GetIdx() == bond.GetIdx()):
continue
idx1 = b1.GetOtherAtomIdx(idx2)
for b2 in kAtom.GetBonds():
if ((b2.GetIdx() == bond.GetIdx())
or (b2.GetIdx() == b1.GetIdx())):
continue
idx4 = b2.GetOtherAtomIdx(idx3)
# skip 3-membered rings
if (idx4 == idx1):
continue
if m.GetAtomWithIdx(idx4).IsInRing():
torsionList.append(
(idx4 + atom_counter, idx3 + atom_counter, idx2 + atom_counter, idx1 + atom_counter))
break
else:
torsionList.append(
(idx1 + atom_counter, idx2 + atom_counter, idx3 + atom_counter, idx4 + atom_counter))
break
break
atom_counter += m.GetNumAtoms()
return torsionList
def A_transpose_matrix(alpha):
return np.array([[np.cos(alpha), np.sin(alpha)], [-np.sin(alpha), np.cos(alpha)]], dtype=np.double)
def S_vec(alpha):
return np.array([[np.cos(alpha)], [np.sin(alpha)]], dtype=np.double)
def GetDihedralFromPointCloud(Z, atom_idx):
p = Z[list(atom_idx)]
b = p[:-1] - p[1:]
b[0] *= -1
v = np.array([v - (v.dot(b[1]) / b[1].dot(b[1])) * b[1] for v in [b[0], b[2]]])
# Normalize vectors
v /= np.sqrt(np.einsum('...i,...i', v, v)).reshape(-1, 1)
b1 = b[1] / np.linalg.norm(b[1])
x = np.dot(v[0], v[1])
m = np.cross(v[0], b1)
y = np.dot(m, v[1])
return np.arctan2(y, x)
def get_dihedral_vonMises(mol, conf, atom_idx, Z):
Z = np.array(Z)
v = np.zeros((2, 1))
iAtom = mol.GetAtomWithIdx(atom_idx[1])
jAtom = mol.GetAtomWithIdx(atom_idx[2])
k_0 = atom_idx[0]
i = atom_idx[1]
j = atom_idx[2]
l_0 = atom_idx[3]
for b1 in iAtom.GetBonds():
k = b1.GetOtherAtomIdx(i)
if k == j:
continue
for b2 in jAtom.GetBonds():
l = b2.GetOtherAtomIdx(j)
if l == i:
continue
assert k != l
s_star = S_vec(GetDihedralFromPointCloud(Z, (k, i, j, l)))
a_mat = A_transpose_matrix(GetDihedral(conf, (k, i, j, k_0)) + GetDihedral(conf, (l_0, i, j, l)))
v = v + np.matmul(a_mat, s_star)
v = v / np.linalg.norm(v)
v = v.reshape(-1)
return np.arctan2(v[1], v[0])
def get_von_mises_rms(mol, mol_rdkit, rotable_bonds, conf_id):
new_dihedrals = np.zeros(len(rotable_bonds))
for idx, r in enumerate(rotable_bonds):
new_dihedrals[idx] = get_dihedral_vonMises(mol_rdkit,
mol_rdkit.GetConformer(conf_id), r,
mol.GetConformer().GetPositions())
mol_rdkit = apply_changes(mol_rdkit, new_dihedrals, rotable_bonds, conf_id)
return RMSD(mol_rdkit, mol, conf_id)
def mmff_func(mol):
mol_mmff = copy.deepcopy(mol)
AllChem.MMFFOptimizeMoleculeConfs(mol_mmff, mmffVariant='MMFF94s')
for i in range(mol.GetNumConformers()):
coords = mol_mmff.GetConformers()[i].GetPositions()
for j in range(coords.shape[0]):
mol.GetConformer(i).SetAtomPosition(j,
Geometry.Point3D(*coords[j]))
RMSD = AllChem.AlignMol

179
datasets/constants.py Normal file
View File

@@ -0,0 +1,179 @@
# Significant contribution from Ben Fry and Nick Polizzi
three_to_one = {'ALA': 'A',
'ARG': 'R',
'ASN': 'N',
'ASP': 'D',
'CYS': 'C',
'GLN': 'Q',
'GLU': 'E',
'GLY': 'G',
'HIS': 'H',
'ILE': 'I',
'LEU': 'L',
'LYS': 'K',
'MET': 'M',
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
'PHE': 'F',
'PRO': 'P',
'PYL': 'O',
'SER': 'S',
'SEC': 'U',
'THR': 'T',
'TRP': 'W',
'TYR': 'Y',
'VAL': 'V',
'ASX': 'B',
'GLX': 'Z',
'XAA': 'X',
'XLE': 'J'}
aa_name2aa_idx = {'ALA': 0, 'ARG': 1, 'ASN': 2, 'ASP': 3, 'CYS': 4, 'GLU': 5, 'GLN': 6, 'GLY': 7,
'HIS': 8, 'ILE': 9, 'LEU': 10, 'LYS': 11, 'MET': 12, 'PHE': 13, 'PRO': 14,
'SER': 15, 'THR': 16, 'TRP': 17, 'TYR': 18, 'VAL': 19, 'MSE': 12}
aa_short2long = {'C': 'CYS', 'D': 'ASP', 'S': 'SER', 'Q': 'GLN', 'K': 'LYS', 'I': 'ILE',
'P': 'PRO', 'T': 'THR', 'F': 'PHE', 'N': 'ASN', 'G': 'GLY', 'H': 'HIS',
'L': 'LEU', 'R': 'ARG', 'W': 'TRP', 'A': 'ALA', 'V': 'VAL', 'E': 'GLU',
'Y': 'TYR', 'M': 'MET'}
aa_short2aa_idx = {aa_short: aa_name2aa_idx[aa_long] for aa_short, aa_long in aa_short2long.items()}
aa_idx2aa_short = {aa_idx: aa_short for aa_short, aa_idx in aa_short2aa_idx.items()}
aa_long2short = {aa_long: aa_short for aa_short, aa_long in aa_short2long.items()}
aa_long2short['MSE'] = 'M'
chi = { 'C' :
{ 1: ('N' , 'CA' , 'CB' , 'SG' ) },
'D' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'OD1'), },
'E' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD' ),
3: ('CB' , 'CG' , 'CD' , 'OE1'), },
'F' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
'H' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'ND1'), },
'I' :
{ 1: ('N' , 'CA' , 'CB' , 'CG1'),
2: ('CA' , 'CB' , 'CG1', 'CD1'), },
'K' :
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
2: ('CA' , 'CB' , 'CG' ,'CD' ),
3: ('CB' , 'CG' , 'CD' ,'CE' ),
4: ('CG' , 'CD' , 'CE' ,'NZ' ), },
'L' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
'M' :
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
2: ('CA' , 'CB' , 'CG' ,'SD' ),
3: ('CB' , 'CG' , 'SD' ,'CE' ), },
'N' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'OD1'), },
'P' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD' ), },
'Q' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD' ),
3: ('CB' , 'CG' , 'CD' , 'OE1'), },
'R' :
{ 1: ('N' , 'CA' , 'CB' ,'CG' ),
2: ('CA' , 'CB' , 'CG' ,'CD' ),
3: ('CB' , 'CG' , 'CD' ,'NE' ),
4: ('CG' , 'CD' , 'NE' ,'CZ' ), },
'S' :
{ 1: ('N' , 'CA' , 'CB' , 'OG' ), },
'T' :
{ 1: ('N' , 'CA' , 'CB' , 'OG1'), },
'V' :
{ 1: ('N' , 'CA' , 'CB' , 'CG1'), },
'W' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
'Y' :
{ 1: ('N' , 'CA' , 'CB' , 'CG' ),
2: ('CA' , 'CB' , 'CG' , 'CD1'), },
}
atom_order = {'G': ['N', 'CA', 'C', 'O'],
'A': ['N', 'CA', 'C', 'O', 'CB'],
'S': ['N', 'CA', 'C', 'O', 'CB', 'OG'],
'C': ['N', 'CA', 'C', 'O', 'CB', 'SG'],
'T': ['N', 'CA', 'C', 'O', 'CB', 'OG1', 'CG2'],
'P': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD'],
'V': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2'],
'M': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'SD', 'CE'],
'N': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'ND2'],
'I': ['N', 'CA', 'C', 'O', 'CB', 'CG1', 'CG2', 'CD1'],
'L': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2'],
'D': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'OD1', 'OD2'],
'E': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'OE2'],
'K': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'CE', 'NZ'],
'Q': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'OE1', 'NE2'],
'H': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'ND1', 'CD2', 'CE1', 'NE2'],
'F': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ'],
'R': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD', 'NE', 'CZ', 'NH1', 'NH2'],
'Y': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ', 'OH'],
'W': ['N', 'CA', 'C', 'O', 'CB', 'CG', 'CD1', 'CD2', 'CE2', 'CE3', 'NE1', 'CZ2', 'CZ3', 'CH2'],
'X': ['N', 'CA', 'C', 'O']} # unknown amino acid
amino_acid_smiles = {
'PHE': '[NH3+]CC(=O)N[C@@H](Cc1ccccc1)C(=O)NCC(=O)O',
'MET': 'CSCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'TYR': '[NH3+]CC(=O)N[C@@H](Cc1ccc(O)cc1)C(=O)NCC(=O)O',
'ILE': 'CC[C@H](C)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'TRP': '[NH3+]CC(=O)N[C@@H](Cc1c[nH]c2ccccc12)C(=O)NCC(=O)O',
'THR': 'C[C@@H](O)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'CYS': '[NH3+]CC(=O)N[C@@H](CS)C(=O)NCC(=O)O',
'ALA': 'C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'LYS': '[NH3+]CCCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'PRO': '[NH3+]CC(=O)N1CCC[C@H]1C(=O)NCC(=O)O',
'LEU': 'CC(C)C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'GLY': '[NH3+]CC(=O)NCC(=O)NCC(=O)O',
'ASP': '[NH3+]CC(=O)N[C@@H](CC(=O)O)C(=O)NCC(=O)O',
'HIS': '[NH3+]CC(=O)N[C@@H](Cc1c[nH]c[nH+]1)C(=O)NCC(=O)O',
'VAL': 'CC(C)[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'SER': '[NH3+]CC(=O)N[C@@H](CO)C(=O)NCC(=O)O',
'ARG': 'NC(=[NH2+])NCCC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'GLU': '[NH3+]CC(=O)N[C@@H](CCC(=O)O)C(=O)NCC(=O)O',
'GLN': 'NC(=O)CC[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
'ASN': 'NC(=O)C[C@H](NC(=O)C[NH3+])C(=O)NCC(=O)O',
}
cg_rdkit_indices = {
'PHE': {4: 'N', 5: 'CA', 13: 'C', 14: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 12: 'CD2', 9: 'CE1', 11: 'CE2', 10: 'CZ'},
'MET': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 2: 'CG', 1: 'SD', 0: 'CE'},
'TYR': {4: 'N', 5: 'CA', 14: 'C', 15: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 13: 'CD2', 9: 'CE1', 12: 'CE2', 10: 'CZ', 11: 'OH'},
'ILE': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 2: 'CB', 1: 'CG1', 3: 'CG2', 0: 'CD1'},
'TRP': {4: 'N', 5: 'CA', 16: 'C', 17: 'O', 6: 'CB', 7: 'CG', 8: 'CD1', 15: 'CD2', 9: 'NE1', 10: 'CE2', 14: 'CE3', 11: 'CZ2', 13: 'CZ3', 12: 'CH2'},
'THR': {4: 'N', 3: 'CA', 9: 'C', 10: 'O', 1: 'CB', 2: 'OG1', 0: 'CG2'},
'CYS': {4: 'N', 5: 'CA', 8: 'C', 9: 'O', 6: 'CB', 7: 'SG'},
'ALA': {2: 'N', 1: 'CA', 7: 'C', 8: 'O', 0: 'CB'},
'LYS': {6: 'N', 5: 'CA', 11: 'C', 12: 'O', 4: 'CB', 3: 'CG', 2: 'CD', 1: 'CE', 0: 'NZ'},
'PRO': {4: 'N', 8: 'CA', 9: 'C', 10: 'O', 7: 'CB', 6: 'CG', 5: 'CD'},
'LEU': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 1: 'CG', 0: 'CD1', 2: 'CD2'},
'GLY': {4: 'N', 5: 'CA', 6: 'C', 7: 'O'},
'ASP': {4: 'N', 5: 'CA', 10: 'C', 11: 'O', 6: 'CB', 7: 'CG', 8: 'OD1', 9: 'OD2'},
'HIS': {4: 'N', 5: 'CA', 12: 'C', 13: 'O', 6: 'CB', 7: 'CG', 11: 'ND1', 8: 'CD2', 10: 'CE1', 9: 'NE2'},
'VAL': {4: 'N', 3: 'CA', 9: 'C', 10: 'O', 1: 'CB', 0: 'CG1', 2: 'CG2'},
'SER': {4: 'N', 5: 'CA', 8: 'C', 9: 'O', 6: 'CB', 7: 'OG'},
'ARG': {8: 'N', 7: 'CA', 13: 'C', 14: 'O', 6: 'CB', 5: 'CG', 4: 'CD', 3: 'NE', 1: 'CZ', 0: 'NH1', 2: 'NH2'},
'GLU': {4: 'N', 5: 'CA', 11: 'C', 12: 'O', 6: 'CB', 7: 'CG', 8: 'CD', 9: 'OE1', 10: 'OE2'},
'GLN': {6: 'N', 5: 'CA', 11: 'C', 12: 'O', 4: 'CB', 3: 'CG', 1: 'CD', 2: 'OE1', 0: 'NE2'},
'ASN': {5: 'N', 4: 'CA', 10: 'C', 11: 'O', 3: 'CB', 1: 'CG', 2: 'OD1', 0: 'ND2'}
}
aa_to_cg_indices = {aa_long2short[x]: [atom_order[aa_long2short[x]].index(aname) for aname in index_dict.values()] for x, index_dict in cg_rdkit_indices.items()}

101
datasets/dataloader.py Normal file
View File

@@ -0,0 +1,101 @@
from collections.abc import Mapping, Sequence
from typing import List, Optional, Union
import torch.utils.data
from torch.utils.data.dataloader import default_collate
from torch_geometric.data import Batch, Dataset
from torch_geometric.data.data import BaseData
class Collater:
def __init__(self, follow_batch, exclude_keys):
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
def __call__(self, batch):
batch = [x for x in batch if x is not None]
elem = batch[0]
if isinstance(elem, BaseData):
return Batch.from_data_list(batch, self.follow_batch,
self.exclude_keys)
elif isinstance(elem, torch.Tensor):
return default_collate(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: self([data[key] for data in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'):
return type(elem)(*(self(s) for s in zip(*batch)))
elif isinstance(elem, Sequence) and not isinstance(elem, str):
return [self(s) for s in zip(*batch)]
raise TypeError(f'DataLoader found invalid type: {type(elem)}')
def collate(self, batch): # Deprecated...
return self(batch)
class DataLoader(torch.utils.data.DataLoader):
r"""A data loader which merges data objects from a
:class:`torch_geometric.data.Dataset` to a mini-batch.
Data objects can be either of type :class:`~torch_geometric.data.Data` or
:class:`~torch_geometric.data.HeteroData`.
Args:
dataset (Dataset): The dataset from which to load the data.
batch_size (int, optional): How many samples per batch to load.
(default: :obj:`1`)
shuffle (bool, optional): If set to :obj:`True`, the data will be
reshuffled at every epoch. (default: :obj:`False`)
follow_batch (List[str], optional): Creates assignment batch
vectors for each key in the list. (default: :obj:`None`)
exclude_keys (List[str], optional): Will exclude each key in the
list. (default: :obj:`None`)
**kwargs (optional): Additional arguments of
:class:`torch.utils.data.DataLoader`.
"""
def __init__(
self,
dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1,
shuffle: bool = False,
follow_batch: Optional[List[str]] = None,
exclude_keys: Optional[List[str]] = None,
**kwargs,
):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
# Save for PyTorch Lightning:
self.follow_batch = follow_batch
self.exclude_keys = exclude_keys
super().__init__(
dataset,
batch_size,
shuffle,
collate_fn=Collater(follow_batch, exclude_keys),
**kwargs,
)
def collate_fn(data_list):
data_list = [x for x in data_list if x is not None]
return data_list
class DataListLoader(torch.utils.data.DataLoader):
def __init__(self, dataset: Union[Dataset, List[BaseData]],
batch_size: int = 1, shuffle: bool = False, **kwargs):
if 'collate_fn' in kwargs:
del kwargs['collate_fn']
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle,
collate_fn=collate_fn, **kwargs)

View File

@@ -1,60 +1,27 @@
import os
from argparse import FileType, ArgumentParser
import numpy as np
import pandas as pd
import pickle
from argparse import ArgumentParser
from Bio.PDB import PDBParser
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from tqdm import tqdm
from Bio import SeqIO
from datasets.constants import three_to_one
parser = ArgumentParser()
parser.add_argument('--out_file', type=str, default="data/prepared_for_esm.fasta")
parser.add_argument('--protein_ligand_csv', type=str, default='data/protein_ligand_example_csv.csv', help='Path to a .csv specifying the input as described in the main README')
parser.add_argument('--protein_path', type=str, default=None, help='Path to a single PDB file. If this is not None then it will be used instead of the --protein_ligand_csv')
parser.add_argument('--dataset', type=str, default="pdbbind")
parser.add_argument('--data_dir', type=str, default='../data/BindingMOAD_2020_ab_processed_biounit/pdb_protein/', help='')
args = parser.parse_args()
biopython_parser = PDBParser()
three_to_one = {'ALA': 'A',
'ARG': 'R',
'ASN': 'N',
'ASP': 'D',
'CYS': 'C',
'GLN': 'Q',
'GLU': 'E',
'GLY': 'G',
'HIS': 'H',
'ILE': 'I',
'LEU': 'L',
'LYS': 'K',
'MET': 'M',
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
'PHE': 'F',
'PRO': 'P',
'PYL': 'O',
'SER': 'S',
'SEC': 'U',
'THR': 'T',
'TRP': 'W',
'TYR': 'Y',
'VAL': 'V',
'ASX': 'B',
'GLX': 'Z',
'XAA': 'X',
'XLE': 'J'}
if args.protein_path is not None:
file_paths = [args.protein_path]
else:
df = pd.read_csv(args.protein_ligand_csv)
file_paths = list(set(df['protein_path'].tolist()))
sequences = []
ids = []
for file_path in tqdm(file_paths):
def get_structure_from_file(file_path):
structure = biopython_parser.get_structure('random_id', file_path)
structure = structure[0]
l = []
for i, chain in enumerate(structure):
seq = ''
for res_idx, residue in enumerate(chain):
@@ -75,13 +42,48 @@ for file_path in tqdm(file_paths):
except Exception as e:
seq += '-'
print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', file_path, '. Replacing it with a dash - .')
sequences.append(seq)
ids.append(f'{os.path.basename(file_path)}_chain_{i}')
records = []
for (index, seq) in zip(ids,sequences):
record = SeqRecord(Seq(seq), str(index))
record.description = ''
records.append(record)
SeqIO.write(records, args.out_file, "fasta")
l.append(seq)
return l
data_dir = args.data_dir
names = os.listdir(data_dir)
if args.dataset == 'pdbbind':
sequences = []
ids = []
for name in tqdm(names):
if name == '.DS_Store': continue
if os.path.exists(os.path.join(data_dir, name, f'{name}_protein_processed.pdb')):
rec_path = os.path.join(data_dir, name, f'{name}_protein_processed.pdb')
else:
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
l = get_structure_from_file(rec_path)
for i, seq in enumerate(l):
sequences.append(seq)
ids.append(f'{name}_chain_{i}')
records = []
for (index, seq) in zip(ids, sequences):
record = SeqRecord(Seq(seq), str(index))
record.description = ''
records.append(record)
SeqIO.write(records, args.out_file, "fasta")
elif args.dataset == 'moad':
names = [n[:6] for n in names]
name_to_sequence = {}
for name in tqdm(names):
if name == '.DS_Store': continue
if not os.path.exists(os.path.join(data_dir, f'{name}_protein.pdb')):
print(f"We are skipping {name} because there was no {name}_protein.pdb")
continue
rec_path = os.path.join(data_dir, f'{name}_protein.pdb')
l = get_structure_from_file(rec_path)
for i, seq in enumerate(l):
name_to_sequence[name + '_chain_' + str(i)] = seq
# save to file
with open(args.out_file, 'wb') as f:
pickle.dump(name_to_sequence, f)

View File

@@ -1,4 +1,3 @@
import os
from argparse import ArgumentParser
@@ -7,8 +6,8 @@ from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument('--esm_embeddings_path', type=str, default='data/embeddings_output', help='')
parser.add_argument('--output_path', type=str, default='data/esm2_3billion_embeddings.pt', help='')
parser.add_argument('--esm_embeddings_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new', help='')
parser.add_argument('--output_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new.pt', help='')
args = parser.parse_args()
dict = {}

123
datasets/loader.py Normal file
View File

@@ -0,0 +1,123 @@
import torch
from torch_geometric.data import Dataset
from datasets.dataloader import DataLoader, DataListLoader
from datasets.moad import MOAD
from datasets.pdb import PDBSidechain
from datasets.pdbbind import NoiseTransform, PDBBind
from utils.utils import read_strings_from_txt
class CombineDatasets(Dataset):
def __init__(self, dataset1, dataset2):
super(CombineDatasets, self).__init__()
self.dataset1 = dataset1
self.dataset2 = dataset2
def len(self):
return len(self.dataset1) + len(self.dataset2)
def get(self, idx):
if idx < len(self.dataset1):
return self.dataset1[idx]
else:
return self.dataset2[idx - len(self.dataset1)]
def add_complexes(self, new_complex_list):
self.dataset1.add_complexes(new_complex_list)
def construct_loader(args, t_to_sigma, device):
val_dataset2 = None
transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
all_atom=args.all_atoms, alpha=args.sampling_alpha, beta=args.sampling_beta,
include_miscellaneous_atoms=False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
crop_beyond_cutoff=args.crop_beyond)
if args.triple_training: assert args.combined_training
sequences_to_embeddings = None
if args.dataset == 'pdbsidechain' or args.triple_training:
if args.pdbsidechain_esm_embeddings_path is not None:
print('Loading ESM embeddings')
id_to_embeddings = torch.load(args.pdbsidechain_esm_embeddings_path)
sequences_list = read_strings_from_txt(args.pdbsidechain_esm_embeddings_sequences_path)
sequences_to_embeddings = {}
for i, seq in enumerate(sequences_list):
if str(i) in id_to_embeddings:
sequences_to_embeddings[seq] = id_to_embeddings[str(i)]
if args.dataset == 'pdbsidechain' or args.triple_training:
common_args = {'root': args.pdbsidechain_dir, 'transform': transform, 'limit_complexes': args.limit_complexes,
'receptor_radius': args.receptor_radius,
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
'remove_hs': args.remove_hs, 'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
'knn_only_graph': not args.not_knn_only_graph, 'sequences_to_embeddings': sequences_to_embeddings,
'vandermers_max_dist': args.vandermers_max_dist,
'vandermers_buffer_residue_num': args.vandermers_buffer_residue_num,
'vandermers_min_contacts': args.vandermers_min_contacts,
'remove_second_segment': args.remove_second_segment,
'merge_clusters': args.merge_clusters}
train_dataset3 = PDBSidechain(cache_path=args.cache_path, split='train', multiplicity=args.train_multiplicity, **common_args)
if args.dataset == 'pdbsidechain':
train_dataset = train_dataset3
val_dataset = PDBSidechain(cache_path=args.cache_path, split='val', multiplicity=args.val_multiplicity, **common_args)
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
if args.dataset in ['pdbbind', 'moad', 'generalisation', 'distillation']:
common_args = {'transform': transform, 'limit_complexes': args.limit_complexes,
'chain_cutoff': args.chain_cutoff, 'receptor_radius': args.receptor_radius,
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
'knn_only_graph': False if not hasattr(args, 'not_knn_only_graph') else not args.not_knn_only_graph,
'include_miscellaneous_atoms': False if not hasattr(args, 'include_miscellaneous_atoms') else args.include_miscellaneous_atoms,
'matching_tries': args.matching_tries}
if args.dataset == 'pdbbind' or args.dataset == 'generalisation' or args.combined_training:
train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
num_conformers=args.num_conformers, root=args.pdbbind_dir,
esm_embeddings_path=args.pdbbind_esm_embeddings_path,
protein_file=args.protein_file, **common_args)
if args.dataset == 'moad' or args.combined_training:
train_dataset2 = MOAD(cache_path=args.cache_path, split='train', keep_original=True,
num_conformers=args.num_conformers, max_receptor_size=args.max_receptor_size,
remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
multiplicity= args.train_multiplicity, unroll_clusters=args.unroll_clusters,
esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
root=args.moad_dir, esm_embeddings_path=args.moad_esm_embeddings_path,
enforce_timesplit=args.enforce_timesplit, **common_args)
if args.combined_training:
train_dataset = CombineDatasets(train_dataset2, train_dataset)
if args.triple_training:
train_dataset = CombineDatasets(train_dataset, train_dataset3)
else:
train_dataset = train_dataset2
if args.dataset == 'pdbbind' or args.double_val:
val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True,
esm_embeddings_path=args.pdbbind_esm_embeddings_path, root=args.pdbbind_dir,
protein_file=args.protein_file, require_ligand=True, **common_args)
if args.double_val:
val_dataset2 = val_dataset
if args.dataset == 'moad' or args.dataset == 'generalisation':
val_dataset = MOAD(cache_path=args.cache_path, split='val', keep_original=True,
multiplicity= args.val_multiplicity, max_receptor_size=args.max_receptor_size,
remove_promiscuous_targets=args.remove_promiscuous_targets, min_ligand_size=args.min_ligand_size,
esm_embeddings_sequences_path=args.moad_esm_embeddings_sequences_path,
unroll_clusters=args.unroll_clusters, root=args.moad_dir,
esm_embeddings_path=args.moad_esm_embeddings_path, require_ligand=True, **common_args)
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=False, pin_memory=args.pin_memory, drop_last=args.dataloader_drop_last)
return train_loader, val_loader, val_dataset2

547
datasets/moad.py Normal file
View File

@@ -0,0 +1,547 @@
import os
import pickle
from multiprocessing import Pool
import random
import copy
from torch_geometric.data import Batch
import numpy as np
import torch
from prody import confProDy
from rdkit import Chem
from rdkit.Chem import RemoveHs
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.utils import subgraph
from tqdm import tqdm
confProDy(verbosity='none')
from datasets.process_mols import get_lig_graph_with_matching, moad_extract_receptor_structure
from utils.utils import read_strings_from_txt
class MOAD(Dataset):
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0, chain_cutoff=None,
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, esm_embeddings_sequences_path=None, require_ligand=False,
include_miscellaneous_atoms=False, keep_local_structures=False,
min_ligand_size=0, knn_only_graph=False, matching_tries=1, multiplicity=1,
max_receptor_size=None, remove_promiscuous_targets=None, unroll_clusters=False, remove_pdbbind=False,
enforce_timesplit=False, no_randomness=False, single_cluster_name=None, total_dataset_size=None, skip_matching=False):
super(MOAD, self).__init__(root, transform)
self.moad_dir = root
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.max_lig_size = max_lig_size
self.split = split
self.limit_complexes = limit_complexes
self.receptor_radius = receptor_radius
self.num_workers = num_workers
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.remove_hs = remove_hs
self.require_ligand = require_ligand
self.esm_embeddings_path = esm_embeddings_path
self.esm_embeddings_sequences_path = esm_embeddings_sequences_path
self.keep_local_structures = keep_local_structures
self.knn_only_graph = knn_only_graph
self.matching_tries = matching_tries
self.all_atoms = all_atoms
self.multiplicity = multiplicity
self.chain_cutoff = chain_cutoff
self.no_randomness = no_randomness
self.total_dataset_size = total_dataset_size
self.skip_matching = skip_matching
self.prot_cache_path = os.path.join(cache_path, f'MOAD12_limit{self.limit_complexes}_INDEX{self.split}'
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms')
+ ('' if not self.knn_only_graph else '_knnOnly'))
self.lig_cache_path = os.path.join(cache_path, f'MOAD12_limit{self.limit_complexes}_INDEX{self.split}'
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
+ ('' if not matching else f'_matching')
+ ('' if not skip_matching else f'skip')
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}')
+ ('' if not keep_local_structures else f'_keptLocalStruct')
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}'))
self.popsize, self.maxiter = popsize, maxiter
self.matching, self.keep_original = matching, keep_original
self.num_conformers = num_conformers
self.single_cluster_name = single_cluster_name
if split == 'train':
split = 'PDBBind'
with open("./data/splits/MOAD_generalisation_splits.pkl", "rb") as f:
self.split_clusters = pickle.load(f)[split]
clustes_path = os.path.join(self.moad_dir, "new_cluster_to_ligands.pkl")
with open(clustes_path, "rb") as f:
self.cluster_to_ligands = pickle.load(f)
#self.cluster_to_ligands = {k: [s.split('.')[0] for s in v] for k, v in self.cluster_to_ligands.items()}
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
if not self.check_all_receptors():
os.makedirs(self.prot_cache_path, exist_ok=True)
self.preprocessing_receptors()
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
if not os.path.exists(os.path.join(self.lig_cache_path, "ligands.pkl")):
os.makedirs(self.lig_cache_path, exist_ok=True)
self.preprocessing_ligands()
print('loading ligands from memory: ', os.path.join(self.lig_cache_path, "ligands.pkl"))
with open(os.path.join(self.lig_cache_path, "ligands.pkl"), 'rb') as f:
self.ligands = pickle.load(f)
if require_ligand:
with open(os.path.join(self.lig_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
self.rdkit_ligands = pickle.load(f)
self.rdkit_ligands = {lig.name:mol for mol, lig in zip(self.rdkit_ligands, self.ligands)}
len_before = len(self.ligands)
if not self.single_cluster_name is None:
self.ligands = [lig for lig in self.ligands if lig.name in self.cluster_to_ligands[self.single_cluster_name]]
print('Kept', len(self.ligands), f'ligands in {self.single_cluster_name} out of', len_before)
len_before = len(self.ligands)
self.ligands = {lig.name: lig for lig in self.ligands if min_ligand_size == 0 or lig['ligand'].x.shape[0] >= min_ligand_size}
print('removed', len_before - len(self.ligands), 'ligands below minimum size out of', len_before)
receptors_names = set([lig.name[:6] for lig in self.ligands.values()])
self.collect_receptors(receptors_names, max_receptor_size, remove_promiscuous_targets)
# filter ligands for which the receptor failed
tot_before = len(self.ligands)
self.ligands = {k:v for k, v in self.ligands.items() if k[:6] in self.receptors}
print('removed', tot_before - len(self.ligands), 'ligands with no receptor out of', tot_before)
if remove_pdbbind:
complexes_pdbbind = read_strings_from_txt('data/splits/timesplit_no_lig_overlap_train') + read_strings_from_txt('data/splits/timesplit_no_lig_overlap_val')
with open('data/BindingMOAD_2020_ab_processed_biounit/ecod_t_group_binding_site_assignment_dict_major_domain.pkl', 'rb') as f:
pdbbind_to_cluster = pickle.load(f)
clusters_pdbbind = set([pdbbind_to_cluster[c] for c in complexes_pdbbind])
self.split_clusters = [c for c in self.split_clusters if c not in clusters_pdbbind]
self.cluster_to_ligands = {k: v for k, v in self.cluster_to_ligands.items() if k not in clusters_pdbbind}
ligand_accepted = []
for c, ligands in self.cluster_to_ligands.items():
ligand_accepted += ligands
ligand_accepted = set(ligand_accepted)
tot_before = len(self.ligands)
self.ligands = {k: v for k, v in self.ligands.items() if k in ligand_accepted}
print('removed', tot_before - len(self.ligands), 'ligands in overlap with PDBBind out of', tot_before)
if enforce_timesplit:
with open("data/splits/pdbids_2019", "r") as f:
lines = f.readlines()
pdbids_from2019 = []
for i in range(6, len(lines), 4):
pdbids_from2019.append(lines[i][18:22])
pdbids_from2019 = set(pdbids_from2019)
len_before = len(self.ligands)
self.ligands = {k: v for k, v in self.ligands.items() if k[:4].upper() not in pdbids_from2019}
print('removed', len_before - len(self.ligands), 'ligands from 2019 out of', len_before)
if unroll_clusters:
rec_keys = set([k[:6] for k in self.ligands.keys()])
self.cluster_to_ligands = {k:[k2 for k2 in self.ligands.keys() if k2[:6] == k] for k in rec_keys}
self.split_clusters = list(rec_keys)
else:
for c in self.cluster_to_ligands.keys():
self.cluster_to_ligands[c] = [v for v in self.cluster_to_ligands[c] if v in self.ligands]
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_ligands[c])>0]
print_statistics(self)
list_names = [name for cluster in self.split_clusters for name in self.cluster_to_ligands[cluster]]
with open(os.path.join(self.prot_cache_path, f'moad_{self.split}_names.txt'), 'w') as f:
f.write('\n'.join(list_names))
def len(self):
return len(self.split_clusters) * self.multiplicity if self.total_dataset_size is None else self.total_dataset_size
def get_by_name(self, ligand_name, cluster):
ligand_graph = copy.deepcopy(self.ligands[ligand_name])
complex_graph = copy.deepcopy(self.receptors[ligand_name[:6]])
if False and self.keep_original and hasattr(ligand_graph['ligand'], 'orig_pos'):
lig_path = os.path.join(self.moad_dir, 'pdb_superligand', ligand_name + '.pdb')
lig = Chem.MolFromPDBFile(lig_path)
formula = np.asarray([atom.GetSymbol() for atom in lig.GetAtoms()])
# check for same receptor/ligand pair with a different binding position
for ligand_comp in self.cluster_to_ligands[cluster]:
if ligand_comp == ligand_name or ligand_comp[:6] != ligand_name[:6]:
continue
lig_path_comp = os.path.join(self.moad_dir, 'pdb_superligand', ligand_comp + '.pdb')
if not os.path.exists(lig_path_comp):
continue
lig_comp = Chem.MolFromPDBFile(lig_path_comp)
formula_comp = np.asarray([atom.GetSymbol() for atom in lig_comp.GetAtoms()])
if formula.shape == formula_comp.shape and np.all(formula == formula_comp) and hasattr(
self.ligands[ligand_comp], 'orig_pos'):
print(f'Found complex {ligand_comp} to have the same complex/ligand pair, adding it into orig_pos')
# add the orig_pos of the binding position
if not isinstance(ligand_graph['ligand'].orig_pos, list):
ligand_graph['ligand'].orig_pos = [ligand_graph['ligand'].orig_pos]
ligand_graph['ligand'].orig_pos.append(self.ligands[ligand_comp].orig_pos)
for type in ligand_graph.node_types + ligand_graph.edge_types:
for key, value in ligand_graph[type].items():
complex_graph[type][key] = value
complex_graph.name = ligand_graph.name
if isinstance(complex_graph['ligand'].pos, list):
for i in range(len(complex_graph['ligand'].pos)):
complex_graph['ligand'].pos[i] -= complex_graph.original_center
else:
complex_graph['ligand'].pos -= complex_graph.original_center
if self.require_ligand:
complex_graph.mol = copy.deepcopy(self.rdkit_ligands[ligand_name])
if self.chain_cutoff:
distances = torch.norm(
(torch.from_numpy(complex_graph['ligand'].orig_pos[0]) - complex_graph.original_center).unsqueeze(1) - complex_graph['receptor'].pos.unsqueeze(0), dim=2)
distances = distances.min(dim=0)[0]
if torch.min(distances) >= self.chain_cutoff:
print('minimum distance', torch.min(distances), 'too large', ligand_name,
'skipping and returning random. Number of chains',
torch.max(complex_graph['receptor'].chain_ids) + 1)
return self.get(random.randint(0, self.len()))
within_cutoff = distances < self.chain_cutoff
chains_within_cutoff = torch.zeros(torch.max(complex_graph['receptor'].chain_ids) + 1)
chains_within_cutoff.index_add_(0, complex_graph['receptor'].chain_ids, within_cutoff.float())
chains_within_cutoff_bool = chains_within_cutoff > 0
residues_to_keep = chains_within_cutoff_bool[complex_graph['receptor'].chain_ids]
if self.all_atoms:
atom_to_res_mapping = complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index[1]
atoms_to_keep = residues_to_keep[atom_to_res_mapping]
rec_remapper = (torch.cumsum(residues_to_keep.long(), dim=0) - 1)
atom_to_res_new_mapping = rec_remapper[atom_to_res_mapping][atoms_to_keep]
atom_res_edge_index = torch.stack([torch.arange(len(atom_to_res_new_mapping)), atom_to_res_new_mapping])
complex_graph['atom'].x = complex_graph['atom'].x[atoms_to_keep]
complex_graph['atom'].pos = complex_graph['atom'].pos[atoms_to_keep]
complex_graph['atom', 'atom_contact', 'atom'].edge_index = \
subgraph(atoms_to_keep, complex_graph['atom', 'atom_contact', 'atom'].edge_index,
relabel_nodes=True)[0]
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
complex_graph['receptor'].pos = complex_graph['receptor'].pos[residues_to_keep]
complex_graph['receptor'].x = complex_graph['receptor'].x[residues_to_keep]
complex_graph['receptor'].side_chain_vecs = complex_graph['receptor'].side_chain_vecs[residues_to_keep]
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
subgraph(residues_to_keep, complex_graph['receptor', 'rec_contact', 'receptor'].edge_index,
relabel_nodes=True)[0]
extra_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= extra_center
if isinstance(complex_graph['ligand'].pos, list):
for i in range(len(complex_graph['ligand'].pos)):
complex_graph['ligand'].pos[i] -= extra_center
else:
complex_graph['ligand'].pos -= extra_center
complex_graph.original_center += extra_center
complex_graph['receptor'].pop('chain_ids')
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq',
'to_keep', 'chain_ids']:
if hasattr(complex_graph, a):
delattr(complex_graph, a)
if hasattr(complex_graph['receptor'], a):
delattr(complex_graph['receptor'], a)
return complex_graph
def get(self, idx):
if self.total_dataset_size is not None:
idx = random.randint(0, len(self.split_clusters) - 1)
idx = idx % len(self.split_clusters)
cluster = self.split_clusters[idx]
if self.no_randomness:
ligand_name = sorted(self.cluster_to_ligands[cluster])[0]
else:
ligand_name = random.choice(self.cluster_to_ligands[cluster])
complex_graph = self.get_by_name(ligand_name, cluster)
if self.total_dataset_size is not None:
complex_graph = Batch.from_data_list([complex_graph])
return complex_graph
def get_all_complexes(self):
complexes = {}
for cluster in self.split_clusters:
for ligand_name in self.cluster_to_ligands[cluster]:
complexes[ligand_name] = self.get_by_name(ligand_name, cluster)
return complexes
def preprocessing_receptors(self):
print(f'Processing receptors from [{self.split}] and saving it to [{self.prot_cache_path}]')
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
receptor_names_all = [l[:6] for l in complex_names_all]
receptor_names_all = sorted(list(dict.fromkeys(receptor_names_all)))
print(f'Loading {len(receptor_names_all)} receptors.')
if self.esm_embeddings_path is not None:
id_to_embeddings = torch.load(self.esm_embeddings_path)
sequences_list = read_strings_from_txt(self.esm_embeddings_sequences_path)
sequences_to_embeddings = {}
for i, seq in enumerate(sequences_list):
sequences_to_embeddings[seq] = id_to_embeddings[str(i)]
else:
sequences_to_embeddings = None
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(receptor_names_all)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.prot_cache_path, f"receptors{i}.pkl")):
continue
receptor_names = receptor_names_all[1000*i:1000*(i+1)]
receptor_graphs = []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(receptor_names), desc=f'loading receptors {i}/{len(receptor_names_all)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_receptor, zip(receptor_names, [sequences_to_embeddings]*len(receptor_names))):
if t is not None:
print(len(receptor_graphs))
receptor_graphs.append(t)
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
print('Number of receptors: ', len(receptor_graphs))
with open(os.path.join(self.prot_cache_path, f"receptors{i}.pkl"), 'wb') as f:
pickle.dump((receptor_graphs), f)
return receptor_names_all
def check_all_receptors(self):
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
receptor_names_all = [l[:6] for l in complex_names_all]
receptor_names_all = list(dict.fromkeys(receptor_names_all))
for i in range(len(receptor_names_all)//1000+1):
if not os.path.exists(os.path.join(self.prot_cache_path, f"receptors{i}.pkl")):
return False
return True
def collect_receptors(self, receptors_to_keep=None, max_receptor_size=None, remove_promiscuous_targets=None):
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
receptor_names_all = [l[:6] for l in complex_names_all]
receptor_names_all = sorted(list(dict.fromkeys(receptor_names_all)))
receptor_graphs_all = []
total_recovered = 0
print(f'Loading {len(receptor_names_all)} receptors to keep {len(receptors_to_keep)}.')
for i in range(len(receptor_names_all)//1000+1):
print(f'prot path: {os.path.join(self.prot_cache_path, f"receptors{i}.pkl")}')
with open(os.path.join(self.prot_cache_path, f"receptors{i}.pkl"), 'rb') as f:
l = pickle.load(f)
total_recovered += len(l)
if receptors_to_keep is not None:
l = [t for t in l if t['receptor_name'] in receptors_to_keep]
receptor_graphs_all.extend(l)
cur_len = len(receptor_graphs_all)
print(f"Kept {len(receptor_graphs_all)} receptors out of {len(receptor_names_all)} total and recovered {total_recovered}")
if max_receptor_size is not None:
receptor_graphs_all = [rec for rec in receptor_graphs_all if rec["receptor"].pos.shape[0] <= max_receptor_size]
print(f"Kept {len(receptor_graphs_all)} receptors out of {cur_len} after filtering by size")
cur_len = len(receptor_graphs_all)
if remove_promiscuous_targets is not None:
promiscuous_targets = set()
for name in complex_names_all:
l = name.split('_')
if int(l[3]) > remove_promiscuous_targets:
promiscuous_targets.add(name[:6])
receptor_graphs_all = [rec for rec in receptor_graphs_all if rec["receptor_name"] not in promiscuous_targets]
print(f"Kept {len(receptor_graphs_all)} receptors out of {cur_len} after removing promiscuous targets")
self.receptors = {}
for r in receptor_graphs_all:
self.receptors[r['receptor_name']] = r
return
def get_receptor(self, par):
name, sequences_to_embeddings = par
rec_path = os.path.join(self.moad_dir, 'pdb_protein', name + '_protein.pdb')
if not os.path.exists(rec_path):
print("Receptor not found", name, rec_path)
return None
complex_graph = HeteroData()
complex_graph['receptor_name'] = name
try:
moad_extract_receptor_structure(path=rec_path, complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors, sequences_to_embeddings=sequences_to_embeddings,
knn_only_graph=self.knn_only_graph, all_atoms=self.all_atoms, atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
return None
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= protein_center
if self.all_atoms:
complex_graph['atom'].pos -= protein_center
complex_graph.original_center = protein_center
return complex_graph
def preprocessing_ligands(self):
print(f'Processing complexes from [{self.split}] and saving it to [{self.lig_cache_path}]')
complex_names_all = sorted([l for c in self.split_clusters for l in self.cluster_to_ligands[c]])
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
print(f'Loading {len(complex_names_all)} ligands.')
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(complex_names_all)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.lig_cache_path, f"ligands{i}.pkl")):
continue
complex_names = complex_names_all[1000*i:1000*(i+1)]
ligand_graphs, rdkit_ligands = [], []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_ligand, complex_names):
if t is not None:
ligand_graphs.append(t[0])
rdkit_ligands.append(t[1])
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.lig_cache_path, f"ligands{i}.pkl"), 'wb') as f:
pickle.dump((ligand_graphs), f)
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
ligand_graphs_all = []
for i in range(len(complex_names_all)//1000+1):
with open(os.path.join(self.lig_cache_path, f"ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
ligand_graphs_all.extend(l)
with open(os.path.join(self.lig_cache_path, f"ligands.pkl"), 'wb') as f:
pickle.dump((ligand_graphs_all), f)
rdkit_ligands_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
rdkit_ligands_all.extend(l)
with open(os.path.join(self.lig_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands_all), f)
def get_ligand(self, name):
if self.split == 'train':
lig_path = os.path.join(self.moad_dir, 'pdb_superligand', name + '.pdb')
else:
lig_path = os.path.join(self.moad_dir, 'pdb_ligand', name + '.pdb')
if not os.path.exists(lig_path):
print("Ligand not found", name, lig_path)
return None
# read pickle
lig = Chem.MolFromPDBFile(lig_path)
if self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size:
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
return None
try:
if self.matching:
smile = Chem.MolToSmiles(lig)
if '.' in smile:
print(f'Ligand {name} has multiple fragments and we are doing matching. Not including {name} in preprocessed data.')
return None
complex_graph = HeteroData()
complex_graph['name'] = name
Chem.SanitizeMol(lig)
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries, skip_matching=self.skip_matching)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
return None
if self.split != 'train':
other_positions = [complex_graph['ligand'].orig_pos]
nsplit = name.split('_')
for i in range(100):
new_file = os.path.join(self.moad_dir, 'pdb_ligand', f'{nsplit[0]}_{nsplit[1]}_{nsplit[2]}_{i}.pdb')
if os.path.exists(new_file):
if i != int(nsplit[3]):
lig = Chem.MolFromPDBFile(new_file)
lig = RemoveHs(lig, sanitize=True)
other_positions.append(lig.GetConformer().GetPositions())
else:
break
complex_graph['ligand'].orig_pos = np.asarray(other_positions)
return complex_graph, lig
def print_statistics(dataset):
statistics = ([], [], [], [], [], [])
receptor_sizes = []
for i in range(len(dataset)):
complex_graph = dataset[i]
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
receptor_sizes.append(complex_graph['receptor'].pos.shape[0])
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
molecule_center = torch.mean(lig_pos, dim=0)
radius_molecule = torch.max(
torch.linalg.vector_norm(lig_pos - molecule_center.unsqueeze(0), dim=1))
distance_center = torch.linalg.vector_norm(molecule_center)
statistics[0].append(radius_protein)
statistics[1].append(radius_molecule)
statistics[2].append(distance_center)
if "rmsd_matching" in complex_graph:
statistics[3].append(complex_graph.rmsd_matching)
else:
statistics[3].append(0)
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1)
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph:
statistics[5].append(complex_graph.rmsd_matching)
if len(statistics[5]) == 0:
statistics[5].append(-1)
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching']
print('Number of complexes: ', len(dataset))
for i in range(len(name)):
array = np.asarray(statistics[i])
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
return

146
datasets/parse_chi.py Normal file
View File

@@ -0,0 +1,146 @@
# From Nick Polizzi
import numpy as np
from collections import defaultdict
import prody as pr
import os
from datasets.constants import chi, atom_order, aa_long2short, aa_short2aa_idx, aa_idx2aa_short
def get_dihedral_indices(resname, chi_num):
"""Return the atom indices for the specified dihedral angle.
"""
if resname not in chi:
return np.array([np.nan]*4)
if chi_num not in chi[resname]:
return np.array([np.nan]*4)
return np.array([atom_order[resname].index(x) for x in chi[resname][chi_num]])
dihedral_indices = defaultdict(list)
for aa in atom_order.keys():
for i in range(1, 5):
inds = get_dihedral_indices(aa, i)
dihedral_indices[aa].append(inds)
dihedral_indices[aa] = np.array(dihedral_indices[aa])
def vector_batch(a, b):
return a - b
def unit_vector_batch(v):
return v / np.linalg.norm(v, axis=1, keepdims=True)
def dihedral_angle_batch(p):
b0 = vector_batch(p[:, 0], p[:, 1])
b1 = vector_batch(p[:, 1], p[:, 2])
b2 = vector_batch(p[:, 2], p[:, 3])
n1 = np.cross(b0, b1)
n2 = np.cross(b1, b2)
m1 = np.cross(n1, b1 / np.linalg.norm(b1, axis=1, keepdims=True))
x = np.sum(n1 * n2, axis=1)
y = np.sum(m1 * n2, axis=1)
deg = np.degrees(np.arctan2(y, x))
deg[deg < 0] += 360
return deg
def batch_compute_dihedral_angles(sidechains):
sidechains_np = np.array(sidechains)
dihedral_angles = dihedral_angle_batch(sidechains_np)
return dihedral_angles
def get_coords(prody_pdb):
resindices = sorted(set(prody_pdb.ca.getResindices()))
coords = np.zeros((len(resindices), 14, 3))
for i, resind in enumerate(resindices):
sel = prody_pdb.select(f'resindex {resind}')
resname = sel.getResnames()[0]
for j, name in enumerate(atom_order[aa_long2short[resname] if resname in aa_long2short else 'X']):
sel_resnum_name = sel.select(f'name {name}')
if sel_resnum_name is not None:
coords[i, j, :] = sel_resnum_name.getCoords()[0]
else:
coords[i, j, :] = [np.nan, np.nan, np.nan]
return coords
def get_onehot_sequence(seq):
onehot = np.zeros((len(seq), 20))
for i, aa in enumerate(seq):
idx = aa_short2aa_idx[aa] if aa in aa_short2aa_idx else 7 # 7 is the index for GLY
onehot[i, idx] = 1
return onehot
def get_dihedral_indices(onehot_sequence):
return np.array([dihedral_indices[aa_idx2aa_short[aa_idx]] for aa_idx in np.where(onehot_sequence)[1]])
def _get_chi_angles(coords, indices):
X = coords
Y = indices.astype(int)
N = coords.shape[0]
mask = np.isnan(indices)
Y[mask] = 0
Z = X[np.arange(N)[:, None, None], Y, :]
Z[mask] = np.nan
chi_angles = batch_compute_dihedral_angles(Z.reshape(-1, 4, 3)).reshape(N, 4)
return chi_angles
def get_chi_angles(coords, seq, return_onehot=False):
"""
Parameters
----------
prody_pdb : prody.AtomGroup
prody pdb object or selection
return_coords : bool, optional
return coordinates of prody_pdb in (N, 14, 3) array format, by default False
return_onehot : bool, optional
return one-hot sequence of prody_pdb, by default False
Returns
-------
numpy array of shape (N, 4)
Array contains chi angles of sidechains in row-order of residue indices in prody_pdb.
If a chi angle is not defined for a residue, due to missing atoms or GLY / ALA, it is set to np.nan.
"""
onehot = get_onehot_sequence(seq)
dihedral_indices = get_dihedral_indices(onehot)
if return_onehot:
return _get_chi_angles(coords, dihedral_indices), onehot
return _get_chi_angles(coords, dihedral_indices)
def test_get_chi_angles(print_chi_angles=False):
# need internet connection of '6w70.pdb' in working directory
pdb = pr.parsePDB('6w70')
prody_pdb = pdb.select('chain A')
chi_angles = get_chi_angles(prody_pdb)
assert chi_angles.shape == (prody_pdb.ca.numAtoms(), 4)
assert chi_angles[0,0] < 56.0 and chi_angles[0,0] > 55.0
print('test_get_chi_angles passed')
try:
os.remove('6w70.pdb.gz')
except:
pass
if print_chi_angles:
print(chi_angles)
return True
if __name__ == '__main__':
test_get_chi_angles(print_chi_angles=True)

536
datasets/pdb.py Normal file
View File

@@ -0,0 +1,536 @@
# Significant contribution from Ben Fry
import copy
import os.path
import pickle
import random
from multiprocessing import Pool
import numpy as np
import pandas as pd
import torch
from rdkit import Chem
from rdkit.Chem import AllChem, MolFromSmiles
from scipy.spatial.distance import pdist, squareform
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.utils import subgraph
from tqdm import tqdm
from datasets.constants import aa_to_cg_indices, amino_acid_smiles, cg_rdkit_indices
from datasets.parse_chi import aa_long2short, atom_order
from datasets.process_mols import new_extract_receptor_structure, get_lig_graph, generate_conformer
from utils.torsion import get_transformation_mask
def read_strings_from_txt(path):
# every line will be one element of the returned list
with open(path) as file:
lines = file.readlines()
return [line.rstrip() for line in lines]
def compute_num_ca_neighbors(coords, cg_coords, idx, is_valid_bb_node, max_dist=5, buffer_residue_num=7):
"""
Counts number of residues with heavy atoms within max_dist (Angstroms) of this sidechain that are not
residues within +/- buffer_residue_num in primary sequence.
From Ben's code
Note: Gabriele removed the chain_index
"""
# Extract coordinates of all residues in the protein.
bb_coords = coords
# Compute the indices that we should not consider interactions.
excluded_neighbors = [idx - x for x in reversed(range(0, buffer_residue_num+1)) if (idx - x) >= 0]
excluded_neighbors.extend([idx + x for x in range(1, buffer_residue_num+1)])
# Create indices of an N x M distance matrix where N is num BB nodes and M is num CG nodes.
e_idx = torch.stack([
torch.arange(bb_coords.shape[0]).unsqueeze(-1).expand((-1, cg_coords.shape[0])).flatten(),
torch.arange(cg_coords.shape[0]).unsqueeze(0).expand((bb_coords.shape[0], -1)).flatten()
])
# Expand bb_coords and cg_coords into the same dimensionality.
bb_coords_exp = bb_coords[e_idx[0]]
cg_coords_exp = cg_coords[e_idx[1]].unsqueeze(1)
# Every row is distance of chemical group to each atom in backbone coordinate frame.
bb_exp_idces, _ = (torch.cdist(bb_coords_exp, cg_coords_exp).squeeze(-1) < max_dist).nonzero(as_tuple=True)
bb_idces_within_thresh = torch.unique(e_idx[0][bb_exp_idces])
# Only count residues that are not adjacent or origin in primary sequence and are valid backbone residues (fully resolved coordinate frame).
bb_idces_within_thresh = bb_idces_within_thresh[~torch.isin(bb_idces_within_thresh, torch.tensor(excluded_neighbors)) & is_valid_bb_node[bb_idces_within_thresh]]
return len(bb_idces_within_thresh)
def identify_valid_vandermers(args):
"""
Constructs a tensor containing all the number of contacts for each residue that can be sampled from for chemical groups.
By using every sidechain as a chemical group, we will load the actual chemical groups at training time.
These can be used to sample as probabilities once divided by the sum.
"""
complex_graph, max_dist, buffer_residue_num = args
# Constructs a mask tracking whether index is a valid coordinate frame / residue label to train over.
#is_in_residue_vocabulary = torch.tensor([x in aa_short2long for x in data['seq']]).bool()
coords, seq = complex_graph.coords, complex_graph.seq
is_valid_bb_node = (coords[:, :4].isnan().sum(dim=(1,2)) == 0).bool() #* is_in_residue_vocabulary
valid_cg_idces = []
for idx, aa in enumerate(seq):
if aa not in aa_to_cg_indices:
valid_cg_idces.append(0)
else:
indices = aa_to_cg_indices[aa]
cg_coordinates = coords[idx][indices]
# remove chemical group residues that aren't fully resolved.
if torch.any(cg_coordinates.isnan()).item():
valid_cg_idces.append(0)
continue
nbr_count = compute_num_ca_neighbors(coords, cg_coordinates, idx, is_valid_bb_node,
max_dist=max_dist, buffer_residue_num=buffer_residue_num)
valid_cg_idces.append(nbr_count)
return complex_graph.name, torch.tensor(valid_cg_idces)
def fast_identify_valid_vandermers(coords, seq, max_dist=5, buffer_residue_num=7):
offset = 10000 + max_dist
R = coords.shape[0]
coords = coords.numpy().reshape(-1, 3)
pdist_mat = squareform(pdist(coords))
pdist_mat = pdist_mat.reshape((R, 14, R, 14))
pdist_mat = np.nan_to_num(pdist_mat, nan=offset)
pdist_mat = np.min(pdist_mat, axis=(1, 3))
# compute pairwise distances
pdist_mat = pdist_mat + np.diag(np.ones(len(seq)) * offset)
for i in range(1, buffer_residue_num+1):
pdist_mat += np.diag(np.ones(len(seq)-i) * offset, k=i) + np.diag(np.ones(len(seq)-i) * offset, k=-i)
# get number of residues that are within max_dist of each other
nbr_count = np.sum(pdist_mat < max_dist, axis=1)
return torch.tensor(nbr_count)
def compute_cg_features(aa, aa_smile):
"""
Given an amino acid and a smiles string returns the stacked tensor of chemical group atom encodings.
The order of the output tensor rows corresponds to the index the atoms appear in aa_to_cg_indices from constants.
"""
# Handle any residues that we don't have chemical groups for (ex: GLY if not using bb_cnh and bb_cco)
aa_short = aa_long2short[aa]
if aa_short not in aa_to_cg_indices:
return None
# Create rdkit molecule from smiles string.
mol = Chem.MolFromSmiles(aa_smile)
complex_graph = HeteroData()
get_lig_graph(mol, complex_graph)
atoms_to_keep = torch.tensor([i for i, _ in cg_rdkit_indices[aa].items()]).long()
complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr = \
subgraph(atoms_to_keep, complex_graph['ligand', 'ligand'].edge_index, complex_graph['ligand', 'ligand'].edge_attr, relabel_nodes=True)
complex_graph['ligand'].x = complex_graph['ligand'].x[atoms_to_keep]
edge_mask, mask_rotate = get_transformation_mask(complex_graph)
complex_graph['ligand'].edge_mask = torch.tensor(edge_mask)
complex_graph['ligand'].mask_rotate = mask_rotate
return complex_graph
class PDBSidechain(Dataset):
def __init__(self, root, transform=None, cache_path='data/cache', split='train', limit_complexes=0,
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, remove_hs=True, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, sequences_to_embeddings=None,
knn_only_graph=True, multiplicity=1, vandermers_max_dist=5, vandermers_buffer_residue_num=7,
vandermers_min_contacts=5, remove_second_segment=False, merge_clusters=1, vandermers_extraction=True,
add_random_ligand=False):
super(PDBSidechain, self).__init__(root, transform)
assert remove_hs == True, "not implemented yet"
self.root = root
self.split = split
self.limit_complexes = limit_complexes
self.receptor_radius = receptor_radius
self.knn_only_graph = knn_only_graph
self.multiplicity = multiplicity
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.num_workers = num_workers
self.sequences_to_embeddings = sequences_to_embeddings
self.remove_second_segment = remove_second_segment
self.merge_clusters = merge_clusters
self.vandermers_extraction = vandermers_extraction
self.add_random_ligand = add_random_ligand
self.all_atoms = all_atoms
self.atom_radius = atom_radius
self.atom_max_neighbors = atom_max_neighbors
if vandermers_extraction:
self.cg_node_feature_lookup_dict = {aa_long2short[aa]: compute_cg_features(aa, aa_smile) for aa, aa_smile in
amino_acid_smiles.items()}
self.cache_path = os.path.join(cache_path, f'PDB3_limit{self.limit_complexes}_INDEX{self.split}'
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ ('' if not self.knn_only_graph else '_knnOnly'))
self.read_split()
if not self.check_all_proteins():
os.makedirs(self.cache_path, exist_ok=True)
self.preprocess()
self.vandermers_max_dist = vandermers_max_dist
self.vandermers_buffer_residue_num = vandermers_buffer_residue_num
self.vandermers_min_contacts = vandermers_min_contacts
self.collect_proteins()
filtered_proteins = []
if vandermers_extraction:
for complex_graph in tqdm(self.protein_graphs):
if complex_graph.name in self.vandermers and torch.any(self.vandermers[complex_graph.name] >= 10):
filtered_proteins.append(complex_graph)
print(f"Computed vandermers and kept {len(filtered_proteins)} proteins out of {len(self.protein_graphs)}")
else:
filtered_proteins = self.protein_graphs
second_filter = []
for complex_graph in tqdm(filtered_proteins):
if sequences_to_embeddings is None or complex_graph.orig_seq in sequences_to_embeddings:
second_filter.append(complex_graph)
print(f"Checked embeddings available and kept {len(second_filter)} proteins out of {len(filtered_proteins)}")
self.protein_graphs = second_filter
# filter clusters that have no protein graphs
self.split_clusters = list(set([g.cluster for g in self.protein_graphs]))
self.cluster_to_complexes = {c: [] for c in self.split_clusters}
for p in self.protein_graphs:
self.cluster_to_complexes[p['cluster']].append(p)
self.split_clusters = [c for c in self.split_clusters if len(self.cluster_to_complexes[c]) > 0]
print("Total elements in set", len(self.split_clusters) * self.multiplicity // self.merge_clusters)
self.name_to_complex = {p.name: p for p in self.protein_graphs}
self.define_probabilities()
if self.add_random_ligand:
# read csv with all smiles
with open('data/smiles_list.csv', 'r') as f:
self.smiles_list = f.readlines()
self.smiles_list = [s.split(',')[0] for s in self.smiles_list]
def define_probabilities(self):
if not self.vandermers_extraction:
return
if self.vandermers_min_contacts is not None:
self.probabilities = torch.arange(1000) - self.vandermers_min_contacts + 1
self.probabilities[:self.vandermers_min_contacts] = 0
else:
with open('data/pdbbind_counts.pkl', 'rb') as f:
pdbbind_counts = pickle.load(f)
pdb_counts = torch.ones(1000)
for contacts in self.vandermers.values():
pdb_counts.index_add_(0, contacts, torch.ones(contacts.shape))
print(pdbbind_counts[:30])
print(pdb_counts[:30])
self.probabilities = pdbbind_counts / pdb_counts
self.probabilities[:7] = 0
def len(self):
return len(self.split_clusters) * self.multiplicity // self.merge_clusters
def get(self, idx=None, protein=None, smiles=None):
assert idx is not None or (protein is not None and smiles is not None), "provide idx or protein or smile"
if protein is None or smiles is None:
idx = idx % len(self.split_clusters)
if self.merge_clusters > 1:
idx = idx * self.merge_clusters
idx = idx + random.randint(0, self.merge_clusters - 1)
idx = min(idx, len(self.split_clusters) - 1)
cluster = self.split_clusters[idx]
protein_graph = copy.deepcopy(random.choice(self.cluster_to_complexes[cluster]))
else:
protein_graph = copy.deepcopy(self.name_to_complex[protein])
if self.sequences_to_embeddings is not None:
#print(self.sequences_to_embeddings[protein_graph.orig_seq].shape, len(protein_graph.orig_seq), protein_graph.to_keep.shape)
if len(protein_graph.orig_seq) != len(self.sequences_to_embeddings[protein_graph.orig_seq]):
print('problem with ESM embeddings')
return self.get(random.randint(0, self.len()))
lm_embeddings = self.sequences_to_embeddings[protein_graph.orig_seq][protein_graph.to_keep]
protein_graph['receptor'].x = torch.cat([protein_graph['receptor'].x, lm_embeddings], dim=1)
if self.vandermers_extraction:
# select sidechain to remove
vandermers_contacts = self.vandermers[protein_graph.name]
vandermers_probs = self.probabilities[vandermers_contacts].numpy()
if not np.any(vandermers_contacts.numpy() >= 10):
print('no vandarmers >= 10 retrying with new one')
return self.get(random.randint(0, self.len()))
sidechain_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
# remove part of the sequence
residues_to_keep = np.ones(len(protein_graph.seq), dtype=bool)
residues_to_keep[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
if self.remove_second_segment:
pos_idx = protein_graph['receptor'].pos[sidechain_idx]
limit_closeness = 10
far_enough = torch.sum((protein_graph['receptor'].pos - pos_idx[None, :]) ** 2, dim=-1) > limit_closeness ** 2
vandermers_probs = vandermers_probs * far_enough.float().numpy()
vandermers_probs[max(0, sidechain_idx - self.vandermers_buffer_residue_num):
min(sidechain_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = 0
if np.all(vandermers_probs<=0):
print('no second vandermer available retrying with new one')
return self.get(random.randint(0, self.len()))
sc2_idx = np.random.choice(np.arange(len(vandermers_probs)), p=vandermers_probs / np.sum(vandermers_probs))
residues_to_keep[max(0, sc2_idx - self.vandermers_buffer_residue_num):
min(sc2_idx + self.vandermers_buffer_residue_num + 1, len(protein_graph.seq))] = False
residues_to_keep = torch.from_numpy(residues_to_keep)
protein_graph['receptor'].pos = protein_graph['receptor'].pos[residues_to_keep]
protein_graph['receptor'].x = protein_graph['receptor'].x[residues_to_keep]
protein_graph['receptor'].side_chain_vecs = protein_graph['receptor'].side_chain_vecs[residues_to_keep]
protein_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
subgraph(residues_to_keep, protein_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
# create the sidechain ligand
sidechain_aa = protein_graph.seq[sidechain_idx]
ligand_graph = self.cg_node_feature_lookup_dict[sidechain_aa]
ligand_graph['ligand'].pos = protein_graph.coords[sidechain_idx][protein_graph.mask[sidechain_idx]]
for type in ligand_graph.node_types + ligand_graph.edge_types:
for key, value in ligand_graph[type].items():
protein_graph[type][key] = value
protein_graph['ligand'].orig_pos = protein_graph['ligand'].pos.numpy()
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
protein_graph['ligand'].pos = protein_graph['ligand'].pos - protein_center
protein_graph.original_center = protein_center
protein_graph['receptor_name'] = protein_graph.name
else:
protein_center = torch.mean(protein_graph['receptor'].pos, dim=0, keepdim=True)
protein_graph['receptor'].pos = protein_graph['receptor'].pos - protein_center
protein_graph.original_center = protein_center
protein_graph['receptor_name'] = protein_graph.name
if self.add_random_ligand:
if smiles is not None:
mol = MolFromSmiles(smiles)
try:
generate_conformer(mol)
except Exception as e:
print("failed to generate the given ligand returning None", e)
return None
else:
success = False
while not success:
smiles = random.choice(self.smiles_list)
mol = MolFromSmiles(smiles)
try:
success = not generate_conformer(mol)
except Exception as e:
print(e, "changing ligand")
lig_graph = HeteroData()
get_lig_graph(mol, lig_graph)
edge_mask, mask_rotate = get_transformation_mask(lig_graph)
lig_graph['ligand'].edge_mask = torch.tensor(edge_mask)
lig_graph['ligand'].mask_rotate = mask_rotate
lig_graph['ligand'].smiles = smiles
lig_graph['ligand'].pos = lig_graph['ligand'].pos - torch.mean(lig_graph['ligand'].pos, dim=0, keepdim=True)
for type in lig_graph.node_types + lig_graph.edge_types:
for key, value in lig_graph[type].items():
protein_graph[type][key] = value
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
if hasattr(protein_graph, a):
delattr(protein_graph, a)
if hasattr(protein_graph['receptor'], a):
delattr(protein_graph['receptor'], a)
return protein_graph
def read_split(self):
# read CSV file
df = pd.read_csv(self.root + "/list.csv")
print("Loaded list CSV file")
# get clusters and filter by split
if self.split == "train":
val_clusters = set(read_strings_from_txt(self.root + "/valid_clusters.txt"))
test_clusters = set(read_strings_from_txt(self.root + "/test_clusters.txt"))
clusters = df["CLUSTER"].unique()
clusters = [int(c) for c in clusters if c not in val_clusters and c not in test_clusters]
elif self.split == "val":
clusters = [int(s) for s in read_strings_from_txt(self.root + "/valid_clusters.txt")]
elif self.split == "test":
clusters = [int(s) for s in read_strings_from_txt(self.root + "/test_clusters.txt")]
else:
raise ValueError("Split must be train, val or test")
print(self.split, "clusters", len(clusters))
clusters = set(clusters)
self.chains_in_cluster = []
complexes_in_cluster = set()
for chain, cluster in zip(df["CHAINID"], df["CLUSTER"]):
if cluster not in clusters:
continue
# limit to one chain per complex
if chain[:4] not in complexes_in_cluster:
self.chains_in_cluster.append((chain, cluster))
complexes_in_cluster.add(chain[:4])
print("Filtered chains in cluster", len(self.chains_in_cluster))
if self.limit_complexes > 0:
self.chains_in_cluster = self.chains_in_cluster[:self.limit_complexes]
def check_all_proteins(self):
for i in range(len(self.chains_in_cluster)//10000+1):
if not os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
return False
return True
def collect_proteins(self):
self.protein_graphs = []
self.vandermers = {}
total_recovered = 0
print(f'Loading {len(self.chains_in_cluster)} protein graphs.')
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
random.shuffle(list_indices)
for i in list_indices:
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'rb') as f:
print(i)
l = pickle.load(f)
total_recovered += len(l)
self.protein_graphs.extend(l)
if not self.vandermers_extraction:
continue
if os.path.exists(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl')):
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'rb') as f:
vandermers = pickle.load(f)
self.vandermers.update(vandermers)
continue
vandermers = {}
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(l), desc=f'computing vandermers {i}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
arguments = zip(l, [self.vandermers_max_dist] * len(l),
[self.vandermers_buffer_residue_num] * len(l))
for t in map_fn(identify_valid_vandermers, arguments):
if t is not None:
vandermers[t[0]] = t[1]
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.cache_path, f'vandermers{i}_{self.vandermers_max_dist}_{self.vandermers_buffer_residue_num}.pkl'), 'wb') as f:
pickle.dump(vandermers, f)
self.vandermers.update(vandermers)
print(f"Kept {len(self.protein_graphs)} proteins out of {len(self.chains_in_cluster)} total")
return
def preprocess(self):
# running preprocessing in parallel on multiple workers and saving the progress every 10000 proteins
list_indices = list(range(len(self.chains_in_cluster) // 10000 + 1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.cache_path, f"protein_graphs{i}.pkl")):
continue
chains_names = self.chains_in_cluster[10000 * i:10000 * (i + 1)]
protein_graphs = []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(chains_names),
desc=f'loading protein batch {i}/{len(self.chains_in_cluster) // 10000 + 1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.load_chain, chains_names):
if t is not None:
protein_graphs.append(t)
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.cache_path, f"protein_graphs{i}.pkl"), 'wb') as f:
pickle.dump(protein_graphs, f)
print("Finished preprocessing and saving protein graphs")
def load_chain(self, c):
chain, cluster = c
if not os.path.exists(self.root + f"/pdb/{chain[1:3]}/{chain}.pt"):
print("File not found", chain)
return None
data = torch.load(self.root + f"/pdb/{chain[1:3]}/{chain}.pt")
complex_graph = HeteroData()
complex_graph['name'] = chain
orig_seq = data["seq"]
coords = data["xyz"]
mask = data["mask"].bool()
# remove residues with NaN backbone coordinates
to_keep = torch.logical_not(torch.any(torch.isnan(coords[:, :4, 0]), dim=1))
coords = coords[to_keep]
seq = ''.join(np.asarray(list(orig_seq))[to_keep.numpy()].tolist())
mask = mask[to_keep]
if len(coords) == 0:
print("All coords were NaN", chain)
return None
try:
new_extract_receptor_structure(seq, coords.numpy(), complex_graph=complex_graph, neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors, knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms, atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print("Error in extracting receptor", chain)
print(e)
return None
if torch.any(torch.isnan(complex_graph['receptor'].pos)):
print("NaN in pos receptor", chain)
return None
complex_graph.coords = coords
complex_graph.seq = seq
complex_graph.mask = mask
complex_graph.cluster = cluster
complex_graph.orig_seq = orig_seq
complex_graph.to_keep = to_keep
return complex_graph
if __name__ == "__main__":
dataset = PDBSidechain(root="data/pdb_2021aug02_sample", split="train", multiplicity=1, limit_complexes=150)
print(len(dataset))
print(dataset[0])
for p in dataset:
print(p)
pass

View File

@@ -1,125 +1,208 @@
import binascii
import glob
import hashlib
import os
import pickle
from collections import defaultdict
from multiprocessing import Pool
import random
import copy
import torch.nn.functional as F
import numpy as np
import torch
from rdkit.Chem import MolToSmiles, MolFromSmiles, AddHs
from rdkit import Chem
from rdkit.Chem import MolFromSmiles, AddHs
from torch_geometric.data import Dataset, HeteroData
from torch_geometric.loader import DataLoader, DataListLoader
from torch_geometric.transforms import BaseTransform
from tqdm import tqdm
from rdkit.Chem import RemoveAllHs
from datasets.process_mols import read_molecule, get_rec_graph, generate_conformer, \
get_lig_graph_with_matching, extract_receptor_structure, parse_receptor, parse_pdb_from_path
from datasets.process_mols import read_molecule, get_lig_graph_with_matching, generate_conformer, moad_extract_receptor_structure
from utils.diffusion_utils import modify_conformer, set_time
from utils.utils import read_strings_from_txt
from utils.utils import read_strings_from_txt, crop_beyond
from utils import so3, torus
class NoiseTransform(BaseTransform):
def __init__(self, t_to_sigma, no_torsion, all_atom):
def __init__(self, t_to_sigma, no_torsion, all_atom, alpha=1, beta=1,
include_miscellaneous_atoms=False, crop_beyond_cutoff=None, time_independent=False, rmsd_cutoff=0,
minimum_t=0, sampling_mixing_coeff=0):
self.t_to_sigma = t_to_sigma
self.no_torsion = no_torsion
self.all_atom = all_atom
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.minimum_t = minimum_t
self.mixing_coeff = sampling_mixing_coeff
self.alpha = alpha
self.beta = beta
self.crop_beyond_cutoff = crop_beyond_cutoff
self.rmsd_cutoff = rmsd_cutoff
self.time_independent = time_independent
def __call__(self, data):
t = np.random.uniform()
t_tr, t_rot, t_tor = t, t, t
return self.apply_noise(data, t_tr, t_rot, t_tor)
t_tr, t_rot, t_tor, t = self.get_time()
return self.apply_noise(data, t_tr, t_rot, t_tor, t)
def apply_noise(self, data, t_tr, t_rot, t_tor, tr_update = None, rot_update=None, torsion_updates=None):
def get_time(self):
if self.time_independent:
t = np.random.beta(self.alpha, self.beta)
t_tr, t_rot, t_tor = t,t,t
else:
t = None
if self.mixing_coeff == 0:
t = np.random.beta(self.alpha, self.beta)
t = self.minimum_t + t * (1 - self.minimum_t)
else:
choice = np.random.binomial(1, self.mixing_coeff)
t1 = np.random.beta(self.alpha, self.beta)
t1 = t1 * self.minimum_t
t2 = np.random.beta(self.alpha, self.beta)
t2 = self.minimum_t + t2 * (1 - self.minimum_t)
t = choice * t1 + (1 - choice) * t2
t_tr, t_rot, t_tor = t,t,t
return t_tr, t_rot, t_tor, t
def apply_noise(self, data, t_tr, t_rot, t_tor, t, tr_update = None, rot_update=None, torsion_updates=None):
if not torch.is_tensor(data['ligand'].pos):
data['ligand'].pos = random.choice(data['ligand'].pos)
if self.time_independent:
orig_complex_graph = copy.deepcopy(data)
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(t_tr, t_rot, t_tor)
set_time(data, t_tr, t_rot, t_tor, 1, self.all_atom, device=None)
if self.time_independent:
set_time(data, 0, 0, 0, 0, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
else:
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
tr_update = torch.normal(mean=0, std=tr_sigma, size=(1, 3)) if tr_update is None else tr_update
rot_update = so3.sample_vec(eps=rot_sigma) if rot_update is None else rot_update
torsion_updates = np.random.normal(loc=0.0, scale=tor_sigma, size=data['ligand'].edge_mask.sum()) if torsion_updates is None else torsion_updates
torsion_updates = None if self.no_torsion else torsion_updates
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
try:
modify_conformer(data, tr_update, torch.from_numpy(rot_update).float(), torsion_updates)
except Exception as e:
print("failed modify conformer")
print(e)
if self.time_independent:
if self.no_torsion:
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
filterHs = torch.not_equal(data['ligand'].x[:, 0], 0).cpu().numpy()
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
ligand_pos = data['ligand'].pos.cpu().numpy()[filterHs]
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy()
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=1).mean(axis=0))
data.y = torch.tensor(rmsd < self.rmsd_cutoff).float().unsqueeze(0)
data.atom_y = data.y
return data
data.tr_score = -tr_update / tr_sigma ** 2
data.rot_score = torch.from_numpy(so3.score_vec(vec=rot_update, eps=rot_sigma)).float().unsqueeze(0)
data.tor_score = None if self.no_torsion else torch.from_numpy(torus.score(torsion_updates, tor_sigma)).float()
data.tor_sigma_edge = None if self.no_torsion else np.ones(data['ligand'].edge_mask.sum()) * tor_sigma
if data['ligand'].pos.shape[0] == 1:
# if the ligand is a single atom, the rotational score is always 0
data.rot_score = data.rot_score * 0
if self.crop_beyond_cutoff is not None:
crop_beyond(data, tr_sigma * 3 + self.crop_beyond_cutoff, self.all_atom)
set_time(data, t, t_tr, t_rot, t_tor, 1, self.all_atom, device=None, include_miscellaneous_atoms=self.include_miscellaneous_atoms)
return data
class PDBBind(Dataset):
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0,
def __init__(self, root, transform=None, cache_path='data/cache', split_path='data/', limit_complexes=0, chain_cutoff=10,
receptor_radius=30, num_workers=1, c_alpha_max_neighbors=None, popsize=15, maxiter=15,
matching=True, keep_original=False, max_lig_size=None, remove_hs=False, num_conformers=1, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, esm_embeddings_path=None, require_ligand=False,
ligands_list=None, protein_path_list=None, ligand_descriptions=None, keep_local_structures=False):
include_miscellaneous_atoms=False,
protein_path_list=None, ligand_descriptions=None, keep_local_structures=False,
protein_file="protein_processed", ligand_file="ligand",
knn_only_graph=False, matching_tries=1, dataset='PDBBind'):
super(PDBBind, self).__init__(root, transform)
self.pdbbind_dir = root
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.max_lig_size = max_lig_size
self.split_path = split_path
self.limit_complexes = limit_complexes
self.chain_cutoff = chain_cutoff
self.receptor_radius = receptor_radius
self.num_workers = num_workers
self.c_alpha_max_neighbors = c_alpha_max_neighbors
self.remove_hs = remove_hs
self.esm_embeddings_path = esm_embeddings_path
self.use_old_wrong_embedding_order = False
self.require_ligand = require_ligand
self.protein_path_list = protein_path_list
self.ligand_descriptions = ligand_descriptions
self.keep_local_structures = keep_local_structures
self.protein_file = protein_file
self.fixed_knn_radius_graph = True
self.knn_only_graph = knn_only_graph
self.matching_tries = matching_tries
self.ligand_file = ligand_file
self.dataset = dataset
assert knn_only_graph or (not all_atoms)
self.all_atoms = all_atoms
if matching or protein_path_list is not None and ligand_descriptions is not None:
cache_path += '_torsion'
if all_atoms:
cache_path += '_allatoms'
self.full_cache_path = os.path.join(cache_path, f'limit{self.limit_complexes}'
self.full_cache_path = os.path.join(cache_path, f'{dataset}3_limit{self.limit_complexes}'
f'_INDEX{os.path.splitext(os.path.basename(self.split_path))[0]}'
f'_maxLigSize{self.max_lig_size}_H{int(not self.remove_hs)}'
f'_recRad{self.receptor_radius}_recMax{self.c_alpha_max_neighbors}'
+ ('' if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ ('' if not matching or num_conformers == 1 else f'_confs{num_conformers}')
f'_chainCutoff{self.chain_cutoff if self.chain_cutoff is None else int(self.chain_cutoff)}'
+ (''if not all_atoms else f'_atomRad{atom_radius}_atomMax{atom_max_neighbors}')
+ (''if not matching or num_conformers == 1 else f'_confs{num_conformers}')
+ ('' if self.esm_embeddings_path is None else f'_esmEmbeddings')
+ '_full'
+ ('' if not keep_local_structures else f'_keptLocalStruct')
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode()))))
+ ('' if protein_path_list is None or ligand_descriptions is None else str(binascii.crc32(''.join(ligand_descriptions + protein_path_list).encode())))
+ ('' if protein_file == "protein_processed" else '_' + protein_file)
+ ('' if not self.fixed_knn_radius_graph else (f'_fixedKNN' if not self.knn_only_graph else '_fixedKNNonly'))
+ ('' if not self.include_miscellaneous_atoms else '_miscAtoms')
+ ('' if self.use_old_wrong_embedding_order else '_chainOrd')
+ ('' if self.matching_tries == 1 else f'_tries{matching_tries}'))
self.popsize, self.maxiter = popsize, maxiter
self.matching, self.keep_original = matching, keep_original
self.num_conformers = num_conformers
self.all_atoms = all_atoms
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
if not os.path.exists(os.path.join(self.full_cache_path, "heterographs.pkl"))\
or (require_ligand and not os.path.exists(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"))):
if not self.check_all_complexes():
os.makedirs(self.full_cache_path, exist_ok=True)
if protein_path_list is None or ligand_descriptions is None:
self.preprocessing()
else:
self.inference_preprocessing()
print('loading data from memory: ', os.path.join(self.full_cache_path, "heterographs.pkl"))
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
self.complex_graphs = pickle.load(f)
if require_ligand:
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
self.rdkit_ligands = pickle.load(f)
self.complex_graphs, self.rdkit_ligands = self.collect_all_complexes()
print_statistics(self.complex_graphs)
list_names = [complex['name'] for complex in self.complex_graphs]
with open(os.path.join(self.full_cache_path, f'pdbbind_{os.path.splitext(os.path.basename(self.split_path))[0][:3]}_names.txt'), 'w') as f:
f.write('\n'.join(list_names))
def len(self):
return len(self.complex_graphs)
def get(self, idx):
complex_graph = copy.deepcopy(self.complex_graphs[idx])
if self.require_ligand:
complex_graph = copy.deepcopy(self.complex_graphs[idx])
complex_graph.mol = copy.deepcopy(self.rdkit_ligands[idx])
return complex_graph
else:
return copy.deepcopy(self.complex_graphs[idx])
complex_graph.mol = RemoveAllHs(copy.deepcopy(self.rdkit_ligands[idx]))
for a in ['random_coords', 'coords', 'seq', 'sequence', 'mask', 'rmsd_matching', 'cluster', 'orig_seq', 'to_keep', 'chain_ids']:
if hasattr(complex_graph, a):
delattr(complex_graph, a)
if hasattr(complex_graph['receptor'], a):
delattr(complex_graph['receptor'], a)
return complex_graph
def preprocessing(self):
print(f'Processing complexes from [{self.split_path}] and saving it to [{self.full_cache_path}]')
@@ -132,94 +215,63 @@ class PDBBind(Dataset):
if self.esm_embeddings_path is not None:
id_to_embeddings = torch.load(self.esm_embeddings_path)
chain_embeddings_dictlist = defaultdict(list)
chain_indices_dictlist = defaultdict(list)
for key, embedding in id_to_embeddings.items():
key_name = key.split('_')[0]
key_name = key.split('_chain_')[0]
if key_name in complex_names_all:
chain_embeddings_dictlist[key_name].append(embedding)
chain_indices_dictlist[key_name].append(int(key.split('_chain_')[1]))
lm_embeddings_chains_all = []
for name in complex_names_all:
lm_embeddings_chains_all.append(chain_embeddings_dictlist[name])
complex_chains_embeddings = chain_embeddings_dictlist[name]
complex_chains_indices = chain_indices_dictlist[name]
chain_reorder_idx = np.argsort(complex_chains_indices)
reordered_chains = [complex_chains_embeddings[i] for i in chain_reorder_idx]
lm_embeddings_chains_all.append(reordered_chains)
else:
lm_embeddings_chains_all = [None] * len(complex_names_all)
if self.num_workers > 1:
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
for i in range(len(complex_names_all)//1000+1):
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
complex_names = complex_names_all[1000*i:1000*(i+1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
complex_graphs_all = []
for i in range(len(complex_names_all)//1000+1):
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
l = pickle.load(f)
complex_graphs_all.extend(l)
with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
pickle.dump((complex_graphs_all), f)
rdkit_ligands_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
rdkit_ligands_all.extend(l)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands_all), f)
else:
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(complex_names_all)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
complex_names = complex_names_all[1000*i:1000*(i+1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
with tqdm(total=len(complex_names_all), desc='loading complexes') as pbar:
for t in map(self.get_complex, zip(complex_names_all, lm_embeddings_chains_all, [None] * len(complex_names_all), [None] * len(complex_names_all))):
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(complex_names), desc=f'loading complexes {i}/{len(complex_names_all)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(complex_names, lm_embeddings_chains, [None] * len(complex_names), [None] * len(complex_names))):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
def inference_preprocessing(self):
ligands_list = []
print('Reading molecules and generating local structures with RDKit (unless --keep_local_structures is turned on).')
failed_ligand_indices = []
for idx, ligand_description in tqdm(enumerate(self.ligand_descriptions)):
try:
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
if mol is not None:
print('Reading molecules and generating local structures with RDKit')
for ligand_description in tqdm(self.ligand_descriptions):
mol = MolFromSmiles(ligand_description) # check if it is a smiles or a path
if mol is not None:
mol = AddHs(mol)
generate_conformer(mol)
ligands_list.append(mol)
else:
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True)
if not self.keep_local_structures:
mol.RemoveAllConformers()
mol = AddHs(mol)
generate_conformer(mol)
ligands_list.append(mol)
else:
mol = read_molecule(ligand_description, remove_hs=False, sanitize=True)
if mol is None:
raise Exception('RDKit could not read the molecule ', ligand_description)
if not self.keep_local_structures:
mol.RemoveAllConformers()
mol = AddHs(mol)
generate_conformer(mol)
ligands_list.append(mol)
except Exception as e:
print('Failed to read molecule ', ligand_description, ' We are skipping it. The reason is the exception: ', e)
failed_ligand_indices.append(idx)
for index in sorted(failed_ligand_indices, reverse=True):
del self.protein_path_list[index]
del self.ligand_descriptions[index]
ligands_list.append(mol)
if self.esm_embeddings_path is not None:
print('Reading language model embeddings.')
@@ -235,129 +287,144 @@ class PDBBind(Dataset):
lm_embeddings_chains_all = [None] * len(self.protein_path_list)
print('Generating graphs for ligands and proteins')
if self.num_workers > 1:
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
for i in range(len(self.protein_path_list)//1000+1):
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
complex_graphs_all = []
for i in range(len(self.protein_path_list)//1000+1):
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
l = pickle.load(f)
complex_graphs_all.extend(l)
with open(os.path.join(self.full_cache_path, f"heterographs.pkl"), 'wb') as f:
pickle.dump((complex_graphs_all), f)
rdkit_ligands_all = []
for i in range(len(self.protein_path_list) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
rdkit_ligands_all.extend(l)
with open(os.path.join(self.full_cache_path, f"rdkit_ligands.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands_all), f)
else:
# running preprocessing in parallel on multiple workers and saving the progress every 1000 complexes
list_indices = list(range(len(self.protein_path_list)//1000+1))
random.shuffle(list_indices)
for i in list_indices:
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
continue
protein_paths_chunk = self.protein_path_list[1000*i:1000*(i+1)]
ligand_description_chunk = self.ligand_descriptions[1000*i:1000*(i+1)]
ligands_chunk = ligands_list[1000 * i:1000 * (i + 1)]
lm_embeddings_chains = lm_embeddings_chains_all[1000*i:1000*(i+1)]
complex_graphs, rdkit_ligands = [], []
with tqdm(total=len(self.protein_path_list), desc='loading complexes') as pbar:
for t in map(self.get_complex, zip(self.protein_path_list, lm_embeddings_chains_all, ligands_list, self.ligand_descriptions)):
if self.num_workers > 1:
p = Pool(self.num_workers, maxtasksperchild=1)
p.__enter__()
with tqdm(total=len(protein_paths_chunk), desc=f'loading complexes {i}/{len(protein_paths_chunk)//1000+1}') as pbar:
map_fn = p.imap_unordered if self.num_workers > 1 else map
for t in map_fn(self.get_complex, zip(protein_paths_chunk, lm_embeddings_chains, ligands_chunk,ligand_description_chunk)):
complex_graphs.extend(t[0])
rdkit_ligands.extend(t[1])
pbar.update()
if complex_graphs == []: raise Exception('Preprocessing did not succeed for any complex')
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'wb') as f:
if self.num_workers > 1: p.__exit__(None, None, None)
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'wb') as f:
pickle.dump((complex_graphs), f)
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'wb') as f:
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'wb') as f:
pickle.dump((rdkit_ligands), f)
def check_all_complexes(self):
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
return True
complex_names_all = read_strings_from_txt(self.split_path)
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
for i in range(len(complex_names_all) // 1000 + 1):
if not os.path.exists(os.path.join(self.full_cache_path, f"heterographs{i}.pkl")):
return False
return True
def collect_all_complexes(self):
print('Collecting all complexes from cache', self.full_cache_path)
if os.path.exists(os.path.join(self.full_cache_path, f"heterographs.pkl")):
with open(os.path.join(self.full_cache_path, "heterographs.pkl"), 'rb') as f:
complex_graphs = pickle.load(f)
if self.require_ligand:
with open(os.path.join(self.full_cache_path, "rdkit_ligands.pkl"), 'rb') as f:
rdkit_ligands = pickle.load(f)
else:
rdkit_ligands = None
return complex_graphs, rdkit_ligands
complex_names_all = read_strings_from_txt(self.split_path)
if self.limit_complexes is not None and self.limit_complexes != 0:
complex_names_all = complex_names_all[:self.limit_complexes]
complex_graphs_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"heterographs{i}.pkl"), 'rb') as f:
print(i)
l = pickle.load(f)
complex_graphs_all.extend(l)
rdkit_ligands_all = []
for i in range(len(complex_names_all) // 1000 + 1):
with open(os.path.join(self.full_cache_path, f"rdkit_ligands{i}.pkl"), 'rb') as f:
l = pickle.load(f)
rdkit_ligands_all.extend(l)
return complex_graphs_all, rdkit_ligands_all
def get_complex(self, par):
name, lm_embedding_chains, ligand, ligand_description = par
if not os.path.exists(os.path.join(self.pdbbind_dir, name)) and ligand is None:
print("Folder not found", name)
return [], []
if ligand is not None:
rec_model = parse_pdb_from_path(name)
name = f'{name}____{ligand_description}'
ligs = [ligand]
else:
try:
rec_model = parse_receptor(name, self.pdbbind_dir)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
try:
lig = read_mol(self.pdbbind_dir, name, suffix=self.ligand_file, remove_hs=False)
if self.max_lig_size != None and lig.GetNumHeavyAtoms() > self.max_lig_size:
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
return [], []
ligs = read_mols(self.pdbbind_dir, name, remove_hs=False)
complex_graphs = []
failed_indices = []
for i, lig in enumerate(ligs):
if self.max_lig_size is not None and lig.GetNumHeavyAtoms() > self.max_lig_size:
print(f'Ligand with {lig.GetNumHeavyAtoms()} heavy atoms is larger than max_lig_size {self.max_lig_size}. Not including {name} in preprocessed data.')
continue
complex_graph = HeteroData()
complex_graph['name'] = name
try:
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
self.num_conformers, remove_hs=self.remove_hs)
rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings = extract_receptor_structure(copy.deepcopy(rec_model), lig, lm_embedding_chains=lm_embedding_chains)
if lm_embeddings is not None and len(c_alpha_coords) != len(lm_embeddings):
print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
failed_indices.append(i)
continue
get_lig_graph_with_matching(lig, complex_graph, self.popsize, self.maxiter, self.matching, self.keep_original,
self.num_conformers, remove_hs=self.remove_hs, tries=self.matching_tries)
get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius,
c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms,
atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings)
moad_extract_receptor_structure(path=os.path.join(self.pdbbind_dir, name, f'{name}_{self.protein_file}.pdb'),
complex_graph=complex_graph,
neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors,
lm_embeddings=lm_embedding_chains,
knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms,
atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
failed_indices.append(i)
continue
except Exception as e:
print(f'Skipping {name} because of the error:')
print(e)
return [], []
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= protein_center
if self.all_atoms:
complex_graph['atom'].pos -= protein_center
if self.dataset == 'posebusters':
other_positions = []
all_mol_file = os.path.join(self.pdbbind_dir, name, f'{name}_ligands.sdf')
supplier = Chem.SDMolSupplier(all_mol_file, sanitize=False, removeHs=False)
for mol in supplier:
Chem.SanitizeMol(mol)
all_mol = RemoveAllHs(mol)
for conf in all_mol.GetConformers():
other_positions.append(conf.GetPositions())
if (not self.matching) or self.num_conformers == 1:
complex_graph['ligand'].pos -= protein_center
else:
for p in complex_graph['ligand'].pos:
p -= protein_center
print(f'Found {len(other_positions)} alternative poses for {name}')
complex_graph['ligand'].orig_pos = np.asarray(other_positions)
complex_graph.original_center = protein_center
complex_graphs.append(complex_graph)
for idx_to_delete in sorted(failed_indices, reverse=True):
del ligs[idx_to_delete]
return complex_graphs, ligs
protein_center = torch.mean(complex_graph['receptor'].pos, dim=0, keepdim=True)
complex_graph['receptor'].pos -= protein_center
if self.all_atoms:
complex_graph['atom'].pos -= protein_center
if (not self.matching) or self.num_conformers == 1:
complex_graph['ligand'].pos -= protein_center
else:
for p in complex_graph['ligand'].pos:
p -= protein_center
complex_graph.original_center = protein_center
complex_graph['receptor_name'] = name
return [complex_graph], [lig]
def print_statistics(complex_graphs):
statistics = ([], [], [], [])
statistics = ([], [], [], [], [], [])
receptor_sizes = []
for complex_graph in complex_graphs:
lig_pos = complex_graph['ligand'].pos if torch.is_tensor(complex_graph['ligand'].pos) else complex_graph['ligand'].pos[0]
receptor_sizes.append(complex_graph['receptor'].pos.shape[0])
radius_protein = torch.max(torch.linalg.vector_norm(complex_graph['receptor'].pos, dim=1))
molecule_center = torch.mean(lig_pos, dim=0)
radius_molecule = torch.max(
@@ -370,43 +437,25 @@ def print_statistics(complex_graphs):
statistics[3].append(complex_graph.rmsd_matching)
else:
statistics[3].append(0)
statistics[4].append(int(complex_graph.random_coords) if "random_coords" in complex_graph else -1)
if "random_coords" in complex_graph and complex_graph.random_coords and "rmsd_matching" in complex_graph:
statistics[5].append(complex_graph.rmsd_matching)
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching']
if len(statistics[5]) == 0:
statistics[5].append(-1)
name = ['radius protein', 'radius molecule', 'distance protein-mol', 'rmsd matching', 'random coordinates', 'random rmsd matching']
print('Number of complexes: ', len(complex_graphs))
for i in range(4):
for i in range(len(name)):
array = np.asarray(statistics[i])
print(f"{name[i]}: mean {np.mean(array)}, std {np.std(array)}, max {np.max(array)}")
def construct_loader(args, t_to_sigma):
transform = NoiseTransform(t_to_sigma=t_to_sigma, no_torsion=args.no_torsion,
all_atom=args.all_atoms)
common_args = {'transform': transform, 'root': args.data_dir, 'limit_complexes': args.limit_complexes,
'receptor_radius': args.receptor_radius,
'c_alpha_max_neighbors': args.c_alpha_max_neighbors,
'remove_hs': args.remove_hs, 'max_lig_size': args.max_lig_size,
'matching': not args.no_torsion, 'popsize': args.matching_popsize, 'maxiter': args.matching_maxiter,
'num_workers': args.num_workers, 'all_atoms': args.all_atoms,
'atom_radius': args.atom_radius, 'atom_max_neighbors': args.atom_max_neighbors,
'esm_embeddings_path': args.esm_embeddings_path}
train_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_train, keep_original=True,
num_conformers=args.num_conformers, **common_args)
val_dataset = PDBBind(cache_path=args.cache_path, split_path=args.split_val, keep_original=True, **common_args)
loader_class = DataListLoader if torch.cuda.is_available() else DataLoader
train_loader = loader_class(dataset=train_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
val_loader = loader_class(dataset=val_dataset, batch_size=args.batch_size, num_workers=args.num_dataloader_workers, shuffle=True, pin_memory=args.pin_memory)
return train_loader, val_loader
return
def read_mol(pdbbind_dir, name, remove_hs=False):
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.sdf'), remove_hs=remove_hs, sanitize=True)
def read_mol(pdbbind_dir, name, suffix='ligand', remove_hs=False):
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.sdf'), remove_hs=remove_hs, sanitize=True)
if lig is None: # read mol2 file if sdf file cannot be sanitized
print('Using the .sdf file failed. We found a .mol2 file instead and are trying to use that.')
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_ligand.mol2'), remove_hs=remove_hs, sanitize=True)
lig = read_molecule(os.path.join(pdbbind_dir, name, f'{name}_{suffix}.mol2'), remove_hs=remove_hs, sanitize=True)
return lig

View File

@@ -1,96 +0,0 @@
import os
from argparse import FileType, ArgumentParser
import numpy as np
from Bio.PDB import PDBParser
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--chain_cutoff', type=int, default=10, help='')
parser.add_argument('--out_file', type=str, default="data/pdbbind_sequences.fasta")
args = parser.parse_args()
cutoff = args.chain_cutoff
data_dir = args.data_dir
names = os.listdir(data_dir)
#%%
from Bio import SeqIO
biopython_parser = PDBParser()
three_to_one = {'ALA': 'A',
'ARG': 'R',
'ASN': 'N',
'ASP': 'D',
'CYS': 'C',
'GLN': 'Q',
'GLU': 'E',
'GLY': 'G',
'HIS': 'H',
'ILE': 'I',
'LEU': 'L',
'LYS': 'K',
'MET': 'M',
'MSE': 'M', # this is almost the same AA as MET. The sulfur is just replaced by Selen
'PHE': 'F',
'PRO': 'P',
'PYL': 'O',
'SER': 'S',
'SEC': 'U',
'THR': 'T',
'TRP': 'W',
'TYR': 'Y',
'VAL': 'V',
'ASX': 'B',
'GLX': 'Z',
'XAA': 'X',
'XLE': 'J'}
sequences = []
ids = []
for name in tqdm(names):
if name == '.DS_Store': continue
if os.path.exists(os.path.join(data_dir, name, f'{name}_protein_processed.pdb')):
rec_path = os.path.join(data_dir, name, f'{name}_protein_processed.pdb')
elif os.path.exists(os.path.join(data_dir, name, f'{name}_protein.pdb')):
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
else:
continue
if cutoff > 10:
rec_path = os.path.join(data_dir, name, f'{name}_protein_obabel_reduce.pdb')
if not os.path.exists(rec_path):
rec_path = os.path.join(data_dir, name, f'{name}_protein.pdb')
structure = biopython_parser.get_structure('random_id', rec_path)
structure = structure[0]
for i, chain in enumerate(structure):
seq = ''
for res_idx, residue in enumerate(chain):
if residue.get_resname() == 'HOH':
continue
residue_coords = []
c_alpha, n, c = None, None, None
for atom in residue:
if atom.name == 'CA':
c_alpha = list(atom.get_vector())
if atom.name == 'N':
n = list(atom.get_vector())
if atom.name == 'C':
c = list(atom.get_vector())
if c_alpha != None and n != None and c != None: # only append residue if it is an amino acid and not
try:
seq += three_to_one[residue.get_resname()]
except Exception as e:
seq += '-'
print("encountered unknown AA: ", residue.get_resname(), ' in the complex ', name, '. Replacing it with a dash - .')
sequences.append(seq)
ids.append(f'{name}_chain_{i}')
records = []
for (index, seq) in zip(ids,sequences):
record = SeqRecord(Seq(seq), str(index))
record.description = ''
records.append(record)
SeqIO.write(records, args.out_file, "fasta")

View File

@@ -1,28 +1,24 @@
import copy
import os
import warnings
import numpy as np
import scipy.spatial as spa
import torch
from Bio.PDB import PDBParser
from Bio.PDB.PDBExceptions import PDBConstructionWarning
from rdkit import Chem
from rdkit.Chem.rdchem import BondType as BT
from rdkit.Chem import AllChem, GetPeriodicTable, RemoveHs
from rdkit.Geometry import Point3D
from scipy import spatial
from scipy.special import softmax
from torch_cluster import radius_graph
from torch import cdist
from torch_cluster import knn_graph
import prody as pr
import torch.nn.functional as F
from datasets.conformer_matching import get_torsion_angles, optimize_rotatable_bonds
from datasets.constants import aa_short2long, atom_order, three_to_one
from datasets.parse_chi import get_chi_angles, get_coords, aa_idx2aa_short, get_onehot_sequence
from utils.torsion import get_transformation_mask
biopython_parser = PDBParser()
periodic_table = GetPeriodicTable()
allowable_features = {
'possible_atomic_num_list': list(range(1, 119)) + ['misc'],
@@ -94,9 +90,13 @@ def lig_atom_featurizer(mol):
ringinfo = mol.GetRingInfo()
atom_features_list = []
for idx, atom in enumerate(mol.GetAtoms()):
chiral_tag = str(atom.GetChiralTag())
if chiral_tag in ['CHI_SQUAREPLANAR', 'CHI_TRIGONALBIPYRAMIDAL', 'CHI_OCTAHEDRAL']:
chiral_tag = 'CHI_OTHER'
atom_features_list.append([
safe_index(allowable_features['possible_atomic_num_list'], atom.GetAtomicNum()),
allowable_features['possible_chirality_list'].index(str(atom.GetChiralTag())),
allowable_features['possible_chirality_list'].index(str(chiral_tag)),
safe_index(allowable_features['possible_degree_list'], atom.GetTotalDegree()),
safe_index(allowable_features['possible_formal_charge_list'], atom.GetFormalCharge()),
safe_index(allowable_features['possible_implicit_valence_list'], atom.GetImplicitValence()),
@@ -111,18 +111,11 @@ def lig_atom_featurizer(mol):
allowable_features['possible_is_in_ring6_list'].index(ringinfo.IsAtomInRingOfSize(idx, 6)),
allowable_features['possible_is_in_ring7_list'].index(ringinfo.IsAtomInRingOfSize(idx, 7)),
allowable_features['possible_is_in_ring8_list'].index(ringinfo.IsAtomInRingOfSize(idx, 8)),
#g_charge if not np.isnan(g_charge) and not np.isinf(g_charge) else 0.
])
return torch.tensor(atom_features_list)
def rec_residue_featurizer(rec):
feature_list = []
for residue in rec.get_residues():
feature_list.append([safe_index(allowable_features['possible_amino_acids'], residue.get_resname())])
return torch.tensor(feature_list, dtype=torch.float32) # (N_res, 1)
def safe_index(l, e):
""" Return index of element e in list l. If e is not present, return the last index """
try:
@@ -131,122 +124,158 @@ def safe_index(l, e):
return len(l) - 1
def moad_extract_receptor_structure(path, complex_graph, neighbor_cutoff=20, max_neighbors=None, sequences_to_embeddings=None,
knn_only_graph=False, lm_embeddings=None, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None):
# load the entire pdb file
pdb = pr.parsePDB(path)
seq = pdb.ca.getSequence()
coords = get_coords(pdb)
one_hot = get_onehot_sequence(seq)
def parse_receptor(pdbid, pdbbind_dir):
rec = parsePDB(pdbid, pdbbind_dir)
return rec
chain_ids = np.zeros(len(one_hot))
res_chain_ids = pdb.ca.getChids()
res_seg_ids = pdb.ca.getSegnames()
res_chain_ids = np.asarray([s + c for s, c in zip(res_seg_ids, res_chain_ids)])
ids = np.unique(res_chain_ids)
sequences = []
lm_embeddings = lm_embeddings if sequences_to_embeddings is None else []
for i, id in enumerate(ids):
chain_ids[res_chain_ids == id] = i
s = np.argmax(one_hot[res_chain_ids == id], axis=1)
s = ''.join([aa_idx2aa_short[aa_idx] for aa_idx in s])
sequences.append(s)
if sequences_to_embeddings is not None:
lm_embeddings.append(sequences_to_embeddings[s])
complex_graph['receptor'].sequence = sequences
complex_graph['receptor'].chain_ids = torch.from_numpy(np.asarray(chain_ids)).long()
new_extract_receptor_structure(seq, coords, complex_graph, neighbor_cutoff=neighbor_cutoff, max_neighbors=max_neighbors,
lm_embeddings=lm_embeddings, knn_only_graph=knn_only_graph, all_atoms=all_atoms,
atom_cutoff=atom_cutoff, atom_max_neighbors=atom_max_neighbors)
def parsePDB(pdbid, pdbbind_dir):
rec_path = os.path.join(pdbbind_dir, pdbid, f'{pdbid}_protein_processed.pdb')
return parse_pdb_from_path(rec_path)
def new_extract_receptor_structure(seq, all_coords, complex_graph, neighbor_cutoff=20, max_neighbors=None, lm_embeddings=None,
knn_only_graph=False, all_atoms=False, atom_cutoff=None, atom_max_neighbors=None):
chi_angles, one_hot = get_chi_angles(all_coords, seq, return_onehot=True)
n_rel_pos, c_rel_pos = all_coords[:, 0, :] - all_coords[:, 1, :], all_coords[:, 2, :] - all_coords[:, 1, :]
side_chain_vecs = torch.from_numpy(np.concatenate([chi_angles / 360, n_rel_pos, c_rel_pos], axis=1))
def parse_pdb_from_path(path):
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=PDBConstructionWarning)
structure = biopython_parser.get_structure('random_id', path)
rec = structure[0]
return rec
# Build the k-NN graph
coords = torch.tensor(all_coords[:, 1, :], dtype=torch.float)
if len(coords) > 3000:
raise ValueError(f'The receptor is too large {len(coords)}')
if knn_only_graph:
edge_index = knn_graph(coords, k=max_neighbors if max_neighbors else 32)
else:
distances = cdist(coords, coords)
src_list = []
dst_list = []
for i in range(len(coords)):
dst = list(np.where(distances[i, :] < neighbor_cutoff)[0])
dst.remove(i)
max_neighbors = max_neighbors if max_neighbors else 1000
if max_neighbors != None and len(dst) > max_neighbors:
dst = list(np.argsort(distances[i, :]))[1: max_neighbors + 1]
if len(dst) == 0:
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
print(
f'The cutoff {neighbor_cutoff} was too small for one atom such that it had no neighbors. '
f'So we connected it to the closest other atom')
assert i not in dst
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend(dst)
edge_index = torch.from_numpy(np.asarray([dst_list, src_list]))
res_names_list = [aa_short2long[seq[i]] if seq[i] in aa_short2long else 'misc' for i in range(len(seq))]
feature_list = [[safe_index(allowable_features['possible_amino_acids'], res)] for res in res_names_list]
node_feat = torch.tensor(feature_list, dtype=torch.float32)
lm_embeddings = torch.tensor(np.concatenate(lm_embeddings, axis=0)) if lm_embeddings is not None else None
complex_graph['receptor'].x = torch.cat([node_feat, lm_embeddings], axis=1) if lm_embeddings is not None else node_feat
complex_graph['receptor'].pos = coords
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = edge_index
if all_atoms:
atom_coords = all_coords.reshape(-1, 3)
atom_coords = torch.from_numpy(atom_coords[~np.any(np.isnan(atom_coords), axis=1)]).float()
if knn_only_graph:
atoms_edge_index = knn_graph(atom_coords, k=atom_max_neighbors if atom_max_neighbors else 1000)
else:
atoms_distances = cdist(atom_coords, atom_coords)
atom_src_list = []
atom_dst_list = []
for i in range(len(atom_coords)):
dst = list(np.where(atoms_distances[i, :] < atom_cutoff)[0])
dst.remove(i)
max_neighbors = atom_max_neighbors if atom_max_neighbors else 1000
if max_neighbors != None and len(dst) > max_neighbors:
dst = list(np.argsort(atoms_distances[i, :]))[1: max_neighbors + 1]
if len(dst) == 0:
dst = list(np.argsort(atoms_distances[i, :]))[1:2] # choose second because first is i itself
print(
f'The atom_cutoff {atom_cutoff} was too small for one atom such that it had no neighbors. '
f'So we connected it to the closest other atom')
assert i not in dst
src = [i] * len(dst)
atom_src_list.extend(src)
atom_dst_list.extend(dst)
atoms_edge_index = torch.from_numpy(np.asarray([atom_dst_list, atom_src_list]))
feats = [get_moad_atom_feats(res, all_coords[i]) for i, res in enumerate(seq)]
atom_feat = torch.from_numpy(np.concatenate(feats, axis=0)).float()
c_alpha_idx = np.concatenate([np.zeros(len(f)) + i for i, f in enumerate(feats)])
np_array = np.stack([np.arange(len(atom_feat)), c_alpha_idx])
atom_res_edge_index = torch.from_numpy(np_array).long()
complex_graph['atom'].x = atom_feat
complex_graph['atom'].pos = atom_coords
assert len(complex_graph['atom'].x) == len(complex_graph['atom'].pos)
complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
return
def extract_receptor_structure(rec, lig, lm_embedding_chains=None):
conf = lig.GetConformer()
lig_coords = conf.GetPositions()
min_distances = []
coords = []
c_alpha_coords = []
n_coords = []
c_coords = []
valid_chain_ids = []
lengths = []
for i, chain in enumerate(rec):
chain_coords = [] # num_residues, num_atoms, 3
chain_c_alpha_coords = []
chain_n_coords = []
chain_c_coords = []
count = 0
invalid_res_ids = []
for res_idx, residue in enumerate(chain):
if residue.get_resname() == 'HOH':
invalid_res_ids.append(residue.get_id())
continue
residue_coords = []
c_alpha, n, c = None, None, None
for atom in residue:
if atom.name == 'CA':
c_alpha = list(atom.get_vector())
if atom.name == 'N':
n = list(atom.get_vector())
if atom.name == 'C':
c = list(atom.get_vector())
residue_coords.append(list(atom.get_vector()))
if c_alpha != None and n != None and c != None:
# only append residue if it is an amino acid and not some weird molecule that is part of the complex
chain_c_alpha_coords.append(c_alpha)
chain_n_coords.append(n)
chain_c_coords.append(c)
chain_coords.append(np.array(residue_coords))
count += 1
def get_moad_atom_feats(res, coords):
feats = []
res_long = aa_short2long[res]
res_order = atom_order[res]
for i, c in enumerate(coords):
if np.any(np.isnan(c)):
continue
atom_feats = []
if res == '-':
atom_feats = [safe_index(allowable_features['possible_amino_acids'], 'misc'),
safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
safe_index(allowable_features['possible_atom_type_3'], 'misc')]
else:
atom_feats.append(safe_index(allowable_features['possible_amino_acids'], res_long))
if i >= len(res_order):
atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
safe_index(allowable_features['possible_atom_type_3'], 'misc')])
else:
invalid_res_ids.append(residue.get_id())
for res_id in invalid_res_ids:
chain.detach_child(res_id)
if len(chain_coords) > 0:
all_chain_coords = np.concatenate(chain_coords, axis=0)
distances = spatial.distance.cdist(lig_coords, all_chain_coords)
min_distance = distances.min()
else:
min_distance = np.inf
atom_name = res_order[i]
try:
atomic_num = periodic_table.GetAtomicNumber(atom_name[:1])
except:
print("element", res_order[i][:1], 'not found')
atomic_num = -1
min_distances.append(min_distance)
lengths.append(count)
coords.append(chain_coords)
c_alpha_coords.append(np.array(chain_c_alpha_coords))
n_coords.append(np.array(chain_n_coords))
c_coords.append(np.array(chain_c_coords))
if not count == 0: valid_chain_ids.append(chain.get_id())
min_distances = np.array(min_distances)
if len(valid_chain_ids) == 0:
valid_chain_ids.append(np.argmin(min_distances))
valid_coords = []
valid_c_alpha_coords = []
valid_n_coords = []
valid_c_coords = []
valid_lengths = []
invalid_chain_ids = []
valid_lm_embeddings = []
for i, chain in enumerate(rec):
if chain.get_id() in valid_chain_ids:
valid_coords.append(coords[i])
valid_c_alpha_coords.append(c_alpha_coords[i])
if lm_embedding_chains is not None:
if i >= len(lm_embedding_chains):
raise ValueError('Encountered valid chain id that was not present in the LM embeddings')
valid_lm_embeddings.append(lm_embedding_chains[i])
valid_n_coords.append(n_coords[i])
valid_c_coords.append(c_coords[i])
valid_lengths.append(lengths[i])
else:
invalid_chain_ids.append(chain.get_id())
coords = [item for sublist in valid_coords for item in sublist] # list with n_residues arrays: [n_atoms, 3]
c_alpha_coords = np.concatenate(valid_c_alpha_coords, axis=0) # [n_residues, 3]
n_coords = np.concatenate(valid_n_coords, axis=0) # [n_residues, 3]
c_coords = np.concatenate(valid_c_coords, axis=0) # [n_residues, 3]
lm_embeddings = np.concatenate(valid_lm_embeddings, axis=0) if lm_embedding_chains is not None else None
for invalid_id in invalid_chain_ids:
rec.detach_child(invalid_id)
assert len(c_alpha_coords) == len(n_coords)
assert len(c_alpha_coords) == len(c_coords)
assert sum(valid_lengths) == len(c_alpha_coords)
return rec, coords, c_alpha_coords, n_coords, c_coords, lm_embeddings
atom_feats.extend([safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
safe_index(allowable_features['possible_atom_type_3'], atom_name)])
feats.append(atom_feats)
feats = np.asarray(feats)
return feats
def get_lig_graph(mol, complex_graph):
lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float()
atom_feats = lig_atom_featurizer(mol)
row, col, edge_type = [], [], []
@@ -261,52 +290,77 @@ def get_lig_graph(mol, complex_graph):
edge_attr = F.one_hot(edge_type, num_classes=len(bonds)).to(torch.float)
complex_graph['ligand'].x = atom_feats
complex_graph['ligand'].pos = lig_coords
complex_graph['ligand', 'lig_bond', 'ligand'].edge_index = edge_index
complex_graph['ligand', 'lig_bond', 'ligand'].edge_attr = edge_attr
if mol.GetNumConformers() > 0:
lig_coords = torch.from_numpy(mol.GetConformer().GetPositions()).float()
complex_graph['ligand'].pos = lig_coords
return
def generate_conformer(mol):
ps = AllChem.ETKDGv2()
id = AllChem.EmbedMolecule(mol, ps)
failures, id = 0, -1
while failures < 3 and id == -1:
if failures > 0:
print(f'rdkit coords could not be generated. trying again {failures}.')
id = AllChem.EmbedMolecule(mol, ps)
failures += 1
if id == -1:
print('rdkit coords could not be generated without using random coords. using random coords now.')
ps.useRandomCoords = True
AllChem.EmbedMolecule(mol, ps)
AllChem.MMFFOptimizeMolecule(mol, confId=0)
# else:
# AllChem.MMFFOptimizeMolecule(mol_rdkit, confId=0)
return True
#else:
# AllChem.MMFFOptimizeMolecule(mol, confId=0)
return False
def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs):
def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching, keep_original, num_conformers, remove_hs, tries=10, skip_matching=False):
if matching:
mol_maybe_noh = copy.deepcopy(mol_)
if remove_hs:
mol_maybe_noh = RemoveHs(mol_maybe_noh, sanitize=True)
mol_maybe_noh = AllChem.RemoveAllHs(mol_maybe_noh)
if keep_original:
complex_graph['ligand'].orig_pos = mol_maybe_noh.GetConformer().GetPositions()
positions = []
for conf in mol_maybe_noh.GetConformers():
positions.append(conf.GetPositions())
complex_graph['ligand'].orig_pos = np.asarray(positions) if len(positions) > 1 else positions[0]
rotable_bonds = get_torsion_angles(mol_maybe_noh)
if not rotable_bonds: print("no_rotable_bonds but still using it")
#if not rotable_bonds: print("no_rotable_bonds but still using it")
for i in range(num_conformers):
mol_rdkit = copy.deepcopy(mol_)
mols, rmsds = [], []
for _ in range(tries):
mol_rdkit = copy.deepcopy(mol_)
mol_rdkit.RemoveAllConformers()
mol_rdkit = AllChem.AddHs(mol_rdkit)
generate_conformer(mol_rdkit)
if remove_hs:
mol_rdkit = RemoveHs(mol_rdkit, sanitize=True)
mol = copy.deepcopy(mol_maybe_noh)
if rotable_bonds:
optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter)
mol.AddConformer(mol_rdkit.GetConformer())
rms_list = []
AllChem.AlignMolConformers(mol, RMSlist=rms_list)
mol_rdkit.RemoveAllConformers()
mol_rdkit.AddConformer(mol.GetConformers()[1])
mol_rdkit.RemoveAllConformers()
mol_rdkit = AllChem.AddHs(mol_rdkit)
generate_conformer(mol_rdkit)
if remove_hs:
mol_rdkit = RemoveHs(mol_rdkit, sanitize=True)
mol_rdkit = AllChem.RemoveAllHs(mol_rdkit)
mol = AllChem.RemoveAllHs(copy.deepcopy(mol_maybe_noh))
if rotable_bonds and not skip_matching:
optimize_rotatable_bonds(mol_rdkit, mol, rotable_bonds, popsize=popsize, maxiter=maxiter)
mol.AddConformer(mol_rdkit.GetConformer())
rms_list = []
AllChem.AlignMolConformers(mol, RMSlist=rms_list)
mol_rdkit.RemoveAllConformers()
mol_rdkit.AddConformer(mol.GetConformers()[1])
mols.append(mol_rdkit)
rmsds.append(rms_list[0])
# select molecule with lowest rmsd
#print("mean std min max", np.mean(rmsds), np.std(rmsds), np.min(rmsds), np.max(rmsds))
mol_rdkit = mols[np.argmin(rmsds)]
if i == 0:
complex_graph.rmsd_matching = rms_list[0]
complex_graph.rmsd_matching = min(rmsds)
get_lig_graph(mol_rdkit, complex_graph)
else:
if torch.is_tensor(complex_graph['ligand'].pos):
@@ -325,157 +379,34 @@ def get_lig_graph_with_matching(mol_, complex_graph, popsize, maxiter, matching,
return
def get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, cutoff=20, max_neighbor=None, lm_embeddings=None):
n_rel_pos = n_coords - c_alpha_coords
c_rel_pos = c_coords - c_alpha_coords
num_residues = len(c_alpha_coords)
if num_residues <= 1:
raise ValueError(f"rec contains only 1 residue!")
# Build the k-NN graph
distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords)
src_list = []
dst_list = []
mean_norm_list = []
for i in range(num_residues):
dst = list(np.where(distances[i, :] < cutoff)[0])
dst.remove(i)
if max_neighbor != None and len(dst) > max_neighbor:
dst = list(np.argsort(distances[i, :]))[1: max_neighbor + 1]
if len(dst) == 0:
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
print(f'The c_alpha_cutoff {cutoff} was too small for one c_alpha such that it had no neighbors. '
f'So we connected it to the closest other c_alpha')
assert i not in dst
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend(dst)
valid_dist = list(distances[i, dst])
valid_dist_np = distances[i, dst]
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
assert weights[0].sum() > 1 - 1e-2 and weights[0].sum() < 1.01
diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] # (neigh_num, 3)
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
mean_norm_list.append(mean_vec_ratio_norm)
assert len(src_list) == len(dst_list)
node_feat = rec_residue_featurizer(rec)
mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
side_chain_vecs = torch.from_numpy(
np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1))
complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat
complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float()
complex_graph['receptor'].mu_r_norm = mu_r_norm
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list]))
return
def rec_atom_featurizer(rec):
atom_feats = []
for i, atom in enumerate(rec.get_atoms()):
atom_name, element = atom.name, atom.element
if element == 'CD':
element = 'C'
assert not element == ''
try:
atomic_num = periodic_table.GetAtomicNumber(element)
except:
atomic_num = -1
atom_feat = [safe_index(allowable_features['possible_amino_acids'], atom.get_parent().get_resname()),
safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
safe_index(allowable_features['possible_atom_type_3'], atom_name)]
atom_feats.append(atom_feat)
return atom_feats
def get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors=None, all_atoms=False,
atom_radius=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None):
if all_atoms:
return get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph,
c_alpha_cutoff=rec_radius, c_alpha_max_neighbors=c_alpha_max_neighbors,
atom_cutoff=atom_radius, atom_max_neighbors=atom_max_neighbors, remove_hs=remove_hs,lm_embeddings=lm_embeddings)
def get_rec_misc_atom_feat(bio_atom=None, atom_name=None, element=None, get_misc_features=False):
if get_misc_features:
return [safe_index(allowable_features['possible_amino_acids'], 'misc'),
safe_index(allowable_features['possible_atomic_num_list'], 'misc'),
safe_index(allowable_features['possible_atom_type_2'], 'misc'),
safe_index(allowable_features['possible_atom_type_3'], 'misc')]
if atom_name is not None:
atom_name = atom_name
else:
return get_calpha_graph(rec, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius, c_alpha_max_neighbors,lm_embeddings=lm_embeddings)
atom_name = bio_atom.name
if element is not None:
element = element
else:
element = bio_atom.element
if element == 'CD':
element = 'C'
assert not element == ''
try:
atomic_num = periodic_table.GetAtomicNumber(element.lower().capitalize())
except:
atomic_num = -1
atom_feat = [safe_index(allowable_features['possible_amino_acids'], bio_atom.get_parent().get_resname()),
safe_index(allowable_features['possible_atomic_num_list'], atomic_num),
safe_index(allowable_features['possible_atom_type_2'], (atom_name + '*')[:2]),
safe_index(allowable_features['possible_atom_type_3'], atom_name)]
return atom_feat
def get_fullrec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, c_alpha_cutoff=20,
c_alpha_max_neighbors=None, atom_cutoff=5, atom_max_neighbors=None, remove_hs=False, lm_embeddings=None):
# builds the receptor graph with both residues and atoms
n_rel_pos = n_coords - c_alpha_coords
c_rel_pos = c_coords - c_alpha_coords
num_residues = len(c_alpha_coords)
if num_residues <= 1:
raise ValueError(f"rec contains only 1 residue!")
# Build the k-NN graph of residues
distances = spa.distance.cdist(c_alpha_coords, c_alpha_coords)
src_list = []
dst_list = []
mean_norm_list = []
for i in range(num_residues):
dst = list(np.where(distances[i, :] < c_alpha_cutoff)[0])
dst.remove(i)
if c_alpha_max_neighbors != None and len(dst) > c_alpha_max_neighbors:
dst = list(np.argsort(distances[i, :]))[1: c_alpha_max_neighbors + 1]
if len(dst) == 0:
dst = list(np.argsort(distances[i, :]))[1:2] # choose second because first is i itself
print(f'The c_alpha_cutoff {c_alpha_cutoff} was too small for one c_alpha such that it had no neighbors. '
f'So we connected it to the closest other c_alpha')
assert i not in dst
src = [i] * len(dst)
src_list.extend(src)
dst_list.extend(dst)
valid_dist = list(distances[i, dst])
valid_dist_np = distances[i, dst]
sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
assert 1 - 1e-2 < weights[0].sum() < 1.01
diff_vecs = c_alpha_coords[src, :] - c_alpha_coords[dst, :] # (neigh_num, 3)
mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
mean_norm_list.append(mean_vec_ratio_norm)
assert len(src_list) == len(dst_list)
node_feat = rec_residue_featurizer(rec)
mu_r_norm = torch.from_numpy(np.array(mean_norm_list).astype(np.float32))
side_chain_vecs = torch.from_numpy(
np.concatenate([np.expand_dims(n_rel_pos, axis=1), np.expand_dims(c_rel_pos, axis=1)], axis=1))
complex_graph['receptor'].x = torch.cat([node_feat, torch.tensor(lm_embeddings)], axis=1) if lm_embeddings is not None else node_feat
complex_graph['receptor'].pos = torch.from_numpy(c_alpha_coords).float()
complex_graph['receptor'].mu_r_norm = mu_r_norm
complex_graph['receptor'].side_chain_vecs = side_chain_vecs.float()
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = torch.from_numpy(np.asarray([src_list, dst_list]))
src_c_alpha_idx = np.concatenate([np.asarray([i]*len(l)) for i, l in enumerate(rec_coords)])
atom_feat = torch.from_numpy(np.asarray(rec_atom_featurizer(rec)))
atom_coords = torch.from_numpy(np.concatenate(rec_coords, axis=0)).float()
if remove_hs:
not_hs = (atom_feat[:, 1] != 0)
src_c_alpha_idx = src_c_alpha_idx[not_hs]
atom_feat = atom_feat[not_hs]
atom_coords = atom_coords[not_hs]
atoms_edge_index = radius_graph(atom_coords, atom_cutoff, max_num_neighbors=atom_max_neighbors if atom_max_neighbors else 1000)
atom_res_edge_index = torch.from_numpy(np.asarray([np.arange(len(atom_feat)), src_c_alpha_idx])).long()
complex_graph['atom'].x = atom_feat
complex_graph['atom'].pos = atom_coords
complex_graph['atom', 'atom_contact', 'atom'].edge_index = atoms_edge_index
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
return
def write_mol_with_coords(mol, new_coords, path):
w = Chem.SDWriter(path)
@@ -502,8 +433,8 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
elif molecule_file.endswith('.pdb'):
mol = Chem.MolFromPDBFile(molecule_file, sanitize=False, removeHs=False)
else:
raise ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
return ValueError('Expect the format of the molecule_file to be '
'one of .mol2, .sdf, .pdbqt and .pdb, got {}'.format(molecule_file))
try:
if sanitize or calc_charges:
@@ -518,30 +449,7 @@ def read_molecule(molecule_file, sanitize=False, calc_charges=False, remove_hs=F
if remove_hs:
mol = Chem.RemoveHs(mol, sanitize=sanitize)
except Exception as e:
print(e)
print("RDKit was unable to read the molecule.")
except:
return None
return mol
def read_sdf_or_mol2(sdf_fileName, mol2_fileName):
mol = Chem.MolFromMolFile(sdf_fileName, sanitize=False)
problem = False
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
except Exception as e:
problem = True
if problem:
mol = Chem.MolFromMol2File(mol2_fileName, sanitize=False)
try:
Chem.SanitizeMol(mol)
mol = Chem.RemoveHs(mol)
problem = False
except Exception as e:
problem = True
return mol, problem

View File

@@ -0,0 +1,39 @@
import os
import pickle
from argparse import ArgumentParser
import torch
from tqdm import tqdm
parser = ArgumentParser()
parser.add_argument('--esm_embeddings_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new', help='')
parser.add_argument('--output_path', type=str, default='data/BindingMOAD_2020_ab_processed_biounit/moad_sequences_new.pt', help='')
args = parser.parse_args()
dic = {}
# read text file with all sequences
with open('data/pdb_2021aug02/sequences_to_id.fasta') as f:
lines = f.readlines()
# read sequences
with open('data/pdb_2021aug02/useful_sequences.pkl', 'rb') as f:
sequences = pickle.load(f)
ids = set()
dict_seq_id = {seq[:-1]: str(id) for id, seq in enumerate(lines)}
for i, seq in tqdm(enumerate(sequences)):
ids.add(dict_seq_id[seq])
if i == 20000: break
print("total", len(ids), "out of", len(os.listdir(args.esm_embeddings_path)))
available = set([filename.split('.')[0] for filename in os.listdir(args.esm_embeddings_path)])
final = available.intersection(ids)
for idp in tqdm(final):
dic[idp] = torch.load(os.path.join(args.esm_embeddings_path, idp+'.pt'))['representations'][33]
torch.save(dic,args.output_path)

View File

@@ -1,102 +1,201 @@
name: diffdock
channels:
- pytorch
- conda-forge
- defaults
dependencies:
- blas=1.0
- brotlipy=0.7.0
- boost=1.74.0
- boost-cpp=1.74.0
- brotli=1.0.9
- brotli-bin=1.0.9
- bzip2=1.0.8
- ca-certificates=2022.07.19
- certifi=2022.9.14
- cffi=1.15.1
- charset-normalizer=2.0.4
- cryptography=37.0.1
- ffmpeg=4.3
- freetype=2.11.0
- gettext=0.21.0
- ca-certificates=2022.6.15
- cairo=1.16.0
- certifi=2022.6.15
- cycler=0.11.0
- expat=2.4.8
- font-ttf-dejavu-sans-mono=2.37
- font-ttf-inconsolata=3.000
- font-ttf-source-code-pro=2.038
- font-ttf-ubuntu=0.83
- fontconfig=2.14.0
- fonts-conda-ecosystem=1
- fonts-conda-forge=1
- fonttools=4.33.3
- freetype=2.10.4
- gettext=0.19.8.1
- giflib=5.2.1
- gmp=6.2.1
- gnutls=3.6.15
- icu=58.2
- idna=3.3
- intel-openmp=2021.4.0
- greenlet=1.1.2
- icu=70.1
- jpeg=9e
- lame=3.100
- kiwisolver=1.4.3
- lcms2=2.12
- lerc=3.0
- libblas=3.9.0
- libbrotlicommon=1.0.9
- libbrotlidec=1.0.9
- libbrotlienc=1.0.9
- libcblas=3.9.0
- libcxx=14.0.6
- libdeflate=1.8
- libffi=3.3
- libdeflate=1.12
- libffi=3.4.2
- libgfortran=5.0.0
- libgfortran5=9.3.0
- libglib=2.70.2
- libiconv=1.16
- libidn2=2.3.2
- liblapack=3.9.0
- libopenblas=0.3.20
- libpng=1.6.37
- libtasn1=4.16.0
- libtiff=4.4.0
- libunistring=0.9.10
- libwebp=1.2.2
- libwebp-base=1.2.2
- libxcb=1.13
- libxml2=2.9.14
- llvm-openmp=14.0.6
- libzlib=1.2.12
- llvm-openmp=14.0.4
- lz4-c=1.9.3
- mkl=2021.4.0
- mkl-service=2.4.0
- mkl_fft=1.3.1
- mkl_random=1.2.2
- matplotlib=3.5.2
- matplotlib-base=3.5.2
- munkres=1.1.4
- ncurses=6.3
- nettle=3.7.3
- numpy=1.23.1
- numpy-base=1.23.1
- openh264=2.1.1
- openssl=1.1.1q
- pillow=9.2.0
- pip=22.2.2
- pycparser=2.21
- pyopenssl=22.0.0
- pysocks=1.7.1
- nomkl=3.0
- numpy=1.23.0
- openbabel=3.1.1
- openjpeg=2.4.0
- openssl=3.0.5
- packaging=21.3
- pandas=1.4.3
- pcre=8.45
- pillow=9.1.1
- pip=22.1.2
- pixman=0.40.0
- pthread-stubs=0.4
- pyaml=21.10.1
- pycairo=1.21.0
- pyparsing=3.0.9
- python=3.9.13
- pytorch=1.12.1
- python-dateutil=2.8.2
- python_abi=3.9
- pytz=2022.1
- pyyaml=6.0
- rdkit=2022.03.3
- readline=8.1.2
- requests=2.28.1
- setuptools=63.4.1
- setuptools=62.6.0
- six=1.16.0
- sqlite=3.39.3
- spyrmsd=0.5.2
- sqlalchemy=1.4.39
- sqlite=3.39.0
- tk=8.6.12
- torchaudio=0.12.1
- torchvision=0.13.1
- typing_extensions=4.3.0
- tzdata=2022c
- urllib3=1.26.11
- tornado=6.1
- tzdata=2022a
- unicodedata2=14.0.0
- wheel=0.37.1
- xz=5.2.6
- xorg-libxau=1.0.9
- xorg-libxdmcp=1.1.3
- xz=5.2.5
- yaml=0.2.5
- zlib=1.2.12
- zstd=1.5.2
- pip:
- appnope==0.1.3
- argon2-cffi==21.3.0
- argon2-cffi-bindings==21.2.0
- asttokens==2.0.5
- attrs==21.4.0
- backcall==0.2.0
- beautifulsoup4==4.11.1
- biopandas==0.4.1
- biopython==1.79
- bleach==5.0.1
- cffi==1.15.1
- charset-normalizer==2.1.0
- click==8.1.3
- debugpy==1.6.2
- decorator==5.1.1
- defusedxml==0.7.1
- docker-pycreds==0.4.0
- e3nn==0.5.0
- entrypoints==0.4
- executing==0.8.3
- fastjsonschema==2.15.3
- gitdb==4.0.9
- gitpython==3.1.27
- h5py==3.7.0
- idna==3.3
- ipykernel==6.15.1
- ipython==8.4.0
- ipython-genutils==0.2.0
- ipywidgets==7.7.1
- jedi==0.18.1
- jinja2==3.1.2
- joblib==1.2.0
- joblib==1.1.0
- jsonschema==4.7.1
- jupyter==1.0.0
- jupyter-client==7.3.4
- jupyter-console==6.4.4
- jupyter-core==4.11.1
- jupyterlab-pygments==0.2.2
- jupyterlab-widgets==1.1.1
- kaleido==0.2.1
- markupsafe==2.1.1
- matplotlib-inline==0.1.3
- mistune==0.8.4
- mpmath==1.2.1
- networkx==2.8.7
- nbclient==0.6.6
- nbconvert==6.5.0
- nbformat==5.4.0
- nest-asyncio==1.5.5
- networkx==2.8.4
- notebook==6.4.12
- opt-einsum==3.3.0
- opt-einsum-fx==0.1.4
- packaging==21.3
- pandas==1.5.0
- pyaml==21.10.1
- pyparsing==3.0.9
- python-dateutil==2.8.2
- pytz==2022.4
- pyyaml==6.0
- rdkit-pypi==2022.3.5
- scikit-learn==1.1.2
- scipy==1.9.1
- spyrmsd==0.5.2
- sympy==1.11.1
- pandocfilters==1.5.0
- parso==0.8.3
- pathtools==0.1.2
- pexpect==4.8.0
- pickleshare==0.7.5
- plotly==5.9.0
- prometheus-client==0.14.1
- promise==2.3
- prompt-toolkit==3.0.30
- protobuf==3.20.1
- psutil==5.9.1
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pycparser==2.21
- pygments==2.12.0
- pyrsistent==0.18.1
- pyzmq==23.2.0
- qtconsole==5.3.1
- qtpy==2.1.0
- requests==2.28.1
- scikit-learn==1.1.1
- scipy==1.8.1
- send2trash==1.8.0
- sentry-sdk==1.6.0
- setproctitle==1.2.3
- shortuuid==1.0.9
- smmap==5.0.0
- soupsieve==2.3.2.post1
- stack-data==0.3.0
- sympy==1.10.1
- tenacity==8.0.1
- terminado==0.15.0
- threadpoolctl==3.1.0
- tinycss2==1.1.1
- torch==1.11.0
- torch-cluster==1.6.0
- torch-geometric==2.1.0.post1
- torch-geometric==2.0.4
- torch-scatter==2.0.9
- torch-sparse==0.6.15
- torch-sparse==0.6.14
- torch-spline-conv==1.2.1
- tqdm==4.64.1
- torchaudio==0.11.0
- torchvision==0.12.0
- tqdm==4.64.0
- traitlets==5.3.0
- typing-extensions==4.3.0
- urllib3==1.26.9
- wandb==0.12.20
- wcwidth==0.2.5
- webencodings==0.5.1
- widgetsnbextension==3.6.1

File diff suppressed because it is too large Load Diff

View File

@@ -1,361 +0,0 @@
import os
from argparse import ArgumentParser
import pandas as pd
import plotly.express as px
import numpy as np
import scipy
from utils.utils import read_strings_from_txt
parser = ArgumentParser()
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--results_path', type=str, default='inference_out_dir_not_specified/TEST_top40_epoch75_FILTER_restart_cacheNewRestart_big_ema_ESM2emb_tr34_WITH_fixedSamples28_id1_FILTERFROM_temp_restart_ema_ESM2emb_tr34', help='')
parser.add_argument('--gnina_results_path', type=str, default='results/gnina_rosetta13', help='')
parser.add_argument('--smina_results_path', type=str, default='results/smina_rosetta13', help='')
parser.add_argument('--glide_results_path', type=str, default='results/glide', help='')
parser.add_argument('--qvinaw_results_path', type=str, default='results/qvinaw', help='')
parser.add_argument('--tankbind_results_path', type=str, default='results/tankbind_top5', help='')
parser.add_argument('--equibind_results_path', type=str, default='results/equibind_paper', help='')
parser.add_argument('--no_rec_overlap', action='store_true', default=False, help='')
args = parser.parse_args()
min_cross_distances = np.load(f'{args.results_path}/min_cross_distances.npy')
#min_self_distances = np.load(f'{args.results_path}/min_self_distances.npy')
base_min_cross_distances = np.load(f'{args.results_path}/base_min_cross_distances.npy')
rmsds = np.load(f'{args.results_path}/rmsds.npy')
centroid_distances = np.load(f'{args.results_path}/centroid_distances.npy')
confidences = np.load(f'{args.results_path}/confidences.npy')
#complex_names = np.load(f'{args.results_path}/complex_names.npy')
complex_names = read_strings_from_txt('data/splits/timesplit_test')
if args.no_rec_overlap:
names_no_rec_overlap = read_strings_from_txt(f'data/splits/timesplit_test_no_rec_overlap')
without_rec_overlap_list = []
for name in complex_names:
if name in names_no_rec_overlap:
without_rec_overlap_list.append(1)
else:
without_rec_overlap_list.append(0)
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
rmsds = np.array(rmsds)[without_rec_overlap]
#min_self_distances = np.array(min_self_distances)[without_rec_overlap]
centroid_distances = np.array(centroid_distances)[without_rec_overlap]
confidences = np.array(confidences)[without_rec_overlap]
min_cross_distances = np.array(min_cross_distances)[without_rec_overlap]
base_min_cross_distances = np.array(base_min_cross_distances)[without_rec_overlap]
complex_names = names_no_rec_overlap
N = rmsds.shape[1]
performance_metrics = {
'steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / N).__round__(2),
'mean_rmsd': rmsds.mean(),
'rmsds_below_2': (100 * (rmsds < 2).sum() / len(rmsds) / N),
'rmsds_below_5': (100 * (rmsds < 5).sum() / len(rmsds) / N),
'rmsds_percentile_25': np.percentile(rmsds, 25).round(2),
'rmsds_percentile_50': np.percentile(rmsds, 50).round(2),
'rmsds_percentile_75': np.percentile(rmsds, 75).round(2),
'mean_centroid': centroid_distances.mean().__round__(2),
'centroid_below_2': (100 * (centroid_distances < 2).sum() / len(centroid_distances) / N).__round__(2),
'centroid_below_5': (100 * (centroid_distances < 5).sum() / len(centroid_distances) / N).__round__(2),
'centroid_percentile_25': np.percentile(centroid_distances, 25).round(2),
'centroid_percentile_50': np.percentile(centroid_distances, 50).round(2),
'centroid_percentile_75': np.percentile(centroid_distances, 75).round(2),
}
if N >= 5:
top5_rmsds = np.min(rmsds[:, :5], axis=1)
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :5], axis=1)][ :, 0]
top5_min_cross_distances = min_cross_distances[ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :5], axis=1)][:, 0]
performance_metrics.update({
'top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
'top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
'top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
'top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
'top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
'top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
'top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
'top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
'top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
'top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
'top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
})
if N >= 10:
top10_rmsds = np.min(rmsds[:, :10], axis=1)
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :10], axis=1)][:, 0]
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[:, :10], axis=1)][:, 0]
performance_metrics.update({
'top10_steric_clash_fraction': (100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
'top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
'top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
'top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
'top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
'top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
'top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
'top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
'top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
'top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
'top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
})
confidence_ordering = np.argsort(confidences,axis=1)[:,::-1]
filtered_rmsds = rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,0]
filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,0]
filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, 0]
performance_metrics.update({
'filtered_steric_clash_fraction': (100 * (filtered_min_cross_distances < 0.4).sum() / len(filtered_min_cross_distances)).__round__(2),
'filtered_rmsds_below_2': (100 * (filtered_rmsds < 2).sum() / len(filtered_rmsds)).__round__(2),
'filtered_rmsds_below_5': (100 * (filtered_rmsds < 5).sum() / len(filtered_rmsds)).__round__(2),
'filtered_rmsds_percentile_25': np.percentile(filtered_rmsds, 25).round(2),
'filtered_rmsds_percentile_50': np.percentile(filtered_rmsds, 50).round(2),
'filtered_rmsds_percentile_75': np.percentile(filtered_rmsds, 75).round(2),
'filtered_centroid_below_2': (100 * (filtered_centroid_distances < 2).sum() / len(filtered_centroid_distances)).__round__(2),
'filtered_centroid_below_5': (100 * (filtered_centroid_distances < 5).sum() / len(filtered_centroid_distances)).__round__(2),
'filtered_centroid_percentile_25': np.percentile(filtered_centroid_distances, 25).round(2),
'filtered_centroid_percentile_50': np.percentile(filtered_centroid_distances, 50).round(2),
'filtered_centroid_percentile_75': np.percentile(filtered_centroid_distances, 75).round(2),
})
if N >= 5:
top5_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:5], axis=1)
top5_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:5][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :5], axis=1)][:, 0]
top5_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, :5][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :5], axis=1)][:, 0]
performance_metrics.update({
'top5_filtered_steric_clash_fraction': (100 * (top5_filtered_min_cross_distances < 0.4).sum() / len(top5_filtered_min_cross_distances)).__round__(2),
'top5_filtered_rmsds_below_2': (100 * (top5_filtered_rmsds < 2).sum() / len(top5_filtered_rmsds)).__round__(2),
'top5_filtered_rmsds_below_5': (100 * (top5_filtered_rmsds < 5).sum() / len(top5_filtered_rmsds)).__round__(2),
'top5_filtered_rmsds_percentile_25': np.percentile(top5_filtered_rmsds, 25).round(2),
'top5_filtered_rmsds_percentile_50': np.percentile(top5_filtered_rmsds, 50).round(2),
'top5_filtered_rmsds_percentile_75': np.percentile(top5_filtered_rmsds, 75).round(2),
'top5_filtered_centroid_below_2': (100 * (top5_filtered_centroid_distances < 2).sum() / len(top5_filtered_centroid_distances)).__round__(2),
'top5_filtered_centroid_below_5': (100 * (top5_filtered_centroid_distances < 5).sum() / len(top5_filtered_centroid_distances)).__round__(2),
'top5_filtered_centroid_percentile_25': np.percentile(top5_filtered_centroid_distances, 25).round(2),
'top5_filtered_centroid_percentile_50': np.percentile(top5_filtered_centroid_distances, 50).round(2),
'top5_filtered_centroid_percentile_75': np.percentile(top5_filtered_centroid_distances, 75).round(2),
})
if N >= 10:
top10_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:10], axis=1)
top10_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:,:10][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :10], axis=1)][:, 0]
top10_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], confidence_ordering][:, :10][ np.arange(rmsds.shape[0])[:, None], np.argsort(rmsds[np.arange(rmsds.shape[0])[:,None],confidence_ordering][:, :10], axis=1)][:, 0]
performance_metrics.update({
'top10_filtered_steric_clash_fraction': (100 * (top10_filtered_min_cross_distances < 0.4).sum() / len(top10_filtered_min_cross_distances)).__round__(2),
'top10_filtered_rmsds_below_2': (100 * (top10_filtered_rmsds < 2).sum() / len(top10_filtered_rmsds)).__round__(2),
'top10_filtered_rmsds_below_5': (100 * (top10_filtered_rmsds < 5).sum() / len(top10_filtered_rmsds)).__round__(2),
'top10_filtered_rmsds_percentile_25': np.percentile(top10_filtered_rmsds, 25).round(2),
'top10_filtered_rmsds_percentile_50': np.percentile(top10_filtered_rmsds, 50).round(2),
'top10_filtered_rmsds_percentile_75': np.percentile(top10_filtered_rmsds, 75).round(2),
'top10_filtered_centroid_below_2': (100 * (top10_filtered_centroid_distances < 2).sum() / len(top10_filtered_centroid_distances)).__round__(2),
'top10_filtered_centroid_below_5': (100 * (top10_filtered_centroid_distances < 5).sum() / len(top10_filtered_centroid_distances)).__round__(2),
'top10_filtered_centroid_percentile_25': np.percentile(top10_filtered_centroid_distances, 25).round(2),
'top10_filtered_centroid_percentile_50': np.percentile(top10_filtered_centroid_distances, 50).round(2),
'top10_filtered_centroid_percentile_75': np.percentile(top10_filtered_centroid_distances, 75).round(2),
})
reverse_confidence_ordering = np.argsort(confidences,axis=1)
reverse_filtered_rmsds = rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
reverse_filtered_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
reverse_filtered_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, 0]
performance_metrics.update({
'reversefiltered_steric_clash_fraction': (100 * (reverse_filtered_min_cross_distances < 0.4).sum() / len(reverse_filtered_min_cross_distances)).__round__(2),
'reversefiltered_rmsds_below_2': (100 * (reverse_filtered_rmsds < 2).sum() / len(reverse_filtered_rmsds)).__round__(2),
'reversefiltered_rmsds_below_5': (100 * (reverse_filtered_rmsds < 5).sum() / len(reverse_filtered_rmsds)).__round__(2),
'reversefiltered_rmsds_percentile_25': np.percentile(reverse_filtered_rmsds, 25).round(2),
'reversefiltered_rmsds_percentile_50': np.percentile(reverse_filtered_rmsds, 50).round(2),
'reversefiltered_rmsds_percentile_75': np.percentile(reverse_filtered_rmsds, 75).round(2),
'reversefiltered_centroid_below_2': (100 * (reverse_filtered_centroid_distances < 2).sum() / len(reverse_filtered_centroid_distances)).__round__(2),
'reversefiltered_centroid_below_5': (100 * (reverse_filtered_centroid_distances < 5).sum() / len(reverse_filtered_centroid_distances)).__round__(2),
'reversefiltered_centroid_percentile_25': np.percentile(reverse_filtered_centroid_distances, 25).round(2),
'reversefiltered_centroid_percentile_50': np.percentile(reverse_filtered_centroid_distances, 50).round(2),
'reversefiltered_centroid_percentile_75': np.percentile(reverse_filtered_centroid_distances, 75).round(2),
})
if N >= 5:
top5_reverse_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
top5_reverse_filtered_centroid_distances = np.min(centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
top5_reverse_filtered_min_cross_distances = np.max(min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :5], axis=1)
performance_metrics.update({
'top5_reverse_filtered_steric_clash_fraction': (100 * (top5_reverse_filtered_min_cross_distances < 0.4).sum() / len(top5_reverse_filtered_min_cross_distances)).__round__(2),
'top5_reversefiltered_rmsds_below_2': (100 * (top5_reverse_filtered_rmsds < 2).sum() / len(top5_reverse_filtered_rmsds)).__round__(2),
'top5_reversefiltered_rmsds_below_5': (100 * (top5_reverse_filtered_rmsds < 5).sum() / len(top5_reverse_filtered_rmsds)).__round__(2),
'top5_reversefiltered_rmsds_percentile_25': np.percentile(top5_reverse_filtered_rmsds, 25).round(2),
'top5_reversefiltered_rmsds_percentile_50': np.percentile(top5_reverse_filtered_rmsds, 50).round(2),
'top5_reversefiltered_rmsds_percentile_75': np.percentile(top5_reverse_filtered_rmsds, 75).round(2),
'top5_reversefiltered_centroid_below_2': (100 * (top5_reverse_filtered_centroid_distances < 2).sum() / len(top5_reverse_filtered_centroid_distances)).__round__(2),
'top5_reversefiltered_centroid_below_5': (100 * (top5_reverse_filtered_centroid_distances < 5).sum() / len(top5_reverse_filtered_centroid_distances)).__round__(2),
'top5_reversefiltered_centroid_percentile_25': np.percentile(top5_reverse_filtered_centroid_distances, 25).round(2),
'top5_reversefiltered_centroid_percentile_50': np.percentile(top5_reverse_filtered_centroid_distances, 50).round(2),
'top5_reversefiltered_centroid_percentile_75': np.percentile(top5_reverse_filtered_centroid_distances, 75).round(2),
})
if N >= 10:
top10_reverse_filtered_rmsds = np.min(rmsds[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
top10_reverse_filtered_centroid_distances = np.min(centroid_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
top10_reverse_filtered_min_cross_distances = np.max(min_cross_distances[np.arange(rmsds.shape[0])[:, None], reverse_confidence_ordering][:, :10], axis=1)
performance_metrics.update({
'top10_reverse_filtered_steric_clash_fraction': (100 * (top10_reverse_filtered_min_cross_distances < 0.4).sum() / len(top10_reverse_filtered_min_cross_distances)).__round__(2),
'top10_reversefiltered_rmsds_below_2': (100 * (top10_reverse_filtered_rmsds < 2).sum() / len(top10_reverse_filtered_rmsds)).__round__(2),
'top10_reversefiltered_rmsds_below_5': (100 * (top10_reverse_filtered_rmsds < 5).sum() / len(top10_reverse_filtered_rmsds)).__round__(2),
'top10_reversefiltered_rmsds_percentile_25': np.percentile(top10_reverse_filtered_rmsds, 25).round(2),
'top10_reversefiltered_rmsds_percentile_50': np.percentile(top10_reverse_filtered_rmsds, 50).round(2),
'top10_reversefiltered_rmsds_percentile_75': np.percentile(top10_reverse_filtered_rmsds, 75).round(2),
'top10_reversefiltered_centroid_below_2': (100 * (top10_reverse_filtered_centroid_distances < 2).sum() / len(top10_reverse_filtered_centroid_distances)).__round__(2),
'top10_reversefiltered_centroid_below_5': (100 * (top10_reverse_filtered_centroid_distances < 5).sum() / len(top10_reverse_filtered_centroid_distances)).__round__(2),
'top10_reversefiltered_centroid_percentile_25': np.percentile(top10_reverse_filtered_centroid_distances, 25).round(2),
'top10_reversefiltered_centroid_percentile_50': np.percentile(top10_reverse_filtered_centroid_distances, 50).round(2),
'top10_reversefiltered_centroid_percentile_75': np.percentile(top10_reverse_filtered_centroid_distances, 75).round(2),
})
filtered_confidences = confidences[np.arange(confidences.shape[0])[:,None],confidence_ordering][:,0]
confident_mask = filtered_confidences > 0
confident_rmsds = filtered_rmsds[confident_mask]
confident_centroid_distances = filtered_centroid_distances[confident_mask]
confident_min_cross_distances = filtered_min_cross_distances[confident_mask]
performance_metrics.update({
'fraction_confident_predictions': (100 * len(confident_rmsds) / len(rmsds)).__round__(2),
'confident_steric_clash_fraction': (100 * (confident_min_cross_distances < 0.4).sum() / len(confident_min_cross_distances)).__round__(2),
'confident_rmsds_below_2': (100 * (confident_rmsds < 2).sum() / len(confident_rmsds)).__round__(2),
'confident_rmsds_below_5': (100 * (confident_rmsds < 5).sum() / len(confident_rmsds)).__round__(2),
'confident_rmsds_percentile_25': np.percentile(confident_rmsds, 25).round(2),
'confident_rmsds_percentile_50': np.percentile(confident_rmsds, 50).round(2),
'confident_rmsds_percentile_75': np.percentile(confident_rmsds, 75).round(2),
'confident_centroid_below_2': (100 * (confident_centroid_distances < 2).sum() / len(confident_centroid_distances)).__round__(2),
'confident_centroid_below_5': (100 * (confident_centroid_distances < 5).sum() / len(confident_centroid_distances)).__round__(2),
'confident_centroid_percentile_25': np.percentile(confident_centroid_distances, 25).round(2),
'confident_centroid_percentile_50': np.percentile(confident_centroid_distances, 50).round(2),
'confident_centroid_percentile_75': np.percentile(confident_centroid_distances, 75).round(2),
})
for k in performance_metrics:
print(k, performance_metrics[k])
fraction_dataset_rmsds_below_2 = []
perfect_calibration = []
no_calibration = []
for dataset_percentage in range(100):
dataset_percentage += 1
dataset_fraction = (dataset_percentage)/100
num_samples = round(len(rmsds)*dataset_fraction)
per_complex_confidence_ordering = np.argsort(filtered_confidences)[::-1]
confident_complexes_rmsds = filtered_rmsds[per_complex_confidence_ordering][:num_samples]
confident_complexes_centroid_distances = filtered_centroid_distances[per_complex_confidence_ordering][:num_samples]
confident_complexes_min_cross_distances = filtered_min_cross_distances[per_complex_confidence_ordering][:num_samples]
confident_complexes_metrics = {
'fraction_confident_complexes_predictions': (100 * len(confident_complexes_rmsds) / len(rmsds)).__round__(2),
'confident_complexes_steric_clash_fraction': (100 * (confident_complexes_min_cross_distances < 0.4).sum() / len(confident_complexes_min_cross_distances)).__round__(2),
'confident_complexes_rmsds_below_2': (100 * (confident_complexes_rmsds < 2).sum() / len(confident_complexes_rmsds)).__round__(2),
'confident_complexes_rmsds_below_5': (100 * (confident_complexes_rmsds < 5).sum() / len(confident_complexes_rmsds)).__round__(2),
'confident_complexes_rmsds_percentile_25': np.percentile(confident_complexes_rmsds, 25).round(2),
'confident_complexes_rmsds_percentile_50': np.percentile(confident_complexes_rmsds, 50).round(2),
'confident_complexes_rmsds_percentile_75': np.percentile(confident_complexes_rmsds, 75).round(2),
'confident_complexes_centroid_below_2': (100 * (confident_complexes_centroid_distances < 2).sum() / len(confident_complexes_centroid_distances)).__round__(2),
'confident_complexes_centroid_below_5': (100 * (confident_complexes_centroid_distances < 5).sum() / len(confident_complexes_centroid_distances)).__round__(2),
'confident_complexes_centroid_percentile_25': np.percentile(confident_complexes_centroid_distances, 25).round(2),
'confident_complexes_centroid_percentile_50': np.percentile(confident_complexes_centroid_distances, 50).round(2),
'confident_complexes_centroid_percentile_75': np.percentile(confident_complexes_centroid_distances, 75).round(2),
}
fraction_dataset_rmsds_below_2.append(confident_complexes_metrics['confident_complexes_rmsds_below_2'])
perfect_calibration.append((100 * (np.sort(filtered_rmsds)[:num_samples] < 2).sum() / len(confident_complexes_rmsds)).__round__(2))
no_calibration.append(performance_metrics['filtered_rmsds_below_2'])
#print('percentage: ',dataset_percentage)
#print(confident_complexes_metrics['confident_complexes_rmsds_below_2'])
print(scipy.stats.spearmanr(filtered_rmsds, filtered_confidences))
df = {'conf': filtered_confidences, 'rmsd': filtered_rmsds}
fig = px.scatter(df, x='rmsd',y='conf').update_layout(
xaxis_title="Percentage of datapoints that may be abstained", yaxis_title="Percentage of predictions with RMSD < 2A"
)
fig.update_layout(margin={'l': 0, 'r': 0, 't': 20, 'b': 100}, plot_bgcolor='white',
paper_bgcolor='white', legend_title_text='', legend_title_font_size=1,
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ),
)
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
fig.show()
df = {'Confidence Model': reversed(fraction_dataset_rmsds_below_2),'No Calibration': reversed(no_calibration),'Perfect Calibration': reversed(perfect_calibration),}
fig = px.line(df, y=list(df.keys())).update_layout(
xaxis_title="Percentage of datapoints that may be abstained", yaxis_title="Percentage of predictions with RMSD < 2A"
)
fig.update_yaxes(range = [0,103])
fig.update_layout(margin={'l': 0, 'r': 0, 't': 20, 'b': 100}, plot_bgcolor='white',
paper_bgcolor='white', legend_title_text='', legend_title_font_size=1,
legend=dict(yanchor="bottom", y=0.1, xanchor="right", x=0.99, font=dict(size=17), ),
)
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=19),mirror=True,ticks='outside',showline=True,)
fig.write_image('results/confidence_calibration.pdf')
fig.show()
def filter_by_names(method_names, method_array, names_to_keep):
output_array = []
output_names = []
for method_name, array_element in zip(method_names,method_array):
if method_name in names_to_keep:
output_array.append(array_element)
output_names.append(method_name)
return np.array(output_array), np.array(output_names)
qvinaw_rmsds = np.load(os.path.join(args.qvinaw_results_path, 'rmsds.npy'))
qvinaw_names = np.load(os.path.join(args.qvinaw_results_path, 'names.npy'))
qvinaw_rmsds, qvinaw_names = filter_by_names(qvinaw_names, qvinaw_rmsds, complex_names)
qvinaw_rmsds = np.concatenate([qvinaw_rmsds, np.random.choice(qvinaw_rmsds, size=len(complex_names) - len(qvinaw_rmsds))])
glide_rmsds = np.load(os.path.join(args.glide_results_path, 'rmsds.npy'))
glide_names = np.load(os.path.join(args.glide_results_path, 'names.npy')).tolist()
glide_rmsds, glide_names = filter_by_names(glide_names, glide_rmsds, complex_names)
glide_rmsds = np.concatenate([glide_rmsds, np.random.choice(glide_rmsds, size=len(complex_names) - len(glide_rmsds))])
smina_rmsds = np.load(os.path.join(args.smina_results_path, 'rmsds.npy'))[:,0]
smina_names = np.load(os.path.join(args.smina_results_path, 'names.npy'))
smina_rmsds, smina_names = filter_by_names(smina_names, smina_rmsds, complex_names)
smina_rmsds = np.concatenate([smina_rmsds, np.random.choice(smina_rmsds, size=len(complex_names) - len(smina_rmsds))])
gnina_rmsds = np.load(os.path.join(args.gnina_results_path, 'rmsds.npy'))[:,0]
gnina_names = np.load(os.path.join(args.gnina_results_path, 'names.npy'))
gnina_rmsds, gnina_names = filter_by_names(gnina_names, gnina_rmsds, complex_names)
gnina_rmsds = np.concatenate([gnina_rmsds, np.random.choice(gnina_rmsds, size=len(complex_names) - len(gnina_rmsds))])
tankbind_rmsds = np.load(os.path.join(args.tankbind_results_path, 'rmsds.npy'))[:,0]
tankbind_names = np.load(os.path.join(args.tankbind_results_path, 'names.npy'))
tankbind_rmsds, tankbind_names = filter_by_names(tankbind_names, tankbind_rmsds, complex_names)
equibind_rmsds = np.load(os.path.join(args.equibind_results_path, 'rmsds.npy'))
equibind_names = np.load(os.path.join(args.equibind_results_path, 'names.npy'))
equibind_rmsds, equibind_names = filter_by_names(equibind_names, equibind_rmsds, complex_names)
df = {'DiffDock': filtered_rmsds, 'GLIDE': glide_rmsds, 'GNINA': gnina_rmsds, 'SMINA': smina_rmsds, 'QVinaW':qvinaw_rmsds, 'TANKBind': tankbind_rmsds, 'EquiBind': equibind_rmsds}
fig = px.ecdf(df, range_x=[0, 5], range_y=[0.001, 0.75], width=600, height=400)
fig.add_vline(x=2, annotation_text='', annotation_font_size=20, annotation_position="top right",
line_dash='dash', line_color='firebrick', annotation_font_color='firebrick')
fig.update_xaxes(title=f'RMSD (Å)')
fig.update_yaxes(title=f'Fraction with lower RMSD')
fig.update_layout(autosize=False, margin={'l': 65, 'r': 5, 't': 5, 'b': 60}, plot_bgcolor='white',
paper_bgcolor='white', legend_title_text='', legend_title_font_size=18,
legend=dict(yanchor="top", y=0.995, xanchor="left", x=0.02, font=dict(size=18, color='black'), ), )
fig.update_xaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=23, color='black'),mirror=True,ticks='outside',showline=True, linewidth=1, linecolor='black', tickfont = dict(size = 18, color='black'))
fig.update_yaxes(showgrid=True, gridcolor='lightgrey',title_font=dict(size=23, color='black'),mirror=True,ticks='outside',showline=True, linewidth=1, linecolor='black', tickfont = dict(size = 18, color='black'))
fig.update_traces(line=dict(width=3))
fig.write_image('results/rmsds_nooverlap.pdf')
fig.show()

View File

@@ -1,180 +0,0 @@
# small script to extract the ligand and save it in a separate file because GNINA will use the ligand position as initial pose
import os
import time
from argparse import FileType, ArgumentParser
import numpy as np
from biopandas.pdb import PandasPdb
from rdkit import Chem
from tqdm import tqdm
from datasets.pdbbind import read_mol
from datasets.process_mols import read_molecule
from utils.utils import read_strings_from_txt, get_symmetry_rmsd
parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed', help='')
parser.add_argument('--results_path', type=str, default='results/user_predictions_testset', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--file_suffix', type=str, default='_baseline_ligand.pdb', help='Path to folder with trained model and hyperparameters')
parser.add_argument('--project', type=str, default='ligbind_inf', help='')
parser.add_argument('--file_to_exclude', type=str, default=None, help='')
parser.add_argument('--all_dirs_in_results', action='store_true', default=True, help='Evaluate all directories in the results path instead of using directly looking for the names')
parser.add_argument('--num_predictions', type=int, default=10, help='')
parser.add_argument('--no_id_in_filename', action='store_true', default=False, help='')
parser.add_argument('--test_names_path', type=str, default='data/splits/timesplit_test', help='Path to text file with the folder names in the test set')
parser.add_argument('--no_overlap_names_path', type=str, default='data/splits/timesplit_test_no_rec_overlap', help='Path text file with the folder names in the test set that have no receptor overlap with the train set')
args = parser.parse_args()
print('Reading paths and names.')
names = read_strings_from_txt(args.test_names_path)
names_no_rec_overlap = read_strings_from_txt(args.no_overlap_names_path)
results_path_containments = os.listdir(args.results_path)
all_times = []
successful_names_list = []
rmsds_list = []
centroid_distances_list = []
min_cross_distances_list = []
min_self_distances_list = []
without_rec_overlap_list = []
start_time = time.time()
for i, name in enumerate(tqdm(names)):
mol = read_mol(args.data_dir, name, remove_hs=True)
mol = Chem.RemoveAllHs(mol)
orig_ligand_pos = np.array(mol.GetConformer().GetPositions())
if args.all_dirs_in_results:
directory_with_name_list = [directory for directory in results_path_containments if name in directory]
if directory_with_name_list == []:
print('Did not find a directory for ', name, '. We are skipping that complex')
continue
else:
directory_with_name = directory_with_name_list[0]
ligand_pos = []
debug_paths = []
for i in range(args.num_predictions):
file_paths = sorted(os.listdir(os.path.join(args.results_path, directory_with_name)))
if args.file_to_exclude is not None:
file_paths = [path for path in file_paths if not args.file_to_exclude in path]
file_path = [path for path in file_paths if f'rank{i+1}_' in path][0]
mol_pred = read_molecule(os.path.join(args.results_path, directory_with_name, file_path),remove_hs=True, sanitize=True)
mol_pred = Chem.RemoveAllHs(mol_pred)
ligand_pos.append(mol_pred.GetConformer().GetPositions())
debug_paths.append(file_path)
ligand_pos = np.asarray(ligand_pos)
else:
if not os.path.exists(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}')): raise Exception('path did not exists:', os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'))
mol_pred = read_molecule(os.path.join(args.results_path, name, f'{"" if args.no_id_in_filename else name}{args.file_suffix}'), remove_hs=True, sanitize=True)
if mol_pred == None:
print("Skipping ", name, ' because RDKIT could not read it.')
continue
mol_pred = Chem.RemoveAllHs(mol_pred)
ligand_pos = np.asarray([np.array(mol_pred.GetConformer(i).GetPositions()) for i in range(args.num_predictions)])
try:
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos, [l for l in ligand_pos], mol_pred)
except Exception as e:
print("Using non corrected RMSD because of the error:", e)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsds_list.append(rmsd)
centroid_distances_list.append(np.linalg.norm(ligand_pos.mean(axis=1) - orig_ligand_pos[None,:].mean(axis=1), axis=1))
rec_path = os.path.join(args.data_dir, name, f'{name}_protein_processed.pdb')
if not os.path.exists(rec_path):
rec_path = os.path.join(args.data_dir, name,f'{name}_protein_obabel_reduce.pdb')
rec = PandasPdb().read_pdb(rec_path)
rec_df = rec.df['ATOM']
receptor_pos = rec_df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
receptor_pos = np.tile(receptor_pos, (args.num_predictions, 1, 1))
cross_distances = np.linalg.norm(receptor_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
self_distances = np.linalg.norm(ligand_pos[:, :, None, :] - ligand_pos[:, None, :, :], axis=-1)
self_distances = np.where(np.eye(self_distances.shape[2]), np.inf, self_distances)
min_cross_distances_list.append(np.min(cross_distances, axis=(1,2)))
min_self_distances_list.append(np.min(self_distances, axis=(1, 2)))
successful_names_list.append(name)
without_rec_overlap_list.append(1 if name in names_no_rec_overlap else 0)
performance_metrics = {}
for overlap in ['', 'no_overlap_']:
if 'no_overlap_' == overlap:
without_rec_overlap = np.array(without_rec_overlap_list, dtype=bool)
rmsds = np.array(rmsds_list)[without_rec_overlap]
centroid_distances = np.array(centroid_distances_list)[without_rec_overlap]
min_cross_distances = np.array(min_cross_distances_list)[without_rec_overlap]
min_self_distances = np.array(min_self_distances_list)[without_rec_overlap]
successful_names = np.array(successful_names_list)[without_rec_overlap]
else:
rmsds = np.array(rmsds_list)
centroid_distances = np.array(centroid_distances_list)
min_cross_distances = np.array(min_cross_distances_list)
min_self_distances = np.array(min_self_distances_list)
successful_names = np.array(successful_names_list)
np.save(os.path.join(args.results_path, f'{overlap}rmsds.npy'), rmsds)
np.save(os.path.join(args.results_path, f'{overlap}names.npy'), successful_names)
np.save(os.path.join(args.results_path, f'{overlap}min_cross_distances.npy'), np.array(min_cross_distances))
np.save(os.path.join(args.results_path, f'{overlap}min_self_distances.npy'), np.array(min_self_distances))
performance_metrics.update({
f'{overlap}steric_clash_fraction': (100 * (min_cross_distances < 0.4).sum() / len(min_cross_distances) / args.num_predictions).__round__(2),
f'{overlap}self_intersect_fraction': (100 * (min_self_distances < 0.4).sum() / len(min_self_distances) / args.num_predictions).__round__(2),
f'{overlap}mean_rmsd': rmsds[:,0].mean(),
f'{overlap}rmsds_below_2': (100 * (rmsds[:,0] < 2).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_below_5': (100 * (rmsds[:,0] < 5).sum() / len(rmsds[:,0])),
f'{overlap}rmsds_percentile_25': np.percentile(rmsds[:,0], 25).round(2),
f'{overlap}rmsds_percentile_50': np.percentile(rmsds[:,0], 50).round(2),
f'{overlap}rmsds_percentile_75': np.percentile(rmsds[:,0], 75).round(2),
f'{overlap}mean_centroid': centroid_distances[:,0].mean().__round__(2),
f'{overlap}centroid_below_2': (100 * (centroid_distances[:,0] < 2).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_below_5': (100 * (centroid_distances[:,0] < 5).sum() / len(centroid_distances[:,0])).__round__(2),
f'{overlap}centroid_percentile_25': np.percentile(centroid_distances[:,0], 25).round(2),
f'{overlap}centroid_percentile_50': np.percentile(centroid_distances[:,0], 50).round(2),
f'{overlap}centroid_percentile_75': np.percentile(centroid_distances[:,0], 75).round(2),
})
top5_rmsds = np.min(rmsds[:, :5], axis=1)
top5_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
top5_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
top5_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :5], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top5_steric_clash_fraction': (100 * (top5_min_cross_distances < 0.4).sum() / len(top5_min_cross_distances)).__round__(2),
f'{overlap}top5_self_intersect_fraction': (100 * (top5_min_self_distances < 0.4).sum() / len(top5_min_self_distances)).__round__(2),
f'{overlap}top5_rmsds_below_2': (100 * (top5_rmsds < 2).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_below_5': (100 * (top5_rmsds < 5).sum() / len(top5_rmsds)).__round__(2),
f'{overlap}top5_rmsds_percentile_25': np.percentile(top5_rmsds, 25).round(2),
f'{overlap}top5_rmsds_percentile_50': np.percentile(top5_rmsds, 50).round(2),
f'{overlap}top5_rmsds_percentile_75': np.percentile(top5_rmsds, 75).round(2),
f'{overlap}top5_centroid_below_2': (100 * (top5_centroid_distances < 2).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_below_5': (100 * (top5_centroid_distances < 5).sum() / len(top5_centroid_distances)).__round__(2),
f'{overlap}top5_centroid_percentile_25': np.percentile(top5_centroid_distances, 25).round(2),
f'{overlap}top5_centroid_percentile_50': np.percentile(top5_centroid_distances, 50).round(2),
f'{overlap}top5_centroid_percentile_75': np.percentile(top5_centroid_distances, 75).round(2),
})
top10_rmsds = np.min(rmsds[:, :10], axis=1)
top10_centroid_distances = centroid_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
top10_min_cross_distances = min_cross_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
top10_min_self_distances = min_self_distances[np.arange(rmsds.shape[0])[:,None],np.argsort(rmsds[:, :10], axis=1)][:,0]
performance_metrics.update({
f'{overlap}top10_self_intersect_fraction': (100 * (top10_min_self_distances < 0.4).sum() / len(top10_min_self_distances)).__round__(2),
f'{overlap}top10_steric_clash_fraction': ( 100 * (top10_min_cross_distances < 0.4).sum() / len(top10_min_cross_distances)).__round__(2),
f'{overlap}top10_rmsds_below_2': (100 * (top10_rmsds < 2).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_below_5': (100 * (top10_rmsds < 5).sum() / len(top10_rmsds)).__round__(2),
f'{overlap}top10_rmsds_percentile_25': np.percentile(top10_rmsds, 25).round(2),
f'{overlap}top10_rmsds_percentile_50': np.percentile(top10_rmsds, 50).round(2),
f'{overlap}top10_rmsds_percentile_75': np.percentile(top10_rmsds, 75).round(2),
f'{overlap}top10_centroid_below_2': (100 * (top10_centroid_distances < 2).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_below_5': (100 * (top10_centroid_distances < 5).sum() / len(top10_centroid_distances)).__round__(2),
f'{overlap}top10_centroid_percentile_25': np.percentile(top10_centroid_distances, 25).round(2),
f'{overlap}top10_centroid_percentile_50': np.percentile(top10_centroid_distances, 50).round(2),
f'{overlap}top10_centroid_percentile_75': np.percentile(top10_centroid_distances, 75).round(2),
})
for k in performance_metrics:
print(k, performance_metrics[k])

View File

@@ -1,13 +1,13 @@
import copy
import os
import torch
from argparse import ArgumentParser, Namespace
from rdkit.Chem import RemoveHs
from argparse import ArgumentParser, Namespace, FileType
from functools import partial
import numpy as np
import pandas as pd
from rdkit import RDLogger
from torch_geometric.loader import DataLoader
from rdkit.Chem import RemoveAllHs
from datasets.process_mols import write_mol_with_coords
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, get_t_schedule
@@ -20,7 +20,8 @@ from tqdm import tqdm
RDLogger.DisableLog('rdApp.*')
import yaml
parser = ArgumentParser()
parser.add_argument('--protein_ligand_csv', type=str, default=None, help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters')
parser.add_argument('--config', type=FileType(mode='r'), default='inference_args.yaml')
parser.add_argument('--protein_ligand_csv', type=str, default="data/protein_ligand_example_csv.csv", help='Path to a .csv file specifying the input as described in the README. If this is not None, it will be used instead of the --protein_path, --protein_sequence and --ligand parameters')
parser.add_argument('--complex_name', type=str, default='1a0q', help='Name that the complex will be saved with')
parser.add_argument('--protein_path', type=str, default=None, help='Path to the protein file')
parser.add_argument('--protein_sequence', type=str, default=None, help='Sequence of the protein for ESMFold, this is ignored if --protein_path is not None')
@@ -30,17 +31,51 @@ parser.add_argument('--out_dir', type=str, default='results/user_inference', hel
parser.add_argument('--save_visualisation', action='store_true', default=False, help='Save a pdb file with all of the steps of the reverse diffusion')
parser.add_argument('--samples_per_complex', type=int, default=10, help='Number of samples to generate')
parser.add_argument('--model_dir', type=str, default='workdir/paper_score_model', help='Path to folder with trained score model and hyperparameters')
parser.add_argument('--model_dir', type=str, default=None, help='Path to folder with trained score model and hyperparameters')
parser.add_argument('--ckpt', type=str, default='best_ema_inference_epoch_model.pt', help='Checkpoint to use for the score model')
parser.add_argument('--confidence_model_dir', type=str, default='workdir/paper_confidence_model', help='Path to folder with trained confidence model and hyperparameters')
parser.add_argument('--confidence_ckpt', type=str, default='best_model_epoch75.pt', help='Checkpoint to use for the confidence model')
parser.add_argument('--confidence_model_dir', type=str, default=None, help='Path to folder with trained confidence model and hyperparameters')
parser.add_argument('--confidence_ckpt', type=str, default='best_model.pt', help='Checkpoint to use for the confidence model')
parser.add_argument('--batch_size', type=int, default=32, help='')
parser.add_argument('--no_final_step_noise', action='store_true', default=False, help='Use no noise in the final step of the reverse diffusion')
parser.add_argument('--batch_size', type=int, default=10, help='')
parser.add_argument('--no_final_step_noise', action='store_true', default=True, help='Use no noise in the final step of the reverse diffusion')
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps')
parser.add_argument('--actual_steps', type=int, default=None, help='Number of denoising steps that are actually performed')
parser.add_argument('--old_score_model', action='store_true', default=False, help='')
parser.add_argument('--old_confidence_model', action='store_true', default=True, help='')
parser.add_argument('--initial_noise_std_proportion', type=float, default=-1.0, help='Initial noise std proportion')
parser.add_argument('--choose_residue', action='store_true', default=False, help='')
parser.add_argument('--temp_sampling_tr', type=float, default=1.0)
parser.add_argument('--temp_psi_tr', type=float, default=0.0)
parser.add_argument('--temp_sigma_data_tr', type=float, default=0.5)
parser.add_argument('--temp_sampling_rot', type=float, default=1.0)
parser.add_argument('--temp_psi_rot', type=float, default=0.0)
parser.add_argument('--temp_sigma_data_rot', type=float, default=0.5)
parser.add_argument('--temp_sampling_tor', type=float, default=1.0)
parser.add_argument('--temp_psi_tor', type=float, default=0.0)
parser.add_argument('--temp_sigma_data_tor', type=float, default=0.5)
parser.add_argument('--gnina_minimize', action='store_true', default=False, help='')
parser.add_argument('--gnina_path', type=str, default='gnina', help='')
parser.add_argument('--gnina_log_file', type=str, default='gnina_log.txt', help='') # To redirect gnina subprocesses stdouts from the terminal window
parser.add_argument('--gnina_full_dock', action='store_true', default=False, help='')
parser.add_argument('--gnina_autobox_add', type=float, default=4.0)
parser.add_argument('--gnina_poses_to_optimize', type=int, default=1)
args = parser.parse_args()
if args.config:
config_dict = yaml.load(args.config, Loader=yaml.FullLoader)
arg_dict = args.__dict__
for key, value in config_dict.items():
if isinstance(value, list):
for v in value:
arg_dict[key].append(v)
else:
arg_dict[key] = value
# TODO check that the args are actually updated
os.makedirs(args.out_dir, exist_ok=True)
with open(f'{args.model_dir}/model_parameters.yml') as f:
score_model_args = Namespace(**yaml.full_load(f))
@@ -70,11 +105,12 @@ for name in complex_name_list:
# preprocessing of complexes into geometric graphs
test_dataset = InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list,
ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list,
lm_embeddings=score_model_args.esm_embeddings_path is not None,
lm_embeddings=True,
receptor_radius=score_model_args.receptor_radius, remove_hs=score_model_args.remove_hs,
c_alpha_max_neighbors=score_model_args.c_alpha_max_neighbors,
all_atoms=score_model_args.all_atoms, atom_radius=score_model_args.atom_radius,
atom_max_neighbors=score_model_args.atom_max_neighbors)
atom_max_neighbors=score_model_args.atom_max_neighbors,
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph)
test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False)
if args.confidence_model_dir is not None and not confidence_args.use_original_model_cache:
@@ -83,25 +119,27 @@ if args.confidence_model_dir is not None and not confidence_args.use_original_mo
confidence_test_dataset = \
InferenceDataset(out_dir=args.out_dir, complex_names=complex_name_list, protein_files=protein_path_list,
ligand_descriptions=ligand_description_list, protein_sequences=protein_sequence_list,
lm_embeddings=confidence_args.esm_embeddings_path is not None,
lm_embeddings=True,
receptor_radius=confidence_args.receptor_radius, remove_hs=confidence_args.remove_hs,
c_alpha_max_neighbors=confidence_args.c_alpha_max_neighbors,
all_atoms=confidence_args.all_atoms, atom_radius=confidence_args.atom_radius,
atom_max_neighbors=confidence_args.atom_max_neighbors,
precomputed_lm_embeddings=test_dataset.lm_embeddings)
precomputed_lm_embeddings=test_dataset.lm_embeddings,
knn_only_graph=False if not hasattr(score_model_args, 'not_knn_only_graph') else not score_model_args.not_knn_only_graph)
else:
confidence_test_dataset = None
t_to_sigma = partial(t_to_sigma_compl, args=score_model_args)
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True)
model = get_model(score_model_args, device, t_to_sigma=t_to_sigma, no_parallel=True, old=args.old_score_model)
state_dict = torch.load(f'{args.model_dir}/{args.ckpt}', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
model.eval()
if args.confidence_model_dir is not None:
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True, confidence_mode=True)
confidence_model = get_model(confidence_args, device, t_to_sigma=t_to_sigma, no_parallel=True,
confidence_mode=True, old=args.old_confidence_model)
state_dict = torch.load(f'{args.confidence_model_dir}/{args.confidence_ckpt}', map_location=torch.device('cpu'))
confidence_model.load_state_dict(state_dict, strict=True)
confidence_model = confidence_model.to(device)
@@ -110,7 +148,7 @@ else:
confidence_model = None
confidence_args = None
tr_schedule = get_t_schedule(inference_steps=args.inference_steps)
tr_schedule = get_t_schedule(inference_steps=args.inference_steps, sigma_schedule='expbeta')
failures, skipped = 0, 0
N = args.samples_per_complex
@@ -131,7 +169,10 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
else:
confidence_data_list = None
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(N)]
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max)
randomize_position(data_list, score_model_args.no_torsion, False, score_model_args.tr_sigma_max,
initial_noise_std_proportion=args.initial_noise_std_proportion,
choose_residue=args.choose_residue)
lig = orig_complex_graph.mol[0]
# initialize visualisation
@@ -154,7 +195,13 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
device=device, t_to_sigma=t_to_sigma, model_args=score_model_args,
visualization_list=visualization_list, confidence_model=confidence_model,
confidence_data_list=confidence_data_list, confidence_model_args=confidence_args,
batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise)
batch_size=args.batch_size, no_final_step_noise=args.no_final_step_noise,
temp_sampling=[args.temp_sampling_tr, args.temp_sampling_rot,
args.temp_sampling_tor],
temp_psi=[args.temp_psi_tr, args.temp_psi_rot, args.temp_psi_tor],
temp_sigma_data=[args.temp_sigma_data_tr, args.temp_sigma_data_rot,
args.temp_sigma_data_tor])
ligand_pos = np.asarray([complex_graph['ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy() for complex_graph in data_list])
# reorder predictions based on confidence output
@@ -170,7 +217,7 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
write_dir = f'{args.out_dir}/{complex_name_list[idx]}'
for rank, pos in enumerate(ligand_pos):
mol_pred = copy.deepcopy(lig)
if score_model_args.remove_hs: mol_pred = RemoveHs(mol_pred)
if score_model_args.remove_hs: mol_pred = RemoveAllHs(mol_pred)
if rank == 0: write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}.sdf'))
write_mol_with_coords(mol_pred, pos, os.path.join(write_dir, f'rank{rank+1}_confidence{confidence[rank]:.2f}.sdf'))
@@ -189,6 +236,4 @@ for idx, orig_complex_graph in tqdm(enumerate(test_loader)):
print(f'Failed for {failures} complexes')
print(f'Skipped {skipped} complexes')
print(f'Results are in {args.out_dir}')
print(f'Results are in {args.out_dir}')

30
inference_args.yaml Normal file
View File

@@ -0,0 +1,30 @@
actual_steps: 19
ckpt: best_ema_inference_epoch_model.pt
confidence_ckpt: best_model_epoch75.pt
confidence_model_dir: workdir/paper_confidence_model
different_schedules: false
inf_sched_alpha: 1
inf_sched_beta: 1
inference_steps: 20
initial_noise_std_proportion: 1.4601642460337794
limit_failures: 5
model_dir: workdir/diffdockL
no_final_step_noise: true
no_model: false
no_random: false
no_random_pocket: false
ode: false
old_filtering_model: true
old_score_model: false
resample_rdkit: false
samples_per_complex: 10
sigma_schedule: expbeta
temp_psi_rot: 0.9022615585677628
temp_psi_tor: 0.5946212391366862
temp_psi_tr: 0.727287304570729
temp_sampling_rot: 2.06391612594481
temp_sampling_tor: 7.044261621607846
temp_sampling_tr: 1.170050527854316
temp_sigma_data_rot: 0.7464326999906034
temp_sigma_data_tor: 0.6943254174849822
temp_sigma_data_tr: 0.9299802531572672

667
models/aa_model.py Normal file
View File

@@ -0,0 +1,667 @@
from e3nn import o3
import torch
from esm.pretrained import load_model_and_alphabet
from torch import nn
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_geometric.utils import subgraph
from torch_scatter import scatter_mean
import numpy as np
from models.layers import GaussianSmearing, AtomEncoder
from models.tensor_layers import get_irrep_seq, TensorProductConvLayer
from utils import so3, torus
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
AGGREGATORS = {"mean": lambda x: torch.mean(x, dim=1),
"max": lambda x: torch.max(x, dim=1)[0],
"min": lambda x: torch.min(x, dim=1)[0],
"std": lambda x: torch.std(x, dim=1)}
class AAModel(torch.nn.Module):
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
separate_noise_schedule=False, lm_embedding_type=False, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm = False,
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
parallel_aggregators="mean max min std", num_confidence_outputs=1, atom_num_confidence_outputs=1, fixed_center_conv=False,
no_aminoacid_identities=False, include_miscellaneous_atoms=False,
differentiate_convolutions=True, tp_weights_layers=2, num_prot_emb_layers=0,
reduce_pseudoscalars=False, embed_also_ligand=False, atom_confidence=False, sidechain_pred=False,
depthwise_convolution=False, crop_beyond=None):
super(AAModel, self).__init__()
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
assert not sidechain_pred, "sidechain prediction not implemented/makes sense for all atom model"
assert not depthwise_convolution, "depthwise convolution not implemented for all atom model"
if parallel > 1: assert affinity_prediction
self.t_to_sigma = t_to_sigma
self.in_lig_edge_features = in_lig_edge_features
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
self.sigma_embed_dim = sigma_embed_dim
self.lig_max_radius = lig_max_radius
self.rec_max_radius = rec_max_radius
self.cross_max_distance = cross_max_distance
self.dynamic_max_cross = dynamic_max_cross
self.center_max_distance = center_max_distance
self.distance_embed_dim = distance_embed_dim
self.cross_distance_embed_dim = cross_distance_embed_dim
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
self.ns, self.nv = ns, nv
self.scale_by_sigma = scale_by_sigma
self.norm_by_sigma = norm_by_sigma
self.device = device
self.no_torsion = no_torsion
self.smooth_edges = smooth_edges
self.odd_parity = odd_parity
self.num_conv_layers = num_conv_layers
self.timestep_emb_func = timestep_emb_func
self.separate_noise_schedule = separate_noise_schedule
self.confidence_mode = confidence_mode
self.num_conv_layers = num_conv_layers
self.num_prot_emb_layers = num_prot_emb_layers
self.asyncronous_noise_schedule = asyncronous_noise_schedule
self.affinity_prediction = affinity_prediction
self.parallel, self.parallel_aggregators = parallel, parallel_aggregators.split(' ')
self.fixed_center_conv = fixed_center_conv
self.no_aminoacid_identities = no_aminoacid_identities
self.differentiate_convolutions = differentiate_convolutions
self.reduce_pseudoscalars = reduce_pseudoscalars
self.atom_confidence = atom_confidence
self.atom_num_confidence_outputs = atom_num_confidence_outputs
self.crop_beyond = crop_beyond
self.lm_embedding_type = lm_embedding_type
if lm_embedding_type is None:
lm_embedding_dim = 0
elif lm_embedding_type == "precomputed":
lm_embedding_dim=1280
else:
lm, alphabet = load_model_and_alphabet(lm_embedding_type)
self.batch_converter = alphabet.get_batch_converter()
lm.lm_head = torch.nn.Identity()
lm.contact_head = torch.nn.Identity()
lm_embedding_dim = lm.embed_dim
self.lm = lm
# embedding layers
atom_encoder_class = AtomEncoder
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
self.rec_sigma_embedding = nn.Sequential(nn.Linear(sigma_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=0, lm_embedding_dim=lm_embedding_dim)
self.rec_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
self.atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=0)
self.atom_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
self.lr_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.ar_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
irrep_seq = get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars)
assert not include_miscellaneous_atoms, "currently not supported"
rec_emb_layers = []
for i in range(num_prot_emb_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1 if not differentiate_convolutions else 4,
)
rec_emb_layers.append(layer)
self.rec_emb_layers = nn.ModuleList(rec_emb_layers)
self.embed_also_ligand = embed_also_ligand
if embed_also_ligand:
lig_emb_layers = []
for i in range(num_prot_emb_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1,
)
lig_emb_layers.append(layer)
self.lig_emb_layers = nn.ModuleList(lig_emb_layers)
# convolutional layers
conv_layers = []
for i in range(num_prot_emb_layers, num_prot_emb_layers + num_conv_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1 if not differentiate_convolutions else (3 if i == num_prot_emb_layers + num_conv_layers - 1 else 9),
)
conv_layers.append(layer)
self.conv_layers = nn.ModuleList(conv_layers)
# confidence and affinity prediction layers
if self.confidence_mode:
if self.affinity_prediction:
if self.parallel > 1:
output_confidence_dim = 1 + ns
else:
output_confidence_dim = num_confidence_outputs + 1
else:
output_confidence_dim = num_confidence_outputs
input_size = ns + (nv if reduce_pseudoscalars else ns) if num_conv_layers + num_prot_emb_layers >= 3 else ns
if self.atom_confidence:
self.atom_confidence_predictor = nn.Sequential(
nn.Linear(input_size, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, atom_num_confidence_outputs + ns)
)
input_size = ns
self.confidence_predictor = nn.Sequential(
nn.Linear(input_size, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, output_confidence_dim)
)
if self.parallel > 1:
self.affinity_predictor = nn.Sequential(
nn.Linear(len(self.parallel_aggregators) * ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, 1)
)
else:
# convolution for translational and rotational scores
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
self.center_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_conv = TensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.sh_irreps,
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
n_edge_features=2 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
if not no_torsion:
# convolution for torsional score
self.final_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
self.tor_bond_conv = TensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.final_tp_tor.irreps_out,
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
n_edge_features=3 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tor_final_layer = nn.Sequential(
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(ns, 1, bias=False)
)
def embedding(self, data):
if not hasattr(data['receptor'], "rec_node_attr"):
if self.lm_embedding_type not in [None, 'precomputed']:
sequences = [s for l in data['receptor'].sequence for s in l]
if isinstance(sequences[0], list):
sequences = [s for l in sequences for s in l]
sequences = [(i, s) for i, s in enumerate(sequences)]
batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
out = self.lm(batch_tokens.to(data['receptor'].x.device), repr_layers=[self.lm.num_layers], return_contacts=False)
rec_lm_emb = torch.cat([t[:len(sequences[i][1])] for i, t in enumerate(out['representations'][self.lm.num_layers])], dim=0)
data['receptor'].x = torch.cat([data['receptor'].x, rec_lm_emb], dim=-1)
rec_node_attr, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
rec_node_attr = self.rec_node_embedding(rec_node_attr)
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
atom_node_attr, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_atom_conv_graph(data)
atom_node_attr = self.atom_node_embedding(atom_node_attr)
atom_edge_attr = self.atom_edge_embedding(atom_edge_attr)
ar_edge_attr, ar_edge_sh, ar_edge_weight = self.build_cross_rec_conv_graph(data)
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
rec_edge_index = data['receptor', 'receptor'].edge_index.clone()
atom_edge_index = data['atom', 'atom'].edge_index.clone()
ar_edge_index = data['atom', 'receptor'].edge_index.clone()
node_attr = torch.cat([rec_node_attr, atom_node_attr], dim=0)
ar_edge_index[0] = ar_edge_index[0] + len(rec_node_attr)
edge_index = torch.cat([rec_edge_index, ar_edge_index, atom_edge_index + len(rec_node_attr), torch.flip(ar_edge_index, dims=[0])], dim=1)
edge_attr = torch.cat([rec_edge_attr, ar_edge_attr, atom_edge_attr, ar_edge_attr], dim=0)
edge_sh = torch.cat([rec_edge_sh, ar_edge_sh, atom_edge_sh, ar_edge_sh], dim=0)
edge_weight = torch.cat([rec_edge_weight, ar_edge_weight, atom_edge_weight, ar_edge_weight], dim=0) \
if torch.is_tensor(rec_edge_weight) else torch.ones((len(edge_index[0]), 1), device=edge_index.device)
s1, s2, s3 = len(rec_edge_index[0]), len(rec_edge_index[0]) + len(ar_edge_index[0]), len(rec_edge_index[0]) + len(ar_edge_index[0]) + len(atom_edge_index[0])
for l in range(len(self.rec_emb_layers)):
edge_attr_ = torch.cat(
[edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:]]
node_attr = self.rec_emb_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
data['receptor'].rec_node_attr = node_attr[:len(rec_node_attr)]
data['receptor', 'receptor'].rec_edge_attr = rec_edge_attr
data['receptor', 'receptor'].edge_sh = rec_edge_sh
data['receptor', 'receptor'].edge_weight = rec_edge_weight
data['atom'].atom_node_attr = node_attr[len(rec_node_attr):]
data['atom', 'atom'].atom_edge_attr = atom_edge_attr
data['atom', 'atom'].edge_sh = atom_edge_sh
data['atom', 'atom'].edge_weight = atom_edge_weight
data['atom', 'receptor'].edge_attr = ar_edge_attr
data['atom', 'receptor'].edge_sh = ar_edge_sh
data['atom', 'receptor'].edge_weight = ar_edge_weight
# receptor embedding
rec_sigma_emb = self.rec_sigma_embedding(self.timestep_emb_func(data.complex_t['tr']))
rec_node_attr = data['receptor'].rec_node_attr + 0
rec_node_attr[:, :self.ns] = rec_node_attr[:, :self.ns] + rec_sigma_emb[data['receptor'].batch]
rec_edge_attr = data['receptor', 'receptor'].rec_edge_attr + rec_sigma_emb[data['receptor'].batch[data['receptor', 'receptor'].edge_index[0]]]
# atom embedding
atom_node_attr = data['atom'].atom_node_attr + 0
atom_node_attr[:, :self.ns] = atom_node_attr[:, :self.ns] + rec_sigma_emb[data['atom'].batch]
atom_edge_attr = data['atom', 'atom'].atom_edge_attr + rec_sigma_emb[data['atom'].batch[data['atom', 'atom'].edge_index[0]]]
# atom-receptor embedding
ar_edge_attr = data['atom', 'receptor'].edge_attr + rec_sigma_emb[data['atom'].batch[data['atom', 'receptor'].edge_index[0]]]
# ligand embedding
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
lig_node_attr = self.lig_node_embedding(lig_node_attr)
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
if self.embed_also_ligand:
for l in range(len(self.lig_emb_layers)):
edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns], lig_node_attr[lig_edge_index[1], :self.ns]], -1)
lig_node_attr = self.lig_emb_layers[l](lig_node_attr, lig_edge_index, edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
else:
lig_node_attr = F.pad(lig_node_attr, (0, rec_node_attr.shape[-1] - lig_node_attr.shape[-1]))
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, \
rec_node_attr, data['receptor', 'receptor'].edge_index, rec_edge_attr, data['receptor', 'receptor'].edge_sh, data['receptor', 'receptor'].edge_weight, \
atom_node_attr, data['atom', 'atom'].edge_index, atom_edge_attr, data['atom', 'atom'].edge_sh, data['atom', 'atom'].edge_weight, \
data['atom', 'receptor'].edge_index, ar_edge_attr, data['atom', 'receptor'].edge_sh, data['atom', 'receptor'].edge_weight
def forward(self, data):
if self.crop_beyond is not None:
# TODO missing filtering atoms
raise NotImplementedError
ligand_pos = data['ligand'].pos
receptor_pos = data['receptor'].pos
residues_to_keep = torch.any(torch.sum((ligand_pos.unsqueeze(0) - receptor_pos.unsqueeze(1)) ** 2, -1) < self.crop_beyond ** 2, dim=1)
data['receptor'].pos = data['receptor'].pos[residues_to_keep]
data['receptor'].x = data['receptor'].x[residues_to_keep]
data['receptor'].side_chain_vecs = data['receptor'].side_chain_vecs[residues_to_keep]
data['receptor', 'rec_contact', 'receptor'].edge_index = subgraph(residues_to_keep, data['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
if self.no_aminoacid_identities:
data['receptor'].x = data['receptor'].x * 0
if not self.confidence_mode:
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
else:
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, rec_node_attr, \
rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight,\
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight, \
ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = self.embedding(data)
# build lig cross graph
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1) if self.dynamic_max_cross else self.cross_max_distance
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight = self.build_cross_lig_conv_graph(data, cross_cutoff)
lr_edge_attr= self.lr_edge_embedding(lr_edge_attr)
la_edge_attr = self.la_edge_embedding(la_edge_attr)
n_lig, n_rec = len(lig_node_attr), len(rec_node_attr)
node_attr = torch.cat([lig_node_attr, rec_node_attr, atom_node_attr], dim=0)
rec_edge_index, atom_edge_index, lr_edge_index, la_edge_index, ar_edge_index = rec_edge_index.clone(), atom_edge_index.clone(), lr_edge_index.clone(), la_edge_index.clone(), ar_edge_index.clone()
rec_edge_index[0], rec_edge_index[1] = rec_edge_index[0] + n_lig, rec_edge_index[1] + n_lig
atom_edge_index[0], atom_edge_index[1] = atom_edge_index[0] + n_lig + n_rec, atom_edge_index[1] + n_lig + n_rec
lr_edge_index[1] = lr_edge_index[1] + n_lig
la_edge_index[1] = la_edge_index[1] + n_lig + n_rec
ar_edge_index[0], ar_edge_index[1] = ar_edge_index[0] + n_lig + n_rec, ar_edge_index[1] + n_lig
edge_index = torch.cat([lig_edge_index, lr_edge_index, la_edge_index, rec_edge_index,
torch.flip(lr_edge_index, dims=[0]), torch.flip(ar_edge_index, dims=[0]),
atom_edge_index, torch.flip(la_edge_index, dims=[0]), ar_edge_index], dim=1)
edge_attr = torch.cat([lig_edge_attr, lr_edge_attr, la_edge_attr, rec_edge_attr, lr_edge_attr,
ar_edge_attr, atom_edge_attr, la_edge_attr, ar_edge_attr], dim=0)
edge_sh = torch.cat([lig_edge_sh, lr_edge_sh, la_edge_sh, rec_edge_sh, lr_edge_sh, ar_edge_sh,
atom_edge_sh, la_edge_sh, ar_edge_sh], dim=0)
edge_weight = torch.cat([lig_edge_weight, lr_edge_weight, la_edge_weight, rec_edge_weight, lr_edge_weight,
ar_edge_weight, atom_edge_weight, la_edge_weight, ar_edge_weight],
dim=0) if torch.is_tensor(lig_edge_weight) else torch.ones((len(edge_index[0]), 1),
device=edge_index.device)
s1, s2, s3, s4, s5, s6, s7, s8, _ = tuple(np.cumsum(list(map(len, [lig_edge_attr, lr_edge_attr, la_edge_attr,
rec_edge_attr, lr_edge_attr, ar_edge_attr, atom_edge_attr, la_edge_attr, ar_edge_attr]))).tolist())
for l in range(len(self.conv_layers)):
if l < len(self.conv_layers) - 1:
edge_attr_ = torch.cat([edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:s4],
edge_attr_[s4:s5], edge_attr_[s5:s6], edge_attr_[s6:s7], edge_attr_[s7:s8], edge_attr_[s8:]]
node_attr = self.conv_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
else:
edge_attr_ = torch.cat([edge_attr[:s3], node_attr[edge_index[0, :s3], :self.ns], node_attr[edge_index[1, :s3], :self.ns]], -1)
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3]]
node_attr = self.conv_layers[l](node_attr, edge_index[:, :s3], edge_attr_, edge_sh[:s3], edge_weight=edge_weight[:s3])
lig_node_attr = node_attr[:len(lig_node_attr)]
# confidence and affinity prediction
if self.confidence_mode:
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns], lig_node_attr[:,-(self.nv if self.reduce_pseudoscalars else self.ns):] ], dim=1) \
if self.num_conv_layers + self.num_prot_emb_layers >= 3 else lig_node_attr[:,:self.ns]
if self.atom_confidence:
scalar_lig_attr = self.atom_confidence_predictor(scalar_lig_attr)
atom_confidence = scalar_lig_attr[:, :self.atom_num_confidence_outputs]
scalar_lig_attr = scalar_lig_attr[:, self.atom_num_confidence_outputs:]
else:
atom_confidence = torch.zeros((len(lig_node_attr),), device=lig_node_attr.device)
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
if self.parallel > 1:
confidence, affinity = confidence[:, 0], confidence[:, 1:]
confidence = confidence.reshape(data.num_graphs, self.parallel)
affinity = affinity.reshape(data.num_graphs, self.parallel, -1)
affinity = torch.cat([AGGREGATORS[agg](affinity) for agg in self.parallel_aggregators], dim=-1)
affinity = self.affinity_predictor(affinity).squeeze(dim=-1)
confidence = confidence, affinity
return confidence, atom_confidence
assert self.parallel == 1
# compute translational and rotational score vectors
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
center_edge_attr = self.center_edge_embedding(center_edge_attr)
if self.fixed_center_conv:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
else:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
if self.separate_noise_schedule:
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
elif self.asyncronous_noise_schedule:
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
else: # tr rot and tor noise is all the same in this case
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# adjust the magniture of the score vectors
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
if self.scale_by_sigma:
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0,device=self.device), None
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
if self.scale_by_sigma:
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
.to(data['ligand'].x.device))
return tr_pred, rot_pred, tor_pred, None
def get_edge_weight(self, edge_vec, max_norm):
if self.smooth_edges:
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
return 1.0
def build_lig_conv_graph(self, data):
# build the graph between ligand atoms
if self.separate_noise_schedule:
data['ligand'].node_sigma_emb = torch.cat(
[self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
dim=1)
elif self.asyncronous_noise_schedule:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
else:
data['ligand'].node_sigma_emb = self.timestep_emb_func(
data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
if self.parallel == 1:
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
else:
batches = torch.zeros(data.num_graphs, device=data['ligand'].x.device).long()
batches = batches.index_add(0, data['ligand'].batch, torch.ones(len(data['ligand'].batch), device=data['ligand'].x.device).long())
outer_batches = data.num_graphs
b = [torch.ones(batches[i].item()//self.parallel, device=data['ligand'].x.device).long() * (self.parallel * i + j)
for i in range(outer_batches) for j in range(self.parallel)]
data['ligand'].batch_parallel = torch.cat(b)
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch_parallel)
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
edge_attr = torch.cat([
data['ligand', 'ligand'].edge_attr,
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
], 0)
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
src, dst = edge_index
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_rec_conv_graph(self, data):
# build the graph between receptor residues
node_attr = data['receptor'].x
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['receptor', 'receptor'].edge_index
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
edge_attr = self.rec_distance_expansion(edge_vec.norm(dim=-1))
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
return node_attr, edge_attr, edge_sh, edge_weight
def build_atom_conv_graph(self, data):
# build the graph between receptor atoms
node_attr = data['atom'].x
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['atom', 'atom'].edge_index
src, dst = edge_index
edge_vec = data['atom'].pos[dst.long()] - data['atom'].pos[src.long()]
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_attr, edge_sh, edge_weight
def build_cross_lig_conv_graph(self, data, lr_cross_distance_cutoff):
# build the cross edges between ligand atoms and receptor residues + atoms
# LIGAND to RECEPTOR
if torch.is_tensor(lr_cross_distance_cutoff):
# different cutoff for every graph
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
# LIGAND to ATOM
la_edge_index = radius(data['atom'].pos, data['ligand'].pos, self.lig_max_radius,
data['atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
la_edge_vec = data['atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
la_edge_length_emb = self.lig_distance_expansion(la_edge_vec.norm(dim=-1))
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight
def build_cross_rec_conv_graph(self, data):
# build the cross edges between ligan atoms, receptor residues and receptor atoms
# ATOM to RECEPTOR
ar_edge_index = data['atom', 'receptor'].edge_index
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['atom'].pos[ar_edge_index[0].long()]
ar_edge_attr = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
ar_edge_weight = 1
return ar_edge_attr, ar_edge_sh, ar_edge_weight
def build_center_conv_graph(self, data):
# build the filter for the convolution of the center with the ligand atoms
# for translational and rotational score
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return edge_index, edge_attr, edge_sh
def build_bond_conv_graph(self, data):
# build graph for the pseudotorque layer
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
bond_batch = data['ligand'].batch[bonds[0]]
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = self.final_edge_embedding(edge_attr)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return bonds, edge_index, edge_attr, edge_sh, edge_weight

640
models/cg_model.py Normal file
View File

@@ -0,0 +1,640 @@
import math
from e3nn import o3
import torch
from e3nn.o3 import Linear
from esm.pretrained import load_model_and_alphabet
from torch import nn
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_scatter import scatter, scatter_mean
import numpy as np
from models.layers import GaussianSmearing, AtomEncoder
from models.tensor_layers import TensorProductConvLayer, get_irrep_seq
from utils import so3, torus
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
class CGModel(torch.nn.Module):
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
separate_noise_schedule=False, lm_embedding_type=None, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm=False,
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
parallel_aggregators="mean max min std", num_confidence_outputs=1, atom_num_confidence_outputs=1, fixed_center_conv=False,
no_aminoacid_identities=False, include_miscellaneous_atoms=False,
differentiate_convolutions=True, tp_weights_layers=2, num_prot_emb_layers=0, reduce_pseudoscalars=False,
embed_also_ligand=False, atom_confidence=False, sidechain_pred=False, depthwise_convolution=False):
super(CGModel, self).__init__()
assert parallel == 1, "not implemented"
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
self.t_to_sigma = t_to_sigma
self.in_lig_edge_features = in_lig_edge_features
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
self.sigma_embed_dim = sigma_embed_dim
self.lig_max_radius = lig_max_radius
self.rec_max_radius = rec_max_radius
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.cross_max_distance = cross_max_distance
self.dynamic_max_cross = dynamic_max_cross
self.center_max_distance = center_max_distance
self.distance_embed_dim = distance_embed_dim
self.cross_distance_embed_dim = cross_distance_embed_dim
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
self.ns, self.nv = ns, nv
self.scale_by_sigma = scale_by_sigma
self.norm_by_sigma = norm_by_sigma
self.device = device
self.no_torsion = no_torsion
self.smooth_edges = smooth_edges
self.odd_parity = odd_parity
self.timestep_emb_func = timestep_emb_func
self.separate_noise_schedule = separate_noise_schedule
self.confidence_mode = confidence_mode
self.num_conv_layers = num_conv_layers
self.num_prot_emb_layers = num_prot_emb_layers
self.asyncronous_noise_schedule = asyncronous_noise_schedule
self.affinity_prediction = affinity_prediction
self.fixed_center_conv = fixed_center_conv
self.no_aminoacid_identities = no_aminoacid_identities
self.differentiate_convolutions = differentiate_convolutions
self.reduce_pseudoscalars = reduce_pseudoscalars
self.atom_confidence = atom_confidence
self.atom_num_confidence_outputs = atom_num_confidence_outputs
self.sidechain_pred = sidechain_pred
self.lm_embedding_type = lm_embedding_type
if lm_embedding_type is None:
lm_embedding_dim = 0
elif lm_embedding_type == "precomputed":
lm_embedding_dim=1280
else:
lm, alphabet = load_model_and_alphabet(lm_embedding_type)
self.batch_converter = alphabet.get_batch_converter()
lm.lm_head = torch.nn.Identity()
lm.contact_head = torch.nn.Identity()
lm_embedding_dim = lm.embed_dim
self.lm = lm
atom_encoder_class = AtomEncoder
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=0, lm_embedding_dim=lm_embedding_dim)
self.rec_edge_embedding = nn.Sequential(nn.Linear(distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
self.rec_sigma_embedding = nn.Sequential(nn.Linear(sigma_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout), nn.Linear(ns, ns))
if self.include_miscellaneous_atoms:
self.misc_atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.misc_atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.ar_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
irrep_seq = get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars)
assert not self.include_miscellaneous_atoms, "currently not supported"
rec_emb_layers = []
for i in range(num_prot_emb_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1,
depthwise=depthwise_convolution
)
rec_emb_layers.append(layer)
self.rec_emb_layers = nn.ModuleList(rec_emb_layers)
self.embed_also_ligand = embed_also_ligand
if embed_also_ligand:
lig_emb_layers = []
for i in range(num_prot_emb_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1,
depthwise=depthwise_convolution
)
lig_emb_layers.append(layer)
self.lig_emb_layers = nn.ModuleList(lig_emb_layers)
conv_layers = []
for i in range(num_prot_emb_layers, num_prot_emb_layers + num_conv_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
layer = TensorProductConvLayer(
in_irreps=in_irreps,
sh_irreps=self.sh_irreps,
out_irreps=out_irreps,
n_edge_features=3 * ns,
hidden_features=3 * ns,
residual=True,
batch_norm=batch_norm,
dropout=dropout,
faster=sh_lmax == 1 and not use_second_order_repr,
tp_weights_layers=tp_weights_layers,
edge_groups=1 if not differentiate_convolutions else (2 if i == num_prot_emb_layers + num_conv_layers - 1 else 4),
depthwise=depthwise_convolution
)
conv_layers.append(layer)
self.conv_layers = nn.ModuleList(conv_layers)
if sidechain_pred:
self.sidechain_predictor = Linear(
irreps_in=irrep_seq[min(num_prot_emb_layers + num_conv_layers, len(irrep_seq) - 1)],
irreps_out='4x0e + 2x1e + 4x0o + 2x1o',
internal_weights=True,
shared_weights=True,
)
if self.confidence_mode:
input_size = ns + (nv if reduce_pseudoscalars else ns) if num_conv_layers + num_prot_emb_layers >= 3 else ns
if self.atom_confidence:
self.atom_confidence_predictor = nn.Sequential(
nn.Linear(input_size, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, atom_num_confidence_outputs + ns)
)
input_size = ns
self.confidence_predictor = nn.Sequential(
nn.Linear(input_size, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, num_confidence_outputs + (1 if self.affinity_prediction else 0))
)
else:
# center of mass translation and rotation components
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
self.center_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_conv = TensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.sh_irreps,
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
n_edge_features=2 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
if not no_torsion:
# torsion angles components
self.final_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
self.tor_bond_conv = TensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.final_tp_tor.irreps_out,
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
n_edge_features=3 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tor_final_layer = nn.Sequential(
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(ns, 1, bias=False)
)
def ligand_embedding(self, data):
# ligand embedding
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
lig_node_attr = self.lig_node_embedding(lig_node_attr)
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
assert self.embed_also_ligand, "otherwise reimplement padding"
for l in range(len(self.lig_emb_layers)):
edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns],
lig_node_attr[lig_edge_index[1], :self.ns]], -1)
lig_node_attr = self.lig_emb_layers[l](lig_node_attr, lig_edge_index, edge_attr_, lig_edge_sh,
edge_weight=lig_edge_weight)
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight
def embedding(self, data):
if not hasattr(data['receptor'], "rec_node_attr"):
if self.lm_embedding_type not in [None, 'precomputed']:
sequences = [s for l in data['receptor'].sequence for s in l]
if isinstance(sequences[0], list):
sequences = [s for l in sequences for s in l]
sequences = [(i, s) for i, s in enumerate(sequences)]
batch_labels, batch_strs, batch_tokens = self.batch_converter(sequences)
out = self.lm(batch_tokens.to(data['receptor'].x.device), repr_layers=[self.lm.num_layers], return_contacts=False)
rec_lm_emb = torch.cat([t[:len(sequences[i][1])] for i, t in enumerate(out['representations'][self.lm.num_layers])], dim=0)
data['receptor'].x = torch.cat([data['receptor'].x, rec_lm_emb], dim=-1)
rec_node_attr, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
rec_node_attr = self.rec_node_embedding(rec_node_attr)
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
for l in range(len(self.rec_emb_layers)):
edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[data['receptor', 'receptor'].edge_index[0], :self.ns], rec_node_attr[data['receptor', 'receptor'].edge_index[1], :self.ns]], -1)
rec_node_attr = self.rec_emb_layers[l](rec_node_attr, data['receptor', 'receptor'].edge_index, edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
data['receptor'].rec_node_attr = rec_node_attr
data['receptor', 'receptor'].rec_edge_attr = rec_edge_attr
data['receptor', 'receptor'].edge_sh = rec_edge_sh
data['receptor', 'receptor'].edge_weight = rec_edge_weight
# receptor embedding
rec_sigma_emb = self.rec_sigma_embedding(self.timestep_emb_func(data.complex_t['tr']))
rec_node_attr = data['receptor'].rec_node_attr + 0
rec_node_attr[:, :self.ns] = rec_node_attr[:, :self.ns] + rec_sigma_emb[data['receptor'].batch]
rec_edge_attr = data['receptor', 'receptor'].rec_edge_attr + rec_sigma_emb[data['receptor'].batch[data['receptor', 'receptor'].edge_index[0]]]
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.ligand_embedding(data)
return lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, \
rec_node_attr, data['receptor', 'receptor'].edge_index, rec_edge_attr, data['receptor', 'receptor'].edge_sh, data['receptor', 'receptor'].edge_weight
def forward(self, data):
if self.no_aminoacid_identities:
data['receptor'].x = data['receptor'].x * 0
if not self.confidence_mode:
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
else:
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight, rec_node_attr, \
rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.embedding(data)
# build cross graph
if self.dynamic_max_cross:
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
else:
cross_cutoff = self.cross_max_distance
lr_edge_index, lr_edge_attr, lr_edge_sh, rev_lr_edge_sh, lr_edge_weight = self.build_cross_conv_graph(data, cross_cutoff)
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
node_attr = torch.cat([lig_node_attr, rec_node_attr], dim=0)
lr_edge_index[1] = lr_edge_index[1] + len(lig_node_attr)
edge_index = torch.cat([lig_edge_index, lr_edge_index, rec_edge_index + len(lig_node_attr),
torch.flip(lr_edge_index, dims=[0])], dim=1)
edge_attr = torch.cat([lig_edge_attr, lr_edge_attr, rec_edge_attr, lr_edge_attr], dim=0)
edge_sh = torch.cat([lig_edge_sh, lr_edge_sh, rec_edge_sh, rev_lr_edge_sh], dim=0)
edge_weight = torch.cat([lig_edge_weight, lr_edge_weight, rec_edge_weight, lr_edge_weight],
dim=0) if torch.is_tensor(lig_edge_weight) else torch.ones((len(edge_index[0]), 1),
device=edge_index.device)
s1, s2, s3 = len(lig_edge_index[0]), len(lig_edge_index[0]) + len(lr_edge_index[0]), len(lig_edge_index[0]) + len(lr_edge_index[0]) + len(rec_edge_index[0])
for l in range(len(self.conv_layers)):
if l < len(self.conv_layers) - 1:
edge_attr_ = torch.cat(
[edge_attr, node_attr[edge_index[0], :self.ns], node_attr[edge_index[1], :self.ns]], -1)
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2], edge_attr_[s2:s3], edge_attr_[s3:]]
node_attr = self.conv_layers[l](node_attr, edge_index, edge_attr_, edge_sh, edge_weight=edge_weight)
else:
edge_attr_ = torch.cat([edge_attr[:s2], node_attr[edge_index[0, :s2], :self.ns], node_attr[edge_index[1, :s2], :self.ns]], -1)
if self.differentiate_convolutions: edge_attr_ = [edge_attr_[:s1], edge_attr_[s1:s2]]
node_attr = self.conv_layers[l](node_attr, edge_index[:, :s2], edge_attr_, edge_sh[:s2], edge_weight=edge_weight[:s2])
lig_node_attr = node_attr[:len(lig_node_attr)]
# compute confidence score
if self.confidence_mode:
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns], lig_node_attr[:,-(self.nv if self.reduce_pseudoscalars else self.ns):] ], dim=1) \
if self.num_conv_layers + self.num_prot_emb_layers >= 3 else lig_node_attr[:,:self.ns]
if self.atom_confidence:
scalar_lig_attr = self.atom_confidence_predictor(scalar_lig_attr)
atom_confidence = scalar_lig_attr[:, :self.atom_num_confidence_outputs]
scalar_lig_attr = scalar_lig_attr[:, self.atom_num_confidence_outputs:]
else:
atom_confidence = torch.zeros((len(lig_node_attr),), device=lig_node_attr.device)
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
return confidence, atom_confidence
# compute translational and rotational score vectors
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
center_edge_attr = self.center_edge_embedding(center_edge_attr)
if self.fixed_center_conv:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
else:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
if self.separate_noise_schedule:
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
elif self.asyncronous_noise_schedule:
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
else: # tr rot and tor noise is all the same in this case
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# fix the magnitude of translational and rotational score vectors
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
if self.scale_by_sigma:
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
# predict sidechain orientation
sidechain_pred = None
if self.sidechain_pred:
rec_node_attr = node_attr[len(lig_node_attr):]
sidechain_pred = self.sidechain_predictor(rec_node_attr)
sidechain_pred = sidechain_pred[:, :10] + sidechain_pred[:, 10:] # sum even and odd components
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device), sidechain_pred
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
if self.scale_by_sigma:
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
.to(data['ligand'].x.device))
return tr_pred, rot_pred, tor_pred, sidechain_pred
def torsional_forward(self, data):
tor_sigma = self.t_to_sigma(data.complex_t['tor'])
# build ligand graph
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.ligand_embedding(data)
if self.separate_noise_schedule:
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
elif self.asyncronous_noise_schedule:
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
else: # tr rot and tor noise is all the same in this case
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
if self.scale_by_sigma:
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
.to(data['ligand'].x.device))
return 0, 0, tor_pred, 0
def get_edge_weight(self, edge_vec, max_norm):
# computes weights for edges that are decreasing with the distance
# it has an effect only if smooth edges is true
if self.smooth_edges:
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
return 1.0
def build_lig_conv_graph(self, data):
# builds the ligand graph edges and initial node and edge features
if self.separate_noise_schedule:
data['ligand'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
elif self.asyncronous_noise_schedule:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
else:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
# compute edges
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
edge_attr = torch.cat([
data['ligand', 'ligand'].edge_attr,
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
], 0)
# compute initial features
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
src, dst = edge_index
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_rec_conv_graph(self, data):
# builds the receptor initial node and edge embeddings
assert not self.separate_noise_schedule or self.asyncronous_noise_schedule, "removed support in this function"
node_attr = data['receptor'].x
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['receptor', 'receptor'].edge_index
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = edge_length_emb
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
return node_attr, edge_attr, edge_sh, edge_weight
def build_misc_atom_conv_graph(self, data):
# build the graph between receptor misc_atoms
if self.separate_noise_schedule:
data['misc_atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['misc_atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
elif self.asyncronous_noise_schedule:
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['t'])
else:
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['misc_atom'].x, data['misc_atom'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['misc_atom', 'misc_atom'].edge_index
src, dst = edge_index
edge_vec = data['misc_atom'].pos[dst.long()] - data['misc_atom'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['misc_atom'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_cross_conv_graph(self, data, cross_distance_cutoff):
# builds the cross edges between ligand and receptor
if torch.is_tensor(cross_distance_cutoff):
# different cutoff for every graph (depends on the diffusion time)
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
rev_edge_sh = o3.spherical_harmonics(self.sh_irreps, -edge_vec, normalize=True, normalization='component')
cutoff_d = cross_distance_cutoff[data['ligand'].batch[src]].squeeze() if torch.is_tensor(cross_distance_cutoff) else cross_distance_cutoff
edge_weight = self.get_edge_weight(edge_vec, cutoff_d)
return edge_index, edge_attr, edge_sh, rev_edge_sh, edge_weight
def build_misc_cross_conv_graph(self, data, lr_cross_distance_cutoff):
# build the cross edges between ligan atoms, receptor residues and receptor atoms
# LIGAND to RECEPTOR
if torch.is_tensor(lr_cross_distance_cutoff):
# different cutoff for every graph
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
# LIGAND to ATOM
la_edge_index = radius(data['misc_atom'].pos, data['ligand'].pos, self.lig_max_radius,
data['misc_atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
la_edge_vec = data['misc_atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
la_edge_length_emb = self.cross_distance_expansion(la_edge_vec.norm(dim=-1))
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
# ATOM to RECEPTOR
ar_edge_index = data['misc_atom', 'receptor'].edge_index
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['misc_atom'].pos[ar_edge_index[0].long()]
ar_edge_length_emb = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
ar_edge_sigma_emb = data['misc_atom'].node_sigma_emb[ar_edge_index[0].long()]
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
ar_edge_weight = 1
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
def build_center_conv_graph(self, data):
# builds the filter and edges for the convolution generating translational and rotational scores
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return edge_index, edge_attr, edge_sh
def build_bond_conv_graph(self, data):
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
bond_batch = data['ligand'].batch[bonds[0]]
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = self.final_edge_embedding(edge_attr)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return bonds, edge_index, edge_attr, edge_sh, edge_weight

94
models/layers.py Normal file
View File

@@ -0,0 +1,94 @@
import torch
from torch import nn
ACTIVATIONS = {
'relu': nn.ReLU,
'silu': nn.SiLU
}
def FCBlock(in_dim, hidden_dim, out_dim, layers, dropout, activation='relu'):
activation = ACTIVATIONS[activation]
assert layers >= 2
sequential = [nn.Linear(in_dim, hidden_dim), activation(), nn.Dropout(dropout)]
for i in range(layers - 2):
sequential += [nn.Linear(hidden_dim, hidden_dim), activation(), nn.Dropout(dropout)]
sequential += [nn.Linear(hidden_dim, out_dim)]
return nn.Sequential(*sequential)
class GaussianSmearing(torch.nn.Module):
# used to embed the edge distances
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
super().__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_dim=0):
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.num_categorical_features = len(feature_dims[0])
self.additional_features_dim = feature_dims[1] + sigma_embed_dim + lm_embedding_dim
for i, dim in enumerate(feature_dims[0]):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
if self.additional_features_dim > 0:
self.additional_features_embedder = torch.nn.Linear(self.additional_features_dim + emb_dim, emb_dim)
def forward(self, x):
x_embedding = 0
assert x.shape[1] == self.num_categorical_features + self.additional_features_dim
for i in range(self.num_categorical_features):
x_embedding += self.atom_embedding_list[i](x[:, i].long())
if self.additional_features_dim > 0:
x_embedding = self.additional_features_embedder(torch.cat([x_embedding, x[:, self.num_categorical_features:]], axis=1))
return x_embedding
class OldAtomEncoder(torch.nn.Module):
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type= None):
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
super(OldAtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.num_categorical_features = len(feature_dims[0])
self.num_scalar_features = feature_dims[1] + sigma_embed_dim
self.lm_embedding_type = lm_embedding_type
for i, dim in enumerate(feature_dims[0]):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
if self.num_scalar_features > 0:
self.linear = torch.nn.Linear(self.num_scalar_features, emb_dim)
if self.lm_embedding_type is not None:
if self.lm_embedding_type == 'esm':
self.lm_embedding_dim = 1280
else: raise ValueError('LM Embedding type was not correctly determined. LM embedding type: ', self.lm_embedding_type)
self.lm_embedding_layer = torch.nn.Linear(self.lm_embedding_dim + emb_dim, emb_dim)
def forward(self, x):
x_embedding = 0
if self.lm_embedding_type is not None:
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features + self.lm_embedding_dim
else:
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features
for i in range(self.num_categorical_features):
x_embedding += self.atom_embedding_list[i](x[:, i].long())
if self.num_scalar_features > 0:
x_embedding += self.linear(x[:, self.num_categorical_features:self.num_categorical_features + self.num_scalar_features])
if self.lm_embedding_type is not None:
x_embedding = self.lm_embedding_layer(torch.cat([x_embedding, x[:, -self.lm_embedding_dim:]], axis=1))
return x_embedding

View File

@@ -3,24 +3,38 @@ import torch
from torch import nn
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_scatter import scatter_mean
from torch_scatter import scatter, scatter_mean
import numpy as np
from e3nn.nn import BatchNorm
from models.score_model import AtomEncoder, TensorProductConvLayer, GaussianSmearing
from models.layers import GaussianSmearing, OldAtomEncoder, AtomEncoder
from models.tensor_layers import OldTensorProductConvLayer
from utils import so3, torus
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
AGGREGATORS = {"mean": lambda x: torch.mean(x, dim=1),
"max": lambda x: torch.max(x, dim=1)[0],
"min": lambda x: torch.min(x, dim=1)[0],
"std": lambda x: torch.std(x, dim=1)}
class TensorProductScoreModel(torch.nn.Module):
class AAOldModel(torch.nn.Module):
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
scale_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, lm_embedding_type=False, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm=False, num_confidence_outputs=1):
super(TensorProductScoreModel, self).__init__()
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
separate_noise_schedule=False, lm_embedding_type=False, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm = False,
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
parallel_aggregators="mean max min std", num_confidence_outputs=1, fixed_center_conv=False,
no_aminoacid_identities=False, include_miscellaneous_atoms=False, use_old_atom_encoder=False):
super(AAOldModel, self).__init__()
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
if parallel > 1: assert affinity_prediction
self.t_to_sigma = t_to_sigma
self.in_lig_edge_features = in_lig_edge_features
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
self.sigma_embed_dim = sigma_embed_dim
self.lig_max_radius = lig_max_radius
self.rec_max_radius = rec_max_radius
@@ -32,21 +46,31 @@ class TensorProductScoreModel(torch.nn.Module):
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
self.ns, self.nv = ns, nv
self.scale_by_sigma = scale_by_sigma
self.norm_by_sigma = norm_by_sigma
self.device = device
self.no_torsion = no_torsion
self.smooth_edges = smooth_edges
self.odd_parity = odd_parity
self.num_conv_layers = num_conv_layers
self.timestep_emb_func = timestep_emb_func
self.separate_noise_schedule = separate_noise_schedule
self.confidence_mode = confidence_mode
self.num_conv_layers = num_conv_layers
self.asyncronous_noise_schedule = asyncronous_noise_schedule
self.affinity_prediction = affinity_prediction
self.parallel, self.parallel_aggregators = parallel, parallel_aggregators.split(' ')
self.fixed_center_conv = fixed_center_conv
self.no_aminoacid_identities = no_aminoacid_identities
# embedding layers
self.lig_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
atom_encoder_class = OldAtomEncoder if use_old_atom_encoder else AtomEncoder
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
self.rec_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.atom_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.lr_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
@@ -88,13 +112,19 @@ class TensorProductScoreModel(torch.nn.Module):
}
for _ in range(9): # 3 intra & 6 inter per each layer
conv_layers.append(TensorProductConvLayer(**parameters))
conv_layers.append(OldTensorProductConvLayer(**parameters))
self.conv_layers = nn.ModuleList(conv_layers)
# confidence and affinity prediction layers
if self.confidence_mode:
output_confidence_dim = num_confidence_outputs
if self.affinity_prediction:
if self.parallel > 1:
output_confidence_dim = 1 + ns
else:
output_confidence_dim = num_confidence_outputs +1
else:
output_confidence_dim = num_confidence_outputs
self.confidence_predictor = nn.Sequential(
nn.Linear(2 * self.ns if num_conv_layers >= 3 else self.ns, ns),
@@ -108,6 +138,19 @@ class TensorProductScoreModel(torch.nn.Module):
nn.Linear(ns, output_confidence_dim)
)
if self.parallel > 1:
self.affinity_predictor = nn.Sequential(
nn.Linear(len(self.parallel_aggregators) * ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, 1)
)
else:
# convolution for translational and rotational scores
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
@@ -118,18 +161,18 @@ class TensorProductScoreModel(torch.nn.Module):
nn.Linear(ns, ns)
)
self.final_conv = TensorProductConvLayer(
self.final_conv = OldTensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.sh_irreps,
out_irreps=f'2x1o + 2x1e',
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
n_edge_features=2 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns), nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns), nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
if not no_torsion:
# convolution for torsional score
@@ -140,47 +183,51 @@ class TensorProductScoreModel(torch.nn.Module):
nn.Linear(ns, ns)
)
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
self.tor_bond_conv = TensorProductConvLayer(
self.tor_bond_conv = OldTensorProductConvLayer(
in_irreps=self.conv_layers[-1].out_irreps,
sh_irreps=self.final_tp_tor.irreps_out,
out_irreps=f'{ns}x0o + {ns}x0e',
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
n_edge_features=3 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tor_final_layer = nn.Sequential(
nn.Linear(2 * ns, ns, bias=False),
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(ns, 1, bias=False)
)
def forward(self, data):
if self.no_aminoacid_identities:
data['receptor'].x = data['receptor'].x * 0
if not self.confidence_mode:
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
else:
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
# build ligand graph
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh = self.build_lig_conv_graph(data)
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
lig_node_attr = self.lig_node_embedding(lig_node_attr)
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
# build receptor graph
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh = self.build_rec_conv_graph(data)
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
rec_node_attr = self.rec_node_embedding(rec_node_attr)
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
# build atom graph
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh = self.build_atom_conv_graph(data)
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_atom_conv_graph(data)
atom_node_attr = self.atom_node_embedding(atom_node_attr)
atom_edge_attr = self.atom_edge_embedding(atom_edge_attr)
# build cross graph
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1) if self.dynamic_max_cross else self.cross_max_distance
lr_edge_index, lr_edge_attr, lr_edge_sh, la_edge_index, la_edge_attr, \
la_edge_sh, ar_edge_index, ar_edge_attr, ar_edge_sh = self.build_cross_conv_graph(data, cross_cutoff)
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = \
self.build_cross_conv_graph(data, cross_cutoff)
lr_edge_attr= self.lr_edge_embedding(lr_edge_attr)
la_edge_attr = self.la_edge_embedding(la_edge_attr)
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
@@ -188,47 +235,47 @@ class TensorProductScoreModel(torch.nn.Module):
for l in range(self.num_conv_layers):
# LIGAND updates
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_edge_index[0], :self.ns], lig_node_attr[lig_edge_index[1], :self.ns]], -1)
lig_update = self.conv_layers[9*l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh)
lig_update = self.conv_layers[9*l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
lr_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[lr_edge_index[0], :self.ns], rec_node_attr[lr_edge_index[1], :self.ns]], -1)
lr_update = self.conv_layers[9*l+1](rec_node_attr, lr_edge_index, lr_edge_attr_, lr_edge_sh,
out_nodes=lig_node_attr.shape[0])
out_nodes=lig_node_attr.shape[0], edge_weight=lr_edge_weight)
la_edge_attr_ = torch.cat([la_edge_attr, lig_node_attr[la_edge_index[0], :self.ns], atom_node_attr[la_edge_index[1], :self.ns]], -1)
la_update = self.conv_layers[9*l+2](atom_node_attr, la_edge_index, la_edge_attr_, la_edge_sh,
out_nodes=lig_node_attr.shape[0])
out_nodes=lig_node_attr.shape[0], edge_weight=la_edge_weight)
if l != self.num_conv_layers-1: # last layer optimisation
# ATOM UPDATES
atom_edge_attr_ = torch.cat([atom_edge_attr, atom_node_attr[atom_edge_index[0], :self.ns], atom_node_attr[atom_edge_index[1], :self.ns]], -1)
atom_update = self.conv_layers[9*l+3](atom_node_attr, atom_edge_index, atom_edge_attr_, atom_edge_sh)
atom_update = self.conv_layers[9*l+3](atom_node_attr, atom_edge_index, atom_edge_attr_, atom_edge_sh, edge_weight=atom_edge_weight)
al_edge_attr_ = torch.cat([la_edge_attr, atom_node_attr[la_edge_index[1], :self.ns], lig_node_attr[la_edge_index[0], :self.ns]], -1)
al_update = self.conv_layers[9*l+4](lig_node_attr, torch.flip(la_edge_index, dims=[0]), al_edge_attr_,
la_edge_sh, out_nodes=atom_node_attr.shape[0])
la_edge_sh, out_nodes=atom_node_attr.shape[0], edge_weight=la_edge_weight)
ar_edge_attr_ = torch.cat([ar_edge_attr, atom_node_attr[ar_edge_index[0], :self.ns], rec_node_attr[ar_edge_index[1], :self.ns]],-1)
ar_update = self.conv_layers[9*l+5](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh, out_nodes=atom_node_attr.shape[0])
ar_update = self.conv_layers[9*l+5](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh, out_nodes=atom_node_attr.shape[0], edge_weight=ar_edge_weight)
# RECEPTOR updates
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_edge_index[0], :self.ns], rec_node_attr[rec_edge_index[1], :self.ns]], -1)
rec_update = self.conv_layers[9*l+6](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh)
rec_update = self.conv_layers[9*l+6](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
rl_edge_attr_ = torch.cat([lr_edge_attr, rec_node_attr[lr_edge_index[1], :self.ns], lig_node_attr[lr_edge_index[0], :self.ns]], -1)
rl_update = self.conv_layers[9*l+7](lig_node_attr, torch.flip(lr_edge_index, dims=[0]), rl_edge_attr_,
lr_edge_sh, out_nodes=rec_node_attr.shape[0])
lr_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=lr_edge_weight)
ra_edge_attr_ = torch.cat([ar_edge_attr, rec_node_attr[ar_edge_index[1], :self.ns], atom_node_attr[ar_edge_index[0], :self.ns]], -1)
ra_update = self.conv_layers[9*l+8](atom_node_attr, torch.flip(ar_edge_index, dims=[0]), ra_edge_attr_,
ar_edge_sh, out_nodes=rec_node_attr.shape[0])
ar_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=ar_edge_weight)
# padding original features and update features with residual updates
lig_node_attr = F.pad(lig_node_attr, (0, lig_update.shape[-1] - lig_node_attr.shape[-1]))
lig_node_attr = lig_node_attr + lig_update + la_update + lr_update
if l != self.num_conv_layers - 1: # last layer optimisation
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - rec_node_attr.shape[-1]))
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - atom_node_attr.shape[-1]))
atom_node_attr = atom_node_attr + atom_update + al_update + ar_update
rec_node_attr = F.pad(rec_node_attr, (0, rec_update.shape[-1] - rec_node_attr.shape[-1]))
rec_node_attr = rec_node_attr + rec_update + ra_update + rl_update
@@ -236,18 +283,36 @@ class TensorProductScoreModel(torch.nn.Module):
# confidence and affinity prediction
if self.confidence_mode:
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:]], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch if self.parallel == 1 else data['ligand'].batch_parallel, dim=0)).squeeze(dim=-1)
if self.parallel > 1:
confidence, affinity = confidence[:, 0], confidence[:, 1:]
confidence = confidence.reshape(data.num_graphs, self.parallel)
affinity = affinity.reshape(data.num_graphs, self.parallel, -1)
affinity = torch.cat([AGGREGATORS[agg](affinity) for agg in self.parallel_aggregators], dim=-1)
affinity = self.affinity_predictor(affinity).squeeze(dim=-1)
confidence = confidence, affinity
return confidence
assert self.parallel == 1
# compute translational and rotational score vectors
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
center_edge_attr = self.center_edge_embedding(center_edge_attr)
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
if self.fixed_center_conv:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
else:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
tr_pred = global_pred[:, :3] + global_pred[:, 6:9]
rot_pred = global_pred[:, 3:6] + global_pred[:, 9:]
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
if self.separate_noise_schedule:
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
elif self.asyncronous_noise_schedule:
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
else: # tr rot and tor noise is all the same in this case
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# adjust the magniture of the score vectors
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
@@ -263,7 +328,7 @@ class TensorProductScoreModel(torch.nn.Module):
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0,device=self.device)
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh = self.build_bond_conv_graph(data)
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
@@ -273,7 +338,7 @@ class TensorProductScoreModel(torch.nn.Module):
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean')
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
@@ -282,11 +347,34 @@ class TensorProductScoreModel(torch.nn.Module):
.to(data['ligand'].x.device))
return tr_pred, rot_pred, tor_pred
def get_edge_weight(self, edge_vec, max_norm):
if self.smooth_edges:
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
return 1.0
def build_lig_conv_graph(self, data):
# build the graph between ligand atoms
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr'])
if self.separate_noise_schedule:
data['ligand'].node_sigma_emb = torch.cat(
[self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
dim=1)
elif self.asyncronous_noise_schedule:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
else:
data['ligand'].node_sigma_emb = self.timestep_emb_func(
data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
if self.parallel == 1:
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
else:
batches = torch.zeros(data.num_graphs, device=data['ligand'].x.device).long()
batches = batches.index_add(0, data['ligand'].batch, torch.ones(len(data['ligand'].batch), device=data['ligand'].x.device).long())
outer_batches = data.num_graphs
b = [torch.ones(batches[i].item()//self.parallel, device=data['ligand'].x.device).long() * (self.parallel * i + j)
for i in range(outer_batches) for j in range(self.parallel)]
data['ligand'].batch_parallel = torch.cat(b)
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch_parallel)
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
edge_attr = torch.cat([
data['ligand', 'ligand'].edge_attr,
@@ -303,29 +391,45 @@ class TensorProductScoreModel(torch.nn.Module):
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_rec_conv_graph(self, data):
# build the graph between receptor residues
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr'])
if self.separate_noise_schedule:
data['receptor'].node_sigma_emb = torch.cat(
[self.timestep_emb_func(data['receptor'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],
dim=1)
elif self.asyncronous_noise_schedule:
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['t'])
else:
data['receptor'].node_sigma_emb = self.timestep_emb_func(
data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['receptor', 'receptor'].edge_index
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
#assert torch.all(edge_vec.norm(dim=-1) < self.rec_max_radius)
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
return node_attr, edge_index, edge_attr, edge_sh
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_atom_conv_graph(self, data):
# build the graph between receptor atoms
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['tr'])
if self.separate_noise_schedule:
data['atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
elif self.asyncronous_noise_schedule:
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['t'])
else:
data['atom'].node_sigma_emb = self.timestep_emb_func(data['atom'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['atom'].x, data['atom'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
@@ -337,8 +441,9 @@ class TensorProductScoreModel(torch.nn.Module):
edge_sigma_emb = data['atom'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_cross_conv_graph(self, data, lr_cross_distance_cutoff):
# build the cross edges between ligan atoms, receptor residues and receptor atoms
@@ -361,6 +466,7 @@ class TensorProductScoreModel(torch.nn.Module):
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
# LIGAND to ATOM
la_edge_index = radius(data['atom'].pos, data['ligand'].pos, self.lig_max_radius,
@@ -371,6 +477,7 @@ class TensorProductScoreModel(torch.nn.Module):
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
# ATOM to RECEPTOR
ar_edge_index = data['atom', 'receptor'].edge_index
@@ -379,9 +486,10 @@ class TensorProductScoreModel(torch.nn.Module):
ar_edge_sigma_emb = data['atom'].node_sigma_emb[ar_edge_index[0].long()]
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
ar_edge_weight = 1
return lr_edge_index, lr_edge_attr, lr_edge_sh, la_edge_index, la_edge_attr, \
la_edge_sh, ar_edge_index, ar_edge_attr, ar_edge_sh
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
def build_center_conv_graph(self, data):
# build the filter for the convolution of the center with the ligand atoms
@@ -411,5 +519,6 @@ class TensorProductScoreModel(torch.nn.Module):
edge_attr = self.final_edge_embedding(edge_attr)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return bonds, edge_index, edge_attr, edge_sh
return bonds, edge_index, edge_attr, edge_sh, edge_weight

538
models/old_cg_model.py Normal file
View File

@@ -0,0 +1,538 @@
import math
from e3nn import o3
import torch
from torch import nn
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_scatter import scatter, scatter_mean
import numpy as np
from e3nn.nn import BatchNorm
from models.layers import OldAtomEncoder, AtomEncoder, GaussianSmearing
from models.tensor_layers import OldTensorProductConvLayer
from utils import so3, torus
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims, rec_atom_feature_dims
class CGOldModel(torch.nn.Module):
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
scale_by_sigma=True, norm_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, smooth_edges=False, odd_parity=False,
separate_noise_schedule=False, lm_embedding_type=None, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm=False,
asyncronous_noise_schedule=False, affinity_prediction=False, parallel=1,
parallel_aggregators="mean max min std", num_confidence_outputs=1, fixed_center_conv=False,
no_aminoacid_identities=False, include_miscellaneous_atoms=False, use_old_atom_encoder=False):
super(CGOldModel, self).__init__()
assert parallel == 1, "not implemented"
assert (not no_aminoacid_identities) or (lm_embedding_type is None), "no language model emb without identities"
self.t_to_sigma = t_to_sigma
self.in_lig_edge_features = in_lig_edge_features
sigma_embed_dim *= (3 if separate_noise_schedule else 1)
self.sigma_embed_dim = sigma_embed_dim
self.lig_max_radius = lig_max_radius
self.rec_max_radius = rec_max_radius
self.include_miscellaneous_atoms = include_miscellaneous_atoms
self.cross_max_distance = cross_max_distance
self.dynamic_max_cross = dynamic_max_cross
self.center_max_distance = center_max_distance
self.distance_embed_dim = distance_embed_dim
self.cross_distance_embed_dim = cross_distance_embed_dim
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
self.ns, self.nv = ns, nv
self.scale_by_sigma = scale_by_sigma
self.norm_by_sigma = norm_by_sigma
self.device = device
self.no_torsion = no_torsion
self.smooth_edges = smooth_edges
self.odd_parity = odd_parity
self.timestep_emb_func = timestep_emb_func
self.separate_noise_schedule = separate_noise_schedule
self.confidence_mode = confidence_mode
self.num_conv_layers = num_conv_layers
self.asyncronous_noise_schedule = asyncronous_noise_schedule
self.affinity_prediction = affinity_prediction
self.fixed_center_conv = fixed_center_conv
self.no_aminoacid_identities = no_aminoacid_identities
atom_encoder_class = OldAtomEncoder if use_old_atom_encoder else AtomEncoder
self.lig_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(),nn.Dropout(dropout),nn.Linear(ns, ns))
self.rec_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
if self.include_miscellaneous_atoms:
self.misc_atom_node_embedding = atom_encoder_class(emb_dim=ns, feature_dims=rec_atom_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.misc_atom_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.ar_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.la_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(),nn.Dropout(dropout), nn.Linear(ns, ns))
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
if use_second_order_repr:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o + {nv}x2e',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {ns}x0o'
]
else:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o',
f'{ns}x0e + {nv}x1o + {nv}x1e',
f'{ns}x0e + {nv}x1o + {nv}x1e + {ns}x0o'
]
lig_conv_layers, rec_conv_layers, lig_to_rec_conv_layers, rec_to_lig_conv_layers = [], [], [], []
if self.include_miscellaneous_atoms:
misc_conv_layers, la_conv_layers, ra_conv_layers, al_conv_layers, ar_conv_layers = [], [], [], [], []
for i in range(num_conv_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
parameters = {
'in_irreps': in_irreps,
'sh_irreps': self.sh_irreps,
'out_irreps': out_irreps,
'n_edge_features': 3 * ns,
'hidden_features': 3 * ns,
'residual': False,
'batch_norm': batch_norm,
'dropout': dropout
}
lig_layer = OldTensorProductConvLayer(**parameters)
lig_conv_layers.append(lig_layer)
rec_layer = OldTensorProductConvLayer(**parameters)
rec_conv_layers.append(rec_layer)
lig_to_rec_layer = OldTensorProductConvLayer(**parameters)
lig_to_rec_conv_layers.append(lig_to_rec_layer)
rec_to_lig_layer = OldTensorProductConvLayer(**parameters)
rec_to_lig_conv_layers.append(rec_to_lig_layer)
if self.include_miscellaneous_atoms:
misc_conv_layer = OldTensorProductConvLayer(**parameters)
la_conv_layer = OldTensorProductConvLayer(**parameters)
ra_conv_layer = OldTensorProductConvLayer(**parameters)
al_conv_layer = OldTensorProductConvLayer(**parameters)
ar_conv_layer = OldTensorProductConvLayer(**parameters)
misc_conv_layers.append(misc_conv_layer)
la_conv_layers.append(la_conv_layer)
ra_conv_layers.append(ra_conv_layer)
al_conv_layers.append(al_conv_layer)
ar_conv_layers.append(ar_conv_layer)
self.lig_conv_layers = nn.ModuleList(lig_conv_layers)
self.rec_conv_layers = nn.ModuleList(rec_conv_layers)
self.lig_to_rec_conv_layers = nn.ModuleList(lig_to_rec_conv_layers)
self.rec_to_lig_conv_layers = nn.ModuleList(rec_to_lig_conv_layers)
if self.include_miscellaneous_atoms:
self.misc_conv_layers = nn.ModuleList(misc_conv_layers)
self.la_conv_layers = nn.ModuleList(la_conv_layers)
self.ra_conv_layers = nn.ModuleList(ra_conv_layers)
self.al_conv_layers = nn.ModuleList(al_conv_layers)
self.ar_conv_layers = nn.ModuleList(ar_conv_layers)
if self.confidence_mode:
self.confidence_predictor = nn.Sequential(
nn.Linear(2*self.ns if num_conv_layers >= 3 else self.ns,ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, 2 if self.affinity_prediction else 1)
)
else:
# center of mass translation and rotation components
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
self.center_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_conv = OldTensorProductConvLayer(
in_irreps=self.lig_conv_layers[-1].out_irreps,
sh_irreps=self.sh_irreps,
out_irreps=f'2x1o + 2x1e' if not self.odd_parity else '1x1o + 1x1e',
n_edge_features=2 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
if not no_torsion:
# torsion angles components
self.final_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
self.tor_bond_conv = OldTensorProductConvLayer(
in_irreps=self.lig_conv_layers[-1].out_irreps,
sh_irreps=self.final_tp_tor.irreps_out,
out_irreps=f'{ns}x0o + {ns}x0e' if not self.odd_parity else f'{ns}x0o',
n_edge_features=3 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tor_final_layer = nn.Sequential(
nn.Linear(2 * ns if not self.odd_parity else ns, ns, bias=False),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(ns, 1, bias=False)
)
def forward(self, data):
if self.no_aminoacid_identities:
data['receptor'].x = data['receptor'].x * 0
if not self.confidence_mode:
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
else:
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
# build ligand graph
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh, lig_edge_weight = self.build_lig_conv_graph(data)
lig_src, lig_dst = lig_edge_index
lig_node_attr = self.lig_node_embedding(lig_node_attr)
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
# build receptor graph
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh, rec_edge_weight = self.build_rec_conv_graph(data)
rec_src, rec_dst = rec_edge_index
rec_node_attr = self.rec_node_embedding(rec_node_attr)
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
if self.include_miscellaneous_atoms:
# build misc_atom graph
atom_node_attr, atom_edge_index, atom_edge_attr, atom_edge_sh, atom_edge_weight = self.build_misc_atom_conv_graph(data)
atom_node_attr = self.misc_atom_node_embedding(atom_node_attr)
atom_edge_attr = self.misc_atom_edge_embedding(atom_edge_attr)
# build cross graph
if self.dynamic_max_cross:
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
else:
cross_cutoff = self.cross_max_distance
if self.include_miscellaneous_atoms:
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight = \
self.build_misc_cross_conv_graph(data, cross_cutoff)
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
la_edge_attr = self.la_edge_embedding(la_edge_attr)
ar_edge_attr = self.ar_edge_embedding(ar_edge_attr)
cross_lig, cross_rec = lr_edge_index
else:
lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight = self.build_cross_conv_graph(data, cross_cutoff)
cross_lig, cross_rec = lr_edge_index
lr_edge_attr = self.cross_edge_embedding(lr_edge_attr)
for l in range(len(self.lig_conv_layers)):
# intra graph message passing
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_src, :self.ns], lig_node_attr[lig_dst, :self.ns]], -1)
lig_intra_update = self.lig_conv_layers[l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh, edge_weight=lig_edge_weight)
# inter graph message passing
rec_to_lig_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
lig_inter_update = self.rec_to_lig_conv_layers[l](rec_node_attr, lr_edge_index, rec_to_lig_edge_attr_, lr_edge_sh,
out_nodes=lig_node_attr.shape[0], edge_weight=lr_edge_weight)
if self.include_miscellaneous_atoms:
la_edge_attr_ = torch.cat([la_edge_attr, lig_node_attr[la_edge_index[0], :self.ns],atom_node_attr[la_edge_index[1], :self.ns]], -1)
la_update = self.la_conv_layers[l](atom_node_attr, la_edge_index, la_edge_attr_, la_edge_sh,out_nodes=lig_node_attr.shape[0], edge_weight=la_edge_weight)
if l != len(self.lig_conv_layers) - 1:
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_src, :self.ns], rec_node_attr[rec_dst, :self.ns]], -1)
rec_intra_update = self.rec_conv_layers[l](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh, edge_weight=rec_edge_weight)
lig_to_rec_edge_attr_ = torch.cat([lr_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
rl_update = self.lig_to_rec_conv_layers[l](lig_node_attr, torch.flip(lr_edge_index, dims=[0]),lig_to_rec_edge_attr_,lr_edge_sh, out_nodes=rec_node_attr.shape[0],edge_weight=lr_edge_weight)
if self.include_miscellaneous_atoms:
# ATOM UPDATES
atom_edge_attr_ = torch.cat([atom_edge_attr, atom_node_attr[atom_edge_index[0], :self.ns],atom_node_attr[atom_edge_index[1], :self.ns]], -1)
atom_update = self.misc_conv_layers[l](atom_node_attr, atom_edge_index, atom_edge_attr_,atom_edge_sh, edge_weight=atom_edge_weight)
al_edge_attr_ = torch.cat([la_edge_attr, atom_node_attr[la_edge_index[1], :self.ns],lig_node_attr[la_edge_index[0], :self.ns]], -1)
al_update = self.al_conv_layers[l](lig_node_attr, torch.flip(la_edge_index, dims=[0]),al_edge_attr_,la_edge_sh, out_nodes=atom_node_attr.shape[0],edge_weight=la_edge_weight)
ar_edge_attr_ = torch.cat([ar_edge_attr, atom_node_attr[ar_edge_index[0], :self.ns],rec_node_attr[ar_edge_index[1], :self.ns]], -1)
ar_update = self.ar_conv_layers[l](rec_node_attr, ar_edge_index, ar_edge_attr_, ar_edge_sh,out_nodes=atom_node_attr.shape[0],edge_weight=ar_edge_weight)
ra_edge_attr_ = torch.cat([ar_edge_attr, rec_node_attr[ar_edge_index[1], :self.ns],atom_node_attr[ar_edge_index[0], :self.ns]], -1)
ra_update = self.ra_conv_layers[l](atom_node_attr, torch.flip(ar_edge_index, dims=[0]), ra_edge_attr_, ar_edge_sh, out_nodes=rec_node_attr.shape[0], edge_weight=ar_edge_weight)
# padding original features
lig_node_attr = F.pad(lig_node_attr, (0, lig_intra_update.shape[-1] - lig_node_attr.shape[-1]))
# update features with residual updates
lig_node_attr = lig_node_attr + lig_intra_update + lig_inter_update
if self.include_miscellaneous_atoms:
lig_node_attr += la_update
if l != len(self.lig_conv_layers) - 1:
rec_node_attr = F.pad(rec_node_attr, (0, rec_intra_update.shape[-1] - rec_node_attr.shape[-1]))
rec_node_attr = rec_node_attr + rec_intra_update + rl_update
if self.include_miscellaneous_atoms:
rec_node_attr += ra_update
atom_node_attr = F.pad(atom_node_attr, (0, atom_update.shape[-1] - atom_node_attr.shape[-1]))
atom_node_attr = atom_node_attr + atom_update + al_update + ar_update
# compute confidence score
if self.confidence_mode:
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:] ], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
return confidence
# compute translational and rotational score vectors
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
center_edge_attr = self.center_edge_embedding(center_edge_attr)
if self.fixed_center_conv:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
else:
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[0], :self.ns]], -1)
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
tr_pred = global_pred[:, :3] + (global_pred[:, 6:9] if not self.odd_parity else 0)
rot_pred = global_pred[:, 3:6] + (global_pred[:, 9:] if not self.odd_parity else 0)
if self.separate_noise_schedule:
data.graph_sigma_emb = torch.cat([self.timestep_emb_func(data.complex_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
elif self.asyncronous_noise_schedule:
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['t'])
else: # tr rot and tor noise is all the same in this case
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# fix the magnitude of translational and rotational score vectors
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
if self.scale_by_sigma:
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device)
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh, tor_edge_weight = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean', edge_weight=tor_edge_weight)
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
if self.scale_by_sigma:
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
.to(data['ligand'].x.device))
return tr_pred, rot_pred, tor_pred
def get_edge_weight(self, edge_vec, max_norm):
# computes weights for edges that are decreasing with the distance
# it has an effect only if smooth edges is true
if self.smooth_edges:
normalised_norm = torch.clip(edge_vec.norm(dim=-1) * np.pi / max_norm, max=np.pi)
return 0.5 * (torch.cos(normalised_norm) + 1.0).unsqueeze(-1)
return 1.0
def build_lig_conv_graph(self, data):
# builds the ligand graph edges and initial node and edge features
if self.separate_noise_schedule:
data['ligand'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['ligand'].node_t[noise_type]) for noise_type in ['tr','rot','tor']], dim=1)
elif self.asyncronous_noise_schedule:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['t'])
else:
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr']) # tr rot and tor noise is all the same
# compute edges
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
edge_attr = torch.cat([
data['ligand', 'ligand'].edge_attr,
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
], 0)
# compute initial features
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
src, dst = edge_index
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_rec_conv_graph(self, data):
# builds the receptor initial node and edge embeddings
if self.separate_noise_schedule:
data['receptor'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['receptor'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']], dim=1)
elif self.asyncronous_noise_schedule:
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['t'])
else:
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['receptor', 'receptor'].edge_index
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.rec_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_misc_atom_conv_graph(self, data):
# build the graph between receptor misc_atoms
if self.separate_noise_schedule:
data['misc_atom'].node_sigma_emb = torch.cat([self.timestep_emb_func(data['misc_atom'].node_t[noise_type]) for noise_type in ['tr', 'rot', 'tor']],dim=1)
elif self.asyncronous_noise_schedule:
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['t'])
else:
data['misc_atom'].node_sigma_emb = self.timestep_emb_func(data['misc_atom'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['misc_atom'].x, data['misc_atom'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['misc_atom', 'misc_atom'].edge_index
src, dst = edge_index
edge_vec = data['misc_atom'].pos[dst.long()] - data['misc_atom'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['misc_atom'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return node_attr, edge_index, edge_attr, edge_sh, edge_weight
def build_cross_conv_graph(self, data, cross_distance_cutoff):
# builds the cross edges between ligand and receptor
if torch.is_tensor(cross_distance_cutoff):
# different cutoff for every graph (depends on the diffusion time)
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
cutoff_d = cross_distance_cutoff[data['ligand'].batch[src]].squeeze() if torch.is_tensor(cross_distance_cutoff) else cross_distance_cutoff
edge_weight = self.get_edge_weight(edge_vec, cutoff_d)
return edge_index, edge_attr, edge_sh, edge_weight
def build_misc_cross_conv_graph(self, data, lr_cross_distance_cutoff):
# build the cross edges between ligan atoms, receptor residues and receptor atoms
# LIGAND to RECEPTOR
if torch.is_tensor(lr_cross_distance_cutoff):
# different cutoff for every graph
lr_edge_index = radius(data['receptor'].pos / lr_cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / lr_cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
lr_edge_index = radius(data['receptor'].pos, data['ligand'].pos, lr_cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
lr_edge_vec = data['receptor'].pos[lr_edge_index[1].long()] - data['ligand'].pos[lr_edge_index[0].long()]
lr_edge_length_emb = self.cross_distance_expansion(lr_edge_vec.norm(dim=-1))
lr_edge_sigma_emb = data['ligand'].node_sigma_emb[lr_edge_index[0].long()]
lr_edge_attr = torch.cat([lr_edge_sigma_emb, lr_edge_length_emb], 1)
lr_edge_sh = o3.spherical_harmonics(self.sh_irreps, lr_edge_vec, normalize=True, normalization='component')
cutoff_d = lr_cross_distance_cutoff[data['ligand'].batch[lr_edge_index[0]]].squeeze() \
if torch.is_tensor(lr_cross_distance_cutoff) else lr_cross_distance_cutoff
lr_edge_weight = self.get_edge_weight(lr_edge_vec, cutoff_d)
# LIGAND to ATOM
la_edge_index = radius(data['misc_atom'].pos, data['ligand'].pos, self.lig_max_radius,
data['misc_atom'].batch, data['ligand'].batch, max_num_neighbors=10000)
la_edge_vec = data['misc_atom'].pos[la_edge_index[1].long()] - data['ligand'].pos[la_edge_index[0].long()]
la_edge_length_emb = self.cross_distance_expansion(la_edge_vec.norm(dim=-1))
la_edge_sigma_emb = data['ligand'].node_sigma_emb[la_edge_index[0].long()]
la_edge_attr = torch.cat([la_edge_sigma_emb, la_edge_length_emb], 1)
la_edge_sh = o3.spherical_harmonics(self.sh_irreps, la_edge_vec, normalize=True, normalization='component')
la_edge_weight = self.get_edge_weight(la_edge_vec, self.lig_max_radius)
# ATOM to RECEPTOR
ar_edge_index = data['misc_atom', 'receptor'].edge_index
ar_edge_vec = data['receptor'].pos[ar_edge_index[1].long()] - data['misc_atom'].pos[ar_edge_index[0].long()]
ar_edge_length_emb = self.rec_distance_expansion(ar_edge_vec.norm(dim=-1))
ar_edge_sigma_emb = data['misc_atom'].node_sigma_emb[ar_edge_index[0].long()]
ar_edge_attr = torch.cat([ar_edge_sigma_emb, ar_edge_length_emb], 1)
ar_edge_sh = o3.spherical_harmonics(self.sh_irreps, ar_edge_vec, normalize=True, normalization='component')
ar_edge_weight = 1
return lr_edge_index, lr_edge_attr, lr_edge_sh, lr_edge_weight, la_edge_index, la_edge_attr, \
la_edge_sh, la_edge_weight, ar_edge_index, ar_edge_attr, ar_edge_sh, ar_edge_weight
def build_center_conv_graph(self, data):
# builds the filter and edges for the convolution generating translational and rotational scores
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return edge_index, edge_attr, edge_sh
def build_bond_conv_graph(self, data):
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
bond_batch = data['ligand'].batch[bonds[0]]
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = self.final_edge_embedding(edge_attr)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
edge_weight = self.get_edge_weight(edge_vec, self.lig_max_radius)
return bonds, edge_index, edge_attr, edge_sh, edge_weight

View File

@@ -1,442 +0,0 @@
import math
from e3nn import o3
import torch
from torch import nn
from torch.nn import functional as F
from torch_cluster import radius, radius_graph
from torch_scatter import scatter, scatter_mean
import numpy as np
from e3nn.nn import BatchNorm
from utils import so3, torus
from datasets.process_mols import lig_feature_dims, rec_residue_feature_dims
class AtomEncoder(torch.nn.Module):
def __init__(self, emb_dim, feature_dims, sigma_embed_dim, lm_embedding_type= None):
# first element of feature_dims tuple is a list with the lenght of each categorical feature and the second is the number of scalar features
super(AtomEncoder, self).__init__()
self.atom_embedding_list = torch.nn.ModuleList()
self.num_categorical_features = len(feature_dims[0])
self.num_scalar_features = feature_dims[1] + sigma_embed_dim
self.lm_embedding_type = lm_embedding_type
for i, dim in enumerate(feature_dims[0]):
emb = torch.nn.Embedding(dim, emb_dim)
torch.nn.init.xavier_uniform_(emb.weight.data)
self.atom_embedding_list.append(emb)
if self.num_scalar_features > 0:
self.linear = torch.nn.Linear(self.num_scalar_features, emb_dim)
if self.lm_embedding_type is not None:
if self.lm_embedding_type == 'esm':
self.lm_embedding_dim = 1280
else: raise ValueError('LM Embedding type was not correctly determined. LM embedding type: ', self.lm_embedding_type)
self.lm_embedding_layer = torch.nn.Linear(self.lm_embedding_dim + emb_dim, emb_dim)
def forward(self, x):
x_embedding = 0
if self.lm_embedding_type is not None:
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features + self.lm_embedding_dim
else:
assert x.shape[1] == self.num_categorical_features + self.num_scalar_features
for i in range(self.num_categorical_features):
x_embedding += self.atom_embedding_list[i](x[:, i].long())
if self.num_scalar_features > 0:
x_embedding += self.linear(x[:, self.num_categorical_features:self.num_categorical_features + self.num_scalar_features])
if self.lm_embedding_type is not None:
x_embedding = self.lm_embedding_layer(torch.cat([x_embedding, x[:, -self.lm_embedding_dim:]], axis=1))
return x_embedding
class TensorProductConvLayer(torch.nn.Module):
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
hidden_features=None):
super(TensorProductConvLayer, self).__init__()
self.in_irreps = in_irreps
self.out_irreps = out_irreps
self.sh_irreps = sh_irreps
self.residual = residual
if hidden_features is None:
hidden_features = n_edge_features
self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
self.fc = nn.Sequential(
nn.Linear(n_edge_features, hidden_features),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_features, tp.weight_numel)
)
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean'):
edge_src, edge_dst = edge_index
tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr))
out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
if self.residual:
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
out = out + padded
if self.batch_norm:
out = self.batch_norm(out)
return out
class TensorProductScoreModel(torch.nn.Module):
def __init__(self, t_to_sigma, device, timestep_emb_func, in_lig_edge_features=4, sigma_embed_dim=32, sh_lmax=2,
ns=16, nv=4, num_conv_layers=2, lig_max_radius=5, rec_max_radius=30, cross_max_distance=250,
center_max_distance=30, distance_embed_dim=32, cross_distance_embed_dim=32, no_torsion=False,
scale_by_sigma=True, use_second_order_repr=False, batch_norm=True,
dynamic_max_cross=False, dropout=0.0, lm_embedding_type=None, confidence_mode=False,
confidence_dropout=0, confidence_no_batchnorm=False, num_confidence_outputs=1):
super(TensorProductScoreModel, self).__init__()
self.t_to_sigma = t_to_sigma
self.in_lig_edge_features = in_lig_edge_features
self.sigma_embed_dim = sigma_embed_dim
self.lig_max_radius = lig_max_radius
self.rec_max_radius = rec_max_radius
self.cross_max_distance = cross_max_distance
self.dynamic_max_cross = dynamic_max_cross
self.center_max_distance = center_max_distance
self.distance_embed_dim = distance_embed_dim
self.cross_distance_embed_dim = cross_distance_embed_dim
self.sh_irreps = o3.Irreps.spherical_harmonics(lmax=sh_lmax)
self.ns, self.nv = ns, nv
self.scale_by_sigma = scale_by_sigma
self.device = device
self.no_torsion = no_torsion
self.timestep_emb_func = timestep_emb_func
self.confidence_mode = confidence_mode
self.num_conv_layers = num_conv_layers
self.lig_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=lig_feature_dims, sigma_embed_dim=sigma_embed_dim)
self.lig_edge_embedding = nn.Sequential(nn.Linear(in_lig_edge_features + sigma_embed_dim + distance_embed_dim, ns),nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.rec_node_embedding = AtomEncoder(emb_dim=ns, feature_dims=rec_residue_feature_dims, sigma_embed_dim=sigma_embed_dim, lm_embedding_type=lm_embedding_type)
self.rec_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.cross_edge_embedding = nn.Sequential(nn.Linear(sigma_embed_dim + cross_distance_embed_dim, ns), nn.ReLU(), nn.Dropout(dropout),nn.Linear(ns, ns))
self.lig_distance_expansion = GaussianSmearing(0.0, lig_max_radius, distance_embed_dim)
self.rec_distance_expansion = GaussianSmearing(0.0, rec_max_radius, distance_embed_dim)
self.cross_distance_expansion = GaussianSmearing(0.0, cross_max_distance, cross_distance_embed_dim)
if use_second_order_repr:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o + {nv}x2e',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {ns}x0o'
]
else:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o',
f'{ns}x0e + {nv}x1o + {nv}x1e',
f'{ns}x0e + {nv}x1o + {nv}x1e + {ns}x0o'
]
lig_conv_layers, rec_conv_layers, lig_to_rec_conv_layers, rec_to_lig_conv_layers = [], [], [], []
for i in range(num_conv_layers):
in_irreps = irrep_seq[min(i, len(irrep_seq) - 1)]
out_irreps = irrep_seq[min(i + 1, len(irrep_seq) - 1)]
parameters = {
'in_irreps': in_irreps,
'sh_irreps': self.sh_irreps,
'out_irreps': out_irreps,
'n_edge_features': 3 * ns,
'hidden_features': 3 * ns,
'residual': False,
'batch_norm': batch_norm,
'dropout': dropout
}
lig_layer = TensorProductConvLayer(**parameters)
lig_conv_layers.append(lig_layer)
rec_layer = TensorProductConvLayer(**parameters)
rec_conv_layers.append(rec_layer)
lig_to_rec_layer = TensorProductConvLayer(**parameters)
lig_to_rec_conv_layers.append(lig_to_rec_layer)
rec_to_lig_layer = TensorProductConvLayer(**parameters)
rec_to_lig_conv_layers.append(rec_to_lig_layer)
self.lig_conv_layers = nn.ModuleList(lig_conv_layers)
self.rec_conv_layers = nn.ModuleList(rec_conv_layers)
self.lig_to_rec_conv_layers = nn.ModuleList(lig_to_rec_conv_layers)
self.rec_to_lig_conv_layers = nn.ModuleList(rec_to_lig_conv_layers)
if self.confidence_mode:
self.confidence_predictor = nn.Sequential(
nn.Linear(2*self.ns if num_conv_layers >= 3 else self.ns,ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, ns),
nn.BatchNorm1d(ns) if not confidence_no_batchnorm else nn.Identity(),
nn.ReLU(),
nn.Dropout(confidence_dropout),
nn.Linear(ns, num_confidence_outputs)
)
else:
# center of mass translation and rotation components
self.center_distance_expansion = GaussianSmearing(0.0, center_max_distance, distance_embed_dim)
self.center_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim + sigma_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_conv = TensorProductConvLayer(
in_irreps=self.lig_conv_layers[-1].out_irreps,
sh_irreps=self.sh_irreps,
out_irreps=f'2x1o + 2x1e',
n_edge_features=2 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tr_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
self.rot_final_layer = nn.Sequential(nn.Linear(1 + sigma_embed_dim, ns),nn.Dropout(dropout), nn.ReLU(), nn.Linear(ns, 1))
if not no_torsion:
# torsion angles components
self.final_edge_embedding = nn.Sequential(
nn.Linear(distance_embed_dim, ns),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(ns, ns)
)
self.final_tp_tor = o3.FullTensorProduct(self.sh_irreps, "2e")
self.tor_bond_conv = TensorProductConvLayer(
in_irreps=self.lig_conv_layers[-1].out_irreps,
sh_irreps=self.final_tp_tor.irreps_out,
out_irreps=f'{ns}x0o + {ns}x0e',
n_edge_features=3 * ns,
residual=False,
dropout=dropout,
batch_norm=batch_norm
)
self.tor_final_layer = nn.Sequential(
nn.Linear(2 * ns, ns, bias=False),
nn.Tanh(),
nn.Dropout(dropout),
nn.Linear(ns, 1, bias=False)
)
def forward(self, data):
if not self.confidence_mode:
tr_sigma, rot_sigma, tor_sigma = self.t_to_sigma(*[data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']])
else:
tr_sigma, rot_sigma, tor_sigma = [data.complex_t[noise_type] for noise_type in ['tr', 'rot', 'tor']]
# build ligand graph
lig_node_attr, lig_edge_index, lig_edge_attr, lig_edge_sh = self.build_lig_conv_graph(data)
lig_src, lig_dst = lig_edge_index
lig_node_attr = self.lig_node_embedding(lig_node_attr)
lig_edge_attr = self.lig_edge_embedding(lig_edge_attr)
# build receptor graph
rec_node_attr, rec_edge_index, rec_edge_attr, rec_edge_sh = self.build_rec_conv_graph(data)
rec_src, rec_dst = rec_edge_index
rec_node_attr = self.rec_node_embedding(rec_node_attr)
rec_edge_attr = self.rec_edge_embedding(rec_edge_attr)
# build cross graph
if self.dynamic_max_cross:
cross_cutoff = (tr_sigma * 3 + 20).unsqueeze(1)
else:
cross_cutoff = self.cross_max_distance
cross_edge_index, cross_edge_attr, cross_edge_sh = self.build_cross_conv_graph(data, cross_cutoff)
cross_lig, cross_rec = cross_edge_index
cross_edge_attr = self.cross_edge_embedding(cross_edge_attr)
for l in range(len(self.lig_conv_layers)):
# intra graph message passing
lig_edge_attr_ = torch.cat([lig_edge_attr, lig_node_attr[lig_src, :self.ns], lig_node_attr[lig_dst, :self.ns]], -1)
lig_intra_update = self.lig_conv_layers[l](lig_node_attr, lig_edge_index, lig_edge_attr_, lig_edge_sh)
# inter graph message passing
rec_to_lig_edge_attr_ = torch.cat([cross_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
lig_inter_update = self.rec_to_lig_conv_layers[l](rec_node_attr, cross_edge_index, rec_to_lig_edge_attr_, cross_edge_sh,
out_nodes=lig_node_attr.shape[0])
if l != len(self.lig_conv_layers) - 1:
rec_edge_attr_ = torch.cat([rec_edge_attr, rec_node_attr[rec_src, :self.ns], rec_node_attr[rec_dst, :self.ns]], -1)
rec_intra_update = self.rec_conv_layers[l](rec_node_attr, rec_edge_index, rec_edge_attr_, rec_edge_sh)
lig_to_rec_edge_attr_ = torch.cat([cross_edge_attr, lig_node_attr[cross_lig, :self.ns], rec_node_attr[cross_rec, :self.ns]], -1)
rec_inter_update = self.lig_to_rec_conv_layers[l](lig_node_attr, torch.flip(cross_edge_index, dims=[0]), lig_to_rec_edge_attr_,
cross_edge_sh, out_nodes=rec_node_attr.shape[0])
# padding original features
lig_node_attr = F.pad(lig_node_attr, (0, lig_intra_update.shape[-1] - lig_node_attr.shape[-1]))
# update features with residual updates
lig_node_attr = lig_node_attr + lig_intra_update + lig_inter_update
if l != len(self.lig_conv_layers) - 1:
rec_node_attr = F.pad(rec_node_attr, (0, rec_intra_update.shape[-1] - rec_node_attr.shape[-1]))
rec_node_attr = rec_node_attr + rec_intra_update + rec_inter_update
# compute confidence score
if self.confidence_mode:
scalar_lig_attr = torch.cat([lig_node_attr[:,:self.ns],lig_node_attr[:,-self.ns:] ], dim=1) if self.num_conv_layers >= 3 else lig_node_attr[:,:self.ns]
confidence = self.confidence_predictor(scatter_mean(scalar_lig_attr, data['ligand'].batch, dim=0)).squeeze(dim=-1)
return confidence
# compute translational and rotational score vectors
center_edge_index, center_edge_attr, center_edge_sh = self.build_center_conv_graph(data)
center_edge_attr = self.center_edge_embedding(center_edge_attr)
center_edge_attr = torch.cat([center_edge_attr, lig_node_attr[center_edge_index[1], :self.ns]], -1)
global_pred = self.final_conv(lig_node_attr, center_edge_index, center_edge_attr, center_edge_sh, out_nodes=data.num_graphs)
tr_pred = global_pred[:, :3] + global_pred[:, 6:9]
rot_pred = global_pred[:, 3:6] + global_pred[:, 9:]
data.graph_sigma_emb = self.timestep_emb_func(data.complex_t['tr'])
# fix the magnitude of translational and rotational score vectors
tr_norm = torch.linalg.vector_norm(tr_pred, dim=1).unsqueeze(1)
tr_pred = tr_pred / tr_norm * self.tr_final_layer(torch.cat([tr_norm, data.graph_sigma_emb], dim=1))
rot_norm = torch.linalg.vector_norm(rot_pred, dim=1).unsqueeze(1)
rot_pred = rot_pred / rot_norm * self.rot_final_layer(torch.cat([rot_norm, data.graph_sigma_emb], dim=1))
if self.scale_by_sigma:
tr_pred = tr_pred / tr_sigma.unsqueeze(1)
rot_pred = rot_pred * so3.score_norm(rot_sigma.cpu()).unsqueeze(1).to(data['ligand'].x.device)
if self.no_torsion or data['ligand'].edge_mask.sum() == 0: return tr_pred, rot_pred, torch.empty(0, device=self.device)
# torsional components
tor_bonds, tor_edge_index, tor_edge_attr, tor_edge_sh = self.build_bond_conv_graph(data)
tor_bond_vec = data['ligand'].pos[tor_bonds[1]] - data['ligand'].pos[tor_bonds[0]]
tor_bond_attr = lig_node_attr[tor_bonds[0]] + lig_node_attr[tor_bonds[1]]
tor_bonds_sh = o3.spherical_harmonics("2e", tor_bond_vec, normalize=True, normalization='component')
tor_edge_sh = self.final_tp_tor(tor_edge_sh, tor_bonds_sh[tor_edge_index[0]])
tor_edge_attr = torch.cat([tor_edge_attr, lig_node_attr[tor_edge_index[1], :self.ns],
tor_bond_attr[tor_edge_index[0], :self.ns]], -1)
tor_pred = self.tor_bond_conv(lig_node_attr, tor_edge_index, tor_edge_attr, tor_edge_sh,
out_nodes=data['ligand'].edge_mask.sum(), reduce='mean')
tor_pred = self.tor_final_layer(tor_pred).squeeze(1)
edge_sigma = tor_sigma[data['ligand'].batch][data['ligand', 'ligand'].edge_index[0]][data['ligand'].edge_mask]
if self.scale_by_sigma:
tor_pred = tor_pred * torch.sqrt(torch.tensor(torus.score_norm(edge_sigma.cpu().numpy())).float()
.to(data['ligand'].x.device))
return tr_pred, rot_pred, tor_pred
def build_lig_conv_graph(self, data):
# builds the ligand graph edges and initial node and edge features
data['ligand'].node_sigma_emb = self.timestep_emb_func(data['ligand'].node_t['tr'])
# compute edges
radius_edges = radius_graph(data['ligand'].pos, self.lig_max_radius, data['ligand'].batch)
edge_index = torch.cat([data['ligand', 'ligand'].edge_index, radius_edges], 1).long()
edge_attr = torch.cat([
data['ligand', 'ligand'].edge_attr,
torch.zeros(radius_edges.shape[-1], self.in_lig_edge_features, device=data['ligand'].x.device)
], 0)
# compute initial features
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
node_attr = torch.cat([data['ligand'].x, data['ligand'].node_sigma_emb], 1)
src, dst = edge_index
edge_vec = data['ligand'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = torch.cat([edge_attr, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return node_attr, edge_index, edge_attr, edge_sh
def build_rec_conv_graph(self, data):
# builds the receptor initial node and edge embeddings
data['receptor'].node_sigma_emb = self.timestep_emb_func(data['receptor'].node_t['tr']) # tr rot and tor noise is all the same
node_attr = torch.cat([data['receptor'].x, data['receptor'].node_sigma_emb], 1)
# this assumes the edges were already created in preprocessing since protein's structure is fixed
edge_index = data['receptor', 'receptor'].edge_index
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['receptor'].pos[src.long()]
edge_length_emb = self.rec_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['receptor'].node_sigma_emb[edge_index[0].long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return node_attr, edge_index, edge_attr, edge_sh
def build_cross_conv_graph(self, data, cross_distance_cutoff):
# builds the cross edges between ligand and receptor
if torch.is_tensor(cross_distance_cutoff):
# different cutoff for every graph (depends on the diffusion time)
edge_index = radius(data['receptor'].pos / cross_distance_cutoff[data['receptor'].batch],
data['ligand'].pos / cross_distance_cutoff[data['ligand'].batch], 1,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
else:
edge_index = radius(data['receptor'].pos, data['ligand'].pos, cross_distance_cutoff,
data['receptor'].batch, data['ligand'].batch, max_num_neighbors=10000)
src, dst = edge_index
edge_vec = data['receptor'].pos[dst.long()] - data['ligand'].pos[src.long()]
edge_length_emb = self.cross_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[src.long()]
edge_attr = torch.cat([edge_sigma_emb, edge_length_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return edge_index, edge_attr, edge_sh
def build_center_conv_graph(self, data):
# builds the filter and edges for the convolution generating translational and rotational scores
edge_index = torch.cat([data['ligand'].batch.unsqueeze(0), torch.arange(len(data['ligand'].batch)).to(data['ligand'].x.device).unsqueeze(0)], dim=0)
center_pos, count = torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device), torch.zeros((data.num_graphs, 3)).to(data['ligand'].x.device)
center_pos.index_add_(0, index=data['ligand'].batch, source=data['ligand'].pos)
center_pos = center_pos / torch.bincount(data['ligand'].batch).unsqueeze(1)
edge_vec = data['ligand'].pos[edge_index[1]] - center_pos[edge_index[0]]
edge_attr = self.center_distance_expansion(edge_vec.norm(dim=-1))
edge_sigma_emb = data['ligand'].node_sigma_emb[edge_index[1].long()]
edge_attr = torch.cat([edge_attr, edge_sigma_emb], 1)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return edge_index, edge_attr, edge_sh
def build_bond_conv_graph(self, data):
# builds the graph for the convolution between the center of the rotatable bonds and the neighbouring nodes
bonds = data['ligand', 'ligand'].edge_index[:, data['ligand'].edge_mask].long()
bond_pos = (data['ligand'].pos[bonds[0]] + data['ligand'].pos[bonds[1]]) / 2
bond_batch = data['ligand'].batch[bonds[0]]
edge_index = radius(data['ligand'].pos, bond_pos, self.lig_max_radius, batch_x=data['ligand'].batch, batch_y=bond_batch)
edge_vec = data['ligand'].pos[edge_index[1]] - bond_pos[edge_index[0]]
edge_attr = self.lig_distance_expansion(edge_vec.norm(dim=-1))
edge_attr = self.final_edge_embedding(edge_attr)
edge_sh = o3.spherical_harmonics(self.sh_irreps, edge_vec, normalize=True, normalization='component')
return bonds, edge_index, edge_attr, edge_sh
class GaussianSmearing(torch.nn.Module):
# used to embed the edge distances
def __init__(self, start=0.0, stop=5.0, num_gaussians=50):
super().__init__()
offset = torch.linspace(start, stop, num_gaussians)
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
self.register_buffer('offset', offset)
def forward(self, dist):
dist = dist.view(-1, 1) - self.offset.view(1, -1)
return torch.exp(self.coeff * torch.pow(dist, 2))

255
models/tensor_layers.py Normal file
View File

@@ -0,0 +1,255 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from e3nn import o3
from e3nn.nn import BatchNorm
from e3nn.o3 import TensorProduct, Linear
from torch_scatter import scatter, scatter_mean
from models.layers import FCBlock
def get_irrep_seq(ns, nv, use_second_order_repr, reduce_pseudoscalars):
if use_second_order_repr:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o + {nv}x2e',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o',
f'{ns}x0e + {nv}x1o + {nv}x2e + {nv}x1e + {nv}x2o + {nv if reduce_pseudoscalars else ns}x0o'
]
else:
irrep_seq = [
f'{ns}x0e',
f'{ns}x0e + {nv}x1o',
f'{ns}x0e + {nv}x1o + {nv}x1e',
f'{ns}x0e + {nv}x1o + {nv}x1e + {nv if reduce_pseudoscalars else ns}x0o'
]
return irrep_seq
def irrep_to_size(irrep):
irreps = irrep.split(' + ')
size = 0
for ir in irreps:
m, (l, p) = ir.split('x')
size += int(m) * (2 * int(l) + 1)
return size
class FasterTensorProduct(torch.nn.Module):
# Implemented by Bowen Jing
def __init__(self, in_irreps, sh_irreps, out_irreps, **kwargs):
super().__init__()
#for ir in in_irreps:
# m, (l, p) = ir
# assert l in [0, 1], "Higher order in irreps are not supported"
#for ir in out_irreps:
# m, (l, p) = ir
# assert l in [0, 1], "Higher order out irreps are not supported"
assert o3.Irreps(sh_irreps) == o3.Irreps('1x0e+1x1o'), "sh_irreps don't look like 1st order spherical harmonics"
self.in_irreps = o3.Irreps(in_irreps)
self.out_irreps = o3.Irreps(out_irreps)
in_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
out_muls = {'0e': 0, '1o': 0, '1e': 0, '0o': 0}
for (m, ir) in self.in_irreps: in_muls[str(ir)] = m
for (m, ir) in self.out_irreps: out_muls[str(ir)] = m
self.weight_shapes = {
'0e': (in_muls['0e'] + in_muls['1o'], out_muls['0e']),
'1o': (in_muls['0e'] + in_muls['1o'] + in_muls['1e'], out_muls['1o']),
'1e': (in_muls['1o'] + in_muls['1e'] + in_muls['0o'], out_muls['1e']),
'0o': (in_muls['1e'] + in_muls['0o'], out_muls['0o'])
}
self.weight_numel = sum(a * b for (a, b) in self.weight_shapes.values())
def forward(self, in_, sh, weight):
in_dict, out_dict = {}, {'0e': [], '1o': [], '1e': [], '0o': []}
for (m, ir), sl in zip(self.in_irreps, self.in_irreps.slices()):
in_dict[str(ir)] = in_[..., sl]
if ir[0] == 1: in_dict[str(ir)] = in_dict[str(ir)].reshape(list(in_dict[str(ir)].shape)[:-1] + [-1, 3])
sh_0e, sh_1o = sh[..., 0], sh[..., 1:]
if '0e' in in_dict:
out_dict['0e'].append(in_dict['0e'] * sh_0e.unsqueeze(-1))
out_dict['1o'].append(in_dict['0e'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
if '1o' in in_dict:
out_dict['0e'].append((in_dict['1o'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
out_dict['1o'].append(in_dict['1o'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
out_dict['1e'].append(torch.linalg.cross(in_dict['1o'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
if '1e' in in_dict:
out_dict['1o'].append(torch.linalg.cross(in_dict['1e'], sh_1o.unsqueeze(-2), dim=-1) / np.sqrt(2))
out_dict['1e'].append(in_dict['1e'] * sh_0e.unsqueeze(-1).unsqueeze(-1))
out_dict['0o'].append((in_dict['1e'] * sh_1o.unsqueeze(-2)).sum(-1) / np.sqrt(3))
if '0o' in in_dict:
out_dict['1e'].append(in_dict['0o'].unsqueeze(-1) * sh_1o.unsqueeze(-2))
out_dict['0o'].append(in_dict['0o'] * sh_0e.unsqueeze(-1))
weight_dict = {}
start = 0
for key in self.weight_shapes:
in_, out = self.weight_shapes[key]
weight_dict[key] = weight[..., start:start + in_ * out].reshape(
list(weight.shape)[:-1] + [in_, out]) / np.sqrt(in_)
start += in_ * out
if out_dict['0e']:
out_dict['0e'] = torch.cat(out_dict['0e'], dim=-1)
out_dict['0e'] = torch.matmul(out_dict['0e'].unsqueeze(-2), weight_dict['0e']).squeeze(-2)
if out_dict['1o']:
out_dict['1o'] = torch.cat(out_dict['1o'], dim=-2)
out_dict['1o'] = (out_dict['1o'].unsqueeze(-2) * weight_dict['1o'].unsqueeze(-1)).sum(-3)
out_dict['1o'] = out_dict['1o'].reshape(list(out_dict['1o'].shape)[:-2] + [-1])
if out_dict['1e']:
out_dict['1e'] = torch.cat(out_dict['1e'], dim=-2)
out_dict['1e'] = (out_dict['1e'].unsqueeze(-2) * weight_dict['1e'].unsqueeze(-1)).sum(-3)
out_dict['1e'] = out_dict['1e'].reshape(list(out_dict['1e'].shape)[:-2] + [-1])
if out_dict['0o']:
out_dict['0o'] = torch.cat(out_dict['0o'], dim=-1)
# out_dict['0o'] = (out_dict['0o'].unsqueeze(-1) * weight_dict['0o']).sum(-2)
out_dict['0o'] = torch.matmul(out_dict['0o'].unsqueeze(-2), weight_dict['0o']).squeeze(-2)
out = []
for _, ir in self.out_irreps:
out.append(out_dict[str(ir)])
return torch.cat(out, dim=-1)
class TensorProductConvLayer(torch.nn.Module):
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
hidden_features=None, faster=False, edge_groups=1, tp_weights_layers=2, activation='relu', depthwise=False):
super(TensorProductConvLayer, self).__init__()
self.in_irreps = in_irreps
self.out_irreps = out_irreps
self.sh_irreps = sh_irreps
self.residual = residual
self.edge_groups = edge_groups
self.out_size = irrep_to_size(out_irreps)
self.depthwise = depthwise
if hidden_features is None:
hidden_features = n_edge_features
if depthwise:
in_irreps = o3.Irreps(in_irreps)
sh_irreps = o3.Irreps(sh_irreps)
out_irreps = o3.Irreps(out_irreps)
irreps_mid = []
instructions = []
for i, (mul, ir_in) in enumerate(in_irreps):
for j, (_, ir_edge) in enumerate(sh_irreps):
for ir_out in ir_in * ir_edge:
if ir_out in out_irreps:
k = len(irreps_mid)
irreps_mid.append((mul, ir_out))
instructions.append((i, j, k, "uvu", True))
# We sort the output irreps of the tensor product so that we can simplify them
# when they are provided to the second o3.Linear
irreps_mid = o3.Irreps(irreps_mid)
irreps_mid, p, _ = irreps_mid.sort()
# Permute the output indexes of the instructions to match the sorted irreps:
instructions = [
(i_in1, i_in2, p[i_out], mode, train)
for i_in1, i_in2, i_out, mode, train in instructions
]
self.tp = TensorProduct(
in_irreps,
sh_irreps,
irreps_mid,
instructions,
shared_weights=False,
internal_weights=False,
)
self.linear_2 = Linear(
# irreps_mid has uncoallesed irreps because of the uvu instructions,
# but there's no reason to treat them seperately for the Linear
# Note that normalization of o3.Linear changes if irreps are coallesed
# (likely for the better)
irreps_in=irreps_mid.simplify(),
irreps_out=out_irreps,
internal_weights=True,
shared_weights=True,
)
else:
if faster:
print("Faster Tensor Product")
self.tp = FasterTensorProduct(in_irreps, sh_irreps, out_irreps)
else:
self.tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
if edge_groups == 1:
self.fc = FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation)
else:
self.fc = [FCBlock(n_edge_features, hidden_features, self.tp.weight_numel, tp_weights_layers, dropout, activation) for _ in range(edge_groups)]
self.fc = nn.ModuleList(self.fc)
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
if edge_index.shape[1] == 0:
out = torch.zeros((node_attr.shape[0], self.out_size), dtype=node_attr.dtype, device=node_attr.device)
else:
edge_src, edge_dst = edge_index
edge_attr_ = self.fc(edge_attr) if self.edge_groups == 1 else torch.cat(
[self.fc[i](edge_attr[i]) for i in range(self.edge_groups)], dim=0).to(node_attr.device)
tp = self.tp(node_attr[edge_dst], edge_sh, edge_attr_ * edge_weight)
out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
if self.depthwise:
out = self.linear_2(out)
if self.batch_norm:
out = self.batch_norm(out)
if self.residual:
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
out = out + padded
return out
class OldTensorProductConvLayer(torch.nn.Module):
def __init__(self, in_irreps, sh_irreps, out_irreps, n_edge_features, residual=True, batch_norm=True, dropout=0.0,
hidden_features=None):
super(OldTensorProductConvLayer, self).__init__()
self.in_irreps = in_irreps
self.out_irreps = out_irreps
self.sh_irreps = sh_irreps
self.residual = residual
if hidden_features is None:
hidden_features = n_edge_features
self.tp = tp = o3.FullyConnectedTensorProduct(in_irreps, sh_irreps, out_irreps, shared_weights=False)
self.fc = nn.Sequential(
nn.Linear(n_edge_features, hidden_features),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden_features, tp.weight_numel)
)
self.batch_norm = BatchNorm(out_irreps) if batch_norm else None
def forward(self, node_attr, edge_index, edge_attr, edge_sh, out_nodes=None, reduce='mean', edge_weight=1.0):
edge_src, edge_dst = edge_index
tp = self.tp(node_attr[edge_dst], edge_sh, self.fc(edge_attr) * edge_weight)
out_nodes = out_nodes or node_attr.shape[0]
out = scatter(tp, edge_src, dim=0, dim_size=out_nodes, reduce=reduce)
if self.residual:
padded = F.pad(node_attr, (0, out.shape[-1] - node_attr.shape[-1]))
out = out + padded
if self.batch_norm:
out = self.batch_norm(out)
return out

View File

Before

Width:  |  Height:  |  Size: 334 KiB

After

Width:  |  Height:  |  Size: 334 KiB

22
spyrmsd/LICENSE Normal file
View File

@@ -0,0 +1,22 @@
MIT License
Copyright (c) 2019-2021 Rocco Meli
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

24
spyrmsd/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Python RMSD tool with symmetry correction.
"""
from ._version import get_versions
from .due import Doi, due
versions = get_versions()
__version__ = versions["version"]
__git_revision__ = versions["full-revisionid"]
del get_versions, versions
# This will print latest Zenodo version
due.cite(
Doi("10.5281/zenodo.3631876"),
path="spyrmsd",
description="spyrmsd",
)
due.cite(
Doi("10.1186/s13321-020-00455-2"),
path="spyrmsd",
description="spyrmsd",
)

63
spyrmsd/__main__.py Normal file
View File

@@ -0,0 +1,63 @@
"""
Symmetry-corrected RMSD calculations in Python
"""
if __name__ == "__main__":
import argparse as ap
import importlib.util
import sys
from spyrmsd import io
from spyrmsd.rmsd import rmsdwrapper
parser = ap.ArgumentParser(
prog="python -m spyrmsd",
description="Symmetry-corrected RMSD calculations in Python.",
)
parser.add_argument("reference", type=str, help="Reference file")
parser.add_argument("molecules", type=str, nargs="+", help="Input file(s)")
parser.add_argument("-m", "--minimize", action="store_true", help="Minimize (fit)")
parser.add_argument(
"-c", "--center", action="store_true", help="Center molecules at origin"
)
parser.add_argument("--hydrogens", action="store_true", help="Keep hydrogen atoms")
parser.add_argument(
"-n", "--nosymm", action="store_false", help="No graph isomorphism"
)
args = parser.parse_args()
if (
importlib.util.find_spec("openbabel") is None
and importlib.util.find_spec("rdkit") is None
):
raise ImportError(
"OpenBabel or RDKit not found. Please install OpenBabel or RDKit to use sPyRMSD as a standalone tool."
)
try:
ref = io.loadmol(args.reference)
except OSError:
print("ERROR: Reference file not found.", file=sys.stderr)
exit(-1)
# Load all molecules
try:
mols = [mol for molfile in args.molecules for mol in io.loadallmols(molfile)]
except OSError:
print("ERROR: Molecule file(s) not found.", file=sys.stderr)
exit(-1)
# Loop over molecules within fil
RMSDlist = rmsdwrapper(
ref,
mols,
symmetry=args.nosymm, # args.nosymm store False
center=args.center,
minimize=args.minimize,
strip=not args.hydrogens,
)
for RMSD in RMSDlist:
print(f"{RMSD:.5f}")

693
spyrmsd/_version.py Normal file
View File

@@ -0,0 +1,693 @@
# This file helps to compute a version number in source trees obtained from
# git-archive tarball (such as those provided by githubs download-from-tag
# feature). Distribution tarballs (built by setup.py sdist) and build
# directories (produced by setup.py build) will contain a much shorter file
# that just contains the computed version number.
# This file is released into the public domain.
# Generated by versioneer-0.28
# https://github.com/python-versioneer/python-versioneer
"""Git implementation of _version.py."""
import errno
import functools
import os
import re
import subprocess
import sys
from typing import Callable, Dict
def get_keywords():
"""Get the keywords needed to look up the version information."""
# these strings will be replaced by git during git-archive.
# setup.py/versioneer.py will grep for the variable names, so they must
# each be defined on a line of their own. _version.py will just call
# get_keywords().
git_refnames = " (HEAD -> develop, refs/pull/86/head)"
git_full = "b5532dd1b677a686c9d31c7bf39446e0a66e0af8"
git_date = "2023-09-08 19:57:37 +0200"
keywords = {"refnames": git_refnames, "full": git_full, "date": git_date}
return keywords
class VersioneerConfig:
"""Container for Versioneer configuration parameters."""
def get_config():
"""Create, populate and return the VersioneerConfig() object."""
# these strings are filled in when 'setup.py versioneer' creates
# _version.py
cfg = VersioneerConfig()
cfg.VCS = "git"
cfg.style = "pep440"
cfg.tag_prefix = ""
cfg.parentdir_prefix = "None"
cfg.versionfile_source = "spyrmsd/_version.py"
cfg.verbose = False
return cfg
class NotThisMethod(Exception):
"""Exception raised if a method is not valid for the current scenario."""
LONG_VERSION_PY: Dict[str, str] = {}
HANDLERS: Dict[str, Dict[str, Callable]] = {}
def register_vcs_handler(vcs, method): # decorator
"""Create decorator to mark a method as the handler of a VCS."""
def decorate(f):
"""Store f in HANDLERS[vcs][method]."""
if vcs not in HANDLERS:
HANDLERS[vcs] = {}
HANDLERS[vcs][method] = f
return f
return decorate
def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, env=None):
"""Call the given command(s)."""
assert isinstance(commands, list)
process = None
popen_kwargs = {}
if sys.platform == "win32":
# This hides the console window if pythonw.exe is used
startupinfo = subprocess.STARTUPINFO()
startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW
popen_kwargs["startupinfo"] = startupinfo
for command in commands:
try:
dispcmd = str([command] + args)
# remember shell=False, so use git.cmd on windows, not just git
process = subprocess.Popen(
[command] + args,
cwd=cwd,
env=env,
stdout=subprocess.PIPE,
stderr=(subprocess.PIPE if hide_stderr else None),
**popen_kwargs,
)
break
except OSError:
e = sys.exc_info()[1]
if e.errno == errno.ENOENT:
continue
if verbose:
print("unable to run %s" % dispcmd)
print(e)
return None, None
else:
if verbose:
print("unable to find command, tried %s" % (commands,))
return None, None
stdout = process.communicate()[0].strip().decode()
if process.returncode != 0:
if verbose:
print("unable to run %s (error)" % dispcmd)
print("stdout was %s" % stdout)
return None, process.returncode
return stdout, process.returncode
def versions_from_parentdir(parentdir_prefix, root, verbose):
"""Try to determine the version from the parent directory name.
Source tarballs conventionally unpack into a directory that includes both
the project name and a version string. We will also support searching up
two directory levels for an appropriately named parent directory
"""
rootdirs = []
for _ in range(3):
dirname = os.path.basename(root)
if dirname.startswith(parentdir_prefix):
return {
"version": dirname[len(parentdir_prefix) :],
"full-revisionid": None,
"dirty": False,
"error": None,
"date": None,
}
rootdirs.append(root)
root = os.path.dirname(root) # up a level
if verbose:
print(
"Tried directories %s but none started with prefix %s"
% (str(rootdirs), parentdir_prefix)
)
raise NotThisMethod("rootdir doesn't start with parentdir_prefix")
@register_vcs_handler("git", "get_keywords")
def git_get_keywords(versionfile_abs):
"""Extract version information from the given file."""
# the code embedded in _version.py can just fetch the value of these
# keywords. When used from setup.py, we don't want to import _version.py,
# so we do it with a regexp instead. This function is not used from
# _version.py.
keywords = {}
try:
with open(versionfile_abs, "r") as fobj:
for line in fobj:
if line.strip().startswith("git_refnames ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["refnames"] = mo.group(1)
if line.strip().startswith("git_full ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["full"] = mo.group(1)
if line.strip().startswith("git_date ="):
mo = re.search(r'=\s*"(.*)"', line)
if mo:
keywords["date"] = mo.group(1)
except OSError:
pass
return keywords
@register_vcs_handler("git", "keywords")
def git_versions_from_keywords(keywords, tag_prefix, verbose):
"""Get version information from git keywords."""
if "refnames" not in keywords:
raise NotThisMethod("Short version file found")
date = keywords.get("date")
if date is not None:
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
# git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant
# datestamp. However we prefer "%ci" (which expands to an "ISO-8601
# -like" string, which we must then edit to make compliant), because
# it's been around since git-1.5.3, and it's too difficult to
# discover which version we're using, or to work around using an
# older one.
date = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
refnames = keywords["refnames"].strip()
if refnames.startswith("$Format"):
if verbose:
print("keywords are unexpanded, not using")
raise NotThisMethod("unexpanded keywords, not a git-archive tarball")
refs = {r.strip() for r in refnames.strip("()").split(",")}
# starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of
# just "foo-1.0". If we see a "tag: " prefix, prefer those.
TAG = "tag: "
tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)}
if not tags:
# Either we're using git < 1.8.3, or there really are no tags. We use
# a heuristic: assume all version tags have a digit. The old git %d
# expansion behaves like git log --decorate=short and strips out the
# refs/heads/ and refs/tags/ prefixes that would let us distinguish
# between branches and tags. By ignoring refnames without digits, we
# filter out many common branch names like "release" and
# "stabilization", as well as "HEAD" and "master".
tags = {r for r in refs if re.search(r"\d", r)}
if verbose:
print("discarding '%s', no digits" % ",".join(refs - tags))
if verbose:
print("likely tags: %s" % ",".join(sorted(tags)))
for ref in sorted(tags):
# sorting will prefer e.g. "2.0" over "2.0rc1"
if ref.startswith(tag_prefix):
r = ref[len(tag_prefix) :]
# Filter out refs that exactly match prefix or that don't start
# with a number once the prefix is stripped (mostly a concern
# when prefix is '')
if not re.match(r"\d", r):
continue
if verbose:
print("picking %s" % r)
return {
"version": r,
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": None,
"date": date,
}
# no suitable tags, so version is "0+unknown", but full hex is still there
if verbose:
print("no suitable tags, using unknown + full revision id")
return {
"version": "0+unknown",
"full-revisionid": keywords["full"].strip(),
"dirty": False,
"error": "no suitable tags",
"date": None,
}
@register_vcs_handler("git", "pieces_from_vcs")
def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command):
"""Get version from 'git describe' in the root of the source tree.
This only gets called if the git-archive 'subst' keywords were *not*
expanded, and _version.py hasn't already been rewritten with a short
version string, meaning we're inside a checked out source tree.
"""
GITS = ["git"]
if sys.platform == "win32":
GITS = ["git.cmd", "git.exe"]
# GIT_DIR can interfere with correct operation of Versioneer.
# It may be intended to be passed to the Versioneer-versioned project,
# but that should not change where we get our version from.
env = os.environ.copy()
env.pop("GIT_DIR", None)
runner = functools.partial(runner, env=env)
_, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose)
if rc != 0:
if verbose:
print("Directory %s not under git control" % root)
raise NotThisMethod("'git rev-parse --git-dir' returned error")
# if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty]
# if there isn't one, this yields HEX[-dirty] (no NUM)
describe_out, rc = runner(
GITS,
[
"describe",
"--tags",
"--dirty",
"--always",
"--long",
"--match",
f"{tag_prefix}[[:digit:]]*",
],
cwd=root,
)
# --long was added in git-1.5.5
if describe_out is None:
raise NotThisMethod("'git describe' failed")
describe_out = describe_out.strip()
full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root)
if full_out is None:
raise NotThisMethod("'git rev-parse' failed")
full_out = full_out.strip()
pieces = {}
pieces["long"] = full_out
pieces["short"] = full_out[:7] # maybe improved later
pieces["error"] = None
branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root)
# --abbrev-ref was added in git-1.6.3
if rc != 0 or branch_name is None:
raise NotThisMethod("'git rev-parse --abbrev-ref' returned error")
branch_name = branch_name.strip()
if branch_name == "HEAD":
# If we aren't exactly on a branch, pick a branch which represents
# the current commit. If all else fails, we are on a branchless
# commit.
branches, rc = runner(GITS, ["branch", "--contains"], cwd=root)
# --contains was added in git-1.5.4
if rc != 0 or branches is None:
raise NotThisMethod("'git branch --contains' returned error")
branches = branches.split("\n")
# Remove the first line if we're running detached
if "(" in branches[0]:
branches.pop(0)
# Strip off the leading "* " from the list of branches.
branches = [branch[2:] for branch in branches]
if "master" in branches:
branch_name = "master"
elif not branches:
branch_name = None
else:
# Pick the first branch that is returned. Good or bad.
branch_name = branches[0]
pieces["branch"] = branch_name
# parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty]
# TAG might have hyphens.
git_describe = describe_out
# look for -dirty suffix
dirty = git_describe.endswith("-dirty")
pieces["dirty"] = dirty
if dirty:
git_describe = git_describe[: git_describe.rindex("-dirty")]
# now we have TAG-NUM-gHEX or HEX
if "-" in git_describe:
# TAG-NUM-gHEX
mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe)
if not mo:
# unparsable. Maybe git-describe is misbehaving?
pieces["error"] = "unable to parse git-describe output: '%s'" % describe_out
return pieces
# tag
full_tag = mo.group(1)
if not full_tag.startswith(tag_prefix):
if verbose:
fmt = "tag '%s' doesn't start with prefix '%s'"
print(fmt % (full_tag, tag_prefix))
pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % (
full_tag,
tag_prefix,
)
return pieces
pieces["closest-tag"] = full_tag[len(tag_prefix) :]
# distance: number of commits since tag
pieces["distance"] = int(mo.group(2))
# commit: short hex revision ID
pieces["short"] = mo.group(3)
else:
# HEX: no tags
pieces["closest-tag"] = None
out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root)
pieces["distance"] = len(out.split()) # total number of commits
# commit date: see ISO-8601 comment in git_versions_from_keywords()
date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[0].strip()
# Use only the last line. Previous lines may contain GPG signature
# information.
date = date.splitlines()[-1]
pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1)
return pieces
def plus_or_dot(pieces):
"""Return a + if we don't already have one, else return a ."""
if "+" in pieces.get("closest-tag", ""):
return "."
return "+"
def render_pep440(pieces):
"""Build up version string, with post-release "local version identifier".
Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you
get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty
Exceptions:
1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_branch(pieces):
"""TAG[[.dev0]+DISTANCE.gHEX[.dirty]] .
The ".dev0" means not master branch. Note that .dev0 sorts backwards
(a feature branch will appear "older" than the master branch).
Exceptions:
1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0"
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"])
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def pep440_split_post(ver):
"""Split pep440 version string at the post-release segment.
Returns the release segments before the post-release and the
post-release version number (or -1 if no post-release segment is present).
"""
vc = str.split(ver, ".post")
return vc[0], int(vc[1] or 0) if len(vc) == 2 else None
def render_pep440_pre(pieces):
"""TAG[.postN.devDISTANCE] -- No -dirty.
Exceptions:
1: no tags. 0.post0.devDISTANCE
"""
if pieces["closest-tag"]:
if pieces["distance"]:
# update the post release segment
tag_version, post_version = pep440_split_post(pieces["closest-tag"])
rendered = tag_version
if post_version is not None:
rendered += ".post%d.dev%d" % (post_version + 1, pieces["distance"])
else:
rendered += ".post0.dev%d" % (pieces["distance"])
else:
# no commits, use the tag as the version
rendered = pieces["closest-tag"]
else:
# exception #1
rendered = "0.post0.dev%d" % pieces["distance"]
return rendered
def render_pep440_post(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX] .
The ".dev0" means dirty. Note that .dev0 sorts backwards
(a dirty tree will appear "older" than the corresponding clean one),
but you shouldn't be releasing software with -dirty anyways.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
return rendered
def render_pep440_post_branch(pieces):
"""TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] .
The ".dev0" means not master branch.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += plus_or_dot(pieces)
rendered += "g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["branch"] != "master":
rendered += ".dev0"
rendered += "+g%s" % pieces["short"]
if pieces["dirty"]:
rendered += ".dirty"
return rendered
def render_pep440_old(pieces):
"""TAG[.postDISTANCE[.dev0]] .
The ".dev0" means dirty.
Exceptions:
1: no tags. 0.postDISTANCE[.dev0]
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"] or pieces["dirty"]:
rendered += ".post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
else:
# exception #1
rendered = "0.post%d" % pieces["distance"]
if pieces["dirty"]:
rendered += ".dev0"
return rendered
def render_git_describe(pieces):
"""TAG[-DISTANCE-gHEX][-dirty].
Like 'git describe --tags --dirty --always'.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
if pieces["distance"]:
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render_git_describe_long(pieces):
"""TAG-DISTANCE-gHEX[-dirty].
Like 'git describe --tags --dirty --always -long'.
The distance/hash is unconditional.
Exceptions:
1: no tags. HEX[-dirty] (note: no 'g' prefix)
"""
if pieces["closest-tag"]:
rendered = pieces["closest-tag"]
rendered += "-%d-g%s" % (pieces["distance"], pieces["short"])
else:
# exception #1
rendered = pieces["short"]
if pieces["dirty"]:
rendered += "-dirty"
return rendered
def render(pieces, style):
"""Render the given version pieces into the requested style."""
if pieces["error"]:
return {
"version": "unknown",
"full-revisionid": pieces.get("long"),
"dirty": None,
"error": pieces["error"],
"date": None,
}
if not style or style == "default":
style = "pep440" # the default
if style == "pep440":
rendered = render_pep440(pieces)
elif style == "pep440-branch":
rendered = render_pep440_branch(pieces)
elif style == "pep440-pre":
rendered = render_pep440_pre(pieces)
elif style == "pep440-post":
rendered = render_pep440_post(pieces)
elif style == "pep440-post-branch":
rendered = render_pep440_post_branch(pieces)
elif style == "pep440-old":
rendered = render_pep440_old(pieces)
elif style == "git-describe":
rendered = render_git_describe(pieces)
elif style == "git-describe-long":
rendered = render_git_describe_long(pieces)
else:
raise ValueError("unknown style '%s'" % style)
return {
"version": rendered,
"full-revisionid": pieces["long"],
"dirty": pieces["dirty"],
"error": None,
"date": pieces.get("date"),
}
def get_versions():
"""Get version information or return default if unable to do so."""
# I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have
# __file__, we can work backwards from there to the root. Some
# py2exe/bbfreeze/non-CPython implementations don't do __file__, in which
# case we can only use expanded keywords.
cfg = get_config()
verbose = cfg.verbose
try:
return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, verbose)
except NotThisMethod:
pass
try:
root = os.path.realpath(__file__)
# versionfile_source is the relative path from the top of the source
# tree (where the .git directory might live) to this file. Invert
# this to find the root from __file__.
for _ in cfg.versionfile_source.split("/"):
root = os.path.dirname(root)
except NameError:
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to find root of source tree",
"date": None,
}
try:
pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose)
return render(pieces, cfg.style)
except NotThisMethod:
pass
try:
if cfg.parentdir_prefix:
return versions_from_parentdir(cfg.parentdir_prefix, root, verbose)
except NotThisMethod:
pass
return {
"version": "0+unknown",
"full-revisionid": None,
"dirty": None,
"error": "unable to compute version",
"date": None,
}

235
spyrmsd/constants.py Normal file
View File

@@ -0,0 +1,235 @@
"""
Useful constants.
Notes
-----
Periodic table data (atomic masses and covalent radii) are extracted from
QCElemental_
.. _QCElemental: http://docs.qcarchive.molssi.org/projects/qcelemental/en/latest/
"""
from typing import Dict
connectivity_tolerance: float = 0.4
# Values from QCElemental
anum_to_mass: Dict = {
1: 1.00782503223,
2: 4.00260325413,
3: 7.0160034366,
4: 9.012183065,
5: 11.00930536,
6: 12.0,
7: 14.00307400443,
8: 15.99491461957,
9: 18.99840316273,
10: 19.9924401762,
11: 22.989769282,
12: 23.985041697,
13: 26.98153853,
14: 27.97692653465,
15: 30.97376199842,
16: 31.9720711744,
17: 34.968852682,
18: 39.9623831237,
19: 38.9637064864,
20: 39.962590863,
21: 44.95590828,
22: 47.94794198,
23: 50.94395704,
24: 51.94050623,
25: 54.93804391,
26: 55.93493633,
27: 58.93319429,
28: 57.93534241,
29: 62.92959772,
30: 63.92914201,
31: 68.9255735,
32: 73.921177761,
33: 74.92159457,
34: 79.9165218,
35: 78.9183376,
36: 83.9114977282,
37: 84.9117897379,
38: 87.9056125,
39: 88.9058403,
40: 89.9046977,
41: 92.906373,
42: 97.90540482,
43: 97.9072124,
44: 101.9043441,
45: 102.905498,
46: 105.9034804,
47: 106.9050916,
48: 113.90336509,
49: 114.903878776,
50: 119.90220163,
51: 120.903812,
52: 129.906222748,
53: 126.9044719,
54: 131.9041550856,
55: 132.905451961,
56: 137.905247,
57: 138.9063563,
58: 139.9054431,
59: 140.9076576,
60: 141.907729,
61: 144.9127559,
62: 151.9197397,
63: 152.921238,
64: 157.9241123,
65: 158.9253547,
66: 163.9291819,
67: 164.9303288,
68: 165.9302995,
69: 168.9342179,
70: 173.9388664,
71: 174.9407752,
72: 179.946557,
73: 180.9479958,
74: 183.95093092,
75: 186.9557501,
76: 191.961477,
77: 192.9629216,
78: 194.9647917,
79: 196.96656879,
80: 201.9706434,
81: 204.9744278,
82: 207.9766525,
83: 208.9803991,
84: 208.9824308,
85: 209.9871479,
86: 222.0175782,
87: 223.019736,
88: 226.0254103,
89: 227.0277523,
90: 232.0380558,
91: 231.0358842,
92: 238.0507884,
93: 237.0481736,
94: 244.0642053,
95: 243.0613813,
96: 247.0703541,
97: 247.0703073,
98: 251.0795886,
99: 252.08298,
100: 257.0951061,
101: 258.0984315,
102: 259.10103,
103: 266.11983,
104: 267.12179,
105: 268.12567,
106: 271.13393,
107: 270.13336,
108: 269.13375,
109: 278.15631,
110: 281.16451,
111: 282.16912,
112: 285.17712,
113: 286.18221,
114: 289.19042,
115: 289.19363,
116: 293.20449,
117: 294.21046,
}
# Data from QCElemental (in angstroms)
anum_to_covalentradius: Dict = {
1: 0.31,
2: 0.28,
3: 1.28,
4: 0.96,
5: 0.84,
6: 0.76,
7: 0.71,
8: 0.66,
9: 0.57,
10: 0.58,
11: 1.66,
12: 1.41,
13: 1.21,
14: 1.11,
15: 1.07,
16: 1.05,
17: 1.02,
18: 1.06,
19: 2.03,
20: 1.76,
21: 1.7,
22: 1.6,
23: 1.53,
24: 1.39,
25: 1.61,
26: 1.52,
27: 1.5,
28: 1.24,
29: 1.32,
30: 1.22,
31: 1.22,
32: 1.2,
33: 1.19,
34: 1.2,
35: 1.2,
36: 1.16,
37: 2.2,
38: 1.95,
39: 1.9,
40: 1.75,
41: 1.64,
42: 1.54,
43: 1.47,
44: 1.46,
45: 1.42,
46: 1.39,
47: 1.45,
48: 1.44,
49: 1.42,
50: 1.39,
51: 1.39,
52: 1.38,
53: 1.39,
54: 1.4,
55: 2.44,
56: 2.15,
57: 2.07,
58: 2.04,
59: 2.03,
60: 2.01,
61: 1.99,
62: 1.98,
63: 1.98,
64: 1.96,
65: 1.94,
66: 1.92,
67: 1.92,
68: 1.89,
69: 1.9,
70: 1.87,
71: 1.87,
72: 1.75,
73: 1.7,
74: 1.62,
75: 1.51,
76: 1.44,
77: 1.41,
78: 1.36,
79: 1.36,
80: 1.32,
81: 1.45,
82: 1.46,
83: 1.48,
84: 1.4,
85: 1.5,
86: 1.5,
87: 2.6,
88: 2.21,
89: 2.15,
90: 2.06,
91: 2.0,
92: 1.96,
93: 1.9,
94: 1.87,
95: 1.8,
96: 1.69,
}

79
spyrmsd/due.py Normal file
View File

@@ -0,0 +1,79 @@
# emacs: at the end of the file
# ex: set sts=4 ts=4 sw=4 et:
# ## ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### #
"""
Stub file for a guaranteed safe import of duecredit constructs: if duecredit
is not available.
To use it, place it into your project codebase to be imported, e.g. copy as
cp stub.py /path/tomodule/module/due.py
Note that it might be better to avoid naming it duecredit.py to avoid shadowing
installed duecredit.
Then use in your code as
from .due import due, Doi, BibTeX, Text
See https://github.com/duecredit/duecredit/blob/master/README.md for examples.
Origin: Originally a part of the duecredit
Copyright: 2015-2019 DueCredit developers
License: BSD-2
"""
__version__ = "0.0.8"
class InactiveDueCreditCollector(object):
"""Just a stub at the Collector which would not do anything"""
def _donothing(self, *args, **kwargs):
"""Perform no good and no bad"""
pass
def dcite(self, *args, **kwargs):
"""If I could cite I would"""
def nondecorating_decorator(func):
return func
return nondecorating_decorator
active = False
activate = add = cite = dump = load = _donothing
def __repr__(self):
return self.__class__.__name__ + "()"
def _donothing_func(*args, **kwargs):
"""Perform no good and no bad"""
pass
try:
from duecredit import BibTeX, Doi, Text, Url, due
if "due" in locals() and not hasattr(due, "cite"):
raise RuntimeError("Imported due lacks .cite. DueCredit is now disabled")
except Exception as e:
if not isinstance(e, ImportError):
import logging
logging.getLogger("duecredit").error(
"Failed to import duecredit due to %s" % str(e)
)
# Initiate due stub
due = InactiveDueCreditCollector()
BibTeX = Doi = Url = Text = _donothing_func
# Emacs mode definitions
# Local Variables:
# mode: python
# py-indent-offset: 4
# tab-width: 4
# indent-tabs-mode: nil
# End:

6
spyrmsd/exceptions.py Normal file
View File

@@ -0,0 +1,6 @@
class NonIsomorphicGraphs(ValueError):
"""
Raised when graphs are not isomorphic
"""
pass

94
spyrmsd/graph.py Normal file
View File

@@ -0,0 +1,94 @@
try:
from spyrmsd.graphs.gt import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
)
except ImportError:
try:
from spyrmsd.graphs.nx import (
cycle,
graph_from_adjacency_matrix,
lattice,
match_graphs,
num_edges,
num_vertices,
vertex_property,
)
except ImportError:
raise ImportError("graph_tool or NetworkX libraries not found.")
__all__ = [
"graph_from_adjacency_matrix",
"match_graphs",
"vertex_property",
"num_vertices",
"num_edges",
"lattice",
"cycle",
"adjacency_matrix_from_atomic_coordinates",
]
import numpy as np
from spyrmsd import constants
def adjacency_matrix_from_atomic_coordinates(
aprops: np.ndarray, coordinates: np.ndarray
) -> np.ndarray:
"""
Compute adjacency matrix from atomic coordinates.
Parameters
----------
aprops: numpy.ndarray
Atomic properties
coordinates: numpy.ndarray
Atomic coordinates
Returns
-------
numpy.ndarray
Adjacency matrix
Notes
-----
This function is based on an automatic bond perception algorithm: two
atoms are considered to be bonded when their distance is smaller than
the sum of their covalent radii plus a tolerance value. [3]_
.. warning::
The automatic bond perceptron rule implemented in this functions
is very simple and only depends on atomic coordinates. Use
with care!
.. [3] E. C. Meng and R. A. Lewis, *Determination of molecular topology and atomic
hybridization states from heavy atom coordinates*, J. Comp. Chem. **12**, 891-898
(1991).
"""
n = len(aprops)
assert coordinates.shape == (n, 3)
A = np.zeros((n, n))
for i in range(n):
r_i = constants.anum_to_covalentradius[aprops[i]]
for j in range(i + 1, n):
r_j = constants.anum_to_covalentradius[aprops[j]]
distance = np.sqrt(np.sum((coordinates[i] - coordinates[j]) ** 2))
if distance < (r_i + r_j + constants.connectivity_tolerance):
A[i, j] = A[j, i] = 1
return A

View File

View File

@@ -0,0 +1,8 @@
warn_disconnected_graph: str = "Disconnected graph detected. Is this expected?"
warn_no_atomic_properties: str = (
"No atomic property information stored on nodes. Node matching is not performed..."
)
error_non_isomorphic_graphs: str = (
"Graphs are not isomorphic.\nMake sure graphs have the same connectivity."
)

239
spyrmsd/graphs/gt.py Normal file
View File

@@ -0,0 +1,239 @@
import warnings
from typing import Any, List, Optional, Tuple, Union
import graph_tool as gt
import numpy as np
from graph_tool import generation, topology
from spyrmsd.exceptions import NonIsomorphicGraphs
from spyrmsd.graphs._common import (
error_non_isomorphic_graphs,
warn_disconnected_graph,
warn_no_atomic_properties,
)
# TODO: Implement all graph-tool supported types
def _c_type(numpy_dtype):
"""
Get C type compatible with graph-tool from numpy dtype
Parameters
----------
numpy_dtype: np.dtype
Numpy dtype
Returns
-------
str
C type
Raises
------
ValueError
If the data type is not supported
Notes
-----
https://graph-tool.skewed.de/static/doc/quickstart.html#sec-property-maps
"""
name: str = numpy_dtype.name
if "int" in name:
return "int"
elif "float" in name:
return "double"
elif "str" in name:
return "string"
else:
raise ValueError(f"Unsupported property type: {name}")
def graph_from_adjacency_matrix(
adjacency_matrix: Union[np.ndarray, List[List[int]]],
aprops: Optional[Union[np.ndarray, List[Any]]] = None,
):
"""
Graph from adjacency matrix.
Parameters
----------
adjacency_matrix: Union[np.ndarray, List[List[int]]]
Adjacency matrix
aprops: Union[np.ndarray, List[Any]], optional
Atomic properties
Returns
-------
Graph
Molecular graph
Notes
-----
It the atomic numbers are passed, they are used as node attributes.
"""
# Get upper triangular adjacency matrix
adj = np.triu(adjacency_matrix)
assert adj.shape[0] == adj.shape[1]
num_vertices = adj.shape[0]
G = gt.Graph(directed=False)
G.add_vertex(n=num_vertices)
G.add_edge_list(np.transpose(adj.nonzero()))
# Check if graph is connected, for warning
cc, _ = topology.label_components(G)
if set(cc.a) != {0}:
warnings.warn(warn_disconnected_graph)
if aprops is not None:
if not isinstance(aprops, np.ndarray):
aprops = np.array(aprops)
assert aprops.shape[0] == num_vertices
ptype: str = _c_type(aprops.dtype) # Get C type
vprop = G.new_vertex_property(ptype, vals=aprops) # Create property map
G.vertex_properties["aprops"] = vprop # Set property map
return G
def match_graphs(G1, G2) -> List[Tuple[List[int], List[int]]]:
"""
Compute graph isomorphisms.
Parameters
----------
G1:
Graph 1
G2:
Graph 2
Returns
-------
List[Tuple[List[int],List[int]]]
All possible mappings between nodes of graph 1 and graph 2 (isomorphisms)
Raises
------
NonIsomorphicGraphs
If the graphs `G1` and `G2` are not isomorphic
"""
try:
maps = topology.subgraph_isomorphism(
G1,
G2,
vertex_label=(
G1.vertex_properties["aprops"],
G2.vertex_properties["aprops"],
),
subgraph=False,
)
except KeyError:
warnings.warn(warn_no_atomic_properties)
maps = topology.subgraph_isomorphism(G1, G2, subgraph=False)
# Check if graphs are actually isomorphic
if len(maps) == 0:
raise NonIsomorphicGraphs(error_non_isomorphic_graphs)
n = num_vertices(G1)
# Extract all isomorphisms in a list
return [(np.arange(0, n, dtype=int), m.a) for m in maps]
def vertex_property(G, vproperty: str, idx: int) -> Any:
"""
Get vertex (node) property from graph
Parameters
----------
G:
Graph
vproperty: str
Vertex property name
idx: int
Vertex index
Returns
-------
Any
Vertex property value
"""
return G.vertex_properties[vproperty][idx]
def num_vertices(G) -> int:
"""
Number of vertices
Parameters
----------
G:
Graph
Returns
-------
int
Number of vertices (nodes)
"""
return G.num_vertices()
def num_edges(G) -> int:
"""
Number of edges
Parameters
----------
G:
Graph
Returns
-------
int
Number of edges
"""
return G.num_edges()
def lattice(n1: int, n2: int):
"""
Build 2D lattice graph
Parameters
----------
n1: int
Number of nodes in dimension 1
n2: int
Number of nodes in dimension 2
Returns
-------
Graph
Lattice graph
"""
return generation.lattice((n1, n2))
def cycle(n):
"""
Build cycle graph
Parameters
----------
n: int
Number of nodes
Returns
-------
Graph
Cycle graph
"""
return generation.circular_graph(n)

192
spyrmsd/graphs/nx.py Normal file
View File

@@ -0,0 +1,192 @@
import warnings
from typing import Any, List, Optional, Tuple, Union
import networkx as nx
import numpy as np
from spyrmsd.exceptions import NonIsomorphicGraphs
from spyrmsd.graphs._common import (
error_non_isomorphic_graphs,
warn_disconnected_graph,
warn_no_atomic_properties,
)
def graph_from_adjacency_matrix(
adjacency_matrix: Union[np.ndarray, List[List[int]]],
aprops: Optional[Union[np.ndarray, List[Any]]] = None,
) -> nx.Graph:
"""
Graph from adjacency matrix.
Parameters
----------
adjacency_matrix: Union[np.ndarray, List[List[int]]]
Adjacency matrix
aprops: Union[np.ndarray, List[Any]], optional
Atomic properties
Returns
-------
Graph
Molecular graph
Notes
-----
It the atomic numbers are passed, they are used as node attributes.
"""
G = nx.Graph(adjacency_matrix)
if not nx.is_connected(G):
warnings.warn(warn_disconnected_graph)
if aprops is not None:
attributes = {idx: aprops for idx, aprops in enumerate(aprops)}
nx.set_node_attributes(G, attributes, "aprops")
return G
def match_graphs(G1, G2) -> List[Tuple[List[int], List[int]]]:
"""
Compute graph isomorphisms.
Parameters
----------
G1:
Graph 1
G2:
Graph 2
Returns
-------
List[Tuple[List[int],List[int]]]
All possible mappings between nodes of graph 1 and graph 2 (isomorphisms)
Raises
------
NonIsomorphicGraphs
If the graphs `G1` and `G2` are not isomorphic
"""
def match_aprops(node1, node2):
"""
Check if atomic properties for two nodes match.
"""
return node1["aprops"] == node2["aprops"]
if (
nx.get_node_attributes(G1, "aprops") == {}
or nx.get_node_attributes(G2, "aprops") == {}
):
# Nodes without atomic number information
# No node-matching check
node_match = None
warnings.warn(warn_no_atomic_properties)
else:
node_match = match_aprops
GM = nx.algorithms.isomorphism.GraphMatcher(G1, G2, node_match)
# Check if graphs are actually isomorphic
if not GM.is_isomorphic():
raise NonIsomorphicGraphs(error_non_isomorphic_graphs)
return [
(list(isomorphism.keys()), list(isomorphism.values()))
for isomorphism in GM.isomorphisms_iter()
]
def vertex_property(G, vproperty: str, idx: int) -> Any:
"""
Get vertex (node) property from graph
Parameters
----------
G:
Graph
vproperty: str
Vertex property name
idx: int
Vertex index
Returns
-------
Any
Vertex property value
"""
return G.nodes[idx][vproperty]
def num_vertices(G) -> int:
"""
Number of vertices
Parameters
----------
G:
Graph
Returns
-------
int
Number of vertices (nodes)
"""
return G.number_of_nodes()
def num_edges(G) -> int:
"""
Number of edges
Parameters
----------
G:
Graph
Returns
-------
int
Number of edges
"""
return G.number_of_edges()
def lattice(n1, n2):
"""
Build 2D lattice graph
Parameters
----------
n1: int
Number of nodes in dimension 1
n2: int
Number of nodes in dimension 2
Returns
-------
Graph
Lattice graph
"""
return nx.generators.lattice.grid_2d_graph(n1, n2)
def cycle(n):
"""
Build cycle graph
Parameters
----------
n: int
Number of nodes
Returns
-------
Graph
Cycle graph
"""
return nx.cycle_graph(n)

120
spyrmsd/hungarian.py Normal file
View File

@@ -0,0 +1,120 @@
import numpy as np
import scipy
from .due import Doi, due
due.cite(
Doi("10.1021/ci400534h"),
path="spyrmsd.hungarian",
description="Hungarian method",
)
def cost_mtx(A: np.ndarray, B: np.ndarray):
"""
Compute the cost matrix for atom-atom assignment.
Parameters
----------
A: numpy.ndarray
Atomic coordinates of molecule A
B: numpy.ndarray
Atomic coordinates of molecule B
Returns
-------
np.ndarray
Cost matrix of squared atomic distances between atoms of
molecules A and B
"""
return scipy.spatial.distance.cdist(A, B, metric="sqeuclidean")
def optimal_assignment(A: np.ndarray, B: np.ndarray):
"""
Solve the optimal assignment problems between atomic coordinates of
molecules A and B.
Parameters
----------
A: numpy.ndarray
Atomic coordinates of molecule A
B: numpy.ndarray
Atomic coordinates of molecule B
Returns
-------
Tuple[float, nd.array, nd.array]
Cost of the optimal assignment, together with the row and column
indices of said assignment
"""
C = cost_mtx(A, B)
row_idx, col_idx = scipy.optimize.linear_sum_assignment(C)
# Compute assignment cost
cost = C[row_idx, col_idx].sum()
return cost, row_idx, col_idx
def hungarian_rmsd(
A: np.ndarray, B: np.ndarray, apropsA: np.ndarray, apropsB: np.ndarray
) -> float:
"""
Solve the optimal assignment problems between atomic coordinates of
molecules A and B.
Parameters
----------
A: numpy.ndarray
Atomic coordinates of molecule A
B: numpy.ndarray
Atomic coordinates of molecule B
apropsA: numpy.ndarray
Atomic properties of molecule A
apropsB: numpy.ndarray
Atomic properties of molecule B
Returns
-------
float
RMSD computed with the Hungarian method
Notes
-----
The Hungarian algorithm is used to solve the linear assignment problem, which is
a minimum weight matching of the molecular graphs (bipartite).
The linear assignment problem is solved for every element separately. [1]_
.. [1] W. J. Allen and R. C. Rizzo, *Implementation of the Hungarian Algorithm to
Account for Ligand Symmetry and Similarity in Structure-Based Design*,
J. Chem. Inf. Model. **54**, 518-529 (2014)
"""
assert A.shape == B.shape
assert apropsA.shape == apropsB.shape
elements = set(apropsA)
total_cost: float = 0.0
for t in elements:
apropsA_idx = apropsA == t
apropsB_idx = apropsB == t
assert apropsA_idx.shape == apropsB_idx.shape
cost, row_idx, col_idx = optimal_assignment(
A[apropsA_idx, :], B[apropsB_idx, :]
)
total_cost += cost
N = A.shape[0]
rmsd = np.sqrt(total_cost / N)
return rmsd

87
spyrmsd/io.py Normal file
View File

@@ -0,0 +1,87 @@
try:
from spyrmsd.optional.obabel import (
adjacency_matrix,
bonds,
load,
loadall,
numatoms,
numbonds,
to_molecule,
)
except ImportError:
try:
from spyrmsd.optional.rdkit import (
adjacency_matrix,
bonds,
load,
loadall,
numatoms,
numbonds,
to_molecule,
)
except ImportError:
# Use sPyRMSD as standalone library
__all__ = []
else:
# Avoid flake8 complaint "imported but unused"
__all__ = [
"load",
"loadall",
"adjacency_matrix",
"to_molecule",
"numatoms",
"numbonds",
"bonds",
]
else:
# Avoid flake8 complaint "imported but unused"
__all__ = [
"load",
"loadall",
"adjacency_matrix",
"to_molecule",
"numatoms",
"numbonds",
"bonds",
]
def loadmol(fname: str, adjacency: bool = True):
"""
Load molecule from file.
Parameters
----------
fname: str
File name
Returns
-------
molecule.Molecule
`spyrmsd` molecule
"""
mol = load(fname)
return to_molecule(mol, adjacency=adjacency)
def loadallmols(fname: str, adjacency: bool = True):
"""
Load molecules from file.
Parameters
----------
fname: str
File name
Returns
-------
List[molecule.Molecule]
`spyrmsd` molecules
"""
mols = loadall(fname)
return [to_molecule(mol, adjacency=adjacency) for mol in mols]

266
spyrmsd/molecule.py Normal file
View File

@@ -0,0 +1,266 @@
import warnings
from typing import List, Optional, Union
import numpy as np
from spyrmsd import constants, graph, utils
class Molecule:
def __init__(
self,
atomicnums: Union[np.ndarray, List[int]],
coordinates: Union[np.ndarray, List[List[float]]],
adjacency_matrix: Optional[Union[np.ndarray, List[List[int]]]] = None,
) -> None:
"""
Molecule initialisation.
Parameters
----------
atomicnums: Union[np.ndarray, List[int]]
Atomic numbers
coordinates: Union[np.ndarray, List[List[float]]]
Atomic coordinates
adjacency_matrix: Union[np.ndarray, List[List[int]]], optional
Molecular graph adjacency matrix
Notes
-----
A molecule is built from atomic numbers and atomic coordinates only.
Optionally, a good representation of the molecular graph (obtained with
OpenBabel or RDKit) can be stored as adjacency matrix.
"""
atomicnums = np.asarray(atomicnums, dtype=int)
coordinates = np.asarray(coordinates, dtype=float)
self.natoms: int = len(atomicnums)
assert atomicnums.shape == (self.natoms,)
assert coordinates.shape == (self.natoms, 3)
self.atomicnums = atomicnums
self.coordinates = coordinates
self.stripped: bool = bool(np.all(atomicnums != 1))
if adjacency_matrix is not None:
self.adjacency_matrix: np.ndarray = np.asarray(adjacency_matrix, dtype=int)
# Molecular graph
self.G = None
self.masses: Optional[List[float]] = None
@classmethod
def from_obabel(cls, obmol, adjacency: bool = True):
"""
Constructor from OpenBabel molecule.
Parameters
----------
obmol:
OpenBabel molecule
adjacency:
Flag to compute the adjacency matrix
Returns
-------
spyrmsd.molecule.Molecule
:code:`spyrmsd` Molecule
"""
# TODO: Check if OpenBabel is available?
from spyrmsd.optional import obabel as ob
return ob.to_molecule(obmol, adjacency=adjacency)
@classmethod
def from_rdkit(cls, rdmol, adjacency: bool = True):
"""
Constructor from RDKit molecule.
Parameters
----------
rdmol:
RDKit molecule
adjacency:
Flag to compute the adjacency matrix
Returns
-------
spyrmsd.molecule.Molecule
:code:`spyrmsd` Molecule
"""
# TODO: Check if RDKit is available?
from spyrmsd.optional import rdkit as rd
return rd.to_molecule(rdmol, adjacency=adjacency)
def translate(self, vector: Union[np.ndarray, List[float]]) -> None:
"""
Translate molecule.
Parameters
----------
vector: np.ndarray
Translation vector (in 3D)
"""
assert len(vector) == 3
vector = np.asarray(vector)
self.coordinates += vector
def rotate(
self, angle: float, axis: Union[np.ndarray, List[float]], units: str = "rad"
) -> None:
"""
Rotate molecule.
Parameters
----------
angle: float
Rotation angle
axis: np.ndarray
Axis of rotation (in 3D)
units: {"rad", "deg"}
Units of the angle (radians `rad` or degrees `deg`)
"""
axis = np.asarray(axis)
self.coordinates = utils.rotate(self.coordinates, angle, axis, units)
def center_of_mass(self) -> np.ndarray:
"""
Center of mass.
Returns
-------
np.ndarray
Center of mass
Notes
-----
Atomic masses are cached.
"""
# Get masses and cache them
if self.masses is None:
self.masses = [constants.anum_to_mass[anum] for anum in self.atomicnums]
return np.average(self.coordinates, axis=0, weights=self.masses)
def center_of_geometry(self) -> np.ndarray:
"""
Center of geometry.
Returns
-------
np.ndarray
Center of geometry
"""
return utils.center_of_geometry(self.coordinates)
# TODO: Change name (to stripH)
def strip(self) -> None:
"""
Strip hydrogen atoms.
"""
if not self.stripped:
idx = self.atomicnums != 1 # Non-hydrogen atoms
# Strip
self.atomicnums = self.atomicnums[idx]
self.coordinates = self.coordinates[idx, :]
# Update number of atoms
self.natoms = len(self.atomicnums)
# Update adjacency matrix
if self.adjacency_matrix is not None:
self.adjacency_matrix = self.adjacency_matrix[np.ix_(idx, idx)]
# Reset molecular graph when stripping
self.G = None
self.stripped = True
def to_graph(self):
"""
Convert molecule to graph.
Returns
-------
Graph
Molecular graph.
Notes
-----
If the molecule does not have an associated adjacency matrix, a simple
bond perception is used.
The molecular graph is cached.
"""
if self.G is None:
try:
self.G = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)
except AttributeError:
warnings.warn(
"Molecule was not initialized with an adjacency matrix. "
+ "Using bond perception..."
)
# Automatic bond perception (with very simple rule)
self.adjacency_matrix = graph.adjacency_matrix_from_atomic_coordinates(
self.atomicnums, self.coordinates
)
self.G = graph.graph_from_adjacency_matrix(
self.adjacency_matrix, self.atomicnums
)
return self.G
def __len__(self) -> int:
"""
Molecule size.
Returns
-------
int
Number of atoms within the molecule
"""
return self.natoms
def coords_from_molecule(mol: Molecule, center: bool = False) -> np.ndarray:
"""
Atomic coordinates from molecule.
Parameters
----------
mol: molecule.Molecule
Molecule
center: bool
Center flag
Returns
-------
np.ndarray
Atomic coordinates (possibly centred)
Notes
-----
Atomic coordinates are centred according to the center of geometry, not the center
of mass.
"""
if center:
coords = mol.coordinates - mol.center_of_geometry()
else:
coords = mol.coordinates
return coords

View File

@@ -0,0 +1,7 @@
# Handle versioneer
from spyrmsd._version import get_versions
versions = get_versions()
__version__ = versions["version"]
__git_revision__ = versions["full-revisionid"]
del get_versions, versions

181
spyrmsd/optional/obabel.py Normal file
View File

@@ -0,0 +1,181 @@
from typing import List, Optional, Tuple
import numpy as np
from openbabel import openbabel as ob
from openbabel import pybel
from spyrmsd import molecule, utils
def load(fname: str):
"""
Load molecule from file.
Parameters
----------
fname: str
File name
Returns
-------
Molecule
"""
fmt = utils.molformat(fname)
obmol = next(pybel.readfile(fmt, fname))
return obmol
def loadall(fname: str):
"""
Load molecules from file.
Parameters
----------
fname: str
File name
Returns
-------
List of molecules
"""
fmt = utils.molformat(fname)
obmols = [obmol for obmol in pybel.readfile(fmt, fname)]
# FIXME: Special handling for multi-model PDB files
# See OpenBabel Issue #2097
if fmt == "pdb":
if len(obmols) > 1: # Multi-model PDB file
obmols = obmols[:-1]
return obmols
def adjacency_matrix(mol) -> np.ndarray:
"""
Adjacency matrix from OpenBabel molecule.
Parameters
----------
mol:
Molecule
Returns
-------
np.ndarray
Adjacency matrix of the molecule
"""
n = len(mol.atoms)
# Pre-allocate memory for the adjacency matrix
A = np.zeros((n, n), dtype=int)
# Loop over molecular bonds
for bond in ob.OBMolBondIter(mol.OBMol):
# Bonds are 1-indexed
i: int = bond.GetBeginAtomIdx() - 1
j: int = bond.GetEndAtomIdx() - 1
# A molecular graph is undirected
A[i, j] = A[j, i] = 1
return A
def to_molecule(mol, adjacency: bool = True):
"""
Transform molecule to `pyrmsd` molecule.
Parameters
----------
mol:
Molecule
adjacency: boolean, optional
Flag to decide wether to build the adjacency matrix from molecule
Returns
-------
pyrmsd.molecule.Molecule
`pyrmsd` molecule
"""
n = len(mol.atoms)
atomicnums = np.zeros(n, dtype=int)
coordinates = np.zeros((n, 3))
for i, atom in enumerate(mol.atoms):
atomicnums[i] = atom.atomicnum
coordinates[i] = atom.coords
A: Optional[np.ndarray] = adjacency_matrix(mol) if adjacency else None
return molecule.Molecule(atomicnums, coordinates, A)
def numatoms(mol) -> int:
"""
Number of atoms.
Parameters
----------
mol:
Molecule
Returns
-------
int
Number of atoms
"""
return mol.OBMol.NumAtoms()
def numbonds(mol) -> int:
"""
Number of bonds.
Parameters
----------
mol:
Molecule
Returns
-------
int
Number of bonds
"""
return mol.OBMol.NumBonds()
def bonds(mol) -> List[Tuple[int, int]]:
"""
List of bonds
Parameters
----------
mol:
Molecule
Returns
-------
List[Tuple[int, int]]
List of bonds
Notes
-----
A bond is defined by a tuple of (0-based) indices of two atoms.
"""
b = []
for bond in ob.OBMolBondIter(mol.OBMol):
i = bond.GetBeginAtomIdx() - 1
j = bond.GetEndAtomIdx() - 1
b.append((i, j))
return b

227
spyrmsd/optional/rdkit.py Normal file
View File

@@ -0,0 +1,227 @@
import gzip
import os
from typing import List, Optional, Tuple
import numpy as np
import rdkit.Chem as Chem
from spyrmsd import molecule, utils
def _load_block_gzipped(loader, fname: str):
"""
Load gzipped files using MolBlocks.
Parameters
----------
loader:
RDKit MolBlock loader (MolFromMol2Block, MolFromPDBBlock, ...)
fname: str
File name
Returns
-------
Molecule
"""
with gzip.open(fname, "r") as fgz:
content = fgz.read()
rdmol = loader(content, removeHs=False)
return rdmol
def load(fname: str):
"""
Load molecule from file.
Parameters
----------
fname: str
File name
Returns
-------
Molecule
"""
gzipped = os.path.splitext(fname)[-1] == ".gz"
fmt = utils.molformat(fname)
if fmt == "mol2":
if not gzipped:
rdmol = Chem.MolFromMol2File(fname, removeHs=False)
else:
rdmol = _load_block_gzipped(Chem.MolFromMol2Block, fname)
elif fmt == "sdf":
if not gzipped:
rdmol = next(Chem.SDMolSupplier(fname, removeHs=False))
else:
with gzip.open(fname, "r") as fgz:
rdmol = next(Chem.ForwardSDMolSupplier(fgz, removeHs=False))
elif fmt == "pdb":
if not gzipped:
rdmol = Chem.MolFromPDBFile(fname, removeHs=False)
else:
rdmol = _load_block_gzipped(Chem.MolFromPDBBlock, fname)
else:
raise NotImplementedError
return rdmol
def loadall(fname: str):
"""
Load molecules from file.
Parameters
----------
fname: str
File name
Returns
-------
List of molecules
"""
gzipped = os.path.splitext(fname)[-1] == ".gz"
fmt = utils.molformat(fname)
if fmt == "mol2":
raise NotImplementedError # See RDKit Issue #415
elif fmt == "sdf":
if not gzipped:
rdmols = Chem.SDMolSupplier(fname, removeHs=False)
mols = [rdmol for rdmol in rdmols]
else:
with gzip.open(fname, "r") as fgz:
rdmols = Chem.ForwardSDMolSupplier(fgz, removeHs=False)
# Load all molecules before closing file
mols = [rdmol for rdmol in rdmols]
elif fmt == "pdb":
# TODO: Implement
raise NotImplementedError
else:
raise NotImplementedError
return mols
def adjacency_matrix(mol) -> np.ndarray:
"""
Adjacency matrix from OpenBabel molecule.
Parameters
----------
mol:
Molecule
Returns
-------
np.ndarray
Adjacency matrix of the molecule
"""
return Chem.rdmolops.GetAdjacencyMatrix(mol)
def to_molecule(mol, adjacency: bool = True):
"""
Transform molecule to `pyrmsd` molecule.
Parameters
----------
mol:
Molecule
adjacency: boolean, optional
Flag to decide wether to build the adjacency matrix from molecule
Returns
-------
spyrmsd.molecule.Molecule
`spyrmsd` molecule
"""
if mol is None:
# Propagate RDKit parsing failure
return None
atoms = mol.GetAtoms()
n = len(atoms)
atomicnums = np.zeros(n, dtype=int)
coordinates = np.zeros((n, 3))
conformer = mol.GetConformer()
for i, atom in enumerate(atoms):
atomicnums[i] = atom.GetAtomicNum()
pos = conformer.GetAtomPosition(i)
coordinates[i] = np.array([pos.x, pos.y, pos.z])
A: Optional[np.ndarray] = adjacency_matrix(mol) if adjacency else None
return molecule.Molecule(atomicnums, coordinates, A)
def numatoms(mol) -> int:
"""
Number of atoms.
Parameters
----------
mol:
Molecule
Returns
-------
int
Number of atoms
"""
return mol.GetNumAtoms()
def numbonds(mol) -> int:
"""
Number of bonds.
Parameters
----------
mol:
Molecule
Returns
-------
int
Number of bonds
"""
return mol.GetNumBonds()
def bonds(mol) -> List[Tuple[int, int]]:
"""
List of bonds.
Parameters
----------
mol:
Molecule
Returns
-------
List[Tuple[int, int]]
List of bonds
Notes
-----
A bond is defined by a tuple of (0-based) indices of two atoms.
"""
b = []
for bond in mol.GetBonds():
b.append((bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()))
return b

288
spyrmsd/qcp.py Normal file
View File

@@ -0,0 +1,288 @@
from typing import Tuple
import numpy as np
from scipy import optimize
from .due import Doi, due
due.cite(
Doi("10.1107/S0108767305015266"),
path="spyrmsd.qcp",
description="QCP method",
)
def M_mtx(A: np.ndarray, B: np.ndarray) -> np.ndarray:
"""
Compute inner product between coordinate matrices.
Parameters
----------
A: numpy.ndarray
Coordinates `A`
B: numpy.ndarray
Coordinates `B`
Returns
-------
numpy.ndarray
Inner product of the coordinate matrices `A` and `B`
Notes
-----
The inner product of the coordinate matrices `A` and `B` corresponds to the matrix
:math:`\\mathbf{M}`. [1]_
If :math:`S_{xy}` is defined as
.. math:: S_{xy} = \\sum_i^N x_{B,i} y_{A,i}
then :math:`\\mathbf{M}` is the :math:`3\\times 3` matrix given by
.. math::
\\begin{pmatrix}
S_{xx} & S_{xy} & S_{xz} \\\\
S_{yx} & S_{yy} & S_{yz} \\\\
S_{zx} & S_{zy} & S_{zz} \\\\
\\end{pmatrix}
.. [1] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
"""
return B.T @ A
def K_mtx(M):
"""
Compute symmetric key matrix.
Parameters
----------
M : numpy.ndarray
Inner product between coordinate matrices
Returns
-------
numpy.ndarray
Symmetric key matrix
Notes
-----
The symmetric key matrix corresponds to the matrix :math:`\\mathbf{K}`. [2]_
If :math:`S_{xy}` is defined as
.. math:: S_{xy} = \\sum_i^N x_{B,i} y_{A,i}
then :math:`\\mathbf{K}` is the :math:`4\\times 4` symmetric matrix given by
.. math::
\\begin{pmatrix}
S_{xx} + S_{yy} + S_{zz} & S_{yz} - S_{zy} & S_{zx} - S_{xz} & S_{xy} - S_{yx} \\\\
& S_{xx} - S_{yy} - S_{zz} & S_{xy} + S_{yx} & S_{zx} + S_{xz}\\\\
& & -S_{xx} + S_{yy} - S_{zz} & S_{yz} - S_{zy} \\\\
& & & -S_{xx} - S_{yy} + S_{zz} \\\\
\\end{pmatrix}
.. [2] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
"""
assert M.shape == (3, 3)
S_xx = M[0, 0]
S_xy = M[0, 1]
S_xz = M[0, 2]
S_yx = M[1, 0]
S_yy = M[1, 1]
S_yz = M[1, 2]
S_zx = M[2, 0]
S_zy = M[2, 1]
S_zz = M[2, 2]
# p = plus, m = minus
S_xx_yy_zz_ppp = S_xx + S_yy + S_zz
S_yz_zy_pm = S_yz - S_zy
S_zx_xz_pm = S_zx - S_xz
S_xy_yx_pm = S_xy - S_yx
S_xx_yy_zz_pmm = S_xx - S_yy - S_zz
S_xy_yx_pp = S_xy + S_yx
S_zx_xz_pp = S_zx + S_xz
S_xx_yy_zz_mpm = -S_xx + S_yy - S_zz
S_yz_zy_pp = S_yz + S_zy
S_xx_yy_zz_mmp = -S_xx - S_yy + S_zz
return np.array(
[
[S_xx_yy_zz_ppp, S_yz_zy_pm, S_zx_xz_pm, S_xy_yx_pm],
[S_yz_zy_pm, S_xx_yy_zz_pmm, S_xy_yx_pp, S_zx_xz_pp],
[S_zx_xz_pm, S_xy_yx_pp, S_xx_yy_zz_mpm, S_yz_zy_pp],
[S_xy_yx_pm, S_zx_xz_pp, S_yz_zy_pp, S_xx_yy_zz_mmp],
]
)
def coefficients(M: np.ndarray, K: np.ndarray) -> Tuple[float, float, float]:
"""
Compute quaternion polynomial coefficients.
Parameters
----------
M : numpy.ndarray
Inner product between coordinate matrices
K: numpy.ndarray
Symmetric key matrix
Returns
-------
Tuple[float, float, float]
Quaternion polynomial coefficients
Notes
_____
Returns only :math:`\\mathbf{M}`- and :math:`\\mathbf{K}`-dependent coefficients
are returned. :math:`c_4=1` and :math:`c_3=0` are not returned.
The :math:`\\mathbf{M}`- and :math:`\\mathbf{K}`-dependent quaternion polynomial
coefficients are given by
.. math:: c_2 = -2 \\text{ tr}\\left(\\mathbf{M}^T\\mathbf{M}\\right)
.. math:: c_1 = -8 \\text{ det}(\\mathbf{M})
.. math:: c_0 = \\text{ det}(\\mathbf{K})
"""
c2 = -2 * np.trace(M.T @ M)
c1 = -8 * np.linalg.det(M) # TODO: Slow?
c0 = np.linalg.det(K) # TODO: Slow?
return c2, c1, c0
def lambda_max(Ga: float, Gb: float, c2: float, c1: float, c0: float) -> float:
"""
Find largest root of the quaternion polynomial.
Parameters
----------
Ga: float
Inner product of structure A
Gb:
Inner product of structure B
c2:
Coefficient :math:`c_2` of the quaternion polynomial
c1:
Coefficient :math:`c_1` of the quaternion polynomial
c0:
Coefficient :math:`c_0` of the quaternion polynomial
Returns
-------
float
Largest root of the quaternion polynomial (:math:`\\lambda_\\text{max}`)
"""
def P(x):
"""
Quaternion polynomial
"""
return x**4 + c2 * x**2 + c1 * x + c0
def dP(x):
"""
Fist derivative of the quaternion polynomial
"""
return 4 * x**3 + 2 * c2 * x + c1
x0 = (Ga + Gb) * 0.5
lmax = optimize.newton(P, x0, fprime=dP)
return lmax
def _lambda_max_eig(K: np.ndarray) -> float:
"""
Find largest eigenvalue of :math:`K`.
Parameters
----------
K: np.ndarray
Symmetric key matrix
Returns
-------
float
Largest eigenvalue of :math:`K`, :math:`\\lambda_\\text{max}`
"""
e, _ = np.linalg.eig(K)
return max(e)
def qcp_rmsd(A: np.ndarray, B: np.ndarray, atol: float = 1e-9) -> float:
"""
Compute RMSD using the quaternion polynomial method.
Parameters
----------
A: numpy.ndarray
Coordinates of structure A
B: numpy.ndarray
Coordinates of structure B
atol: float
Absolute tolerance parameter (see notes)
Returns
-------
float
RMSD between structures `A` and `B`
Raises
------
AssertionError
If the shape of structures `A` and `B` is different
Notes
-----
If the structures `A` and `B` can be superimposed exactly (i.e. they differ only
by center-of-mass translations and rotations), we have
.. math:: G_a + G_b = 2 \\lambda_\\text{max}
This means that :math:`s = G_a + G_bb - 2 * \\lambda_\\text{max}` can become
negative because of numerical errors and therefore :math:`\\sqrt{s}` fails.
In order to avoid this problem, the final RMSD is set to :math:`0`
if :math:`|s| < atol`.
"""
assert A.shape == B.shape
N = A.shape[0]
Ga = np.trace(A.T @ A)
Gb = np.trace(B.T @ B)
M = M_mtx(A, B)
K = K_mtx(M)
c2, c1, c0 = coefficients(M, K)
try:
# Fast calculation of the largest eigenvalue of K as root of the characteristic
# polynomial.
l_max = lambda_max(Ga, Gb, c2, c1, c0)
except RuntimeError: # Newton method fails to converge; see GitHub Issue #35
# Fallback to (slower) explicit calculation of the largest eigenvalue of K
l_max = _lambda_max_eig(K)
s = Ga + Gb - 2 * l_max
if abs(s) < atol: # Avoid numerical errors when Ga + Gb = 2 * l_max
rmsd = 0.0
else:
rmsd = np.sqrt(s / N)
return rmsd

382
spyrmsd/rmsd.py Normal file
View File

@@ -0,0 +1,382 @@
from typing import Any, List, Optional, Tuple, Union
import numpy as np
from spyrmsd import graph, hungarian, molecule, qcp, utils
def rmsd(
coords1: np.ndarray,
coords2: np.ndarray,
atomicn1: np.ndarray,
atomicn2: np.ndarray,
center: bool = False,
minimize: bool = False,
atol: float = 1e-9,
) -> float:
"""
Compute RMSD
Parameters
----------
coords1: np.ndarray
Coordinate of molecule 1
coords2: np.ndarray
Coordinates of molecule 2
atomicn1: np.ndarray
Atomic numbers for molecule 1
atomicn2: np.ndarray
Atomic numbers for molecule 2
center: bool
Center molecules at origin
minimize: bool
Compute minimum RMSD (with QCP method)
atol: float
Absolute tolerance parameter for QCP method (see :func:`qcp_rmsd`)
Returns
-------
float
RMSD
Notes
-----
When `minimize=True`, the QCP method is used. [1]_ The molecules are
centred at the origin according to the center of geometry and superimposed
in order to minimize the RMSD.
.. [1] D. L. Theobald, *Rapid calculation of RMSDs using a quaternion-based
characteristic polynomial*, Acta Crys. A **61**, 478-480 (2005).
"""
assert np.all(atomicn1 == atomicn2)
assert coords1.shape == coords2.shape
# Center coordinates if required
c1 = utils.center(coords1) if center or minimize else coords1
c2 = utils.center(coords2) if center or minimize else coords2
if minimize:
rmsd = qcp.qcp_rmsd(c1, c2, atol)
else:
n = coords1.shape[0]
rmsd = np.sqrt(np.sum((c1 - c2) ** 2) / n)
return rmsd
def hrmsd(
coords1: np.ndarray,
coords2: np.ndarray,
atomicn1: np.ndarray,
atomicn2: np.ndarray,
center=False,
):
"""
Compute minimum RMSD using the Hungarian method.
Parameters
----------
coords1: np.ndarray
Coordinate of molecule 1
coords2: np.ndarray
Coordinates of molecule 2
atomicn1: np.ndarray
Atomic numbers for molecule 1
atomicn2: np.ndarray
Atomic numbers for molecule 2
Returns
-------
float
Minimum RMSD (after assignment)
Notes
-----
The Hungarian algorithm is used to solve the linear assignment problem, which is
a minimum weight matching of the molecular graphs (bipartite). [2]_
The linear assignment problem is solved for every element separately.
.. [2] W. J. Allen and R. C. Rizzo, *Implementation of the Hungarian Algorithm to
Account for Ligand Symmetry and Similarity in Structure-Based Design*,
J. Chem. Inf. Model. **54**, 518-529 (2014)
"""
assert atomicn1.shape == atomicn2.shape
assert coords1.shape == coords2.shape
# Center coordinates if required
c1 = utils.center(coords1) if center else coords1
c2 = utils.center(coords2) if center else coords2
return hungarian.hungarian_rmsd(c1, c2, atomicn1, atomicn2)
def _rmsd_isomorphic_core(
coords1: np.ndarray,
coords2: np.ndarray,
aprops1: np.ndarray,
aprops2: np.ndarray,
am1: np.ndarray,
am2: np.ndarray,
center: bool = False,
minimize: bool = False,
isomorphisms: Optional[List[Tuple[List[int], List[int]]]] = None,
atol: float = 1e-9,
) -> Tuple[float, List[Tuple[List[int], List[int]]], Tuple[List[int], List[int]]]:
"""
Compute RMSD using graph isomorphism.
Parameters
----------
coords1: np.ndarray
Coordinate of molecule 1
coords2: np.ndarray
Coordinates of molecule 2
aprops1: np.ndarray
Atomic properties for molecule 1
aprops2: np.ndarray
Atomic properties for molecule 2
am1: np.ndarray
Adjacency matrix for molecule 1
am2: np.ndarray
Adjacency matrix for molecule 2
center: bool
Centering flag
minimize: bool
Compute minized RMSD
isomorphisms: Optional[List[Dict[int,int]]]
Previously computed graph isomorphism
atol: float
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
Returns
-------
Tuple[float, List[Dict[int, int]]]
RMSD (after graph matching) and graph isomorphisms
"""
assert coords1.shape == coords2.shape
n = coords1.shape[0]
# Center coordinates if required
c1 = utils.center(coords1) if center or minimize else coords1
c2 = utils.center(coords2) if center or minimize else coords2
# No cached isomorphisms
if isomorphisms is None:
# Convert molecules to graphs
G1 = graph.graph_from_adjacency_matrix(am1, aprops1)
G2 = graph.graph_from_adjacency_matrix(am2, aprops2)
# Get all the possible graph isomorphisms
isomorphisms = graph.match_graphs(G1, G2)
# Minimum result
# Squared displacement (not minimize) or RMSD (minimize)
min_result = np.inf
min_isomorphisms = None
# Loop over all graph isomorphisms to find the lowest RMSD
for idx1, idx2 in isomorphisms:
# Use the isomorphism to shuffle coordinates around (from original order)
c1i = c1[idx1, :]
c2i = c2[idx2, :]
if not minimize:
# Compute square displacement
# Avoid dividing by n and an expensive sqrt() operation
result = np.sum((c1i - c2i) ** 2)
else:
# Compute minimized RMSD using QCP
result = qcp.qcp_rmsd(c1i, c2i, atol)
if result < min_result:
min_result = result
min_isomorphisms = (idx1, idx2)
if not minimize:
# Compute actual RMSD from square displacement
min_result = np.sqrt(min_result / n)
# Return the actual RMSD
return min_result, isomorphisms, min_isomorphisms
def symmrmsd(
coordsref: np.ndarray,
coords: Union[np.ndarray, List[np.ndarray]],
apropsref: np.ndarray,
aprops: np.ndarray,
amref: np.ndarray,
am: np.ndarray,
center: bool = False,
minimize: bool = False,
cache: bool = True,
atol: float = 1e-9,
return_permutation: bool = False,
) -> Any:
"""
Compute RMSD using graph isomorphism for multiple coordinates.
Parameters
----------
coordsref: np.ndarray
Coordinate of reference molecule
coords: List[np.ndarray]
Coordinates of other molecule
apropsref: np.ndarray
Atomic properties for reference
aprops: np.ndarray
Atomic properties for other molecule
amref: np.ndarray
Adjacency matrix for reference molecule
am: np.ndarray
Adjacency matrix for other molecule
center: bool
Centering flag
minimize: bool
Minimum RMSD
cache: bool
Cache graph isomorphisms
atol: float
Absolute tolerance parameter for QCP (see :func:`qcp_rmsd`)
Returns
-------
float: Union[float, List[float]]
Symmetry-corrected RMSD(s) and graph isomorphisms
Notes
-----
Graph isomorphism is introduced for symmetry corrections. However, it is also
useful when two molecules do not have the atoms in the same order since atom
matching according to atomic numbers and the molecular connectivity is
performed. If atoms are in the same order and there is no symmetry, use the
`rmsd` function.
"""
if isinstance(coords, list): # Multiple RMSD calculations
RMSD: Any = []
isomorphism = None
min_iso = []
for c in coords:
if not cache:
# Reset isomorphism
isomorphism = None
srmsd, isomorphism, min_i = _rmsd_isomorphic_core(
coordsref,
c,
apropsref,
aprops,
amref,
am,
center=center,
minimize=minimize,
isomorphisms=isomorphism,
atol=atol,
)
min_iso.append(min_i)
RMSD.append(srmsd)
else: # Single RMSD calculation
RMSD, isomorphism, min_iso = _rmsd_isomorphic_core(
coordsref,
coords,
apropsref,
aprops,
amref,
am,
center=center,
minimize=minimize,
isomorphisms=None,
atol=atol,
)
if return_permutation:
return RMSD, min_iso
return RMSD
def rmsdwrapper(
molref: molecule.Molecule,
mols: Union[molecule.Molecule, List[molecule.Molecule]],
symmetry: bool = True,
center: bool = False,
minimize: bool = False,
strip: bool = True,
cache: bool = True,
) -> Any:
"""
Compute RMSD between two molecule.
Parameters
----------
molref: molecule.Molecule
Reference molecule
mols: Union[molecule.Molecule, List[molecule.Molecule]]
Molecules to compare to reference molecule
symmetry: bool, optional
Symmetry-corrected RMSD (using graph isomorphism)
center: bool, optional
Center molecules at origin
minimize: bool, optional
Minimised RMSD (using the quaternion polynomial method)
strip: bool, optional
Strip hydrogen atoms
Returns
-------
List[float]
RMSDs
"""
if not isinstance(mols, list):
mols = [mols]
if strip:
molref.strip()
for mol in mols:
mol.strip()
if minimize:
center = True
cref = molecule.coords_from_molecule(molref, center)
cmols = [molecule.coords_from_molecule(mol, center) for mol in mols]
RMSDlist = []
if symmetry:
RMSDlist = symmrmsd(
cref,
cmols,
molref.atomicnums,
mols[0].atomicnums,
molref.adjacency_matrix,
mols[0].adjacency_matrix,
center=center,
minimize=minimize,
cache=cache,
)
else: # No symmetry
for c in cmols:
RMSDlist.append(
rmsd(
cref,
c,
molref.atomicnums,
mols[0].atomicnums,
center=center,
minimize=minimize,
)
)
return RMSDlist

176
spyrmsd/utils.py Normal file
View File

@@ -0,0 +1,176 @@
import os
import numpy as np
def format(fname: str) -> str:
"""
Extract format extension from file name.
Parameters
----------
fname : str
File name
Returns
-------
str
File extension
Notes
-----
The file extension is returned without the `.` character, i.e. for the file
`path/filename.ext` the string `ext` is returned.
If a file is compressed, the `.gz` extension is ignored.
"""
name, ext = os.path.splitext(fname)
if ext == ".gz":
_, ext = os.path.splitext(name)
return ext[1:] # Remove "."
def molformat(fname: str) -> str:
"""
Extract an OpenBabel-friendly format from file name.
Parameters
----------
fname : str
File name
Returns
-------
str
File extension in an OpenBabel-friendly format
Notes
-----
File types in OpenBabel do not always correspond to the file extension. This
function converts the file extension to an OpenBabel file type.
The following table shows the different conversions performed by this function:
========= =========
Extension File Type
--------- ---------
xyz XYZ
========= =========
"""
ext = format(fname)
if ext == "xyz":
# xyz files in OpenBabel are called XYZ
ext = "XYZ"
return ext
def deg_to_rad(angle: float) -> float:
"""
Convert angle in degrees to angle in radians.
Parameters
----------
angle : float
Angle (in degrees)
Returns
-------
float
Angle (in radians)
"""
return angle * np.pi / 180.0
def rotate(
v: np.ndarray, angle: float, axis: np.ndarray, units: str = "rad"
) -> np.ndarray:
"""
Rotate vector.
Parameters
----------
v: numpy.array
3D vector to be rotated
angle : float
Angle of rotation (in `units`)
axis : numpy.array
3D axis of rotation
units: {"rad", "deg"}
Units of `angle` (in radians `rad` or degrees `deg`)
Returns
-------
numpy.array
Rotated vector
Raises
------
AssertionError
If the axis of rotation is not a 3D vector
ValueError
If `units` is not `rad` or `deg`
"""
assert len(axis) == 3
# Ensure rotation axis is normalised
axis = axis / np.linalg.norm(axis)
if units.lower() == "rad":
pass
elif units.lower() == "deg":
angle = deg_to_rad(angle)
else:
raise ValueError(
f"Units {units} for angle is not supported. Use 'deg' or 'rad' instead."
)
t1 = np.outer(axis, np.inner(axis, v)).T
t2 = np.cos(angle) * np.cross(np.cross(axis, v), axis)
t3 = np.sin(angle) * np.cross(axis, v)
return t1 + t2 + t3
def center_of_geometry(coordinates: np.ndarray) -> np.ndarray:
"""
Center of geometry.
Parameters
----------
coordinates: np.ndarray
Coordinates
Returns
-------
np.ndarray
Center of geometry
"""
assert coordinates.shape[1] == 3
return np.mean(coordinates, axis=0)
def center(coordinates: np.ndarray) -> np.ndarray:
"""
Center coordinates.
Parameters
----------
coordinates: np.ndarray
Coordinates
Returns
-------
np.ndarray
Centred coordinates
"""
return coordinates - center_of_geometry(coordinates)

139
train.py
View File

@@ -1,6 +1,7 @@
import copy
import math
import os
import shutil
from functools import partial
import wandb
@@ -12,46 +13,88 @@ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
resource.setrlimit(resource.RLIMIT_NOFILE, (64000, rlimit[1]))
import yaml
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl
from datasets.pdbbind import construct_loader
from utils.diffusion_utils import t_to_sigma as t_to_sigma_compl, t_to_sigma_individual
from datasets.loader import construct_loader
from utils.parsing import parse_train_args
from utils.training import train_epoch, test_epoch, loss_function, inference_epoch
from utils.training import train_epoch, test_epoch, loss_function, inference_epoch_fix
from utils.utils import save_yaml_file, get_optimizer_and_scheduler, get_model, ExponentialMovingAverage
def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir):
def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir, val_dataset2):
loss_fn = partial(loss_function, tr_weight=args.tr_weight, rot_weight=args.rot_weight,
tor_weight=args.tor_weight, no_torsion=args.no_torsion, backbone_weight=args.backbone_loss_weight,
sidechain_weight=args.sidechain_loss_weight)
best_val_loss = math.inf
best_val_inference_value = math.inf if args.inference_earlystop_goal == 'min' else 0
best_val_secondary_value = math.inf if args.inference_earlystop_goal == 'min' else 0
best_epoch = 0
best_val_inference_epoch = 0
loss_fn = partial(loss_function, tr_weight=args.tr_weight, rot_weight=args.rot_weight,
tor_weight=args.tor_weight, no_torsion=args.no_torsion)
freeze_params = 0
scheduler_mode = args.inference_earlystop_goal if args.val_inference_freq is not None else 'min'
if args.scheduler == 'layer_linear_warmup':
freeze_params = args.warmup_dur * (args.num_conv_layers + 2) - 1
print("Freezing some parameters until epoch {}".format(freeze_params))
print("Starting training...")
for epoch in range(args.n_epochs):
if epoch % 5 == 0: print("Run name: ", args.run_name)
logs = {}
train_losses = train_epoch(model, train_loader, optimizer, device, t_to_sigma, loss_fn, ema_weights)
print("Epoch {}: Training loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f}"
.format(epoch, train_losses['loss'], train_losses['tr_loss'], train_losses['rot_loss'],
train_losses['tor_loss']))
ema_weights.store(model.parameters())
if args.use_ema: ema_weights.copy_to(model.parameters()) # load ema parameters into model for running validation and inference
if args.scheduler == 'layer_linear_warmup' and (epoch+1) % args.warmup_dur == 0:
step = (epoch+1) // args.warmup_dur
if step < args.num_conv_layers + 2:
print("New unfreezing step")
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=step, scheduler_mode=scheduler_mode)
elif step == args.num_conv_layers + 2:
print("Unfreezing all parameters")
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=step, scheduler_mode=scheduler_mode)
ema_weights = ExponentialMovingAverage(model.parameters(), decay=args.ema_rate)
elif args.scheduler == 'linear_warmup' and epoch == args.warmup_dur:
print("Moving to plateu scheduler")
optimizer, scheduler = get_optimizer_and_scheduler(args, model, step=1, scheduler_mode=scheduler_mode,
optimizer=optimizer)
logs = {}
train_losses = train_epoch(model, train_loader, optimizer, device, t_to_sigma, loss_fn, ema_weights if epoch > freeze_params else None)
print("Epoch {}: Training loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f} sc {:.4f} lr {:.4f}"
.format(epoch, train_losses['loss'], train_losses['tr_loss'], train_losses['rot_loss'],
train_losses['tor_loss'], train_losses['sidechain_loss'], optimizer.param_groups[0]['lr']))
if epoch > freeze_params:
ema_weights.store(model.parameters())
if args.use_ema: ema_weights.copy_to(model.parameters()) # load ema parameters into model for running validation and inference
val_losses = test_epoch(model, val_loader, device, t_to_sigma, loss_fn, args.test_sigma_intervals)
print("Epoch {}: Validation loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f}"
.format(epoch, val_losses['loss'], val_losses['tr_loss'], val_losses['rot_loss'], val_losses['tor_loss']))
print("Epoch {}: Validation loss {:.4f} tr {:.4f} rot {:.4f} tor {:.4f} sc {:.4f}"
.format(epoch, val_losses['loss'], val_losses['tr_loss'], val_losses['rot_loss'], val_losses['tor_loss'], val_losses['sidechain_loss']))
if args.val_inference_freq != None and (epoch + 1) % args.val_inference_freq == 0:
inf_metrics = inference_epoch(model, val_loader.dataset.complex_graphs[:args.num_inference_complexes], device, t_to_sigma, args)
print("Epoch {}: Val inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f}"
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5']))
inf_dataset = [val_loader.dataset.get(i) for i in range(min(args.num_inference_complexes, val_loader.dataset.__len__()))]
inf_metrics = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
print("Epoch {}: Val inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5'], inf_metrics['min_rmsds_lt2'], inf_metrics['min_rmsds_lt5']))
logs.update({'valinf_' + k: v for k, v in inf_metrics.items()}, step=epoch + 1)
if not args.use_ema: ema_weights.copy_to(model.parameters())
ema_state_dict = copy.deepcopy(model.module.state_dict() if device.type == 'cuda' else model.state_dict())
ema_weights.restore(model.parameters())
if args.double_val and args.val_inference_freq != None and (epoch + 1) % args.val_inference_freq == 0:
inf_dataset = [val_dataset2.get(i) for i in range(min(args.num_inference_complexes, val_dataset2.__len__()))]
inf_metrics2 = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
print("Epoch {}: Val inference on second validation rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
.format(epoch, inf_metrics2['rmsds_lt2'], inf_metrics2['rmsds_lt5'], inf_metrics2['min_rmsds_lt2'], inf_metrics2['min_rmsds_lt5']))
logs.update({'valinf2_' + k: v for k, v in inf_metrics2.items()}, step=epoch + 1)
logs.update({'valinfcomb_' + k: (v + inf_metrics[k])/2 for k, v in inf_metrics2.items()}, step=epoch + 1)
if args.train_inference_freq != None and (epoch + 1) % args.train_inference_freq == 0:
inf_dataset = [train_loader.dataset.get(i) for i in range(min(min(args.num_inference_complexes, 300), train_loader.dataset.__len__()))]
inf_metrics = inference_epoch_fix(model, inf_dataset, device, t_to_sigma, args)
print("Epoch {}: Train inference rmsds_lt2 {:.3f} rmsds_lt5 {:.3f} min_rmsds_lt2 {:.3f} min_rmsds_lt5 {:.3f}"
.format(epoch, inf_metrics['rmsds_lt2'], inf_metrics['rmsds_lt5'], inf_metrics['min_rmsds_lt2'], inf_metrics['min_rmsds_lt5']))
logs.update({'traininf_' + k: v for k, v in inf_metrics.items()}, step=epoch + 1)
if epoch > freeze_params:
if not args.use_ema: ema_weights.copy_to(model.parameters())
ema_state_dict = copy.deepcopy(model.module.state_dict() if device.type == 'cuda' else model.state_dict())
ema_weights.restore(model.parameters())
if args.wandb:
logs.update({'train_' + k: v for k, v in train_losses.items()})
@@ -66,15 +109,31 @@ def train(args, model, optimizer, scheduler, ema_weights, train_loader, val_load
best_val_inference_value = logs[args.inference_earlystop_metric]
best_val_inference_epoch = epoch
torch.save(state_dict, os.path.join(run_dir, 'best_inference_epoch_model.pt'))
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_inference_epoch_model.pt'))
if epoch > freeze_params:
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_inference_epoch_model.pt'))
if args.inference_secondary_metric is not None and args.inference_secondary_metric in logs.keys() and \
(args.inference_earlystop_goal == 'min' and logs[args.inference_secondary_metric] <= best_val_secondary_value or
args.inference_earlystop_goal == 'max' and logs[args.inference_secondary_metric] >= best_val_secondary_value):
best_val_secondary_value = logs[args.inference_secondary_metric]
if epoch > freeze_params:
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_secondary_epoch_model.pt'))
if val_losses['loss'] <= best_val_loss:
best_val_loss = val_losses['loss']
best_epoch = epoch
torch.save(state_dict, os.path.join(run_dir, 'best_model.pt'))
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_model.pt'))
if epoch > freeze_params:
torch.save(ema_state_dict, os.path.join(run_dir, 'best_ema_model.pt'))
if args.save_model_freq is not None and (epoch + 1) % args.save_model_freq == 0:
shutil.copyfile(os.path.join(run_dir, 'best_model.pt'),
os.path.join(run_dir, f'epoch{epoch+1}_best_model.pt'))
if scheduler:
if args.val_inference_freq is not None:
if epoch < freeze_params or (args.scheduler == 'linear_warmup' and epoch < args.warmup_dur):
scheduler.step()
elif args.val_inference_freq is not None:
scheduler.step(best_val_inference_value)
else:
scheduler.step(val_losses['loss'])
@@ -108,17 +167,26 @@ def main_function():
if args.cudnn_benchmark:
torch.backends.cudnn.benchmark = True
if args.wandb:
wandb.init(
entity='',
settings=wandb.Settings(start_method="fork"),
project=args.project,
name=args.run_name,
config=args
)
# construct loader
t_to_sigma = partial(t_to_sigma_compl, args=args)
train_loader, val_loader = construct_loader(args, t_to_sigma)
train_loader, val_loader, val_dataset2 = construct_loader(args, t_to_sigma, device)
model = get_model(args, device, t_to_sigma=t_to_sigma)
optimizer, scheduler = get_optimizer_and_scheduler(args, model, scheduler_mode=args.inference_earlystop_goal if args.val_inference_freq is not None else 'min')
ema_weights = ExponentialMovingAverage(model.parameters(),decay=args.ema_rate)
if args.restart_dir:
try:
dict = torch.load(f'{args.restart_dir}/last_model.pt', map_location=torch.device('cpu'))
dict = torch.load(f'{args.restart_dir}/{args.restart_ckpt}.pt', map_location=torch.device('cpu'))
if args.restart_lr is not None: dict['optimizer']['param_groups'][0]['lr'] = args.restart_lr
optimizer.load_state_dict(dict['optimizer'])
model.module.load_state_dict(dict['model'], strict=True)
@@ -130,18 +198,15 @@ def main_function():
dict = torch.load(f'{args.restart_dir}/best_model.pt', map_location=torch.device('cpu'))
model.module.load_state_dict(dict, strict=True)
print("Due to exception had to take the best epoch and no optimiser")
elif args.pretrain_dir:
dict = torch.load(f'{args.pretrain_dir}/{args.pretrain_ckpt}.pt', map_location=torch.device('cpu'))
model.module.load_state_dict(dict, strict=True)
print("Using pretrained model", f'{args.pretrain_dir}/{args.pretrain_ckpt}.pt')
numel = sum([p.numel() for p in model.parameters()])
print('Model with', numel, 'parameters')
if args.wandb:
wandb.init(
entity='entity',
settings=wandb.Settings(start_method="fork"),
project=args.project,
name=args.run_name,
config=args
)
wandb.log({'numel': numel})
# record parameters
@@ -150,9 +215,9 @@ def main_function():
save_yaml_file(yaml_file_name, args.__dict__)
args.device = device
train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir)
train(args, model, optimizer, scheduler, ema_weights, train_loader, val_loader, t_to_sigma, run_dir, val_dataset2)
if __name__ == '__main__':
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
main_function()
main_function()

View File

@@ -5,8 +5,24 @@ import torch.nn.functional as F
from torch import nn
from scipy.stats import beta
from utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch
from utils.torsion import modify_conformer_torsion_angles
from utils.geometry import axis_angle_to_matrix, rigid_transform_Kabsch_3D_torch, rigid_transform_Kabsch_3D_torch_batch
from utils.torsion import modify_conformer_torsion_angles, modify_conformer_torsion_angles_batch
def sigmoid(t):
return 1 / (1 + np.e**(-t))
def sigmoid_schedule(t, k=10, m=0.5):
s = lambda t: sigmoid(k*(t-m))
return (s(t)-s(0))/(s(1)-s(0))
def t_to_sigma_individual(t, schedule_type, sigma_min, sigma_max, schedule_k=10, schedule_m=0.4):
if schedule_type == "exponential":
return sigma_min ** (1 - t) * sigma_max ** t
elif schedule_type == 'sigmoid':
return sigmoid_schedule(t, k=schedule_k, m=schedule_m) * (sigma_max - sigma_min) + sigma_min
def t_to_sigma(t_tr, t_rot, t_tor, args):
@@ -16,7 +32,7 @@ def t_to_sigma(t_tr, t_rot, t_tor, args):
return tr_sigma, rot_sigma, tor_sigma
def modify_conformer(data, tr_update, rot_update, torsion_updates):
def modify_conformer(data, tr_update, rot_update, torsion_updates, pivot=None):
lig_center = torch.mean(data['ligand'].pos, dim=0, keepdim=True)
rot_mat = axis_angle_to_matrix(rot_update.squeeze())
rigid_new_pos = (data['ligand'].pos - lig_center) @ rot_mat.T + tr_update + lig_center
@@ -26,14 +42,60 @@ def modify_conformer(data, tr_update, rot_update, torsion_updates):
data['ligand', 'ligand'].edge_index.T[data['ligand'].edge_mask],
data['ligand'].mask_rotate if isinstance(data['ligand'].mask_rotate, np.ndarray) else data['ligand'].mask_rotate[0],
torsion_updates).to(rigid_new_pos.device)
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
if pivot is None:
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
else:
R1, t1 = rigid_transform_Kabsch_3D_torch(pivot.T, rigid_new_pos.T)
R2, t2 = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, pivot.T)
aligned_flexible_pos = (flexible_new_pos @ R2.T + t2.T) @ R1.T + t1.T
data['ligand'].pos = aligned_flexible_pos
else:
data['ligand'].pos = rigid_new_pos
return data
def modify_conformer_batch(orig_pos, data, tr_update, rot_update, torsion_updates, mask_rotate):
B = data.num_graphs
N, M, R = data['ligand'].num_nodes // B, data['ligand', 'ligand'].num_edges // B, data['ligand'].edge_mask.sum().item() // B
pos, edge_index, edge_mask = orig_pos.reshape(B, N, 3) + 0, data['ligand', 'ligand'].edge_index[:, :M], data['ligand'].edge_mask[:M]
torsion_updates = torsion_updates.reshape(B, -1) if torsion_updates is not None else None
lig_center = torch.mean(pos, dim=1, keepdim=True)
rot_mat = axis_angle_to_matrix(rot_update)
rigid_new_pos = torch.bmm(pos - lig_center, rot_mat.permute(0, 2, 1)) + tr_update.unsqueeze(1) + lig_center
if torsion_updates is not None:
flexible_new_pos = modify_conformer_torsion_angles_batch(rigid_new_pos, edge_index.T[edge_mask], mask_rotate, torsion_updates)
R, t = rigid_transform_Kabsch_3D_torch_batch(flexible_new_pos, rigid_new_pos)
aligned_flexible_pos = torch.bmm(flexible_new_pos, R.transpose(1, 2)) + t.transpose(1, 2)
final_pos = aligned_flexible_pos.reshape(-1, 3)
else:
final_pos = rigid_new_pos.reshape(-1, 3)
return final_pos
def modify_conformer_coordinates(pos, tr_update, rot_update, torsion_updates, edge_mask, mask_rotate, edge_index):
# Made this function which does the same as modify_conformer because passing a graph would require
# creating a new heterograph for reach graph when unbatching a batch of graphs
lig_center = torch.mean(pos, dim=0, keepdim=True)
rot_mat = axis_angle_to_matrix(rot_update.squeeze())
rigid_new_pos = (pos - lig_center) @ rot_mat.T + tr_update + lig_center
if torsion_updates is not None:
flexible_new_pos = modify_conformer_torsion_angles(rigid_new_pos,edge_index.T[edge_mask],mask_rotate \
if isinstance(mask_rotate, np.ndarray) else mask_rotate[0], torsion_updates).to(rigid_new_pos.device)
R, t = rigid_transform_Kabsch_3D_torch(flexible_new_pos.T, rigid_new_pos.T)
aligned_flexible_pos = flexible_new_pos @ R.T + t.T
return aligned_flexible_pos
else:
return rigid_new_pos
def sinusoidal_embedding(timesteps, embedding_dim, max_positions=10000):
""" from https://github.com/hojonathanho/diffusion/blob/master/diffusion_tf/nn.py """
assert len(timesteps.shape) == 1
@@ -73,11 +135,15 @@ def get_timestep_embedding(embedding_type, embedding_dim, embedding_scale=10000)
return emb_func
def get_t_schedule(inference_steps):
return np.linspace(1, 0, inference_steps + 1)[:-1]
def get_t_schedule(sigma_schedule, inference_steps, inf_sched_alpha=1, inf_sched_beta=1, t_max=1):
if sigma_schedule == 'expbeta':
lin_max = beta.cdf(t_max, a=inf_sched_alpha, b=inf_sched_beta)
c = np.linspace(lin_max, 0, inference_steps + 1)[:-1]
return beta.ppf(c, a=inf_sched_alpha, b=inf_sched_beta)
raise Exception()
def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device):
def set_time(complex_graphs, t, t_tr, t_rot, t_tor, batchsize, all_atoms, device, include_miscellaneous_atoms=False):
complex_graphs['ligand'].node_t = {
'tr': t_tr * torch.ones(complex_graphs['ligand'].num_nodes).to(device),
'rot': t_rot * torch.ones(complex_graphs['ligand'].num_nodes).to(device),
@@ -93,4 +159,10 @@ def set_time(complex_graphs, t_tr, t_rot, t_tor, batchsize, all_atoms, device):
complex_graphs['atom'].node_t = {
'tr': t_tr * torch.ones(complex_graphs['atom'].num_nodes).to(device),
'rot': t_rot * torch.ones(complex_graphs['atom'].num_nodes).to(device),
'tor': t_tor * torch.ones(complex_graphs['atom'].num_nodes).to(device)}
'tor': t_tor * torch.ones(complex_graphs['atom'].num_nodes).to(device)}
if include_miscellaneous_atoms and not all_atoms:
complex_graphs['misc_atom'].node_t = {
'tr': t_tr * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device),
'rot': t_rot * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device),
'tor': t_tor * torch.ones(complex_graphs['misc_atom'].num_nodes).to(device)}

View File

@@ -1,5 +1,6 @@
import math
import torch.nn.functional as F
import numpy as np
import torch
@@ -85,6 +86,126 @@ def axis_angle_to_matrix(axis_angle):
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
"""
Returns torch.sqrt(torch.max(0, x))
but with a zero subgradient where x is 0.
"""
ret = torch.zeros_like(x)
positive_mask = x > 0
ret[positive_mask] = torch.sqrt(x[positive_mask])
return ret
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to quaternions.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
quaternions with real part first, as tensor of shape (..., 4).
"""
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
batch_dim = matrix.shape[:-2]
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
matrix.reshape(batch_dim + (9,)), dim=-1
)
q_abs = _sqrt_positive_part(
torch.stack(
[
1.0 + m00 + m11 + m22,
1.0 + m00 - m11 - m22,
1.0 - m00 + m11 - m22,
1.0 - m00 - m11 + m22,
],
dim=-1,
)
)
# we produce the desired quaternion multiplied by each of r, i, j, k
quat_by_rijk = torch.stack(
[
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
# pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
# `int`.
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
],
dim=-2,
)
# We floor here at 0.1 but the exact level is not important; if q_abs is small,
# the candidate won't be picked.
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
# forall i; we pick the best-conditioned one (with the largest denominator)
return quat_candidates[
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
].reshape(batch_dim + (4,))
def quaternion_to_axis_angle(quaternions: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as quaternions to axis/angle.
Args:
quaternions: quaternions with real part first,
as tensor of shape (..., 4).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
half_angles = torch.atan2(norms, quaternions[..., :1])
angles = 2 * half_angles
eps = 1e-6
small_angles = angles.abs() < eps
sin_half_angles_over_angles = torch.empty_like(angles)
sin_half_angles_over_angles[~small_angles] = (
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
)
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
# so sin(x/2)/x is about 1/2 - (x*x)/48
sin_half_angles_over_angles[small_angles] = (
0.5 - (angles[small_angles] * angles[small_angles]) / 48
)
return quaternions[..., 1:] / sin_half_angles_over_angles
def matrix_to_axis_angle(matrix: torch.Tensor) -> torch.Tensor:
"""
Convert rotations given as rotation matrices to axis/angle.
Args:
matrix: Rotation matrices as tensor of shape (..., 3, 3).
Returns:
Rotations given as a vector in axis angle form, as a tensor
of shape (..., 3), where the magnitude is the angle
turned anticlockwise in radians around the vector's
direction.
"""
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
def rigid_transform_Kabsch_3D_torch(A, B):
# R = 3x3 rotation matrix, t = 3x1 column vector
# This already takes residue identity into account.
@@ -97,7 +218,6 @@ def rigid_transform_Kabsch_3D_torch(A, B):
if num_rows != 3:
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
# find mean column wise: 3 x 1
centroid_A = torch.mean(A, axis=1, keepdims=True)
centroid_B = torch.mean(B, axis=1, keepdims=True)
@@ -121,3 +241,78 @@ def rigid_transform_Kabsch_3D_torch(A, B):
t = -R @ centroid_A + centroid_B
return R, t
def rigid_transform_Kabsch_3D_torch_batch(A, B):
# R = Bx3x3 rotation matrix, t = Bx3x1 column vector
assert A.shape == B.shape
_, N, M = A.shape
if M != 3:
raise Exception(f"matrix A and B should be BxNx3")
A, B = A.permute(0, 2, 1), B.permute(0, 2, 1)
# find mean column wise: 3 x 1
centroid_A = torch.mean(A, axis=2, keepdims=True)
centroid_B = torch.mean(B, axis=2, keepdims=True)
# subtract mean
Am = A - centroid_A
Bm = B - centroid_B
H = torch.bmm(Am, Bm.transpose(1, 2))
# find rotation
U, S, Vt = torch.linalg.svd(H)
R = torch.bmm(Vt.transpose(1, 2), U.transpose(1, 2))
# reflection case
SS = torch.diag(torch.tensor([1., 1., -1.], device=A.device))
Rm = torch.bmm(Vt.transpose(1,2) @ SS, U.transpose(1, 2))
R = torch.where(torch.linalg.det(R)[:, None, None] < 0, Rm, R)
assert torch.all(torch.abs(torch.linalg.det(R) - 1) < 3e-3) # note I had to change this error bound to be higher
t = torch.bmm(-R, centroid_A) + centroid_B
return R, t
def rigid_transform_Kabsch_independent_torch(A, B):
# R = 3x3 rotation matrix, t = 3x1 column vector
# This already takes residue identity into account.
assert A.shape[1] == B.shape[1]
num_rows, num_cols = A.shape
if num_rows != 3:
raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
num_rows, num_cols = B.shape
if num_rows != 3:
raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
# find mean column wise: 3 x 1
centroid_A = torch.mean(A, axis=1, keepdims=True)
centroid_B = torch.mean(B, axis=1, keepdims=True)
# subtract mean
Am = A - centroid_A
Bm = B - centroid_B
H = Am @ Bm.T
# find rotation
U, S, Vt = torch.linalg.svd(H)
R = Vt.T @ U.T
# special reflection case
if torch.linalg.det(R) < 0:
# print("det(R) < R, reflection detected!, correcting for it ...")
SS = torch.diag(torch.tensor([1.,1.,-1.], device=A.device))
R = (Vt.T @ SS) @ U.T
assert math.fabs(torch.linalg.det(R) - 1) < 3e-3 # note I had to change this error bound to be higher
t = - centroid_A + centroid_B # note does not change rotation
R_vec = matrix_to_axis_angle(R)
return t, R_vec

89
utils/gnina_utils.py Normal file
View File

@@ -0,0 +1,89 @@
import os
import subprocess
import numpy as np
from rdkit.Chem import AllChem, RemoveHs, RemoveAllHs
from datasets.process_mols import write_mol_with_coords, read_molecule
import re
from utils.utils import remove_all_hs
def read_gnina_metrics(gnina_sdf_path):
with open(gnina_sdf_path, 'r') as f:
pattern = re.compile(r'> <(.*?)>\n(.*?)\n')
content = f.read()
matches = pattern.findall(content)
metrics = {k: v for k, v in matches}
return metrics
def read_gnina_score(gnina_sdf_path):
with open(gnina_sdf_path, 'r') as f:
pattern = re.compile(r'> <CNNscore>\n(.*?)\n')
content = f.read()
matches = pattern.findall(content)
return float(matches[0])
def invert_permutation(p):
"""Return an array s with which np.array_equal(arr[p][s], arr) is True.
The array_like argument p must be some permutation of 0, 1, ..., len(p)-1.
"""
p = np.asanyarray(p) # in case p is a tuple, etc.
s = np.empty_like(p)
s[p] = np.arange(p.size)
return s
def get_gnina_poses(args, mol, pos, orig_center, name, folder, gnina_path, thread_id=0):
#folder = "data/MOAD_new_test_processed" if args.split == 'test' else "data/MOAD_new_val_processed"
out_dir = args.out_dir if hasattr(args, 'out_dir') else args.inference_out_dir
rec_path = os.path.join(folder, name[:6] + '_protein_chain_removed.pdb')
pred_lig_path = os.path.join(out_dir, f'pred_{name}_tid{thread_id}_lig.sdf')
if not os.path.exists(os.path.dirname(pred_lig_path)):
os.mkdir(os.path.dirname(pred_lig_path))
print(f'Ligand path {pred_lig_path}')
write_mol_with_coords(mol, pos + orig_center, pred_lig_path)
gnina_pred_path = os.path.join(out_dir, f'gnina_{name}_tid{thread_id}_lig.sdf')
gnina_logs_dir = os.path.join(out_dir, "gnina_logs")
with open(os.path.join(gnina_logs_dir, f'{name}'), "w+") as f:
if args.gnina_full_dock:
return_code = subprocess.run(
f'{gnina_path} -r {rec_path} -l "{pred_lig_path}" --autobox_ligand "{pred_lig_path}" -o "{gnina_pred_path}" --no_gpu --autobox_add {args.gnina_autobox_add}',
shell=True, stdout=f, stderr=f)
else:
return_code = subprocess.run(
f'{gnina_path} --receptor {rec_path} --ligand "{pred_lig_path}" --minimize -o "{gnina_pred_path}"',
shell=True, stdout=f, stderr=f)
# print(f'gnina return code: {return_code}')
try:
gnina_mol = RemoveAllHs(read_molecule(gnina_pred_path, remove_hs=True, sanitize=True))
gnina_minimized_ligand_pos = np.array(gnina_mol.GetConformer(0).GetPositions())
gnina_atoms = np.array([atom.GetSymbol() for atom in gnina_mol.GetAtoms()])
gnina_filter_Hs = np.where(gnina_atoms != 'H')
gnina_ligand_pos = gnina_minimized_ligand_pos[gnina_filter_Hs] - orig_center
try:
gnina_score = read_gnina_score(gnina_pred_path)
if gnina_score is None:
gnina_score = 0
except Exception as e:
print(f'Error reading gnina score: {e}')
gnina_score = 0
except Exception as e:
print(f'Error when running gnina with {name} to minimize energy')
print('Error:', e)
print('Using score model output pos instead.')
gnina_ligand_pos = pos
gnina_mol = RemoveAllHs(mol)
gnina_score = 0
return gnina_ligand_pos, gnina_mol, gnina_score

View File

@@ -1,4 +1,6 @@
import copy
import os
import pickle
import torch
from Bio.PDB import PDBParser
@@ -7,38 +9,10 @@ from rdkit.Chem import AddHs, MolFromSmiles
from torch_geometric.data import Dataset, HeteroData
import esm
from datasets.process_mols import parse_pdb_from_path, generate_conformer, read_molecule, get_lig_graph_with_matching, \
extract_receptor_structure, get_rec_graph
from datasets.constants import three_to_one
from datasets.process_mols import generate_conformer, read_molecule, get_lig_graph_with_matching, moad_extract_receptor_structure
three_to_one = {'ALA': 'A',
'ARG': 'R',
'ASN': 'N',
'ASP': 'D',
'CYS': 'C',
'GLN': 'Q',
'GLU': 'E',
'GLY': 'G',
'HIS': 'H',
'ILE': 'I',
'LEU': 'L',
'LYS': 'K',
'MET': 'M',
'MSE': 'M', # MSE this is almost the same AA as MET. The sulfur is just replaced by Selen
'PHE': 'F',
'PRO': 'P',
'PYL': 'O',
'SER': 'S',
'SEC': 'U',
'THR': 'T',
'TRP': 'W',
'TYR': 'Y',
'VAL': 'V',
'ASX': 'B',
'GLX': 'Z',
'XAA': 'X',
'XLE': 'J'}
def get_sequences_from_pdbfile(file_path):
biopython_parser = PDBParser()
structure = biopython_parser.get_structure('random_id', file_path)
@@ -153,7 +127,7 @@ def generate_ESM_structure(model, filename, sequence):
class InferenceDataset(Dataset):
def __init__(self, out_dir, complex_names, protein_files, ligand_descriptions, protein_sequences, lm_embeddings,
receptor_radius=30, c_alpha_max_neighbors=None, precomputed_lm_embeddings=None,
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None):
remove_hs=False, all_atoms=False, atom_radius=5, atom_max_neighbors=None, knn_only_graph=False):
super(InferenceDataset, self).__init__()
self.receptor_radius = receptor_radius
@@ -161,6 +135,7 @@ class InferenceDataset(Dataset):
self.remove_hs = remove_hs
self.all_atoms = all_atoms
self.atom_radius, self.atom_max_neighbors = atom_radius, atom_max_neighbors
self.knn_only_graph = knn_only_graph
self.complex_names = complex_names
self.protein_files = protein_files
@@ -242,18 +217,19 @@ class InferenceDataset(Dataset):
try:
# parse the receptor from the pdb file
rec_model = parse_pdb_from_path(protein_file)
get_lig_graph_with_matching(mol, complex_graph, popsize=None, maxiter=None, matching=False, keep_original=False,
num_conformers=1, remove_hs=self.remove_hs)
rec, rec_coords, c_alpha_coords, n_coords, c_coords, lm_embeddings = extract_receptor_structure(rec_model, mol, lm_embedding_chains=lm_embedding)
if lm_embeddings is not None and len(c_alpha_coords) != len(lm_embeddings):
print(f'LM embeddings for complex {name} did not have the right length for the protein. Skipping {name}.')
complex_graph['success'] = False
return complex_graph
get_rec_graph(rec, rec_coords, c_alpha_coords, n_coords, c_coords, complex_graph, rec_radius=self.receptor_radius,
c_alpha_max_neighbors=self.c_alpha_max_neighbors, all_atoms=self.all_atoms,
atom_radius=self.atom_radius, atom_max_neighbors=self.atom_max_neighbors, remove_hs=self.remove_hs, lm_embeddings=lm_embeddings)
moad_extract_receptor_structure(
path=os.path.join(protein_file),
complex_graph=complex_graph,
neighbor_cutoff=self.receptor_radius,
max_neighbors=self.c_alpha_max_neighbors,
lm_embeddings=lm_embedding,
knn_only_graph=self.knn_only_graph,
all_atoms=self.all_atoms,
atom_cutoff=self.atom_radius,
atom_max_neighbors=self.atom_max_neighbors)
except Exception as e:
print(f'Skipping {name} because of the error:')

39
utils/molecules_utils.py Normal file
View File

@@ -0,0 +1,39 @@
from spyrmsd import rmsd, molecule
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None, return_permutation=False):
with time_limit(10):
mol = molecule.Molecule.from_rdkit(mol)
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
RMSD = rmsd.symmrmsd(
coords1,
coords2,
mol.atomicnums,
mol2_atomicnums,
mol.adjacency_matrix,
mol2_adjacency_matrix,
return_permutation=return_permutation
)
return RMSD
import signal
from contextlib import contextmanager
class TimeoutException(Exception): pass
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)

View File

@@ -6,32 +6,46 @@ def parse_train_args():
# General arguments
parser = ArgumentParser()
parser.add_argument('--config', type=FileType(mode='r'), default=None)
parser.add_argument('--log_dir', type=str, default='workdir', help='Folder in which to save model and logs')
parser.add_argument('--log_dir', type=str, default='workdir/test_score', help='Folder in which to save model and logs')
parser.add_argument('--restart_dir', type=str, help='Folder of previous training model from which to restart')
parser.add_argument('--restart_ckpt', type=str, default='last_model', help='')
parser.add_argument('--pretrain_dir', type=str, help='Folder of pretrained model from which to restart')
parser.add_argument('--pretrain_ckpt', type=str, help='')
parser.add_argument('--freeze_params', type=int, default=0, help='')
parser.add_argument('--cache_path', type=str, default='data/cache', help='Folder from where to load/restore cached dataset')
parser.add_argument('--data_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures')
parser.add_argument('--moad_dir', type=str, default='data/BindingMOAD_2020_processed/', help='Folder containing original structures')
parser.add_argument('--pdbbind_dir', type=str, default='data/PDBBind_processed/', help='Folder containing original structures')
parser.add_argument('--dataset', type=str, default='pdbbind', help='Folder containing original structures')
parser.add_argument('--split_train', type=str, default='data/splits/timesplit_no_lig_overlap_train', help='Path of file defining the split')
parser.add_argument('--split_val', type=str, default='data/splits/timesplit_no_lig_overlap_val', help='Path of file defining the split')
parser.add_argument('--split_test', type=str, default='data/splits/timesplit_test', help='Path of file defining the split')
parser.add_argument('--test_sigma_intervals', action='store_true', default=False, help='Whether to log loss per noise interval')
parser.add_argument('--val_inference_freq', type=int, default=5, help='Frequency of epochs for which to run expensive inference on val data')
parser.add_argument('--save_model_freq', type=int, default=None, help='')
parser.add_argument('--inference_samples', type=int, default=1, help='')
parser.add_argument('--train_inference_freq', type=int, default=None, help='Frequency of epochs for which to run expensive inference on train data')
parser.add_argument('--inference_steps', type=int, default=20, help='Number of denoising steps for inference on val')
parser.add_argument('--num_inference_complexes', type=int, default=100, help='Number of complexes for which inference is run every val/train_inference_freq epochs (None will run it on all)')
parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None')
parser.add_argument('--inference_earlystop_metric', type=str, default='valinf_min_rmsds_lt2', help='This is the metric that is addionally used when val_inference_freq is not None')
parser.add_argument('--inference_secondary_metric', type=str, default=None, help='')
parser.add_argument('--inference_earlystop_goal', type=str, default='max', help='Whether to maximize or minimize metric')
parser.add_argument('--wandb', action='store_true', default=False, help='')
parser.add_argument('--project', type=str, default='difdock_train', help='')
parser.add_argument('--project', type=str, default='diffdock', help='')
parser.add_argument('--run_name', type=str, default='', help='')
parser.add_argument('--cudnn_benchmark', action='store_true', default=False, help='CUDA optimization parameter for faster training')
parser.add_argument('--num_dataloader_workers', type=int, default=0, help='Number of workers for dataloader')
parser.add_argument('--pin_memory', action='store_true', default=False, help='pin_memory arg of dataloader')
parser.add_argument('--dataloader_drop_last', action='store_true', default=False, help='drop_last arg of dataloader')
parser.add_argument('--double_val', action='store_true', default=False, help='')
parser.add_argument('--combined_training', action='store_true', default=False, help='')
# Training arguments
parser.add_argument('--n_epochs', type=int, default=400, help='Number of epochs for training')
parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
parser.add_argument('--scheduler', type=str, default=None, help='LR scheduler')
parser.add_argument('--scheduler_patience', type=int, default=20, help='Patience of the LR scheduler')
parser.add_argument('--lr_start_factor', type=float, default=0.001, help='')
parser.add_argument('--warmup_dur', type=int, default=4, help='')
parser.add_argument('--lr', type=float, default=1e-3, help='Initial learning rate')
parser.add_argument('--restart_lr', type=float, default=None, help='If this is not none, the lr of the optimizer will be overwritten with this value when restarting from a checkpoint.')
parser.add_argument('--w_decay', type=float, default=0.0, help='Weight decay added to loss')
@@ -40,23 +54,42 @@ def parse_train_args():
parser.add_argument('--ema_rate', type=float, default=0.999, help='decay rate for the exponential moving average model parameters ')
# Dataset
parser.add_argument('--limit_complexes', type=int, default=0, help='If positive, the number of training and validation complexes is capped')
parser.add_argument('--limit_complexes', type=int, default=5, help='If positive, the number of training and validation complexes is capped') # TODO change
parser.add_argument('--all_atoms', action='store_true', default=False, help='Whether to use the all atoms model')
parser.add_argument('--chain_cutoff', type=float, default=None, help='Cutoff on whether to include non-interacting chains')
parser.add_argument('--receptor_radius', type=float, default=30, help='Cutoff on distances for receptor edges')
parser.add_argument('--c_alpha_max_neighbors', type=int, default=10, help='Maximum number of neighbors for each residue')
parser.add_argument('--atom_radius', type=float, default=5, help='Cutoff on distances for atom connections')
parser.add_argument('--atom_max_neighbors', type=int, default=8, help='Maximum number of atom neighbours for receptor')
parser.add_argument('--matching_popsize', type=int, default=20, help='Differential evolution popsize parameter in matching')
parser.add_argument('--matching_maxiter', type=int, default=20, help='Differential evolution maxiter parameter in matching')
parser.add_argument('--matching_tries', type=int, default=1, help='')
parser.add_argument('--max_lig_size', type=int, default=None, help='Maximum number of heavy atoms in ligand')
parser.add_argument('--remove_hs', action='store_true', default=False, help='remove Hs')
parser.add_argument('--remove_hs', action='store_true', default=True, help='remove Hs')
parser.add_argument('--num_conformers', type=int, default=1, help='Number of conformers to match to each ligand')
parser.add_argument('--esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
parser.add_argument('--moad_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
parser.add_argument('--pdbbind_esm_embeddings_path', type=str, default=None, help='If this is set then the LM embeddings at that path will be used for the receptor features')
parser.add_argument('--moad_esm_embeddings_sequences_path', type=str, default=None, help='')
parser.add_argument('--esm_embeddings_model', type=str, default=None, help='')
parser.add_argument('--not_fixed_knn_radius_graph', action='store_true', default=False, help='Use knn graph and radius graph with closest neighbors instead of random ones as with radius_graph')
parser.add_argument('--not_knn_only_graph', action='store_true', default=False, help='Use knn graph only and not restrict to a specific radius')
parser.add_argument('--include_miscellaneous_atoms', action='store_true', default=False, help='include non amino acid atoms for the receptor')
parser.add_argument('--train_multiplicity', type=int, default=1, help='')
parser.add_argument('--val_multiplicity', type=int, default=1, help='')
parser.add_argument('--max_receptor_size', type=int, default=None, help='')
parser.add_argument('--remove_promiscuous_targets', type=int, default=None, help='')
parser.add_argument('--min_ligand_size', type=int, default=0, help='')
parser.add_argument('--unroll_clusters', action='store_true', default=False, help='')
parser.add_argument('--enforce_timesplit', action='store_true', default=False, help='')
parser.add_argument('--merge_clusters', type=int, default=1, help='')
parser.add_argument('--triple_training', action='store_true', default=False, help='')
parser.add_argument('--crop_beyond', type=float, default=20, help='')
# Diffusion
parser.add_argument('--tr_weight', type=float, default=0.33, help='Weight of translation loss')
parser.add_argument('--rot_weight', type=float, default=0.33, help='Weight of rotation loss')
parser.add_argument('--tor_weight', type=float, default=0.33, help='Weight of torsional loss')
parser.add_argument('--confidence_weight', type=float, default=0.33, help='Weight of confidence loss')
parser.add_argument('--rot_sigma_min', type=float, default=0.1, help='Minimum sigma for rotational component')
parser.add_argument('--rot_sigma_max', type=float, default=1.65, help='Maximum sigma for rotational component')
parser.add_argument('--tr_sigma_min', type=float, default=0.1, help='Minimum sigma for translational component')
@@ -64,11 +97,17 @@ def parse_train_args():
parser.add_argument('--tor_sigma_min', type=float, default=0.0314, help='Minimum sigma for torsional component')
parser.add_argument('--tor_sigma_max', type=float, default=3.14, help='Maximum sigma for torsional component')
parser.add_argument('--no_torsion', action='store_true', default=False, help='If set only rigid matching')
parser.add_argument('--sampling_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t')
parser.add_argument('--sampling_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t')
parser.add_argument('--bootstrap_alpha', type=float, default=1, help='Alpha parameter of beta distribution for sampling t in bootstrapping')
parser.add_argument('--bootstrap_beta', type=float, default=1, help='Beta parameter of beta distribution for sampling t in bootstrapping')
parser.add_argument('--bootstrap_tmin', type=float, default=0, help='')
# Model
parser.add_argument('--num_conv_layers', type=int, default=2, help='Number of interaction layers')
parser.add_argument('--max_radius', type=float, default=5.0, help='Radius cutoff for geometric graph')
parser.add_argument('--scale_by_sigma', action='store_true', default=True, help='Whether to normalise the score')
parser.add_argument('--norm_by_sigma', action='store_true', default=False, help='Whether to normalise the score')
parser.add_argument('--ns', type=int, default=16, help='Number of hidden features per node of order 0')
parser.add_argument('--nv', type=int, default=4, help='Number of hidden features per node of order >0')
parser.add_argument('--distance_embed_dim', type=int, default=32, help='Embedding size for the distance')
@@ -78,9 +117,36 @@ def parse_train_args():
parser.add_argument('--cross_max_distance', type=float, default=80, help='Maximum cross distance in case not dynamic')
parser.add_argument('--dynamic_max_cross', action='store_true', default=False, help='Whether to use the dynamic distance cutoff')
parser.add_argument('--dropout', type=float, default=0.0, help='MLP dropout')
parser.add_argument('--smooth_edges', action='store_true', default=False, help='Whether to apply additional smoothing weight to edges')
parser.add_argument('--odd_parity', action='store_true', default=False, help='Whether to impose odd parity in output')
parser.add_argument('--embedding_type', type=str, default="sinusoidal", help='Type of diffusion time embedding')
parser.add_argument('--sigma_embed_dim', type=int, default=32, help='Size of the embedding of the diffusion time')
parser.add_argument('--embedding_scale', type=int, default=1000, help='Parameter of the diffusion time embedding')
parser.add_argument('--use_old_atom_encoder', action='store_true', default=False, help='option to use old atom encoder for backward compatibility')
parser.add_argument('--depthwise_convolution', action='store_true', default=False, help='')
parser.add_argument('--protein_file', type=str, default='protein_processed', help='')
parser.add_argument('--no_aminoacid_identities', action='store_true', default=False, help='')
parser.add_argument('--sh_lmax', type=int, default=2, help='Size of the embedding of the diffusion time')
parser.add_argument('--no_differentiate_convolutions', action='store_true', default=False, help='')
parser.add_argument('--tp_weights_layers', type=int, default=2, help='')
parser.add_argument('--num_prot_emb_layers', type=int, default=0, help='')
parser.add_argument('--reduce_pseudoscalars', action='store_true', default=False, help='')
parser.add_argument('--embed_also_ligand', action='store_true', default=True, help='')
parser.add_argument('--sidechain_loss_weight', type=float, default=0, help='')
parser.add_argument('--backbone_loss_weight', type=float, default=0, help='')
# pdb sidechain training
parser.add_argument('--pdbsidechain_dir', type=str, default='data/pdb_2021aug02_sample', help='')
parser.add_argument('--pdbsidechain_esm_embeddings_path', type=str, default=None, help='')
parser.add_argument('--pdbsidechain_esm_embeddings_sequences_path', type=str, default=None, help='')
parser.add_argument('--vandermers_max_dist', type=int, default=5, help='')
parser.add_argument('--vandermers_buffer_residue_num', type=int, default=7, help='')
parser.add_argument('--vandermers_min_contacts', type=int, default=None, help='')
parser.add_argument('--remove_second_segment', action='store_true', default=False, help='')
args = parser.parse_args()
assert (not args.dynamic_max_cross) or (args.tr_sigma_max * 3 + 20 < args.cross_max_distance)
assert args.esm_embeddings_model is None or args.esm_embeddings_path is None
return args

View File

@@ -1,14 +1,33 @@
import copy
import random
import numpy as np
import torch
from torch_geometric.data import Batch
from torch_geometric.loader import DataLoader
from utils.diffusion_utils import modify_conformer, set_time
from utils.diffusion_utils import modify_conformer, set_time, modify_conformer_batch
from utils.torsion import modify_conformer_torsion_angles
from scipy.spatial.transform import Rotation as R
from utils.utils import crop_beyond
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max):
def randomize_position(data_list, no_torsion, no_random, tr_sigma_max, pocket_knowledge=False, pocket_cutoff=7,
initial_noise_std_proportion=-1.0, choose_residue=False):
# in place modification of the list
center_pocket = data_list[0]['receptor'].pos.mean(dim=0)
if pocket_knowledge:
complex = data_list[0]
d = torch.cdist(complex['receptor'].pos, torch.from_numpy(complex['ligand'].orig_pos[0]).float() - complex.original_center)
label = torch.any(d < pocket_cutoff, dim=1)
if torch.any(label):
center_pocket = complex['receptor'].pos[label].mean(dim=0)
else:
print("No pocket residue below minimum distance ", pocket_cutoff, "taking closest at", torch.min(d))
center_pocket = complex['receptor'].pos[torch.argmin(torch.min(d, dim=1)[0])]
if not no_torsion:
# randomize torsion angles
for complex_graph in data_list:
@@ -23,92 +42,212 @@ def randomize_position(data_list, no_torsion, no_random, tr_sigma_max):
# randomize position
molecule_center = torch.mean(complex_graph['ligand'].pos, dim=0, keepdim=True)
random_rotation = torch.from_numpy(R.random().as_matrix()).float()
complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T
complex_graph['ligand'].pos = (complex_graph['ligand'].pos - molecule_center) @ random_rotation.T + center_pocket
# base_rmsd = np.sqrt(np.sum((complex_graph['ligand'].pos.cpu().numpy() - orig_complex_graph['ligand'].pos.numpy()) ** 2, axis=1).mean())
if not no_random: # note for now the torsion angles are still randomised
tr_update = torch.normal(mean=0, std=tr_sigma_max, size=(1, 3))
if choose_residue:
idx = random.randint(0, len(complex_graph['receptor'].pos)-1)
tr_update = torch.normal(mean=complex_graph['receptor'].pos[idx:idx+1], std=0.01)
elif initial_noise_std_proportion >= 0.0:
std_rec = torch.sqrt(torch.mean(torch.sum(complex_graph['receptor'].pos ** 2, dim=1)))
tr_update = torch.normal(mean=0, std=std_rec * initial_noise_std_proportion / 1.73, size=(1, 3))
else:
# if initial_noise_std_proportion < 0.0, we use the tr_sigma_max multiplied by -initial_noise_std_proportion
tr_update = torch.normal(mean=0, std=-initial_noise_std_proportion * tr_sigma_max, size=(1, 3))
complex_graph['ligand'].pos += tr_update
def is_iterable(arr):
try:
some_object_iterator = iter(arr)
return True
except TypeError as te:
return False
def sampling(data_list, model, inference_steps, tr_schedule, rot_schedule, tor_schedule, device, t_to_sigma, model_args,
no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None,
confidence_model_args=None, batch_size=32, no_final_step_noise=False):
no_random=False, ode=False, visualization_list=None, confidence_model=None, confidence_data_list=None, confidence_model_args=None,
t_schedule=None, batch_size=32, no_final_step_noise=False, pivot=None, return_full_trajectory=False,
temp_sampling=1.0, temp_psi=0.0, temp_sigma_data=0.5, return_features=False):
N = len(data_list)
trajectory = []
if return_features:
lig_features, rec_features = [], []
assert batch_size >= N, "Not implemented yet"
for t_idx in range(inference_steps):
t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx]
dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx]
dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx]
dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx]
loader = DataLoader(data_list, batch_size=batch_size)
assert not (return_full_trajectory or return_features or pivot), "Not implemented yet in new inference version"
loader = DataLoader(data_list, batch_size=batch_size)
new_data_list = []
mask_rotate = torch.from_numpy(data_list[0]['ligand'].mask_rotate[0]).to(device)
for complex_graph_batch in loader:
b = complex_graph_batch.num_graphs
complex_graph_batch = complex_graph_batch.to(device)
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor)
set_time(complex_graph_batch, t_tr, t_rot, t_tor, b, model_args.all_atoms, device)
with torch.no_grad():
tr_score, rot_score, tor_score = model(complex_graph_batch)
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
rot_g = 2 * rot_sigma * torch.sqrt(torch.tensor(np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
if ode:
tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score.cpu()).cpu()
rot_perturb = (0.5 * rot_score.cpu() * dt_rot * rot_g ** 2).cpu()
else:
tr_z = torch.zeros((b, 3)) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(b, 3))
tr_perturb = (tr_g ** 2 * dt_tr * tr_score.cpu() + tr_g * np.sqrt(dt_tr) * tr_z).cpu()
rot_z = torch.zeros((b, 3)) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(b, 3))
rot_perturb = (rot_score.cpu() * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z).cpu()
if not model_args.no_torsion:
tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min)))
if ode:
tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score.cpu()).numpy()
else:
tor_z = torch.zeros(tor_score.shape) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=tor_score.shape)
tor_perturb = (tor_g ** 2 * dt_tor * tor_score.cpu() + tor_g * np.sqrt(dt_tor) * tor_z).numpy()
torsions_per_molecule = tor_perturb.shape[0] // b
else:
tor_perturb = None
# Apply noise
new_data_list.extend([modify_conformer(complex_graph, tr_perturb[i:i + 1], rot_perturb[i:i + 1].squeeze(0),
tor_perturb[i * torsions_per_molecule:(i + 1) * torsions_per_molecule] if not model_args.no_torsion else None)
for i, complex_graph in enumerate(complex_graph_batch.to('cpu').to_data_list())])
data_list = new_data_list
if visualization_list is not None:
for idx, visualization in enumerate(visualization_list):
visualization.add((data_list[idx]['ligand'].pos + data_list[idx].original_center).detach().cpu(),
part=1, order=t_idx + 2)
confidence = None
if confidence_model is not None:
confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size))
confidence = []
with torch.no_grad():
if confidence_model is not None:
loader = DataLoader(data_list, batch_size=batch_size)
confidence_loader = iter(DataLoader(confidence_data_list, batch_size=batch_size))
confidence = []
for complex_graph_batch in loader:
complex_graph_batch = complex_graph_batch.to(device)
if confidence_data_list is not None:
confidence_complex_graph_batch = next(confidence_loader).to(device)
confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos
set_time(confidence_complex_graph_batch, 0, 0, 0, N, confidence_model_args.all_atoms, device)
confidence.append(confidence_model(confidence_complex_graph_batch))
for batch_id, complex_graph_batch in enumerate(loader):
b = complex_graph_batch.num_graphs
n = len(complex_graph_batch['ligand'].pos) // b
complex_graph_batch = complex_graph_batch.to(device)
for t_idx in range(inference_steps):
t_tr, t_rot, t_tor = tr_schedule[t_idx], rot_schedule[t_idx], tor_schedule[t_idx]
dt_tr = tr_schedule[t_idx] - tr_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tr_schedule[t_idx]
dt_rot = rot_schedule[t_idx] - rot_schedule[t_idx + 1] if t_idx < inference_steps - 1 else rot_schedule[t_idx]
dt_tor = tor_schedule[t_idx] - tor_schedule[t_idx + 1] if t_idx < inference_steps - 1 else tor_schedule[t_idx]
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(t_tr, t_rot, t_tor)
if hasattr(model_args, 'crop_beyond') and model_args.crop_beyond is not None:
#print('Cropping beyond', tr_sigma * 3 + model_args.crop_beyond, 'for score model')
mod_complex_graph_batch = copy.deepcopy(complex_graph_batch).to_data_list()
for batch in mod_complex_graph_batch:
crop_beyond(batch, tr_sigma * 3 + model_args.crop_beyond, model_args.all_atoms)
mod_complex_graph_batch = Batch.from_data_list(mod_complex_graph_batch)
else:
confidence.append(confidence_model(complex_graph_batch))
confidence = torch.cat(confidence, dim=0)
else:
confidence = None
mod_complex_graph_batch = complex_graph_batch
set_time(mod_complex_graph_batch, t_schedule[t_idx] if t_schedule is not None else None, t_tr, t_rot, t_tor, b,
'all_atoms' in model_args and model_args.all_atoms, device)
tr_score, rot_score, tor_score = model(mod_complex_graph_batch)[:3]
tr_g = tr_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tr_sigma_max / model_args.tr_sigma_min)))
rot_g = rot_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.rot_sigma_max / model_args.rot_sigma_min)))
if ode:
tr_perturb = (0.5 * tr_g ** 2 * dt_tr * tr_score)
rot_perturb = (0.5 * rot_score * dt_rot * rot_g ** 2)
else:
tr_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
tr_perturb = (tr_g ** 2 * dt_tr * tr_score + tr_g * np.sqrt(dt_tr) * tr_z)
rot_z = torch.zeros((min(batch_size, N), 3), device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=(min(batch_size, N), 3), device=device)
rot_perturb = (rot_score * dt_rot * rot_g ** 2 + rot_g * np.sqrt(dt_rot) * rot_z)
if not model_args.no_torsion:
tor_g = tor_sigma * torch.sqrt(torch.tensor(2 * np.log(model_args.tor_sigma_max / model_args.tor_sigma_min)))
if ode:
tor_perturb = (0.5 * tor_g ** 2 * dt_tor * tor_score)
else:
tor_z = torch.zeros(tor_score.shape, device=device) if no_random or (no_final_step_noise and t_idx == inference_steps - 1) \
else torch.normal(mean=0, std=1, size=tor_score.shape, device=device)
tor_perturb = (tor_g ** 2 * dt_tor * tor_score + tor_g * np.sqrt(dt_tor) * tor_z)
torsions_per_molecule = tor_perturb.shape[0] // b
else:
tor_perturb = None
if not is_iterable(temp_sampling):
temp_sampling = [temp_sampling] * 3
if not is_iterable(temp_psi):
temp_psi = [temp_psi] * 3
if not is_iterable(temp_sampling): temp_sampling = [temp_sampling] * 3
if not is_iterable(temp_psi): temp_psi = [temp_psi] * 3
if not is_iterable(temp_sigma_data): temp_sigma_data = [temp_sigma_data] * 3
assert len(temp_sampling) == 3
assert len(temp_psi) == 3
assert len(temp_sigma_data) == 3
if temp_sampling[0] != 1.0:
tr_sigma_data = np.exp(temp_sigma_data[0] * np.log(model_args.tr_sigma_max) + (1 - temp_sigma_data[0]) * np.log(model_args.tr_sigma_min))
lambda_tr = (tr_sigma_data + tr_sigma) / (tr_sigma_data + tr_sigma / temp_sampling[0])
tr_perturb = (tr_g ** 2 * dt_tr * (lambda_tr + temp_sampling[0] * temp_psi[0] / 2) * tr_score + tr_g * np.sqrt(dt_tr * (1 + temp_psi[0])) * tr_z)
if temp_sampling[1] != 1.0:
rot_sigma_data = np.exp(temp_sigma_data[1] * np.log(model_args.rot_sigma_max) + (1 - temp_sigma_data[1]) * np.log(model_args.rot_sigma_min))
lambda_rot = (rot_sigma_data + rot_sigma) / (rot_sigma_data + rot_sigma / temp_sampling[1])
rot_perturb = (rot_g ** 2 * dt_rot * (lambda_rot + temp_sampling[1] * temp_psi[1] / 2) * rot_score + rot_g * np.sqrt(dt_rot * (1 + temp_psi[1])) * rot_z)
if temp_sampling[2] != 1.0:
tor_sigma_data = np.exp(temp_sigma_data[2] * np.log(model_args.tor_sigma_max) + (1 - temp_sigma_data[2]) * np.log(model_args.tor_sigma_min))
lambda_tor = (tor_sigma_data + tor_sigma) / (tor_sigma_data + tor_sigma / temp_sampling[2])
tor_perturb = (tor_g ** 2 * dt_tor * (lambda_tor + temp_sampling[2] * temp_psi[2] / 2) * tor_score + tor_g * np.sqrt(dt_tor * (1 + temp_psi[2])) * tor_z)
# Apply noise
complex_graph_batch['ligand'].pos = \
modify_conformer_batch(complex_graph_batch['ligand'].pos, complex_graph_batch, tr_perturb, rot_perturb,
tor_perturb if not model_args.no_torsion else None, mask_rotate)
if visualization_list is not None:
for idx_b in range(b):
visualization_list[batch_id * batch_size + idx_b].add((
complex_graph_batch['ligand'].pos[idx_b*n:n*(idx_b+1)].detach().cpu() +
data_list[batch_id * batch_size + idx_b].original_center.detach().cpu()),
part=1, order=t_idx + 2)
for i in range(b):
data_list[batch_id * batch_size + i]['ligand'].pos = complex_graph_batch['ligand'].pos[i*n:n*(i+1)]
if visualization_list is not None:
for idx, visualization in enumerate(visualization_list):
visualization.add((data_list[idx]['ligand'].pos.detach().cpu() + data_list[idx].original_center.detach().cpu()),
part=1, order=2)
if confidence_model is not None:
if confidence_data_list is not None:
confidence_complex_graph_batch = next(confidence_loader)
confidence_complex_graph_batch['ligand'].pos = complex_graph_batch['ligand'].pos.cpu()
if hasattr(confidence_model_args, 'crop_beyond') and confidence_model_args.crop_beyond is not None:
confidence_complex_graph_batch = confidence_complex_graph_batch.to_data_list()
for batch in confidence_complex_graph_batch:
crop_beyond(batch, confidence_model_args.crop_beyond, confidence_model_args.all_atoms)
confidence_complex_graph_batch = Batch.from_data_list(confidence_complex_graph_batch)
confidence_complex_graph_batch = confidence_complex_graph_batch.to(device)
set_time(confidence_complex_graph_batch, 0, 0, 0, 0, b, confidence_model_args.all_atoms, device)
out = confidence_model(confidence_complex_graph_batch)
else:
out = confidence_model(complex_graph_batch)
if type(out) is tuple:
out = out[0]
confidence.append(out)
if confidence_model is not None:
confidence = torch.cat(confidence, dim=0)
confidence = torch.nan_to_num(confidence, nan=-1000)
if return_full_trajectory:
return data_list, confidence, trajectory
elif return_features:
lig_features = torch.cat(lig_features, dim=0)
rec_features = torch.cat(rec_features, dim=0)
return data_list, confidence, lig_features, rec_features
return data_list, confidence
def compute_affinity(data_list, affinity_model, affinity_data_list, device, parallel, all_atoms, include_miscellaneous_atoms):
with torch.no_grad():
if affinity_model is not None:
assert parallel <= len(data_list)
loader = DataLoader(data_list, batch_size=parallel)
complex_graph_batch = next(iter(loader)).to(device)
positions = complex_graph_batch['ligand'].pos
assert affinity_data_list is not None
complex_graph = affinity_data_list[0]
N = complex_graph['ligand'].num_nodes
complex_graph['ligand'].x = complex_graph['ligand'].x.repeat(parallel, 1)
complex_graph['ligand'].edge_mask = complex_graph['ligand'].edge_mask.repeat(parallel)
complex_graph['ligand', 'ligand'].edge_index = torch.cat(
[N * i + complex_graph['ligand', 'ligand'].edge_index for i in range(parallel)], dim=1)
complex_graph['ligand', 'ligand'].edge_attr = complex_graph['ligand', 'ligand'].edge_attr.repeat(parallel, 1)
complex_graph['ligand'].pos = positions
affinity_loader = DataLoader([complex_graph], batch_size=1)
affinity_batch = next(iter(affinity_loader)).to(device)
set_time(affinity_batch, 0, 0, 0, 0, 1, all_atoms, device, include_miscellaneous_atoms=include_miscellaneous_atoms)
_, affinity = affinity_model(affinity_batch)
else:
affinity = None
return affinity

View File

@@ -3,7 +3,7 @@ import numpy as np
import torch
from scipy.spatial.transform import Rotation
MIN_EPS, MAX_EPS, N_EPS = 0.01, 2, 1000
MIN_EPS, MAX_EPS, N_EPS = 0.0005, 4, 2000
X_N = 2000
"""
@@ -21,7 +21,7 @@ def _compose(r1, r2): # R1 @ R2 but for Euler vecs
def _expansion(omega, eps, L=2000): # the summation term only
p = 0
for l in range(L):
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
p += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * np.sin(omega * (l + 1 / 2)) / np.sin(omega / 2)
return p
@@ -39,17 +39,16 @@ def _score(exp, omega, eps, L=2000): # score of density over SO(3)
dhi = (l + 1 / 2) * np.cos(omega * (l + 1 / 2))
lo = np.sin(omega / 2)
dlo = 1 / 2 * np.cos(omega / 2)
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2) * (lo * dhi - hi * dlo) / lo ** 2
dSigma += (2 * l + 1) * np.exp(-l * (l + 1) * eps**2 / 2) * (lo * dhi - hi * dlo) / lo ** 2
return dSigma / exp
if os.path.exists('.so3_omegas_array2.npy'):
_omegas_array = np.load('.so3_omegas_array2.npy')
_cdf_vals = np.load('.so3_cdf_vals2.npy')
_score_norms = np.load('.so3_score_norms2.npy')
_exp_score_norms = np.load('.so3_exp_score_norms2.npy')
if os.path.exists('.so3_omegas_array4.npy'):
_omegas_array = np.load('.so3_omegas_array4.npy')
_cdf_vals = np.load('.so3_cdf_vals4.npy')
_score_norms = np.load('.so3_score_norms4.npy')
_exp_score_norms = np.load('.so3_exp_score_norms4.npy')
else:
print("Precomputing and saving to cache SO(3) distribution table")
_eps_array = 10 ** np.linspace(np.log10(MIN_EPS), np.log10(MAX_EPS), N_EPS)
_omegas_array = np.linspace(0, np.pi, X_N + 1)[1:]
@@ -60,10 +59,10 @@ else:
_exp_score_norms = np.sqrt(np.sum(_score_norms**2 * _pdf_vals, axis=1) / np.sum(_pdf_vals, axis=1) / np.pi)
np.save('.so3_omegas_array2.npy', _omegas_array)
np.save('.so3_cdf_vals2.npy', _cdf_vals)
np.save('.so3_score_norms2.npy', _score_norms)
np.save('.so3_exp_score_norms2.npy', _exp_score_norms)
np.save('.so3_omegas_array4.npy', _omegas_array)
np.save('.so3_cdf_vals4.npy', _cdf_vals)
np.save('.so3_score_norms4.npy', _score_norms)
np.save('.so3_exp_score_norms4.npy', _exp_score_norms)
def sample(eps):

View File

@@ -5,6 +5,8 @@ from scipy.spatial.transform import Rotation as R
from torch_geometric.utils import to_networkx
from torch_geometric.data import Data
from utils.geometry import rigid_transform_Kabsch_independent_torch, axis_angle_to_matrix
"""
Preprocessing and computation for torsional updates to conformers
"""
@@ -35,7 +37,7 @@ def get_transformation_mask(pyg_data):
mask_edges = np.asarray([0 if len(l) == 0 else 1 for l in to_rotate], dtype=bool)
mask_rotate = np.zeros((np.sum(mask_edges), len(G.nodes())), dtype=bool)
idx = 0
for i in range(len(G.edges())):
for i in range(min(edges.shape[0], len(G.edges()))):
if mask_edges[i]:
mask_rotate[idx][np.asarray(to_rotate[i], dtype=int)] = True
idx += 1
@@ -46,15 +48,19 @@ def get_transformation_mask(pyg_data):
def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_updates, as_numpy=False):
pos = copy.deepcopy(pos)
if type(pos) != np.ndarray: pos = pos.cpu().numpy()
if type(mask_rotate) == list: mask_rotate = mask_rotate[0]
for idx_edge, e in enumerate(edge_index.cpu().numpy()):
if torsion_updates[idx_edge] == 0:
continue
u, v = e[0], e[1]
# check if need to reverse the edge, v should be connected to the part that gets rotated
assert not mask_rotate[idx_edge, u]
assert mask_rotate[idx_edge, v]
if mask_rotate[idx_edge, u] or (not mask_rotate[idx_edge, v]):
print("mask rotate exception")
#assert not mask_rotate[idx_edge, u]
#assert mask_rotate[idx_edge, v]
rot_vec = pos[u] - pos[v] # convention: positive rotation if pointing inwards
rot_vec = rot_vec * torsion_updates[idx_edge] / np.linalg.norm(rot_vec) # idx_edge!
@@ -66,6 +72,24 @@ def modify_conformer_torsion_angles(pos, edge_index, mask_rotate, torsion_update
return pos
def modify_conformer_torsion_angles_batch(pos, edge_index, mask_rotate, torsion_updates):
pos = pos + 0
for idx_edge, e in enumerate(edge_index):
u, v = e[0], e[1]
# check if need to reverse the edge, v should be connected to the part that gets rotated
assert not mask_rotate[idx_edge, u]
assert mask_rotate[idx_edge, v]
rot_vec = pos[:, u] - pos[:, v] # convention: positive rotation if pointing inwards
rot_mat = axis_angle_to_matrix(
rot_vec / torch.linalg.norm(rot_vec, dim=-1, keepdims=True) * torsion_updates[:, idx_edge:idx_edge + 1])
pos[:, mask_rotate[idx_edge]] = torch.bmm(pos[:, mask_rotate[idx_edge]] - pos[:, v:v + 1], torch.transpose(rot_mat, 1, 2)) + pos[:, v:v + 1]
return pos
def perturb_batch(data, torsion_updates, split=False, return_updates=False):
if type(data) is Data:
return modify_conformer_torsion_angles(data.pos,
@@ -91,4 +115,24 @@ def perturb_batch(data, torsion_updates, split=False, return_updates=False):
idx_edges += mask_rotate.shape[0]
if return_updates:
return pos_new, torsion_update_list
return pos_new
return pos_new
def get_dihedrals(data_list):
edge_index, edge_mask = data_list[0]['ligand', 'ligand'].edge_index, data_list[0]['ligand'].edge_mask
edge_list = [[] for _ in range(torch.max(edge_index) + 1)]
for p in edge_index.T:
edge_list[p[0]].append(p[1])
rot_bonds = [(p[0], p[1]) for i, p in enumerate(edge_index.T) if edge_mask[i]]
dihedral = []
for a, b in rot_bonds:
c = edge_list[a][0] if edge_list[a][0] != b else edge_list[a][1]
d = edge_list[b][0] if edge_list[b][0] != a else edge_list[b][1]
dihedral.append((c.item(), a.item(), b.item(), d.item()))
# dihedral_numpy = np.asarray(dihedral)
# print(dihedral_numpy.shape)
dihedral = torch.tensor(dihedral)
return dihedral

View File

@@ -32,7 +32,6 @@ if os.path.exists('.p.npy'):
p_ = np.load('.p.npy')
score_ = np.load('.score.npy')
else:
print("Precomputing and saving to cache torus distribution table")
p_ = p(x, sigma[:, None], N=100)
np.save('.p.npy', p_)

View File

@@ -1,18 +1,19 @@
import copy
import numpy as np
from rdkit.Chem import RemoveAllHs
from torch_geometric.loader import DataLoader
from tqdm import tqdm
import torch
from confidence.dataset import ListDataset
from utils import so3, torus
from utils.molecules_utils import get_symmetry_rmsd
from utils.sampling import randomize_position, sampling
import torch
from utils.diffusion_utils import get_t_schedule
def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weight=1, rot_weight=1,
tor_weight=1, apply_mean=True, no_torsion=False):
def loss_function(tr_pred, rot_pred, tor_pred, sidechain_pred, data, t_to_sigma, device, tr_weight=1, rot_weight=1,
tor_weight=1, backbone_weight=0, sidechain_weight=0, apply_mean=True, no_torsion=False):
tr_sigma, rot_sigma, tor_sigma = t_to_sigma(
*[torch.cat([d.complex_t[noise_type] for d in data]) if device.type == 'cuda' else data.complex_t[noise_type]
for noise_type in ['tr', 'rot', 'tor']])
@@ -21,14 +22,14 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
# translation component
tr_score = torch.cat([d.tr_score for d in data], dim=0) if device.type == 'cuda' else data.tr_score
tr_sigma = tr_sigma.unsqueeze(-1)
tr_loss = ((tr_pred.cpu() - tr_score) ** 2 * tr_sigma ** 2).mean(dim=mean_dims)
tr_loss = ((tr_pred.cpu() - tr_score.cpu()) ** 2 * tr_sigma.cpu() ** 2).mean(dim=mean_dims)
tr_base_loss = (tr_score ** 2 * tr_sigma ** 2).mean(dim=mean_dims).detach()
# rotation component
rot_score = torch.cat([d.rot_score for d in data], dim=0) if device.type == 'cuda' else data.rot_score
rot_score_norm = so3.score_norm(rot_sigma.cpu()).unsqueeze(-1)
rot_loss = (((rot_pred.cpu() - rot_score) / rot_score_norm) ** 2).mean(dim=mean_dims)
rot_base_loss = ((rot_score / rot_score_norm) ** 2).mean(dim=mean_dims).detach()
rot_loss = (((rot_pred.cpu() - rot_score.cpu()) / rot_score_norm) ** 2).mean(dim=mean_dims)
rot_base_loss = ((rot_score.cpu() / rot_score_norm) ** 2).mean(dim=mean_dims).detach()
# torsion component
if not no_torsion:
@@ -36,8 +37,8 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
np.concatenate([d.tor_sigma_edge for d in data] if device.type == 'cuda' else data.tor_sigma_edge))
tor_score = torch.cat([d.tor_score for d in data], dim=0) if device.type == 'cuda' else data.tor_score
tor_score_norm2 = torch.tensor(torus.score_norm(edge_tor_sigma.cpu().numpy())).float()
tor_loss = ((tor_pred.cpu() - tor_score) ** 2 / tor_score_norm2)
tor_base_loss = ((tor_score ** 2 / tor_score_norm2)).detach()
tor_loss = ((tor_pred.cpu() - tor_score.cpu()) ** 2 / tor_score_norm2)
tor_base_loss = ((tor_score.cpu() ** 2 / tor_score_norm2)).detach()
if apply_mean:
tor_loss, tor_base_loss = tor_loss.mean() * torch.ones(1, dtype=torch.float), tor_base_loss.mean() * torch.ones(1, dtype=torch.float)
else:
@@ -57,8 +58,70 @@ def loss_function(tr_pred, rot_pred, tor_pred, data, t_to_sigma, device, tr_weig
else:
tor_loss, tor_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(len(rot_loss), dtype=torch.float)
loss = tr_loss * tr_weight + rot_loss * rot_weight + tor_loss * tor_weight
return loss, tr_loss.detach(), rot_loss.detach(), tor_loss.detach(), tr_base_loss, rot_base_loss, tor_base_loss
if backbone_weight > 0:
backbone_vecs = torch.cat([d['receptor'].side_chain_vecs.cpu() for d in data], dim=0) if device.type == 'cuda' else data['receptor'].side_chain_vecs
backbone_vecs = backbone_vecs[:, 4:]
backbone_pred = sidechain_pred[:, 4:]
backbone_base_loss = (backbone_vecs ** 2).detach().mean(dim=1) + 0.0001
backbone_loss = ((backbone_pred.cpu() - backbone_vecs) ** 2).mean(dim=1) / backbone_base_loss.mean()
backbone_base_loss = backbone_base_loss / backbone_base_loss.mean()
if apply_mean:
backbone_loss, backbone_base_loss = backbone_loss.mean() * torch.ones(1, dtype=torch.float), backbone_base_loss.mean() * torch.ones(1, dtype=torch.float)
else:
index = torch.cat([torch.ones((d['receptor'].pos.shape[0])) * i for i, d in enumerate(data)], dim=0).long() if device.type == 'cuda' else data['receptor'].batch
num_graphs = len(data) if device.type == 'cuda' else data.num_graphs
s_l, s_b_l, c = torch.zeros(num_graphs), torch.zeros(num_graphs), torch.zeros(num_graphs)
c.index_add_(0, index, torch.ones(backbone_loss.shape[0]))
c = c + 0.0001
s_l.index_add_(0, index, backbone_loss)
s_b_l.index_add_(0, index, backbone_base_loss)
backbone_loss, backbone_base_loss = s_l / c, s_b_l / c
else:
if apply_mean:
backbone_loss, backbone_base_loss = torch.zeros(1, dtype=torch.float), torch.zeros(1, dtype=torch.float)
else:
backbone_loss, backbone_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(len(rot_loss), dtype=torch.float)
if sidechain_weight > 0:
sidechain_vecs = torch.cat([d['receptor'].side_chain_vecs.cpu() for d in data],
dim=0) if device.type == 'cuda' else data['receptor'].side_chain_vecs
chi_angles = sidechain_vecs[:, :4].to(device)
chi_pred = sidechain_pred[:, :4].to(device)
chi_pred = torch.where(torch.isnan(chi_angles), torch.zeros_like(chi_angles, device=device), chi_pred)
chi_angles = torch.where(torch.isnan(chi_angles), torch.zeros_like(chi_angles, device=device), chi_angles)
difference = torch.abs(chi_pred - chi_angles)
difference = torch.min(difference, 1 - difference) # angles are circular and 360 degrees = 1
sidechain_base_loss = (chi_angles ** 2).detach().mean(dim=1) + 0.0001
sidechain_loss = (difference ** 2).mean(dim=1) / sidechain_base_loss.mean()
sidechain_base_loss = sidechain_base_loss / sidechain_base_loss.mean()
if apply_mean:
sidechain_loss, sidechain_base_loss = \
sidechain_loss.mean().cpu() * torch.ones(1, dtype=torch.float), \
sidechain_base_loss.mean().cpu() * torch.ones(1, dtype=torch.float)
else:
index = torch.cat([torch.ones((d['receptor'].pos.shape[0])) * i for i, d in enumerate(data)],
dim=0).long() if device.type == 'cuda' else data['receptor'].batch
num_graphs = len(data) if device.type == 'cuda' else data.num_graphs
s_l, s_b_l, c = torch.zeros(num_graphs), torch.zeros(num_graphs), torch.zeros(num_graphs)
c.index_add_(0, index, torch.ones(sidechain_loss.shape[0]))
c = c + 0.0001
s_l.index_add_(0, index, sidechain_loss.cpu())
s_b_l.index_add_(0, index, sidechain_base_loss.cpu())
sidechain_loss, sidechain_base_loss = s_l / c, s_b_l / c
else:
if apply_mean:
sidechain_loss, sidechain_base_loss = torch.zeros(1, dtype=torch.float), torch.zeros(1, dtype=torch.float)
else:
sidechain_loss, sidechain_base_loss = torch.zeros(len(rot_loss), dtype=torch.float), torch.zeros(
len(rot_loss), dtype=torch.float)
loss = tr_loss * tr_weight + rot_loss * rot_weight + tor_loss * tor_weight + sidechain_loss * sidechain_weight + backbone_loss * backbone_weight
return loss, tr_loss.detach(), rot_loss.detach(), tor_loss.detach(), backbone_loss.detach(), sidechain_loss.detach(), \
tr_base_loss, rot_base_loss, tor_base_loss, backbone_base_loss, sidechain_base_loss
class AverageMeter():
@@ -73,7 +136,7 @@ class AverageMeter():
if self.intervals == 1:
self.count += 1 if vals[0].dim() == 0 else len(vals[0])
for type_idx, v in enumerate(vals):
self.acc[self.types[type_idx]] += v.sum() if self.unpooled_metrics else v
self.acc[self.types[type_idx]] += v.sum().cpu() if self.unpooled_metrics else v.cpu()
else:
for type_idx, v in enumerate(vals):
self.count[type_idx].index_add_(0, interval_idx[type_idx], torch.ones(len(v)))
@@ -93,22 +156,34 @@ class AverageMeter():
return out
def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weigths):
def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weights):
model.train()
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'])
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'])
for data in tqdm(loader, total=len(loader)):
if device.type == 'cuda' and len(data) == 1 or device.type == 'cpu' and data.num_graphs == 1:
print("Skipping batch of size 1 since otherwise batchnorm would not work.")
continue
optimizer.zero_grad()
data = [d.to(device) for d in data] if device.type == 'cuda' else data
try:
tr_pred, rot_pred, tor_pred = model(data)
loss, tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss = \
loss_fn(tr_pred, rot_pred, tor_pred, data=data, t_to_sigma=t_to_sigma, device=device)
tr_pred, rot_pred, tor_pred, sidechain_pred = model(data)
loss_tuple = loss_fn(tr_pred, rot_pred, tor_pred, sidechain_pred, data=data, t_to_sigma=t_to_sigma, device=device)
if loss_tuple is None:
print("None loss tuple, skipping")
continue
loss = loss_tuple[0]
if torch.any(torch.isnan(loss)):
names = data.name if device.type == 'cpu' else [d.name for d in data]
print("Nan loss, skipping batch with complexes", names)
continue
loss.backward()
optimizer.step()
ema_weigths.update(model.parameters())
meter.add([loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss])
if ema_weights is not None: ema_weights.update(model.parameters())
meter.add([loss.cpu().detach(), *loss_tuple[1:]])
except RuntimeError as e:
if 'out of memory' in str(e):
print('| WARNING: ran out of memory, skipping batch')
@@ -125,40 +200,42 @@ def train_epoch(model, loader, optimizer, device, t_to_sigma, loss_fn, ema_weigt
torch.cuda.empty_cache()
continue
else:
raise e
#raise e
print(e)
continue
return meter.summary()
def test_epoch(model, loader, device, t_to_sigma, loss_fn, test_sigma_intervals=False):
model.eval()
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'],
meter = AverageMeter(['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'],
unpooled_metrics=True)
if test_sigma_intervals:
meter_all = AverageMeter(
['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'tr_base_loss', 'rot_base_loss', 'tor_base_loss'],
['loss', 'tr_loss', 'rot_loss', 'tor_loss', 'backbone_loss', 'sidechain_loss',
'tr_base_loss', 'rot_base_loss', 'tor_base_loss', 'backbone_base_loss', 'sidechain_base_loss'],
unpooled_metrics=True, intervals=10)
for data in tqdm(loader, total=len(loader)):
try:
with torch.no_grad():
tr_pred, rot_pred, tor_pred = model(data)
loss, tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss = \
loss_fn(tr_pred, rot_pred, tor_pred, data=data, t_to_sigma=t_to_sigma, apply_mean=False, device=device)
meter.add([loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss])
tr_pred, rot_pred, tor_pred, sidechain_pred = model(data)
loss_tuple = loss_fn(tr_pred, rot_pred, tor_pred, sidechain_pred, data=data, t_to_sigma=t_to_sigma, apply_mean=False, device=device)
if loss_tuple is None: continue
meter.add([loss_tuple[0].cpu().detach(), *loss_tuple[1:]])
if test_sigma_intervals > 0:
complex_t_tr, complex_t_rot, complex_t_tor = [torch.cat([d.complex_t[noise_type] for d in data]) for
complex_t_tr, complex_t_rot, complex_t_tor = [torch.cat([data[i].complex_t[noise_type] for i in range(len(data))]) for
noise_type in ['tr', 'rot', 'tor']]
sigma_index_tr = torch.round(complex_t_tr.cpu() * (10 - 1)).long()
sigma_index_rot = torch.round(complex_t_rot.cpu() * (10 - 1)).long()
sigma_index_tor = torch.round(complex_t_tor.cpu() * (10 - 1)).long()
meter_all.add(
[loss.cpu().detach(), tr_loss, rot_loss, tor_loss, tr_base_loss, rot_base_loss, tor_base_loss],
[sigma_index_tr, sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_rot,
sigma_index_tor, sigma_index_tr])
meter_all.add([loss_tuple[0].cpu().detach(), *loss_tuple[1:]],
[sigma_index_tr, sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_tr,
sigma_index_tr, sigma_index_rot, sigma_index_tor, sigma_index_tr, sigma_index_tr])
except RuntimeError as e:
if 'out of memory' in str(e):
@@ -177,60 +254,87 @@ def test_epoch(model, loader, device, t_to_sigma, loss_fn, test_sigma_intervals=
continue
else:
raise e
print(e)
continue
out = meter.summary()
if test_sigma_intervals > 0: out.update(meter_all.summary())
return out
def inference_epoch(model, complex_graphs, device, t_to_sigma, args):
t_schedule = get_t_schedule(inference_steps=args.inference_steps)
def inference_epoch_fix(model, complex_graphs, device, t_to_sigma, args):
t_schedule = get_t_schedule(sigma_schedule='expbeta', inference_steps=args.inference_steps,
inf_sched_alpha=1, inf_sched_beta=1)
tr_schedule, rot_schedule, tor_schedule = t_schedule, t_schedule, t_schedule
dataset = ListDataset(complex_graphs)
loader = DataLoader(dataset=dataset, batch_size=1, shuffle=False)
rmsds = []
rmsds, min_rmsds = [], []
for orig_complex_graph in tqdm(loader):
data_list = [copy.deepcopy(orig_complex_graph)]
data_list = [copy.deepcopy(orig_complex_graph) for _ in range(args.inference_samples)]
randomize_position(data_list, args.no_torsion, False, args.tr_sigma_max)
predictions_list = None
failed_convergence_counter = 0
while predictions_list == None:
try:
predictions_list, confidences = sampling(data_list=data_list, model=model.module if device.type=='cuda' else model,
predictions_list, confidences = sampling(data_list=data_list, model=model.module if device.type == 'cuda' else model,
inference_steps=args.inference_steps,
tr_schedule=tr_schedule, rot_schedule=rot_schedule,
tor_schedule=tor_schedule,
device=device, t_to_sigma=t_to_sigma, model_args=args)
device=device, t_to_sigma=t_to_sigma, model_args=args,
t_schedule=t_schedule)
except Exception as e:
if 'failed to converge' in str(e):
failed_convergence_counter += 1
if failed_convergence_counter > 5:
print('| WARNING: SVD failed to converge 5 times - skipping the complex')
break
print('| WARNING: SVD failed to converge - trying again with a new sample')
else:
raise e
if failed_convergence_counter > 5: continue
failed_convergence_counter += 1
if failed_convergence_counter > 5:
print('failed 5 times - skipping the complex')
break
print("Exception while running inference on complex:", e)
if failed_convergence_counter > 5:
rmsds.extend([100] * args.inference_samples)
min_rmsds.append(100)
continue
if args.no_torsion:
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph['ligand'].pos.cpu().numpy() +
orig_complex_graph.original_center.cpu().numpy())
orig_complex_graph['ligand'].orig_pos = (orig_complex_graph[
'ligand'].pos.cpu().numpy() + orig_complex_graph.original_center.cpu().numpy())
filterHs = torch.not_equal(predictions_list[0]['ligand'].x[:, 0], 0).cpu().numpy()
if isinstance(orig_complex_graph['ligand'].orig_pos, list):
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
# if len(orig_complex_graph['ligand'].orig_pos.shape) == 3:
# orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[0]
ligand_pos = np.asarray(
[complex_graph['ligand'].pos.cpu().numpy()[filterHs] for complex_graph in predictions_list])
orig_ligand_pos = np.expand_dims(
orig_complex_graph['ligand'].orig_pos[filterHs] - orig_complex_graph.original_center.cpu().numpy(), axis=0)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos) ** 2).sum(axis=2).mean(axis=1))
rmsds.append(rmsd)
if len(orig_complex_graph['ligand'].orig_pos.shape) == 2:
orig_complex_graph['ligand'].orig_pos = orig_complex_graph['ligand'].orig_pos[None, :, :]
try:
orig_ligand_pos = orig_complex_graph['ligand'].orig_pos[:, filterHs] - orig_complex_graph.original_center.cpu().numpy()
except Exception as e:
print("problem with orig_pos which is of shape:", orig_complex_graph['ligand'].orig_pos.shape, e)
continue
mol = RemoveAllHs(orig_complex_graph.mol[0])
complex_rmsds = []
for i in range(len(orig_ligand_pos)):
try:
rmsd = get_symmetry_rmsd(mol, orig_ligand_pos[i], [l for l in ligand_pos])
except Exception as e:
print("Using non corrected RMSD because of the error:", e)
rmsd = np.sqrt(((ligand_pos - orig_ligand_pos[i]) ** 2).sum(axis=2).mean(axis=1))
complex_rmsds.append(rmsd)
complex_rmsds = np.asarray(complex_rmsds)
rmsd = np.min(complex_rmsds, axis=0)
rmsds.extend([r for r in rmsd])
min_rmsds.append(rmsd.min(axis=0))
rmsds = np.array(rmsds)
min_rmsds = np.array(min_rmsds)
losses = {'rmsds_lt2': (100 * (rmsds < 2).sum() / len(rmsds)),
'rmsds_lt5': (100 * (rmsds < 5).sum() / len(rmsds))}
'rmsds_lt5': (100 * (rmsds < 5).sum() / len(rmsds)),
'min_rmsds_lt2': (100 * (min_rmsds < 2).sum() / len(min_rmsds)),
'min_rmsds_lt5': (100 * (min_rmsds < 5).sum() / len(min_rmsds)),}
return losses

View File

@@ -2,19 +2,23 @@ import os
import subprocess
import warnings
from datetime import datetime
import signal
from contextlib import contextmanager
from typing import List
import numpy
import numpy as np
import torch
import yaml
from rdkit import Chem
from rdkit.Chem import RemoveHs, MolToPDBFile
from torch import nn, Tensor
from torch_geometric.nn.data_parallel import DataParallel
from torch_geometric.utils import degree, subgraph
from models.all_atom_score_model import TensorProductScoreModel as AAScoreModel
from models.score_model import TensorProductScoreModel as CGScoreModel
from models.aa_model import AAModel
from models.cg_model import CGModel
from models.old_aa_model import AAOldModel
from models.old_cg_model import CGOldModel
from utils.diffusion_utils import get_timestep_embedding
from spyrmsd import rmsd, molecule
def get_obrmsd(mol1_path, mol2_path, cache_name=None):
@@ -61,6 +65,53 @@ def read_strings_from_txt(path):
return [line.rstrip() for line in lines]
def unbatch(src, batch: Tensor, dim: int = 0) -> List[Tensor]:
r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension
:obj:`dim`.
Args:
src (Tensor): The source tensor.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
entry in :obj:`src` to a specific example. Must be ordered.
dim (int, optional): The dimension along which to split the :obj:`src`
tensor. (default: :obj:`0`)
:rtype: :class:`List[Tensor]`
"""
sizes = degree(batch, dtype=torch.long).tolist()
if isinstance(src, numpy.ndarray):
return np.split(src, np.array(sizes).cumsum()[:-1], axis=dim)
else:
return src.split(sizes, dim)
def unbatch_edge_index(edge_index: Tensor, batch: Tensor) -> List[Tensor]:
r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector.
Args:
edge_index (Tensor): The edge_index tensor. Must be ordered.
batch (LongTensor): The batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. Must be ordered.
:rtype: :class:`List[Tensor]`
"""
deg = degree(batch, dtype=torch.int64)
ptr = torch.cat([deg.new_zeros(1), deg.cumsum(dim=0)[:-1]], dim=0)
edge_batch = batch[edge_index[0]]
edge_index = edge_index - ptr[edge_batch]
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
return edge_index.split(sizes, dim=1)
def unbatch_edge_attributes(edge_attributes, edge_index: Tensor, batch: Tensor) -> List[Tensor]:
edge_batch = batch[edge_index[0]]
sizes = degree(edge_batch, dtype=torch.int64).cpu().tolist()
return edge_attributes.split(sizes, dim=0)
def save_yaml_file(path, content):
assert isinstance(path, str), f'path must be a string, got {path} which is a {type(path)}'
content = yaml.dump(data=content)
@@ -70,12 +121,47 @@ def save_yaml_file(path, content):
f.write(content)
def get_optimizer_and_scheduler(args, model, scheduler_mode='min'):
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay)
def unfreeze_layer(model):
for name, child in (model.named_children()):
#print(name, child.parameters())
for param in child.parameters():
param.requires_grad = True
def get_optimizer_and_scheduler(args, model, scheduler_mode='min', step=0, optimizer=None):
if args.scheduler == 'layer_linear_warmup':
if step == 0:
for name, child in (model.named_children()):
if name.find('batch_norm') == -1:
for name, param in child.named_parameters():
if name.find('batch_norm') == -1:
param.requires_grad = False
for l in [model.center_edge_embedding, model.final_conv, model.tr_final_layer, model.rot_final_layer,
model.final_edge_embedding, model.final_tp_tor, model.tor_bond_conv, model.tor_final_layer]:
unfreeze_layer(l)
elif 0 < step <= args.num_conv_layers:
unfreeze_layer(model.conv_layers[-step])
elif step == args.num_conv_layers + 1:
for l in [model.lig_node_embedding, model.lig_edge_embedding, model.rec_node_embedding, model.rec_edge_embedding,
model.rec_sigma_embedding, model.cross_edge_embedding, model.rec_emb_layers, model.lig_emb_layers]:
unfreeze_layer(l)
if step == 0 or args.scheduler == 'layer_linear_warmup':
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.w_decay)
scheduler_plateau = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7, patience=args.scheduler_patience, min_lr=args.lr / 100)
if args.scheduler == 'plateau':
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode=scheduler_mode, factor=0.7,
patience=args.scheduler_patience, min_lr=args.lr / 100)
scheduler = scheduler_plateau
elif args.scheduler == 'linear_warmup' or args.scheduler == 'layer_linear_warmup':
if (args.scheduler == 'linear_warmup' and step < 1) or \
(args.scheduler == 'layer_linear_warmup' and step <= args.num_conv_layers + 1):
scheduler = torch.optim.lr_scheduler.LinearLR(optimizer, start_factor=args.lr_start_factor, end_factor=1.0,
total_iters=args.warmup_dur)
else:
scheduler = scheduler_plateau
else:
print('No scheduler')
scheduler = None
@@ -83,63 +169,119 @@ def get_optimizer_and_scheduler(args, model, scheduler_mode='min'):
return optimizer, scheduler
def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False):
if 'all_atoms' in args and args.all_atoms:
model_class = AAScoreModel
else:
model_class = CGScoreModel
def get_model(args, device, t_to_sigma, no_parallel=False, confidence_mode=False, old=False):
timestep_emb_func = get_timestep_embedding(
embedding_type=args.embedding_type,
embedding_type=args.embedding_type if 'embedding_type' in args else 'sinusoidal',
embedding_dim=args.sigma_embed_dim,
embedding_scale=args.embedding_scale)
embedding_scale=args.embedding_scale if 'embedding_type' in args else 10000)
lm_embedding_type = None
if args.esm_embeddings_path is not None: lm_embedding_type = 'esm'
if old:
if 'all_atoms' in args and args.all_atoms:
model_class = AAOldModel
else:
model_class = CGOldModel
model = model_class(t_to_sigma=t_to_sigma,
device=device,
no_torsion=args.no_torsion,
timestep_emb_func=timestep_emb_func,
num_conv_layers=args.num_conv_layers,
lig_max_radius=args.max_radius,
scale_by_sigma=args.scale_by_sigma,
sigma_embed_dim=args.sigma_embed_dim,
ns=args.ns, nv=args.nv,
distance_embed_dim=args.distance_embed_dim,
cross_distance_embed_dim=args.cross_distance_embed_dim,
batch_norm=not args.no_batch_norm,
dropout=args.dropout,
use_second_order_repr=args.use_second_order_repr,
cross_max_distance=args.cross_max_distance,
dynamic_max_cross=args.dynamic_max_cross,
lm_embedding_type=lm_embedding_type,
confidence_mode=confidence_mode,
num_confidence_outputs=len(
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
args.rmsd_classification_cutoff, list) else 1)
lm_embedding_type = None
if args.esm_embeddings_path is not None: lm_embedding_type = 'esm'
if device.type == 'cuda' and not no_parallel:
model = model_class(t_to_sigma=t_to_sigma,
device=device,
no_torsion=args.no_torsion,
timestep_emb_func=timestep_emb_func,
num_conv_layers=args.num_conv_layers,
lig_max_radius=args.max_radius,
scale_by_sigma=args.scale_by_sigma,
sigma_embed_dim=args.sigma_embed_dim,
norm_by_sigma='norm_by_sigma' in args and args.norm_by_sigma,
ns=args.ns, nv=args.nv,
distance_embed_dim=args.distance_embed_dim,
cross_distance_embed_dim=args.cross_distance_embed_dim,
batch_norm=not args.no_batch_norm,
dropout=args.dropout,
use_second_order_repr=args.use_second_order_repr,
cross_max_distance=args.cross_max_distance,
dynamic_max_cross=args.dynamic_max_cross,
smooth_edges=args.smooth_edges if "smooth_edges" in args else False,
odd_parity=args.odd_parity if "odd_parity" in args else False,
lm_embedding_type=lm_embedding_type,
confidence_mode=confidence_mode,
affinity_prediction=args.affinity_prediction if 'affinity_prediction' in args else False,
parallel=args.parallel if "parallel" in args else 1,
num_confidence_outputs=len(
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
args.rmsd_classification_cutoff, list) else 1,
parallel_aggregators=args.parallel_aggregators if "parallel_aggregators" in args else "",
fixed_center_conv=not args.not_fixed_center_conv if "not_fixed_center_conv" in args else False,
no_aminoacid_identities=args.no_aminoacid_identities if "no_aminoacid_identities" in args else False,
include_miscellaneous_atoms=args.include_miscellaneous_atoms if hasattr(args, 'include_miscellaneous_atoms') else False,
use_old_atom_encoder=args.use_old_atom_encoder if hasattr(args, 'use_old_atom_encoder') else True)
else:
if 'all_atoms' in args and args.all_atoms:
model_class = AAModel
else:
model_class = CGModel
lm_embedding_type = None
if ('moad_esm_embeddings_path' in args and args.moad_esm_embeddings_path is not None) or \
('pdbbind_esm_embeddings_path' in args and args.pdbbind_esm_embeddings_path is not None) or \
('pdbsidechain_esm_embeddings_path' in args and args.pdbsidechain_esm_embeddings_path is not None) or \
('esm_embeddings_path' in args and args.esm_embeddings_path is not None):
lm_embedding_type = 'precomputed'
if 'esm_embeddings_model' in args and args.esm_embeddings_model is not None: lm_embedding_type = args.esm_embeddings_model
model = model_class(t_to_sigma=t_to_sigma,
device=device,
no_torsion=args.no_torsion,
timestep_emb_func=timestep_emb_func,
num_conv_layers=args.num_conv_layers,
lig_max_radius=args.max_radius,
scale_by_sigma=args.scale_by_sigma,
sigma_embed_dim=args.sigma_embed_dim,
norm_by_sigma='norm_by_sigma' in args and args.norm_by_sigma,
ns=args.ns, nv=args.nv,
distance_embed_dim=args.distance_embed_dim,
cross_distance_embed_dim=args.cross_distance_embed_dim,
batch_norm=not args.no_batch_norm,
dropout=args.dropout,
use_second_order_repr=args.use_second_order_repr,
cross_max_distance=args.cross_max_distance,
dynamic_max_cross=args.dynamic_max_cross,
smooth_edges=args.smooth_edges if "smooth_edges" in args else False,
odd_parity=args.odd_parity if "odd_parity" in args else False,
lm_embedding_type=lm_embedding_type,
confidence_mode=confidence_mode,
affinity_prediction=args.affinity_prediction if 'affinity_prediction' in args else False,
parallel=args.parallel if "parallel" in args else 1,
num_confidence_outputs=len(
args.rmsd_classification_cutoff) + 1 if 'rmsd_classification_cutoff' in args and isinstance(
args.rmsd_classification_cutoff, list) else 1,
atom_num_confidence_outputs=len(
args.atom_rmsd_classification_cutoff) + 1 if 'atom_rmsd_classification_cutoff' in args and isinstance(
args.atom_rmsd_classification_cutoff, list) else 1,
parallel_aggregators=args.parallel_aggregators if "parallel_aggregators" in args else "",
fixed_center_conv=not args.not_fixed_center_conv if "not_fixed_center_conv" in args else False,
no_aminoacid_identities=args.no_aminoacid_identities if "no_aminoacid_identities" in args else False,
include_miscellaneous_atoms=args.include_miscellaneous_atoms if hasattr(args, 'include_miscellaneous_atoms') else False,
sh_lmax=args.sh_lmax if 'sh_lmax' in args else 2,
differentiate_convolutions=not args.no_differentiate_convolutions if "no_differentiate_convolutions" in args else True,
tp_weights_layers=args.tp_weights_layers if "tp_weights_layers" in args else 2,
num_prot_emb_layers=args.num_prot_emb_layers if "num_prot_emb_layers" in args else 0,
reduce_pseudoscalars=args.reduce_pseudoscalars if "reduce_pseudoscalars" in args else False,
embed_also_ligand=args.embed_also_ligand if "embed_also_ligand" in args else False,
atom_confidence=args.atom_confidence_loss_weight > 0.0 if "atom_confidence_loss_weight" in args else False,
sidechain_pred=(hasattr(args, 'sidechain_loss_weight') and args.sidechain_loss_weight > 0) or
(hasattr(args, 'backbone_loss_weight') and args.backbone_loss_weight > 0),
depthwise_convolution=args.depthwise_convolution if hasattr(args, 'depthwise_convolution') else False)
if device.type == 'cuda' and not no_parallel and ('dataset' not in args or not args.dataset == 'torsional'):
model = DataParallel(model)
model.to(device)
return model
def get_symmetry_rmsd(mol, coords1, coords2, mol2=None):
with time_limit(10):
mol = molecule.Molecule.from_rdkit(mol)
mol2 = molecule.Molecule.from_rdkit(mol2) if mol2 is not None else mol2
mol2_atomicnums = mol2.atomicnums if mol2 is not None else mol.atomicnums
mol2_adjacency_matrix = mol2.adjacency_matrix if mol2 is not None else mol.adjacency_matrix
RMSD = rmsd.symmrmsd(
coords1,
coords2,
mol.atomicnums,
mol2_atomicnums,
mol.adjacency_matrix,
mol2_adjacency_matrix,
)
return RMSD
import signal
from contextlib import contextmanager
class TimeoutException(Exception): pass
@@ -241,3 +383,31 @@ class ExponentialMovingAverage:
self.decay = state_dict['decay']
self.num_updates = state_dict['num_updates']
self.shadow_params = [tensor.to(device) for tensor in state_dict['shadow_params']]
def crop_beyond(complex_graph, cutoff, all_atoms):
ligand_pos = complex_graph['ligand'].pos
receptor_pos = complex_graph['receptor'].pos
residues_to_keep = torch.any(torch.sum((ligand_pos.unsqueeze(0) - receptor_pos.unsqueeze(1)) ** 2, -1) < cutoff ** 2, dim=1)
if all_atoms:
#print(complex_graph['atom'].x.shape, complex_graph['atom'].pos.shape, complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index.shape)
atom_to_res_mapping = complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index[1]
atoms_to_keep = residues_to_keep[atom_to_res_mapping]
rec_remapper = (torch.cumsum(residues_to_keep.long(), dim=0) - 1)
atom_to_res_new_mapping = rec_remapper[atom_to_res_mapping][atoms_to_keep]
atom_res_edge_index = torch.stack([torch.arange(len(atom_to_res_new_mapping), device=atom_to_res_new_mapping.device), atom_to_res_new_mapping])
complex_graph['receptor'].pos = complex_graph['receptor'].pos[residues_to_keep]
complex_graph['receptor'].x = complex_graph['receptor'].x[residues_to_keep]
complex_graph['receptor'].side_chain_vecs = complex_graph['receptor'].side_chain_vecs[residues_to_keep]
complex_graph['receptor', 'rec_contact', 'receptor'].edge_index = \
subgraph(residues_to_keep, complex_graph['receptor', 'rec_contact', 'receptor'].edge_index, relabel_nodes=True)[0]
if all_atoms:
complex_graph['atom'].x = complex_graph['atom'].x[atoms_to_keep]
complex_graph['atom'].pos = complex_graph['atom'].pos[atoms_to_keep]
complex_graph['atom', 'atom_contact', 'atom'].edge_index = subgraph(atoms_to_keep, complex_graph['atom', 'atom_contact', 'atom'].edge_index, relabel_nodes=True)[0]
complex_graph['atom', 'atom_rec_contact', 'receptor'].edge_index = atom_res_edge_index
#print("cropped", 1-torch.mean(residues_to_keep.float()), 'residues', 1-torch.mean(atoms_to_keep.float()), 'atoms')

View File

@@ -1,14 +0,0 @@
## Visualizations of complexes that were unseen during training. EquiBind (cyan), DockDiff highest confidence sample (red), all other DockDiff samples (orange), and the crystal structure (green).
Complex 6agt:
![Alt Text](example_6agt_symmetric.gif)
Complex 6dz3:
![Alt Text](example_6dz3_symmetric.gif)
Complex 6gdy:
![Alt Text](example_6gdy_symmetric.gif)
Complex 6ckl:
![Alt Text](example_6ckl_symmetric.gif)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 12 MiB

View File

@@ -1,85 +0,0 @@
all_atoms: true
atom_max_neighbors: 8
atom_radius: 5
balance: false
batch_size: 16
best_model_save_frequency: 5
c_alpha_max_neighbors: 24
cache_creation_id: 1
cache_ids_to_combine:
- '1'
- '2'
- '3'
- '4'
cache_path: data/cache
ckpt: best_model.pt
confidence_dropout: 0.0
confidence_loss_weigth: 1
confidence_no_batchnorm: false
confidence_weight: 0.33
config: null
cross_distance_embed_dim: 32
cross_max_distance: 80
data_dir: data/PDBBind_processed/
distance_embed_dim: 32
dropout: 0.1
dynamic_max_cross: true
embedding_scale: 10000
embedding_type: sinusoidal
esm_embeddings_path: data/esm2_3billion_embeddings.pt
high_confidence_threshold: 5.0
include_confidence_prediction: false
inference_steps: 20
limit_complexes: 0
lm_embeddings_path: null
log_dir: workdir
lr: 0.0003
main_metric: loss
main_metric_goal: min
matching_maxiter: 20
matching_popsize: 20
max_lig_size: null
max_radius: 5.0
model_save_frequency: 0
n_epochs: 100
no_batch_norm: false
no_torsion: false
ns: 24
num_conformers: 1
num_conv_layers: 5
num_workers: 1
nv: 6
original_model_dir: workdir/temp_restart_ema_ESM2emb_tr34
project: diffdock_confidence
receptor_radius: 15.0
remove_hs: true
restart_dir: null
rmsd_classification_cutoff:
- 2.0
rmsd_prediction: false
rot_sigma_max: 1.55
rot_sigma_min: 0.03
rot_weight: 0.33
run_name: confidencetrain_samples28_FILTERFROM_ema_ESM2emb_tr34
samples_per_complex: 7
scale_by_sigma: true
scheduler: plateau
scheduler_patience: 50
sigma_embed_dim: 32
split_test: data/splits/timesplit_test
split_train: data/splits/timesplit_no_lig_overlap_train
split_val: data/splits/timesplit_no_lig_overlap_val
tor_sigma_max: 3.14
tor_sigma_min: 0.0314
tor_sigma_schedule: expbeta
tor_weight: 0.33
tr_only_confidence: true
tr_sigma_max: 34.0
tr_sigma_min: 0.1
tr_weight: 0.33
train_sampling: linear
transfer_weights: false
use_original_model_cache: false
use_second_order_repr: false
w_decay: 0.0
wandb: true

View File

@@ -1,83 +0,0 @@
all_atoms: false
atom_max_neighbors: 8
atom_radius: 5
batch_size: 16
c_alpha_max_neighbors: 24
cache_path: data/cacheNew
confidence_dropout: 0.0
confidence_no_batchnorm: false
config: null
cross_distance_embed_dim: 64
cross_max_distance: 80
cudnn_benchmark: true
data_dir: data/PDBBind_processed/
dataset: pdbbind
distance_embed_dim: 64
dropout: 0.1
dynamic_max_cross: true
ema_rate: 0.999
embedding_scale: 10000
embedding_type: sinusoidal
esm_embeddings_path: data/esm2_3billion_embeddings.pt
high_confidence_threshold: 5.0
include_confidence_prediction: false
inf_pocket_cutoff: 5
inf_pocket_knowledge: false
inference_earlystop_goal: max
inference_earlystop_metric: valinf_rmsds_lt2
inference_steps: 20
limit_complexes: 0
lm_embeddings_path: null
log_dir: workdir
lr: 0.001
matching_maxiter: 20
matching_popsize: 20
max_lig_size: null
max_radius: 5.0
multiplicity: 1
n_epochs: 850
no_batch_norm: false
no_torsion: false
norm_by_sigma: false
not_full_dataset: false
ns: 48
num_conformers: 1
num_conv_layers: 6
num_dataloader_workers: 1
num_gpus: 1
num_inference_complexes: 500
num_workers: 1
nv: 10
odd_parity: false
pin_memory: true
pretrained_model: null
project: diffdock_train
receptor_radius: 15.0
remove_hs: true
restart_dir: null
rot_sigma_max: 1.55
rot_sigma_min: 0.03
rot_weight: 0.33
run_name: big_ema_ESM2emb
scale_by_sigma: true
scheduler: plateau
scheduler_patience: 30
sigma_embed_dim: 64
split_test: data/splits/timesplit_test
split_train: data/splits/timesplit_no_lig_overlap_train
split_val: data/splits/timesplit_no_lig_overlap_val
test_sigma_intervals: true
tor_sigma_max: 3.14
tor_sigma_min: 0.0314
tor_weight: 0.33
tr_only_confidence: true
tr_sigma_max: 19.0
tr_sigma_min: 0.1
tr_weight: 0.33
train_inference_freq: null
train_sampling: linear
use_ema: true
use_second_order_repr: false
val_inference_freq: 5
w_decay: 0.0
wandb: true