mirror of
https://github.com/jertubiana/ScanNet.git
synced 2026-06-04 13:44:22 +08:00
- Added scripts for visualizing 3D filters. - Added scripts for visualizing 3D neighborhoods. - Change predict_feature.py API (added single input support).
818 lines
35 KiB
Python
818 lines
35 KiB
Python
import os
|
|
cores = 5 # Set number of CPUs to use!
|
|
if __name__ == '__main__':
|
|
os.environ["MKL_NUM_THREADS"] = "%s" % cores
|
|
os.environ["NUMEXPR_NUM_THREADS"] = "%s" % cores
|
|
os.environ["OMP_NUM_THREADS"] = "%s" % cores
|
|
os.environ["OPENBLAS_NUM_THREADS"] = "%s" % cores
|
|
os.environ["VECLIB_MAXIMUM_THREADS"] = "%s" % cores
|
|
os.environ['NUMBA_DEFAULT_NUM_THREADS'] = "%s" % cores
|
|
os.environ["NUMBA_NUM_THREADS"] = "%s" % cores
|
|
|
|
from preprocessing import pipelines,PDB_processing,sequence_utils,PDBio
|
|
from utilities import wrappers, chimera
|
|
import numpy as np
|
|
import argparse
|
|
from keras.models import Model
|
|
from utilities.paths import structures_folder,MSA_folder,predictions_folder,path2hhblits,path2sequence_database,model_folder
|
|
|
|
|
|
pipeline_MSA = pipelines.ScanNetPipeline(
|
|
with_aa=True,
|
|
with_atom=True,
|
|
aa_features='pwm',
|
|
atom_features='valency',
|
|
aa_frames='triplet_sidechain',
|
|
Beff=500,
|
|
)
|
|
|
|
pipeline_noMSA = pipelines.ScanNetPipeline(
|
|
with_aa=True,
|
|
with_atom=True,
|
|
aa_features='sequence',
|
|
atom_features='valency',
|
|
aa_frames='triplet_sidechain',
|
|
Beff=500,
|
|
)
|
|
|
|
|
|
interface_model_name_MSA = 'ScanNet_interface'
|
|
interface_model_MSA = 'ScanNet_PPI'
|
|
|
|
|
|
interface_model_name_noMSA = 'ScanNet_interface_noMSA'
|
|
interface_model_noMSA = 'ScanNet_PPI_noMSA'
|
|
|
|
epitope_model_name_MSA = 'ScanNet_epitope'
|
|
epitope_model_MSA = ['ScanNet_PAI_%s'%index for index in range(5)]
|
|
|
|
epitope_model_name_noMSA = 'ScanNet_epitope_noMSA'
|
|
epitope_model_noMSA = ['ScanNet_PAI_noMSA_%s'%index for index in range(5)]
|
|
|
|
|
|
idp_model_name_MSA = 'ScanNet_idp'
|
|
idp_model_MSA = ['ScanNet_PIDPI_%s'%index for index in range(5)]
|
|
|
|
idp_model_name_noMSA = 'ScanNet_idp_noMSA'
|
|
idp_model_noMSA = ['ScanNet_PIDPI_noMSA_%s'%index for index in range(5)]
|
|
|
|
|
|
|
|
interface_model_folder = model_folder
|
|
epitope_model_folder = model_folder
|
|
idp_model_folder = model_folder
|
|
|
|
|
|
default_pipeline = pipeline_MSA
|
|
default_model = interface_model_MSA
|
|
default_model_name = interface_model_name_MSA
|
|
model_folder = interface_model_folder
|
|
|
|
def write_predictions(csv_file, residue_ids, sequence, interface_prediction):
|
|
L = len(residue_ids)
|
|
columns = ['Model','Chain','Residue Index','Sequence']
|
|
if interface_prediction.ndim == 1:
|
|
columns.append('Binding site probability')
|
|
else:
|
|
columns += ['Output %s' %i for i in range(interface_prediction.shape[-1] )]
|
|
|
|
with open(csv_file, 'w') as f:
|
|
f.write(','.join(columns) + '\n' )
|
|
for i in range(L):
|
|
string = '%s,%s,%s,%s,' % (residue_ids[i][0],
|
|
residue_ids[i][1],
|
|
residue_ids[i][2],
|
|
sequence[i])
|
|
if interface_prediction.ndim == 1:
|
|
string += '%.3f'%interface_prediction[i]
|
|
else:
|
|
string += ','.join(['%.3f'%value for value in interface_prediction[i]])
|
|
f.write(string + '\n')
|
|
return
|
|
|
|
|
|
def predict_interface_residues(
|
|
query_pdbs='1a3x',
|
|
query_chain_ids=None,
|
|
query_sequences=None,
|
|
query_names=None,
|
|
pipeline=default_pipeline,
|
|
model=default_model,
|
|
model_name=default_model_name,
|
|
model_folder=model_folder,
|
|
structures_folder=structures_folder,
|
|
predictions_folder=predictions_folder,
|
|
query_MSAs=None,
|
|
query_PWMs=None,
|
|
MSA_folder=MSA_folder,
|
|
logfile=None,
|
|
biounit=True,
|
|
assembly=True,
|
|
layer=None,
|
|
use_MSA=True,
|
|
overwrite_MSA=False,
|
|
Lmin=1,
|
|
output_predictions=True,
|
|
output_chimera='annotation',
|
|
chimera_thresholds = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7],
|
|
permissive=False,
|
|
output_format='numpy'):
|
|
|
|
if not os.path.isdir(MSA_folder):
|
|
os.mkdir(MSA_folder)
|
|
if not os.path.isdir(predictions_folder):
|
|
os.mkdir(predictions_folder)
|
|
if use_MSA:
|
|
assert os.path.exists(
|
|
path2hhblits), 'HHblits not found at %s!!' % path2hhblits
|
|
|
|
if query_pdbs is not None:
|
|
try: # Check whether chain_ids is a list of pdb/chains or a single pdb/chain
|
|
assert len(query_pdbs[0]) > 1
|
|
except:
|
|
query_pdbs = [query_pdbs]
|
|
print('Predicting binding sites from pdb structures with %s' %model_name,file=logfile)
|
|
predict_from_pdb = True
|
|
predict_from_sequence = False
|
|
npdbs = len(query_pdbs)
|
|
|
|
if query_chain_ids is None:
|
|
query_chain_ids = ['all' for _ in query_pdbs]
|
|
else:
|
|
if not ( (query_chain_ids[0] in ['all','upper','lower'] ) | (isinstance(query_chain_ids[0],list)) ):
|
|
query_chain_ids = [query_chain_ids]
|
|
|
|
elif query_sequences is not None:
|
|
try: # Check whether sequences is a list of pdb/chains or a single pdb/chain
|
|
assert len(query_sequences[0]) > 1
|
|
except:
|
|
query_sequences = [query_sequences]
|
|
print('Predicting interface residues from sequences using %s' %
|
|
model_name, file=logfile)
|
|
predict_from_pdb = False
|
|
predict_from_sequence = True
|
|
nqueries = len(query_sequences)
|
|
else:
|
|
print('No input provided for interface prediction using %s' %
|
|
model_name, file=logfile)
|
|
return
|
|
|
|
if query_names is None:
|
|
if predict_from_pdb:
|
|
query_names = []
|
|
for i in range(npdbs):
|
|
pdb = query_pdbs[i].split('/')[-1].split('.')[0]
|
|
query_names.append(pdb)
|
|
elif predict_from_sequence:
|
|
sequence_lengths = [len(sequence) for sequence in query_sequences]
|
|
first_aa = [sequence[:5] for sequence in query_sequences]
|
|
query_names = ['seq_%s_start:%s_L:%s' % (
|
|
i, first_aa[i], sequence_lengths[i]) for i in range(nqueries)]
|
|
if use_MSA:
|
|
if query_MSAs is None:
|
|
query_MSAs = [None for _ in query_names]
|
|
if query_PWMs is None:
|
|
query_PWMs = [None for _ in query_names]
|
|
|
|
|
|
# Locate pdb files or download from pdb server.
|
|
if predict_from_pdb:
|
|
pdb_file_locations = []
|
|
i = 0
|
|
while i < npdbs:
|
|
pdb_id = query_pdbs[i]
|
|
location,chain = PDBio.getPDB(pdb_id,biounit=biounit,structures_folder=structures_folder)
|
|
if not os.path.exists(location):
|
|
print('i=%s,file:%s not found' %
|
|
(i, location), file=logfile)
|
|
if permissive & (npdbs > 1):
|
|
del query_pdbs[i]
|
|
del query_chain_ids[i]
|
|
del query_names[i]
|
|
if use_MSA:
|
|
del query_MSAs[i]
|
|
del query_PWMs[i]
|
|
npdbs -= 1
|
|
else:
|
|
return
|
|
else:
|
|
pdb_file_locations.append(location)
|
|
i += 1
|
|
|
|
# Parse pdb files.
|
|
query_chain_objs = []
|
|
query_chain_names = []
|
|
query_chain_id_is_alls = [query_chain_id == 'all' for query_chain_id in query_chain_ids]
|
|
|
|
i = 0
|
|
while i < npdbs:
|
|
try:
|
|
_, chain_objs = PDBio.load_chains(
|
|
chain_ids= query_chain_ids[i], file=pdb_file_locations[i])
|
|
|
|
if query_chain_ids[i] == 'all':
|
|
query_chain_ids[i] = [(chain_obj.get_full_id()[1], chain_obj.get_full_id()[2])
|
|
for chain_obj in chain_objs]
|
|
elif query_chain_ids[i] == 'upper':
|
|
query_chain_ids[i] = [(chain_obj.get_full_id()[1], chain_obj.get_full_id()[2])
|
|
for chain_obj in chain_objs if (chain_obj.get_full_id()[2].isupper() | (chain_obj.get_full_id()[2] == ' ') )]
|
|
elif query_chain_ids[i] == 'lower':
|
|
query_chain_ids[i] = [(chain_obj.get_full_id()[1], chain_obj.get_full_id()[2])
|
|
for chain_obj in chain_objs if chain_obj.get_full_id()[2].islower()]
|
|
|
|
query_chain_objs.append(chain_objs)
|
|
|
|
query_chain_names.append([query_names[i] + '_%s_%s' %
|
|
query_chain_id for query_chain_id in query_chain_ids[i]])
|
|
|
|
i += 1
|
|
except:
|
|
print('Failed to parse i=%s,%s, %s' %
|
|
(i, query_names[i], pdb_file_locations[i]), file=logfile)
|
|
if permissive & (npdbs > 1):
|
|
del query_pdbs[i]
|
|
del query_chain_ids[i]
|
|
del query_names[i]
|
|
if use_MSA:
|
|
del query_MSAs[i]
|
|
del query_PWMs[i]
|
|
npdbs -= 1
|
|
else:
|
|
return
|
|
|
|
query_sequences = [[PDB_processing.process_chain(chain_obj)[0]
|
|
for chain_obj in chain_objs] for chain_objs in query_chain_objs]
|
|
|
|
if Lmin > 0:
|
|
for i in range(npdbs):
|
|
j = 0
|
|
nsequences = len(query_sequences[i])
|
|
while j < nsequences:
|
|
sequence = query_sequences[i][j]
|
|
if len(sequence) < Lmin:
|
|
print('Chain %s %s from PDB %s is too short (L=%s), discarding.' % (
|
|
query_chain_ids[i][j][0], query_chain_ids[i][j][1], query_pdbs[i], len(sequence)), file=logfile)
|
|
del query_sequences[i][j]
|
|
del query_chain_ids[i][j]
|
|
del query_chain_objs[i][j]
|
|
del query_chain_names[i][j]
|
|
# Assumes that the MSA/PWM input was provided without the small chains...
|
|
# if use_MSA:
|
|
# if (query_MSAs[i] is not None) & isinstance(query_MSAs[i], list):
|
|
# del query_MSAs[i][j]
|
|
# if (query_PWMs[i] is not None) & isinstance(query_PWMs[i], list):
|
|
# del query_PWMs[i][j]
|
|
nsequences -= 1
|
|
else:
|
|
j += 1
|
|
|
|
i = 0
|
|
while i < npdbs:
|
|
if not len(query_sequences[i]) > 0:
|
|
print('PDB %s has no chains remaining!' % (query_pdbs[i]))
|
|
if permissive & (npdbs > 1):
|
|
del query_pdbs[i]
|
|
del query_sequences[i]
|
|
del query_chain_ids[i]
|
|
del query_chain_objs[i]
|
|
del query_names[i]
|
|
if use_MSA:
|
|
del query_MSAs[i]
|
|
del query_PWMs[i]
|
|
del query_chain_names[i]
|
|
npdbs -= 1
|
|
else:
|
|
return
|
|
else:
|
|
i += 1
|
|
|
|
nqueries = npdbs
|
|
else:
|
|
query_chain_names = query_names
|
|
query_chain_objs = [None for _ in query_chain_names]
|
|
if query_chain_ids is None:
|
|
query_chain_ids = [('', '') for _ in query_chain_names]
|
|
nqueries = len(query_names)
|
|
|
|
print('List of inputs:', file=logfile)
|
|
for i in range(nqueries):
|
|
print(query_chain_names[i], file=logfile)
|
|
|
|
if use_MSA:
|
|
i = 0
|
|
while i < nqueries:
|
|
if query_PWMs[i] is not None:
|
|
if not isinstance(query_PWMs[i], list):
|
|
query_PWMs[i] = [query_PWMs[i]]
|
|
i +=1
|
|
elif query_MSAs[i] is not None:
|
|
if not isinstance(query_MSAs[i], list):
|
|
query_MSAs[i] = [query_MSAs[i]]
|
|
for j in range(len(query_MSAs[i])):
|
|
if not os.path.exists(query_MSAs[i][j]):
|
|
print('i=%s,file:%s not found' %
|
|
(i, query_MSAs[i][j]), file=logfile)
|
|
if permissive & (nqueries > 1):
|
|
if predict_from_pdb:
|
|
del query_pdbs[i]
|
|
del query_sequences[i]
|
|
del query_chain_ids[i]
|
|
del query_chain_objs[i]
|
|
del query_names[i]
|
|
del query_MSAs[i]
|
|
del query_PWMs[i]
|
|
del query_chain_names[i]
|
|
nqueries -= 1
|
|
break
|
|
else:
|
|
return
|
|
else:
|
|
if j == len(query_MSAs[i]) - 1:
|
|
i += 1
|
|
else:
|
|
query_PWMs[i] = []
|
|
query_MSAs[i] = []
|
|
if not isinstance(query_sequences[i], list):
|
|
query_sequences[i] = [query_sequences[i]]
|
|
if not isinstance(query_chain_names[i], list):
|
|
query_chain_names[i] = [query_chain_names[i]]
|
|
|
|
|
|
for j in range(len(query_sequences[i])):
|
|
target_location = MSA_folder + 'MSA_' + \
|
|
query_chain_names[i][j] + '.fasta'
|
|
sequence = query_sequences[i][j]
|
|
if not (os.path.exists(target_location) & (~overwrite_MSA)):
|
|
if sequence in query_sequences[i][:j]:
|
|
jseen = query_sequences[i][:j].index(sequence)
|
|
target_location = MSA_folder + 'MSA_' + \
|
|
query_chain_names[i][jseen] + '.fasta'
|
|
else:
|
|
print('i=%s,%s, no MSA found. Building it using HHblits' %
|
|
(i, query_chain_names[i][j]), file=logfile)
|
|
sequence_utils.call_hhblits(sequence, target_location,
|
|
path2hhblits=path2hhblits, path2sequence_database=path2sequence_database, cores=cores)
|
|
query_MSAs[i].append(target_location)
|
|
query_PWMs[i].append(None)
|
|
i+=1
|
|
|
|
else:
|
|
if assembly:
|
|
query_MSAs = [None for _ in range(nqueries)]
|
|
query_PWMs = [None for _ in range(nqueries)]
|
|
else:
|
|
query_MSAs = [[None for _ in query_chain_names[i]] for i in range(nqueries)]
|
|
query_PWMs = [[None for _ in query_chain_names[i]] for i in range(nqueries)]
|
|
|
|
sequence_lengths = [[len(sequence) for sequence in sequences]
|
|
for sequences in query_sequences]
|
|
if assembly:
|
|
assembly_lengths = [sum(sequence_length)
|
|
for sequence_length in sequence_lengths]
|
|
Lmax = max(assembly_lengths)
|
|
else:
|
|
Lmax = max([max(sequence_length)
|
|
for sequence_length in sequence_lengths])
|
|
Lmax = max(Lmax,32)
|
|
|
|
|
|
|
|
query_residue_ids =[]
|
|
query_sequences=[''.join(sequences) for sequences in query_sequences]
|
|
for i, chain_objs in enumerate(query_chain_objs):
|
|
if chain_objs is not None:
|
|
residue_ids = PDB_processing.get_PDB_indices(chain_objs, return_chain=True, return_model=True)
|
|
else:
|
|
model_indices=[' ' for _ in query_sequences[i]]
|
|
chain_indices=[' ' for _ in query_sequences[i]]
|
|
residue_indices= ['%s'%i for i in range(1, len(query_sequences[i]) + 1) ]
|
|
residue_ids = np.concatenate(
|
|
np.array(model_indices)[:,np.newaxis],
|
|
np.array(chain_indices)[:,np.newaxis],
|
|
np.array(residue_indices)[:,np.newaxis],
|
|
axis = 1
|
|
)
|
|
|
|
query_residue_ids.append( residue_ids)
|
|
|
|
print('Loading model %s' % model_name, file=logfile)
|
|
if isinstance(model,list):
|
|
multi_models = True
|
|
aggregate_models = True
|
|
model_objs = [wrappers.load_model(model_folder + model_, Lmax=Lmax) for model_ in model]
|
|
elif isinstance(layer,list):
|
|
multi_models = True
|
|
aggregate_models = False
|
|
model_objs = [wrappers.load_model(model_folder + model, Lmax=Lmax) for l in layer]
|
|
else:
|
|
multi_models = False
|
|
aggregate_models = True
|
|
model_obj = wrappers.load_model(model_folder + model, Lmax=Lmax)
|
|
|
|
if layer is not None:
|
|
if isinstance(layer,list):
|
|
for l,model_obj in zip(layer,model_objs):
|
|
if l is not None:
|
|
if l == 'attention_layer':
|
|
model_truncated = Model(inputs=model_obj.model.inputs,outputs=model_obj.model.get_layer(l).output[1])
|
|
else:
|
|
model_truncated = Model(inputs=model_obj.model.inputs,outputs=model_obj.model.get_layer(l).output)
|
|
model_obj.model = model_truncated
|
|
else:
|
|
if layer == 'attention_layer':
|
|
model_truncated = Model(inputs=model_obj.model.inputs, outputs=model_obj.model.get_layer(layer).output[1])
|
|
else:
|
|
model_truncated = Model(inputs=model_obj.model.inputs,outputs=model_obj.model.get_layer(layer).output)
|
|
model_obj.model = model_truncated
|
|
return_all = True
|
|
else:
|
|
return_all = False
|
|
|
|
|
|
if hasattr(pipeline, 'Lmax'):
|
|
pipeline.Lmax = Lmax
|
|
if hasattr(pipeline, 'Lmax_aa'):
|
|
pipeline.Lmax_aa = Lmax
|
|
if hasattr(pipeline, 'Lmax_atom'):
|
|
pipeline.Lmax_atom = 9* Lmax
|
|
|
|
if hasattr(pipeline, 'padded'):
|
|
padded = pipeline.padded
|
|
else:
|
|
padded = True
|
|
|
|
|
|
if assembly:
|
|
inputs = wrappers.stack_list_of_arrays(
|
|
[pipeline.process_example(chain_obj=chain_obj, sequence=sequence, MSA_file=MSA_file_location,PWM=PWM)[0]
|
|
for chain_obj, sequence, MSA_file_location,PWM in zip(query_chain_objs, query_sequences, query_MSAs,query_PWMs)], padded=padded)
|
|
if multi_models:
|
|
if aggregate_models:
|
|
query_predictions = model_objs[0].predict(inputs, batch_size=1,return_all=return_all)
|
|
for model_obj in model_objs[1:]:
|
|
predictions = model_obj.predict(inputs, batch_size=1,return_all=return_all)
|
|
query_predictions = [prediction1 + prediction2 for prediction1,prediction2 in zip(query_predictions,predictions)]
|
|
query_predictions = np.array([prediction/len(model_objs) for prediction in query_predictions])
|
|
else:
|
|
tmp = [model_obj.predict(inputs, batch_size=1,return_all=return_all) for model_obj in model_objs]
|
|
query_predictions = [ [tmp[k][l] for k in range(len(tmp))] for l in range(len(tmp[0])) ]
|
|
else:
|
|
query_predictions = model_obj.predict(inputs, batch_size=1,return_all=return_all)
|
|
|
|
if padded:
|
|
query_predictions = wrappers.truncate_list_of_arrays(
|
|
query_predictions, assembly_lengths)
|
|
|
|
has_attention_layer = False
|
|
if layer == 'attention_layer':
|
|
has_attention_layer = True
|
|
elif isinstance(layer,list):
|
|
has_attention_layer = 'attention_layer' in layer
|
|
|
|
if has_attention_layer:
|
|
'''
|
|
Output the aggregated attention coefficient for each node (= node importance, potential hotspot detector).
|
|
1. Recompute, for each aa, the indices of its K neighbors using Calpha coordinates.
|
|
2. Compute the degree of each aa by summing its contribution to all other amino acids.
|
|
3. Put back into predictions.
|
|
'''
|
|
calpha_coordinates = [inputs[3][n].astype(np.float32)[inputs[0][n].astype(np.int)[..., 0]] for n in range(len(inputs[0]))]
|
|
K_graph = model_obj.kwargs['K_graph']
|
|
neighborhood_graphs = [np.argsort(PDB_processing.distance(calpha_coordinate,calpha_coordinate), axis=1)[:,:K_graph] for calpha_coordinate in calpha_coordinates]
|
|
if layer == 'attention_layer':
|
|
attention_coeffs = query_predictions
|
|
else:
|
|
index = layer.index('attention_layer')
|
|
attention_coeffs = [prediction[index] for prediction in query_predictions]
|
|
|
|
aggregated_attention_coeffs = []
|
|
sign = np.sign(attention_coeffs[0][:, 0, 0]).mean()
|
|
if sign<0:
|
|
print('Warning, attention coeffs are flipped')
|
|
for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs):
|
|
aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32)
|
|
for s in range( len(attention_coeff) ):
|
|
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean(
|
|
-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
|
# aggregated_attention_coeff[neighborhood_graph[s]] += np.abs(attention_coeff[s][:len(neighborhood_graph[s])]).mean(1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
|
aggregated_attention_coeffs.append(aggregated_attention_coeff)
|
|
aggregated_attention_coeffs = np.array(aggregated_attention_coeffs)
|
|
if layer == 'attention_layer':
|
|
query_predictions = aggregated_attention_coeffs
|
|
else:
|
|
for query_prediction,aggregated_attention_coeff in zip(query_predictions,aggregated_attention_coeffs):
|
|
query_prediction[index] = aggregated_attention_coeff
|
|
|
|
|
|
else:
|
|
query_predictions = []
|
|
for i in range(nqueries):
|
|
inputs = wrappers.stack_list_of_arrays(
|
|
[pipeline.process_example(chain_obj=chain_obj, sequence=sequence, MSA_file=MSA_file_location,PWM=PWM)[0]
|
|
for chain_obj, sequence, MSA_file_location,PWM in zip(query_chain_objs[i], query_sequences[i], query_MSAs[i],query_PWMs[i])], padded=padded)
|
|
if multi_models:
|
|
if aggregate_models:
|
|
predictions = model_objs[0].predict(inputs, batch_size=1,return_all=return_all)
|
|
for model_obj in model_objs[1:]:
|
|
predictions_ = model_obj.predict(inputs, batch_size=1,return_all=return_all)
|
|
predictions = [prediction1 + prediction2 for prediction1,prediction2 in zip(predictions,predictions_)]
|
|
predictions = np.array([prediction/len(model_objs) for prediction in predictions])
|
|
else:
|
|
predictions = [model_obj.predict(inputs, batch_size=1,return_all=return_all) for model_obj in model_objs]
|
|
else:
|
|
predictions = model_obj.predict(inputs, batch_size=1,return_all=return_all)
|
|
|
|
if padded:
|
|
predictions = wrappers.truncate_list_of_arrays(
|
|
predictions, sequence_lengths[i])
|
|
|
|
has_attention_layer = False
|
|
if layer == 'attention_layer':
|
|
has_attention_layer = True
|
|
elif isinstance(layer,list):
|
|
has_attention_layer = 'attention_layer' in layer
|
|
|
|
if has_attention_layer:
|
|
'''
|
|
Output the aggregated attention coefficient for each node (= node importance, potential hotspot detector).
|
|
1. Recompute, for each aa, the indices of its K neighbors using Calpha coordinates.
|
|
2. Compute the degree of each aa by summing its contribution to all other amino acids.
|
|
3. Put back into predictions.
|
|
'''
|
|
|
|
calpha_coordinates = [inputs[3][n].astype(np.float32)[inputs[0][n].astype(np.int)[..., 0]] for n in range(len(inputs[0]))]
|
|
K_graph = model_obj.kwargs['K_graph']
|
|
neighborhood_graphs = [np.argsort(PDB_processing.distance(calpha_coordinate,calpha_coordinate), axis=1)[:,:K_graph] for calpha_coordinate in calpha_coordinates]
|
|
if layer == 'attention_layer':
|
|
attention_coeffs = predictions
|
|
else:
|
|
index = layer.index('attention_layer')
|
|
attention_coeffs = predictions[index]
|
|
|
|
sign = np.sign(attention_coeffs[0][:, 0, 0]).mean()
|
|
if sign < 0:
|
|
print('Warning, attention coeffs are flipped')
|
|
aggregated_attention_coeffs = []
|
|
for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs):
|
|
aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32)
|
|
for s in range( len(attention_coeff) ):
|
|
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
|
# aggregated_attention_coeff[neighborhood_graph[s]] += np.abs(attention_coeff[s][:len(neighborhood_graph[s])] ).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
|
aggregated_attention_coeffs.append(aggregated_attention_coeff)
|
|
aggregated_attention_coeffs = np.array(aggregated_attention_coeffs)
|
|
if layer == 'attention_layer':
|
|
predictions = aggregated_attention_coeffs
|
|
else:
|
|
predictions[index] = aggregated_attention_coeffs
|
|
|
|
|
|
if aggregate_models:
|
|
query_predictions.append(
|
|
np.concatenate(predictions, axis=0)
|
|
)
|
|
else:
|
|
query_predictions.append(
|
|
[np.concatenate(prediction, axis=0) for prediction in predictions]
|
|
)
|
|
|
|
|
|
output_folder=predictions_folder + '/'
|
|
if not os.path.isdir(output_folder):
|
|
os.mkdir(output_folder)
|
|
|
|
if output_predictions:
|
|
for i in range(nqueries):
|
|
res_ids = query_residue_ids[i]
|
|
sequence = query_sequences[i]
|
|
predictions = query_predictions[i]
|
|
query_name = query_names[i]
|
|
query_chain = query_chain_ids[i]
|
|
query_chain_id_is_all = query_chain_id_is_alls[i]
|
|
query_pdb = query_pdbs[i]
|
|
file_is_cif = (pdb_file_locations[i][-4:] == '.cif')
|
|
|
|
query_output_folder = output_folder+query_name
|
|
if (len(query_pdb) == 4) & biounit:
|
|
query_output_folder += '_biounit'
|
|
|
|
if not query_chain_id_is_all:
|
|
query_output_folder += '_(' + PDBio.format_chain_id(query_chain) + ')'
|
|
|
|
if not assembly:
|
|
query_output_folder += '_single'
|
|
|
|
query_output_folder += '_%s' % model_name
|
|
|
|
query_output_folder += '/'
|
|
if not os.path.isdir(query_output_folder):
|
|
os.mkdir(query_output_folder)
|
|
|
|
|
|
|
|
if not aggregate_models: # multioutput.
|
|
for layer_,prediction in zip(layer,predictions):
|
|
if layer_ is None:
|
|
prediction = prediction[:,1]
|
|
csv_file = query_output_folder + 'predictions_' + query_name + '.csv'
|
|
chimera_file = query_output_folder + 'chimera_' + query_names[i]
|
|
annotated_pdb_file = query_output_folder + 'annotated_' + query_names[i] + ('.cif' if file_is_cif else '.pdb')
|
|
else:
|
|
csv_file = query_output_folder + 'activity_%s_'%layer_ + query_name + '.csv'
|
|
chimera_file = query_output_folder + 'chimera_%s'%layer_ + query_names[i]
|
|
annotated_pdb_file = query_output_folder + 'annotated_%s'%layer_ + query_names[i] + ('.cif' if file_is_cif else '.pdb')
|
|
write_predictions(csv_file, res_ids,sequence, prediction)
|
|
if predict_from_pdb & (prediction.ndim == 1):
|
|
if output_chimera == 'script':
|
|
chimera.show_binding_sites(
|
|
query_pdbs[i], csv_file, chimera_file, biounit=biounit, directory='',thresholds=chimera_thresholds)
|
|
elif output_chimera == 'annotation':
|
|
if layer_ == 'attention_layer':
|
|
mini = 0.5
|
|
maxi = 2.5
|
|
else:
|
|
mini = 0
|
|
maxi = chimera_thresholds[-1]
|
|
chimera.annotate_pdb_file(pdb_file_locations[i], csv_file, annotated_pdb_file, output_script=True, mini=mini, maxi=maxi)
|
|
else:
|
|
if layer is None:
|
|
csv_file = query_output_folder + 'predictions_' + query_name + '.csv'
|
|
chimera_file = query_output_folder + 'chimera_' + query_names[i]
|
|
annotated_pdb_file = query_output_folder + 'annotated_' + query_names[i] + ('.cif' if file_is_cif else '.pdb')
|
|
else:
|
|
csv_file = query_output_folder + 'activity_%s_' % layer + query_name + '.csv'
|
|
chimera_file = query_output_folder + 'chimera_%s' % layer + query_names[i]
|
|
annotated_pdb_file = query_output_folder + 'annotated_%s' % layer + query_names[i] + ('.cif' if file_is_cif else '.pdb')
|
|
write_predictions(csv_file, res_ids, sequence, predictions)
|
|
if predict_from_pdb & (predictions.ndim == 1):
|
|
if output_chimera == 'script':
|
|
chimera.show_binding_sites(
|
|
query_pdbs[i], csv_file, chimera_file, biounit=biounit, directory='',thresholds=chimera_thresholds)
|
|
elif output_chimera == 'annotation':
|
|
if layer == 'attention_layer':
|
|
mini = 0.5
|
|
maxi = 2.5
|
|
else:
|
|
mini = 0
|
|
maxi = chimera_thresholds[-1]
|
|
|
|
chimera.annotate_pdb_file(pdb_file_locations[i], csv_file, annotated_pdb_file, output_script=True, mini=mini, maxi=maxi)
|
|
|
|
if output_format == 'dictionary':
|
|
query_dictionary_predictions = [PDB_processing.make_values_dictionary(query_residue_id,query_prediction) for query_residue_id,query_prediction in zip(query_residue_ids,query_predictions)]
|
|
return query_pdbs,query_names,query_dictionary_predictions
|
|
else:
|
|
return query_pdbs,query_names,query_predictions, query_residue_ids, query_sequences
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
parser = argparse.ArgumentParser(description='Predict binding sites in PDB files using Geometric Neural Network')
|
|
|
|
parser.add_argument('input', type=str,
|
|
help='Three input formats. i) A pdb id (1a3x)\
|
|
ii) Path to pdb file (structures/1a3x.pdb)\
|
|
iii) Path to text file containing list of pdb files (one per line) (1a3x \n 2kho \n ...) \
|
|
For performing prediction only on specfic chains, append "_" and the list of chains. (e.g. 1a3x_AB)')
|
|
|
|
parser.add_argument('--name',dest='name',
|
|
default='',
|
|
help='Input name')
|
|
|
|
|
|
parser.add_argument('--predictions_folder',dest='predictions_folder',
|
|
default=predictions_folder,
|
|
help='Input name')
|
|
|
|
parser.add_argument('--mode', dest='mode',
|
|
default='interface',
|
|
help='Prediction mode (interface, epitope)')
|
|
parser.add_argument('--noMSA', dest='use_MSA', action='store_const',
|
|
const = False, default = True,
|
|
help = 'Perform prediction without Multiple Sequence Alignments (less accurate, faster)'
|
|
)
|
|
parser.add_argument('--assembly',dest='assembly',action='store_const',
|
|
const = True, default = False,
|
|
help = 'Perform prediction from single chains or from biological assemblies')
|
|
|
|
parser.add_argument('--permissive',dest='permissive',action='store_const',
|
|
const=True,default=True,help='Permissive prediction')
|
|
|
|
parser.add_argument('--layer', dest='layer',
|
|
default='',
|
|
help='Choose output layer')
|
|
|
|
parser.add_argument('--pdb', dest='biounit',action='store_const',
|
|
const = False, default = True,
|
|
help='Predict from pdb file (default= predict from biounit file)')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
input = args.input
|
|
query_pdbs = []
|
|
query_chain_ids = []
|
|
if '.txt' in input:
|
|
with open(input,'r') as f:
|
|
for line in f:
|
|
pdb,chain_ids = PDBio.parse_str(line[:-1])
|
|
query_pdbs.append(pdb)
|
|
query_chain_ids.append(chain_ids)
|
|
else:
|
|
query_pdbs, query_chain_ids = PDBio.parse_str(input)
|
|
|
|
if args.name != '':
|
|
query_names = [args.name]
|
|
else:
|
|
query_names = None
|
|
|
|
predictions_folder = args.predictions_folder
|
|
|
|
if args.use_MSA:
|
|
pipeline = pipeline_MSA
|
|
else:
|
|
pipeline = pipeline_noMSA
|
|
|
|
if args.mode == 'interface':
|
|
model_folder = interface_model_folder
|
|
chimera_thresholds = [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
|
|
if args.use_MSA:
|
|
model_name = interface_model_name_MSA
|
|
model = interface_model_MSA
|
|
else:
|
|
model_name = interface_model_name_noMSA
|
|
model = interface_model_noMSA
|
|
elif args.mode == 'epitope':
|
|
model_folder = epitope_model_folder
|
|
chimera_thresholds = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35]
|
|
if args.use_MSA:
|
|
model_name = epitope_model_name_MSA
|
|
model = epitope_model_MSA
|
|
else:
|
|
model_name = epitope_model_name_noMSA
|
|
model = epitope_model_noMSA
|
|
|
|
elif args.mode[:-1] == 'epitope': # epitope1, epitope2, epitope3, epitope4, epitope5
|
|
fold = int(args.mode[-1]) - 1
|
|
model_folder = epitope_model_folder
|
|
chimera_thresholds = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35]
|
|
if args.use_MSA:
|
|
model_name = epitope_model_name_MSA + str(args.mode[-1])
|
|
model = epitope_model_MSA[fold]
|
|
else:
|
|
model_name = epitope_model_name_noMSA + str(args.mode[-1])
|
|
model = epitope_model_noMSA[fold]
|
|
|
|
elif args.mode == 'idp':
|
|
model_folder = idp_model_folder
|
|
chimera_thresholds = [0.05,0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
|
if args.use_MSA:
|
|
model_name = idp_model_name_MSA
|
|
model = idp_model_MSA
|
|
else:
|
|
model_name = idp_model_name_noMSA
|
|
model = idp_model_noMSA
|
|
|
|
elif args.mode[:-1] == 'idp': # idp1, idp2, idp3, idp4, idp5
|
|
fold = int(args.mode[-1]) - 1
|
|
model_folder = idp_model_folder
|
|
chimera_thresholds = [0.05,0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
|
|
if args.use_MSA:
|
|
model_name = idp_model_name_MSA + str(args.mode[-1])
|
|
model = idp_model_MSA[fold]
|
|
else:
|
|
model_name = idp_model_name_noMSA + str(args.mode[-1])
|
|
model = idp_model_noMSA[fold]
|
|
else:
|
|
raise ValueError('Mode %s not supported'%args.mode)
|
|
|
|
if args.layer == '':
|
|
layer = None
|
|
else:
|
|
layer = args.layer
|
|
if '+' in layer:
|
|
layer = layer.split('+')
|
|
for i in range(len(layer)):
|
|
if layer[i] in ['','output','probability']:
|
|
layer[i] = None
|
|
|
|
predict_interface_residues(
|
|
query_pdbs=query_pdbs,
|
|
query_chain_ids=query_chain_ids,
|
|
query_names=query_names,
|
|
pipeline=pipeline,
|
|
model=model,
|
|
model_name=model_name,
|
|
model_folder=model_folder,
|
|
structures_folder=structures_folder,
|
|
predictions_folder=predictions_folder,
|
|
MSA_folder=MSA_folder,
|
|
biounit=args.biounit,
|
|
assembly=args.assembly,
|
|
overwrite_MSA=False,
|
|
permissive=args.permissive,
|
|
use_MSA=args.use_MSA,
|
|
chimera_thresholds=chimera_thresholds,
|
|
layer=layer
|
|
)
|