Files
ScanNet/network/scannet.py
2022-05-25 11:30:35 +03:00

917 lines
40 KiB
Python

from keras.models import Model
from keras.layers import Input, Dense, Masking, TimeDistributed, Concatenate, Activation, Embedding, Dropout, Lambda
from keras.initializers import Zeros, Ones,RandomUniform
import numpy as np
import keras.regularizers
from . import neighborhoods, attention, embeddings, utils
from functools import partial
import keras.backend as K
import tensorflow as tf
from utilities import wrappers,io_utils
from preprocessing import PDB_processing
def l1_regularization(W, l1):
return l1 * K.sum(K.abs(W))
def l12_regularization(W, l12, ndims=2, order='gaussian_feature_filter'):
if order == 'filter_gaussian_feature': # Order of the tensor indices.
if ndims == 2: # For gaussian-dependent bias
return l12 / 2 * tf.cast(K.shape(W)[1], tf.float32) * K.sum(K.square(K.mean(K.abs(W), axis=1)))
elif ndims == 3:
return l12 / 2 * tf.cast(K.shape(W)[1] * K.shape(W)[2], tf.float32) * K.sum(
K.square(K.mean(K.abs(W), axis=(1, 2))))
elif order == 'gaussian_feature_filter': # Default order the tensor indices.
if ndims == 2:
return l12 / 2 * tf.cast(K.shape(W)[0], tf.float32) * K.sum(K.square(K.mean(K.abs(W), axis=0)))
elif ndims == 3:
return l12 / 2 * tf.cast(K.shape(W)[0] * K.shape(W)[1], tf.float32) * K.sum(
K.square(K.mean(K.abs(W), axis=(0, 1))))
def l12group_regularization(W, l12group, ndims=2, order='gaussian_feature_filter'):
if ndims == 2: # For gaussian-dependent bias, Same as l12
return l12_regularization(W, l12group, ndims=ndims, order=order)
elif ndims == 3:
if order == 'filter_gaussian_feature': # Order of the tensor indices.
return l12group / 2 * tf.cast(K.shape(W)[1] * K.shape(W)[2], tf.float32) * K.sum(
K.square(K.mean(K.sqrt(K.mean(K.square(W), axis=-1)), axis=-1)))
elif order == 'gaussian_feature_filter': # Order of the tensor indices.
return l12group / 2 * tf.cast(K.shape(W)[0] * K.shape(W)[1], tf.float32) * K.sum(
K.square(K.mean(K.sqrt(K.mean(K.square(W), axis=1)), axis=0)))
def addNonLinearity(input_layer, activation, name=None):
if name is None:
name = 'activity_nonlinear'
if activation == 'tanh':
center = True
scale = True
elif 'multitanh' in activation: # e.g. 'multitanh5'
center = False
scale = False
elif activation == 'relu':
center = True
scale = False
elif activation == 'elu':
center = True
scale = True
elif activation in [None,'linear']:
center = False
scale = False
else:
center = True
scale = True
input_layer_normalized = embeddings.MaskedBatchNormalization(
epsilon=1e-3, axis=-1, center=center, scale=scale, name=name + '_normalization')(input_layer) # Custom batch norm layer that takes into account the mask.
if 'multitanh' in activation:
output_layer = embeddings.MultiTanh(int(activation[-1]), use_bias=True, name=name)(
input_layer_normalized)
else:
output_layer = TimeDistributed(Activation(
activation), name=name)(input_layer_normalized)
return output_layer
def attribute_embedding(attribute_layer,output_dim, activation,name=None):
if name is None:
name = 'attribute_embedding'
if output_dim is None: # Does nothing. Just passes through mask-preserving identity layer to change name.
return TimeDistributed(Activation('linear'),name=name)(attribute_layer)
else:
projected_attribute_layer = TimeDistributed(Dense(output_dim, use_bias=False, activation=None),
name=name+'_projection')(attribute_layer)
embedded_attribute_layer = addNonLinearity(projected_attribute_layer, activation,
name=name)
return embedded_attribute_layer
def attribute_normalizer(attribute_layer,name=None,normalizer=1.0):
if name is None:
name = 'attribute_normalizer'
return TimeDistributed(Lambda(lambda x: normalizer * x), name=name)(attribute_layer)
def neighborhood_embedding(
point_clouds,
frame_indices,
attributes,
sequence_indices=None,
order_frame='2',
dipole_frame=False,
Kmax=16,
coordinates=['euclidian'],
index_distance_max = 8,
nrotations=1,
Ngaussians=32,
covariance_type = 'full',
initial_gaussian_values = None,
l1=0.,
l12=0.,
l12group=0.,
nfilters=64,
activation = 'relu',
scale='atom',
return_frames=False):
# Build frames.
frames = neighborhoods.FrameBuilder(name='frames_%s'%scale,
order=order_frame, # The permutation to build the frame (for order ='2', the zaxis is defined from X_1 -> X_3).
dipole=dipole_frame # Only used if using a quadruplet of indices. Then, frame has a 4rth additional direction (the dipole).
)([point_clouds,frame_indices])
# Compute local neighborhoods.
if 'index_distance' in coordinates: # Add sequence position distance.
assert sequence_indices is not None
input2localneighborhood = [frames, sequence_indices, attributes]
else:
input2localneighborhood = [frames, attributes]
local_coordinates, local_attributes = neighborhoods.LocalNeighborhood(Kmax=Kmax,
coordinates=coordinates,
index_distance_max=index_distance_max,
nrotations=nrotations,
name='neighborhood_%s'%scale)(input2localneighborhood)
# Gaussian embedding of local coordinates.
embedded_local_coordinates = embeddings.GaussianKernel(Ngaussians,
initial_values=initial_gaussian_values,
covariance_type=covariance_type,
name='embedded_local_coordinates_%s'%scale)(local_coordinates)
# Apply Spatio-chemical filters.
if l1 > 0:
kernel_regularizer = partial(l1_regularization, l1=l1)
single1_regularizer = kernel_regularizer
fixednorm = np.sqrt(Ngaussians / Kmax)
elif l12 > 0:
kernel_regularizer = partial(l12_regularization, l12=l12, ndims=3)
single1_regularizer = partial(l12_regularization, l12=l12, ndims=2)
fixednorm = np.sqrt(Ngaussians / Kmax)
elif l12group >0:
kernel_regularizer = partial(l12group_regularization, l12group=l12group, ndims=3)
single1_regularizer = partial(l12group_regularization, l12group=l12group, ndims=2)
fixednorm = np.sqrt(Ngaussians / Kmax)
else:
kernel_regularizer = None
single1_regularizer = None
fixednorm = None
spatiochemical_filters_input = embeddings.OuterProduct(nfilters,
use_single1=True, use_single2=False,
use_bias=False,
kernel_regularizer=kernel_regularizer,
single1_regularizer=single1_regularizer,
fixednorm=fixednorm,
non_negative=False,
non_negative_initial = False,
sum_axis=2,
name='SCAN_filter_input_%s'%scale)(
[embedded_local_coordinates, local_attributes])
if nrotations>1:
spatiochemical_filters_activity = addNonLinearity(spatiochemical_filters_input, activation,
name='SCAN_filter_activity_prepooling_%s' % scale) # Add non-linearity.
spatiochemical_filters_activity = utils.MaxAbsPooling(axis=-2, name='SCAN_filter_activity_%s' % scale)(spatiochemical_filters_input)
else:
spatiochemical_filters_activity = addNonLinearity(spatiochemical_filters_input, activation,
name='SCAN_filter_activity_%s' % scale) # Add non-linearity.
if return_frames:
return spatiochemical_filters_activity,frames
else:
return spatiochemical_filters_activity
def ScanNet(
Lmax_aa=800,
Lmax_atom=None,
Lmax_aa_points=None,
Lmax_atom_points=None,
with_atom=True,
K_aa=16,
K_atom=16,
K_graph=32,
N_aa=32,
N_atom=16,
N_graph=16,
nfeatures_atom=4,
nfeatures_aa=20,
nembedding_atom=4,
nembedding_aa=None,
nembedding_graph=1,
nfilters_atom=16,
nfilters_aa=64,
nfilters_graph=2,
dense_pooling=None,
nattentionheads_pooling=1,
filter_MLP=[],
nattentionheads_graph=1,
initial_values={'GaussianKernel_atom': None, 'GaussianKernel_aa': None, 'GaussianKernel_graph': None,
'dense_graph': None},
covariance_type_atom='full',
covariance_type_aa='full',
covariance_type_graph='diag',
activation='multitanh5',
coordinates_atom=['euclidian'],
coordinates_aa=['euclidian'],
frame_aa='triplet_backbone',
coordinates_graph=['distance', 'ZdotZ', 'ZdotDelta', 'index_distance'],
index_distance_max_graph=16,
nrotations =1,
l1_aa=0.,
l12_aa=0.,
l12group_aa=0.,
l1_atom=0.,
l12_atom=0.,
l12group_atom=0.,
l1_pool=0.,
l12_pool=0.,
l12group_pool=0.,
dropout=0.,
optimizer='adam',
output = 'classification'
):
if frame_aa == 'triplet_backbone':
order_aa = '3'
dipole_aa = False
elif frame_aa in ['triplet_sidechain', 'triplet_cbeta']:
order_aa = '2'
dipole_aa = False
elif frame_aa == 'quadruplet':
order_aa = '3'
dipole_aa = True
else:
print('Incorrect frame_aa')
return
frame_atom = 'covalent'
order_atom = '2'
dipole_atom = False
# Check initial values for GaussianKernel, and initialize to random uniform if needed.
for component in ['aa', 'atom', 'graph']:
try:
assert (initial_values['GaussianKernel_%s' % component] is not None)
except:
print('Initial values for %s GaussianKernel not found, random initialization...' % component)
coordinates = locals()['coordinates_%s' % component]
xlims = []
for coordinate in coordinates:
if coordinate == 'euclidian':
mini = -8
maxi = 8
ndims = 3
elif coordinate == 'dipole_spherical':
mini = - np.pi
maxi = np.pi
ndims = 2
elif coordinate == 'ZdotDelta':
mini = -1
maxi = 1
ndims = 2
else:
mini = -1
maxi = 1
ndims = 1
for _ in range(ndims):
xlims.append([mini, maxi])
N = locals()['N_%s' % component]
covariance_type = locals()['covariance_type_%s' % component]
key = 'GaussianKernel_%s' % component
print(N, covariance_type, key)
initial_values[key] = embeddings.initialize_GaussianKernelRandom(xlims, N, covariance_type)
try:
assert (initial_values['dense_graph'] is not None)
except:
print('Initial values for graph dense not found, random initialization...')
initial_values['dense_graph'] = [np.random.rand(N_graph, nembedding_graph)]
# Given maximum number of amino acids, define other maximum values.
if Lmax_atom is None:
Lmax_atom = 9 * Lmax_aa
if Lmax_aa_points is None:
if frame_aa == 'triplet_backbone':
Lmax_aa_points = Lmax_aa + 2
elif frame_aa == 'triplet_sidechain':
Lmax_aa_points = 2 * Lmax_aa + 1
elif frame_aa == 'quadruplet':
Lmax_aa_points = 2 * Lmax_aa + 2
if Lmax_atom_points is None:
Lmax_atom_points = 11 * Lmax_aa
### Input layers.
if frame_aa == 'quadruplet':
nindex = 4
else:
nindex = 3
# The 4 inputs arrays at amino acid scale.
frame_indices_aa = Input(shape=[Lmax_aa, nindex], name='frame_indices_aa', dtype="int32") # List of triplet/quadruplet of the indices of the points for building the frames.
attributes_aa = Input(shape=[Lmax_aa, nfeatures_aa], name='attributes_aa', dtype="float32") # The attributes of the amino acids (PWM/ one-hot encoded sequence).
sequence_indices_aa = Input(shape=[Lmax_aa, 1], name='sequence_indices_aa', dtype="int32") # The position of each amino acid on the sequence.
point_clouds_aa = Input(shape=[Lmax_aa_points, 3], name='point_clouds_aa', dtype="float32") # 3D coordinates of the amino acids point clouds. (Calpha coordinates, side chain center of mass,...).
# Same, after masking.
masked_frame_indices_aa = Masking(mask_value=-1, name='masked_frame_indices_aa')(frame_indices_aa)
masked_attributes_aa = Masking(mask_value=0.0, name='masked_attributes_aa')(attributes_aa)
masked_sequence_indices_aa = Masking(mask_value=-1, name='masked_sequence_indices_aa')(sequence_indices_aa)
masked_point_clouds_aa = Masking(mask_value=0.0, name='masked_point_clouds_aa')(point_clouds_aa)
# Same at atomic scale.
if with_atom:
frame_indices_atom = Input(shape=[Lmax_atom, 3], name='frame_indices_atom', dtype="int32")
attributes_atom = Input(shape=[Lmax_atom], name='attributes_atom', dtype="int32")
sequence_indices_atom = Input(shape=[Lmax_atom, 1], name='sequence_indices_atom', dtype="int32")
point_clouds_atom = Input(shape=[Lmax_atom_points, 3], name='point_clouds_atom', dtype="float32")
masked_frame_indices_atom = Masking(mask_value=-1, name='masked_frame_indices_atom')(frame_indices_atom)
masked_sequence_indices_atom = Masking(mask_value=-1, name='masked_sequence_indices_atom')(sequence_indices_atom)
masked_point_clouds_atom = Masking(mask_value=0.0, name='masked_point_clouds_atom')(point_clouds_atom)
## Embed attributes.
if nembedding_aa is not None: # Apply point-wise dense embedding.
embedded_attributes_aa = attribute_embedding(masked_attributes_aa,nembedding_aa,activation,name='embedded_attributes_aa')
else: # Normalize such that that variance is approximately one.
normalizer = np.sqrt(nfeatures_aa)
embedded_attributes_aa = TimeDistributed(Lambda(lambda x: normalizer * x), name='embedded_attributes_aa')(masked_attributes_aa)
nembedding_aa = nfeatures_aa
if with_atom:
if nfeatures_atom == nembedding_atom:
trainable = False # Same number of categories as output dimension. Use one-hot encoding.
else:
trainable = True
embedded_attributes_atom = Embedding(
nfeatures_atom + 1, nembedding_atom, mask_zero=True, embeddings_initializer=utils.embeddings_initializer,
embeddings_constraint=utils.FixedNorm(axis=0, value=np.sqrt(nfeatures_atom)), trainable=trainable,
name='embedded_attributes_atom')(attributes_atom)
## If atomic coordinates are included, compute embeddings of atomic neighborhoods.
SCAN_filters_atom = neighborhood_embedding(
masked_point_clouds_atom,
masked_frame_indices_atom,
embedded_attributes_atom,
sequence_indices=masked_sequence_indices_atom,
order_frame=order_atom,
dipole_frame=dipole_atom,
Kmax=K_atom,
coordinates=coordinates_atom,
Ngaussians=N_atom,
covariance_type=covariance_type_atom,
initial_gaussian_values=initial_values['GaussianKernel_atom'],
l1=l1_atom,
l12=l12_atom,
l12group=l12group_atom,
nfilters=nfilters_atom,
activation=activation,
scale='atom')
# Pool at amino acid level by attention-pooling.
if dense_pooling is None:
dense_pooling = nfilters_atom
if l12_pool > 0:
kernel_regularizer = partial(l12_regularization, l12=l12_pool)
kernel_constraint = utils.FixedNorm(1.0,axis=0)
else:
kernel_regularizer = None
kernel_constraint = None
# Compute attention coefficients (before softmax) for each atom.
pooling_attention = TimeDistributed(Dense(nattentionheads_pooling, use_bias=False,
kernel_regularizer=kernel_regularizer,
kernel_initializer=Zeros()),
name='attention_pooling_coefficients_atom')(SCAN_filters_atom)
# Compute output coefficients (before averaging) for each atom.
pooling_features = TimeDistributed(Dense(dense_pooling, activation=None, use_bias=False,
kernel_regularizer=kernel_regularizer,
kernel_constraint=kernel_constraint,
#kernel_initializer= RandomUniform(minval=0,maxval=np.sqrt(3/nfilters_atom))
),
name='attention_pooling_features_atom')(SCAN_filters_atom)
# Build a binary bipartite graph from amino acid to atoms.
# For each amino acid we look for the 14 closest atoms in terms of sequence distance.
pooling_mask, pooling_attention_local, pooling_features_local = neighborhoods.LocalNeighborhood(Kmax=14,
coordinates=[
'index_distance'],
self_neighborhood=False,
index_distance_max=1,
name='pooling_intermediate_atom')(
[masked_sequence_indices_aa, masked_sequence_indices_atom, pooling_attention, pooling_features])
# Here pooling_mask is a binary matrix of size [L_aa, L_atom], with M_{ij} =0 iff atom j belongs to amino acid i.
pooling_mask = Lambda(lambda x: 1 - x, name='pooling_edges_atom')(
pooling_mask)
# Reverse the pooling mask such that M_{ij} =1 iff atom j belongs to amino acid i.
SCAN_filters_atom_aggregated_input, _ = attention.AttentionLayer(self_attention=False, beta=False,
name='SCAN_filters_atom_aggregated_input')(
[pooling_attention_local, pooling_features_local, pooling_mask])
# Attention-based aggregation of atom features to amino acid scale.
SCAN_filters_atom_aggregated_activity = addNonLinearity(SCAN_filters_atom_aggregated_input, activation,name='SCAN_filters_atom_aggregated_activity')
# Add final non-linearity.
if with_atom:
all_embedded_attributes_aa = Concatenate(name='all_embedded_attributes_aa', axis=-1)([embedded_attributes_aa, SCAN_filters_atom_aggregated_activity] )
else:
all_embedded_attributes_aa = Activation('linear',name='all_embedded_attributes_aa')(embedded_attributes_aa)
# Compute embeddings of amino acid neighborhoods.
SCAN_filters_aa,frames_aa = neighborhood_embedding(
masked_point_clouds_aa,
masked_frame_indices_aa,
all_embedded_attributes_aa,
sequence_indices=masked_sequence_indices_aa,
order_frame=order_aa,
dipole_frame=dipole_aa,
Kmax=K_aa,
coordinates=coordinates_aa,
nrotations=nrotations,
Ngaussians=N_aa,
covariance_type=covariance_type_aa,
initial_gaussian_values=initial_values['GaussianKernel_aa'],
l1=l1_aa,
l12=l12_aa,
l12group=l12group_aa,
nfilters=nfilters_aa,
activation=activation,
scale='aa',
return_frames=True)
if dropout > 0:
SCAN_filters_aa = Dropout(dropout, noise_shape=(
None, 1, None), name='dropout')(SCAN_filters_aa)
embedded_filter = SCAN_filters_aa
if filter_MLP:
for k, n_feature_filter in enumerate(filter_MLP):
embedded_filter = attribute_embedding(embedded_filter,n_feature_filter,activation,name='SCAN_filters_aa_embedded_%s'%(k+1))
# Final graph attention layer. Propagates label information from "hotspots" to passengers to obtain spatially consistent labels.
beta = TimeDistributed(Dense(nembedding_graph * nattentionheads_graph, use_bias=True, bias_initializer=Ones(),
kernel_initializer=Zeros(), activation='relu'), name='beta')(embedded_filter)
self_attention = TimeDistributed(
Dense(nembedding_graph * nattentionheads_graph, use_bias=True, bias_initializer=Zeros(),
kernel_initializer=Zeros()), name='self_attention')(embedded_filter)
cross_attention = TimeDistributed(Dense(nembedding_graph * nattentionheads_graph, use_bias=False,
kernel_initializer=Zeros()), name='cross_attention')(embedded_filter)
if output == 'classification':
if nfilters_graph > 2:
node_features_activation = 'relu'
else:
node_features_activation = None
else:
node_features_activation = None
node_features = TimeDistributed(Dense(
nattentionheads_graph * nfilters_graph, activation=node_features_activation, use_bias=True),
name='node_output')(embedded_filter)
graph_weights, attention_local, node_features_local = neighborhoods.LocalNeighborhood(
Kmax=K_graph, coordinates=coordinates_graph,
index_distance_max=index_distance_max_graph, name='neighborhood_graph')(
[frames_aa, masked_sequence_indices_aa, cross_attention, node_features])
embedded_graph_weights = embeddings.GaussianKernel(N_graph, initial_values=initial_values['GaussianKernel_graph']
, covariance_type=covariance_type_graph,
name='embedded_local_coordinates_graph')(graph_weights)
embedded_graph_weights = TimeDistributed(Dense(
nembedding_graph, use_bias=False), name='edges_graph')(embedded_graph_weights)
graph_attention_output, attention_coefficients = attention.AttentionLayer(name='attention_layer')(
[beta, self_attention, attention_local, node_features_local, embedded_graph_weights])
if output == 'classification':
if nattentionheads_graph * nfilters_graph > 2:
classifier_output = TimeDistributed(Dense(
2, use_bias=True, activation='softmax'), name='classifier_output')(graph_attention_output)
else:
classifier_output = TimeDistributed(Activation(
'softmax'), name='classifier_output')(graph_attention_output)
else:
classifier_output = graph_attention_output
inputs = [frame_indices_aa, attributes_aa, sequence_indices_aa, point_clouds_aa]
if with_atom:
inputs += [frame_indices_atom, attributes_atom, sequence_indices_atom, point_clouds_atom]
model = Model(inputs=inputs,
outputs=classifier_output)
if output == 'classification':
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=[
'categorical_crossentropy', 'categorical_accuracy'])
else:
model.compile(loss='MSE', optimizer=optimizer, metrics=['MSE'])
model.get_layer('edges_graph').set_weights(
initial_values['dense_graph'])
model.summary()
return model
def initialize_ScanNet(
inputs,
outputs,
with_atom=True,
Lmax_aa=1024,
K_aa=16,
K_atom=16,
K_graph=32,
Dmax_aa=13.,
Dmax_atom=4.,
Dmax_graph=13.,
N_aa=32,
N_atom=32,
N_graph=32,
nfeatures_aa=21,
nfeatures_atom=12,
nembedding_atom=12,
nembedding_aa=16,
nembedding_graph=1,
dense_pooling=None,
nattentionheads_pooling=1,
nfilters_atom=16,
nfilters_aa=64,
nfilters_graph=2,
nattentionheads_graph=1,
filter_MLP=[],
covariance_type_atom='full',
covariance_type_aa='full',
covariance_type_graph='full',
activation='relu',
coordinates_atom=['euclidian'],
coordinates_aa=['euclidian'],
frame_aa='triplet_sidechain',
coordinates_graph=['distance', 'ZdotZ', 'ZdotDelta', 'index_distance'],
index_distance_max_graph=16,
nrotations=1,
l1_aa=0.,
l12_aa=0.,
l12group_aa=0.,
l1_atom=0.,
l12_atom=0.,
l12group_atom=0.,
l1_pool=0.,
l12_pool=0.,
l12group_pool=0.,
dropout=0.,
optimizer='adam',
initial_values_folder='/specific/netapp5_2/iscb/wolfson/jeromet/Data/InterfacePrediction/initial_values/',
fresh_initial_values=False,
save_initial_values=True,
output = 'classification',
n_init=10,
epochs = 100,
batch_size=1):
Lmax_atom = 9 * Lmax_aa
if frame_aa == 'triplet_backbone':
Lmax_aa_points = Lmax_aa + 2
elif frame_aa in ['triplet_sidechain','triplet_cbeta']:
Lmax_aa_points = 2 * Lmax_aa + 1
elif frame_aa == 'quadruplet':
Lmax_aa_points = 2 * Lmax_aa + 2
Lmax_atom_points = 11 * Lmax_aa
initial_values = {'GaussianKernel_aa': None, 'GaussianKernel_atom': None, 'GaussianKernel_graph': None,'dense_graph': None}
if with_atom:
input_type = ['triplets', 'attributes', 'indices', 'points', 'triplets', 'attributes', 'indices', 'points']
Lmaxs = [Lmax_aa, Lmax_aa, Lmax_aa, Lmax_aa_points, Lmax_atom, Lmax_atom, Lmax_atom, Lmax_atom_points]
else:
input_type = ['triplets', 'attributes', 'indices', 'points']
Lmaxs = [Lmax_aa, Lmax_aa, Lmax_aa, Lmax_aa_points]
if frame_aa == 'triplet_backbone':
order_aa = '3'
dipole_aa = False
elif frame_aa in ['triplet_sidechain', 'triplet_cbeta']:
order_aa = '2'
dipole_aa = False
elif frame_aa == 'quadruplet':
order_aa = '3'
dipole_aa = True
else:
print('Incorrecte frame_aa')
return
order_atom = '2'
dipole_atom = False
frame_atom = 'covalent'
c_aa = ''.join([c[0] for c in coordinates_aa])
location_aa = initial_values_folder + 'initial_GaussianKernel_aa_N_%s_Kmax_%s_Dmax_%s_frames_%s_coords_%s_nrotations_%s_cov_%s_tripletsorder_%s_dipole_%s.data' % (
N_aa, K_aa, Dmax_aa, frame_aa, c_aa, nrotations, covariance_type_aa, order_aa, dipole_aa)
try:
assert fresh_initial_values == False
initial_values['GaussianKernel_aa'] = io_utils.load_pickle(location_aa)[
'GaussianKernel_aa']
except:
print(
'Initializing the Gaussian kernels for the amino acid neighborhood (takes a few minutes to do it robustly, be patient!) Reduce n_init from 10 to 1 if speed needed')
initial_values['GaussianKernel_aa'] = neighborhoods.initialize_GaussianKernel_for_NeighborhoodEmbedding(
[inputs[0], inputs[3]],
N_aa,
covariance_type_aa,
neighborhood_params={'Kmax': K_aa,
'coordinates': coordinates_aa,
'nrotations': nrotations,
'index_distance_max': None,
'self_neighborhood': True,
},
n_samples=100,
Dmax=Dmax_aa, padded=False, from_triplets=True, order=order_aa, dipole=dipole_aa, n_init=n_init)
if save_initial_values:
io_utils.save_pickle(
{'GaussianKernel_aa': initial_values['GaussianKernel_aa']}, location_aa)
if with_atom:
print(
'Initializing the Gaussian kernels for the atomic neighborhood (takes a few minutes to do it robustly, be patient!). Reduce n_init from 10 to 1 if speed needed')
c_atom = ''.join([c[0] for c in coordinates_atom])
location_atom = initial_values_folder + 'initial_GaussianKernel_atom_N_%s_Kmax_%s_Dmax_%s_frames_%s_coords_%s_nrotations_%s_cov_%s_tripletsorder_%s_dipole_%s.data' % (
N_atom, K_atom, Dmax_atom, frame_atom, c_atom, 1, covariance_type_atom, order_atom, dipole_atom)
try:
assert fresh_initial_values == False
initial_values['GaussianKernel_atom'] = io_utils.load_pickle(location_atom)[
'GaussianKernel_atom']
except:
if 'index_distance' in coordinates_atom:
inputs_ = [inputs[4], inputs[7], inputs[6]]
else:
inputs_ = [inputs[4], inputs[7]]
initial_values['GaussianKernel_atom'] = neighborhoods.initialize_GaussianKernel_for_NeighborhoodEmbedding(
inputs_,
N_atom,
covariance_type_atom,
neighborhood_params={'Kmax': K_atom,
'coordinates': coordinates_atom,
'index_distance_max': None,
'self_neighborhood': True
},
n_samples=100,
Dmax=Dmax_atom, padded=False, from_triplets=True, order=order_atom, dipole=dipole_atom, n_init=n_init)
if save_initial_values:
io_utils.save_pickle(
{'GaussianKernel_atom': initial_values['GaussianKernel_atom']}, location_atom)
c_graph = ''.join([c[0] for c in coordinates_graph])
location_graph = initial_values_folder + 'initial_GaussianKernel_graph_N_%s_%s_Kmax_%s_Dmax_%s_coords_%s_indexmax_%s_cov_%s_tripletsorder_%s_dipole_%s.data' % (
N_graph, nembedding_graph, K_graph, Dmax_graph, c_graph, index_distance_max_graph, covariance_type_graph,
order_aa, dipole_aa)
try:
assert fresh_initial_values == False
initial_values_graph = io_utils.load_pickle(location_graph)
except:
print(
'Initializing the Gaussian and dense kernels for the graph neighborhood (takes a few minutes to do it robustly, be patient!) Reduce n_init from 10 to 1 if speed needed')
initial_values_graph = neighborhoods.initialize_Embedding_for_NeighborhoodAttention(
[inputs[0], inputs[3], inputs[2]], outputs,
neighborhood_params={'Kmax': K_graph,
'coordinates': coordinates_graph,
'index_distance_max': index_distance_max_graph,
'self_neighborhood': True
},
N=N_graph, dense=nembedding_graph, covariance_type=covariance_type_graph,
Dmax=Dmax_graph,
epochs=10,
padded=False,
from_triplets=True,
n_samples=100, order=order_aa, dipole=dipole_aa, n_init=n_init)
if save_initial_values:
io_utils.save_pickle(initial_values_graph, location_graph)
initial_values['GaussianKernel_graph'] = initial_values_graph['graph_embedding_GaussianKernel']
initial_values['dense_graph'] = initial_values_graph['graph_embedding_dense']
model = wrappers.grouped_Predictor_wrapper(ScanNet,
with_atom=with_atom,
Lmax_aa=Lmax_aa,
Lmax_atom=Lmax_atom,
Lmax_aa_points=Lmax_aa_points,
Lmax_atom_points=Lmax_atom_points,
K_aa=K_aa,
K_atom=K_atom,
K_graph=K_graph,
N_aa=N_aa,
N_atom=N_atom,
N_graph=N_graph,
nfeatures_aa=nfeatures_aa,
nfeatures_atom=nfeatures_atom,
nembedding_atom=nembedding_atom,
nembedding_aa=nembedding_aa,
nembedding_graph=nembedding_graph,
dense_pooling=dense_pooling,
nattentionheads_pooling=nattentionheads_pooling,
nfilters_aa=nfilters_aa,
nfilters_atom=nfilters_atom,
nfilters_graph=nfilters_graph,
nattentionheads_graph=nattentionheads_graph,
filter_MLP=filter_MLP,
initial_values=initial_values,
covariance_type_aa=covariance_type_aa,
covariance_type_atom=covariance_type_atom,
covariance_type_graph=covariance_type_graph,
activation=activation,
frame_aa=frame_aa,
coordinates_aa=coordinates_aa,
coordinates_atom=coordinates_atom,
coordinates_graph=coordinates_graph,
index_distance_max_graph=index_distance_max_graph,
l1_aa=l1_aa,
l12_aa=l12_aa,
l12group_aa=l12group_aa,
l1_atom=l1_atom,
l12_atom=l12_atom,
l12group_atom=l12group_atom,
l1_pool=l1_pool,
l12_pool=l12_pool,
l12group_pool=l12group_pool,
nrotations=nrotations,
dropout=dropout,
optimizer=optimizer,
input_type=input_type,
Lmaxs=Lmaxs,
output = output,
multi_inputs=True,
multi_outputs=False,
)
extra_params = {'epochs': epochs, 'batch_size': batch_size}
return model, extra_params
if __name__ == '__main__':
import matplotlib
matplotlib.use('module://backend_interagg')
import matplotlib.pyplot as plt
import Bio.PDB
from preprocessing import PDBio,pipelines
import numpy as np
with_atom = True
frames_aa = 'triplet_sidechain'
pipeline = pipelines.ScanNetPipeline(
with_aa=True,
with_atom=with_atom,
aa_features='sequence',
atom_features='valency',
aa_frames=frames_aa,
)
PDB_folder = '/Users/jerometubiana/PDB/'
label_file = '/Users/jerometubiana/Downloads/interface_labels_train.txt'
# PDB_folder = '/home/iscb/wolfson/jeromet/Data/PDB_files/'
# label_file = '/home/iscb/wolfson/jeromet/Data/PPI_binding_database/CATH/interface_labels_train.txt'
pdblist = Bio.PDB.PDBList()
nmax = 11
list_origins, list_sequences, list_resids, list_labels = io_utils.read_labels(label_file, nmax=nmax,
label_type='int')
inputs = []
for origin in list_origins:
pdb = origin[:4]
chain = origin.split('_')[-1]
name = pdblist.retrieve_pdb_file(pdb, pdir=PDB_folder)
struct, chains = PDBio.load_chains(pdb_id=pdb, chain_ids=[(0, chain)], file=PDB_folder + '%s.cif' % pdb)
inputs.append(pipeline.process_example(chains))
inputs = [np.array([input[k] for input in inputs])
for k in range(len(inputs[0]))]
outputs = [np.stack([label < 5, label >= 5], axis=-1) for label in list_labels]
Lmax_aa = max([len(o) for o in outputs])
K_aa = 8
K_atom = 8
K_graph = 16
Dmax_aa = 13.
Dmax_atom = 4.
Dmax_graph = 13.
N_aa = 16
N_atom = 16
N_graph = 16
nfeatures_aa = 20
nfeatures_atom = 12
nembedding_atom = 4
nembedding_aa = 8
nembedding_graph = 1
dense_pooling = None
nattentionheads_pooling = 2
nfilters_atom = 32
nfilters_aa = 32
nfilters_graph = 2
nattentionheads_graph = 1
filter_MLP = []
covariance_type_atom = 'full'
covariance_type_aa = 'full'
covariance_type_graph = 'full'
activation = 'relu'
coordinates_atom = ['euclidian']
coordinates_aa = ['euclidian'] # ,'dipole_spherical']
frames_aa = frames_aa
coordinates_graph = ['distance', 'ZdotZ', 'ZdotDelta', 'index_distance']
index_distance_max_graph = 16
l1_aa = 0.
l12_aa = 0.
l12group_aa = 1e-3
l1_atom = 0.
l12_atom = 1e-3
l12group_atom = 0.
nrotations = 1
dropout = 0.
optimizer = 'adam'
model, extra_params = initialize_ScanNet(
inputs,
outputs,
with_atom=with_atom,
Lmax_aa=Lmax_aa,
K_aa=K_aa,
K_atom=K_atom,
K_graph=K_graph,
Dmax_aa=Dmax_aa,
Dmax_atom=Dmax_atom,
Dmax_graph=Dmax_graph,
N_aa=N_aa,
N_atom=N_atom,
N_graph=N_graph,
nfeatures_aa=nfeatures_aa,
nfeatures_atom=nfeatures_atom,
nembedding_atom=nembedding_atom,
nembedding_aa=nembedding_aa,
nembedding_graph=nembedding_graph,
dense_pooling=dense_pooling,
nattentionheads_pooling=nattentionheads_pooling,
nfilters_atom=nfilters_atom,
nfilters_aa=nfilters_aa,
nfilters_graph=nfilters_graph,
nattentionheads_graph=nattentionheads_graph,
filter_MLP=filter_MLP,
covariance_type_atom=covariance_type_atom,
covariance_type_aa=covariance_type_aa,
covariance_type_graph=covariance_type_graph,
activation=activation,
coordinates_atom=coordinates_atom,
coordinates_aa=coordinates_aa,
frame_aa=frames_aa,
coordinates_graph=coordinates_graph,
index_distance_max_graph=index_distance_max_graph,
l1_aa=l1_aa,
l12_aa=l12_aa,
l12group_aa=l12group_aa,
l1_atom=l1_atom,
l12_atom=l12_atom,
l12group_atom=l12group_atom,
nrotations=nrotations,
dropout=dropout,
optimizer=optimizer, batch_size=1,
fresh_initial_values=True,
save_initial_values=False, n_init=1
)
model.fit(inputs, outputs, batch_size=1, epochs=10)