diff --git a/examples/graphbolt/pyg/labor/node_classification.py b/examples/graphbolt/pyg/labor/node_classification.py index d907748d98..a8faded50c 100644 --- a/examples/graphbolt/pyg/labor/node_classification.py +++ b/examples/graphbolt/pyg/labor/node_classification.py @@ -402,8 +402,13 @@ def main(): num_classes = dataset.tasks[0].metadata["num_classes"] + feature_index_device = ( + args.feature_device if args.feature_device != "pinned" else None + ) feature_num_bytes = ( - features[("node", None, "feat")].read(torch.zeros(1).long()).nbytes + features[("node", None, "feat")] + # Read a single row to query its size in bytes. + .read(torch.zeros(1, device=feature_index_device).long()).nbytes ) if args.num_cpu_cached_features > 0 and isinstance( features[("node", None, "feat")], gb.DiskBasedFeature