mirror of
https://github.com/junliu621/PPLM.git
synced 2026-06-04 14:24:22 +08:00
Add files via upload
This commit is contained in:
344
pplm_contact/LoadHHM.py
Normal file
344
pplm_contact/LoadHHM.py
Normal file
@@ -0,0 +1,344 @@
|
||||
#
|
||||
# This file (LoadHHM.py) was downloaded from the RaptorX-Contact
|
||||
# at https://github.com/j3xugit/RaptorX-Contact
|
||||
#
|
||||
import numpy as np
|
||||
import os
|
||||
import sys
|
||||
import pickle as pkl
|
||||
|
||||
"""
|
||||
This script reads an hhm file generated by the HHpred/HHblits package to generate position-specific scoring/frequency matrix.
|
||||
After reading an hhm file, this script stores the HMM as a python dict().
|
||||
To use the position-specfic frequency matrix, please use the keyword PSFM.
|
||||
To use the position-specific scoring matrix, please use the keyword PSSM.
|
||||
Both PSFM and PSSM encode information derived from the profile HMM built by HHpred or HHblits,
|
||||
so there is no need to directly use the keys containing 'hmm'
|
||||
|
||||
PSFM and PSSM columns are arranged by the alphabetical order of amino acids in their 1-letter code
|
||||
"""
|
||||
|
||||
SS8Letter2Code = {'H':0, 'G':1, 'I':2, 'E':3, 'B':4, 'T':5, 'S':6, 'L':7, 'C':7 }
|
||||
|
||||
## secondary structure conversion, note here HELIX, BETA and LOOP correspond to 1, 2 and 0, respectively
|
||||
SS8Letter2SS3Code = {'H':0, 'G':0, 'I':0, 'E':1, 'B':1, 'T':2, 'S':2, 'L':2, 'C':2 }
|
||||
|
||||
AA3LetterCode=['ALA', 'ARG', 'ASN', 'ASP', 'ASX', 'CYS', 'GLU', 'GLN', 'GLX', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRY', 'TYR', 'VAL']
|
||||
AA1LetterCode=['A', 'R', 'N', 'D', 'B', 'C', 'E', 'Q', 'Z', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ]
|
||||
|
||||
##we only allow 20 amino acids in our protein sequences
|
||||
ValidAALetters=set(['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V'])
|
||||
|
||||
##we map the two rare amino acids to 20
|
||||
## AAOrderBy3Letter is the alphabetical order of amino acids by its 3-letter code
|
||||
AAOrderBy3Letter=[0, 1, 2, 3, 20, 4, 5, 6, 20, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ]
|
||||
|
||||
## AAOrderBy1Letter is the alphabetical order of amino acids by its 1-letter code
|
||||
AAOrderBy1Letter=[0, 14, 11, 2, 20, 1, 3, 13, 20, 5, 6, 7, 9, 8, 10, 4, 12, 15, 16, 18, 19, 17 ]
|
||||
|
||||
## mapping between 1-letter code and 3-letter code
|
||||
AA3LetterCode21LetterCode = {}
|
||||
AA1LetterCode23LetterCode = {}
|
||||
|
||||
##map from one AA (3-letter code or 1-letter code) to its order in terms of 3-letter code
|
||||
AALetter2OrderOf3LetterCode = {}
|
||||
|
||||
## map from one AA (3-letter or 1-letter code) to its order in terms of 1-letter code
|
||||
AALetter2OrderOf1LetterCode = {}
|
||||
|
||||
## map between 3-letter order and 1-letter order
|
||||
AA3LetterOrder21LetterOrder = {}
|
||||
AA1LetterOrder23LetterOrder = {}
|
||||
|
||||
for l3, l1, o3, o1 in zip(AA3LetterCode, AA1LetterCode, AAOrderBy3Letter, AAOrderBy1Letter):
|
||||
AA1LetterCode23LetterCode[l1] = l3
|
||||
AA3LetterCode21LetterCode[l3] = l1
|
||||
|
||||
AALetter2OrderOf3LetterCode[l1] = o3
|
||||
AALetter2OrderOf3LetterCode[l3] = o3
|
||||
|
||||
AALetter2OrderOf1LetterCode[l1] = o1
|
||||
AALetter2OrderOf1LetterCode[l3] = o1
|
||||
|
||||
AA3LetterOrder21LetterOrder[o3] = o1
|
||||
AA1LetterOrder23LetterOrder[o1] = o3
|
||||
|
||||
|
||||
## the rows and cols of gonnet matrix are arranged in the alpahbetical oder of amino acids in the 3-letter code
|
||||
## the below matrix is a mutation probability matrix?
|
||||
gonnet = [
|
||||
[ 1.7378, 0.870964,0.933254,0.933254, 1.12202, 0.954993, 1, 1.12202, 0.831764, 0.831764, 0.758578, 0.912011, 0.851138, 0.588844, 1.07152, 1.28825, 1.14815, 0.436516, 0.60256, 1.02329],
|
||||
[ 0.870964,2.95121, 1.07152, 0.933254, 0.60256, 1.41254, 1.09648, 0.794328, 1.14815, 0.57544, 0.60256, 1.86209, 0.676083, 0.47863, 0.812831, 0.954993, 0.954993, 0.691831, 0.660693, 0.630957],
|
||||
[ 0.933254,1.07152, 2.39883, 1.65959, 0.660693, 1.1749, 1.23027, 1.09648, 1.31826, 0.524807, 0.501187, 1.20226, 0.60256, 0.489779, 0.812831, 1.23027, 1.12202, 0.436516, 0.724436, 0.60256],
|
||||
[ 0.933254,0.933254,1.65959, 2.95121, 0.47863, 1.23027, 1.86209, 1.02329, 1.09648, 0.416869, 0.398107, 1.12202, 0.501187, 0.354813, 0.851138, 1.12202, 1, 0.301995, 0.524807, 0.512861],
|
||||
[ 1.12202, 0.60256, 0.660693,0.47863, 14.1254, 0.57544, 0.501187, 0.630957, 0.74131, 0.776247, 0.707946, 0.524807, 0.812831, 0.831764, 0.489779, 1.02329, 0.891251, 0.794328, 0.891251, 1],
|
||||
[ 0.954993,1.41254, 1.1749, 1.23027, 0.57544, 1.86209, 1.47911, 0.794328, 1.31826, 0.645654, 0.691831, 1.41254, 0.794328, 0.549541, 0.954993, 1.04713, 1, 0.537032, 0.676083, 0.707946],
|
||||
[ 1, 1.09648, 1.23027, 1.86209, 0.501187, 1.47911, 2.29087, 0.831764, 1.09648, 0.537032, 0.524807, 1.31826, 0.630957, 0.40738, 0.891251, 1.04713, 0.977237, 0.371535, 0.537032, 0.645654],
|
||||
[ 1.12202, 0.794328,1.09648, 1.02329, 0.630957, 0.794328, 0.831764, 4.57088, 0.724436, 0.354813, 0.363078, 0.776247, 0.446684, 0.301995, 0.691831, 1.09648, 0.776247, 0.398107, 0.398107, 0.467735],
|
||||
[ 0.831764,1.14815, 1.31826, 1.09648, 0.74131, 1.31826, 1.09648, 0.724436, 3.98107, 0.60256, 0.645654, 1.14815, 0.74131, 0.977237, 0.776247, 0.954993, 0.933254, 0.831764, 1.65959, 0.630957],
|
||||
[ 0.831764,0.57544, 0.524807,0.416869, 0.776247, 0.645654, 0.537032, 0.354813, 0.60256, 2.51189, 1.90546, 0.616595, 1.77828, 1.25893, 0.549541, 0.660693, 0.870964, 0.660693, 0.851138, 2.04174],
|
||||
[ 0.758578,0.60256, 0.501187,0.398107, 0.707946, 0.691831, 0.524807, 0.363078, 0.645654, 1.90546, 2.51189, 0.616595, 1.90546, 1.58489, 0.588844, 0.616595, 0.74131, 0.851138, 1, 1.51356],
|
||||
[ 0.912011,1.86209, 1.20226, 1.12202, 0.524807, 1.41254, 1.31826, 0.776247, 1.14815, 0.616595, 0.616595, 2.0893, 0.724436, 0.467735, 0.870964, 1.02329, 1.02329, 0.446684, 0.616595, 0.676083],
|
||||
[ 0.851138,0.676083,0.60256, 0.501187, 0.812831, 0.794328, 0.630957, 0.446684, 0.74131, 1.77828, 1.90546, 0.724436, 2.69153, 1.44544, 0.57544, 0.724436, 0.870964, 0.794328, 0.954993, 1.44544],
|
||||
[ 0.588844,0.47863, 0.489779,0.354813, 0.831764, 0.549541, 0.40738, 0.301995, 0.977237, 1.25893, 1.58489, 0.467735, 1.44544, 5.01187, 0.416869, 0.524807, 0.60256, 2.29087, 3.23594, 1.02329],
|
||||
[ 1.07152, 0.812831,0.812831,0.851138, 0.489779, 0.954993, 0.891251, 0.691831, 0.776247, 0.549541, 0.588844, 0.870964, 0.57544, 0.416869, 5.7544, 1.09648, 1.02329, 0.316228, 0.489779, 0.660693],
|
||||
[ 1.28825, 0.954993,1.23027, 1.12202, 1.02329, 1.04713, 1.04713, 1.09648, 0.954993, 0.660693, 0.616595, 1.02329, 0.724436, 0.524807, 1.09648, 1.65959, 1.41254, 0.467735, 0.645654, 0.794328],
|
||||
[ 1.14815, 0.954993,1.12202, 1, 0.891251, 1, 0.977237, 0.776247, 0.933254, 0.870964, 0.74131, 1.02329, 0.870964, 0.60256, 1.02329, 1.41254, 1.77828, 0.446684, 0.645654, 1],
|
||||
[ 0.436516,0.691831,0.436516,0.301995, 0.794328, 0.537032, 0.371535, 0.398107, 0.831764, 0.660693, 0.851138, 0.446684, 0.794328, 2.29087, 0.316228, 0.467735, 0.446684, 26.3027, 2.5704, 0.549541],
|
||||
[ 0.60256, 0.660693,0.724436,0.524807, 0.891251, 0.676083, 0.537032, 0.398107, 1.65959, 0.851138, 1, 0.616595, 0.954993, 3.23594, 0.489779, 0.645654, 0.645654, 2.5704, 6.0256, 0.776247],
|
||||
[ 1.02329, 0.630957,0.60256, 0.512861, 1, 0.707946, 0.645654, 0.467735, 0.630957, 2.04174, 1.51356, 0.676083, 1.44544, 1.02329, 0.660693, 0.794328, 1, 0.549541, 0.776247, 2.18776]
|
||||
]
|
||||
|
||||
gonnet = np.array(gonnet, np.float32)
|
||||
|
||||
M_M, M_I, M_D, I_M, I_I, D_M, D_D, _NEFF, I_NEFF, D_NEFF = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9
|
||||
|
||||
## HMMNull is the background score of amino acids, in the alphabetical order of 1-letter code
|
||||
HMMNull = [3706,5728,4211,4064,4839,3729,4763,4308,4069,3323,5509,4640,4464,4937,4285,4423,3815,3783,6325,4665,0]
|
||||
HMMNull = np.array(HMMNull, dtype=np.float32)
|
||||
|
||||
|
||||
## this function reads the HMM block from a profleHMM file (generated by HHpred/HHblits package)
|
||||
## for the profileHMM file, the header has 4 lines
|
||||
def ReadHHM(lines, start_position, length, one_protein, numLines4Header=4):
|
||||
|
||||
i = start_position
|
||||
one_protein['HMMHeader'] = lines[i: i+ numLines4Header]
|
||||
i += numLines4Header
|
||||
|
||||
## the columns of hmm1 are in the alphabetical order of amino acids in 1-letter code, different from the above PSP and PSM matrices
|
||||
one_protein['hmm1'] = np.zeros( (length, 20), np.float32)
|
||||
one_protein['hmm2'] = np.zeros( (length, 10), np.float32)
|
||||
one_protein['hmm1_prob'] = np.zeros((length, 20), np.float32)
|
||||
one_protein['hmm1_score'] = np.zeros((length, 20), np.float32)
|
||||
|
||||
seqStr = ''
|
||||
|
||||
for l in range(length):
|
||||
##this line is for emission score. The amino acids are ordered from left to right alphabetically by their 1-letter code
|
||||
fields = lines[ i + l*3 + 0].replace("*", "99999").split()
|
||||
|
||||
assert len(fields) == 23
|
||||
one_protein['hmm1'][l] = np.array([ - np.int32(num) for num in fields[2: -1] ])/1000.
|
||||
aa = fields[0]
|
||||
seqStr += aa
|
||||
"""
|
||||
if (l + 1) != np.int32(fields[1]):
|
||||
print 'Error: inconsistent residue number in file for protein ', one_protein['name'], ' at line: ', lines[ i + l*3 + 0 ]
|
||||
exit(-1)
|
||||
"""
|
||||
|
||||
##the first 7 columns of this line is for state transition
|
||||
one_protein['hmm2'][l][0:7] = [ np.exp(-np.int32(num)/1000.0*0.6931) for num in lines[i + l*3 + 1].replace("*", "99999").split()[0:7]]
|
||||
|
||||
##the last 3 columns of this line is for Neff of Match, Insertion and Deletion
|
||||
one_protein['hmm2'][l][7:10] = [ np.int32(num)/1000.0 for num in lines[i + l*3 + 1].split()[7:10]]
|
||||
|
||||
## _NEFF is for match, I_NEFF for insertion and D_NEFF for deletion. More comments are needed for the below code
|
||||
rm = 0.1
|
||||
one_protein['hmm2'][l][M_M] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_M] + rm*0.6)/(rm + one_protein['hmm2'][l][_NEFF])
|
||||
one_protein['hmm2'][l][M_I] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_I] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF])
|
||||
one_protein['hmm2'][l][M_D] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_D] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF])
|
||||
|
||||
ri = 0.1
|
||||
one_protein['hmm2'][l][I_I] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_I] + ri*0.75)/(ri + one_protein['hmm2'][l][I_NEFF])
|
||||
one_protein['hmm2'][l][I_M] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_M] + ri*0.25)/(ri + one_protein['hmm2'][l][I_NEFF])
|
||||
|
||||
rd = 0.1
|
||||
one_protein['hmm2'][l][D_D] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_D] + rd*0.75)/(rd + one_protein['hmm2'][l][D_NEFF])
|
||||
one_protein['hmm2'][l][D_M] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_M] + rd*0.25)/(rd + one_protein['hmm2'][l][D_NEFF])
|
||||
|
||||
|
||||
one_protein['hmm1_prob'][l,] = pow(2.0, one_protein['hmm1'][l,])
|
||||
wssum = sum(one_protein['hmm1_prob'][l, ])
|
||||
|
||||
#print 'l = ', l, 'sum= ', wssum
|
||||
|
||||
## renormalize to make wssum = 1
|
||||
if wssum > 0 :
|
||||
one_protein['hmm1_prob'][l, ] /= wssum
|
||||
else:
|
||||
one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1.
|
||||
|
||||
"""
|
||||
## if the probability sum is not equal to 1
|
||||
if abs(wssum - 1.) > 0.1 :
|
||||
one_protein['hmm1_prob'][l, ] = 0
|
||||
one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1.
|
||||
"""
|
||||
|
||||
|
||||
## add pseudo count
|
||||
g = np.zeros( (20), np.float32)
|
||||
for j in range(20):
|
||||
orderIn3LetterCode_j = AA1LetterOrder23LetterOrder[j]
|
||||
for k in range(20):
|
||||
orderIn3LetterCode_k = AA1LetterOrder23LetterOrder[k]
|
||||
g[j] += one_protein['hmm1_prob'][l, k] * gonnet[ orderIn3LetterCode_k, orderIn3LetterCode_j ]
|
||||
g[j] *= pow(2.0, -1.0*HMMNull[j] / 1000.0)
|
||||
|
||||
#print 'l=', l, ' gsum= ', sum(g)
|
||||
## sum(g) is very close to 1, here we renormalize g to make its sum to be exactly 1
|
||||
g = g/sum(g)
|
||||
|
||||
ws_tmp_neff = one_protein['hmm2'][l][_NEFF] - 1
|
||||
one_protein['hmm1'][l, ] = (ws_tmp_neff * one_protein['hmm1_prob'][l, ] + g*10) / (ws_tmp_neff+10)
|
||||
|
||||
## recalculate the emission score and probability after pseudo count is added
|
||||
one_protein['hmm1_prob'][l,] = one_protein['hmm1'][l, ]
|
||||
one_protein['hmm1'][l, ] = np.log2(one_protein['hmm1_prob'][l, ])
|
||||
one_protein['hmm1_score' ][l, ] = one_protein['hmm1'][l, ] + HMMNull[:20]/1000.0
|
||||
|
||||
## PSFM: position-specific frequency matrix, PSSM: position-specific scoring matrix
|
||||
one_protein['PSFM'] = one_protein['hmm1_prob']
|
||||
one_protein['PSSM'] = one_protein['hmm1_score']
|
||||
|
||||
#assert ( seqStr == one_protein['sequence'] )
|
||||
if len(seqStr) != len(one_protein['sequence']):
|
||||
print ('ERROR: inconsistent sequence length in HMM section and orignal sequence for protein: ', one_protein['name'] )
|
||||
exit(-1)
|
||||
|
||||
comparison = [ (aa=='X' or bb=='X' or aa==bb) for aa, bb in zip(seqStr, one_protein['sequence']) ]
|
||||
if not all(comparison):
|
||||
print ('ERROR: inconsistent sequence between HMM section and orignal sequence for protein: ', one_protein['name'])
|
||||
print (' original seq: ', one_protein['sequence'])
|
||||
print (' HMM seq: ', seqStr)
|
||||
exit(-1)
|
||||
|
||||
return i + 3*length, one_protein
|
||||
|
||||
## this function reads a profile HMM file generated by HHpred/HHblits package
|
||||
def load_hmm(hmmfile):
|
||||
with open(hmmfile, 'r') as fh:
|
||||
content = [ r.strip() for r in list(fh) ]
|
||||
if not bool(content):
|
||||
print ('ERROR: empty profileHMM file: ', hmmfile)
|
||||
exit(1)
|
||||
if not content[0].startswith('HHsearch'):
|
||||
print ('ERROR: this file may not be a profileHMM file generated by HHpred/HHblits: ', hmmfile)
|
||||
exit(1)
|
||||
if len(content) < 10:
|
||||
print ('ERROR: this profileHMM file is too short: ', hmmfile)
|
||||
exit(1)
|
||||
|
||||
requiredSections = ['name', 'length', 'sequence', 'NEFF', 'hmm1', 'hmm2', 'hmm1_prob', 'hmm1_score', 'PSFM', 'PSSM', 'DateCreated']
|
||||
protein = {}
|
||||
|
||||
## get sequence name
|
||||
if not content[1].startswith('NAME '):
|
||||
print ('ERROR: the protein name shall appear at the second line of profileHMM file: ', hmmfile)
|
||||
exit(1)
|
||||
fields = content[1].split()
|
||||
if len(fields) < 2:
|
||||
print ('ERROR: incorrect name format in profileHMM file: ', hmmfile)
|
||||
exit(1)
|
||||
protein['name'] = fields[1]
|
||||
|
||||
i = 0
|
||||
while i < len(content):
|
||||
row = content[i]
|
||||
if len(row)<1:
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if row.startswith('DATE '):
|
||||
protein['DateCreated'] = row[6:]
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if row.startswith('NEFF '):
|
||||
protein['NEFF'] = np.float32(row.split()[1])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if row.startswith('LENG '):
|
||||
protein['length'] = np.int32(row.split()[1])
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if row.startswith('>ss_pred'):
|
||||
## read the predicted secondary structure
|
||||
start = i+1
|
||||
end = i+1
|
||||
while not content[end].startswith('>'):
|
||||
end += 1
|
||||
protein['SSEseq'] = ''.join(content[start:end]).replace('C', 'L')
|
||||
if len(protein['SSEseq']) != protein['length']:
|
||||
print ('ERROR: inconsistent sequence length and predicted SS sequence length in hmmfile: ', hmmfile)
|
||||
exit(1)
|
||||
i = end
|
||||
continue
|
||||
|
||||
if row.startswith('>ss_conf'):
|
||||
## read the predicted secondary structure confidence score
|
||||
start = i+1
|
||||
end = i+1
|
||||
while not content[end].startswith('>'):
|
||||
end += 1
|
||||
|
||||
SSEconfStr = ''.join(content[start:end])
|
||||
protein['SSEconf'] = [ np.int16(score) for score in SSEconfStr ]
|
||||
|
||||
if len(protein['SSEconf']) != protein['length']:
|
||||
print ('ERROR: inconsistent sequence length and predicted SS confidence sequence length in hmmfile: ', hmmfile)
|
||||
exit(1)
|
||||
|
||||
i = end
|
||||
continue
|
||||
|
||||
|
||||
if row.startswith('>' + protein['name']):
|
||||
## read in the sequence in the following lines
|
||||
start = i+1
|
||||
end = i+1
|
||||
while (not content[end].startswith('>')) and (not content[end].startswith('#')):
|
||||
end += 1
|
||||
|
||||
## at this point, content[end] shall start with >
|
||||
protein['sequence'] = ''.join(content[start:end])
|
||||
if len(protein['sequence']) != protein['length']:
|
||||
print ('ERROR: inconsistent sequence length in hmmfile: ', hmmfile)
|
||||
exit(1)
|
||||
i = end
|
||||
continue
|
||||
|
||||
if len(row) == 1 and row[0]=='#' and content[i+1].startswith('NULL') and content[i+2].startswith('HMM'):
|
||||
i, protein = ReadHHM(content, i+1, protein['length'], protein, numLines4Header=4)
|
||||
continue
|
||||
|
||||
i += 1
|
||||
|
||||
|
||||
## double check to see some required sections are read in
|
||||
for section in requiredSections:
|
||||
#if not protein.has_key(section):
|
||||
if section not in protein.keys():
|
||||
print ('ERROR: one section for ', section, ' is missing in the hmm file: ', hmmfile)
|
||||
print ('ERROR: it is also possible that the hmm file has a format incompatible with this script.')
|
||||
exit(1)
|
||||
|
||||
protein['requiredSections'] = requiredSections
|
||||
|
||||
return protein
|
||||
|
||||
|
||||
|
||||
## for test only
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print ('python LoadHHM.py hhm_file')
|
||||
print ('the input file shall end with .hhm')
|
||||
exit(1)
|
||||
|
||||
file = sys.argv[1]
|
||||
|
||||
if file.endswith('.hhm'):
|
||||
protein = load_hmm(file)
|
||||
else:
|
||||
print ('ERROR: the input file shall have suffix .hmm')
|
||||
exit(1)
|
||||
|
||||
savefile = os.path.basename(file) + '.pkl'
|
||||
fh = open(savefile, 'wb')
|
||||
pkl.dump( protein, fh, protocol=pkl.HIGHEST_PROTOCOL)
|
||||
fh.close()
|
||||
325
pplm_contact/Module.py
Normal file
325
pplm_contact/Module.py
Normal file
@@ -0,0 +1,325 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
def __init__(self, channels, dropout=0.0):
|
||||
super(ResNet, self).__init__()
|
||||
self.channels = channels
|
||||
self.n_layers = 5
|
||||
self.dilations = [1, 2, 4, 2, 1]
|
||||
self.dropout = dropout
|
||||
|
||||
self.blocks = nn.ModuleList()
|
||||
|
||||
for layer in range(self.n_layers):
|
||||
block = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]),
|
||||
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True),
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]),
|
||||
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
self.blocks.append(block)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
_residual = x
|
||||
x = block(x)
|
||||
x = x + _residual
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Intra_ResNet(nn.Module):
|
||||
def __init__(self, dim_1d=768+20, dim_2d=144+2+64, channels=64, dropout=0):
|
||||
super(Intra_ResNet, self).__init__()
|
||||
|
||||
self.dim_1d = dim_1d
|
||||
self.dim_2d = dim_2d
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
|
||||
self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d)
|
||||
self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d)
|
||||
|
||||
self.pair_conv1 = nn.Sequential(
|
||||
nn.Conv2d(self.dim_1d * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
|
||||
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
|
||||
self.pair_conv2 = nn.Sequential(
|
||||
nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
|
||||
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
|
||||
self.pair_conv3 = nn.Sequential(
|
||||
nn.Conv2d(self.channels * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
|
||||
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
|
||||
self.resnet = ResNet(channels=self.channels, dropout=self.dropout)
|
||||
|
||||
def forward(self, x_1d, x_2d):
|
||||
|
||||
x_1d = self.pre_norm_1d(x_1d)
|
||||
x_2d = self.pre_norm_2d(x_2d)
|
||||
|
||||
b, d, L = x_1d.size()
|
||||
x_1d_row = x_1d.unsqueeze(-1).expand(-1, -1, L, L)
|
||||
x_1d_col = x_1d.unsqueeze(-2).expand(-1, -1, L, L)
|
||||
pair_1 = torch.cat([x_1d_row, x_1d_col], dim=1)
|
||||
pair_1 = self.pair_conv1(pair_1)
|
||||
pair_2 = self.pair_conv2(x_2d)
|
||||
|
||||
pair = torch.cat([pair_1, pair_2], dim=1)
|
||||
pair = self.pair_conv3(pair)
|
||||
pair = self.resnet(pair)
|
||||
|
||||
return pair
|
||||
|
||||
|
||||
class Inter_ResNet(nn.Module):
|
||||
def __init__(self, dim_1d=768+20, dim_2d=768 + 20, channels=64, dropout=0):
|
||||
super(Inter_ResNet, self).__init__()
|
||||
|
||||
self.dim_1d = dim_1d
|
||||
self.dim_2d = dim_2d
|
||||
self.channels = channels
|
||||
self.dropout = dropout
|
||||
|
||||
# self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d)
|
||||
self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d)
|
||||
|
||||
self.pair_conv2 = nn.Sequential(
|
||||
nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
|
||||
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
|
||||
self.pair_conv3 = nn.Sequential(
|
||||
nn.Conv2d(self.channels, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False),
|
||||
nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True),
|
||||
nn.LeakyReLU(negative_slope=0.01, inplace=True)
|
||||
)
|
||||
|
||||
self.resnet = ResNet(channels=self.channels, dropout=self.dropout)
|
||||
|
||||
def forward(self, inter_2d):
|
||||
|
||||
inter_2d = self.pre_norm_2d(inter_2d)
|
||||
pair = self.pair_conv2(inter_2d)
|
||||
pair = self.pair_conv3(pair)
|
||||
pair = self.resnet(pair)
|
||||
|
||||
return pair
|
||||
|
||||
|
||||
class InterTriangleMultiplication_S(nn.Module):
|
||||
def __init__(self, channel_z=64, channel_c=64, transpose=False):
|
||||
super(InterTriangleMultiplication_S, self).__init__()
|
||||
|
||||
self.dz = channel_z
|
||||
self.dc = channel_c
|
||||
self.transpose = transpose
|
||||
|
||||
# init norm
|
||||
self.norm_init = nn.LayerNorm(self.dz)
|
||||
|
||||
# linear * gate for com_rec, com_lig
|
||||
self.Linear_inter = nn.Linear(self.dz, self.dc)
|
||||
self.gate_inter = nn.Linear(self.dz, self.dc)
|
||||
|
||||
self.Linear_intra = nn.Linear(self.dz, self.dc)
|
||||
self.gate_intra = nn.Linear(self.dz, self.dc)
|
||||
|
||||
# final
|
||||
self.norm_final = nn.LayerNorm(self.dc)
|
||||
self.Linear_final = nn.Linear(self.dc, self.dz)
|
||||
self.gate_final = nn.Linear(self.dz, self.dz)
|
||||
|
||||
def forward(self, z_inter, z_intra, transpose=False):
|
||||
if self.transpose or transpose:
|
||||
z_inter = z_inter.permute(0, 2, 1, 3)
|
||||
z_intra = z_intra.permute(0, 2, 1, 3)
|
||||
|
||||
z_inter_init = self.norm_init(z_inter)
|
||||
z_intra = self.norm_init(z_intra)
|
||||
|
||||
z_inter = self.Linear_inter(z_inter_init) * self.gate_inter(z_inter_init).sigmoid()
|
||||
z_intra = self.Linear_intra(z_intra) * self.gate_intra(z_intra).sigmoid()
|
||||
|
||||
z_inter_update = torch.einsum('bikc,bkjc->bijc', z_intra, z_inter)
|
||||
|
||||
z_inter_update = self.norm_final(z_inter_update)
|
||||
z_inter = self.Linear_final(z_inter_update) * self.gate_final(z_inter_init).sigmoid()
|
||||
|
||||
if self.transpose or transpose:
|
||||
z_inter = z_inter.permute(0, 2, 1, 3)
|
||||
|
||||
return z_inter
|
||||
|
||||
|
||||
class InterCrossAttention_S(nn.Module):
|
||||
def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=True, transpose=False):
|
||||
super(InterCrossAttention_S, self).__init__()
|
||||
|
||||
self.dz = channel_z
|
||||
self.dhc = dim_head
|
||||
self.num_head = num_head
|
||||
self.dc = self.num_head * self.dhc
|
||||
self.transpose = transpose
|
||||
self.bias = bias
|
||||
|
||||
self.norm_init = nn.LayerNorm(self.dz)
|
||||
self.Linear_Q = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_K = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_V = nn.Linear(self.dz, self.dc)
|
||||
if self.bias:
|
||||
self.Linear_bias = nn.Linear(self.dz, self.num_head)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
self.gate_final = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_final = nn.Linear(self.dc, self.dz)
|
||||
|
||||
|
||||
def reshape_dim(self, x):
|
||||
new_shape = x.size()[:-1] + (self.num_head, self.dhc)
|
||||
return x.view(*new_shape)
|
||||
|
||||
def forward(self, z_inter, z_intra, transpose=False):
|
||||
if self.transpose or transpose:
|
||||
z_inter = z_inter.permute(0, 2, 1, 3)
|
||||
z_intra = z_intra.permute(0, 2, 1, 3)
|
||||
|
||||
|
||||
B, row, col, _ = z_inter.shape
|
||||
z_inter = self.norm_init(z_inter)
|
||||
z_intra = self.norm_init(z_intra)
|
||||
|
||||
q = self.reshape_dim(self.Linear_Q(z_inter))
|
||||
k = self.reshape_dim(self.Linear_K(z_intra))
|
||||
v = self.reshape_dim(self.Linear_V(z_intra))
|
||||
|
||||
scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device))
|
||||
|
||||
attn = torch.einsum('b l j h c, b l i h c -> b l j h i', q * scalar, k)
|
||||
|
||||
if self.bias:
|
||||
bias = self.Linear_bias(z_inter).unsqueeze(-1)
|
||||
attn = attn + bias
|
||||
|
||||
if attn.dtype is torch.bfloat16:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights = torch.nn.functional.softmax(attn, -1)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn, -1)
|
||||
|
||||
attn = torch.einsum('b l j h i, b l i h c -> b l j h c', attn_weights, v)
|
||||
|
||||
gate_final = self.reshape_dim(self.gate_final(z_inter)).sigmoid()
|
||||
z_inter = (attn * gate_final).contiguous().view(gate_final.size()[:-2] + (-1,))
|
||||
|
||||
z_inter = self.Linear_final(z_inter)
|
||||
|
||||
if self.transpose or transpose:
|
||||
z_inter = z_inter.permute(0, 2, 1, 3)
|
||||
|
||||
return z_inter
|
||||
|
||||
|
||||
class InterSelfAttention_S(nn.Module):
|
||||
def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=False, transpose=False):
|
||||
super(InterSelfAttention_S, self).__init__()
|
||||
|
||||
self.dz = channel_z
|
||||
self.dhc = dim_head
|
||||
self.num_head = num_head
|
||||
self.dc = self.num_head * self.dhc
|
||||
self.transpose = transpose
|
||||
self.bias = bias
|
||||
|
||||
self.norm_init = nn.LayerNorm(self.dz)
|
||||
self.Linear_Q = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_K = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_V = nn.Linear(self.dz, self.dc)
|
||||
if self.bias:
|
||||
self.Linear_bias = nn.Linear(self.dz, self.num_head)
|
||||
|
||||
self.softmax = nn.Softmax(-1)
|
||||
self.gate_v = nn.Linear(self.dz, self.dc)
|
||||
self.Linear_final = nn.Linear(self.dc, self.dz)
|
||||
|
||||
|
||||
def reshape_dim(self, x):
|
||||
new_shape = x.size()[:-1] + (self.num_head, self.dhc)
|
||||
return x.view(*new_shape)
|
||||
|
||||
def forward(self, z_inter, dist=None, transpose=False):
|
||||
if self.transpose or transpose:
|
||||
z_inter = z_inter.permute(0, 2, 1, 3)
|
||||
|
||||
B, row, col, _ = z_inter.shape
|
||||
z_inter = self.norm_init(z_inter)
|
||||
|
||||
q = self.reshape_dim(self.Linear_Q(z_inter))
|
||||
k = self.reshape_dim(self.Linear_K(z_inter))
|
||||
v = self.reshape_dim(self.Linear_V(z_inter))
|
||||
|
||||
scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device))
|
||||
|
||||
attn = torch.einsum('b l i h c, b l j h c -> b l i h j', q * scalar, k)
|
||||
|
||||
if self.bias:
|
||||
bias = self.Linear_bias(z_inter)
|
||||
# print("bias:", bias.shape, attn.shape, z_inter.shape)
|
||||
attn = attn + bias
|
||||
|
||||
if dist != None:
|
||||
coef = torch.exp(-(dist/8.0)**2.0/2.0).unsqueeze(3).type_as(q)
|
||||
attn = attn * coef
|
||||
|
||||
if attn.dtype is torch.bfloat16:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
attn_weights = torch.nn.functional.softmax(attn, -1)
|
||||
else:
|
||||
attn_weights = torch.nn.functional.softmax(attn, -1)
|
||||
|
||||
v_avg = torch.einsum('b l i h j, b l j h c -> b l i h c', attn_weights, v)
|
||||
|
||||
gate_v = self.reshape_dim(self.gate_v(z_inter)).sigmoid()
|
||||
z_com = (v_avg * gate_v).contiguous().view( v_avg.size()[:-2] + (-1,) )
|
||||
|
||||
z_final = self.Linear_final(z_com)
|
||||
|
||||
if self.transpose or transpose:
|
||||
z_final = z_final.permute(0, 2, 1, 3)
|
||||
|
||||
return z_final
|
||||
|
||||
|
||||
class Transition(nn.Module):
|
||||
def __init__(self, channel_z=64, transition_n=4):
|
||||
super(Transition, self).__init__()
|
||||
|
||||
self.dz = channel_z
|
||||
self.n = transition_n
|
||||
|
||||
self.norm = nn.LayerNorm(self.dz)
|
||||
self.transition = nn.Sequential(
|
||||
nn.Linear(self.dz, self.dz*self.n),
|
||||
nn.Linear(self.dz*self.n, self.dz)
|
||||
)
|
||||
|
||||
def forward(self, z_com):
|
||||
z_com = self.norm(z_com)
|
||||
z_com = self.transition(z_com)
|
||||
|
||||
return z_com
|
||||
|
||||
16
pplm_contact/config.py
Normal file
16
pplm_contact/config.py
Normal file
@@ -0,0 +1,16 @@
|
||||
import os
|
||||
|
||||
hhsuite_dir = "/mnt/rna01/junl/tools/hh-suite/build"
|
||||
|
||||
hhblits = os.path.join(hhsuite_dir, "bin/hhblits")
|
||||
hhmake = os.path.join(hhsuite_dir, "bin/hhmake")
|
||||
reformat = os.path.join(hhsuite_dir, "scripts/reformat.pl")
|
||||
hhfilter = os.path.join(hhsuite_dir, "bin/hhfilter")
|
||||
UniRef_database = "/mnt/dna01/library2/e2e_folding/alphafold/v2.3/lib/uniref30/UniRef30_2021_03"
|
||||
|
||||
|
||||
ccmpred = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/ccmpred"
|
||||
esm_msa = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/extract_esm_msa_features.py"
|
||||
esm_msa_model = "/mnt/rna01/junl/tools/esm-main/esm_models/esm_msa1_t12_100M_UR50S.pt"
|
||||
|
||||
|
||||
72
pplm_contact/model.py
Normal file
72
pplm_contact/model.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
mian_path = os.path.dirname(__file__) + "/../"
|
||||
sys.path.append(os.path.abspath(mian_path))
|
||||
|
||||
from pplm_contact.Module import *
|
||||
|
||||
class PPLM_Contact(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
intra_1d_dim=768+20,
|
||||
intra_2d_dim=144+2+64,
|
||||
inter_2d_dim=144+2+660,
|
||||
channels=64,
|
||||
num_blocks=12,
|
||||
droupout=0.10,
|
||||
):
|
||||
super(PPLM_Contact, self).__init__()
|
||||
self.intra_1d_dim = intra_1d_dim
|
||||
self.intra_2d_dim = intra_2d_dim
|
||||
self.inter_2d_dim = inter_2d_dim
|
||||
self.channels = channels
|
||||
self.num_blocks = num_blocks
|
||||
self.droupout = droupout
|
||||
|
||||
### ResNet
|
||||
self.intra_resnet = Intra_ResNet(dim_1d=self.intra_1d_dim, dim_2d=self.intra_2d_dim, channels=self.channels)
|
||||
self.inter_resnet = Inter_ResNet(dim_2d=self.inter_2d_dim, channels=self.channels)
|
||||
|
||||
### Transformer
|
||||
self.InterTriangleMulti_R = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=False) for _ in range(num_blocks)])
|
||||
self.InterTriangleMulti_L = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=True) for _ in range(num_blocks)])
|
||||
self.InterCrossAttn_R = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=False) for _ in range(num_blocks)])
|
||||
self.InterCrossAttn_L = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=True) for _ in range(num_blocks)])
|
||||
self.InterSelfAttn_R = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=False) for _ in range(num_blocks)])
|
||||
self.InterSelfAttn_L = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=True) for _ in range(num_blocks)])
|
||||
self.Transition = nn.ModuleList([Transition(channel_z=self.channels) for _ in range(num_blocks)])
|
||||
self.drop = nn.Dropout(self.droupout)
|
||||
|
||||
self.norm_final = nn.LayerNorm(self.channels)
|
||||
self.Linear_final = nn.Linear(self.channels, 1)
|
||||
self.act_final = nn.Sigmoid()
|
||||
|
||||
def forward(self, intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_dist=None, intra2_dist=None, last_layer=False):
|
||||
|
||||
intra1_2d = self.intra_resnet(intra1_1d, intra1_2d)
|
||||
intra2_2d = self.intra_resnet(intra2_1d, intra2_2d)
|
||||
inter_2d = self.inter_resnet(inter_2d)
|
||||
|
||||
intra1_2d = intra1_2d.permute(0, 2, 3, 1)
|
||||
intra2_2d = intra2_2d.permute(0, 2, 3, 1)
|
||||
inter_2d = inter_2d.permute(0, 2, 3, 1)
|
||||
|
||||
|
||||
for block in range(self.num_blocks):
|
||||
inter_2d = inter_2d + self.drop(self.InterTriangleMulti_R[block](inter_2d, intra1_2d)) + self.drop(self.InterTriangleMulti_L[block](inter_2d, intra2_2d, transpose=True))
|
||||
inter_2d = inter_2d + self.drop(self.InterCrossAttn_R[block](inter_2d, intra1_2d)) + self.drop(self.InterCrossAttn_L[block](inter_2d, intra2_2d, transpose=True))
|
||||
inter_2d = inter_2d + self.drop(self.InterSelfAttn_R[block](inter_2d, intra2_dist)) + self.drop(self.InterSelfAttn_L[block](inter_2d, intra1_dist, transpose=True))
|
||||
inter_2d = inter_2d + self.Transition[block](inter_2d)
|
||||
|
||||
inter_2d = self.norm_final(inter_2d)
|
||||
if last_layer:
|
||||
representation = inter_2d.permute(0,3,1,2)
|
||||
inter_2d = self.Linear_final(inter_2d)
|
||||
inter_contact = self.act_final(inter_2d)
|
||||
|
||||
if last_layer:
|
||||
return inter_contact.permute(0,3,1,2)[0], representation
|
||||
|
||||
return inter_contact.permute(0,3,1,2)[0]
|
||||
353
pplm_contact/predict.py
Normal file
353
pplm_contact/predict.py
Normal file
@@ -0,0 +1,353 @@
|
||||
import os
|
||||
import sys
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
import subprocess
|
||||
import argparse
|
||||
from model import PPLM_Contact
|
||||
from utils import extract_seq_and_dist_map, pairing_msa, RBF
|
||||
from LoadHHM import load_hmm
|
||||
from config import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Protein-Protein Contact Prediction",
|
||||
epilog="v0.0.1")
|
||||
|
||||
parser.add_argument("pdbA_path",
|
||||
type=pathlib.Path,
|
||||
help="Location of pdb A")
|
||||
|
||||
parser.add_argument("pdbB_path",
|
||||
type=pathlib.Path,
|
||||
help="Location of pdb B")
|
||||
|
||||
parser.add_argument("output_folder",
|
||||
type=pathlib.Path,
|
||||
help="Location to store output files")
|
||||
|
||||
parser.add_argument("--n_cpu",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of CPU cores for search MSA",
|
||||
)
|
||||
|
||||
parser.add_argument("--gpu_id",
|
||||
type=int,
|
||||
default=0,
|
||||
help="gpu device specified",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
### Define parameters ###
|
||||
define_param(args)
|
||||
|
||||
##### Step 0: Process the input pdb (clean pdb & extract sequence & derive monomer distance map) #####
|
||||
print("========== Step 0: Processing the intput pdb (Clean pdb, extract seq and dist map) ==========")
|
||||
|
||||
subprocess.run("grep \"^ATOM\" " + str(args.pdbA_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target1_pdb, shell=True, check=True)
|
||||
target1_res_idx_type = extract_seq_and_dist_map(target1_pdb, target1_seq, target1_monomer_dist)
|
||||
if mode == "hetero":
|
||||
subprocess.run("grep \"^ATOM\" " + str(args.pdbB_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target2_pdb, shell=True, check=True)
|
||||
target2_res_idx_type = extract_seq_and_dist_map(target2_pdb, target2_seq, target2_monomer_dist)
|
||||
else:
|
||||
target2_res_idx_type = target1_res_idx_type
|
||||
|
||||
##### Step 1: Search MSA #####
|
||||
print("========== Step 1: Searching MSA ==========")
|
||||
if not os.path.isfile(target1_msa):
|
||||
print("Searching MSA for", target1_seq)
|
||||
subprocess.run(hhblits + " -i " + target1_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target1_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True)
|
||||
if mode == "hetero" and not os.path.isfile(target2_msa):
|
||||
print("Searching MSA for", target2_seq)
|
||||
subprocess.run(hhblits + " -i " + target2_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target2_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True)
|
||||
|
||||
##### Step 2: Extract Monomer MSA features
|
||||
print("========== Step 2: Extract Monomer MSA features ==========")
|
||||
if os.path.isfile(target1_msa):
|
||||
extract_MSA_features(target1, target1_msa, target1_hhm, target1_aln, target1_dca_di, target1_dca_apc, target1_esm_msa)
|
||||
if mode == "hetero" and os.path.isfile(target2_msa):
|
||||
extract_MSA_features(target2, target2_msa, target2_hhm, target2_aln, target2_dca_di, target2_dca_apc, target2_esm_msa)
|
||||
|
||||
##### Step 3: Extract paired MSA features
|
||||
print("========== Step 3: Extract paired MSA features ==========")
|
||||
if mode == "hetero":
|
||||
pairing_msa(target1_msa, target2_msa, paired_msa)
|
||||
extract_MSA_features(target, paired_msa, paired_hhm, paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa)
|
||||
|
||||
##### Step 4: Generate PPLM inter-protein attention matrix
|
||||
print("========== Step 4: Generate PPLM features ==========")
|
||||
if not os.path.isfile(pplm_feat):
|
||||
get_pplm_features(target1_seq, target2_seq, pplm_feat, device='cpu')
|
||||
|
||||
##### Step 5: Collect all features
|
||||
print("========== Step 5: Collect all features ==========")
|
||||
feats = collect_all_features()
|
||||
|
||||
##### Step 6: Predict inter-protein contact
|
||||
print("========== Step 6: Predict inter-protein contact ==========")
|
||||
if not os.path.isfile(pred_contact_pkl_path) or not os.path.isfile(pred_contact_txt_path):
|
||||
pred_inter_contatct = predict_contact(feats, mode, device=device)
|
||||
|
||||
with open(pred_contact_pkl_path, "wb") as fw:
|
||||
pickle.dump(pred_inter_contatct, fw)
|
||||
|
||||
with open(pred_contact_txt_path, "w") as fw:
|
||||
prediction_idx_prob = []
|
||||
for i in range(pred_inter_contatct.shape[0]):
|
||||
for j in range(pred_inter_contatct.shape[1]):
|
||||
prediction_idx_prob.append([i, j, pred_inter_contatct[i, j]])
|
||||
|
||||
prediction_idx_prob = sorted(prediction_idx_prob, key=lambda x: x[2], reverse=True)
|
||||
|
||||
data = "{:<10}".format("Rank") + "{:<10}".format("ResIdx1") + "{:<10}".format("ResType1") + "{:<10}".format("ResIdx2") + "{:<10}".format("ResType2") + "{:<20}".format("Contact_Probability") + "\n"
|
||||
fw.write(data)
|
||||
for k in range(len(prediction_idx_prob)):
|
||||
res1 = prediction_idx_prob[k][0]
|
||||
res2 = prediction_idx_prob[k][1]
|
||||
prob = prediction_idx_prob[k][2]
|
||||
res1_idx = target1_res_idx_type[res1][0]
|
||||
res1_type = target1_res_idx_type[res1][1]
|
||||
res2_idx = target2_res_idx_type[res2][0]
|
||||
res2_type = target2_res_idx_type[res2][1]
|
||||
|
||||
data = "{:<10}".format(k+1) + "{:<10}".format(str(res1_idx) + ":A") + "{:<10}".format(res1_type) + "{:<10}".format(str(res2_idx) + ":B") + "{:<10}".format(res2_type) + "{:<10}".format(f"{prob:.6g}") + "\n"
|
||||
fw.write(data)
|
||||
|
||||
|
||||
|
||||
def define_param(args):
|
||||
# param = {}
|
||||
global target1, target2, target, workspace, mode, device
|
||||
global target1_pdb, target1_seq, target1_monomer_dist, target2_pdb, target2_seq
|
||||
global target2_monomer_dist, target1_msa, target2_msa, target1_hhm, target1_aln
|
||||
global target1_dca_di, target1_dca_apc, target1_esm_msa, target2_hhm, target2_aln
|
||||
global target2_dca_di, target2_dca_apc, target2_esm_msa, paired_msa, paired_hhm
|
||||
global paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa, pplm_feat
|
||||
global pred_contact_pkl_path, pred_contact_txt_path
|
||||
|
||||
|
||||
target1 = args.pdbA_path.stem
|
||||
target2 = args.pdbB_path.stem
|
||||
target = str(args.output_folder).split('/')[-1] # target1 + '-' + target2
|
||||
workspace = args.output_folder
|
||||
if not os.path.isdir(workspace):
|
||||
os.makedirs(workspace)
|
||||
|
||||
assigned_device = "cuda:" + str(args.gpu_id)
|
||||
device = assigned_device if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if args.pdbA_path == args.pdbB_path:
|
||||
mode = "homo"
|
||||
else:
|
||||
mode = "hetero"
|
||||
|
||||
print("device:", device, " mode:", mode)
|
||||
|
||||
target1_pdb = os.path.join(workspace, target1 + ".clean.pdb")
|
||||
target1_seq = os.path.join(workspace, target1 + ".fasta")
|
||||
target1_monomer_dist = os.path.join(workspace, target1 + ".monomer_dist.pkl")
|
||||
target2_pdb = os.path.join(workspace, target2 + ".clean.pdb")
|
||||
target2_seq = os.path.join(workspace, target2 + ".fasta")
|
||||
target2_monomer_dist = os.path.join(workspace, target2 + ".monomer_dist.pkl")
|
||||
target1_msa = os.path.join(workspace, target1 + ".a3m")
|
||||
target2_msa = os.path.join(workspace, target2 + ".a3m")
|
||||
target1_hhm = os.path.join(workspace, target1 + ".hhm")
|
||||
target1_aln = os.path.join(workspace, target1 + ".aln")
|
||||
target1_dca_di = os.path.join(workspace, target1 + ".dca_di")
|
||||
target1_dca_apc = os.path.join(workspace, target1 + ".dca_apc")
|
||||
target1_esm_msa = os.path.join(workspace, target1 + ".esm_msa.pkl")
|
||||
target2_hhm = os.path.join(workspace, target2 + ".hhm")
|
||||
target2_aln = os.path.join(workspace, target2 + ".aln")
|
||||
target2_dca_di = os.path.join(workspace, target2 + ".dca_di")
|
||||
target2_dca_apc = os.path.join(workspace, target2 + ".dca_apc")
|
||||
target2_esm_msa = os.path.join(workspace, target2 + ".esm_msa.pkl")
|
||||
paired_msa = os.path.join(workspace, target + "_paired.a3m")
|
||||
paired_hhm = os.path.join(workspace, target + "_paired.hhm")
|
||||
paired_aln = os.path.join(workspace, target + "_paired.aln")
|
||||
paired_dca_di = os.path.join(workspace, target + "_paired.dca_di")
|
||||
paired_dca_apc = os.path.join(workspace, target + "_paired.dca_apc")
|
||||
paired_esm_msa = os.path.join(workspace, target + "_paired.esm_msa.pkl")
|
||||
pplm_feat = os.path.join(workspace, target + ".pplm.pkl")
|
||||
pred_contact_pkl_path = os.path.join(workspace, target + ".pred_contact.pkl")
|
||||
pred_contact_txt_path = os.path.join(workspace, target + ".pred_contact.txt")
|
||||
|
||||
def extract_MSA_features(name, msa, hhm, aln, dci_di, dci_apc, esm_msa_path):
|
||||
|
||||
print("Generating the HHM file for", msa)
|
||||
subprocess.run(hhmake + " -i " + msa + " -o " + hhm, shell=True, check=True)
|
||||
subprocess.run(reformat + " " + msa + " " + os.path.join(workspace, name + ".fas") + " -r -l 2000 >/dev/null", shell=True, check=True)
|
||||
subprocess.run("awk '{if(!($0~/^>/)){print}}' " + os.path.join(workspace, name + ".fas") + " > " + aln, shell=True, check=True)
|
||||
|
||||
# generate the DCA features
|
||||
print("Generating the DCA features for", aln)
|
||||
if not os.path.isfile(dci_di) or not os.path.isfile(dci_apc):
|
||||
subprocess.run(ccmpred + " " + aln + " " + dci_di + " -R -A", shell=True, check=True)
|
||||
subprocess.run(ccmpred + " " + aln + " " + dci_apc + " -R", shell=True, check=True)
|
||||
|
||||
# generate the ESM-MSA features
|
||||
print("Generating the ESM-MSA features for", msa)
|
||||
if not os.path.isfile(esm_msa_path):
|
||||
subprocess.run(hhfilter + " -i " + msa + " -o " + msa + "_filtered -diff 512", shell=True, check=True)
|
||||
subprocess.run(["python", esm_msa, esm_msa_model, msa + "_filtered", esm_msa_path])
|
||||
|
||||
def get_pplm_features(seqA_path, seqB_path, out_pkl_path, device='cpu'):
|
||||
mian_path = os.path.dirname(__file__) + "/../"
|
||||
sys.path.append(os.path.abspath(mian_path))
|
||||
|
||||
from pplm import PPLM, Alphabet
|
||||
model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt')
|
||||
|
||||
##### Loading PPLM Model #####
|
||||
alphabet = Alphabet.from_architecture()
|
||||
batch_converter = alphabet.get_batch_converter()
|
||||
model_data = torch.load(model_location, map_location="cpu")
|
||||
model_param = model_data["param"]
|
||||
model_state = model_data["model"]
|
||||
|
||||
model = PPLM(
|
||||
num_layers=model_param['encoder_layers'],
|
||||
embed_dim=model_param['encoder_embed_dim'],
|
||||
attention_heads=model_param['encoder_attention_heads'],
|
||||
token_dropout=False,
|
||||
alphabet=alphabet
|
||||
)
|
||||
model.to(device)
|
||||
model.load_state_dict(model_state, strict=False)
|
||||
|
||||
with torch.no_grad():
|
||||
seqA, seqB = '', ''
|
||||
for line in open(seqA_path, "r").readlines():
|
||||
if not line.startswith(">"):
|
||||
seqA += line.strip()
|
||||
for line in open(seqB_path, "r").readlines():
|
||||
if not line.startswith(">"):
|
||||
seqB += line.strip()
|
||||
|
||||
seqA_labels, seqA_strs, seqA_tokens = batch_converter([('seqA', seqA)])
|
||||
seqB_labels, seqB_strs, seqB_tokens = batch_converter([('seqB', seqB)])
|
||||
tokens = torch.cat([seqA_tokens, seqB_tokens], dim=-1).to(device)
|
||||
|
||||
inter_chain_mask = torch.ones((len(seqA) + 2 + len(seqB) + 2, len(seqA) + 2 + len(seqB) + 2), device=device)
|
||||
inter_chain_mask[:len(seqA) + 2, :len(seqA) + 2] = 0
|
||||
inter_chain_mask[len(seqA) + 2:, len(seqA) + 2:] = 0
|
||||
|
||||
##### running PPLM #####
|
||||
out = model(tokens, inter_chain_mask, repr_layers=[33], need_head_weights=True, return_contacts=False)
|
||||
|
||||
# attn_AA = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), 1:(len(seqA) + 1)].reshape(33 * 20, len(seqA), len(seqA)).cpu().numpy()
|
||||
attn_AB = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), -(len(seqB) + 1):-1].reshape(33 * 20, len(seqA), len(seqB)).cpu().numpy()
|
||||
attn_BA = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, 1:(len(seqA) + 1)].reshape(33 * 20, len(seqB), len(seqA)).cpu().numpy()
|
||||
# attn_BB = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, -(len(seqB) + 1):-1].reshape(33 * 20, len(seqB), len(seqB)).cpu().numpy()
|
||||
inter_attn = (attn_AB + attn_BA.transpose(0, 2, 1)) / 2
|
||||
|
||||
with open(out_pkl_path, mode='wb') as fw:
|
||||
pickle.dump(inter_attn, fw)
|
||||
|
||||
def collect_all_features():
|
||||
|
||||
with open(target1_monomer_dist, "rb") as fr:
|
||||
target1_M_dist = pickle.load(fr)
|
||||
|
||||
target1_DCA_DI = np.expand_dims(np.loadtxt(target1_dca_di), 0)
|
||||
target1_DCA_APC = np.expand_dims(np.loadtxt(target1_dca_apc), 0)
|
||||
target1_PSSM = load_hmm(target1_hhm)['PSSM']
|
||||
|
||||
with open(target1_esm_msa, "rb") as fr:
|
||||
esm_msa_data = pickle.load(fr)
|
||||
target1_esm_msa_1d = esm_msa_data['esm_msa_1d']
|
||||
target1_esm_msa_2d = esm_msa_data['row_attentions']
|
||||
|
||||
print(target1_DCA_DI.shape, target1_DCA_APC.shape, target1_esm_msa_2d.shape, RBF(target1_M_dist).shape, target1_M_dist.shape)
|
||||
intra1_1d = np.concatenate([target1_PSSM, target1_esm_msa_1d], axis=-1).transpose(1,0)
|
||||
intra1_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, RBF(target1_M_dist)], axis=0)
|
||||
intra1_Mdist = target1_M_dist
|
||||
|
||||
with open(pplm_feat, "rb") as fr:
|
||||
inter_pplm_attn = pickle.load(fr)
|
||||
|
||||
if mode == "homo":
|
||||
intra2_1d = intra1_1d
|
||||
intra2_2d = intra1_2d
|
||||
intra2_Mdist = intra1_Mdist
|
||||
inter_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, inter_pplm_attn], axis=0)
|
||||
|
||||
else:
|
||||
with open(target2_monomer_dist, "rb") as fr:
|
||||
target2_M_dist = pickle.load(fr)
|
||||
|
||||
target2_DCA_DI = np.expand_dims(np.loadtxt(target2_dca_di), 0)
|
||||
target2_DCA_APC = np.expand_dims(np.loadtxt(target2_dca_apc), 0)
|
||||
target2_PSSM = load_hmm(target2_hhm)['PSSM']
|
||||
|
||||
with open(target2_esm_msa, "rb") as fr:
|
||||
esm_msa_data = pickle.load(fr)
|
||||
target2_esm_msa_1d = esm_msa_data['esm_msa_1d']
|
||||
target2_esm_msa_2d = esm_msa_data['row_attentions']
|
||||
|
||||
print(target2_DCA_DI.shape, target2_DCA_APC.shape, target2_esm_msa_2d.shape, RBF(target2_M_dist).shape, target2_M_dist.shape)
|
||||
intra2_1d = np.concatenate([target2_PSSM, target2_esm_msa_1d], axis=-1).transpose(1, 0)
|
||||
intra2_2d = np.concatenate([target2_DCA_DI, target2_DCA_APC, target2_esm_msa_2d, RBF(target2_M_dist)], axis=0)
|
||||
intra2_Mdist = target2_M_dist
|
||||
|
||||
inter_DCA_DI = np.expand_dims(np.loadtxt(paired_dca_di), 0)
|
||||
inter_DCA_APC = np.expand_dims(np.loadtxt(paired_dca_apc), 0)
|
||||
with open(paired_esm_msa, "rb") as fr:
|
||||
esm_msa_data = pickle.load(fr)
|
||||
inter_esm_msa_2d = esm_msa_data['row_attentions']
|
||||
|
||||
len1 = inter_pplm_attn.shape[-2]
|
||||
len2 = inter_esm_msa_2d.shape[-1]
|
||||
|
||||
print(inter_DCA_DI.shape, inter_DCA_APC.shape, inter_esm_msa_2d.shape, inter_pplm_attn.shape)
|
||||
inter_2d = np.concatenate([inter_DCA_DI[:, :len1, len1:len1+len2], inter_DCA_APC[:, :len1, len1:len1+len2], inter_esm_msa_2d[:, :len1, len1:len1+len2], inter_pplm_attn], axis=0)
|
||||
|
||||
feats = {"intra1_1d": intra1_1d, "intra1_2d": intra1_2d, "intra1_Mdist": intra1_Mdist, "intra2_1d": intra2_1d, "intra2_2d": intra2_2d, "intra2_Mdist": intra2_Mdist, "inter_2d": inter_2d}
|
||||
|
||||
return feats
|
||||
|
||||
def predict_contact(feats, mode, device="cpu"):
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if mode == "homo":
|
||||
model_paths = [os.path.join(script_dir, "models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)]
|
||||
else:
|
||||
model_paths = [os.path.join(script_dir, "models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)]
|
||||
|
||||
model = PPLM_Contact()
|
||||
model.to(device)
|
||||
|
||||
intra1_1d = torch.Tensor(feats["intra1_1d"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
intra1_2d = torch.Tensor(feats["intra1_2d"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
intra1_Mdist = torch.Tensor(feats["intra1_Mdist"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
intra2_1d = torch.Tensor(feats["intra2_1d"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
intra2_2d = torch.Tensor(feats["intra2_2d"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
intra2_Mdist = torch.Tensor(feats["intra2_Mdist"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
inter_2d = torch.Tensor(feats["inter_2d"]).to(device).type(torch.float32).unsqueeze(0)
|
||||
|
||||
ensemble_pred_inter_contact = []
|
||||
with torch.no_grad():
|
||||
for model_path in model_paths:
|
||||
|
||||
checkpoint = torch.load(model_path, map_location=device)
|
||||
model.load_state_dict(checkpoint["model_state_dict"])
|
||||
model.eval()
|
||||
|
||||
#################################### Network predict #######################################
|
||||
contact_pred = model(intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_Mdist, intra2_Mdist)
|
||||
|
||||
pred_inter_contact = contact_pred # [inter_contact_mask_ur]
|
||||
ensemble_pred_inter_contact.append(pred_inter_contact)
|
||||
|
||||
pred_inter_contact = torch.stack(ensemble_pred_inter_contact)
|
||||
pred_inter_contact = torch.mean(pred_inter_contact, dim=0).squeeze().cpu().detach().numpy()
|
||||
|
||||
return pred_inter_contact
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
225
pplm_contact/utils.py
Normal file
225
pplm_contact/utils.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import numpy as np
|
||||
import pickle
|
||||
import string
|
||||
|
||||
restype_3to1 = {k: v for k, v in zip(['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'], 'ARNDCEQGHILKMFPSTWYV')}
|
||||
|
||||
heavy_atoms = ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG']
|
||||
|
||||
def pdb2seq(pdb_path, seq_path):
|
||||
sequence_list = []
|
||||
sequence = ""
|
||||
last_chain_id = ""
|
||||
for data in open(pdb_path, 'r').readlines():
|
||||
atom_name = data[13:16].strip()
|
||||
if data.startswith("ATOM") and atom_name == 'CA':
|
||||
chain_id = data[21]
|
||||
res_name = data[17:20].strip()
|
||||
|
||||
if chain_id != last_chain_id and len(sequence) > 0:
|
||||
sequence_list.append([last_chain_id, sequence])
|
||||
sequence = restype_3to1[res_name]
|
||||
else:
|
||||
sequence += restype_3to1[res_name]
|
||||
|
||||
last_chain_id = chain_id
|
||||
|
||||
if len(sequence) > 0:
|
||||
sequence_list.append([last_chain_id, sequence])
|
||||
|
||||
if len(sequence_list) > 1:
|
||||
print("Warning:", pdb_path, "contain multiple chains:", sequence_list[:, 0], "! Only first chain is considered.")
|
||||
|
||||
return sequence_list[0][1]
|
||||
|
||||
with open(seq_path, 'w') as fw:
|
||||
chain_id = sequence_list[0][0]
|
||||
sequence = sequence_list[0][1]
|
||||
fw.write(">seq_" + chain_id + " " + str(len(sequence)) + "\n")
|
||||
fw.write(sequence + "\n")
|
||||
|
||||
def extract_seq_and_dist_map(pdb_path, seq_path, dist_path):
|
||||
### Load pdb_chain ###
|
||||
pdb_res_coordis = []
|
||||
res_atom_coordis = {}
|
||||
res_idx_type = []
|
||||
sequence = ''
|
||||
last_res_idx = -1
|
||||
last_res_name = ''
|
||||
for data in open(pdb_path, 'r').readlines():
|
||||
if data.startswith("ATOM"):
|
||||
atom_name = data[13:16].strip()
|
||||
res_name = data[17:20].strip()
|
||||
res_idx = int(data[22:26].strip())
|
||||
coordi_x = float(data[30:38].strip())
|
||||
coordi_y = float(data[38:46].strip())
|
||||
coordi_z = float(data[46:54].strip())
|
||||
if last_res_idx != -1 and res_idx != last_res_idx:
|
||||
pdb_res_coordis.append(res_atom_coordis)
|
||||
res_atom_coordis = {}
|
||||
sequence += restype_3to1[last_res_name]
|
||||
res_idx_type.append([last_res_idx, last_res_name])
|
||||
res_atom_coordis[atom_name] = [coordi_x, coordi_y, coordi_z]
|
||||
last_res_idx = res_idx
|
||||
last_res_name = res_name
|
||||
if len(res_atom_coordis) != 0:
|
||||
pdb_res_coordis.append(res_atom_coordis)
|
||||
res_atom_coordis = {}
|
||||
sequence += restype_3to1[last_res_name]
|
||||
res_idx_type.append([last_res_idx, last_res_name])
|
||||
|
||||
length = len(sequence)
|
||||
|
||||
############## extract heavy atom distance ##################
|
||||
heavy_atom_dist_map = np.ones((1, length, length)) * np.inf
|
||||
for i in range(length):
|
||||
for j in range(i, length):
|
||||
min_dist = np.inf
|
||||
for heay_i in heavy_atoms:
|
||||
if heay_i in pdb_res_coordis[i]:
|
||||
coordi_1 = pdb_res_coordis[i][heay_i]
|
||||
else:
|
||||
continue
|
||||
for heay_j in heavy_atoms:
|
||||
if heay_j in pdb_res_coordis[j]:
|
||||
coordi_2 = pdb_res_coordis[j][heay_j]
|
||||
else:
|
||||
continue
|
||||
dist = np.sqrt(pow(coordi_1[0] - coordi_2[0], 2) + pow(coordi_1[1] - coordi_2[1], 2) + pow(coordi_1[2] - coordi_2[2], 2))
|
||||
if dist < min_dist:
|
||||
min_dist = dist
|
||||
heavy_atom_dist_map[0, i, j] = min_dist
|
||||
heavy_atom_dist_map[0, j, i] = min_dist
|
||||
|
||||
with open(dist_path, mode='wb') as fw:
|
||||
pickle.dump(heavy_atom_dist_map, fw)
|
||||
|
||||
with open(seq_path, 'w') as fw:
|
||||
fw.write(">seq " + str(length) + "\n")
|
||||
fw.write(sequence + "\n")
|
||||
|
||||
return res_idx_type
|
||||
|
||||
def pairing_msa(msa1_path, msa2_path, paired_msa_path):
|
||||
msas1, sid1 = extract_taxid(msa1_path)
|
||||
msas2, sid2 = extract_taxid(msa2_path)
|
||||
aligns = alignment(msas1, sid1, msas2, sid2, top=True)
|
||||
|
||||
with open(paired_msa_path, 'w') as f:
|
||||
f.write(">target " + str(len(aligns[0])) + "\n")
|
||||
f.write(aligns[0] + "\n")
|
||||
|
||||
for idx, aligned_seq in enumerate(aligns[1:]):
|
||||
f.write(">seq" + str(idx+1) + "\n")
|
||||
f.write(aligned_seq + "\n")
|
||||
def extract_taxid(file, gap_cutoff=0.8):
|
||||
deletekeys = dict.fromkeys(string.ascii_lowercase)
|
||||
deletekeys["."] = None
|
||||
deletekeys["*"] = None
|
||||
translation = str.maketrans(deletekeys)
|
||||
|
||||
lines = open(file, 'r').readlines()
|
||||
query = lines[1].strip().translate(translation)
|
||||
seq_len = len(query)
|
||||
|
||||
msas = [query]
|
||||
sid = [0]
|
||||
for line in lines[2:]:
|
||||
|
||||
if line[0] == ">":
|
||||
if "TaxID=" in line:
|
||||
content = line.split("TaxID=")[1]
|
||||
if len(content) > 0:
|
||||
try:
|
||||
sid.append(int(content.split()[0]))
|
||||
except:
|
||||
sid.append(0)
|
||||
elif "OX=" in line:
|
||||
content = line.split("OX=")[1]
|
||||
if len(content) > 0:
|
||||
try:
|
||||
sid.append(int(content.split()[0]))
|
||||
except:
|
||||
sid.append(0)
|
||||
else:
|
||||
sid.append(0)
|
||||
continue
|
||||
|
||||
seq = line.strip().translate(translation)
|
||||
gap_fra = float(seq.count('-')) / seq_len
|
||||
if gap_fra <= gap_cutoff:
|
||||
msas.append(seq)
|
||||
else:
|
||||
sid.pop(-1)
|
||||
|
||||
if len(msas) != len(sid):
|
||||
print("ERROR: len(msas) != len(sid)")
|
||||
print(len(msas), len(sid))
|
||||
exit()
|
||||
|
||||
return msas, np.array(sid)
|
||||
|
||||
def cal_identity(query, sub_msas):
|
||||
"""
|
||||
Args:
|
||||
query : str
|
||||
sub_msas : List[str]
|
||||
Return:
|
||||
identity : np.array
|
||||
"""
|
||||
|
||||
identity = np.zeros((len(sub_msas)))
|
||||
seq_len = len(query)
|
||||
ones = np.ones(seq_len)
|
||||
for idx, seq in enumerate(sub_msas):
|
||||
match = [query[i] == seq[i] for i in range(seq_len)]
|
||||
counts = np.sum(ones[match])
|
||||
identity[idx] = counts / seq_len
|
||||
|
||||
return identity
|
||||
|
||||
def alignment(msas1, sid1, msas2, sid2, top=True):
|
||||
# obtain the same species and delete species=0
|
||||
smatch = np.intersect1d(sid1, sid2)
|
||||
smatch = smatch[np.argsort(smatch)]
|
||||
smatch = np.delete(smatch, 0)
|
||||
|
||||
query1 = msas1[0]
|
||||
query2 = msas2[0]
|
||||
aligns = [query1 + query2]
|
||||
|
||||
for id in smatch:
|
||||
|
||||
index1 = np.where(sid1 == id)[0]
|
||||
sub_msas1 = [msas1[idx] for idx in index1]
|
||||
identity1 = cal_identity(query1, sub_msas1)
|
||||
sort_idx1 = np.argsort(-identity1)
|
||||
|
||||
index2 = np.where(sid2 == id)[0]
|
||||
sub_msas2 = [msas2[idx] for idx in index2]
|
||||
identity2 = cal_identity(query2, sub_msas2)
|
||||
sort_idx2 = np.argsort(-identity2)
|
||||
|
||||
if top == True:
|
||||
aligns.append(sub_msas1[sort_idx1[0]] + \
|
||||
sub_msas2[sort_idx2[0]])
|
||||
else:
|
||||
num = min(len(sub_msas1), len(sub_msas2))
|
||||
for i in range(num):
|
||||
aligns.append(sub_msas1[sort_idx1[i]] + \
|
||||
sub_msas2[sort_idx2[i]])
|
||||
|
||||
return aligns
|
||||
|
||||
|
||||
def RBF(dist_map):
|
||||
# Radial Basis Function
|
||||
D_min, D_max, D_count = 2., 22., 64
|
||||
D_mu = np.linspace(D_min, D_max, D_count)
|
||||
D_mu = D_mu[None,:]
|
||||
D_sigma = (D_max - D_min) / D_count
|
||||
|
||||
dist_map = dist_map.transpose(1,2,0)
|
||||
RBF = np.exp(-((dist_map - D_mu) / D_sigma)**2)
|
||||
|
||||
return RBF.transpose(2,0,1)
|
||||
Reference in New Issue
Block a user