mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
Gb to dgl nb (#6764)
This commit is contained in:
@@ -218,29 +218,6 @@
|
||||
"print(next(iter(datapipe)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "Gt059n1xrmj-"
|
||||
},
|
||||
"source": [
|
||||
"After retrieving the required data, Graphbolt provides helper methods to convert it to the output format needed for subsequent GNN training.\n",
|
||||
"\n",
|
||||
"* Convert to **DGLMiniBatch** format for training with DGL."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "o8Yoi8BeqSdu"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"datapipe = datapipe.to_dgl()\n",
|
||||
"print(next(iter(datapipe)))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
|
||||
@@ -172,27 +172,6 @@
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"In order to train with DGL, you need to convert `MiniBatch` to `DGLMiniBatch` like below:"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "IpAgrEp_cdEP"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"data = data.to_dgl()\n",
|
||||
"print(f\"DGLMiniBatch: {data}\")"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "KQgxFUyCcjVT"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
@@ -263,38 +242,6 @@
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"Define utility function to vonvert the minibatch to a training pair and a label tensor.\n",
|
||||
"\n"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "J9K1GUs4ZDYw"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):\n",
|
||||
" \"\"\"Convert the minibatch to a training pair and a label tensor.\"\"\"\n",
|
||||
" pos_src, pos_dst = data.positive_node_pairs\n",
|
||||
" neg_src, neg_dst = data.negative_node_pairs\n",
|
||||
" node_pairs = (\n",
|
||||
" torch.cat((pos_src, neg_src), dim=0),\n",
|
||||
" torch.cat((pos_dst, neg_dst), dim=0),\n",
|
||||
" )\n",
|
||||
" pos_label = torch.ones_like(pos_src)\n",
|
||||
" neg_label = torch.zeros_like(neg_src)\n",
|
||||
" labels = torch.cat([pos_label, neg_label], dim=0)\n",
|
||||
" return (node_pairs, labels.float())"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "wvIJBPb7ZNUv"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
@@ -313,11 +260,8 @@
|
||||
" model.train()\n",
|
||||
" total_loss = 0\n",
|
||||
" for step, data in tqdm.tqdm(enumerate(train_dataloader)):\n",
|
||||
" # Convert to DGL format.\n",
|
||||
" data = data.to_dgl()\n",
|
||||
"\n",
|
||||
" # Unpack DGLMiniBatch.\n",
|
||||
" compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)\n",
|
||||
" # Get node pairs with labels for loss calculation.\n",
|
||||
" compacted_pairs, labels = data.node_pairs_with_labels\n",
|
||||
" node_feature = data.node_features[\"feat\"]\n",
|
||||
" # Convert sampled subgraphs to DGL blocks.\n",
|
||||
" blocks = data.blocks\n",
|
||||
@@ -369,11 +313,8 @@
|
||||
"logits = []\n",
|
||||
"labels = []\n",
|
||||
"for step, data in enumerate(eval_dataloader):\n",
|
||||
" # Convert to DGL format.\n",
|
||||
" data = data.to_dgl()\n",
|
||||
"\n",
|
||||
" # Unpack MiniBatch.\n",
|
||||
" compacted_pairs, label = to_binary_link_dgl_computing_pack(data)\n",
|
||||
" # Get node pairs with labels for loss calculation.\n",
|
||||
" compacted_pairs, label = data.node_pairs_with_labels\n",
|
||||
"\n",
|
||||
" # The features of sampled nodes.\n",
|
||||
" x = data.node_features[\"feat\"]\n",
|
||||
|
||||
@@ -183,27 +183,6 @@
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"In order to train with DGL, you need to convert `MiniBatch` to `DGLMiniBatch` like below:"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "FwDJf1AJbNtt"
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"source": [
|
||||
"data = data.to_dgl()\n",
|
||||
"print(data)"
|
||||
],
|
||||
"metadata": {
|
||||
"id": "3Tzfp6A8bdWv"
|
||||
},
|
||||
"execution_count": null,
|
||||
"outputs": []
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
@@ -336,7 +315,6 @@
|
||||
"\n",
|
||||
" with tqdm.tqdm(train_dataloader) as tq:\n",
|
||||
" for step, data in enumerate(tq):\n",
|
||||
" data = data.to_dgl()\n",
|
||||
" x = data.node_features[\"feat\"]\n",
|
||||
" labels = data.labels\n",
|
||||
"\n",
|
||||
@@ -363,7 +341,6 @@
|
||||
" labels = []\n",
|
||||
" with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n",
|
||||
" for data in tq:\n",
|
||||
" data = data.to_dgl()\n",
|
||||
" x = data.node_features[\"feat\"]\n",
|
||||
" labels.append(data.labels.cpu().numpy())\n",
|
||||
" predictions.append(model(data.blocks, x).argmax(1).cpu().numpy())\n",
|
||||
|
||||
Reference in New Issue
Block a user