From d018f9c759afdd2a133ad117b8bbcb03c06fbfd1 Mon Sep 17 00:00:00 2001 From: Jun Liu Date: Tue, 18 Mar 2025 19:14:00 +0800 Subject: [PATCH] Add files via upload --- pplm_contact/LoadHHM.py | 344 +++++++++++++++++++++++++++++++++++++++ pplm_contact/Module.py | 325 ++++++++++++++++++++++++++++++++++++ pplm_contact/config.py | 16 ++ pplm_contact/model.py | 72 ++++++++ pplm_contact/predict.py | 353 ++++++++++++++++++++++++++++++++++++++++ pplm_contact/utils.py | 225 +++++++++++++++++++++++++ 6 files changed, 1335 insertions(+) create mode 100644 pplm_contact/LoadHHM.py create mode 100644 pplm_contact/Module.py create mode 100644 pplm_contact/config.py create mode 100644 pplm_contact/model.py create mode 100644 pplm_contact/predict.py create mode 100644 pplm_contact/utils.py diff --git a/pplm_contact/LoadHHM.py b/pplm_contact/LoadHHM.py new file mode 100644 index 0000000..4dd4469 --- /dev/null +++ b/pplm_contact/LoadHHM.py @@ -0,0 +1,344 @@ +# +# This file (LoadHHM.py) was downloaded from the RaptorX-Contact +# at https://github.com/j3xugit/RaptorX-Contact +# +import numpy as np +import os +import sys +import pickle as pkl + +""" +This script reads an hhm file generated by the HHpred/HHblits package to generate position-specific scoring/frequency matrix. +After reading an hhm file, this script stores the HMM as a python dict(). +To use the position-specfic frequency matrix, please use the keyword PSFM. +To use the position-specific scoring matrix, please use the keyword PSSM. +Both PSFM and PSSM encode information derived from the profile HMM built by HHpred or HHblits, +so there is no need to directly use the keys containing 'hmm' + +PSFM and PSSM columns are arranged by the alphabetical order of amino acids in their 1-letter code +""" + +SS8Letter2Code = {'H':0, 'G':1, 'I':2, 'E':3, 'B':4, 'T':5, 'S':6, 'L':7, 'C':7 } + +## secondary structure conversion, note here HELIX, BETA and LOOP correspond to 1, 2 and 0, respectively +SS8Letter2SS3Code = {'H':0, 'G':0, 'I':0, 'E':1, 'B':1, 'T':2, 'S':2, 'L':2, 'C':2 } + +AA3LetterCode=['ALA', 'ARG', 'ASN', 'ASP', 'ASX', 'CYS', 'GLU', 'GLN', 'GLX', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRY', 'TYR', 'VAL'] +AA1LetterCode=['A', 'R', 'N', 'D', 'B', 'C', 'E', 'Q', 'Z', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V' ] + +##we only allow 20 amino acids in our protein sequences +ValidAALetters=set(['A', 'R', 'N', 'D', 'C', 'E', 'Q', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', 'S', 'T', 'W', 'Y', 'V']) + +##we map the two rare amino acids to 20 +## AAOrderBy3Letter is the alphabetical order of amino acids by its 3-letter code +AAOrderBy3Letter=[0, 1, 2, 3, 20, 4, 5, 6, 20, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19 ] + +## AAOrderBy1Letter is the alphabetical order of amino acids by its 1-letter code +AAOrderBy1Letter=[0, 14, 11, 2, 20, 1, 3, 13, 20, 5, 6, 7, 9, 8, 10, 4, 12, 15, 16, 18, 19, 17 ] + +## mapping between 1-letter code and 3-letter code +AA3LetterCode21LetterCode = {} +AA1LetterCode23LetterCode = {} + +##map from one AA (3-letter code or 1-letter code) to its order in terms of 3-letter code +AALetter2OrderOf3LetterCode = {} + +## map from one AA (3-letter or 1-letter code) to its order in terms of 1-letter code +AALetter2OrderOf1LetterCode = {} + +## map between 3-letter order and 1-letter order +AA3LetterOrder21LetterOrder = {} +AA1LetterOrder23LetterOrder = {} + +for l3, l1, o3, o1 in zip(AA3LetterCode, AA1LetterCode, AAOrderBy3Letter, AAOrderBy1Letter): + AA1LetterCode23LetterCode[l1] = l3 + AA3LetterCode21LetterCode[l3] = l1 + + AALetter2OrderOf3LetterCode[l1] = o3 + AALetter2OrderOf3LetterCode[l3] = o3 + + AALetter2OrderOf1LetterCode[l1] = o1 + AALetter2OrderOf1LetterCode[l3] = o1 + + AA3LetterOrder21LetterOrder[o3] = o1 + AA1LetterOrder23LetterOrder[o1] = o3 + + +## the rows and cols of gonnet matrix are arranged in the alpahbetical oder of amino acids in the 3-letter code +## the below matrix is a mutation probability matrix? +gonnet = [ +[ 1.7378, 0.870964,0.933254,0.933254, 1.12202, 0.954993, 1, 1.12202, 0.831764, 0.831764, 0.758578, 0.912011, 0.851138, 0.588844, 1.07152, 1.28825, 1.14815, 0.436516, 0.60256, 1.02329], +[ 0.870964,2.95121, 1.07152, 0.933254, 0.60256, 1.41254, 1.09648, 0.794328, 1.14815, 0.57544, 0.60256, 1.86209, 0.676083, 0.47863, 0.812831, 0.954993, 0.954993, 0.691831, 0.660693, 0.630957], +[ 0.933254,1.07152, 2.39883, 1.65959, 0.660693, 1.1749, 1.23027, 1.09648, 1.31826, 0.524807, 0.501187, 1.20226, 0.60256, 0.489779, 0.812831, 1.23027, 1.12202, 0.436516, 0.724436, 0.60256], +[ 0.933254,0.933254,1.65959, 2.95121, 0.47863, 1.23027, 1.86209, 1.02329, 1.09648, 0.416869, 0.398107, 1.12202, 0.501187, 0.354813, 0.851138, 1.12202, 1, 0.301995, 0.524807, 0.512861], +[ 1.12202, 0.60256, 0.660693,0.47863, 14.1254, 0.57544, 0.501187, 0.630957, 0.74131, 0.776247, 0.707946, 0.524807, 0.812831, 0.831764, 0.489779, 1.02329, 0.891251, 0.794328, 0.891251, 1], +[ 0.954993,1.41254, 1.1749, 1.23027, 0.57544, 1.86209, 1.47911, 0.794328, 1.31826, 0.645654, 0.691831, 1.41254, 0.794328, 0.549541, 0.954993, 1.04713, 1, 0.537032, 0.676083, 0.707946], +[ 1, 1.09648, 1.23027, 1.86209, 0.501187, 1.47911, 2.29087, 0.831764, 1.09648, 0.537032, 0.524807, 1.31826, 0.630957, 0.40738, 0.891251, 1.04713, 0.977237, 0.371535, 0.537032, 0.645654], +[ 1.12202, 0.794328,1.09648, 1.02329, 0.630957, 0.794328, 0.831764, 4.57088, 0.724436, 0.354813, 0.363078, 0.776247, 0.446684, 0.301995, 0.691831, 1.09648, 0.776247, 0.398107, 0.398107, 0.467735], +[ 0.831764,1.14815, 1.31826, 1.09648, 0.74131, 1.31826, 1.09648, 0.724436, 3.98107, 0.60256, 0.645654, 1.14815, 0.74131, 0.977237, 0.776247, 0.954993, 0.933254, 0.831764, 1.65959, 0.630957], +[ 0.831764,0.57544, 0.524807,0.416869, 0.776247, 0.645654, 0.537032, 0.354813, 0.60256, 2.51189, 1.90546, 0.616595, 1.77828, 1.25893, 0.549541, 0.660693, 0.870964, 0.660693, 0.851138, 2.04174], +[ 0.758578,0.60256, 0.501187,0.398107, 0.707946, 0.691831, 0.524807, 0.363078, 0.645654, 1.90546, 2.51189, 0.616595, 1.90546, 1.58489, 0.588844, 0.616595, 0.74131, 0.851138, 1, 1.51356], +[ 0.912011,1.86209, 1.20226, 1.12202, 0.524807, 1.41254, 1.31826, 0.776247, 1.14815, 0.616595, 0.616595, 2.0893, 0.724436, 0.467735, 0.870964, 1.02329, 1.02329, 0.446684, 0.616595, 0.676083], +[ 0.851138,0.676083,0.60256, 0.501187, 0.812831, 0.794328, 0.630957, 0.446684, 0.74131, 1.77828, 1.90546, 0.724436, 2.69153, 1.44544, 0.57544, 0.724436, 0.870964, 0.794328, 0.954993, 1.44544], +[ 0.588844,0.47863, 0.489779,0.354813, 0.831764, 0.549541, 0.40738, 0.301995, 0.977237, 1.25893, 1.58489, 0.467735, 1.44544, 5.01187, 0.416869, 0.524807, 0.60256, 2.29087, 3.23594, 1.02329], +[ 1.07152, 0.812831,0.812831,0.851138, 0.489779, 0.954993, 0.891251, 0.691831, 0.776247, 0.549541, 0.588844, 0.870964, 0.57544, 0.416869, 5.7544, 1.09648, 1.02329, 0.316228, 0.489779, 0.660693], +[ 1.28825, 0.954993,1.23027, 1.12202, 1.02329, 1.04713, 1.04713, 1.09648, 0.954993, 0.660693, 0.616595, 1.02329, 0.724436, 0.524807, 1.09648, 1.65959, 1.41254, 0.467735, 0.645654, 0.794328], +[ 1.14815, 0.954993,1.12202, 1, 0.891251, 1, 0.977237, 0.776247, 0.933254, 0.870964, 0.74131, 1.02329, 0.870964, 0.60256, 1.02329, 1.41254, 1.77828, 0.446684, 0.645654, 1], +[ 0.436516,0.691831,0.436516,0.301995, 0.794328, 0.537032, 0.371535, 0.398107, 0.831764, 0.660693, 0.851138, 0.446684, 0.794328, 2.29087, 0.316228, 0.467735, 0.446684, 26.3027, 2.5704, 0.549541], +[ 0.60256, 0.660693,0.724436,0.524807, 0.891251, 0.676083, 0.537032, 0.398107, 1.65959, 0.851138, 1, 0.616595, 0.954993, 3.23594, 0.489779, 0.645654, 0.645654, 2.5704, 6.0256, 0.776247], +[ 1.02329, 0.630957,0.60256, 0.512861, 1, 0.707946, 0.645654, 0.467735, 0.630957, 2.04174, 1.51356, 0.676083, 1.44544, 1.02329, 0.660693, 0.794328, 1, 0.549541, 0.776247, 2.18776] +] + +gonnet = np.array(gonnet, np.float32) + +M_M, M_I, M_D, I_M, I_I, D_M, D_D, _NEFF, I_NEFF, D_NEFF = 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 + +## HMMNull is the background score of amino acids, in the alphabetical order of 1-letter code +HMMNull = [3706,5728,4211,4064,4839,3729,4763,4308,4069,3323,5509,4640,4464,4937,4285,4423,3815,3783,6325,4665,0] +HMMNull = np.array(HMMNull, dtype=np.float32) + + +## this function reads the HMM block from a profleHMM file (generated by HHpred/HHblits package) +## for the profileHMM file, the header has 4 lines +def ReadHHM(lines, start_position, length, one_protein, numLines4Header=4): + + i = start_position + one_protein['HMMHeader'] = lines[i: i+ numLines4Header] + i += numLines4Header + + ## the columns of hmm1 are in the alphabetical order of amino acids in 1-letter code, different from the above PSP and PSM matrices + one_protein['hmm1'] = np.zeros( (length, 20), np.float32) + one_protein['hmm2'] = np.zeros( (length, 10), np.float32) + one_protein['hmm1_prob'] = np.zeros((length, 20), np.float32) + one_protein['hmm1_score'] = np.zeros((length, 20), np.float32) + + seqStr = '' + + for l in range(length): + ##this line is for emission score. The amino acids are ordered from left to right alphabetically by their 1-letter code + fields = lines[ i + l*3 + 0].replace("*", "99999").split() + + assert len(fields) == 23 + one_protein['hmm1'][l] = np.array([ - np.int32(num) for num in fields[2: -1] ])/1000. + aa = fields[0] + seqStr += aa + """ + if (l + 1) != np.int32(fields[1]): + print 'Error: inconsistent residue number in file for protein ', one_protein['name'], ' at line: ', lines[ i + l*3 + 0 ] + exit(-1) + """ + + ##the first 7 columns of this line is for state transition + one_protein['hmm2'][l][0:7] = [ np.exp(-np.int32(num)/1000.0*0.6931) for num in lines[i + l*3 + 1].replace("*", "99999").split()[0:7]] + + ##the last 3 columns of this line is for Neff of Match, Insertion and Deletion + one_protein['hmm2'][l][7:10] = [ np.int32(num)/1000.0 for num in lines[i + l*3 + 1].split()[7:10]] + + ## _NEFF is for match, I_NEFF for insertion and D_NEFF for deletion. More comments are needed for the below code + rm = 0.1 + one_protein['hmm2'][l][M_M] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_M] + rm*0.6)/(rm + one_protein['hmm2'][l][_NEFF]) + one_protein['hmm2'][l][M_I] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_I] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF]) + one_protein['hmm2'][l][M_D] = (one_protein['hmm2'][l][_NEFF]*one_protein['hmm2'][l][M_D] + rm*0.2)/(rm + one_protein['hmm2'][l][_NEFF]) + + ri = 0.1 + one_protein['hmm2'][l][I_I] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_I] + ri*0.75)/(ri + one_protein['hmm2'][l][I_NEFF]) + one_protein['hmm2'][l][I_M] = (one_protein['hmm2'][l][I_NEFF]*one_protein['hmm2'][l][I_M] + ri*0.25)/(ri + one_protein['hmm2'][l][I_NEFF]) + + rd = 0.1 + one_protein['hmm2'][l][D_D] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_D] + rd*0.75)/(rd + one_protein['hmm2'][l][D_NEFF]) + one_protein['hmm2'][l][D_M] = (one_protein['hmm2'][l][D_NEFF]*one_protein['hmm2'][l][D_M] + rd*0.25)/(rd + one_protein['hmm2'][l][D_NEFF]) + + + one_protein['hmm1_prob'][l,] = pow(2.0, one_protein['hmm1'][l,]) + wssum = sum(one_protein['hmm1_prob'][l, ]) + + #print 'l = ', l, 'sum= ', wssum + + ## renormalize to make wssum = 1 + if wssum > 0 : + one_protein['hmm1_prob'][l, ] /= wssum + else: + one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1. + + """ + ## if the probability sum is not equal to 1 + if abs(wssum - 1.) > 0.1 : + one_protein['hmm1_prob'][l, ] = 0 + one_protein['hmm1_prob'][l, AALetter2OrderOf1LetterCode[aa] ] = 1. + """ + + + ## add pseudo count + g = np.zeros( (20), np.float32) + for j in range(20): + orderIn3LetterCode_j = AA1LetterOrder23LetterOrder[j] + for k in range(20): + orderIn3LetterCode_k = AA1LetterOrder23LetterOrder[k] + g[j] += one_protein['hmm1_prob'][l, k] * gonnet[ orderIn3LetterCode_k, orderIn3LetterCode_j ] + g[j] *= pow(2.0, -1.0*HMMNull[j] / 1000.0) + + #print 'l=', l, ' gsum= ', sum(g) + ## sum(g) is very close to 1, here we renormalize g to make its sum to be exactly 1 + g = g/sum(g) + + ws_tmp_neff = one_protein['hmm2'][l][_NEFF] - 1 + one_protein['hmm1'][l, ] = (ws_tmp_neff * one_protein['hmm1_prob'][l, ] + g*10) / (ws_tmp_neff+10) + + ## recalculate the emission score and probability after pseudo count is added + one_protein['hmm1_prob'][l,] = one_protein['hmm1'][l, ] + one_protein['hmm1'][l, ] = np.log2(one_protein['hmm1_prob'][l, ]) + one_protein['hmm1_score' ][l, ] = one_protein['hmm1'][l, ] + HMMNull[:20]/1000.0 + + ## PSFM: position-specific frequency matrix, PSSM: position-specific scoring matrix + one_protein['PSFM'] = one_protein['hmm1_prob'] + one_protein['PSSM'] = one_protein['hmm1_score'] + + #assert ( seqStr == one_protein['sequence'] ) + if len(seqStr) != len(one_protein['sequence']): + print ('ERROR: inconsistent sequence length in HMM section and orignal sequence for protein: ', one_protein['name'] ) + exit(-1) + + comparison = [ (aa=='X' or bb=='X' or aa==bb) for aa, bb in zip(seqStr, one_protein['sequence']) ] + if not all(comparison): + print ('ERROR: inconsistent sequence between HMM section and orignal sequence for protein: ', one_protein['name']) + print (' original seq: ', one_protein['sequence']) + print (' HMM seq: ', seqStr) + exit(-1) + + return i + 3*length, one_protein + +## this function reads a profile HMM file generated by HHpred/HHblits package +def load_hmm(hmmfile): + with open(hmmfile, 'r') as fh: + content = [ r.strip() for r in list(fh) ] + if not bool(content): + print ('ERROR: empty profileHMM file: ', hmmfile) + exit(1) + if not content[0].startswith('HHsearch'): + print ('ERROR: this file may not be a profileHMM file generated by HHpred/HHblits: ', hmmfile) + exit(1) + if len(content) < 10: + print ('ERROR: this profileHMM file is too short: ', hmmfile) + exit(1) + + requiredSections = ['name', 'length', 'sequence', 'NEFF', 'hmm1', 'hmm2', 'hmm1_prob', 'hmm1_score', 'PSFM', 'PSSM', 'DateCreated'] + protein = {} + + ## get sequence name + if not content[1].startswith('NAME '): + print ('ERROR: the protein name shall appear at the second line of profileHMM file: ', hmmfile) + exit(1) + fields = content[1].split() + if len(fields) < 2: + print ('ERROR: incorrect name format in profileHMM file: ', hmmfile) + exit(1) + protein['name'] = fields[1] + + i = 0 + while i < len(content): + row = content[i] + if len(row)<1: + i += 1 + continue + + if row.startswith('DATE '): + protein['DateCreated'] = row[6:] + i += 1 + continue + + if row.startswith('NEFF '): + protein['NEFF'] = np.float32(row.split()[1]) + i += 1 + continue + + if row.startswith('LENG '): + protein['length'] = np.int32(row.split()[1]) + i += 1 + continue + + if row.startswith('>ss_pred'): + ## read the predicted secondary structure + start = i+1 + end = i+1 + while not content[end].startswith('>'): + end += 1 + protein['SSEseq'] = ''.join(content[start:end]).replace('C', 'L') + if len(protein['SSEseq']) != protein['length']: + print ('ERROR: inconsistent sequence length and predicted SS sequence length in hmmfile: ', hmmfile) + exit(1) + i = end + continue + + if row.startswith('>ss_conf'): + ## read the predicted secondary structure confidence score + start = i+1 + end = i+1 + while not content[end].startswith('>'): + end += 1 + + SSEconfStr = ''.join(content[start:end]) + protein['SSEconf'] = [ np.int16(score) for score in SSEconfStr ] + + if len(protein['SSEconf']) != protein['length']: + print ('ERROR: inconsistent sequence length and predicted SS confidence sequence length in hmmfile: ', hmmfile) + exit(1) + + i = end + continue + + + if row.startswith('>' + protein['name']): + ## read in the sequence in the following lines + start = i+1 + end = i+1 + while (not content[end].startswith('>')) and (not content[end].startswith('#')): + end += 1 + + ## at this point, content[end] shall start with > + protein['sequence'] = ''.join(content[start:end]) + if len(protein['sequence']) != protein['length']: + print ('ERROR: inconsistent sequence length in hmmfile: ', hmmfile) + exit(1) + i = end + continue + + if len(row) == 1 and row[0]=='#' and content[i+1].startswith('NULL') and content[i+2].startswith('HMM'): + i, protein = ReadHHM(content, i+1, protein['length'], protein, numLines4Header=4) + continue + + i += 1 + + + ## double check to see some required sections are read in + for section in requiredSections: + #if not protein.has_key(section): + if section not in protein.keys(): + print ('ERROR: one section for ', section, ' is missing in the hmm file: ', hmmfile) + print ('ERROR: it is also possible that the hmm file has a format incompatible with this script.') + exit(1) + + protein['requiredSections'] = requiredSections + + return protein + + + +## for test only +if __name__ == "__main__": + if len(sys.argv) < 2: + print ('python LoadHHM.py hhm_file') + print ('the input file shall end with .hhm') + exit(1) + + file = sys.argv[1] + + if file.endswith('.hhm'): + protein = load_hmm(file) + else: + print ('ERROR: the input file shall have suffix .hmm') + exit(1) + + savefile = os.path.basename(file) + '.pkl' + fh = open(savefile, 'wb') + pkl.dump( protein, fh, protocol=pkl.HIGHEST_PROTOCOL) + fh.close() diff --git a/pplm_contact/Module.py b/pplm_contact/Module.py new file mode 100644 index 0000000..9e30f1f --- /dev/null +++ b/pplm_contact/Module.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn + + +class ResNet(nn.Module): + def __init__(self, channels, dropout=0.0): + super(ResNet, self).__init__() + self.channels = channels + self.n_layers = 5 + self.dilations = [1, 2, 4, 2, 1] + self.dropout = dropout + + self.blocks = nn.ModuleList() + + for layer in range(self.n_layers): + block = nn.Sequential( + nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]), + nn.InstanceNorm2d(channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True), + nn.Dropout(dropout), + nn.Conv2d(channels, channels, kernel_size=3, dilation=self.dilations[layer], padding=self.dilations[layer]), + nn.InstanceNorm2d(channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + self.blocks.append(block) + + def forward(self, x): + for block in self.blocks: + _residual = x + x = block(x) + x = x + _residual + + return x + + +class Intra_ResNet(nn.Module): + def __init__(self, dim_1d=768+20, dim_2d=144+2+64, channels=64, dropout=0): + super(Intra_ResNet, self).__init__() + + self.dim_1d = dim_1d + self.dim_2d = dim_2d + self.channels = channels + self.dropout = dropout + + self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d) + self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d) + + self.pair_conv1 = nn.Sequential( + nn.Conv2d(self.dim_1d * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False), + nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + self.pair_conv2 = nn.Sequential( + nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False), + nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + self.pair_conv3 = nn.Sequential( + nn.Conv2d(self.channels * 2, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False), + nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + self.resnet = ResNet(channels=self.channels, dropout=self.dropout) + + def forward(self, x_1d, x_2d): + + x_1d = self.pre_norm_1d(x_1d) + x_2d = self.pre_norm_2d(x_2d) + + b, d, L = x_1d.size() + x_1d_row = x_1d.unsqueeze(-1).expand(-1, -1, L, L) + x_1d_col = x_1d.unsqueeze(-2).expand(-1, -1, L, L) + pair_1 = torch.cat([x_1d_row, x_1d_col], dim=1) + pair_1 = self.pair_conv1(pair_1) + pair_2 = self.pair_conv2(x_2d) + + pair = torch.cat([pair_1, pair_2], dim=1) + pair = self.pair_conv3(pair) + pair = self.resnet(pair) + + return pair + + +class Inter_ResNet(nn.Module): + def __init__(self, dim_1d=768+20, dim_2d=768 + 20, channels=64, dropout=0): + super(Inter_ResNet, self).__init__() + + self.dim_1d = dim_1d + self.dim_2d = dim_2d + self.channels = channels + self.dropout = dropout + + # self.pre_norm_1d = nn.InstanceNorm1d(self.dim_1d) + self.pre_norm_2d = nn.InstanceNorm2d(self.dim_2d) + + self.pair_conv2 = nn.Sequential( + nn.Conv2d(self.dim_2d, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False), + nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + self.pair_conv3 = nn.Sequential( + nn.Conv2d(self.channels, self.channels, kernel_size=1, stride=1, padding="same", dilation=1, bias=False), + nn.InstanceNorm2d(self.channels, affine=True, track_running_stats=True), + nn.LeakyReLU(negative_slope=0.01, inplace=True) + ) + + self.resnet = ResNet(channels=self.channels, dropout=self.dropout) + + def forward(self, inter_2d): + + inter_2d = self.pre_norm_2d(inter_2d) + pair = self.pair_conv2(inter_2d) + pair = self.pair_conv3(pair) + pair = self.resnet(pair) + + return pair + + +class InterTriangleMultiplication_S(nn.Module): + def __init__(self, channel_z=64, channel_c=64, transpose=False): + super(InterTriangleMultiplication_S, self).__init__() + + self.dz = channel_z + self.dc = channel_c + self.transpose = transpose + + # init norm + self.norm_init = nn.LayerNorm(self.dz) + + # linear * gate for com_rec, com_lig + self.Linear_inter = nn.Linear(self.dz, self.dc) + self.gate_inter = nn.Linear(self.dz, self.dc) + + self.Linear_intra = nn.Linear(self.dz, self.dc) + self.gate_intra = nn.Linear(self.dz, self.dc) + + # final + self.norm_final = nn.LayerNorm(self.dc) + self.Linear_final = nn.Linear(self.dc, self.dz) + self.gate_final = nn.Linear(self.dz, self.dz) + + def forward(self, z_inter, z_intra, transpose=False): + if self.transpose or transpose: + z_inter = z_inter.permute(0, 2, 1, 3) + z_intra = z_intra.permute(0, 2, 1, 3) + + z_inter_init = self.norm_init(z_inter) + z_intra = self.norm_init(z_intra) + + z_inter = self.Linear_inter(z_inter_init) * self.gate_inter(z_inter_init).sigmoid() + z_intra = self.Linear_intra(z_intra) * self.gate_intra(z_intra).sigmoid() + + z_inter_update = torch.einsum('bikc,bkjc->bijc', z_intra, z_inter) + + z_inter_update = self.norm_final(z_inter_update) + z_inter = self.Linear_final(z_inter_update) * self.gate_final(z_inter_init).sigmoid() + + if self.transpose or transpose: + z_inter = z_inter.permute(0, 2, 1, 3) + + return z_inter + + +class InterCrossAttention_S(nn.Module): + def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=True, transpose=False): + super(InterCrossAttention_S, self).__init__() + + self.dz = channel_z + self.dhc = dim_head + self.num_head = num_head + self.dc = self.num_head * self.dhc + self.transpose = transpose + self.bias = bias + + self.norm_init = nn.LayerNorm(self.dz) + self.Linear_Q = nn.Linear(self.dz, self.dc) + self.Linear_K = nn.Linear(self.dz, self.dc) + self.Linear_V = nn.Linear(self.dz, self.dc) + if self.bias: + self.Linear_bias = nn.Linear(self.dz, self.num_head) + + self.softmax = nn.Softmax(dim=-1) + + self.gate_final = nn.Linear(self.dz, self.dc) + self.Linear_final = nn.Linear(self.dc, self.dz) + + + def reshape_dim(self, x): + new_shape = x.size()[:-1] + (self.num_head, self.dhc) + return x.view(*new_shape) + + def forward(self, z_inter, z_intra, transpose=False): + if self.transpose or transpose: + z_inter = z_inter.permute(0, 2, 1, 3) + z_intra = z_intra.permute(0, 2, 1, 3) + + + B, row, col, _ = z_inter.shape + z_inter = self.norm_init(z_inter) + z_intra = self.norm_init(z_intra) + + q = self.reshape_dim(self.Linear_Q(z_inter)) + k = self.reshape_dim(self.Linear_K(z_intra)) + v = self.reshape_dim(self.Linear_V(z_intra)) + + scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device)) + + attn = torch.einsum('b l j h c, b l i h c -> b l j h i', q * scalar, k) + + if self.bias: + bias = self.Linear_bias(z_inter).unsqueeze(-1) + attn = attn + bias + + if attn.dtype is torch.bfloat16: + with torch.cuda.amp.autocast(enabled=False): + attn_weights = torch.nn.functional.softmax(attn, -1) + else: + attn_weights = torch.nn.functional.softmax(attn, -1) + + attn = torch.einsum('b l j h i, b l i h c -> b l j h c', attn_weights, v) + + gate_final = self.reshape_dim(self.gate_final(z_inter)).sigmoid() + z_inter = (attn * gate_final).contiguous().view(gate_final.size()[:-2] + (-1,)) + + z_inter = self.Linear_final(z_inter) + + if self.transpose or transpose: + z_inter = z_inter.permute(0, 2, 1, 3) + + return z_inter + + +class InterSelfAttention_S(nn.Module): + def __init__(self, channel_z=64, num_head=4, dim_head=8, bias=False, transpose=False): + super(InterSelfAttention_S, self).__init__() + + self.dz = channel_z + self.dhc = dim_head + self.num_head = num_head + self.dc = self.num_head * self.dhc + self.transpose = transpose + self.bias = bias + + self.norm_init = nn.LayerNorm(self.dz) + self.Linear_Q = nn.Linear(self.dz, self.dc) + self.Linear_K = nn.Linear(self.dz, self.dc) + self.Linear_V = nn.Linear(self.dz, self.dc) + if self.bias: + self.Linear_bias = nn.Linear(self.dz, self.num_head) + + self.softmax = nn.Softmax(-1) + self.gate_v = nn.Linear(self.dz, self.dc) + self.Linear_final = nn.Linear(self.dc, self.dz) + + + def reshape_dim(self, x): + new_shape = x.size()[:-1] + (self.num_head, self.dhc) + return x.view(*new_shape) + + def forward(self, z_inter, dist=None, transpose=False): + if self.transpose or transpose: + z_inter = z_inter.permute(0, 2, 1, 3) + + B, row, col, _ = z_inter.shape + z_inter = self.norm_init(z_inter) + + q = self.reshape_dim(self.Linear_Q(z_inter)) + k = self.reshape_dim(self.Linear_K(z_inter)) + v = self.reshape_dim(self.Linear_V(z_inter)) + + scalar = 1.0 / torch.sqrt(torch.tensor(self.dhc, dtype=q.dtype, device=q.device)) + + attn = torch.einsum('b l i h c, b l j h c -> b l i h j', q * scalar, k) + + if self.bias: + bias = self.Linear_bias(z_inter) + # print("bias:", bias.shape, attn.shape, z_inter.shape) + attn = attn + bias + + if dist != None: + coef = torch.exp(-(dist/8.0)**2.0/2.0).unsqueeze(3).type_as(q) + attn = attn * coef + + if attn.dtype is torch.bfloat16: + with torch.cuda.amp.autocast(enabled=False): + attn_weights = torch.nn.functional.softmax(attn, -1) + else: + attn_weights = torch.nn.functional.softmax(attn, -1) + + v_avg = torch.einsum('b l i h j, b l j h c -> b l i h c', attn_weights, v) + + gate_v = self.reshape_dim(self.gate_v(z_inter)).sigmoid() + z_com = (v_avg * gate_v).contiguous().view( v_avg.size()[:-2] + (-1,) ) + + z_final = self.Linear_final(z_com) + + if self.transpose or transpose: + z_final = z_final.permute(0, 2, 1, 3) + + return z_final + + +class Transition(nn.Module): + def __init__(self, channel_z=64, transition_n=4): + super(Transition, self).__init__() + + self.dz = channel_z + self.n = transition_n + + self.norm = nn.LayerNorm(self.dz) + self.transition = nn.Sequential( + nn.Linear(self.dz, self.dz*self.n), + nn.Linear(self.dz*self.n, self.dz) + ) + + def forward(self, z_com): + z_com = self.norm(z_com) + z_com = self.transition(z_com) + + return z_com + diff --git a/pplm_contact/config.py b/pplm_contact/config.py new file mode 100644 index 0000000..6e0021d --- /dev/null +++ b/pplm_contact/config.py @@ -0,0 +1,16 @@ +import os + +hhsuite_dir = "/mnt/rna01/junl/tools/hh-suite/build" + +hhblits = os.path.join(hhsuite_dir, "bin/hhblits") +hhmake = os.path.join(hhsuite_dir, "bin/hhmake") +reformat = os.path.join(hhsuite_dir, "scripts/reformat.pl") +hhfilter = os.path.join(hhsuite_dir, "bin/hhfilter") +UniRef_database = "/mnt/dna01/library2/e2e_folding/alphafold/v2.3/lib/uniref30/UniRef30_2021_03" + + +ccmpred = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/ccmpred" +esm_msa = "/mnt/rna01/junl/PPLM/PPLM_source_code/pplm_contact/external_tools/extract_esm_msa_features.py" +esm_msa_model = "/mnt/rna01/junl/tools/esm-main/esm_models/esm_msa1_t12_100M_UR50S.pt" + + diff --git a/pplm_contact/model.py b/pplm_contact/model.py new file mode 100644 index 0000000..8daa80c --- /dev/null +++ b/pplm_contact/model.py @@ -0,0 +1,72 @@ +import os +import sys +import torch +import torch.nn as nn +mian_path = os.path.dirname(__file__) + "/../" +sys.path.append(os.path.abspath(mian_path)) + +from pplm_contact.Module import * + +class PPLM_Contact(nn.Module): + + def __init__(self, + intra_1d_dim=768+20, + intra_2d_dim=144+2+64, + inter_2d_dim=144+2+660, + channels=64, + num_blocks=12, + droupout=0.10, + ): + super(PPLM_Contact, self).__init__() + self.intra_1d_dim = intra_1d_dim + self.intra_2d_dim = intra_2d_dim + self.inter_2d_dim = inter_2d_dim + self.channels = channels + self.num_blocks = num_blocks + self.droupout = droupout + + ### ResNet + self.intra_resnet = Intra_ResNet(dim_1d=self.intra_1d_dim, dim_2d=self.intra_2d_dim, channels=self.channels) + self.inter_resnet = Inter_ResNet(dim_2d=self.inter_2d_dim, channels=self.channels) + + ### Transformer + self.InterTriangleMulti_R = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=False) for _ in range(num_blocks)]) + self.InterTriangleMulti_L = nn.ModuleList([InterTriangleMultiplication_S(channel_z=self.channels, channel_c=self.channels, transpose=True) for _ in range(num_blocks)]) + self.InterCrossAttn_R = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=False) for _ in range(num_blocks)]) + self.InterCrossAttn_L = nn.ModuleList([InterCrossAttention_S(channel_z=self.channels, bias=True, transpose=True) for _ in range(num_blocks)]) + self.InterSelfAttn_R = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=False) for _ in range(num_blocks)]) + self.InterSelfAttn_L = nn.ModuleList([InterSelfAttention_S(channel_z=self.channels, bias=False, transpose=True) for _ in range(num_blocks)]) + self.Transition = nn.ModuleList([Transition(channel_z=self.channels) for _ in range(num_blocks)]) + self.drop = nn.Dropout(self.droupout) + + self.norm_final = nn.LayerNorm(self.channels) + self.Linear_final = nn.Linear(self.channels, 1) + self.act_final = nn.Sigmoid() + + def forward(self, intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_dist=None, intra2_dist=None, last_layer=False): + + intra1_2d = self.intra_resnet(intra1_1d, intra1_2d) + intra2_2d = self.intra_resnet(intra2_1d, intra2_2d) + inter_2d = self.inter_resnet(inter_2d) + + intra1_2d = intra1_2d.permute(0, 2, 3, 1) + intra2_2d = intra2_2d.permute(0, 2, 3, 1) + inter_2d = inter_2d.permute(0, 2, 3, 1) + + + for block in range(self.num_blocks): + inter_2d = inter_2d + self.drop(self.InterTriangleMulti_R[block](inter_2d, intra1_2d)) + self.drop(self.InterTriangleMulti_L[block](inter_2d, intra2_2d, transpose=True)) + inter_2d = inter_2d + self.drop(self.InterCrossAttn_R[block](inter_2d, intra1_2d)) + self.drop(self.InterCrossAttn_L[block](inter_2d, intra2_2d, transpose=True)) + inter_2d = inter_2d + self.drop(self.InterSelfAttn_R[block](inter_2d, intra2_dist)) + self.drop(self.InterSelfAttn_L[block](inter_2d, intra1_dist, transpose=True)) + inter_2d = inter_2d + self.Transition[block](inter_2d) + + inter_2d = self.norm_final(inter_2d) + if last_layer: + representation = inter_2d.permute(0,3,1,2) + inter_2d = self.Linear_final(inter_2d) + inter_contact = self.act_final(inter_2d) + + if last_layer: + return inter_contact.permute(0,3,1,2)[0], representation + + return inter_contact.permute(0,3,1,2)[0] diff --git a/pplm_contact/predict.py b/pplm_contact/predict.py new file mode 100644 index 0000000..941e8dc --- /dev/null +++ b/pplm_contact/predict.py @@ -0,0 +1,353 @@ +import os +import sys +import pathlib +import numpy as np +import pickle +import torch +import subprocess +import argparse +from model import PPLM_Contact +from utils import extract_seq_and_dist_map, pairing_msa, RBF +from LoadHHM import load_hmm +from config import * + + +def main(): + parser = argparse.ArgumentParser(description="Protein-Protein Contact Prediction", + epilog="v0.0.1") + + parser.add_argument("pdbA_path", + type=pathlib.Path, + help="Location of pdb A") + + parser.add_argument("pdbB_path", + type=pathlib.Path, + help="Location of pdb B") + + parser.add_argument("output_folder", + type=pathlib.Path, + help="Location to store output files") + + parser.add_argument("--n_cpu", + type=int, + default=8, + help="Number of CPU cores for search MSA", + ) + + parser.add_argument("--gpu_id", + type=int, + default=0, + help="gpu device specified", + ) + + args = parser.parse_args() + + ### Define parameters ### + define_param(args) + + ##### Step 0: Process the input pdb (clean pdb & extract sequence & derive monomer distance map) ##### + print("========== Step 0: Processing the intput pdb (Clean pdb, extract seq and dist map) ==========") + + subprocess.run("grep \"^ATOM\" " + str(args.pdbA_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target1_pdb, shell=True, check=True) + target1_res_idx_type = extract_seq_and_dist_map(target1_pdb, target1_seq, target1_monomer_dist) + if mode == "hetero": + subprocess.run("grep \"^ATOM\" " + str(args.pdbB_path) + " | sed 's/MEX/CYS/g; s/HID/HIS/g; s/HIE/HIS/g; s/HIP/HIS/g; s/MSE/MET/g; s/ASX/ASN/g; s/GLX/GLN/g; s/TYS/TRP/g' > " + target2_pdb, shell=True, check=True) + target2_res_idx_type = extract_seq_and_dist_map(target2_pdb, target2_seq, target2_monomer_dist) + else: + target2_res_idx_type = target1_res_idx_type + + ##### Step 1: Search MSA ##### + print("========== Step 1: Searching MSA ==========") + if not os.path.isfile(target1_msa): + print("Searching MSA for", target1_seq) + subprocess.run(hhblits + " -i " + target1_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target1_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True) + if mode == "hetero" and not os.path.isfile(target2_msa): + print("Searching MSA for", target2_seq) + subprocess.run(hhblits + " -i " + target2_seq + " -d " + UniRef_database + " -cpu " + str(args.n_cpu) + " -oa3m " + target2_msa + " -n 3 -e 0.001 -id 99 -cov 0.4", shell=True, check=True) + + ##### Step 2: Extract Monomer MSA features + print("========== Step 2: Extract Monomer MSA features ==========") + if os.path.isfile(target1_msa): + extract_MSA_features(target1, target1_msa, target1_hhm, target1_aln, target1_dca_di, target1_dca_apc, target1_esm_msa) + if mode == "hetero" and os.path.isfile(target2_msa): + extract_MSA_features(target2, target2_msa, target2_hhm, target2_aln, target2_dca_di, target2_dca_apc, target2_esm_msa) + + ##### Step 3: Extract paired MSA features + print("========== Step 3: Extract paired MSA features ==========") + if mode == "hetero": + pairing_msa(target1_msa, target2_msa, paired_msa) + extract_MSA_features(target, paired_msa, paired_hhm, paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa) + + ##### Step 4: Generate PPLM inter-protein attention matrix + print("========== Step 4: Generate PPLM features ==========") + if not os.path.isfile(pplm_feat): + get_pplm_features(target1_seq, target2_seq, pplm_feat, device='cpu') + + ##### Step 5: Collect all features + print("========== Step 5: Collect all features ==========") + feats = collect_all_features() + + ##### Step 6: Predict inter-protein contact + print("========== Step 6: Predict inter-protein contact ==========") + if not os.path.isfile(pred_contact_pkl_path) or not os.path.isfile(pred_contact_txt_path): + pred_inter_contatct = predict_contact(feats, mode, device=device) + + with open(pred_contact_pkl_path, "wb") as fw: + pickle.dump(pred_inter_contatct, fw) + + with open(pred_contact_txt_path, "w") as fw: + prediction_idx_prob = [] + for i in range(pred_inter_contatct.shape[0]): + for j in range(pred_inter_contatct.shape[1]): + prediction_idx_prob.append([i, j, pred_inter_contatct[i, j]]) + + prediction_idx_prob = sorted(prediction_idx_prob, key=lambda x: x[2], reverse=True) + + data = "{:<10}".format("Rank") + "{:<10}".format("ResIdx1") + "{:<10}".format("ResType1") + "{:<10}".format("ResIdx2") + "{:<10}".format("ResType2") + "{:<20}".format("Contact_Probability") + "\n" + fw.write(data) + for k in range(len(prediction_idx_prob)): + res1 = prediction_idx_prob[k][0] + res2 = prediction_idx_prob[k][1] + prob = prediction_idx_prob[k][2] + res1_idx = target1_res_idx_type[res1][0] + res1_type = target1_res_idx_type[res1][1] + res2_idx = target2_res_idx_type[res2][0] + res2_type = target2_res_idx_type[res2][1] + + data = "{:<10}".format(k+1) + "{:<10}".format(str(res1_idx) + ":A") + "{:<10}".format(res1_type) + "{:<10}".format(str(res2_idx) + ":B") + "{:<10}".format(res2_type) + "{:<10}".format(f"{prob:.6g}") + "\n" + fw.write(data) + + + +def define_param(args): + # param = {} + global target1, target2, target, workspace, mode, device + global target1_pdb, target1_seq, target1_monomer_dist, target2_pdb, target2_seq + global target2_monomer_dist, target1_msa, target2_msa, target1_hhm, target1_aln + global target1_dca_di, target1_dca_apc, target1_esm_msa, target2_hhm, target2_aln + global target2_dca_di, target2_dca_apc, target2_esm_msa, paired_msa, paired_hhm + global paired_aln, paired_dca_di, paired_dca_apc, paired_esm_msa, pplm_feat + global pred_contact_pkl_path, pred_contact_txt_path + + + target1 = args.pdbA_path.stem + target2 = args.pdbB_path.stem + target = str(args.output_folder).split('/')[-1] # target1 + '-' + target2 + workspace = args.output_folder + if not os.path.isdir(workspace): + os.makedirs(workspace) + + assigned_device = "cuda:" + str(args.gpu_id) + device = assigned_device if torch.cuda.is_available() else "cpu" + + if args.pdbA_path == args.pdbB_path: + mode = "homo" + else: + mode = "hetero" + + print("device:", device, " mode:", mode) + + target1_pdb = os.path.join(workspace, target1 + ".clean.pdb") + target1_seq = os.path.join(workspace, target1 + ".fasta") + target1_monomer_dist = os.path.join(workspace, target1 + ".monomer_dist.pkl") + target2_pdb = os.path.join(workspace, target2 + ".clean.pdb") + target2_seq = os.path.join(workspace, target2 + ".fasta") + target2_monomer_dist = os.path.join(workspace, target2 + ".monomer_dist.pkl") + target1_msa = os.path.join(workspace, target1 + ".a3m") + target2_msa = os.path.join(workspace, target2 + ".a3m") + target1_hhm = os.path.join(workspace, target1 + ".hhm") + target1_aln = os.path.join(workspace, target1 + ".aln") + target1_dca_di = os.path.join(workspace, target1 + ".dca_di") + target1_dca_apc = os.path.join(workspace, target1 + ".dca_apc") + target1_esm_msa = os.path.join(workspace, target1 + ".esm_msa.pkl") + target2_hhm = os.path.join(workspace, target2 + ".hhm") + target2_aln = os.path.join(workspace, target2 + ".aln") + target2_dca_di = os.path.join(workspace, target2 + ".dca_di") + target2_dca_apc = os.path.join(workspace, target2 + ".dca_apc") + target2_esm_msa = os.path.join(workspace, target2 + ".esm_msa.pkl") + paired_msa = os.path.join(workspace, target + "_paired.a3m") + paired_hhm = os.path.join(workspace, target + "_paired.hhm") + paired_aln = os.path.join(workspace, target + "_paired.aln") + paired_dca_di = os.path.join(workspace, target + "_paired.dca_di") + paired_dca_apc = os.path.join(workspace, target + "_paired.dca_apc") + paired_esm_msa = os.path.join(workspace, target + "_paired.esm_msa.pkl") + pplm_feat = os.path.join(workspace, target + ".pplm.pkl") + pred_contact_pkl_path = os.path.join(workspace, target + ".pred_contact.pkl") + pred_contact_txt_path = os.path.join(workspace, target + ".pred_contact.txt") + +def extract_MSA_features(name, msa, hhm, aln, dci_di, dci_apc, esm_msa_path): + + print("Generating the HHM file for", msa) + subprocess.run(hhmake + " -i " + msa + " -o " + hhm, shell=True, check=True) + subprocess.run(reformat + " " + msa + " " + os.path.join(workspace, name + ".fas") + " -r -l 2000 >/dev/null", shell=True, check=True) + subprocess.run("awk '{if(!($0~/^>/)){print}}' " + os.path.join(workspace, name + ".fas") + " > " + aln, shell=True, check=True) + + # generate the DCA features + print("Generating the DCA features for", aln) + if not os.path.isfile(dci_di) or not os.path.isfile(dci_apc): + subprocess.run(ccmpred + " " + aln + " " + dci_di + " -R -A", shell=True, check=True) + subprocess.run(ccmpred + " " + aln + " " + dci_apc + " -R", shell=True, check=True) + + # generate the ESM-MSA features + print("Generating the ESM-MSA features for", msa) + if not os.path.isfile(esm_msa_path): + subprocess.run(hhfilter + " -i " + msa + " -o " + msa + "_filtered -diff 512", shell=True, check=True) + subprocess.run(["python", esm_msa, esm_msa_model, msa + "_filtered", esm_msa_path]) + +def get_pplm_features(seqA_path, seqB_path, out_pkl_path, device='cpu'): + mian_path = os.path.dirname(__file__) + "/../" + sys.path.append(os.path.abspath(mian_path)) + + from pplm import PPLM, Alphabet + model_location = os.path.join(mian_path, 'pplm/models/', 'pplm_t33_650M.pt') + + ##### Loading PPLM Model ##### + alphabet = Alphabet.from_architecture() + batch_converter = alphabet.get_batch_converter() + model_data = torch.load(model_location, map_location="cpu") + model_param = model_data["param"] + model_state = model_data["model"] + + model = PPLM( + num_layers=model_param['encoder_layers'], + embed_dim=model_param['encoder_embed_dim'], + attention_heads=model_param['encoder_attention_heads'], + token_dropout=False, + alphabet=alphabet + ) + model.to(device) + model.load_state_dict(model_state, strict=False) + + with torch.no_grad(): + seqA, seqB = '', '' + for line in open(seqA_path, "r").readlines(): + if not line.startswith(">"): + seqA += line.strip() + for line in open(seqB_path, "r").readlines(): + if not line.startswith(">"): + seqB += line.strip() + + seqA_labels, seqA_strs, seqA_tokens = batch_converter([('seqA', seqA)]) + seqB_labels, seqB_strs, seqB_tokens = batch_converter([('seqB', seqB)]) + tokens = torch.cat([seqA_tokens, seqB_tokens], dim=-1).to(device) + + inter_chain_mask = torch.ones((len(seqA) + 2 + len(seqB) + 2, len(seqA) + 2 + len(seqB) + 2), device=device) + inter_chain_mask[:len(seqA) + 2, :len(seqA) + 2] = 0 + inter_chain_mask[len(seqA) + 2:, len(seqA) + 2:] = 0 + + ##### running PPLM ##### + out = model(tokens, inter_chain_mask, repr_layers=[33], need_head_weights=True, return_contacts=False) + + # attn_AA = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), 1:(len(seqA) + 1)].reshape(33 * 20, len(seqA), len(seqA)).cpu().numpy() + attn_AB = out['attentions'].squeeze()[:, :, 1:(len(seqA) + 1), -(len(seqB) + 1):-1].reshape(33 * 20, len(seqA), len(seqB)).cpu().numpy() + attn_BA = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, 1:(len(seqA) + 1)].reshape(33 * 20, len(seqB), len(seqA)).cpu().numpy() + # attn_BB = out['attentions'].squeeze()[:, :, -(len(seqB) + 1):-1, -(len(seqB) + 1):-1].reshape(33 * 20, len(seqB), len(seqB)).cpu().numpy() + inter_attn = (attn_AB + attn_BA.transpose(0, 2, 1)) / 2 + + with open(out_pkl_path, mode='wb') as fw: + pickle.dump(inter_attn, fw) + +def collect_all_features(): + + with open(target1_monomer_dist, "rb") as fr: + target1_M_dist = pickle.load(fr) + + target1_DCA_DI = np.expand_dims(np.loadtxt(target1_dca_di), 0) + target1_DCA_APC = np.expand_dims(np.loadtxt(target1_dca_apc), 0) + target1_PSSM = load_hmm(target1_hhm)['PSSM'] + + with open(target1_esm_msa, "rb") as fr: + esm_msa_data = pickle.load(fr) + target1_esm_msa_1d = esm_msa_data['esm_msa_1d'] + target1_esm_msa_2d = esm_msa_data['row_attentions'] + + print(target1_DCA_DI.shape, target1_DCA_APC.shape, target1_esm_msa_2d.shape, RBF(target1_M_dist).shape, target1_M_dist.shape) + intra1_1d = np.concatenate([target1_PSSM, target1_esm_msa_1d], axis=-1).transpose(1,0) + intra1_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, RBF(target1_M_dist)], axis=0) + intra1_Mdist = target1_M_dist + + with open(pplm_feat, "rb") as fr: + inter_pplm_attn = pickle.load(fr) + + if mode == "homo": + intra2_1d = intra1_1d + intra2_2d = intra1_2d + intra2_Mdist = intra1_Mdist + inter_2d = np.concatenate([target1_DCA_DI, target1_DCA_APC, target1_esm_msa_2d, inter_pplm_attn], axis=0) + + else: + with open(target2_monomer_dist, "rb") as fr: + target2_M_dist = pickle.load(fr) + + target2_DCA_DI = np.expand_dims(np.loadtxt(target2_dca_di), 0) + target2_DCA_APC = np.expand_dims(np.loadtxt(target2_dca_apc), 0) + target2_PSSM = load_hmm(target2_hhm)['PSSM'] + + with open(target2_esm_msa, "rb") as fr: + esm_msa_data = pickle.load(fr) + target2_esm_msa_1d = esm_msa_data['esm_msa_1d'] + target2_esm_msa_2d = esm_msa_data['row_attentions'] + + print(target2_DCA_DI.shape, target2_DCA_APC.shape, target2_esm_msa_2d.shape, RBF(target2_M_dist).shape, target2_M_dist.shape) + intra2_1d = np.concatenate([target2_PSSM, target2_esm_msa_1d], axis=-1).transpose(1, 0) + intra2_2d = np.concatenate([target2_DCA_DI, target2_DCA_APC, target2_esm_msa_2d, RBF(target2_M_dist)], axis=0) + intra2_Mdist = target2_M_dist + + inter_DCA_DI = np.expand_dims(np.loadtxt(paired_dca_di), 0) + inter_DCA_APC = np.expand_dims(np.loadtxt(paired_dca_apc), 0) + with open(paired_esm_msa, "rb") as fr: + esm_msa_data = pickle.load(fr) + inter_esm_msa_2d = esm_msa_data['row_attentions'] + + len1 = inter_pplm_attn.shape[-2] + len2 = inter_esm_msa_2d.shape[-1] + + print(inter_DCA_DI.shape, inter_DCA_APC.shape, inter_esm_msa_2d.shape, inter_pplm_attn.shape) + inter_2d = np.concatenate([inter_DCA_DI[:, :len1, len1:len1+len2], inter_DCA_APC[:, :len1, len1:len1+len2], inter_esm_msa_2d[:, :len1, len1:len1+len2], inter_pplm_attn], axis=0) + + feats = {"intra1_1d": intra1_1d, "intra1_2d": intra1_2d, "intra1_Mdist": intra1_Mdist, "intra2_1d": intra2_1d, "intra2_2d": intra2_2d, "intra2_Mdist": intra2_Mdist, "inter_2d": inter_2d} + + return feats + +def predict_contact(feats, mode, device="cpu"): + + script_dir = os.path.dirname(os.path.abspath(__file__)) + if mode == "homo": + model_paths = [os.path.join(script_dir, "models/pplm_contact.homo_" + str(i) + ".pkl") for i in range(1, 6)] + else: + model_paths = [os.path.join(script_dir, "models/pplm_contact.hetero_" + str(i) + ".pkl") for i in range(1, 6)] + + model = PPLM_Contact() + model.to(device) + + intra1_1d = torch.Tensor(feats["intra1_1d"]).to(device).type(torch.float32).unsqueeze(0) + intra1_2d = torch.Tensor(feats["intra1_2d"]).to(device).type(torch.float32).unsqueeze(0) + intra1_Mdist = torch.Tensor(feats["intra1_Mdist"]).to(device).type(torch.float32).unsqueeze(0) + intra2_1d = torch.Tensor(feats["intra2_1d"]).to(device).type(torch.float32).unsqueeze(0) + intra2_2d = torch.Tensor(feats["intra2_2d"]).to(device).type(torch.float32).unsqueeze(0) + intra2_Mdist = torch.Tensor(feats["intra2_Mdist"]).to(device).type(torch.float32).unsqueeze(0) + inter_2d = torch.Tensor(feats["inter_2d"]).to(device).type(torch.float32).unsqueeze(0) + + ensemble_pred_inter_contact = [] + with torch.no_grad(): + for model_path in model_paths: + + checkpoint = torch.load(model_path, map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + model.eval() + + #################################### Network predict ####################################### + contact_pred = model(intra1_1d, intra1_2d, intra2_1d, intra2_2d, inter_2d, intra1_Mdist, intra2_Mdist) + + pred_inter_contact = contact_pred # [inter_contact_mask_ur] + ensemble_pred_inter_contact.append(pred_inter_contact) + + pred_inter_contact = torch.stack(ensemble_pred_inter_contact) + pred_inter_contact = torch.mean(pred_inter_contact, dim=0).squeeze().cpu().detach().numpy() + + return pred_inter_contact + + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pplm_contact/utils.py b/pplm_contact/utils.py new file mode 100644 index 0000000..25512db --- /dev/null +++ b/pplm_contact/utils.py @@ -0,0 +1,225 @@ +import numpy as np +import pickle +import string + +restype_3to1 = {k: v for k, v in zip(['ALA', 'ARG', 'ASN', 'ASP', 'CYS', 'GLN', 'GLU', 'GLY', 'HIS', 'ILE', 'LEU', 'LYS', 'MET', 'PHE', 'PRO', 'SER', 'THR', 'TRP', 'TYR', 'VAL'], 'ARNDCEQGHILKMFPSTWYV')} + +heavy_atoms = ['C', 'CA', 'CB', 'CD', 'CD1', 'CD2', 'CE', 'CE1', 'CE2', 'CE3', 'CG', 'CG1', 'CG2', 'CH2', 'CZ', 'CZ2', 'CZ3', 'N', 'ND1', 'ND2', 'NE', 'NE1', 'NE2', 'NH1', 'NH2', 'NZ', 'O', 'OD1', 'OD2', 'OE1', 'OE2', 'OG', 'OG1', 'OH', 'OXT', 'SD', 'SG'] + +def pdb2seq(pdb_path, seq_path): + sequence_list = [] + sequence = "" + last_chain_id = "" + for data in open(pdb_path, 'r').readlines(): + atom_name = data[13:16].strip() + if data.startswith("ATOM") and atom_name == 'CA': + chain_id = data[21] + res_name = data[17:20].strip() + + if chain_id != last_chain_id and len(sequence) > 0: + sequence_list.append([last_chain_id, sequence]) + sequence = restype_3to1[res_name] + else: + sequence += restype_3to1[res_name] + + last_chain_id = chain_id + + if len(sequence) > 0: + sequence_list.append([last_chain_id, sequence]) + + if len(sequence_list) > 1: + print("Warning:", pdb_path, "contain multiple chains:", sequence_list[:, 0], "! Only first chain is considered.") + + return sequence_list[0][1] + + with open(seq_path, 'w') as fw: + chain_id = sequence_list[0][0] + sequence = sequence_list[0][1] + fw.write(">seq_" + chain_id + " " + str(len(sequence)) + "\n") + fw.write(sequence + "\n") + +def extract_seq_and_dist_map(pdb_path, seq_path, dist_path): + ### Load pdb_chain ### + pdb_res_coordis = [] + res_atom_coordis = {} + res_idx_type = [] + sequence = '' + last_res_idx = -1 + last_res_name = '' + for data in open(pdb_path, 'r').readlines(): + if data.startswith("ATOM"): + atom_name = data[13:16].strip() + res_name = data[17:20].strip() + res_idx = int(data[22:26].strip()) + coordi_x = float(data[30:38].strip()) + coordi_y = float(data[38:46].strip()) + coordi_z = float(data[46:54].strip()) + if last_res_idx != -1 and res_idx != last_res_idx: + pdb_res_coordis.append(res_atom_coordis) + res_atom_coordis = {} + sequence += restype_3to1[last_res_name] + res_idx_type.append([last_res_idx, last_res_name]) + res_atom_coordis[atom_name] = [coordi_x, coordi_y, coordi_z] + last_res_idx = res_idx + last_res_name = res_name + if len(res_atom_coordis) != 0: + pdb_res_coordis.append(res_atom_coordis) + res_atom_coordis = {} + sequence += restype_3to1[last_res_name] + res_idx_type.append([last_res_idx, last_res_name]) + + length = len(sequence) + + ############## extract heavy atom distance ################## + heavy_atom_dist_map = np.ones((1, length, length)) * np.inf + for i in range(length): + for j in range(i, length): + min_dist = np.inf + for heay_i in heavy_atoms: + if heay_i in pdb_res_coordis[i]: + coordi_1 = pdb_res_coordis[i][heay_i] + else: + continue + for heay_j in heavy_atoms: + if heay_j in pdb_res_coordis[j]: + coordi_2 = pdb_res_coordis[j][heay_j] + else: + continue + dist = np.sqrt(pow(coordi_1[0] - coordi_2[0], 2) + pow(coordi_1[1] - coordi_2[1], 2) + pow(coordi_1[2] - coordi_2[2], 2)) + if dist < min_dist: + min_dist = dist + heavy_atom_dist_map[0, i, j] = min_dist + heavy_atom_dist_map[0, j, i] = min_dist + + with open(dist_path, mode='wb') as fw: + pickle.dump(heavy_atom_dist_map, fw) + + with open(seq_path, 'w') as fw: + fw.write(">seq " + str(length) + "\n") + fw.write(sequence + "\n") + + return res_idx_type + +def pairing_msa(msa1_path, msa2_path, paired_msa_path): + msas1, sid1 = extract_taxid(msa1_path) + msas2, sid2 = extract_taxid(msa2_path) + aligns = alignment(msas1, sid1, msas2, sid2, top=True) + + with open(paired_msa_path, 'w') as f: + f.write(">target " + str(len(aligns[0])) + "\n") + f.write(aligns[0] + "\n") + + for idx, aligned_seq in enumerate(aligns[1:]): + f.write(">seq" + str(idx+1) + "\n") + f.write(aligned_seq + "\n") +def extract_taxid(file, gap_cutoff=0.8): + deletekeys = dict.fromkeys(string.ascii_lowercase) + deletekeys["."] = None + deletekeys["*"] = None + translation = str.maketrans(deletekeys) + + lines = open(file, 'r').readlines() + query = lines[1].strip().translate(translation) + seq_len = len(query) + + msas = [query] + sid = [0] + for line in lines[2:]: + + if line[0] == ">": + if "TaxID=" in line: + content = line.split("TaxID=")[1] + if len(content) > 0: + try: + sid.append(int(content.split()[0])) + except: + sid.append(0) + elif "OX=" in line: + content = line.split("OX=")[1] + if len(content) > 0: + try: + sid.append(int(content.split()[0])) + except: + sid.append(0) + else: + sid.append(0) + continue + + seq = line.strip().translate(translation) + gap_fra = float(seq.count('-')) / seq_len + if gap_fra <= gap_cutoff: + msas.append(seq) + else: + sid.pop(-1) + + if len(msas) != len(sid): + print("ERROR: len(msas) != len(sid)") + print(len(msas), len(sid)) + exit() + + return msas, np.array(sid) + +def cal_identity(query, sub_msas): + """ + Args: + query : str + sub_msas : List[str] + Return: + identity : np.array + """ + + identity = np.zeros((len(sub_msas))) + seq_len = len(query) + ones = np.ones(seq_len) + for idx, seq in enumerate(sub_msas): + match = [query[i] == seq[i] for i in range(seq_len)] + counts = np.sum(ones[match]) + identity[idx] = counts / seq_len + + return identity + +def alignment(msas1, sid1, msas2, sid2, top=True): + # obtain the same species and delete species=0 + smatch = np.intersect1d(sid1, sid2) + smatch = smatch[np.argsort(smatch)] + smatch = np.delete(smatch, 0) + + query1 = msas1[0] + query2 = msas2[0] + aligns = [query1 + query2] + + for id in smatch: + + index1 = np.where(sid1 == id)[0] + sub_msas1 = [msas1[idx] for idx in index1] + identity1 = cal_identity(query1, sub_msas1) + sort_idx1 = np.argsort(-identity1) + + index2 = np.where(sid2 == id)[0] + sub_msas2 = [msas2[idx] for idx in index2] + identity2 = cal_identity(query2, sub_msas2) + sort_idx2 = np.argsort(-identity2) + + if top == True: + aligns.append(sub_msas1[sort_idx1[0]] + \ + sub_msas2[sort_idx2[0]]) + else: + num = min(len(sub_msas1), len(sub_msas2)) + for i in range(num): + aligns.append(sub_msas1[sort_idx1[i]] + \ + sub_msas2[sort_idx2[i]]) + + return aligns + + +def RBF(dist_map): + # Radial Basis Function + D_min, D_max, D_count = 2., 22., 64 + D_mu = np.linspace(D_min, D_max, D_count) + D_mu = D_mu[None,:] + D_sigma = (D_max - D_min) / D_count + + dist_map = dist_map.transpose(1,2,0) + RBF = np.exp(-((dist_map - D_mu) / D_sigma)**2) + + return RBF.transpose(2,0,1)