Files
dgl/examples/graphbolt/disk_based_feature/node_classification.py
Bowen Yao f9097ee438 [Graphbolt][Dataset] Add igb-hom dataset (#7781)
Co-authored-by: Muhammed Fatih BALIN <m.f.balin@gmail.com>
2024-09-06 01:04:53 -04:00

541 lines
18 KiB
Python

"""
This example references examples/graphbolt/pyg/labor/node_classification.py
"""
import argparse
import time
from copy import deepcopy
import dgl.graphbolt as gb
import dgl.nn as dglnn
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
def accuracy(out, labels):
assert out.ndim == 2
assert out.size(0) == labels.size(0)
assert labels.ndim == 1 or (labels.ndim == 2 and labels.size(1) == 1)
labels = labels.flatten()
predictions = torch.argmax(out, 1)
return (labels == predictions).sum(dtype=torch.float64) / labels.size(0)
class SAGE(nn.Module):
def __init__(self, in_size, hidden_size, out_size, num_layers, dropout):
super().__init__()
self.layers = nn.ModuleList()
# Three-layer GraphSAGE-mean.
self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
for _ in range(num_layers - 2):
self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
self.layers.append(dglnn.SAGEConv(hidden_size, out_size, "mean"))
self.dropout = nn.Dropout(dropout)
self.hidden_size = hidden_size
self.out_size = out_size
# Set the dtype for the layers manually.
self.set_layer_dtype(torch.float32)
def set_layer_dtype(self, _dtype):
for layer in self.layers:
for param in layer.parameters():
param.data = param.data.to(_dtype)
def forward(self, blocks, x):
hidden_x = x
for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
hidden_x = layer(block, hidden_x)
is_last_layer = layer_idx == len(self.layers) - 1
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
return hidden_x
def inference(self, graph, features, dataloader, storage_device):
"""Conduct layer-wise inference to get all the node embeddings."""
pin_memory = storage_device == "pinned"
buffer_device = torch.device("cpu" if pin_memory else storage_device)
for layer_idx, layer in enumerate(self.layers):
is_last_layer = layer_idx == len(self.layers) - 1
y = torch.empty(
graph.total_num_nodes,
self.out_size if is_last_layer else self.hidden_size,
dtype=torch.float32,
device=buffer_device,
pin_memory=pin_memory,
)
for data in tqdm(dataloader):
# len(blocks) = 1
hidden_x = layer(data.blocks[0], data.node_features["feat"])
if not is_last_layer:
hidden_x = F.relu(hidden_x)
hidden_x = self.dropout(hidden_x)
# By design, our output nodes are contiguous.
y[data.seeds[0] : data.seeds[-1] + 1] = hidden_x.to(
buffer_device
)
if not is_last_layer:
features.update("node", None, "feat", y)
return y
def create_dataloader(
graph, features, itemset, batch_size, fanout, device, job
):
# Initialize an ItemSampler to sample mini-batches from the dataset.
datapipe = gb.ItemSampler(
itemset,
batch_size=batch_size,
shuffle=(job == "train"),
drop_last=(job == "train"),
)
# Copy the data to the specified device.
if args.graph_device != "cpu":
datapipe = datapipe.copy_to(device=device)
# Sample neighbors for each node in the mini-batch.
kwargs = (
{
# Layer dependency makes it so that the sampled neighborhoods across layers
# become correlated, reducing the total number of sampled unique nodes in a
# minibatch, thus reducing the amount of feature data requested.
"layer_dependency": args.layer_dependency,
# Batch dependency makes it so that the sampled neighborhoods across minibatches
# become correlated, reducing the total number of sampled unique nodes across
# minibatches, thus increasing temporal locality and reducing cache miss rates.
"batch_dependency": args.batch_dependency,
}
if args.sample_mode == "sample_layer_neighbor"
else {}
)
datapipe = getattr(datapipe, args.sample_mode)(
graph,
fanout if job != "infer" else [-1],
overlap_fetch=args.overlap_graph_fetch,
**kwargs,
)
# Copy the data to the specified device.
if args.feature_device != "cpu":
datapipe = datapipe.copy_to(device=device)
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(
features,
node_feature_keys=["feat"],
overlap_fetch=args.overlap_feature_fetch,
)
# Copy the data to the specified device.
if args.feature_device == "cpu":
datapipe = datapipe.copy_to(device=device)
# Create and return a DataLoader to handle data loading.
return gb.DataLoader(datapipe, num_workers=args.num_workers)
def train_step(minibatch, optimizer, model, loss_fn):
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
optimizer.zero_grad()
out = model(minibatch.blocks, node_features)
loss = loss_fn(out, labels)
num_correct = accuracy(out, labels) * labels.size(0)
loss.backward()
optimizer.step()
return loss.detach(), num_correct, labels.size(0)
def train_helper(
dataloader,
model,
optimizer,
loss_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.train() # Set the model to training mode
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
# Accumulator for the total number of correct predictions
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0 # Accumulator for the total number of samples processed
num_batches = 0 # Counter for the number of mini-batches processed
start = time.time()
dataloader = tqdm(dataloader, "Training")
for step, minibatch in enumerate(dataloader):
loss, num_correct, num_samples = train_step(
minibatch, optimizer, model, loss_fn
)
total_loss += loss
total_correct += num_correct
total_samples += num_samples
num_batches += 1
if step % 25 == 0:
# log every 25 steps for performance.
dataloader.set_postfix(
{
"num_nodes": minibatch.node_ids().size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
train_loss = total_loss / num_batches
train_acc = total_correct / total_samples
end = time.time()
return train_loss, train_acc, end - start
def train(
train_dataloader,
valid_dataloader,
model,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
loss_fn = nn.CrossEntropyLoss()
best_model = None
best_model_acc = 0
best_model_epoch = -1
for epoch in range(args.epochs):
train_loss, train_acc, duration = train_helper(
train_dataloader,
model,
optimizer,
loss_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
val_acc = evaluate(
model,
valid_dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
)
if val_acc > best_model_acc:
best_model_acc = val_acc
best_model = deepcopy(model.state_dict())
best_model_epoch = epoch
print(
f"Epoch {epoch:02d}, Loss: {train_loss.item():.4f}, "
f"Approx. Train: {train_acc.item():.4f}, "
f"Approx. Val: {val_acc.item():.4f}, "
f"Time: {duration}s"
)
if best_model_epoch + args.early_stopping_patience < epoch:
break
return best_model
@torch.no_grad()
def layerwise_infer(
args,
graph,
features,
itemsets,
all_nodes_set,
model,
):
model.eval()
dataloader = create_dataloader(
graph=graph,
features=features,
itemset=all_nodes_set,
batch_size=args.batch_size,
fanout=[-1],
device=args.device,
job="infer",
)
pred = model.inference(graph, features, dataloader, args.feature_device)
metrics = {}
for split_name, itemset in itemsets.items():
nid, labels = itemset[:]
acc = accuracy(
pred[nid.to(pred.device)],
labels.to(pred.device),
)
metrics[split_name] = acc.item()
return metrics
def evaluate_step(minibatch, model):
node_features = minibatch.node_features["feat"]
labels = minibatch.labels
out = model(minibatch.blocks, node_features)
num_correct = accuracy(out, labels) * labels.size(0)
return num_correct, labels.size(0)
@torch.no_grad()
def evaluate(
model,
dataloader,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
model.eval()
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0
val_dataloader_tqdm = tqdm(dataloader, "Evaluating")
for step, minibatch in enumerate(val_dataloader_tqdm):
num_correct, num_samples = evaluate_step(minibatch, model)
total_correct += num_correct
total_samples += num_samples
if step % 25 == 0:
val_dataloader_tqdm.set_postfix(
{
"num_nodes": minibatch.node_ids().size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
return total_correct / total_samples
def parse_args():
parser = argparse.ArgumentParser(
description="Which dataset are you going to use?"
)
parser.add_argument(
"--epochs", type=int, default=9999999, help="Number of training epochs."
)
parser.add_argument(
"--lr",
type=float,
default=0.001,
help="Learning rate for optimization.",
)
parser.add_argument("--num-hidden", type=int, default=256)
parser.add_argument("--dropout", type=float, default=0.2)
parser.add_argument(
"--batch-size", type=int, default=1024, help="Batch size for training."
)
parser.add_argument(
"--num-workers",
type=int,
default=0,
help="Number of workers for data loading.",
)
parser.add_argument(
"--dataset",
type=str,
default="ogbn-products",
choices=[
"ogbn-arxiv",
"ogbn-products",
"ogbn-papers100M",
"igb-hom-tiny",
"igb-hom-small",
"igb-hom-medium",
"igb-hom-large",
"igb-hom",
],
)
parser.add_argument("--root", type=str, default="datasets")
parser.add_argument(
"--fanout",
type=str,
default="10,10,10",
help="Fan-out of neighbor sampling. len(fanout) determines the number of"
" GNN layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--mode",
default="pinned-pinned-cuda",
choices=[
"cpu-cpu-cpu",
"cpu-cpu-cuda",
"cpu-pinned-cuda",
"pinned-pinned-cuda",
"cuda-pinned-cuda",
"cuda-cuda-cuda",
],
help="Graph storage - feature storage - Train device: 'cpu' for CPU and"
" RAM, 'pinned' for pinned memory in RAM, 'cuda' for GPU and GPU memory.",
)
parser.add_argument("--layer-dependency", action="store_true")
parser.add_argument("--batch-dependency", type=int, default=1)
parser.add_argument(
"--cpu-feature-cache-policy",
type=str,
default=None,
choices=["s3-fifo", "sieve", "lru", "clock"],
help="The cache policy for the CPU feature cache.",
)
parser.add_argument(
"--cpu-cache-size-in-gigabytes",
type=float,
default=0,
help="The capacity of the CPU cache in GiB.",
)
parser.add_argument(
"--gpu-cache-size-in-gigabytes",
type=float,
default=0,
help="The capacity of the GPU cache in GiB.",
)
parser.add_argument("--early-stopping-patience", type=int, default=25)
parser.add_argument(
"--sample-mode",
default="sample_neighbor",
choices=["sample_neighbor", "sample_layer_neighbor"],
help="The sampling function when doing layerwise sampling.",
)
parser.add_argument("--precision", type=str, default="high")
parser.add_argument("--enable-inference", action="store_true")
return parser.parse_args()
def main():
torch.set_float32_matmul_precision(args.precision)
if not torch.cuda.is_available():
args.mode = "cpu-cpu-cpu"
print(f"Training in {args.mode} mode.")
args.graph_device, args.feature_device, args.device = args.mode.split("-")
args.overlap_feature_fetch = args.feature_device == "pinned"
args.overlap_graph_fetch = args.graph_device == "pinned"
"""
Load and preprocess on-disk dataset.
We inspect the in_memory field of the feature_data in the YAML file and modify
it to False. This will make sure the feature_data is loaded as DiskBasedFeature.
"""
print("Loading data...")
disk_based_feature_keys = None
if args.cpu_cache_size_in_gigabytes > 0:
disk_based_feature_keys = [("node", None, "feat")]
dataset = gb.BuiltinDataset(args.dataset, root=args.root)
if disk_based_feature_keys is None:
disk_based_feature_keys = set()
for feature in dataset.yaml_data["feature_data"]:
feature_key = (feature["domain"], feature["type"], feature["name"])
# Set the in_memory setting to False without modifying YAML file.
if feature_key in disk_based_feature_keys:
feature["in_memory"] = False
dataset = dataset.load()
# Move the dataset to the selected storage.
graph = (
dataset.graph.pin_memory_()
if args.graph_device == "pinned"
else dataset.graph.to(args.graph_device)
)
features = (
dataset.feature.pin_memory_()
if args.feature_device == "pinned"
else dataset.feature.to(args.feature_device)
)
train_set = dataset.tasks[0].train_set
valid_set = dataset.tasks[0].validation_set
test_set = dataset.tasks[0].test_set
all_nodes_set = dataset.all_nodes_set
args.fanout = list(map(int, args.fanout.split(",")))
num_classes = dataset.tasks[0].metadata["num_classes"]
"""
If the CPU cache size is greater than 0, we wrap the DiskBasedFeature to be
a CPUCachedFeature. This internally manages the CPU feature cache by the
specified cache replacement policy. This will reduce the amount of data
transferred during disk read operations for this feature.
Note: It is advised to set the CPU cache size to be at least 4 times the number
of sampled nodes in a mini-batch, otherwise the feature fetcher might get into
a deadlock, causing a hang.
"""
if args.cpu_cache_size_in_gigabytes > 0 and isinstance(
features[("node", None, "feat")], gb.DiskBasedFeature
):
features[("node", None, "feat")] = gb.cpu_cached_feature(
features[("node", None, "feat")],
int(args.cpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),
args.cpu_feature_cache_policy,
args.feature_device == "pinned",
)
cpu_cached_feature = features[("node", None, "feat")]
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature.miss_rate
else:
cpu_cache_miss_rate_fn = lambda: 1
"""
If the GPU cache size is greater than 0, we wrap the underlying feature store
to be a GPUCachedFeature. This will reduce the amount of data transferred during
host-to-device copy operations for this feature.
"""
if args.gpu_cache_size_in_gigabytes > 0 and args.feature_device != "cuda":
features[("node", None, "feat")] = gb.gpu_cached_feature(
features[("node", None, "feat")],
int(args.gpu_cache_size_in_gigabytes * 1024 * 1024 * 1024),
)
gpu_cached_feature = features[("node", None, "feat")]
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature.miss_rate
else:
gpu_cache_miss_rate_fn = lambda: 1
train_dataloader, valid_dataloader = (
create_dataloader(
graph=graph,
features=features,
itemset=itemset,
batch_size=args.batch_size,
fanout=args.fanout,
device=args.device,
job=job,
)
for itemset, job in zip([train_set, valid_set], ["train", "evaluate"])
)
in_channels = features.size("node", None, "feat")[0]
model = SAGE(
in_channels,
args.num_hidden,
num_classes,
len(args.fanout),
args.dropout,
).to(args.device)
assert len(args.fanout) == len(model.layers)
best_model = train(
train_dataloader,
valid_dataloader,
model,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
args.device,
)
model.load_state_dict(best_model)
if args.enable_inference:
# Test the model.
print("Testing...")
itemsets = {"train": train_set, "val": valid_set, "test": test_set}
final_acc = layerwise_infer(
args,
graph,
features,
itemsets,
all_nodes_set,
model,
)
print("Final accuracy values:")
print(final_acc)
if __name__ == "__main__":
args = parse_args()
print(args)
main()