mirror of
https://github.com/samsledje/D-SCRIPT.git
synced 2026-06-04 15:04:24 +08:00
Initial commit from Lynn on contact supervision branch
This commit is contained in:
4
.gitignore
vendored
4
.gitignore
vendored
@@ -1,3 +1,6 @@
|
||||
dscript/proteins
|
||||
dscript/pdbs
|
||||
|
||||
RUN_DSCRIPT.sh
|
||||
build/*
|
||||
scratch/*
|
||||
@@ -8,6 +11,7 @@ dev*/*
|
||||
*.egg-info
|
||||
**/*.sav
|
||||
**/*_log.txt
|
||||
**.DS_Store
|
||||
**/*.h5
|
||||
**/.ipynb_checkpoints/**
|
||||
**/__pycache__/**
|
||||
|
||||
1
dscript/1a0n.tsv
Normal file
1
dscript/1a0n.tsv
Normal file
@@ -0,0 +1 @@
|
||||
1A0N:A UNP:P27986 P85A_HUMAN 1A0N:B UNP:P06241 FYN_HUMAN 1
|
||||
|
@@ -1,4 +1,5 @@
|
||||
__version__ = "0.1.9"
|
||||
# goes to this version with relative imports
|
||||
__version__ = "0.1.9-Lynn"
|
||||
__citation__ = """Sledzieski, Singh, Cowen, Berger. Sequence-based prediction of protein-protein interactions: a structure-aware interpretable deep learning model. Cell Systems, 2021."""
|
||||
from . import alphabets, commands, fasta, language_model, models, pretrained
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
"""
|
||||
D-SCRIPT: Structure Aware PPI Prediction
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import argparse # command line run
|
||||
import os # manages operating system (system software)
|
||||
import sys
|
||||
|
||||
|
||||
# CLASS prints citation (from initial file)
|
||||
class CitationAction(argparse.Action):
|
||||
def __init__(self, option_strings, dest, **kwargs):
|
||||
super(CitationAction, self).__init__(option_strings, dest, **kwargs)
|
||||
@@ -17,8 +18,9 @@ class CitationAction(argparse.Action):
|
||||
setattr(namespace, self.dest, values)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
# Compiles a summary of the dscript information: version, citation, commdands, modules
|
||||
def main():
|
||||
# print("Hello World!")
|
||||
from . import __version__
|
||||
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import print_function, division
|
||||
import numpy as np
|
||||
|
||||
|
||||
# encodes amino acid seqence into mathematical indices and vice versa
|
||||
class Alphabet:
|
||||
"""
|
||||
From `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
@@ -60,6 +61,7 @@ class Alphabet:
|
||||
string = self.chars[x]
|
||||
return string.tobytes()
|
||||
|
||||
# ***
|
||||
def unpack(self, h, k):
|
||||
""" unpack integer h into array of this alphabet with length k """
|
||||
n = self.size
|
||||
|
||||
@@ -5,7 +5,7 @@ Generate new embeddings using pre-trained language model.
|
||||
import argparse
|
||||
from ..language_model import embed_from_fasta
|
||||
|
||||
|
||||
# *** make new embeddings --> this vs. embedding.py (models) vs language_model.py
|
||||
def add_args(parser):
|
||||
"""
|
||||
Create parser for command line utility.
|
||||
|
||||
@@ -44,7 +44,7 @@ def add_args(parser):
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# *** plot positive and negative predictions??
|
||||
def plot_eval_predictions(labels, predictions, path="figure"):
|
||||
"""
|
||||
Plot histogram of positive and negative predictions, precision-recall curve, and receiver operating characteristic curve.
|
||||
|
||||
@@ -44,13 +44,16 @@ def add_args(parser):
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
# *** simple prediction, interpretation?
|
||||
def main(args):
|
||||
"""
|
||||
Run new prediction from arguments.
|
||||
|
||||
:meta private:
|
||||
"""
|
||||
# print("Hello World")
|
||||
# sys.exit(0)
|
||||
|
||||
if args.seqs is None and args.embeddings is None:
|
||||
print("One of --seqs or --embeddings is required.")
|
||||
sys.exit(0)
|
||||
|
||||
@@ -182,7 +182,7 @@ def predict_cmap_interaction(model, n0, n1, tensors, use_cuda):
|
||||
:param use_cuda: Whether to use GPU
|
||||
:type use_cuda: bool
|
||||
"""
|
||||
|
||||
|
||||
b = len(n0)
|
||||
|
||||
p_hat = []
|
||||
@@ -200,7 +200,7 @@ def predict_cmap_interaction(model, n0, n1, tensors, use_cuda):
|
||||
c_map_mag = torch.stack(c_map_mag, 0)
|
||||
return c_map_mag, p_hat
|
||||
|
||||
|
||||
# *** list and list interactions? from cmap predict interactions
|
||||
def predict_interaction(model, n0, n1, tensors, use_cuda):
|
||||
"""
|
||||
Predict whether a list of protein pairs will interact.
|
||||
@@ -243,6 +243,7 @@ def interaction_grad(model, n0, n1, y, tensors, weight=0.35, use_cuda=True):
|
||||
:rtype: (torch.Tensor, int, torch.Tensor, int)
|
||||
"""
|
||||
|
||||
|
||||
c_map_mag, p_hat = predict_cmap_interaction(
|
||||
model, n0, n1, tensors, use_cuda
|
||||
)
|
||||
@@ -252,6 +253,7 @@ def interaction_grad(model, n0, n1, y, tensors, weight=0.35, use_cuda=True):
|
||||
|
||||
p_hat = p_hat.float()
|
||||
bce_loss = F.binary_cross_entropy(p_hat.float(), y.float())
|
||||
|
||||
cmap_loss = torch.mean(c_map_mag)
|
||||
loss = (weight * bce_loss) + ((1 - weight) * cmap_loss)
|
||||
b = len(p_hat)
|
||||
@@ -397,7 +399,7 @@ def train_model(args, output):
|
||||
for prot_name in tqdm(all_proteins):
|
||||
embeddings[prot_name] = torch.from_numpy(h5fi[prot_name][:, :])
|
||||
|
||||
if args.checkpoint is None:
|
||||
if args.checkpoint is None:
|
||||
|
||||
# Create embedding model
|
||||
input_dim = args.input_dim
|
||||
|
||||
76
dscript/contact_map.py
Normal file
76
dscript/contact_map.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import Bio.PDB
|
||||
import numpy
|
||||
import matplotlib.pyplot as plt
|
||||
import h5py
|
||||
from models.interaction import ModelInteraction
|
||||
# from language_model import embed_from_fasta
|
||||
|
||||
fi = h5py.File("2022-06-27-06:26.predictions.cmaps.h5","r")
|
||||
ke = list(fi.keys())
|
||||
# print(fi[ke[0]])
|
||||
cmap = fi[ke[0]]
|
||||
plt.imshow(cmap)
|
||||
plt.show()
|
||||
|
||||
# SOURCE CODE: https://warwick.ac.uk/fac/sci/moac/people/students/peter_cock/python/protein_contact_map/
|
||||
# pdb_code = "15c8"
|
||||
# pdb_filename = "dscript/pdbs/15c8.pdb"
|
||||
# pdb_code = "1xi4"
|
||||
# pdb_filename = "dscript/1xi4.pdb" #not the full cage!
|
||||
|
||||
|
||||
# def calc_residue_dist(residue_one, residue_two) :
|
||||
# """Returns the C-alpha distance between two residues"""
|
||||
# diff_vector = residue_one["CA"].coord - residue_two["CA"].coord
|
||||
# return numpy.sqrt(numpy.sum(diff_vector * diff_vector))
|
||||
|
||||
|
||||
# def calc_dist_matrix(chain_one, chain_two) :
|
||||
# """Returns a matrix of C-alpha distances between two chains"""
|
||||
# answer = numpy.zeros((len(chain_one), len(chain_two)), numpy.float)
|
||||
# for row, residue_one in enumerate(chain_one) :
|
||||
# for col, residue_two in enumerate(chain_two) :
|
||||
# answer[row, col] = calc_residue_dist(residue_one, residue_two)
|
||||
# return answer
|
||||
|
||||
|
||||
# structure = Bio.PDB.PDBParser().get_structure(pdb_code, pdb_filename)
|
||||
# print(structure)
|
||||
|
||||
# model = structure[0]
|
||||
# # print(list(model["H"]))
|
||||
|
||||
# # chain1 = model["H"]
|
||||
# # for residue in chain1:
|
||||
# # if residue.id[0] != ' ':
|
||||
# # chain1.detach_child(residue.id)
|
||||
|
||||
# # chain2 = model["L"]
|
||||
# # for residue in chain2:
|
||||
# # if residue.id[0] != ' ':
|
||||
# # chain2.detach_child(residue.id)
|
||||
|
||||
# dist_matrix = calc_dist_matrix(model["D"], model["M"])
|
||||
# contact_map = dist_matrix < 12.0
|
||||
|
||||
# print(contact_map)
|
||||
# print(numpy.min(dist_matrix))
|
||||
# print(numpy.max(dist_matrix))
|
||||
|
||||
# pylab.matshow(numpy.transpose(dist_matrix))
|
||||
# pylab.colorbar()
|
||||
# pylab.show()
|
||||
|
||||
# pylab.imshow(numpy.transpose(contact_map))
|
||||
# pylab.show()
|
||||
|
||||
|
||||
# DSCRIPT CONTACT MAPPING
|
||||
# embed_from_fasta("dscript/proteins/1a0n.fasta", "dscript/1a0n_embed.h5", device=0, verbose=False)
|
||||
# h5fi = h5py.File("dscript/1a0n_embed.h5", "r")
|
||||
# z_a = h5fi["1A0N:A UNP:P27986 P85A_HUMAN"]
|
||||
# z_b = h5fi["1A0N:B UNP:P06241 FYN_HUMAN"]
|
||||
# self = ModelInteraction()
|
||||
# cm, ph = self.map_predict(z_a, z_b)
|
||||
# print(cm)
|
||||
# print(ph)
|
||||
22
dscript/data_collection.py
Normal file
22
dscript/data_collection.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# converting PDB CIF to sequence format
|
||||
import Bio
|
||||
from Bio import SeqIO
|
||||
import os
|
||||
|
||||
entries = os.listdir('dscript/pdbs')
|
||||
|
||||
# CONVERT ALL
|
||||
for i in range(0, len(entries)):
|
||||
if entries[i][0] != ".":
|
||||
with open(f'dscript/proteins/{entries[i][:-4]}.fasta', 'w') as f:
|
||||
for record in SeqIO.parse(f"dscript/pdbs/{entries[i]}", "pdb-seqres"):
|
||||
f.write(record.format("fasta-2line"))
|
||||
f.close()
|
||||
|
||||
# TEST A SINGLE
|
||||
# with open(f'dscript/proteins/15c8.fasta', 'w') as f:
|
||||
# for record in SeqIO.parse(f"dscript/15c8.pdb", "pdb-seqres"):
|
||||
# # print("Record id %s, chain %s" % (record.id, record.annotations["chain"]))
|
||||
# # print(record.format("fasta"))
|
||||
# f.write(record.format("fasta-2line"))
|
||||
# f.close()
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import print_function, division
|
||||
|
||||
|
||||
# deal with formatting of fasta files (header id : protein sequence)
|
||||
def parse_stream(f, comment=b"#"):
|
||||
|
||||
name = None
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import sys
|
||||
import subprocess as sp
|
||||
import subprocess as sp # handle subprocesses
|
||||
import random
|
||||
import torch
|
||||
import h5py
|
||||
@@ -12,7 +12,7 @@ from .models.embedding import SkipLSTM
|
||||
from .utils import log
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# using pre-trained language model to embed protein sequences
|
||||
def lm_embed(sequence, use_cuda=False):
|
||||
"""
|
||||
Embed a single sequence using pre-trained language model from `Bepler & Berger <https://github.com/tbepler/protein-sequence-embedding-iclr2019>`_.
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.functional as F
|
||||
|
||||
|
||||
# *** broadcast tensor?
|
||||
class FullyConnected(nn.Module):
|
||||
"""
|
||||
Performs part 1 of Contact Prediction Module. Takes embeddings from Projection module and produces broadcast tensor.
|
||||
@@ -54,7 +54,7 @@ class FullyConnected(nn.Module):
|
||||
|
||||
return c
|
||||
|
||||
|
||||
# *** part 2 of contact model?
|
||||
class ContactCNN(nn.Module):
|
||||
"""
|
||||
Residue Contact Prediction Module. Takes embeddings from Projection module and produces contact map, output of Contact module.
|
||||
|
||||
@@ -5,7 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
# language model --> embedding --> contact
|
||||
class IdentityEmbed(nn.Module):
|
||||
"""
|
||||
Does not reduce the dimension of the language model embeddings, just passes them through to the contact model.
|
||||
@@ -20,7 +20,7 @@ class IdentityEmbed(nn.Module):
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
# amino acids --> low dimensional projection
|
||||
class FullyConnectedEmbed(nn.Module):
|
||||
"""
|
||||
Protein Projection Module. Takes embedding from language model and outputs low-dimensional interaction aware projection.
|
||||
@@ -57,7 +57,7 @@ class FullyConnectedEmbed(nn.Module):
|
||||
t = self.drop(t)
|
||||
return t
|
||||
|
||||
|
||||
# uses (bidirectional) LSTM as part of embedding into vector process
|
||||
class LSTMEmbed(nn.Module):
|
||||
def __init__(self, nout, activation="ReLU", sparse=False, p=0.5):
|
||||
super(LSTMEmbed, self).__init__()
|
||||
@@ -104,6 +104,7 @@ class SkipLSTM(nn.Module):
|
||||
|
||||
Loaded with pre-trained weights in embedding function.
|
||||
|
||||
# *** 21, not 20? reference said something to explain this
|
||||
:param nin: Input dimension of amino acid one-hot [default: 21]
|
||||
:type nin: int
|
||||
:param nout: Output dimension of final layer [default: 100]
|
||||
@@ -169,6 +170,7 @@ class SkipLSTM(nn.Module):
|
||||
one_hot.scatter_(2, x.unsqueeze(2), 1)
|
||||
return one_hot
|
||||
|
||||
# *** powerpoint slide, what's going on here?
|
||||
def transform(self, x):
|
||||
"""
|
||||
:param x: Input numeric amino acid encoding :math:`(N)`
|
||||
|
||||
@@ -62,6 +62,7 @@ class ModelInteraction(nn.Module):
|
||||
lambda_init=0,
|
||||
gamma_init=0,
|
||||
):
|
||||
# *** pooling operations? reference
|
||||
"""
|
||||
Main D-SCRIPT model. Contains an embedding and contact model and offers access to those models. Computes pooling operations on contact map to generate interaction probability.
|
||||
|
||||
@@ -122,6 +123,7 @@ class ModelInteraction(nn.Module):
|
||||
|
||||
self.gamma.data.clamp_(min=0)
|
||||
|
||||
# *** explain 'Project down input language model embeddings into low dimension using projection module"?
|
||||
def embed(self, x):
|
||||
"""
|
||||
Project down input language model embeddings into low dimension using projection module
|
||||
|
||||
@@ -7,12 +7,12 @@ from .models.contact import ContactCNN
|
||||
from .models.embedding import FullyConnectedEmbed, SkipLSTM
|
||||
from .models.interaction import ModelInteraction
|
||||
|
||||
|
||||
# create an lstm and a human_v1 model
|
||||
def build_lm_1(state_dict_path):
|
||||
"""
|
||||
:meta private:
|
||||
"""
|
||||
model = SkipLSTM(21, 100, 1024, 3)
|
||||
model = SkipLSTM(21, 100, 1024, 3) # creates bidirection lstm, read in reference
|
||||
state_dict = torch.load(state_dict_path)
|
||||
model.load_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
@@ -15,7 +15,7 @@ MODEL_VERSIONS = [
|
||||
|
||||
print(dscript.__version__)
|
||||
|
||||
|
||||
# checker method to test if pretrained models exist and are working
|
||||
def test_get_state_dict():
|
||||
for mv in MODEL_VERSIONS:
|
||||
sd = get_state_dict(mv, verbose=True)
|
||||
|
||||
@@ -7,7 +7,7 @@ import numpy as np
|
||||
import pandas as pd
|
||||
import subprocess as sp
|
||||
import sys
|
||||
import gzip as gz
|
||||
import gzip as gz # handles unix/linux gzip files
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ channels:
|
||||
|
||||
dependencies:
|
||||
- pip=20.0
|
||||
- cudatoolkit=10.2
|
||||
- python=3.7
|
||||
- pytorch=1.5
|
||||
- h5py
|
||||
|
||||
18
setup.py
18
setup.py
@@ -19,14 +19,14 @@ setup(
|
||||
},
|
||||
include_package_data=True,
|
||||
install_requires=[
|
||||
"numpy",
|
||||
"scipy",
|
||||
"pandas",
|
||||
"torch",
|
||||
"matplotlib",
|
||||
"seaborn",
|
||||
"tqdm",
|
||||
"scikit-learn",
|
||||
"h5py",
|
||||
"numpy", # mathematical operations/representations
|
||||
"scipy", # linear algebra
|
||||
"pandas", # data analysis tools
|
||||
"torch", #nn training/creation
|
||||
"matplotlib", # visualization
|
||||
"seaborn", # matplotlib 2.0
|
||||
"tqdm", # progress bar
|
||||
"scikit-learn", # machine learning
|
||||
"h5py", # store lots of (binary) data
|
||||
],
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user