mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt] Update notebook examples. (#7266)
Co-authored-by: Ubuntu <ubuntu@ip-172-31-0-133.us-west-2.compute.internal>
This commit is contained in:
@@ -77,7 +77,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = gb.BuiltinDataset(\"cora\").load()"
|
||||
"dataset = gb.BuiltinDataset(\"cora-seeds\").load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -255,7 +255,8 @@
|
||||
" total_loss = 0\n",
|
||||
" for step, data in tqdm(enumerate(create_train_dataloader())):\n",
|
||||
" # Get node pairs with labels for loss calculation.\n",
|
||||
" compacted_pairs, labels = data.node_pairs_with_labels\n",
|
||||
" compacted_seeds = data.compacted_seeds.T\n",
|
||||
" labels = data.labels\n",
|
||||
" node_feature = data.node_features[\"feat\"]\n",
|
||||
" # Convert sampled subgraphs to DGL blocks.\n",
|
||||
" blocks = data.blocks\n",
|
||||
@@ -263,7 +264,7 @@
|
||||
" # Get the embeddings of the input nodes.\n",
|
||||
" y = model(blocks, node_feature)\n",
|
||||
" logits = model.predictor(\n",
|
||||
" y[compacted_pairs[0]] * y[compacted_pairs[1]]\n",
|
||||
" y[compacted_seeds[0]] * y[compacted_seeds[1]]\n",
|
||||
" ).squeeze()\n",
|
||||
"\n",
|
||||
" # Compute loss.\n",
|
||||
@@ -308,7 +309,8 @@
|
||||
"labels = []\n",
|
||||
"for step, data in tqdm(enumerate(eval_dataloader)):\n",
|
||||
" # Get node pairs with labels for loss calculation.\n",
|
||||
" compacted_pairs, label = data.node_pairs_with_labels\n",
|
||||
" compacted_seeds = data.compacted_seeds.T\n",
|
||||
" label = data.labels\n",
|
||||
"\n",
|
||||
" # The features of sampled nodes.\n",
|
||||
" x = data.node_features[\"feat\"]\n",
|
||||
@@ -316,7 +318,7 @@
|
||||
" # Forward.\n",
|
||||
" y = model(data.blocks, x)\n",
|
||||
" logit = (\n",
|
||||
" model.predictor(y[compacted_pairs[0]] * y[compacted_pairs[1]])\n",
|
||||
" model.predictor(y[compacted_seeds[0]] * y[compacted_seeds[1]])\n",
|
||||
" .squeeze()\n",
|
||||
" .detach()\n",
|
||||
" )\n",
|
||||
|
||||
@@ -78,7 +78,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = gb.BuiltinDataset(\"ogbn-arxiv\").load()"
|
||||
"dataset = gb.BuiltinDataset(\"ogbn-arxiv-seeds\").load()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -143,7 +143,7 @@
|
||||
"source": [
|
||||
"def create_dataloader(itemset, shuffle):\n",
|
||||
" datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=shuffle)\n",
|
||||
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seed_nodes\"])\n",
|
||||
" datapipe = datapipe.copy_to(device, extra_attrs=[\"seeds\"])\n",
|
||||
" datapipe = datapipe.sample_neighbor(graph, [4, 4])\n",
|
||||
" datapipe = datapipe.fetch_feature(feature, node_feature_keys=[\"feat\"])\n",
|
||||
" return gb.DataLoader(datapipe)"
|
||||
@@ -375,4 +375,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user