mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][PyG] Heterogenous example. (#7722)
This commit is contained in:
committed by
GitHub
parent
c45d299c1d
commit
1d378f8f83
548
examples/graphbolt/pyg/hetero/node_classification.py
Normal file
548
examples/graphbolt/pyg/hetero/node_classification.py
Normal file
@@ -0,0 +1,548 @@
|
||||
"""
|
||||
This script is a PyG counterpart of ``/examples/graphbolt/rgcn/hetero_rgcn.py``.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
|
||||
import dgl.graphbolt as gb
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch_geometric.nn import SimpleConv
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def accuracy(out, labels):
|
||||
assert out.ndim == 2
|
||||
assert out.size(0) == labels.size(0)
|
||||
assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)
|
||||
labels = labels.flatten()
|
||||
predictions = torch.argmax(out, 1)
|
||||
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)
|
||||
|
||||
|
||||
def create_dataloader(
|
||||
graph,
|
||||
features,
|
||||
itemset,
|
||||
batch_size,
|
||||
fanout,
|
||||
device,
|
||||
job,
|
||||
):
|
||||
"""Create a GraphBolt dataloader for training, validation or testing."""
|
||||
datapipe = gb.ItemSampler(
|
||||
itemset,
|
||||
batch_size=batch_size,
|
||||
shuffle=(job == "train"),
|
||||
drop_last=(job == "train"),
|
||||
)
|
||||
need_copy = True
|
||||
# Copy the data to the specified device.
|
||||
if args.graph_device != "cpu" and need_copy:
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
need_copy = False
|
||||
# Sample neighbors for each node in the mini-batch.
|
||||
datapipe = getattr(datapipe, args.sample_mode)(
|
||||
graph,
|
||||
fanout if job != "infer" else [-1],
|
||||
overlap_fetch=args.overlap_graph_fetch,
|
||||
num_gpu_cached_edges=args.num_gpu_cached_edges,
|
||||
gpu_cache_threshold=args.gpu_graph_caching_threshold,
|
||||
asynchronous=args.graph_device != "cpu",
|
||||
)
|
||||
# Copy the data to the specified device.
|
||||
if args.feature_device != "cpu" and need_copy:
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
need_copy = False
|
||||
|
||||
if args.dataset == "ogb-lsc-mag240m":
|
||||
node_feature_keys = {
|
||||
"paper": ["feat"],
|
||||
"author": ["feat"],
|
||||
"institution": ["feat"],
|
||||
}
|
||||
# Fetch node features for the sampled subgraph.
|
||||
datapipe = datapipe.fetch_feature(features, node_feature_keys)
|
||||
|
||||
# Copy the data to the specified device.
|
||||
if need_copy:
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
# Create and return a DataLoader to handle data loading.
|
||||
return gb.DataLoader(datapipe, num_workers=args.num_workers)
|
||||
|
||||
|
||||
def convert_to_pyg(h, subgraph):
|
||||
#####################################################################
|
||||
# (HIGHLIGHT) Convert given features to be consumed by a PyG layer.
|
||||
#
|
||||
# We convert the provided sampled edges in CSC format from GraphBolt and
|
||||
# convert to COO via using gb.expand_indptr.
|
||||
#####################################################################
|
||||
h_dst_dict = {}
|
||||
edge_index_dict = {}
|
||||
sizes_dict = {}
|
||||
for etype, sampled_csc in subgraph.sampled_csc.items():
|
||||
src = sampled_csc.indices
|
||||
dst = gb.expand_indptr(
|
||||
sampled_csc.indptr,
|
||||
dtype=src.dtype,
|
||||
output_size=src.size(0),
|
||||
)
|
||||
edge_index = torch.stack([src, dst], dim=0).long()
|
||||
dst_size = sampled_csc.indptr.size(0) - 1
|
||||
# h and h[:dst_size] correspond to source and destination features resp.
|
||||
src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)
|
||||
h_dst_dict[dst_ntype] = h[dst_ntype][:dst_size]
|
||||
edge_index_dict[etype] = edge_index
|
||||
sizes_dict[etype] = (h[src_ntype].size(0), dst_size)
|
||||
|
||||
return (h, h_dst_dict), edge_index_dict, sizes_dict
|
||||
|
||||
|
||||
class RelGraphConvLayer(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
out_size,
|
||||
ntypes,
|
||||
etypes,
|
||||
activation,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.in_size = in_size
|
||||
self.out_size = out_size
|
||||
self.activation = activation
|
||||
|
||||
# Create a separate convolution layer for each relationship. PyG's
|
||||
# SimpleConv does not have any weights and only performs message passing
|
||||
# and aggregation.
|
||||
self.convs = nn.ModuleDict(
|
||||
{etype: SimpleConv(aggr="mean") for etype in etypes}
|
||||
)
|
||||
|
||||
# Create a separate Linear layer for each relationship. Each
|
||||
# relationship has its own weights which will be applied to the node
|
||||
# features before performing convolution.
|
||||
self.weight = nn.ModuleDict(
|
||||
{
|
||||
etype: nn.Linear(in_size, out_size, bias=False)
|
||||
for etype in etypes
|
||||
}
|
||||
)
|
||||
|
||||
# Create a separate Linear layer for each node type.
|
||||
# loop_weights are used to update the output embedding of each target node
|
||||
# based on its own features, thereby allowing the model to refine the node
|
||||
# representations. Note that this does not imply the existence of self-loop
|
||||
# edges in the graph. It is similar to residual connection.
|
||||
self.loop_weights = nn.ModuleDict(
|
||||
{ntype: nn.Linear(in_size, out_size, bias=True) for ntype in ntypes}
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, subgraph, x):
|
||||
# Create a dictionary of node features for the destination nodes in
|
||||
# the graph. We slice the node features according to the number of
|
||||
# destination nodes of each type. This is necessary because when
|
||||
# incorporating the effect of self-loop edges, we perform computations
|
||||
# only on the destination nodes' features. By doing so, we ensure the
|
||||
# feature dimensions match and prevent any misuse of incorrect node
|
||||
# features.
|
||||
(h, h_dst), edge_index, size = convert_to_pyg(x, subgraph)
|
||||
|
||||
h_out = {}
|
||||
for etype in edge_index:
|
||||
src_ntype, _, dst_ntype = gb.etype_str_to_tuple(etype)
|
||||
# h_dst is unused in SimpleConv.
|
||||
t = self.convs[etype](
|
||||
(h[src_ntype], h_dst[dst_ntype]),
|
||||
edge_index[etype],
|
||||
size=size[etype],
|
||||
)
|
||||
t = self.weight[etype](t)
|
||||
if dst_ntype in h_out:
|
||||
h_out[dst_ntype] += t
|
||||
else:
|
||||
h_out[dst_ntype] = t
|
||||
|
||||
def _apply(ntype, x):
|
||||
# Apply the `loop_weight` to the input node features, effectively
|
||||
# acting as a residual connection. This allows the model to refine
|
||||
# node embeddings based on its current features.
|
||||
x = x + self.loop_weights[ntype](h_dst[ntype])
|
||||
return self.dropout(self.activation(x))
|
||||
|
||||
# Apply the function defined above for each node type. This will update
|
||||
# the node features using the `loop_weights`, apply the activation
|
||||
# function and dropout.
|
||||
return {ntype: _apply(ntype, h) for ntype, h in h_out.items()}
|
||||
|
||||
|
||||
class EntityClassify(nn.Module):
|
||||
def __init__(self, graph, in_size, hidden_size, out_size, n_layers):
|
||||
super(EntityClassify, self).__init__()
|
||||
self.layers = nn.ModuleList()
|
||||
sizes = [in_size] + [hidden_size] * (n_layers - 1) + [out_size]
|
||||
for i in range(n_layers):
|
||||
self.layers.append(
|
||||
RelGraphConvLayer(
|
||||
sizes[i],
|
||||
sizes[i + 1],
|
||||
graph.node_type_to_id.keys(),
|
||||
graph.edge_type_to_id.keys(),
|
||||
activation=F.relu if i != n_layers - 1 else lambda x: x,
|
||||
dropout=0.5,
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, subgraphs, h):
|
||||
for layer, subgraph in zip(self.layers, subgraphs):
|
||||
h = layer(subgraph, h)
|
||||
return h
|
||||
|
||||
|
||||
@torch.compile
|
||||
def evaluate_step(minibatch, model):
|
||||
category = "paper"
|
||||
node_features = {
|
||||
ntype: feat.float()
|
||||
for (ntype, name), feat in minibatch.node_features.items()
|
||||
if name == "feat"
|
||||
}
|
||||
labels = minibatch.labels[category].long()
|
||||
out = model(minibatch.sampled_subgraphs, node_features)[category]
|
||||
num_correct = accuracy(out, labels) * labels.size(0)
|
||||
return num_correct, labels.size(0)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(
|
||||
model,
|
||||
dataloader,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
):
|
||||
model.eval()
|
||||
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
|
||||
total_samples = 0
|
||||
dataloader = tqdm(dataloader, desc="Evaluating")
|
||||
for step, minibatch in enumerate(dataloader):
|
||||
num_correct, num_samples = evaluate_step(minibatch, model)
|
||||
total_correct += num_correct
|
||||
total_samples += num_samples
|
||||
if step % 15 == 0:
|
||||
num_nodes = sum(id.size(0) for id in minibatch.node_ids().values())
|
||||
dataloader.set_postfix(
|
||||
{
|
||||
"num_nodes": num_nodes,
|
||||
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
|
||||
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
|
||||
}
|
||||
)
|
||||
|
||||
return total_correct / total_samples
|
||||
|
||||
|
||||
@torch.compile
|
||||
def train_step(minibatch, optimizer, model, loss_fn):
|
||||
category = "paper"
|
||||
node_features = {
|
||||
ntype: feat.float()
|
||||
for (ntype, name), feat in minibatch.node_features.items()
|
||||
if name == "feat"
|
||||
}
|
||||
labels = minibatch.labels[category].long()
|
||||
optimizer.zero_grad()
|
||||
out = model(minibatch.sampled_subgraphs, node_features)[category]
|
||||
loss = loss_fn(out, labels)
|
||||
# https://github.com/pytorch/pytorch/issues/133942
|
||||
# num_correct = accuracy(out, labels) * labels.size(0)
|
||||
num_correct = torch.zeros(1, dtype=torch.float64, device=out.device)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
return loss.detach(), num_correct, labels.size(0)
|
||||
|
||||
|
||||
def train_helper(
|
||||
dataloader,
|
||||
model,
|
||||
optimizer,
|
||||
loss_fn,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
):
|
||||
model.train()
|
||||
total_loss = torch.zeros(1, device=device)
|
||||
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
|
||||
total_samples = 0
|
||||
start = time.time()
|
||||
dataloader = tqdm(dataloader, "Training")
|
||||
for step, minibatch in enumerate(dataloader):
|
||||
loss, num_correct, num_samples = train_step(
|
||||
minibatch, optimizer, model, loss_fn
|
||||
)
|
||||
total_loss += loss * num_samples
|
||||
total_correct += num_correct
|
||||
total_samples += num_samples
|
||||
if step % 15 == 0:
|
||||
# log every 15 steps for performance.
|
||||
num_nodes = sum(id.size(0) for id in minibatch.node_ids().values())
|
||||
dataloader.set_postfix(
|
||||
{
|
||||
"num_nodes": num_nodes,
|
||||
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
|
||||
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
|
||||
}
|
||||
)
|
||||
loss = total_loss / total_samples
|
||||
acc = total_correct / total_samples
|
||||
end = time.time()
|
||||
return loss, acc, end - start
|
||||
|
||||
|
||||
def train(
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
model,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
loss_fn = nn.CrossEntropyLoss()
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_loss, train_acc, duration = train_helper(
|
||||
train_dataloader,
|
||||
model,
|
||||
optimizer,
|
||||
loss_fn,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
)
|
||||
val_acc = evaluate(
|
||||
model,
|
||||
valid_dataloader,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
)
|
||||
print(
|
||||
f"Epoch: {epoch:02d}, Loss: {train_loss.item():.4f}, "
|
||||
f"Approx. Train: {train_acc.item():.4f}, "
|
||||
f"Approx. Val: {val_acc.item():.4f}, "
|
||||
f"Time: {duration}s"
|
||||
)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="GraphBolt PyG R-SAGE")
|
||||
parser.add_argument(
|
||||
"--epochs", type=int, default=10, help="Number of training epochs."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Learning rate for optimization.",
|
||||
)
|
||||
parser.add_argument("--num-hidden", type=int, default=1024)
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=1024, help="Batch size for training."
|
||||
)
|
||||
parser.add_argument("--num_workers", type=int, default=0)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
default="ogb-lsc-mag240m",
|
||||
choices=["ogb-lsc-mag240m"],
|
||||
help="Dataset name. Possible values: ogb-lsc-mag240m",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fanout",
|
||||
type=str,
|
||||
default="25,10",
|
||||
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
|
||||
" identical with the number of layers in your model. Default: 25,10",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
default="pinned-pinned-cuda",
|
||||
choices=[
|
||||
"cpu-cpu-cpu",
|
||||
"cpu-cpu-cuda",
|
||||
"cpu-pinned-cuda",
|
||||
"pinned-pinned-cuda",
|
||||
"cuda-pinned-cuda",
|
||||
"cuda-cuda-cuda",
|
||||
],
|
||||
help="Graph storage - feature storage - Train device: 'cpu' for CPU and RAM,"
|
||||
" 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample-mode",
|
||||
default="sample_neighbor",
|
||||
choices=["sample_neighbor", "sample_layer_neighbor"],
|
||||
help="The sampling function when doing layerwise sampling.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu-feature-cache-policy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["s3-fifo", "sieve", "lru", "clock"],
|
||||
help="The cache policy for the CPU feature cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpu-cache-size",
|
||||
type=float,
|
||||
default=0,
|
||||
help="The capacity of the CPU feature cache in GiB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-cache-size",
|
||||
type=float,
|
||||
default=0,
|
||||
help="The capacity of the GPU feature cache in GiB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpu-cached-edges",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The number of edges to be cached from the graph on the GPU.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-graph-caching-threshold",
|
||||
type=int,
|
||||
default=1,
|
||||
help="The number of accesses after which a vertex neighborhood will be cached.",
|
||||
)
|
||||
parser.add_argument("--precision", type=str, default="high")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_float32_matmul_precision(args.precision)
|
||||
if not torch.cuda.is_available():
|
||||
args.mode = "cpu-cpu-cpu"
|
||||
print(f"Training in {args.mode} mode.")
|
||||
args.graph_device, args.feature_device, args.device = args.mode.split("-")
|
||||
args.overlap_feature_fetch = args.feature_device == "pinned"
|
||||
args.overlap_graph_fetch = args.graph_device == "pinned"
|
||||
|
||||
# Load dataset.
|
||||
dataset = gb.BuiltinDataset(args.dataset).load()
|
||||
print("Dataset loaded")
|
||||
|
||||
# Move the dataset to the selected storage.
|
||||
graph = (
|
||||
dataset.graph.pin_memory_()
|
||||
if args.graph_device == "pinned"
|
||||
else dataset.graph.to(args.graph_device)
|
||||
)
|
||||
features = (
|
||||
dataset.feature.pin_memory_()
|
||||
if args.feature_device == "pinned"
|
||||
else dataset.feature.to(args.feature_device)
|
||||
)
|
||||
|
||||
train_set = dataset.tasks[0].train_set
|
||||
valid_set = dataset.tasks[0].validation_set
|
||||
test_set = dataset.tasks[0].test_set
|
||||
args.fanout = list(map(int, args.fanout.split(",")))
|
||||
|
||||
num_classes = dataset.tasks[0].metadata["num_classes"]
|
||||
num_etypes = len(graph.num_edges)
|
||||
|
||||
feats_on_disk = {
|
||||
k: features[k]
|
||||
for k in features.keys()
|
||||
if k[2] == "feat" and isinstance(features[k], gb.DiskBasedFeature)
|
||||
}
|
||||
|
||||
if args.cpu_cache_size > 0 and len(feats_on_disk) > 0:
|
||||
cached_features = gb.cpu_cached_feature(
|
||||
feats_on_disk,
|
||||
int(args.cpu_cache_size * (2**30)),
|
||||
args.cpu_feature_cache_policy,
|
||||
args.feature_device == "pinned",
|
||||
)
|
||||
for k, cpu_cached_feature in cached_features.items():
|
||||
features[k] = cpu_cached_feature
|
||||
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate
|
||||
else:
|
||||
cpu_cache_miss_rate_fn = lambda: 1
|
||||
|
||||
if args.gpu_cache_size > 0 and args.feature_device != "cuda":
|
||||
feats = {k: features[k] for k in features.keys() if k[2] == "feat"}
|
||||
cached_features = gb.gpu_cached_feature(
|
||||
feats,
|
||||
int(args.gpu_cache_size * (2**30)),
|
||||
)
|
||||
for k, gpu_cached_feature in cached_features.items():
|
||||
features[k] = gpu_cached_feature
|
||||
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate
|
||||
else:
|
||||
gpu_cache_miss_rate_fn = lambda: 1
|
||||
|
||||
train_dataloader, valid_dataloader, test_dataloader = (
|
||||
create_dataloader(
|
||||
graph=graph,
|
||||
features=features,
|
||||
itemset=itemset,
|
||||
batch_size=args.batch_size,
|
||||
fanout=[
|
||||
torch.full((num_etypes,), fanout) for fanout in args.fanout
|
||||
],
|
||||
device=args.device,
|
||||
job=job,
|
||||
)
|
||||
for itemset, job in zip(
|
||||
[train_set, valid_set, test_set], ["train", "evaluate", "evaluate"]
|
||||
)
|
||||
)
|
||||
|
||||
feat_size = features.size("node", "paper", "feat")[0]
|
||||
hidden_channels = args.num_hidden
|
||||
|
||||
# Initialize the entity classification model.
|
||||
model = EntityClassify(
|
||||
graph, feat_size, hidden_channels, num_classes, 3
|
||||
).to(args.device)
|
||||
|
||||
print(
|
||||
"Number of model parameters: "
|
||||
f"{sum(p.numel() for p in model.parameters())}"
|
||||
)
|
||||
|
||||
train(
|
||||
train_dataloader,
|
||||
valid_dataloader,
|
||||
model,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
args.device,
|
||||
)
|
||||
|
||||
# Labels are currently unavailable for mag240M so the test acc will be 0.
|
||||
print("Testing...")
|
||||
test_acc = evaluate(
|
||||
model,
|
||||
test_dataloader,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
args.device,
|
||||
)
|
||||
print(f"Test accuracy {test_acc.item():.4f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main()
|
||||
Reference in New Issue
Block a user