mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
382 lines
12 KiB
Python
Executable File
382 lines
12 KiB
Python
Executable File
import argparse
|
|
import os
|
|
import random
|
|
import time
|
|
|
|
import dgl
|
|
import dgl.function as fn
|
|
|
|
import networkx as nx
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.utils.data
|
|
from data_utils import pre_process
|
|
from dgl import DGLGraph
|
|
from dgl.data import tu
|
|
from model.encoder import DiffPool
|
|
|
|
global_train_time_per_epoch = []
|
|
|
|
|
|
def arg_parse():
|
|
"""
|
|
argument parser
|
|
"""
|
|
parser = argparse.ArgumentParser(description="DiffPool arguments")
|
|
parser.add_argument("--dataset", dest="dataset", help="Input Dataset")
|
|
parser.add_argument(
|
|
"--pool_ratio", dest="pool_ratio", type=float, help="pooling ratio"
|
|
)
|
|
parser.add_argument(
|
|
"--num_pool", dest="num_pool", type=int, help="num_pooling layer"
|
|
)
|
|
parser.add_argument(
|
|
"--no_link_pred",
|
|
dest="linkpred",
|
|
action="store_false",
|
|
help="switch of link prediction object",
|
|
)
|
|
parser.add_argument("--cuda", dest="cuda", type=int, help="switch cuda")
|
|
parser.add_argument("--lr", dest="lr", type=float, help="learning rate")
|
|
parser.add_argument(
|
|
"--clip", dest="clip", type=float, help="gradient clipping"
|
|
)
|
|
parser.add_argument(
|
|
"--batch-size", dest="batch_size", type=int, help="batch size"
|
|
)
|
|
parser.add_argument("--epochs", dest="epoch", type=int, help="num-of-epoch")
|
|
parser.add_argument(
|
|
"--train-ratio",
|
|
dest="train_ratio",
|
|
type=float,
|
|
help="ratio of trainning dataset split",
|
|
)
|
|
parser.add_argument(
|
|
"--test-ratio",
|
|
dest="test_ratio",
|
|
type=float,
|
|
help="ratio of testing dataset split",
|
|
)
|
|
parser.add_argument(
|
|
"--num_workers",
|
|
dest="n_worker",
|
|
type=int,
|
|
help="number of workers when dataloading",
|
|
)
|
|
parser.add_argument(
|
|
"--gc-per-block",
|
|
dest="gc_per_block",
|
|
type=int,
|
|
help="number of graph conv layer per block",
|
|
)
|
|
parser.add_argument(
|
|
"--bn",
|
|
dest="bn",
|
|
action="store_const",
|
|
const=True,
|
|
default=True,
|
|
help="switch for bn",
|
|
)
|
|
parser.add_argument(
|
|
"--dropout", dest="dropout", type=float, help="dropout rate"
|
|
)
|
|
parser.add_argument(
|
|
"--bias",
|
|
dest="bias",
|
|
action="store_const",
|
|
const=True,
|
|
default=True,
|
|
help="switch for bias",
|
|
)
|
|
parser.add_argument(
|
|
"--save_dir",
|
|
dest="save_dir",
|
|
help="model saving directory: SAVE_DICT/DATASET",
|
|
)
|
|
parser.add_argument(
|
|
"--load_epoch",
|
|
dest="load_epoch",
|
|
type=int,
|
|
help="load trained model params from\
|
|
SAVE_DICT/DATASET/model-LOAD_EPOCH",
|
|
)
|
|
parser.add_argument(
|
|
"--data_mode",
|
|
dest="data_mode",
|
|
help="data\
|
|
preprocessing mode: default, id, degree, or one-hot\
|
|
vector of degree number",
|
|
choices=["default", "id", "deg", "deg_num"],
|
|
)
|
|
|
|
parser.set_defaults(
|
|
dataset="ENZYMES",
|
|
pool_ratio=0.15,
|
|
num_pool=1,
|
|
cuda=1,
|
|
lr=1e-3,
|
|
clip=2.0,
|
|
batch_size=20,
|
|
epoch=4000,
|
|
train_ratio=0.7,
|
|
test_ratio=0.1,
|
|
n_worker=1,
|
|
gc_per_block=3,
|
|
dropout=0.0,
|
|
method="diffpool",
|
|
bn=True,
|
|
bias=True,
|
|
save_dir="./model_param",
|
|
load_epoch=-1,
|
|
data_mode="default",
|
|
)
|
|
return parser.parse_args()
|
|
|
|
|
|
def prepare_data(dataset, prog_args, train=False, pre_process=None):
|
|
"""
|
|
preprocess TU dataset according to DiffPool's paper setting and load dataset into dataloader
|
|
"""
|
|
if train:
|
|
shuffle = True
|
|
else:
|
|
shuffle = False
|
|
|
|
if pre_process:
|
|
pre_process(dataset, prog_args)
|
|
|
|
# dataset.set_fold(fold)
|
|
return dgl.dataloading.GraphDataLoader(
|
|
dataset,
|
|
batch_size=prog_args.batch_size,
|
|
shuffle=shuffle,
|
|
num_workers=prog_args.n_worker,
|
|
)
|
|
|
|
|
|
def graph_classify_task(prog_args):
|
|
"""
|
|
perform graph classification task
|
|
"""
|
|
|
|
dataset = tu.LegacyTUDataset(name=prog_args.dataset)
|
|
train_size = int(prog_args.train_ratio * len(dataset))
|
|
test_size = int(prog_args.test_ratio * len(dataset))
|
|
val_size = int(len(dataset) - train_size - test_size)
|
|
|
|
dataset_train, dataset_val, dataset_test = torch.utils.data.random_split(
|
|
dataset, (train_size, val_size, test_size)
|
|
)
|
|
train_dataloader = prepare_data(
|
|
dataset_train, prog_args, train=True, pre_process=pre_process
|
|
)
|
|
val_dataloader = prepare_data(
|
|
dataset_val, prog_args, train=False, pre_process=pre_process
|
|
)
|
|
test_dataloader = prepare_data(
|
|
dataset_test, prog_args, train=False, pre_process=pre_process
|
|
)
|
|
input_dim, label_dim, max_num_node = dataset.statistics()
|
|
print("++++++++++STATISTICS ABOUT THE DATASET")
|
|
print("dataset feature dimension is", input_dim)
|
|
print("dataset label dimension is", label_dim)
|
|
print("the max num node is", max_num_node)
|
|
print("number of graphs is", len(dataset))
|
|
# assert len(dataset) % prog_args.batch_size == 0, "training set not divisible by batch size"
|
|
|
|
hidden_dim = 64 # used to be 64
|
|
embedding_dim = 64
|
|
|
|
# calculate assignment dimension: pool_ratio * largest graph's maximum
|
|
# number of nodes in the dataset
|
|
assign_dim = int(max_num_node * prog_args.pool_ratio)
|
|
print("++++++++++MODEL STATISTICS++++++++")
|
|
print("model hidden dim is", hidden_dim)
|
|
print("model embedding dim for graph instance embedding", embedding_dim)
|
|
print("initial batched pool graph dim is", assign_dim)
|
|
activation = F.relu
|
|
|
|
# initialize model
|
|
# 'diffpool' : diffpool
|
|
model = DiffPool(
|
|
input_dim,
|
|
hidden_dim,
|
|
embedding_dim,
|
|
label_dim,
|
|
activation,
|
|
prog_args.gc_per_block,
|
|
prog_args.dropout,
|
|
prog_args.num_pool,
|
|
prog_args.linkpred,
|
|
prog_args.batch_size,
|
|
"meanpool",
|
|
assign_dim,
|
|
prog_args.pool_ratio,
|
|
)
|
|
|
|
if prog_args.load_epoch >= 0 and prog_args.save_dir is not None:
|
|
model.load_state_dict(
|
|
torch.load(
|
|
prog_args.save_dir
|
|
+ "/"
|
|
+ prog_args.dataset
|
|
+ "/model.iter-"
|
|
+ str(prog_args.load_epoch),
|
|
weights_only=False,
|
|
)
|
|
)
|
|
|
|
print("model init finished")
|
|
print("MODEL:::::::", prog_args.method)
|
|
if prog_args.cuda:
|
|
model = model.cuda()
|
|
|
|
logger = train(
|
|
train_dataloader, model, prog_args, val_dataset=val_dataloader
|
|
)
|
|
result = evaluate(test_dataloader, model, prog_args, logger)
|
|
print("test accuracy {:.2f}%".format(result * 100))
|
|
|
|
|
|
def train(dataset, model, prog_args, same_feat=True, val_dataset=None):
|
|
"""
|
|
training function
|
|
"""
|
|
dir = prog_args.save_dir + "/" + prog_args.dataset
|
|
if not os.path.exists(dir):
|
|
os.makedirs(dir)
|
|
dataloader = dataset
|
|
optimizer = torch.optim.Adam(
|
|
filter(lambda p: p.requires_grad, model.parameters()), lr=0.001
|
|
)
|
|
early_stopping_logger = {"best_epoch": -1, "val_acc": -1}
|
|
|
|
if prog_args.cuda > 0:
|
|
torch.cuda.set_device(0)
|
|
for epoch in range(prog_args.epoch):
|
|
begin_time = time.time()
|
|
model.train()
|
|
accum_correct = 0
|
|
total = 0
|
|
print("\nEPOCH ###### {} ######".format(epoch))
|
|
computation_time = 0.0
|
|
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
|
|
for key, value in batch_graph.ndata.items():
|
|
batch_graph.ndata[key] = value.float()
|
|
graph_labels = graph_labels.long()
|
|
if torch.cuda.is_available():
|
|
batch_graph = batch_graph.to(torch.cuda.current_device())
|
|
graph_labels = graph_labels.cuda()
|
|
|
|
model.zero_grad()
|
|
compute_start = time.time()
|
|
ypred = model(batch_graph)
|
|
indi = torch.argmax(ypred, dim=1)
|
|
correct = torch.sum(indi == graph_labels).item()
|
|
accum_correct += correct
|
|
total += graph_labels.size()[0]
|
|
loss = model.loss(ypred, graph_labels)
|
|
loss.backward()
|
|
batch_compute_time = time.time() - compute_start
|
|
computation_time += batch_compute_time
|
|
nn.utils.clip_grad_norm_(model.parameters(), prog_args.clip)
|
|
optimizer.step()
|
|
|
|
train_accu = accum_correct / total
|
|
print(
|
|
"train accuracy for this epoch {} is {:.2f}%".format(
|
|
epoch, train_accu * 100
|
|
)
|
|
)
|
|
elapsed_time = time.time() - begin_time
|
|
print(
|
|
"loss {:.4f} with epoch time {:.4f} s & computation time {:.4f} s ".format(
|
|
loss.item(), elapsed_time, computation_time
|
|
)
|
|
)
|
|
global_train_time_per_epoch.append(elapsed_time)
|
|
if val_dataset is not None:
|
|
result = evaluate(val_dataset, model, prog_args)
|
|
print("validation accuracy {:.2f}%".format(result * 100))
|
|
if (
|
|
result >= early_stopping_logger["val_acc"]
|
|
and result <= train_accu
|
|
):
|
|
early_stopping_logger.update(best_epoch=epoch, val_acc=result)
|
|
if prog_args.save_dir is not None:
|
|
torch.save(
|
|
model.state_dict(),
|
|
prog_args.save_dir
|
|
+ "/"
|
|
+ prog_args.dataset
|
|
+ "/model.iter-"
|
|
+ str(early_stopping_logger["best_epoch"]),
|
|
)
|
|
print(
|
|
"best epoch is EPOCH {}, val_acc is {:.2f}%".format(
|
|
early_stopping_logger["best_epoch"],
|
|
early_stopping_logger["val_acc"] * 100,
|
|
)
|
|
)
|
|
torch.cuda.empty_cache()
|
|
return early_stopping_logger
|
|
|
|
|
|
def evaluate(dataloader, model, prog_args, logger=None):
|
|
"""
|
|
evaluate function
|
|
"""
|
|
if logger is not None and prog_args.save_dir is not None:
|
|
model.load_state_dict(
|
|
torch.load(
|
|
prog_args.save_dir
|
|
+ "/"
|
|
+ prog_args.dataset
|
|
+ "/model.iter-"
|
|
+ str(logger["best_epoch"]),
|
|
weights_only=False,
|
|
)
|
|
)
|
|
model.eval()
|
|
correct_label = 0
|
|
with torch.no_grad():
|
|
for batch_idx, (batch_graph, graph_labels) in enumerate(dataloader):
|
|
for key, value in batch_graph.ndata.items():
|
|
batch_graph.ndata[key] = value.float()
|
|
graph_labels = graph_labels.long()
|
|
if torch.cuda.is_available():
|
|
batch_graph = batch_graph.to(torch.cuda.current_device())
|
|
graph_labels = graph_labels.cuda()
|
|
ypred = model(batch_graph)
|
|
indi = torch.argmax(ypred, dim=1)
|
|
correct = torch.sum(indi == graph_labels)
|
|
correct_label += correct.item()
|
|
result = correct_label / (len(dataloader) * prog_args.batch_size)
|
|
return result
|
|
|
|
|
|
def main():
|
|
"""
|
|
main
|
|
"""
|
|
prog_args = arg_parse()
|
|
print(prog_args)
|
|
graph_classify_task(prog_args)
|
|
|
|
print(
|
|
"Train time per epoch: {:.4f}".format(
|
|
sum(global_train_time_per_epoch) / len(global_train_time_per_epoch)
|
|
)
|
|
)
|
|
print(
|
|
"Max memory usage: {:.4f}".format(
|
|
torch.cuda.max_memory_allocated(0) / (1024 * 1024)
|
|
)
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|