mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt][Example] Slightly improve examples. (#7563)
This commit is contained in:
committed by
GitHub
parent
afcf65c0a2
commit
d11cb874c3
@@ -122,7 +122,11 @@ def create_dataloader(
|
||||
if args.feature_device != "cpu":
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
# Fetch node features for the sampled subgraph.
|
||||
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
|
||||
datapipe = datapipe.fetch_feature(
|
||||
features,
|
||||
node_feature_keys=["feat"],
|
||||
overlap_fetch=args.overlap_feature_fetch,
|
||||
)
|
||||
# Copy the data to the specified device.
|
||||
if args.feature_device == "cpu":
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
@@ -130,7 +134,6 @@ def create_dataloader(
|
||||
return gb.DataLoader(
|
||||
datapipe,
|
||||
num_workers=args.num_workers,
|
||||
overlap_feature_fetch=args.overlap_feature_fetch,
|
||||
overlap_graph_fetch=args.overlap_graph_fetch,
|
||||
)
|
||||
|
||||
@@ -141,7 +144,8 @@ def train(
|
||||
model,
|
||||
multilabel,
|
||||
kwargs,
|
||||
cache_miss_rate_fn,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
device,
|
||||
):
|
||||
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
|
||||
@@ -176,13 +180,20 @@ def train(
|
||||
train_dataloader_tqdm.set_postfix(
|
||||
{
|
||||
"num_nodes": node_features.size(0),
|
||||
"cache_miss": cache_miss_rate_fn(),
|
||||
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
|
||||
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
|
||||
}
|
||||
)
|
||||
train_loss = total_loss / num_batches
|
||||
train_acc = total_correct / total_samples
|
||||
end = time.time()
|
||||
val_acc = evaluate(model, valid_dataloader, kwargs, cache_miss_rate_fn)
|
||||
val_acc = evaluate(
|
||||
model,
|
||||
valid_dataloader,
|
||||
kwargs,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
)
|
||||
if val_acc > best_model_acc:
|
||||
best_model_acc = val_acc
|
||||
best_model = deepcopy(model.state_dict())
|
||||
@@ -233,7 +244,9 @@ def layerwise_infer(
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(model, dataloader, kwargs, cache_miss_rate_fn):
|
||||
def evaluate(
|
||||
model, dataloader, kwargs, gpu_cache_miss_rate_fn, cpu_cache_miss_rate_fn
|
||||
):
|
||||
model.eval()
|
||||
y_hats = []
|
||||
ys = []
|
||||
@@ -247,7 +260,8 @@ def evaluate(model, dataloader, kwargs, cache_miss_rate_fn):
|
||||
val_dataloader_tqdm.set_postfix(
|
||||
{
|
||||
"num_nodes": node_features.size(0),
|
||||
"cache_miss": cache_miss_rate_fn(),
|
||||
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
|
||||
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -314,6 +328,19 @@ def parse_args():
|
||||
)
|
||||
parser.add_argument("--layer-dependency", action="store_true")
|
||||
parser.add_argument("--batch-dependency", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--cpu-feature-cache-policy",
|
||||
type=str,
|
||||
default="sieve",
|
||||
choices=["s3-fifo", "sieve", "lru", "clock"],
|
||||
help="The cache policy for the CPU feature cache.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-cpu-cached-features",
|
||||
type=int,
|
||||
default=0,
|
||||
help="The capacity of the CPU cache, the number of features to store.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-gpu-cached-features",
|
||||
type=int,
|
||||
@@ -375,17 +402,31 @@ def main():
|
||||
|
||||
num_classes = dataset.tasks[0].metadata["num_classes"]
|
||||
|
||||
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
|
||||
feature = features._features[("node", None, "feat")]
|
||||
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
|
||||
feature,
|
||||
args.num_gpu_cached_features * feature._tensor[:1].nbytes,
|
||||
feature_num_bytes = (
|
||||
features[("node", None, "feat")].read(torch.zeros(1).long()).nbytes
|
||||
)
|
||||
if args.num_cpu_cached_features > 0 and isinstance(
|
||||
features[("node", None, "feat")], gb.DiskBasedFeature
|
||||
):
|
||||
features[("node", None, "feat")] = gb.CPUCachedFeature(
|
||||
features[("node", None, "feat")],
|
||||
args.num_cpu_cached_features * feature_num_bytes,
|
||||
args.cpu_feature_cache_policy,
|
||||
args.feature_device == "pinned",
|
||||
)
|
||||
cache_miss_rate_fn = lambda: features._features[
|
||||
("node", None, "feat")
|
||||
]._feature.miss_rate
|
||||
cpu_cached_feature = features[("node", None, "feat")]
|
||||
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature._feature.miss_rate
|
||||
else:
|
||||
cache_miss_rate_fn = lambda: 1
|
||||
cpu_cache_miss_rate_fn = lambda: 1
|
||||
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
|
||||
features[("node", None, "feat")] = gb.GPUCachedFeature(
|
||||
features[("node", None, "feat")],
|
||||
args.num_gpu_cached_features * feature_num_bytes,
|
||||
)
|
||||
gpu_cached_feature = features[("node", None, "feat")]
|
||||
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature._feature.miss_rate
|
||||
else:
|
||||
gpu_cache_miss_rate_fn = lambda: 1
|
||||
|
||||
train_dataloader, valid_dataloader = (
|
||||
create_dataloader(
|
||||
@@ -425,7 +466,8 @@ def main():
|
||||
model,
|
||||
multilabel,
|
||||
kwargs,
|
||||
cache_miss_rate_fn,
|
||||
gpu_cache_miss_rate_fn,
|
||||
cpu_cache_miss_rate_fn,
|
||||
args.device,
|
||||
)
|
||||
model.load_state_dict(best_model)
|
||||
|
||||
@@ -189,7 +189,11 @@ def create_dataloader(
|
||||
if args.feature_device != "cpu":
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
# Fetch node features for the sampled subgraph.
|
||||
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
|
||||
datapipe = datapipe.fetch_feature(
|
||||
features,
|
||||
node_feature_keys=["feat"],
|
||||
overlap_fetch=args.overlap_feature_fetch,
|
||||
)
|
||||
# Copy the data to the specified device.
|
||||
if args.feature_device == "cpu":
|
||||
datapipe = datapipe.copy_to(device=device)
|
||||
@@ -197,7 +201,6 @@ def create_dataloader(
|
||||
return gb.DataLoader(
|
||||
datapipe,
|
||||
num_workers=args.num_workers,
|
||||
overlap_feature_fetch=args.overlap_feature_fetch,
|
||||
overlap_graph_fetch=args.overlap_graph_fetch,
|
||||
num_gpu_cached_edges=args.num_gpu_cached_edges,
|
||||
gpu_cache_threshold=args.gpu_graph_caching_threshold,
|
||||
@@ -385,10 +388,12 @@ def parse_args():
|
||||
help="Disables torch.compile() on the trained GNN model because it is "
|
||||
"enabled by default for torch>=2.2.0 without this option.",
|
||||
)
|
||||
parser.add_argument("--precision", type=str, default="high")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
torch.set_float32_matmul_precision(args.precision)
|
||||
if not torch.cuda.is_available():
|
||||
args.mode = "cpu-cpu-cpu"
|
||||
print(f"Training in {args.mode} mode.")
|
||||
|
||||
@@ -297,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
|
||||
out_size = num_classes
|
||||
|
||||
if args.gpu_cache_size > 0 and args.storage_device != "cuda":
|
||||
feature._features[("node", None, "feat")] = gb.GPUCachedFeature(
|
||||
feature._features[("node", None, "feat")],
|
||||
feature[("node", None, "feat")] = gb.GPUCachedFeature(
|
||||
feature[("node", None, "feat")],
|
||||
args.gpu_cache_size,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user