[GraphBolt] Refine examples. (#7733)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-22 18:22:10 -04:00
committed by GitHub
parent 1d378f8f83
commit b3eacd22d7
3 changed files with 10 additions and 10 deletions

View File

@@ -376,13 +376,13 @@ def parse_args():
"--cpu-cache-size-in-gigabytes",
type=float,
default=0,
help="The capacity of the CPU cache, the number of features to store.",
help="The capacity of the CPU cache in GiB.",
)
parser.add_argument(
"--gpu-cache-size-in-gigabytes",
type=float,
default=0,
help="The capacity of the GPU cache, the number of features to store.",
help="The capacity of the GPU cache in GiB.",
)
parser.add_argument("--early-stopping-patience", type=int, default=25)
parser.add_argument(

View File

@@ -333,13 +333,13 @@ def evaluate(
model.eval()
total_correct = torch.zeros(1, dtype=torch.float64, device=device)
total_samples = 0
val_dataloader_tqdm = tqdm(dataloader, "Evaluating")
for step, minibatch in enumerate(val_dataloader_tqdm):
dataloader = tqdm(dataloader, "Evaluating")
for step, minibatch in enumerate(dataloader):
num_correct, num_samples = evaluate_step(minibatch, model, eval_fn)
total_correct += num_correct
total_samples += num_samples
if step % 25 == 0:
val_dataloader_tqdm.set_postfix(
dataloader.set_postfix(
{
"num_nodes": minibatch.node_ids().size(0),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),

View File

@@ -232,7 +232,7 @@ def train_step(minibatch, optimizer, model, loss_fn):
return loss.detach(), num_correct, labels.size(0)
def train_helper(dataloader, model, optimizer, loss_fn, num_classes, device):
def train_helper(dataloader, model, optimizer, loss_fn, device):
model.train() # Set the model to training mode
total_loss = torch.zeros(1, device=device) # Accumulator for the total loss
# Accumulator for the total number of correct predictions
@@ -254,7 +254,7 @@ def train_helper(dataloader, model, optimizer, loss_fn, num_classes, device):
return train_loss, train_acc, end - start
def train(train_dataloader, valid_dataloader, num_classes, model, device):
def train(train_dataloader, valid_dataloader, model, device):
#####################################################################
# (HIGHLIGHT) Train the model for one epoch.
#
@@ -276,7 +276,7 @@ def train(train_dataloader, valid_dataloader, num_classes, model, device):
for epoch in range(args.epochs):
train_loss, train_acc, duration = train_helper(
train_dataloader, model, optimizer, loss_fn, num_classes, device
train_dataloader, model, optimizer, loss_fn, device
)
val_acc = evaluate(model, valid_dataloader, device)
print(
@@ -363,7 +363,7 @@ def parse_args():
type=str,
default="10,10,10",
help="Fan-out of neighbor sampling. It is IMPORTANT to keep len(fanout)"
" identical with the number of layers in your model. Default: 5,10,15",
" identical with the number of layers in your model. Default: 10,10,10",
)
parser.add_argument(
"--mode",
@@ -466,7 +466,7 @@ def main():
).to(args.device)
assert len(args.fanout) == len(model.layers)
train(train_dataloader, valid_dataloader, num_classes, model, args.device)
train(train_dataloader, valid_dataloader, model, args.device)
# Test the model.
print("Testing...")