mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
83 lines
2.4 KiB
Python
83 lines
2.4 KiB
Python
"""
|
|
Deep Graph Infomax in DGL
|
|
|
|
References
|
|
----------
|
|
Papers: https://arxiv.org/abs/1809.10341
|
|
Author's code: https://github.com/PetarV-/DGI
|
|
"""
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from gcn import GCN
|
|
from tensorflow.keras import layers
|
|
|
|
|
|
class Encoder(layers.Layer):
|
|
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
|
|
super(Encoder, self).__init__()
|
|
self.g = g
|
|
self.conv = GCN(
|
|
g, in_feats, n_hidden, n_hidden, n_layers, activation, dropout
|
|
)
|
|
|
|
def call(self, features, corrupt=False):
|
|
if corrupt:
|
|
perm = np.random.permutation(self.g.number_of_nodes())
|
|
features = tf.gather(features, perm)
|
|
features = self.conv(features)
|
|
return features
|
|
|
|
|
|
class Discriminator(layers.Layer):
|
|
def __init__(self, n_hidden):
|
|
super(Discriminator, self).__init__()
|
|
uinit = tf.keras.initializers.RandomUniform(
|
|
-1.0 / math.sqrt(n_hidden), 1.0 / math.sqrt(n_hidden)
|
|
)
|
|
self.weight = tf.Variable(
|
|
initial_value=uinit(shape=(n_hidden, n_hidden), dtype="float32"),
|
|
trainable=True,
|
|
)
|
|
|
|
def call(self, features, summary):
|
|
features = tf.matmul(
|
|
features, tf.matmul(self.weight, tf.expand_dims(summary, -1))
|
|
)
|
|
return features
|
|
|
|
|
|
class DGI(tf.keras.Model):
|
|
def __init__(self, g, in_feats, n_hidden, n_layers, activation, dropout):
|
|
super(DGI, self).__init__()
|
|
self.encoder = Encoder(
|
|
g, in_feats, n_hidden, n_layers, activation, dropout
|
|
)
|
|
self.discriminator = Discriminator(n_hidden)
|
|
self.loss = tf.nn.sigmoid_cross_entropy_with_logits
|
|
|
|
def call(self, features):
|
|
positive = self.encoder(features, corrupt=False)
|
|
negative = self.encoder(features, corrupt=True)
|
|
summary = tf.nn.sigmoid(tf.reduce_mean(positive, axis=0))
|
|
|
|
positive = self.discriminator(positive, summary)
|
|
negative = self.discriminator(negative, summary)
|
|
|
|
l1 = self.loss(tf.ones(positive.shape), positive)
|
|
l2 = self.loss(tf.zeros(negative.shape), negative)
|
|
|
|
return tf.reduce_mean(l1) + tf.reduce_mean(l2)
|
|
|
|
|
|
class Classifier(layers.Layer):
|
|
def __init__(self, n_hidden, n_classes):
|
|
super(Classifier, self).__init__()
|
|
self.fc = layers.Dense(n_classes)
|
|
|
|
def call(self, features):
|
|
features = self.fc(features)
|
|
return features
|