Add files via upload

This commit is contained in:
Jun Liu
2025-03-18 19:14:00 +08:00
committed by GitHub
parent 5c5b5efd54
commit d018f9c759
6 changed files with 1335 additions and 0 deletions

344
pplm_contact/LoadHHM.py Normal file
View File

@@ -0,0 +1,344 @@
#
# This file (LoadHHM.py) was downloaded from the RaptorX-Contact
# at https://github.com/j3xugit/RaptorX-Contact
#
import numpy as np
import os
import sys
import pickle as pkl
"""
This script reads an hhm file generated by the HHpred/HHblits package to generate position-specific scoring/frequency matrix.
After reading an hhm file, this script stores the HMM as a python dict().
To use the position-specfic frequency matrix, please use the keyword PSFM.
To use the position-specific scoring matrix, please use the keyword PSSM.
Both PSFM and PSSM encode information derived from the profile HMM built by HHpred or HHblits,
so there is no need to directly use the keys containing 'hmm'
PSFM and PSSM columns are arranged by the alphabetical order of amino acids in their 1-letter code
"""
SS8Letter2Code = {'H':0, 'G':1, 'I':2, 'E':3, 'B':4, 'T':5, 'S':6, 'L':7, 'C':7 }
## secondary structure conversion, note here HELIX, BETA and LOOP correspond to 1, 2 and 0, respectively
SS8Letter2SS3Code = {'H':0, 'G':0, 'I':0, 'E':1, 'B':1, 'T':2, 'S':2, 'L':2, 'C':2 }
AA3LetterCode=['ALA', 'ARG', 'ASN', 'ASP', 'ASX', 'CYS', 'GLU', 'GLN', 'GLX', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRY', 'TYR', 'VAL']
AA1LetterCode=['A', 'R', 'N', 'D', 'B', 'C', 'E', 'Q', 'Z', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ]
##we only allow 20 amino acids in our protein sequences
ValidAALetters=set(['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'])
##we map the two rare amino acids to 20
## AAOrderBy3Letter is the alphabetical order of amino acids by its 3-letter code
AAOrderBy3Letter=[0, 1, 2, 3, 20, 4, 5, 6, 20, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ]
## AAOrderBy1Letter is the alphabetical order of amino acids by its 1-letter code
AAOrderBy1Letter=[0, 14, 11, 2, 20, 1, 3, 13, 20, 5, 6, 7, 9, 8, 10, 4, 12, 15, 16, 18, 19, 17 ]
## mapping between 1-letter code and 3-letter code
AA3LetterCode21LetterCode = {}
AA1LetterCode23LetterCode = {}
##map from one AA (3-letter code or 1-letter code) to its order in terms of 3-letter code
AALetter2OrderOf3LetterCode = {}
## map from one AA (3-letter or 1-letter code) to its order in terms of 1-letter code
AALetter2OrderOf1LetterCode = {}
## map between 3-letter order and 1-letter order
AA3LetterOrder21LetterOrder = {}
AA1LetterOrder23LetterOrder = {}
for l3, l1, o3, o1 in zip(AA3LetterCode, AA1LetterCode, AAOrderBy3Letter, AAOrderBy1Letter):
AA1LetterCode23LetterCode[l1] = l3
AA3LetterCode21LetterCode[l3] = l1
AALetter2OrderOf3LetterCode[l1] = o3
AALetter2OrderOf3LetterCode[l3] = o3
AALetter2OrderOf1LetterCode[l1] = o1
AALetter2OrderOf1LetterCode[l3] = o1
AA3LetterOrder21LetterOrder[o3] = o1
AA1LetterOrder23LetterOrder[o1] = o3
## the rows and cols of gonnet matrix are arranged in the alpahbetical oder of amino acids in the 3-letter code
## the below matrix is a mutation probability matrix?
gonnet = [
[ 1.7378, 0.870964,0.933254,0.933254, 1.12202, 0.954993, 1, 1.12202, 0.831764, 0.831764, 0.758578, 0.912011, 0.851138, 0.588844, 1.07152, 1.28825, 1.14815, 0.436516, 0.60256, 1.02329],
[ 0.870964,2.95121, 1.07152, 0.933254, 0.60256, 1.41254, 1.09648, 0.794328, 1.14815, 0.57544, 0.60256, 1.86209, 0.676083, 0.47863, 0.812831, 0.954993, 0.954993, 0.691831, 0.660693, 0.630957],
[ 0.933254,1.07152, 2.39883, 1.65959, 0.660693, 1.1749, 1.23027, 1.09648, 1.31826, 0.524807, 0.501187, 1.20226, 0.60256, 0.489779, 0.812831, 1.23027, 1.12202, 0.436516, 0.724436, 0.60256],
[ 0.933254,0.933254,1.65959, 2.95121, 0.47863, 1.23027, 1.86209, 1.02329, 1.09648, 0.416869, 0.398107, 1.12202, 0.501187, 0.354813, 0.851138, 1.12202, 1, 0.301995, 0.524807, 0.512861],
[ 1.12202, 0.60256, 0.660693,0.47863, 14.1254, 0.57544, 0.501187, 0.630957, 0.74131, 0.776247, 0.707946, 0.524807, 0.812831, 0.831764, 0.489779, 1.02329, 0.891251, 0.794328, 0.891251, 1],
[ 0.954993,1.41254, 1.1749, 1.23027, 0.57544, 1.86209, 1.47911, 0.794328, 1.31826, 0.645654, 0.691831, 1.41254, 0.794328, 0.549541, 0.954993, 1.04713, 1, 0.537032, 0.676083, 0.707946],
[ 1, 1.09648, 1.23027, 1.86209, 0.501187, 1.47911, 2.29087, 0.831764, 1.09648, 0.537032, 0.524807, 1.31826, 0.630957, 0.40738, 0.891251, 1.04713, 0.977237, 0.371535, 0.537032, 0.645654],
[ 1.12202, 0.794328,1.09648, 1.02329, 0.630957, 0.794328, 0.831764, 4.57088, 0.724436, 0.354813, 0.363078, 0.776247, 0.446684, 0.301995, 0.691831, 1.09648, 0.776247, 0.398107, 0.398107, 0.467735],
[ 0.831764,1.14815, 1.31826, 1.09648, 0.74131, 1.31826, 1.09648, 0.724436, 3.98107, 0.60256, 0.645654, 1.14815, 0.74131, 0.977237, 0.776247, 0.954993, 0.933254, 0.831764, 1.65959, 0.630957],
[ 0.831764,0.57544, 0.524807,0.416869, 0.776247, 0.645654, 0.537032, 0.354813, 0.60256, 2.51189, 1.90546, 0.616595, 1.77828, 1.25893, 0.549541, 0.660693, 0.870964, 0.660693, 0.851138, 2.04174],
[ 0.758578,0.60256, 0.501187,0.398107, 0.707946, 0.691831, 0.524807, 0.363078, 0.645654, 1.90546, 2.51189, 0.616595, 1.90546, 1.58489, 0.588844, 0.616595, 0.74131, 0.851138, 1, 1.51356],
[ 0.912011,1.86209, 1.20226, 1.12202, 0.524807, 1.41254, 1.31826, 0.776247, 1.14815, 0.616595, 0.616595, 2.0893, 0.724436, 0.467735, 0.870964, 1.02329, 1.02329, 0.446684, 0.616595, 0.676083],
[ 0.851138,0.676083,0.60256, 0.501187, 0.812831, 0.794328, 0.630957, 0.446684, 0.74131, 1.77828, 1.90546, 0.724436, 2.69153, 1.44544, 0.57544, 0.724436, 0.870964, 0.794328, 0.954993, 1.44544],
[ 0.588844,0.47863, 0.489779,0.354813, 0.831764, 0.549541, 0.40738, 0.301995, 0.977237, 1.25893, 1.58489, 0.467735, 1.44544, 5.01187, 0.416869, 0.524807, 0.60256, 2.29087, 3.23594, 1.02329],
[ 1.07152, 0.812831,0.812831,0.851138, 0.489779, 0.954993, 0.891251, 0.691831, 0.776247, 0.549541, 0.588844, 0.870964, 0.57544, 0.416869, 5.7544, 1.09648, 1.02329, 0.316228, 0.489779, 0.660693],
[ 1.28825, 0.954993,1.23027, 1.12202, 1.02329, 1.04713, 1.04713, 1.09648, 0.954993, 0.660693, 0.616595, 1.02329, 0.724436, 0.524807, 1.09648, 1.65959, 1.41254, 0.467735, 0.645654, 0.794328],
[ 1.14815, 0.954993,1.12202, 1, 0.891251, 1, 0.977237, 0.776247, 0.933254, 0.870964, 0.74131, 1.02329, 0.870964, 0.60256, 1.02329, 1.41254, 1.77828, 0.446684, 0.645654, 1],
[ 0.436516,0.691831,0.436516,0.301995, 0.794328, 0.537032, 0.371535, 0.398107, 0.831764, 0.660693, 0.851138, 0.446684, 0.794328, 2.29087, 0.316228, 0.467735, 0.446684, 26.3027, 2.5704, 0.549541],
[ 0.60256, 0.660693,0.724436,0.524807, 0.891251, 0.676083, 0.537032, 0.398107, 1.65959, 0.851138, 1, 0.616595, 0.954993, 3.23594, 0.489779, 0.645654, 0.645654, 2.5704, 6.0256, 0.776247],
[ 1.02329, 0.630957,0.60256, 0.512861, 1, 0.707946, 0.645654, 0.467735, 0.630957, 2.04174, 1.51356, 0.676083, 1.44544, 1.02329, 0.660693, 0.794328, 1, 0.549541, 0.776247, 2.18776]
]
gonnet = np.array(gonnet, np.float32)
M_M, M_I, M_D, I_M, I_I, D_M, D_D, _NEFF, I_NEFF, D_NEFF = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
## HMMNull is the background score of amino acids, in the alphabetical order of 1-letter code
HMMNull = [3706,5728,4211,4064,4839,3729,4763,4308,4069,3323,5509,4640,4464,4937,4285,4423,3815,3783,6325,4665,0]
HMMNull = np.array(HMMNull, dtype=np.float32)
## this function reads the HMM block from a profleHMM file (generated by HHpred/HHblits package)
## for the profileHMM file, the header has 4 lines
def ReadHHM(lines, start_position, length, one_protein, numLines4Header=4):
i = start_position
one_protein['HMMHeader'] = lines[i: i+ numLines4Header]
i += numLines4Header
## the columns of hmm1 are in the alphabetical order of amino acids in 1-letter code, different from the above PSP and PSM matrices
one_protein['hmm1'] = np.zeros( (length, 20), np.float32)
one_protein['hmm2'] = np.zeros( (length, 10), np.float32)
one_protein['hmm1_prob'] = np.zeros((length, 20), np.float32)
one_protein['hmm1_score'] = np.zeros((length, 20), np.float32)
seqStr = ''
for l in range(length):
##this line is for emission score. The amino acids are ordered from left to right alphabetically by their 1-letter code
fields = lines[ i + l*3 + 0].replace("*", "99999").split()
assert len(fields) == 23
one_protein['hmm1'][l] = np.array([ - np.int32(num) for num in fields[2: -1] ])/1000.
aa = fields[0]
seqStr += aa
"""
if (l + 1) != np.int32(fields[1]):
print 'Error: inconsistent residue number in file for protein ', one_protein['name'], ' at line: ', lines[ i + l*3 + 0 ]
exit(-1)
"""
##the first 7 columns of this line is for state transition
one_protein['hmm2'][l][0:7] = [ np.exp(-np.int32(num)/1000.0*0.6931) for num in lines[i + l*3 + 1].replace("*", "99999").split()[0:7]]
##the last 3 columns of this line is for Neff of Match, Insertion and Deletion
one_protein['hmm2'][l][7:10] = [ np.int32(num)/1000.0 for num in lines[i + l*3 + 1].split()[7:10]]
## _NEFF is for match, I_NEFF for insertion and D_NEFF for deletion. More comments are needed for the below code
rm = 0.1
one_protein['hmm2'][l][M_M] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_M] + rm*0.6)/(rm + one_protein['hmm2'][l][_NEFF])
one_protein['hmm2'][l][M_I] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_I] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF])
one_protein['hmm2'][l][M_D] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_D] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF])
ri = 0.1
one_protein['hmm2'][l][I_I] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_I] + ri*0.75)/(ri + one_protein['hmm2'][l][I_NEFF])
one_protein['hmm2'][l][I_M] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_M] + ri*0.25)/(ri + one_protein['hmm2'][l][I_NEFF])
rd = 0.1
one_protein['hmm2'][l][D_D] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_D] + rd*0.75)/(rd + one_protein['hmm2'][l][D_NEFF])
one_protein['hmm2'][l][D_M] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_M] + rd*0.25)/(rd + one_protein['hmm2'][l][D_NEFF])
one_protein['hmm1_prob'][l,] = pow(2.0, one_protein['hmm1'][l,])
wssum = sum(one_protein['hmm1_prob'][l, ])
#print 'l = ', l, 'sum= ', wssum
## renormalize to make wssum = 1
if wssum > 0 :
one_protein['hmm1_prob'][l, ] /= wssum
else:
one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1.
"""
## if the probability sum is not equal to 1
if abs(wssum - 1.) > 0.1 :
one_protein['hmm1_prob'][l, ] = 0
one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1.
"""
## add pseudo count
g = np.zeros( (20), np.float32)
for j in range(20):
orderIn3LetterCode_j = AA1LetterOrder23LetterOrder[j]
for k in range(20):
orderIn3LetterCode_k = AA1LetterOrder23LetterOrder[k]
g[j] += one_protein['hmm1_prob'][l, k] * gonnet[ orderIn3LetterCode_k, orderIn3LetterCode_j ]
g[j] *= pow(2.0, -1.0*HMMNull[j] / 1000.0)
#print 'l=', l, ' gsum= ', sum(g)
## sum(g) is very close to 1, here we renormalize g to make its sum to be exactly 1
g = g/sum(g)
ws_tmp_neff = one_protein['hmm2'][l][_NEFF] - 1
one_protein['hmm1'][l, ] = (ws_tmp_neff * one_protein['hmm1_prob'][l, ] + g*10) / (ws_tmp_neff+10)
## recalculate the emission score and probability after pseudo count is added
one_protein['hmm1_prob'][l,] = one_protein['hmm1'][l, ]
one_protein['hmm1'][l, ] = np.log2(one_protein['hmm1_prob'][l, ])
one_protein['hmm1_score' ][l, ] = one_protein['hmm1'][l, ] + HMMNull[:20]/1000.0
## PSFM: position-specific frequency matrix, PSSM: position-specific scoring matrix
one_protein['PSFM'] = one_protein['hmm1_prob']
one_protein['PSSM'] = one_protein['hmm1_score']
#assert ( seqStr == one_protein['sequence'] )
if len(seqStr) != len(one_protein['sequence']):
print ('ERROR: inconsistent sequence length in HMM section and orignal sequence for protein: ', one_protein['name'] )
exit(-1)
comparison = [ (aa=='X' or bb=='X' or aa==bb) for aa, bb in zip(seqStr, one_protein['sequence']) ]
if not all(comparison):
print ('ERROR: inconsistent sequence between HMM section and orignal sequence for protein: ', one_protein['name'])
print (' original seq: ', one_protein['sequence'])
print (' HMM seq: ', seqStr)
exit(-1)
return i + 3*length, one_protein
## this function reads a profile HMM file generated by HHpred/HHblits package
def load_hmm(hmmfile):
with open(hmmfile, 'r') as fh:
content = [ r.strip() for r in list(fh) ]
if not bool(content):
print ('ERROR: empty profileHMM file: ', hmmfile)
exit(1)
if not content[0].startswith('HHsearch'):
print ('ERROR: this file may not be a profileHMM file generated by HHpred/HHblits: ', hmmfile)
exit(1)
if len(content) < 10:
print ('ERROR: this profileHMM file is too short: ', hmmfile)
exit(1)
requiredSections = ['name', 'length', 'sequence', 'NEFF', 'hmm1', 'hmm2', 'hmm1_prob', 'hmm1_score', 'PSFM', 'PSSM', 'DateCreated']
protein = {}
## get sequence name
if not content[1].startswith('NAME '):
print ('ERROR: the protein name shall appear at the second line of profileHMM file: ', hmmfile)
exit(1)
fields = content[1].split()
if len(fields) < 2:
print ('ERROR: incorrect name format in profileHMM file: ', hmmfile)
exit(1)
protein['name'] = fields[1]
i = 0
while i < len(content):
row = content[i]
if len(row)<1:
i += 1
continue
if row.startswith('DATE '):
protein['DateCreated'] = row[6:]
i += 1
continue
if row.startswith('NEFF '):
protein['NEFF'] = np.float32(row.split()[1])
i += 1
continue
if row.startswith('LENG '):
protein['length'] = np.int32(row.split()[1])
i += 1
continue
if row.startswith('>ss_pred'):
## read the predicted secondary structure
start = i+1
end = i+1
while not content[end].startswith('>'):
end += 1
protein['SSEseq'] = ''.join(content[start:end]).replace('C', 'L')
if len(protein['SSEseq']) != protein['length']:
print ('ERROR: inconsistent sequence length and predicted SS sequence length in hmmfile: ', hmmfile)
exit(1)
i = end
continue
if row.startswith('>ss_conf'):
## read the predicted secondary structure confidence score
start = i+1
end = i+1
while not content[end].startswith('>'):
end += 1
SSEconfStr = ''.join(content[start:end])
protein['SSEconf'] = [ np.int16(score) for score in SSEconfStr ]
if len(protein['SSEconf']) != protein['length']:
print ('ERROR: inconsistent sequence length and predicted SS confidence sequence length in hmmfile: ', hmmfile)
exit(1)
i = end
continue
if row.startswith('>' + protein['name']):
## read in the sequence in the following lines
start = i+1
end = i+1
while (not content[end].startswith('>')) and (not content[end].startswith('#')):
end += 1
## at this point, content[end] shall start with >
protein['sequence'] = ''.join(content[start:end])
if len(protein['sequence']) != protein['length']:
print ('ERROR: inconsistent sequence length in hmmfile: ', hmmfile)
exit(1)
i = end
continue
if len(row) == 1 and row[0]=='#' and content[i+1].startswith('NULL') and content[i+2].startswith('HMM'):
i, protein = ReadHHM(content, i+1, protein['length'], protein, numLines4Header=4)
continue
i += 1
## double check to see some required sections are read in
for section in requiredSections:
#if not protein.has_key(section):
if section not in protein.keys():
print ('ERROR: one section for ', section, ' is missing in the hmm file: ', hmmfile)
print ('ERROR: it is also possible that the hmm file has a format incompatible with this script.')
exit(1)
protein['requiredSections'] = requiredSections
return protein
## for test only
if __name__ == "__main__":
if len(sys.argv) < 2:
print ('python LoadHHM.py hhm_file')
print ('the input file shall end with .hhm')
exit(1)
file = sys.argv[1]
if file.endswith('.hhm'):
protein = load_hmm(file)
else:
print ('ERROR: the input file shall have suffix .hmm')
exit(1)
savefile = os.path.basename(file) + '.pkl'
fh = open(savefile, 'wb')
pkl.dump( protein, fh, protocol=pkl.HIGHEST_PROTOCOL)
fh.close()

325
pplm_contact/Module.py Normal file
View File

@@ -0,0 +1,325 @@
import torch
import torch.nn as nn
class ResNet(nn.Module):
def __init__(self, channels, dropout=0.0):
super(ResNet, self).__init__()
self.channels = channels
self.n_layers = 5
self.dilations = [1, 2, 4, 2, 1]
self.dropout = dropout
self.blocks = nn.ModuleList()
for layer in range(self.n_layers):
block = nn.Sequential(
nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]),
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True),
nn.Dropout(dropout),
nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]),
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.blocks.append(block)
def forward(self, x):
for block in self.blocks:
_residual = x
x = block(x)
x = x + _residual
return x
class Intra_ResNet(nn.Module):
def __init__(self, dim_1d=768+20, dim_2d=144+2+64, channels=64, dropout=0):
super(Intra_ResNet, self).__init__()
self.dim_1d = dim_1d
self.dim_2d = dim_2d
self.channels = channels
self.dropout = dropout
self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d)
self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d)
self.pair_conv1 = nn.Sequential(
nn.Conv2d(self.dim_1d * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.pair_conv2 = nn.Sequential(
nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.pair_conv3 = nn.Sequential(
nn.Conv2d(self.channels * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.resnet = ResNet(channels=self.channels, dropout=self.dropout)
def forward(self, x_1d, x_2d):
x_1d = self.pre_norm_1d(x_1d)
x_2d = self.pre_norm_2d(x_2d)
b, d, L = x_1d.size()
x_1d_row = x_1d.unsqueeze(-1).expand(-1, -1, L, L)
x_1d_col = x_1d.unsqueeze(-2).expand(-1, -1, L, L)
pair_1 = torch.cat([x_1d_row, x_1d_col], dim=1)
pair_1 = self.pair_conv1(pair_1)
pair_2 = self.pair_conv2(x_2d)
pair = torch.cat([pair_1, pair_2], dim=1)
pair = self.pair_conv3(pair)
pair = self.resnet(pair)
return pair
class Inter_ResNet(nn.Module):
def __init__(self, dim_1d=768+20, dim_2d=768 + 20, channels=64, dropout=0):
super(Inter_ResNet, self).__init__()
self.dim_1d = dim_1d
self.dim_2d = dim_2d
self.channels = channels
self.dropout = dropout
# self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d)
self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d)
self.pair_conv2 = nn.Sequential(
nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.pair_conv3 = nn.Sequential(
nn.Conv2d(self.channels, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
nn.LeakyReLU(negative_slope=0.01, inplace=True)
)
self.resnet = ResNet(channels=self.channels, dropout=self.dropout)
def forward(self, inter_2d):
inter_2d = self.pre_norm_2d(inter_2d)
pair = self.pair_conv2(inter_2d)
pair = self.pair_conv3(pair)
pair = self.resnet(pair)
return pair
class InterTriangleMultiplication_S(nn.Module):
def __init__(self, channel_z=64, channel_c=64, transpose=False):
super(InterTriangleMultiplication_S, self).__init__()
self.dz = channel_z
self.dc = channel_c
self.transpose = transpose
# init norm
self.norm_init = nn.LayerNorm(self.dz)
# linear * gate for com_rec, com_lig
self.Linear_inter = nn.Linear(self.dz, self.dc)
self.gate_inter = nn.Linear(self.dz, self.dc)
self.Linear_intra = nn.Linear(self.dz, self.dc)
self.gate_intra = nn.Linear(self.dz, self.dc)
# final
self.norm_final = nn.LayerNorm(self.dc)
self.Linear_final = nn.Linear(self.dc, self.dz)
self.gate_final = nn.Linear(self.dz, self.dz)
def forward(self, z_inter, z_intra, transpose=False):
if self.transpose or transpose:
z_inter = z_inter.permute(0, 2, 1, 3)
z_intra = z_intra.permute(0, 2, 1, 3)
z_inter_init = self.norm_init(z_inter)
z_intra = self.norm_init(z_intra)
z_inter = self.Linear_inter(z_inter_init) * self.gate_inter(z_inter_init).sigmoid()
z_intra = self.Linear_intra(z_intra) * self.gate_intra(z_intra).sigmoid()
z_inter_update = torch.einsum('bikc,bkjc->bijc', z_intra, z_inter)
z_inter_update = self.norm_final(z_inter_update)
z_inter = self.Linear_final(z_inter_update) * self.gate_final(z_inter_init).sigmoid()
if self.transpose or transpose:
z_inter = z_inter.permute(0, 2, 1, 3)
return z_inter
class InterCrossAttention_S(nn.Module):
def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=True, transpose=False):
super(InterCrossAttention_S, self).__init__()
self.dz = channel_z
self.dhc = dim_head
self.num_head = num_head
self.dc = self.num_head * self.dhc
self.transpose = transpose
self.bias = bias
self.norm_init = nn.LayerNorm(self.dz)
self.Linear_Q = nn.Linear(self.dz, self.dc)
self.Linear_K = nn.Linear(self.dz, self.dc)
self.Linear_V = nn.Linear(self.dz, self.dc)
if self.bias:
self.Linear_bias = nn.Linear(self.dz, self.num_head)
self.softmax = nn.Softmax(dim=-1)
self.gate_final = nn.Linear(self.dz, self.dc)
self.Linear_final = nn.Linear(self.dc, self.dz)
def reshape_dim(self, x):
new_shape = x.size()[:-1] + (self.num_head, self.dhc)
return x.view(*new_shape)
def forward(self, z_inter, z_intra, transpose=False):
if self.transpose or transpose:
z_inter = z_inter.permute(0, 2, 1, 3)
z_intra = z_intra.permute(0, 2, 1, 3)
B, row, col, _ = z_inter.shape
z_inter = self.norm_init(z_inter)
z_intra = self.norm_init(z_intra)
q = self.reshape_dim(self.Linear_Q(z_inter))
k = self.reshape_dim(self.Linear_K(z_intra))
v = self.reshape_dim(self.Linear_V(z_intra))
scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device))
attn = torch.einsum('b l j h c, b l i h c -> b l j h i', q * scalar, k)
if self.bias:
bias = self.Linear_bias(z_inter).unsqueeze(-1)
attn = attn + bias
if attn.dtype is torch.bfloat16:
with torch.cuda.amp.autocast(enabled=False):
attn_weights = torch.nn.functional.softmax(attn, -1)
else:
attn_weights = torch.nn.functional.softmax(attn, -1)
attn = torch.einsum('b l j h i, b l i h c -> b l j h c', attn_weights, v)
gate_final = self.reshape_dim(self.gate_final(z_inter)).sigmoid()
z_inter = (attn * gate_final).contiguous().view(gate_final.size()[:-2] + (-1,))
z_inter = self.Linear_final(z_inter)
if self.transpose or transpose:
z_inter = z_inter.permute(0, 2, 1, 3)
return z_inter
class InterSelfAttention_S(nn.Module):
def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=False, transpose=False):
super(InterSelfAttention_S, self).__init__()
self.dz = channel_z
self.dhc = dim_head
self.num_head = num_head
self.dc = self.num_head * self.dhc
self.transpose = transpose
self.bias = bias
self.norm_init = nn.LayerNorm(self.dz)
self.Linear_Q = nn.Linear(self.dz, self.dc)
self.Linear_K = nn.Linear(self.dz, self.dc)
self.Linear_V = nn.Linear(self.dz, self.dc)
if self.bias:
self.Linear_bias = nn.Linear(self.dz, self.num_head)
self.softmax = nn.Softmax(-1)
self.gate_v = nn.Linear(self.dz, self.dc)
self.Linear_final = nn.Linear(self.dc, self.dz)
def reshape_dim(self, x):
new_shape = x.size()[:-1] + (self.num_head, self.dhc)
return x.view(*new_shape)
def forward(self, z_inter, dist=None, transpose=False):
if self.transpose or transpose:
z_inter = z_inter.permute(0, 2, 1, 3)
B, row, col, _ = z_inter.shape
z_inter = self.norm_init(z_inter)
q = self.reshape_dim(self.Linear_Q(z_inter))
k = self.reshape_dim(self.Linear_K(z_inter))
v = self.reshape_dim(self.Linear_V(z_inter))
scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device))
attn = torch.einsum('b l i h c, b l j h c -> b l i h j', q * scalar, k)
if self.bias:
bias = self.Linear_bias(z_inter)
# print("bias:", bias.shape, attn.shape, z_inter.shape)
attn = attn + bias
if dist != None:
coef = torch.exp(-(dist/8.0)**2.0/2.0).unsqueeze(3).type_as(q)
attn = attn * coef
if attn.dtype is torch.bfloat16:
with torch.cuda.amp.autocast(enabled=False):
attn_weights = torch.nn.functional.softmax(attn, -1)
else:
attn_weights = torch.nn.functional.softmax(attn, -1)
v_avg = torch.einsum('b l i h j, b l j h c -> b l i h c', attn_weights, v)
gate_v = self.reshape_dim(self.gate_v(z_inter)).sigmoid()
z_com = (v_avg * gate_v).contiguous().view( v_avg.size()[:-2] + (-1,) )
z_final = self.Linear_final(z_com)
if self.transpose or transpose:
z_final = z_final.permute(0, 2, 1, 3)
return z_final
class Transition(nn.Module):
def __init__(self, channel_z=64, transition_n=4):
super(Transition, self).__init__()
self.dz = channel_z
self.n = transition_n
self.norm = nn.LayerNorm(self.dz)
self.transition = nn.Sequential(
nn.Linear(self.dz, self.dz*self.n),
nn.Linear(self.dz*self.n, self.dz)
)
def forward(self, z_com):
z_com = self.norm(z_com)
z_com = self.transition(z_com)
return z_com

16
pplm_contact/config.py Normal file
View File

@@ -0,0 +1,16 @@
import os
hhsuite_dir = "/mnt/rna01/junl/tools/hh-suite/build"
hhblits = os.path.join(hhsuite_dir, "bin/hhblits")
hhmake = os.path.join(hhsuite_dir, "bin/hhmake")
reformat = os.path.join(hhsuite_dir, "scripts/reformat.pl")
hhfilter = os.path.join(hhsuite_dir, "bin/hhfilter")
UniRef_database = "/mnt/dna01/library2/e2e_folding/alphafold/v2.3/lib/uniref30/UniRef30_2021_03"
ccmpred = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/ccmpred"
esm_msa = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/extract_esm_msa_features.py"
esm_msa_model = "/mnt/rna01/junl/tools/esm-main/esm_models/esm_msa1_t12_100M_UR50S.pt"

72
pplm_contact/model.py Normal file
View File

@@ -0,0 +1,72 @@
import os
import sys
import torch
import torch.nn as nn
mian_path = os.path.dirname(__file__) + "/../"
sys.path.append(os.path.abspath(mian_path))
from pplm_contact.Module import *
class PPLM_Contact(nn.Module):
def __init__(self,
intra_1d_dim=768+20,
intra_2d_dim=144+2+64,
inter_2d_dim=144+2+660,
channels=64,
num_blocks=12,
droupout=0.10,
):
super(PPLM_Contact, self).__init__()
self.intra_1d_dim = intra_1d_dim
self.intra_2d_dim = intra_2d_dim
self.inter_2d_dim = inter_2d_dim
self.channels = channels
self.num_blocks = num_blocks
self.droupout = droupout
### ResNet
self.intra_resnet = Intra_ResNet(dim_1d=self.intra_1d_dim, dim_2d=self.intra_2d_dim, channels=self.channels)
self.inter_resnet = Inter_ResNet(dim_2d=self.inter_2d_dim, channels=self.channels)
### Transformer
self.InterTriangleMulti_R = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=False) for _ in range(num_blocks)])
self.InterTriangleMulti_L = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=True) for _ in range(num_blocks)])
self.InterCrossAttn_R = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=False) for _ in range(num_blocks)])
self.InterCrossAttn_L = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=True) for _ in range(num_blocks)])
self.InterSelfAttn_R = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=False) for _ in range(num_blocks)])
self.InterSelfAttn_L = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=True) for _ in range(num_blocks)])
self.Transition = nn.ModuleList([Transition(channel_z=self.channels) for _ in range(num_blocks)])
self.drop = nn.Dropout(self.droupout)
self.norm_final = nn.LayerNorm(self.channels)
self.Linear_final = nn.Linear(self.channels, 1)
self.act_final = nn.Sigmoid()
def forward(self, intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_dist=None, intra2_dist=None, last_layer=False):
intra1_2d = self.intra_resnet(intra1_1d, intra1_2d)
intra2_2d = self.intra_resnet(intra2_1d, intra2_2d)
inter_2d = self.inter_resnet(inter_2d)
intra1_2d = intra1_2d.permute(0, 2, 3, 1)
intra2_2d = intra2_2d.permute(0, 2, 3, 1)
inter_2d = inter_2d.permute(0, 2, 3, 1)
for block in range(self.num_blocks):
inter_2d = inter_2d + self.drop(self.InterTriangleMulti_R[block](inter_2d, intra1_2d)) + self.drop(self.InterTriangleMulti_L[block](inter_2d, intra2_2d, transpose=True))
inter_2d = inter_2d + self.drop(self.InterCrossAttn_R[block](inter_2d, intra1_2d)) + self.drop(self.InterCrossAttn_L[block](inter_2d, intra2_2d, transpose=True))
inter_2d = inter_2d + self.drop(self.InterSelfAttn_R[block](inter_2d, intra2_dist)) + self.drop(self.InterSelfAttn_L[block](inter_2d, intra1_dist, transpose=True))
inter_2d = inter_2d + self.Transition[block](inter_2d)
inter_2d = self.norm_final(inter_2d)
if last_layer:
representation = inter_2d.permute(0,3,1,2)
inter_2d = self.Linear_final(inter_2d)
inter_contact = self.act_final(inter_2d)
if last_layer:
return inter_contact.permute(0,3,1,2)[0], representation
return inter_contact.permute(0,3,1,2)[0]

353
pplm_contact/predict.py Normal file
View File

@@ -0,0 +1,353 @@
import os
import sys
import pathlib
import numpy as np
import pickle
import torch
import subprocess
import argparse
from model import PPLM_Contact
from utils import extract_seq_and_dist_map, pairing_msa, RBF
from LoadHHM import load_hmm
from config import *
def main():
parser = argparse.ArgumentParser(description="Protein-Protein Contact Prediction",
epilog="v0.0.1")
parser.add_argument("pdbA_path",
type=pathlib.Path,
help="Location of pdb A")
parser.add_argument("pdbB_path",
type=pathlib.Path,
help="Location of pdb B")
parser.add_argument("output_folder",
type=pathlib.Path,
help="Location to store output files")
parser.add_argument("--n_cpu",
type=int,
default=8,
help="Number of CPU cores for search MSA",
)
parser.add_argument("--gpu_id",
type=int,
default=0,
help="gpu device specified",
)
args = parser.parse_args()
### Define parameters ###
define_param(args)
##### Step 0: Process the input pdb (clean pdb & extract sequence & derive monomer distance map) #####
print("========== Step 0: Processing the intput pdb (Clean pdb, extract seq and dist map) ==========")
subprocess.run("grep \"^ATOM\" " + str(args.pdbA_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target1_pdb, shell=True, check=True)
target1_res_idx_type = extract_seq_and_dist_map(target1_pdb, target1_seq, target1_monomer_dist)
if mode == "hetero":
subprocess.run("grep \"^ATOM\" " + str(args.pdbB_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target2_pdb, shell=True, check=True)
target2_res_idx_type = extract_seq_and_dist_map(target2_pdb, target2_seq, target2_monomer_dist)
else:
target2_res_idx_type = target1_res_idx_type
##### Step 1: Search MSA #####
print("========== Step 1: Searching MSA ==========")
if not os.path.isfile(target1_msa):
print("Searching MSA for", target1_seq)
subprocess.run(hhblits + " -i " + target1_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target1_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True)
if mode == "hetero" and not os.path.isfile(target2_msa):
print("Searching MSA for", target2_seq)
subprocess.run(hhblits + " -i " + target2_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target2_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True)
##### Step 2: Extract Monomer MSA features
print("========== Step 2: Extract Monomer MSA features ==========")
if os.path.isfile(target1_msa):
extract_MSA_features(target1, target1_msa, target1_hhm, target1_aln, target1_dca_di, target1_dca_apc, target1_esm_msa)
if mode == "hetero" and os.path.isfile(target2_msa):
extract_MSA_features(target2, target2_msa, target2_hhm, target2_aln, target2_dca_di, target2_dca_apc, target2_esm_msa)
##### Step 3: Extract paired MSA features
print("========== Step 3: Extract paired MSA features ==========")
if mode == "hetero":
pairing_msa(target1_msa, target2_msa, paired_msa)
extract_MSA_features(target, paired_msa, paired_hhm, paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa)
##### Step 4: Generate PPLM inter-protein attention matrix
print("========== Step 4: Generate PPLM features ==========")
if not os.path.isfile(pplm_feat):
get_pplm_features(target1_seq, target2_seq, pplm_feat, device='cpu')
##### Step 5: Collect all features
print("========== Step 5: Collect all features ==========")
feats = collect_all_features()
##### Step 6: Predict inter-protein contact
print("========== Step 6: Predict inter-protein contact ==========")
if not os.path.isfile(pred_contact_pkl_path) or not os.path.isfile(pred_contact_txt_path):
pred_inter_contatct = predict_contact(feats, mode, device=device)
with open(pred_contact_pkl_path, "wb") as fw:
pickle.dump(pred_inter_contatct, fw)
with open(pred_contact_txt_path, "w") as fw:
prediction_idx_prob = []
for i in range(pred_inter_contatct.shape[0]):
for j in range(pred_inter_contatct.shape[1]):
prediction_idx_prob.append([i, j, pred_inter_contatct[i, j]])
prediction_idx_prob = sorted(prediction_idx_prob, key=lambda x: x[2], reverse=True)
data = "{:<10}".format("Rank") + "{:<10}".format("ResIdx1") + "{:<10}".format("ResType1") + "{:<10}".format("ResIdx2") + "{:<10}".format("ResType2") + "{:<20}".format("Contact_Probability") + "\n"
fw.write(data)
for k in range(len(prediction_idx_prob)):
res1 = prediction_idx_prob[k][0]
res2 = prediction_idx_prob[k][1]
prob = prediction_idx_prob[k][2]
res1_idx = target1_res_idx_type[res1][0]
res1_type = target1_res_idx_type[res1][1]
res2_idx = target2_res_idx_type[res2][0]
res2_type = target2_res_idx_type[res2][1]
data = "{:<10}".format(k+1) + "{:<10}".format(str(res1_idx) + ":A") + "{:<10}".format(res1_type) + "{:<10}".format(str(res2_idx) + ":B") + "{:<10}".format(res2_type) + "{:<10}".format(f"{prob:.6g}") + "\n"
fw.write(data)
def define_param(args):
# param = {}
global target1, target2, target, workspace, mode, device
global target1_pdb, target1_seq, target1_monomer_dist, target2_pdb, target2_seq
global target2_monomer_dist, target1_msa, target2_msa, target1_hhm, target1_aln
global target1_dca_di, target1_dca_apc, target1_esm_msa, target2_hhm, target2_aln
global target2_dca_di, target2_dca_apc, target2_esm_msa, paired_msa, paired_hhm
global paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa, pplm_feat
global pred_contact_pkl_path, pred_contact_txt_path
target1 = args.pdbA_path.stem
target2 = args.pdbB_path.stem
target = str(args.output_folder).split('/')[-1] # target1 + '-' + target2
workspace = args.output_folder
if not os.path.isdir(workspace):
os.makedirs(workspace)
assigned_device = "cuda:" + str(args.gpu_id)
device = assigned_device if torch.cuda.is_available() else "cpu"
if args.pdbA_path == args.pdbB_path:
mode = "homo"
else:
mode = "hetero"
print("device:", device, " mode:", mode)
target1_pdb = os.path.join(workspace, target1 + ".clean.pdb")
target1_seq = os.path.join(workspace, target1 + ".fasta")
target1_monomer_dist = os.path.join(workspace, target1 + ".monomer_dist.pkl")
target2_pdb = os.path.join(workspace, target2 + ".clean.pdb")
target2_seq = os.path.join(workspace, target2 + ".fasta")
target2_monomer_dist = os.path.join(workspace, target2 + ".monomer_dist.pkl")
target1_msa = os.path.join(workspace, target1 + ".a3m")
target2_msa = os.path.join(workspace, target2 + ".a3m")
target1_hhm = os.path.join(workspace, target1 + ".hhm")
target1_aln = os.path.join(workspace, target1 + ".aln")
target1_dca_di = os.path.join(workspace, target1 + ".dca_di")
target1_dca_apc = os.path.join(workspace, target1 + ".dca_apc")
target1_esm_msa = os.path.join(workspace, target1 + ".esm_msa.pkl")
target2_hhm = os.path.join(workspace, target2 + ".hhm")
target2_aln = os.path.join(workspace, target2 + ".aln")
target2_dca_di = os.path.join(workspace, target2 + ".dca_di")
target2_dca_apc = os.path.join(workspace, target2 + ".dca_apc")
target2_esm_msa = os.path.join(workspace, target2 + ".esm_msa.pkl")
paired_msa = os.path.join(workspace, target + "_paired.a3m")
paired_hhm = os.path.join(workspace, target + "_paired.hhm")
paired_aln = os.path.join(workspace, target + "_paired.aln")
paired_dca_di = os.path.join(workspace, target + "_paired.dca_di")
paired_dca_apc = os.path.join(workspace, target + "_paired.dca_apc")
paired_esm_msa = os.path.join(workspace, target + "_paired.esm_msa.pkl")
pplm_feat = os.path.join(workspace, target + ".pplm.pkl")
pred_contact_pkl_path = os.path.join(workspace, target + ".pred_contact.pkl")
pred_contact_txt_path = os.path.join(workspace, target + ".pred_contact.txt")
def extract_MSA_features(name, msa, hhm, aln, dci_di, dci_apc, esm_msa_path):
print("Generating the HHM file for", msa)
subprocess.run(hhmake + " -i " + msa + " -o " + hhm, shell=True, check=True)
subprocess.run(reformat + " " + msa + " " + os.path.join(workspace, name + ".fas") + " -r -l 2000 >/dev/null", shell=True, check=True)
subprocess.run("awk '{if(!($0~/^>/)){print}}' " + os.path.join(workspace, name + ".fas") + " > " + aln, shell=True, check=True)
# generate the DCA features
print("Generating the DCA features for", aln)
if not os.path.isfile(dci_di) or not os.path.isfile(dci_apc):
subprocess.run(ccmpred + " " + aln + " " + dci_di + " -R -A", shell=True, check=True)
subprocess.run(ccmpred + " " + aln + " " + dci_apc + " -R", shell=True, check=True)
# generate the ESM-MSA features
print("Generating the ESM-MSA features for", msa)
if not os.path.isfile(esm_msa_path):
subprocess.run(hhfilter + " -i " + msa + " -o " + msa + "_filtered -diff 512", shell=True, check=True)
subprocess.run(["python", esm_msa, esm_msa_model, msa + "_filtered", esm_msa_path])
def get_pplm_features(seqA_path, seqB_path, out_pkl_path, device='cpu'):
mian_path = os.path.dirname(__file__) + "/../"
sys.path.append(os.path.abspath(mian_path))
from pplm import PPLM, Alphabet
model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt')
##### Loading PPLM Model #####
alphabet = Alphabet.from_architecture()
batch_converter = alphabet.get_batch_converter()
model_data = torch.load(model_location, map_location="cpu")
model_param = model_data["param"]
model_state = model_data["model"]
model = PPLM(
num_layers=model_param['encoder_layers'],
embed_dim=model_param['encoder_embed_dim'],
attention_heads=model_param['encoder_attention_heads'],
token_dropout=False,
alphabet=alphabet
)
model.to(device)
model.load_state_dict(model_state, strict=False)
with torch.no_grad():
seqA, seqB = '', ''
for line in open(seqA_path, "r").readlines():
if not line.startswith(">"):
seqA += line.strip()
for line in open(seqB_path, "r").readlines():
if not line.startswith(">"):
seqB += line.strip()
seqA_labels, seqA_strs, seqA_tokens = batch_converter([('seqA', seqA)])
seqB_labels, seqB_strs, seqB_tokens = batch_converter([('seqB', seqB)])
tokens = torch.cat([seqA_tokens, seqB_tokens], dim=-1).to(device)
inter_chain_mask = torch.ones((len(seqA) + 2 + len(seqB) + 2, len(seqA) + 2 + len(seqB) + 2), device=device)
inter_chain_mask[:len(seqA) + 2, :len(seqA) + 2] = 0
inter_chain_mask[len(seqA) + 2:, len(seqA) + 2:] = 0
##### running PPLM #####
out = model(tokens, inter_chain_mask, repr_layers=[33], need_head_weights=True, return_contacts=False)
# attn_AA = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), 1:(len(seqA) + 1)].reshape(33 * 20, len(seqA), len(seqA)).cpu().numpy()
attn_AB = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), -(len(seqB) + 1):-1].reshape(33 * 20, len(seqA), len(seqB)).cpu().numpy()
attn_BA = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, 1:(len(seqA) + 1)].reshape(33 * 20, len(seqB), len(seqA)).cpu().numpy()
# attn_BB = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, -(len(seqB) + 1):-1].reshape(33 * 20, len(seqB), len(seqB)).cpu().numpy()
inter_attn = (attn_AB + attn_BA.transpose(0, 2, 1)) / 2
with open(out_pkl_path, mode='wb') as fw:
pickle.dump(inter_attn, fw)
def collect_all_features():
with open(target1_monomer_dist, "rb") as fr:
target1_M_dist = pickle.load(fr)
target1_DCA_DI = np.expand_dims(np.loadtxt(target1_dca_di), 0)
target1_DCA_APC = np.expand_dims(np.loadtxt(target1_dca_apc), 0)
target1_PSSM = load_hmm(target1_hhm)['PSSM']
with open(target1_esm_msa, "rb") as fr:
esm_msa_data = pickle.load(fr)
target1_esm_msa_1d = esm_msa_data['esm_msa_1d']
target1_esm_msa_2d = esm_msa_data['row_attentions']
print(target1_DCA_DI.shape, target1_DCA_APC.shape, target1_esm_msa_2d.shape, RBF(target1_M_dist).shape, target1_M_dist.shape)
intra1_1d = np.concatenate([target1_PSSM, target1_esm_msa_1d], axis=-1).transpose(1,0)
intra1_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, RBF(target1_M_dist)], axis=0)
intra1_Mdist = target1_M_dist
with open(pplm_feat, "rb") as fr:
inter_pplm_attn = pickle.load(fr)
if mode == "homo":
intra2_1d = intra1_1d
intra2_2d = intra1_2d
intra2_Mdist = intra1_Mdist
inter_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, inter_pplm_attn], axis=0)
else:
with open(target2_monomer_dist, "rb") as fr:
target2_M_dist = pickle.load(fr)
target2_DCA_DI = np.expand_dims(np.loadtxt(target2_dca_di), 0)
target2_DCA_APC = np.expand_dims(np.loadtxt(target2_dca_apc), 0)
target2_PSSM = load_hmm(target2_hhm)['PSSM']
with open(target2_esm_msa, "rb") as fr:
esm_msa_data = pickle.load(fr)
target2_esm_msa_1d = esm_msa_data['esm_msa_1d']
target2_esm_msa_2d = esm_msa_data['row_attentions']
print(target2_DCA_DI.shape, target2_DCA_APC.shape, target2_esm_msa_2d.shape, RBF(target2_M_dist).shape, target2_M_dist.shape)
intra2_1d = np.concatenate([target2_PSSM, target2_esm_msa_1d], axis=-1).transpose(1, 0)
intra2_2d = np.concatenate([target2_DCA_DI, target2_DCA_APC, target2_esm_msa_2d, RBF(target2_M_dist)], axis=0)
intra2_Mdist = target2_M_dist
inter_DCA_DI = np.expand_dims(np.loadtxt(paired_dca_di), 0)
inter_DCA_APC = np.expand_dims(np.loadtxt(paired_dca_apc), 0)
with open(paired_esm_msa, "rb") as fr:
esm_msa_data = pickle.load(fr)
inter_esm_msa_2d = esm_msa_data['row_attentions']
len1 = inter_pplm_attn.shape[-2]
len2 = inter_esm_msa_2d.shape[-1]
print(inter_DCA_DI.shape, inter_DCA_APC.shape, inter_esm_msa_2d.shape, inter_pplm_attn.shape)
inter_2d = np.concatenate([inter_DCA_DI[:, :len1, len1:len1+len2], inter_DCA_APC[:, :len1, len1:len1+len2], inter_esm_msa_2d[:, :len1, len1:len1+len2], inter_pplm_attn], axis=0)
feats = {"intra1_1d": intra1_1d, "intra1_2d": intra1_2d, "intra1_Mdist": intra1_Mdist, "intra2_1d": intra2_1d, "intra2_2d": intra2_2d, "intra2_Mdist": intra2_Mdist, "inter_2d": inter_2d}
return feats
def predict_contact(feats, mode, device="cpu"):
script_dir = os.path.dirname(os.path.abspath(__file__))
if mode == "homo":
model_paths = [os.path.join(script_dir, "models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)]
else:
model_paths = [os.path.join(script_dir, "models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)]
model = PPLM_Contact()
model.to(device)
intra1_1d = torch.Tensor(feats["intra1_1d"]).to(device).type(torch.float32).unsqueeze(0)
intra1_2d = torch.Tensor(feats["intra1_2d"]).to(device).type(torch.float32).unsqueeze(0)
intra1_Mdist = torch.Tensor(feats["intra1_Mdist"]).to(device).type(torch.float32).unsqueeze(0)
intra2_1d = torch.Tensor(feats["intra2_1d"]).to(device).type(torch.float32).unsqueeze(0)
intra2_2d = torch.Tensor(feats["intra2_2d"]).to(device).type(torch.float32).unsqueeze(0)
intra2_Mdist = torch.Tensor(feats["intra2_Mdist"]).to(device).type(torch.float32).unsqueeze(0)
inter_2d = torch.Tensor(feats["inter_2d"]).to(device).type(torch.float32).unsqueeze(0)
ensemble_pred_inter_contact = []
with torch.no_grad():
for model_path in model_paths:
checkpoint = torch.load(model_path, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()
#################################### Network predict #######################################
contact_pred = model(intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_Mdist, intra2_Mdist)
pred_inter_contact = contact_pred # [inter_contact_mask_ur]
ensemble_pred_inter_contact.append(pred_inter_contact)
pred_inter_contact = torch.stack(ensemble_pred_inter_contact)
pred_inter_contact = torch.mean(pred_inter_contact, dim=0).squeeze().cpu().detach().numpy()
return pred_inter_contact
if __name__ == "__main__":
main()

225
pplm_contact/utils.py Normal file
View File

@@ -0,0 +1,225 @@
import numpy as np
import pickle
import string
restype_3to1 = {k: v for k, v in zip(['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'], 'ARNDCEQGHILKMFPSTWYV')}
heavy_atoms = ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG']
def pdb2seq(pdb_path, seq_path):
sequence_list = []
sequence = ""
last_chain_id = ""
for data in open(pdb_path, 'r').readlines():
atom_name = data[13:16].strip()
if data.startswith("ATOM") and atom_name == 'CA':
chain_id = data[21]
res_name = data[17:20].strip()
if chain_id != last_chain_id and len(sequence) > 0:
sequence_list.append([last_chain_id, sequence])
sequence = restype_3to1[res_name]
else:
sequence += restype_3to1[res_name]
last_chain_id = chain_id
if len(sequence) > 0:
sequence_list.append([last_chain_id, sequence])
if len(sequence_list) > 1:
print("Warning:", pdb_path, "contain multiple chains:", sequence_list[:, 0], "! Only first chain is considered.")
return sequence_list[0][1]
with open(seq_path, 'w') as fw:
chain_id = sequence_list[0][0]
sequence = sequence_list[0][1]
fw.write(">seq_" + chain_id + " " + str(len(sequence)) + "\n")
fw.write(sequence + "\n")
def extract_seq_and_dist_map(pdb_path, seq_path, dist_path):
### Load pdb_chain ###
pdb_res_coordis = []
res_atom_coordis = {}
res_idx_type = []
sequence = ''
last_res_idx = -1
last_res_name = ''
for data in open(pdb_path, 'r').readlines():
if data.startswith("ATOM"):
atom_name = data[13:16].strip()
res_name = data[17:20].strip()
res_idx = int(data[22:26].strip())
coordi_x = float(data[30:38].strip())
coordi_y = float(data[38:46].strip())
coordi_z = float(data[46:54].strip())
if last_res_idx != -1 and res_idx != last_res_idx:
pdb_res_coordis.append(res_atom_coordis)
res_atom_coordis = {}
sequence += restype_3to1[last_res_name]
res_idx_type.append([last_res_idx, last_res_name])
res_atom_coordis[atom_name] = [coordi_x, coordi_y, coordi_z]
last_res_idx = res_idx
last_res_name = res_name
if len(res_atom_coordis) != 0:
pdb_res_coordis.append(res_atom_coordis)
res_atom_coordis = {}
sequence += restype_3to1[last_res_name]
res_idx_type.append([last_res_idx, last_res_name])
length = len(sequence)
############## extract heavy atom distance ##################
heavy_atom_dist_map = np.ones((1, length, length)) * np.inf
for i in range(length):
for j in range(i, length):
min_dist = np.inf
for heay_i in heavy_atoms:
if heay_i in pdb_res_coordis[i]:
coordi_1 = pdb_res_coordis[i][heay_i]
else:
continue
for heay_j in heavy_atoms:
if heay_j in pdb_res_coordis[j]:
coordi_2 = pdb_res_coordis[j][heay_j]
else:
continue
dist = np.sqrt(pow(coordi_1[0] - coordi_2[0], 2) + pow(coordi_1[1] - coordi_2[1], 2) + pow(coordi_1[2] - coordi_2[2], 2))
if dist < min_dist:
min_dist = dist
heavy_atom_dist_map[0, i, j] = min_dist
heavy_atom_dist_map[0, j, i] = min_dist
with open(dist_path, mode='wb') as fw:
pickle.dump(heavy_atom_dist_map, fw)
with open(seq_path, 'w') as fw:
fw.write(">seq " + str(length) + "\n")
fw.write(sequence + "\n")
return res_idx_type
def pairing_msa(msa1_path, msa2_path, paired_msa_path):
msas1, sid1 = extract_taxid(msa1_path)
msas2, sid2 = extract_taxid(msa2_path)
aligns = alignment(msas1, sid1, msas2, sid2, top=True)
with open(paired_msa_path, 'w') as f:
f.write(">target " + str(len(aligns[0])) + "\n")
f.write(aligns[0] + "\n")
for idx, aligned_seq in enumerate(aligns[1:]):
f.write(">seq" + str(idx+1) + "\n")
f.write(aligned_seq + "\n")
def extract_taxid(file, gap_cutoff=0.8):
deletekeys = dict.fromkeys(string.ascii_lowercase)
deletekeys["."] = None
deletekeys["*"] = None
translation = str.maketrans(deletekeys)
lines = open(file, 'r').readlines()
query = lines[1].strip().translate(translation)
seq_len = len(query)
msas = [query]
sid = [0]
for line in lines[2:]:
if line[0] == ">":
if "TaxID=" in line:
content = line.split("TaxID=")[1]
if len(content) > 0:
try:
sid.append(int(content.split()[0]))
except:
sid.append(0)
elif "OX=" in line:
content = line.split("OX=")[1]
if len(content) > 0:
try:
sid.append(int(content.split()[0]))
except:
sid.append(0)
else:
sid.append(0)
continue
seq = line.strip().translate(translation)
gap_fra = float(seq.count('-')) / seq_len
if gap_fra <= gap_cutoff:
msas.append(seq)
else:
sid.pop(-1)
if len(msas) != len(sid):
print("ERROR: len(msas) != len(sid)")
print(len(msas), len(sid))
exit()
return msas, np.array(sid)
def cal_identity(query, sub_msas):
"""
Args:
query : str
sub_msas : List[str]
Return:
identity : np.array
"""
identity = np.zeros((len(sub_msas)))
seq_len = len(query)
ones = np.ones(seq_len)
for idx, seq in enumerate(sub_msas):
match = [query[i] == seq[i] for i in range(seq_len)]
counts = np.sum(ones[match])
identity[idx] = counts / seq_len
return identity
def alignment(msas1, sid1, msas2, sid2, top=True):
# obtain the same species and delete species=0
smatch = np.intersect1d(sid1, sid2)
smatch = smatch[np.argsort(smatch)]
smatch = np.delete(smatch, 0)
query1 = msas1[0]
query2 = msas2[0]
aligns = [query1 + query2]
for id in smatch:
index1 = np.where(sid1 == id)[0]
sub_msas1 = [msas1[idx] for idx in index1]
identity1 = cal_identity(query1, sub_msas1)
sort_idx1 = np.argsort(-identity1)
index2 = np.where(sid2 == id)[0]
sub_msas2 = [msas2[idx] for idx in index2]
identity2 = cal_identity(query2, sub_msas2)
sort_idx2 = np.argsort(-identity2)
if top == True:
aligns.append(sub_msas1[sort_idx1[0]] + \
sub_msas2[sort_idx2[0]])
else:
num = min(len(sub_msas1), len(sub_msas2))
for i in range(num):
aligns.append(sub_msas1[sort_idx1[i]] + \
sub_msas2[sort_idx2[i]])
return aligns
def RBF(dist_map):
# Radial Basis Function
D_min, D_max, D_count = 2., 22., 64
D_mu = np.linspace(D_min, D_max, D_count)
D_mu = D_mu[None,:]
D_sigma = (D_max - D_min) / D_count
dist_map = dist_map.transpose(1,2,0)
RBF = np.exp(-((dist_map - D_mu) / D_sigma)**2)
return RBF.transpose(2,0,1)