mirror of
https://github.com/jertubiana/ScanNet.git
synced 2026-06-04 13:44:22 +08:00
Support for attention coefficient extraction
This commit is contained in:
@@ -74,7 +74,7 @@ def write_predictions(csv_file, residue_ids, sequence, interface_prediction):
|
||||
if interface_prediction.ndim == 1:
|
||||
columns.append('Binding site probability')
|
||||
else:
|
||||
columns += ['Output %s' %i for i in range(len(interface_prediction) )]
|
||||
columns += ['Output %s' %i for i in range(interface_prediction.shape[-1] )]
|
||||
|
||||
with open(csv_file, 'w') as f:
|
||||
f.write(','.join(columns) + '\n' )
|
||||
@@ -492,7 +492,7 @@ def predict_interface_residues(
|
||||
for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs):
|
||||
aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32)
|
||||
for s in range( len(attention_coeff) ):
|
||||
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s],0).mean(
|
||||
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean(
|
||||
-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
||||
# aggregated_attention_coeff[neighborhood_graph[s]] += np.abs(attention_coeff[s]).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
||||
aggregated_attention_coeffs.append(aggregated_attention_coeff)
|
||||
@@ -556,7 +556,7 @@ def predict_interface_residues(
|
||||
for attention_coeff,neighborhood_graph in zip(attention_coeffs,neighborhood_graphs):
|
||||
aggregated_attention_coeff = np.zeros(len(attention_coeff),dtype=np.float32)
|
||||
for s in range( len(attention_coeff) ):
|
||||
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s],0).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
||||
aggregated_attention_coeff[neighborhood_graph[s]] += np.maximum(sign*attention_coeff[s][:len(neighborhood_graph[s])],0).mean(-1) # Attention coefficient has size [N_aa,K_graph,nheads]. average over heads.
|
||||
aggregated_attention_coeffs.append(aggregated_attention_coeff)
|
||||
aggregated_attention_coeffs = np.array(aggregated_attention_coeffs)
|
||||
if layer == 'attention_layer':
|
||||
@@ -611,6 +611,7 @@ def predict_interface_residues(
|
||||
if not aggregate_models: # multioutput.
|
||||
for layer_,prediction in zip(layer,predictions):
|
||||
if layer_ is None:
|
||||
prediction = prediction[:,1]
|
||||
csv_file = query_output_folder + 'predictions_' + query_name + '.csv'
|
||||
chimera_file = query_output_folder + 'chimera_' + query_names[i]
|
||||
annotated_pdb_file = query_output_folder + 'annotated_' + query_names[i] + ('.cif' if file_is_cif else '.pdb')
|
||||
|
||||
Reference in New Issue
Block a user