mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
226 lines
7.1 KiB
Python
226 lines
7.1 KiB
Python
import argparse
|
|
import time
|
|
|
|
import dgl
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
from dgi import Classifier, DGI
|
|
from dgl.data import (
|
|
CiteseerGraphDataset,
|
|
CoraGraphDataset,
|
|
PubmedGraphDataset,
|
|
register_data_args,
|
|
)
|
|
from tensorflow.keras import layers
|
|
|
|
|
|
def evaluate(model, features, labels, mask):
|
|
logits = model(features, training=False)
|
|
logits = logits[mask]
|
|
labels = labels[mask]
|
|
indices = tf.math.argmax(logits, axis=1)
|
|
acc = tf.reduce_mean(tf.cast(indices == labels, dtype=tf.float32))
|
|
return acc.numpy().item()
|
|
|
|
|
|
def main(args):
|
|
# load and preprocess dataset
|
|
if args.dataset == "cora":
|
|
data = CoraGraphDataset()
|
|
elif args.dataset == "citeseer":
|
|
data = CiteseerGraphDataset()
|
|
elif args.dataset == "pubmed":
|
|
data = PubmedGraphDataset()
|
|
else:
|
|
raise ValueError("Unknown dataset: {}".format(args.dataset))
|
|
|
|
g = data[0]
|
|
if args.gpu < 0:
|
|
device = "/cpu:0"
|
|
else:
|
|
device = "/gpu:{}".format(args.gpu)
|
|
g = g.to(device)
|
|
|
|
with tf.device(device):
|
|
features = g.ndata["feat"]
|
|
labels = g.ndata["label"]
|
|
train_mask = g.ndata["train_mask"]
|
|
val_mask = g.ndata["val_mask"]
|
|
test_mask = g.ndata["test_mask"]
|
|
in_feats = features.shape[1]
|
|
n_classes = data.num_classes
|
|
n_edges = g.number_of_edges()
|
|
|
|
# add self loop
|
|
if args.self_loop:
|
|
g = dgl.remove_self_loop(g)
|
|
g = dgl.add_self_loop(g)
|
|
n_edges = g.number_of_edges()
|
|
|
|
# create DGI model
|
|
dgi = DGI(
|
|
g,
|
|
in_feats,
|
|
args.n_hidden,
|
|
args.n_layers,
|
|
tf.keras.layers.PReLU(
|
|
alpha_initializer=tf.constant_initializer(0.25)
|
|
),
|
|
args.dropout,
|
|
)
|
|
|
|
dgi_optimizer = tf.keras.optimizers.Adam(learning_rate=args.dgi_lr)
|
|
|
|
# train deep graph infomax
|
|
cnt_wait = 0
|
|
best = 1e9
|
|
best_t = 0
|
|
dur = []
|
|
for epoch in range(args.n_dgi_epochs):
|
|
if epoch >= 3:
|
|
t0 = time.time()
|
|
|
|
with tf.GradientTape() as tape:
|
|
loss = dgi(features)
|
|
# Manually Weight Decay
|
|
# We found Tensorflow has a different implementation on weight decay
|
|
# of Adam(W) optimizer with PyTorch. And this results in worse results.
|
|
# Manually adding weights to the loss to do weight decay solves this problem.
|
|
for weight in dgi.trainable_weights:
|
|
loss = loss + args.weight_decay * tf.nn.l2_loss(weight)
|
|
grads = tape.gradient(loss, dgi.trainable_weights)
|
|
dgi_optimizer.apply_gradients(zip(grads, dgi.trainable_weights))
|
|
|
|
if loss < best:
|
|
best = loss
|
|
best_t = epoch
|
|
cnt_wait = 0
|
|
dgi.save_weights("best_dgi.pkl")
|
|
else:
|
|
cnt_wait += 1
|
|
|
|
if cnt_wait == args.patience:
|
|
print("Early stopping!")
|
|
break
|
|
|
|
if epoch >= 3:
|
|
dur.append(time.time() - t0)
|
|
|
|
print(
|
|
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | "
|
|
"ETputs(KTEPS) {:.2f}".format(
|
|
epoch,
|
|
np.mean(dur),
|
|
loss.numpy().item(),
|
|
n_edges / np.mean(dur) / 1000,
|
|
)
|
|
)
|
|
|
|
# create classifier model
|
|
classifier = Classifier(args.n_hidden, n_classes)
|
|
|
|
classifier_optimizer = tf.keras.optimizers.Adam(
|
|
learning_rate=args.classifier_lr
|
|
)
|
|
|
|
# train classifier
|
|
print("Loading {}th epoch".format(best_t))
|
|
dgi.load_weights("best_dgi.pkl")
|
|
embeds = dgi.encoder(features, corrupt=False)
|
|
embeds = tf.stop_gradient(embeds)
|
|
dur = []
|
|
loss_fcn = tf.keras.losses.SparseCategoricalCrossentropy(
|
|
from_logits=True
|
|
)
|
|
for epoch in range(args.n_classifier_epochs):
|
|
if epoch >= 3:
|
|
t0 = time.time()
|
|
with tf.GradientTape() as tape:
|
|
preds = classifier(embeds)
|
|
loss = loss_fcn(labels[train_mask], preds[train_mask])
|
|
# Manually Weight Decay
|
|
# We found Tensorflow has a different implementation on weight decay
|
|
# of Adam(W) optimizer with PyTorch. And this results in worse results.
|
|
# Manually adding weights to the loss to do weight decay solves this problem.
|
|
# In original code, there's no weight decay applied in this part
|
|
# link: https://github.com/PetarV-/DGI/blob/master/execute.py#L121
|
|
# for weight in classifier.trainable_weights:
|
|
# loss = loss + \
|
|
# args.weight_decay * tf.nn.l2_loss(weight)
|
|
grads = tape.gradient(loss, classifier.trainable_weights)
|
|
classifier_optimizer.apply_gradients(
|
|
zip(grads, classifier.trainable_weights)
|
|
)
|
|
if epoch >= 3:
|
|
dur.append(time.time() - t0)
|
|
|
|
acc = evaluate(classifier, embeds, labels, val_mask)
|
|
print(
|
|
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
|
|
"ETputs(KTEPS) {:.2f}".format(
|
|
epoch,
|
|
np.mean(dur),
|
|
loss.numpy().item(),
|
|
acc,
|
|
n_edges / np.mean(dur) / 1000,
|
|
)
|
|
)
|
|
|
|
print()
|
|
acc = evaluate(classifier, embeds, labels, test_mask)
|
|
print("Test Accuracy {:.4f}".format(acc))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="DGI")
|
|
register_data_args(parser)
|
|
parser.add_argument(
|
|
"--dropout", type=float, default=0.0, help="dropout probability"
|
|
)
|
|
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
|
|
parser.add_argument(
|
|
"--dgi-lr", type=float, default=1e-3, help="dgi learning rate"
|
|
)
|
|
parser.add_argument(
|
|
"--classifier-lr",
|
|
type=float,
|
|
default=1e-2,
|
|
help="classifier learning rate",
|
|
)
|
|
parser.add_argument(
|
|
"--n-dgi-epochs",
|
|
type=int,
|
|
default=300,
|
|
help="number of training epochs",
|
|
)
|
|
parser.add_argument(
|
|
"--n-classifier-epochs",
|
|
type=int,
|
|
default=300,
|
|
help="number of training epochs",
|
|
)
|
|
parser.add_argument(
|
|
"--n-hidden", type=int, default=512, help="number of hidden gcn units"
|
|
)
|
|
parser.add_argument(
|
|
"--n-layers", type=int, default=1, help="number of hidden gcn layers"
|
|
)
|
|
parser.add_argument(
|
|
"--weight-decay", type=float, default=0.0, help="Weight for L2 loss"
|
|
)
|
|
parser.add_argument(
|
|
"--patience", type=int, default=20, help="early stop patience condition"
|
|
)
|
|
parser.add_argument(
|
|
"--self-loop",
|
|
action="store_true",
|
|
help="graph self-loop (default=False)",
|
|
)
|
|
parser.set_defaults(self_loop=False)
|
|
args = parser.parse_args()
|
|
print(args)
|
|
|
|
main(args)
|