mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Refine examples. (#7733)
This commit is contained in:
committed by
GitHub
parent
1d378f8f83
commit
b3eacd22d7
@@ -376,13 +376,13 @@ def parse_args():
|
||||
"--cpu-cache-size-in-gigabytes",
|
||||
type=float,
|
||||
default=0,
|
||||
help="The capacity of the CPU cache, the number of features to store.",
|
||||
help="The capacity of the CPU cache in GiB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gpu-cache-size-in-gigabytes",
|
||||
type=float,
|
||||
default=0,
|
||||
help="The capacity of the GPU cache, the number of features to store.",
|
||||
help="The capacity of the GPU cache in GiB.",
|
||||
)
|
||||
parser.add_argument("--early-stopping-patience", type=int, default=25)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -333,13 +333,13 @@ def evaluate(
|
||||
model.eval()
|
||||
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
|
||||
total_samples = 0
|
||||
val_dataloader_tqdm = tqdm(dataloader, "Evaluating")
|
||||
for step, minibatch in enumerate(val_dataloader_tqdm):
|
||||
dataloader = tqdm(dataloader, "Evaluating")
|
||||
for step, minibatch in enumerate(dataloader):
|
||||
num_correct, num_samples = evaluate_step(minibatch, model, eval_fn)
|
||||
total_correct += num_correct
|
||||
total_samples += num_samples
|
||||
if step % 25 == 0:
|
||||
val_dataloader_tqdm.set_postfix(
|
||||
dataloader.set_postfix(
|
||||
{
|
||||
"num_nodes": minibatch.node_ids().size(0),
|
||||
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
|
||||
|
||||
@@ -232,7 +232,7 @@ def train_step(minibatch, optimizer, model, loss_fn):
|
||||
return loss.detach(), num_correct, labels.size(0)
|
||||
|
||||
|
||||
def train_helper(dataloader, model, optimizer, loss_fn, num_classes, device):
|
||||
def train_helper(dataloader, model, optimizer, loss_fn, device):
|
||||
model.train() # Set the model to training mode
|
||||
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
|
||||
# Accumulator for the total number of correct predictions
|
||||
@@ -254,7 +254,7 @@ def train_helper(dataloader, model, optimizer, loss_fn, num_classes, device):
|
||||
return train_loss, train_acc, end - start
|
||||
|
||||
|
||||
def train(train_dataloader, valid_dataloader, num_classes, model, device):
|
||||
def train(train_dataloader, valid_dataloader, model, device):
|
||||
#####################################################################
|
||||
# (HIGHLIGHT) Train the model for one epoch.
|
||||
#
|
||||
@@ -276,7 +276,7 @@ def train(train_dataloader, valid_dataloader, num_classes, model, device):
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
train_loss, train_acc, duration = train_helper(
|
||||
train_dataloader, model, optimizer, loss_fn, num_classes, device
|
||||
train_dataloader, model, optimizer, loss_fn, device
|
||||
)
|
||||
val_acc = evaluate(model, valid_dataloader, device)
|
||||
print(
|
||||
@@ -363,7 +363,7 @@ def parse_args():
|
||||
type=str,
|
||||
default="10,10,10",
|
||||
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
|
||||
" identical with the number of layers in your model. Default: 5,10,15",
|
||||
" identical with the number of layers in your model. Default: 10,10,10",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mode",
|
||||
@@ -466,7 +466,7 @@ def main():
|
||||
).to(args.device)
|
||||
assert len(args.fanout) == len(model.layers)
|
||||
|
||||
train(train_dataloader, valid_dataloader, num_classes, model, args.device)
|
||||
train(train_dataloader, valid_dataloader, model, args.device)
|
||||
|
||||
# Test the model.
|
||||
print("Testing...")
|
||||
|
||||
Reference in New Issue
Block a user