mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
464 lines
16 KiB
Python
464 lines
16 KiB
Python
"""
|
|
This script trains and tests a GraphSAGE model for link prediction on
|
|
large graphs using graphbolt dataloader. It is the PyG counterpart of the
|
|
example in `examples/graphbolt/link_prediction.py`.
|
|
|
|
Paper: [Inductive Representation Learning on Large Graphs]
|
|
(https://arxiv.org/abs/1706.02216)
|
|
|
|
While node classification predicts labels for nodes based on their
|
|
local neighborhoods, link prediction assesses the likelihood of an edge
|
|
existing between two nodes, necessitating different sampling strategies
|
|
that account for pairs of nodes and their joint neighborhoods.
|
|
|
|
This flowchart describes the main functional sequence of the provided example.
|
|
main
|
|
│
|
|
├───> OnDiskDataset pre-processing
|
|
│
|
|
├───> Instantiate SAGE model
|
|
│
|
|
├───> train
|
|
│ │
|
|
│ ├───> Get graphbolt dataloader (HIGHLIGHT)
|
|
| |
|
|
| |───> Define a PyG GNN model for link prediction (HIGHLIGHT)
|
|
│ │
|
|
│ └───> Training loop
|
|
│ │
|
|
│ ├───> SAGE.forward
|
|
│
|
|
└───> Validation and test set evaluation
|
|
"""
|
|
import argparse
|
|
import time
|
|
from functools import partial
|
|
|
|
import dgl.graphbolt as gb
|
|
import torch
|
|
|
|
# For torch.compile until https://github.com/pytorch/pytorch/issues/121197 is
|
|
# resolved.
|
|
import torch._inductor.codecache
|
|
|
|
torch._dynamo.config.cache_size_limit = 32
|
|
|
|
import torch.nn.functional as F
|
|
from torch_geometric.nn import SAGEConv
|
|
from torchmetrics.retrieval import RetrievalMRR
|
|
from tqdm import tqdm, trange
|
|
|
|
|
|
class GraphSAGE(torch.nn.Module):
|
|
#####################################################################
|
|
# (HIGHLIGHT) Define the GraphSAGE model architecture.
|
|
#
|
|
# - This class inherits from `torch.nn.Module`.
|
|
# - Two convolutional layers are created using the SAGEConv class from PyG.
|
|
# - The forward method defines the computation performed at every call.
|
|
#####################################################################
|
|
def __init__(self, in_size, hidden_size, n_layers):
|
|
super(GraphSAGE, self).__init__()
|
|
self.layers = torch.nn.ModuleList()
|
|
sizes = [in_size] + [hidden_size] * n_layers
|
|
for i in range(n_layers):
|
|
self.layers.append(SAGEConv(sizes[i], sizes[i + 1]))
|
|
self.hidden_size = hidden_size
|
|
self.predictor = torch.nn.Sequential(
|
|
torch.nn.Linear(hidden_size, hidden_size),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_size, hidden_size),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(hidden_size, 1),
|
|
)
|
|
|
|
def forward(self, subgraphs, x):
|
|
h = x
|
|
for i, (layer, subgraph) in enumerate(zip(self.layers, subgraphs)):
|
|
#####################################################################
|
|
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
|
|
#
|
|
# PyG layers have two modes, bipartite and normal. We slice the
|
|
# given features to get src and dst features to use the PyG layers
|
|
# in the more efficient bipartite mode.
|
|
#####################################################################
|
|
h, edge_index, size = subgraph.to_pyg(h)
|
|
h = layer(h, edge_index, size=size)
|
|
if i != len(subgraphs) - 1:
|
|
h = F.relu(h)
|
|
return h
|
|
|
|
def inference(self, graph, features, dataloader, storage_device):
|
|
"""Conduct layer-wise inference to get all the node embeddings."""
|
|
pin_memory = storage_device == "pinned"
|
|
buffer_device = torch.device("cpu" if pin_memory else storage_device)
|
|
|
|
for layer_idx, layer in enumerate(self.layers):
|
|
is_last_layer = layer_idx == len(self.layers) - 1
|
|
|
|
y = torch.empty(
|
|
graph.total_num_nodes,
|
|
self.hidden_size,
|
|
dtype=torch.float32,
|
|
device=buffer_device,
|
|
pin_memory=pin_memory,
|
|
)
|
|
for data in tqdm(dataloader, "Inferencing"):
|
|
# len(data.sampled_subgraphs) = 1
|
|
h, edge_index, size = data.sampled_subgraphs[0].to_pyg(
|
|
data.node_features["feat"]
|
|
)
|
|
hidden_x = layer(h, edge_index, size=size)
|
|
if not is_last_layer:
|
|
hidden_x = F.relu(hidden_x)
|
|
# By design, our output nodes are contiguous.
|
|
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
|
|
buffer_device
|
|
)
|
|
if not is_last_layer:
|
|
features.update("node", None, "feat", y)
|
|
|
|
return y
|
|
|
|
|
|
def create_dataloader(
|
|
graph, features, itemset, batch_size, fanout, device, job
|
|
):
|
|
#####################################################################
|
|
# (HIGHLIGHT) Create a data loader for efficiently loading graph data.
|
|
#
|
|
# - 'ItemSampler' samples mini-batches of node IDs from the dataset.
|
|
# - 'CopyTo' copies the fetched data to the specified device.
|
|
# - 'sample_neighbor' performs neighbor sampling on the graph.
|
|
# - 'FeatureFetcher' fetches node features based on the sampled subgraph.
|
|
|
|
#####################################################################
|
|
# Create a datapipe for mini-batch sampling with a specific neighbor fanout.
|
|
# Here, [10, 10, 10] specifies the number of neighbors sampled for each node at each layer.
|
|
# We're using `sample_neighbor` for consistency with DGL's sampling API.
|
|
# Note: GraphBolt offers additional sampling methods, such as `sample_layer_neighbor`,
|
|
# which could provide further optimization and efficiency for GNN training.
|
|
# Users are encouraged to explore these advanced features for potentially improved performance.
|
|
|
|
# Initialize an ItemSampler to sample mini-batches from the dataset.
|
|
datapipe = gb.ItemSampler(
|
|
itemset,
|
|
batch_size=batch_size,
|
|
shuffle=(job == "train"),
|
|
drop_last=(job == "train"),
|
|
)
|
|
need_copy = True
|
|
# Copy the data to the specified device.
|
|
if args.graph_device != "cpu" and need_copy:
|
|
datapipe = datapipe.copy_to(device=device)
|
|
need_copy = False
|
|
# Sample negative edges.
|
|
if job == "train":
|
|
datapipe = datapipe.sample_uniform_negative(graph, args.neg_ratio)
|
|
# Sample neighbors for each node in the mini-batch.
|
|
datapipe = getattr(datapipe, args.sample_mode)(
|
|
graph,
|
|
fanout if job != "infer" else [-1],
|
|
overlap_fetch=args.overlap_graph_fetch,
|
|
asynchronous=args.graph_device != "cpu",
|
|
)
|
|
if job == "train" and args.exclude_edges:
|
|
datapipe = datapipe.exclude_seed_edges(
|
|
include_reverse_edges=True,
|
|
asynchronous=args.graph_device != "cpu",
|
|
)
|
|
# Copy the data to the specified device.
|
|
if args.feature_device != "cpu" and need_copy:
|
|
datapipe = datapipe.copy_to(device=device)
|
|
need_copy = False
|
|
# Fetch node features for the sampled subgraph.
|
|
datapipe = datapipe.fetch_feature(
|
|
features,
|
|
node_feature_keys=["feat"],
|
|
overlap_fetch=args.overlap_feature_fetch,
|
|
)
|
|
# Copy the data to the specified device.
|
|
if need_copy:
|
|
datapipe = datapipe.copy_to(device=device)
|
|
# Create and return a DataLoader to handle data loading.
|
|
return gb.DataLoader(datapipe, num_workers=args.num_workers)
|
|
|
|
|
|
@torch.compile
|
|
def predictions_step(model, h_src, h_dst):
|
|
return model.predictor(h_src * h_dst).squeeze()
|
|
|
|
|
|
def compute_predictions(model, node_emb, seeds, device):
|
|
"""Compute the predictions for given source and destination nodes.
|
|
|
|
This function computes the predictions for a set of node pairs, dividing the
|
|
task into batches to handle potentially large graphs.
|
|
"""
|
|
|
|
preds = torch.empty(seeds.shape[0], device=device)
|
|
seeds_src, seeds_dst = seeds.T
|
|
# The constant number is 1001, due to negtive ratio in the `ogbl-citation2`
|
|
# dataset is 1000.
|
|
eval_size = args.eval_batch_size * 1001
|
|
# Loop over node pairs in batches.
|
|
for start in trange(0, seeds_src.shape[0], eval_size, desc="Evaluate"):
|
|
end = min(start + eval_size, seeds_src.shape[0])
|
|
|
|
# Fetch embeddings for current batch of source and destination nodes.
|
|
h_src = node_emb[seeds_src[start:end]].to(device, non_blocking=True)
|
|
h_dst = node_emb[seeds_dst[start:end]].to(device, non_blocking=True)
|
|
|
|
# Compute prediction scores using the model.
|
|
preds[start:end] = predictions_step(model, h_src, h_dst)
|
|
return preds
|
|
|
|
|
|
@torch.no_grad()
|
|
def evaluate(model, graph, features, all_nodes_set, valid_set, test_set):
|
|
"""Evaluate the model on validation and test sets."""
|
|
model.eval()
|
|
|
|
dataloader = create_dataloader(
|
|
graph,
|
|
features,
|
|
all_nodes_set,
|
|
args.eval_batch_size,
|
|
[-1],
|
|
args.device,
|
|
job="infer",
|
|
)
|
|
|
|
# Compute node embeddings for the entire graph.
|
|
node_emb = model.inference(graph, features, dataloader, args.feature_device)
|
|
results = []
|
|
|
|
# Loop over both validation and test sets.
|
|
for split in [valid_set, test_set]:
|
|
# Unpack the item set.
|
|
seeds = split._items[0].to(node_emb.device)
|
|
labels = split._items[1].to(node_emb.device)
|
|
indexes = split._items[2].to(node_emb.device)
|
|
|
|
preds = compute_predictions(model, node_emb, seeds, indexes.device)
|
|
# Compute MRR values for the current split.
|
|
results.append(RetrievalMRR()(preds, labels, indexes))
|
|
return results
|
|
|
|
|
|
@torch.compile
|
|
def train_step(minibatch, optimizer, model):
|
|
node_features = minibatch.node_features["feat"]
|
|
compacted_seeds = minibatch.compacted_seeds.T
|
|
labels = minibatch.labels
|
|
optimizer.zero_grad()
|
|
y = model(minibatch.sampled_subgraphs, node_features)
|
|
logits = model.predictor(
|
|
y[compacted_seeds[0]] * y[compacted_seeds[1]]
|
|
).squeeze()
|
|
loss = F.binary_cross_entropy_with_logits(logits, labels)
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss.detach(), labels.size(0)
|
|
|
|
|
|
def train_helper(dataloader, model, optimizer, device):
|
|
model.train() # Set the model to training mode
|
|
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
|
|
total_samples = 0 # Accumulator for the total number of samples processed
|
|
start = time.time()
|
|
for step, minibatch in tqdm(enumerate(dataloader), "Training"):
|
|
loss, num_samples = train_step(minibatch, optimizer, model)
|
|
total_loss += loss * num_samples
|
|
total_samples += num_samples
|
|
if step + 1 == args.early_stop:
|
|
break
|
|
train_loss = total_loss / total_samples
|
|
end = time.time()
|
|
return train_loss, end - start
|
|
|
|
|
|
def train(dataloader, model, device):
|
|
#####################################################################
|
|
# (HIGHLIGHT) Train the model for one epoch.
|
|
#
|
|
# - Iterates over the data loader, fetching mini-batches of graph data.
|
|
# - For each mini-batch, it performs a forward pass, computes loss, and
|
|
# updates the model parameters.
|
|
# - The function returns the average loss and accuracy for the epoch.
|
|
#
|
|
# Parameters:
|
|
# dataloader: DataLoader that provides mini-batches of graph data.
|
|
# model: The GraphSAGE model.
|
|
# device: The device (CPU/GPU) to run the training on.
|
|
#####################################################################
|
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
|
|
|
for epoch in range(args.epochs):
|
|
train_loss, duration = train_helper(
|
|
dataloader, model, optimizer, device
|
|
)
|
|
print(
|
|
f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, "
|
|
f"Time: {duration}s"
|
|
)
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(
|
|
description="Which dataset are you going to use?"
|
|
)
|
|
parser.add_argument(
|
|
"--epochs", type=int, default=10, help="Number of training epochs."
|
|
)
|
|
parser.add_argument(
|
|
"--lr",
|
|
type=float,
|
|
default=0.003,
|
|
help="Learning rate for optimization.",
|
|
)
|
|
parser.add_argument("--neg-ratio", type=int, default=1)
|
|
parser.add_argument("--train-batch-size", type=int, default=512)
|
|
parser.add_argument("--eval-batch-size", type=int, default=1024)
|
|
parser.add_argument(
|
|
"--batch-size", type=int, default=1024, help="Batch size for training."
|
|
)
|
|
parser.add_argument(
|
|
"--num-workers",
|
|
type=int,
|
|
default=0,
|
|
help="Number of workers for data loading.",
|
|
)
|
|
parser.add_argument(
|
|
"--early-stop",
|
|
type=int,
|
|
default=0,
|
|
help="0 means no early stop, otherwise stop at the input-th step",
|
|
)
|
|
parser.add_argument(
|
|
"--dataset",
|
|
type=str,
|
|
default="ogbl-citation2",
|
|
choices=["ogbl-citation2"],
|
|
help="The dataset we can use for link prediction. Currently"
|
|
" only ogbl-citation2 dataset is supported.",
|
|
)
|
|
parser.add_argument(
|
|
"--fanout",
|
|
type=str,
|
|
default="10,10,10",
|
|
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
|
|
" identical with the number of layers in your model. Default: 10,10,10",
|
|
)
|
|
parser.add_argument(
|
|
"--exclude-edges",
|
|
type=bool,
|
|
default=True,
|
|
help="Whether to exclude reverse edges during sampling. Default: True",
|
|
)
|
|
parser.add_argument(
|
|
"--mode",
|
|
default="pinned-pinned-cuda",
|
|
choices=[
|
|
"cpu-cpu-cpu",
|
|
"cpu-cpu-cuda",
|
|
"cpu-pinned-cuda",
|
|
"pinned-pinned-cuda",
|
|
"cuda-pinned-cuda",
|
|
"cuda-cuda-cuda",
|
|
],
|
|
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
|
|
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
|
|
)
|
|
parser.add_argument(
|
|
"--gpu-cache-size",
|
|
type=int,
|
|
default=0,
|
|
help="The capacity of the GPU cache in bytes.",
|
|
)
|
|
parser.add_argument(
|
|
"--sample-mode",
|
|
default="sample_neighbor",
|
|
choices=["sample_neighbor", "sample_layer_neighbor"],
|
|
help="The sampling function when doing layerwise sampling.",
|
|
)
|
|
parser.add_argument("--precision", type=str, default="high")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
torch.set_float32_matmul_precision(args.precision)
|
|
if not torch.cuda.is_available():
|
|
args.mode = "cpu-cpu-cpu"
|
|
print(f"Training in {args.mode} mode.")
|
|
args.graph_device, args.feature_device, args.device = args.mode.split("-")
|
|
args.overlap_feature_fetch = args.feature_device == "pinned"
|
|
args.overlap_graph_fetch = args.graph_device == "pinned"
|
|
|
|
# Load and preprocess dataset.
|
|
print("Loading data...")
|
|
dataset = gb.BuiltinDataset(args.dataset).load()
|
|
|
|
# Move the dataset to the selected storage.
|
|
graph = (
|
|
dataset.graph.pin_memory_()
|
|
if args.graph_device == "pinned"
|
|
else dataset.graph.to(args.graph_device)
|
|
)
|
|
features = (
|
|
dataset.feature.pin_memory_()
|
|
if args.feature_device == "pinned"
|
|
else dataset.feature.to(args.feature_device)
|
|
)
|
|
|
|
train_set = dataset.tasks[0].train_set
|
|
valid_set = dataset.tasks[0].validation_set
|
|
test_set = dataset.tasks[0].test_set
|
|
all_nodes_set = dataset.all_nodes_set
|
|
args.fanout = list(map(int, args.fanout.split(",")))
|
|
|
|
if args.gpu_cache_size > 0 and args.feature_device != "cuda":
|
|
features._features[("node", None, "feat")] = gb.gpu_cached_feature(
|
|
features._features[("node", None, "feat")],
|
|
args.gpu_cache_size,
|
|
)
|
|
|
|
train_dataloader = create_dataloader(
|
|
graph=graph,
|
|
features=features,
|
|
itemset=train_set,
|
|
batch_size=args.train_batch_size,
|
|
fanout=args.fanout,
|
|
device=args.device,
|
|
job="train",
|
|
)
|
|
|
|
in_channels = features.size("node", None, "feat")[0]
|
|
hidden_channels = 256
|
|
model = GraphSAGE(in_channels, hidden_channels, len(args.fanout)).to(
|
|
args.device
|
|
)
|
|
assert len(args.fanout) == len(model.layers)
|
|
|
|
train(train_dataloader, model, args.device)
|
|
|
|
# Test the model.
|
|
print("Testing...")
|
|
valid_mrr, test_mrr = evaluate(
|
|
model,
|
|
graph,
|
|
features,
|
|
all_nodes_set,
|
|
valid_set,
|
|
test_set,
|
|
)
|
|
print(
|
|
f"Validation MRR {valid_mrr.item():.4f}, Test MRR {test_mrr.item():.4f}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = parse_args()
|
|
main()
|