[GraphBolt][Example] Slightly improve examples. (#7563)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-22 06:39:18 -04:00
committed by GitHub
parent afcf65c0a2
commit d11cb874c3
3 changed files with 68 additions and 21 deletions

View File

@@ -122,7 +122,11 @@ def create_dataloader(
if args.feature_device != "cpu":
datapipe = datapipe.copy_to(device=device)
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(
features,
node_feature_keys=["feat"],
overlap_fetch=args.overlap_feature_fetch,
)
# Copy the data to the specified device.
if args.feature_device == "cpu":
datapipe = datapipe.copy_to(device=device)
@@ -130,7 +134,6 @@ def create_dataloader(
return gb.DataLoader(
datapipe,
num_workers=args.num_workers,
overlap_feature_fetch=args.overlap_feature_fetch,
overlap_graph_fetch=args.overlap_graph_fetch,
)
@@ -141,7 +144,8 @@ def train(
model,
multilabel,
kwargs,
cache_miss_rate_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
device,
):
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
@@ -176,13 +180,20 @@ def train(
train_dataloader_tqdm.set_postfix(
{
"num_nodes": node_features.size(0),
"cache_miss": cache_miss_rate_fn(),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
train_loss = total_loss / num_batches
train_acc = total_correct / total_samples
end = time.time()
val_acc = evaluate(model, valid_dataloader, kwargs, cache_miss_rate_fn)
val_acc = evaluate(
model,
valid_dataloader,
kwargs,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
)
if val_acc > best_model_acc:
best_model_acc = val_acc
best_model = deepcopy(model.state_dict())
@@ -233,7 +244,9 @@ def layerwise_infer(
@torch.no_grad()
def evaluate(model, dataloader, kwargs, cache_miss_rate_fn):
def evaluate(
model, dataloader, kwargs, gpu_cache_miss_rate_fn, cpu_cache_miss_rate_fn
):
model.eval()
y_hats = []
ys = []
@@ -247,7 +260,8 @@ def evaluate(model, dataloader, kwargs, cache_miss_rate_fn):
val_dataloader_tqdm.set_postfix(
{
"num_nodes": node_features.size(0),
"cache_miss": cache_miss_rate_fn(),
"gpu_cache_miss": gpu_cache_miss_rate_fn(),
"cpu_cache_miss": cpu_cache_miss_rate_fn(),
}
)
@@ -314,6 +328,19 @@ def parse_args():
)
parser.add_argument("--layer-dependency", action="store_true")
parser.add_argument("--batch-dependency", type=int, default=1)
parser.add_argument(
"--cpu-feature-cache-policy",
type=str,
default="sieve",
choices=["s3-fifo", "sieve", "lru", "clock"],
help="The cache policy for the CPU feature cache.",
)
parser.add_argument(
"--num-cpu-cached-features",
type=int,
default=0,
help="The capacity of the CPU cache, the number of features to store.",
)
parser.add_argument(
"--num-gpu-cached-features",
type=int,
@@ -375,17 +402,31 @@ def main():
num_classes = dataset.tasks[0].metadata["num_classes"]
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
feature = features._features[("node", None, "feat")]
features._features[("node", None, "feat")] = gb.GPUCachedFeature(
feature,
args.num_gpu_cached_features * feature._tensor[:1].nbytes,
feature_num_bytes = (
features[("node", None, "feat")].read(torch.zeros(1).long()).nbytes
)
if args.num_cpu_cached_features > 0 and isinstance(
features[("node", None, "feat")], gb.DiskBasedFeature
):
features[("node", None, "feat")] = gb.CPUCachedFeature(
features[("node", None, "feat")],
args.num_cpu_cached_features * feature_num_bytes,
args.cpu_feature_cache_policy,
args.feature_device == "pinned",
)
cache_miss_rate_fn = lambda: features._features[
("node", None, "feat")
]._feature.miss_rate
cpu_cached_feature = features[("node", None, "feat")]
cpu_cache_miss_rate_fn = lambda: cpu_cached_feature._feature.miss_rate
else:
cache_miss_rate_fn = lambda: 1
cpu_cache_miss_rate_fn = lambda: 1
if args.num_gpu_cached_features > 0 and args.feature_device != "cuda":
features[("node", None, "feat")] = gb.GPUCachedFeature(
features[("node", None, "feat")],
args.num_gpu_cached_features * feature_num_bytes,
)
gpu_cached_feature = features[("node", None, "feat")]
gpu_cache_miss_rate_fn = lambda: gpu_cached_feature._feature.miss_rate
else:
gpu_cache_miss_rate_fn = lambda: 1
train_dataloader, valid_dataloader = (
create_dataloader(
@@ -425,7 +466,8 @@ def main():
model,
multilabel,
kwargs,
cache_miss_rate_fn,
gpu_cache_miss_rate_fn,
cpu_cache_miss_rate_fn,
args.device,
)
model.load_state_dict(best_model)

View File

@@ -189,7 +189,11 @@ def create_dataloader(
if args.feature_device != "cpu":
datapipe = datapipe.copy_to(device=device)
# Fetch node features for the sampled subgraph.
datapipe = datapipe.fetch_feature(features, node_feature_keys=["feat"])
datapipe = datapipe.fetch_feature(
features,
node_feature_keys=["feat"],
overlap_fetch=args.overlap_feature_fetch,
)
# Copy the data to the specified device.
if args.feature_device == "cpu":
datapipe = datapipe.copy_to(device=device)
@@ -197,7 +201,6 @@ def create_dataloader(
return gb.DataLoader(
datapipe,
num_workers=args.num_workers,
overlap_feature_fetch=args.overlap_feature_fetch,
overlap_graph_fetch=args.overlap_graph_fetch,
num_gpu_cached_edges=args.num_gpu_cached_edges,
gpu_cache_threshold=args.gpu_graph_caching_threshold,
@@ -385,10 +388,12 @@ def parse_args():
help="Disables torch.compile() on the trained GNN model because it is "
"enabled by default for torch>=2.2.0 without this option.",
)
parser.add_argument("--precision", type=str, default="high")
return parser.parse_args()
def main():
torch.set_float32_matmul_precision(args.precision)
if not torch.cuda.is_available():
args.mode = "cpu-cpu-cpu"
print(f"Training in {args.mode} mode.")

View File

@@ -297,8 +297,8 @@ def run(rank, world_size, args, devices, dataset):
out_size = num_classes
if args.gpu_cache_size > 0 and args.storage_device != "cuda":
feature._features[("node", None, "feat")] = gb.GPUCachedFeature(
feature._features[("node", None, "feat")],
feature[("node", None, "feat")] = gb.GPUCachedFeature(
feature[("node", None, "feat")],
args.gpu_cache_size,
)