mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[Bug fix] Various fix from bug bash (#3133)
* Update * Update * Update dependencies * Update * Update * Fix ogbn-products gat * Update * Update * Reformat * Fix typo in node2vec_random_walk * Specify file encoding * Working for 6.7 * Update * Fix subgraph * Fix doc for sample_neighbors_biased * Fix hyperlink * Add example for udf cross reducer * Fix * Add example for slice_batch * Replace dgl.bipartite * Fix GATConv * Fix math rendering * Fix doc Co-authored-by: Ubuntu <ubuntu@ip-172-31-28-17.us-west-2.compute.internal> Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com> Co-authored-by: Ubuntu <ubuntu@ip-172-31-22-156.us-west-2.compute.internal>
This commit is contained in:
@@ -49,7 +49,7 @@ the same as the other user guides and tutorials.
|
||||
|
||||
GPU-based neighbor sampling also works for custom neighborhood samplers as long as
|
||||
(1) your sampler is subclassed from :class:`~dgl.dataloading.BlockSampler`, and (2)
|
||||
your code in the sampler entirely works on GPU.
|
||||
your sampler entirely works on GPU.
|
||||
|
||||
.. note::
|
||||
|
||||
|
||||
@@ -153,7 +153,7 @@ of MFGs, we:
|
||||
training.
|
||||
|
||||
If the features are stored in ``g.ndata``, then the labels
|
||||
can be loaded by accessing the features in ``blocks[-1].srcdata``,
|
||||
can be loaded by accessing the features in ``blocks[-1].dstdata``,
|
||||
the features of destination nodes of the last MFG, which is identical to
|
||||
the nodes we wish to compute the final representation.
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ DGL提供了几个邻居采样类,这些类会生成需计算的节点在每
|
||||
3. 将与输出节点相对应的节点标签加载到GPU上。同样,节点标签可以存储在内存或外部存储器中。
|
||||
再次提醒下,用户只需要加载输出节点的标签,而不是像整图训练那样加载所有节点的标签。
|
||||
|
||||
如果特征存储在 ``g.ndata`` 中,则可以通过访问 ``blocks[-1].srcdata`` 中的特征来加载标签,
|
||||
如果特征存储在 ``g.ndata`` 中,则可以通过访问 ``blocks[-1].dstdata`` 中的特征来加载标签,
|
||||
它是最后一个块的输出节点的特征,这些节点与用户希望计算最终表示的节点相同。
|
||||
|
||||
4. 计算损失并反向传播。
|
||||
@@ -208,4 +208,4 @@ DGL提供的一些采样方法也支持异构图。例如,用户仍然可以
|
||||
opt.step()
|
||||
|
||||
DGL提供了端到端随机批次训练的
|
||||
`RGCN的实现 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__。
|
||||
`RGCN的实现 <https://github.com/dmlc/dgl/blob/master/examples/pytorch/rgcn-hetero/entity_classify_mb.py>`__。
|
||||
|
||||
@@ -51,7 +51,7 @@ def main(args):
|
||||
tr_loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
train_loss = np.sum(train_loss)
|
||||
train_loss = torch.stack(train_loss).sum().cpu().item()
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
@@ -66,8 +66,8 @@ def main(args):
|
||||
val_acc = evaluate_acc(logits.detach().cpu().numpy(), label.detach().cpu().numpy())
|
||||
validate_loss.append(val_loss)
|
||||
validate_acc.append(val_acc)
|
||||
|
||||
validate_loss = np.sum(validate_loss)
|
||||
|
||||
validate_loss = torch.stack(validate_loss).sum().cpu().item()
|
||||
validate_acc = np.mean(validate_acc)
|
||||
|
||||
#validate
|
||||
@@ -111,8 +111,8 @@ def main(args):
|
||||
test_auc.append(auc)
|
||||
test_f1.append(f1)
|
||||
test_logloss.append(log_loss)
|
||||
|
||||
test_loss = np.sum(test_loss)
|
||||
|
||||
test_loss = torch.stack(test_loss).sum().cpu().item()
|
||||
test_acc = np.mean(test_acc)
|
||||
test_auc = np.mean(test_auc)
|
||||
test_f1 = np.mean(test_f1)
|
||||
@@ -146,4 +146,4 @@ if __name__ == '__main__':
|
||||
print(args)
|
||||
|
||||
main(args)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ Dependencies
|
||||
----------------------
|
||||
- pytorch 1.7.1
|
||||
- dgl 0.6.0
|
||||
- sklearn 0.22.1
|
||||
|
||||
Datasets
|
||||
---------------------------------------
|
||||
|
||||
@@ -466,11 +466,11 @@ class MovieLens(object):
|
||||
file_path = os.path.join(self._dir, 'u.item')
|
||||
self.movie_info = pd.read_csv(file_path, sep='|', header=None,
|
||||
names=['id', 'title', 'release_date', 'video_release_date', 'url'] + GENRES,
|
||||
engine='python')
|
||||
encoding='iso-8859-1')
|
||||
elif self._name == 'ml-1m' or self._name == 'ml-10m':
|
||||
file_path = os.path.join(self._dir, 'movies.dat')
|
||||
movie_info = pd.read_csv(file_path, sep='::', header=None,
|
||||
names=['id', 'title', 'genres'], engine='python')
|
||||
names=['id', 'title', 'genres'], encoding='iso-8859-1')
|
||||
genre_map = {ele: i for i, ele in enumerate(GENRES)}
|
||||
genre_map['Children\'s'] = genre_map['Children']
|
||||
genre_map['Childrens'] = genre_map['Children']
|
||||
|
||||
@@ -12,6 +12,7 @@ This example was implemented by [Hengrui Zhang](https://github.com/hengruizhang9
|
||||
- Python 3.7
|
||||
- PyTorch 1.7.1
|
||||
- dgl 0.6.0
|
||||
- sklearn 0.22.1
|
||||
|
||||
## Datasets
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ Requirements
|
||||
------------
|
||||
- requests
|
||||
|
||||
``bash
|
||||
```bash
|
||||
pip install requests
|
||||
``
|
||||
```
|
||||
|
||||
|
||||
Results
|
||||
@@ -34,10 +34,14 @@ Train w/ mini-batch sampling (on the Reddit dataset)
|
||||
```bash
|
||||
python3 train_sampling.py --num-epochs 30 # neighbor sampling
|
||||
python3 train_sampling.py --num-epochs 30 --inductive # inductive learning with neighbor sampling
|
||||
python3 train_sampling_multi_gpu.py --num-epochs 30 # neighbor sampling with multi GPU
|
||||
python3 train_sampling_multi_gpu.py --num-epochs 30 --inductive # inductive learning with neighbor sampling, multi GPU
|
||||
python3 train_cv.py --num-epochs 30 # control variate sampling
|
||||
python3 train_cv_multi_gpu.py --num-epochs 30 # control variate sampling with multi GPU
|
||||
```
|
||||
|
||||
For multi-gpu training
|
||||
```bash
|
||||
python3 train_sampling_multi_gpu.py --num-epochs 30 --gpu 0,1,... # neighbor sampling
|
||||
python3 train_sampling_multi_gpu.py --num-epochs 30 --inductive --gpu 0,1,... # inductive learning
|
||||
python3 train_cv_multi_gpu.py --num-epochs 30 --gpu 0,1,... # control variate sampling
|
||||
```
|
||||
|
||||
Accuracy:
|
||||
|
||||
@@ -24,6 +24,8 @@ All datasets used are provided by Author's [code](https://github.com/GraphSAINT/
|
||||
| PPI | 14,755 | 225,270 | 15 | 50 | 121(m) | 0.66/0.12/0.22 |
|
||||
| Flickr | 89,250 | 899,756 | 10 | 500 | 7(s) | 0.50/0.25/0.25 |
|
||||
|
||||
Note that the PPI dataset here is different from DGL's built-in variant.
|
||||
|
||||
## Minibatch training
|
||||
|
||||
Run with following:
|
||||
@@ -94,4 +96,4 @@ python train_sampling.py --gpu 0 --dataset flickr --sampler rw --num-roots 6000
|
||||
| Sampling(Running) | 0.83 | 1.22 |
|
||||
| Sampling(DGL) | 0.28 | 0.63 |
|
||||
| Normalization(Running) | 0.87 | 2.60 |
|
||||
| Normalization(DGL) | 0.70 | 0.42 |
|
||||
| Normalization(DGL) | 0.70 | 0.42 |
|
||||
|
||||
@@ -11,6 +11,8 @@ The authors' implementation can be found [here](https://github.com/Jhy1993/HAN).
|
||||
[here](https://github.com/Jhy1993/HAN/tree/master/data/acm). The dataset is noisy
|
||||
because there are same author occurring multiple times as different nodes.
|
||||
|
||||
For sampling-based training, `python train_sampling.py`
|
||||
|
||||
## Performance
|
||||
|
||||
Reference performance numbers for the ACM dataset:
|
||||
|
||||
@@ -92,7 +92,7 @@ def gen_model(args):
|
||||
input_drop=args.input_drop,
|
||||
attn_drop=args.attn_dropout,
|
||||
edge_drop=args.edge_drop,
|
||||
use_attn_dst=not args.use_attn_dst,
|
||||
use_attn_dst=not args.no_attn_dst,
|
||||
allow_zero_in_degree=True,
|
||||
residual=False,
|
||||
)
|
||||
|
||||
@@ -433,6 +433,28 @@ def slice_batch(g, gid, store_ids=False):
|
||||
-------
|
||||
DGLGraph
|
||||
Retrieved graph.
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
The following example uses PyTorch backend.
|
||||
|
||||
>>> import dgl
|
||||
>>> import torch
|
||||
|
||||
Create a batched graph.
|
||||
|
||||
>>> g1 = dgl.graph(([0, 1], [2, 3]))
|
||||
>>> g2 = dgl.graph(([1], [2]))
|
||||
>>> bg = dgl.batch([g1, g2])
|
||||
|
||||
Get the second component graph.
|
||||
|
||||
>>> g = dgl.slice_batch(bg, 1)
|
||||
>>> print(g)
|
||||
Graph(num_nodes=3, num_edges=1,
|
||||
ndata_schemes={}
|
||||
edata_schemes={})
|
||||
"""
|
||||
start_nid = []
|
||||
num_nodes = []
|
||||
|
||||
@@ -4951,6 +4951,18 @@ class DGLHeteroGraph(object):
|
||||
>>> g.nodes['user'].data['h']
|
||||
tensor([[0.],
|
||||
[4.]])
|
||||
|
||||
User-defined cross reducer equivalent to "sum".
|
||||
|
||||
>>> def cross_sum(flist):
|
||||
... return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]
|
||||
|
||||
Use the user-defined cross reducer.
|
||||
|
||||
>>> g.multi_update_all(
|
||||
... {'follows': (fn.copy_src('h', 'm'), fn.sum('m', 'h')),
|
||||
... 'attracts': (fn.copy_src('h', 'm'), fn.sum('m', 'h'))},
|
||||
... cross_sum)
|
||||
"""
|
||||
all_out = defaultdict(list)
|
||||
merge_order = defaultdict(list)
|
||||
|
||||
@@ -117,7 +117,7 @@ class GATConv(nn.Block):
|
||||
>>> # Case 2: Unidirectional bipartite graph
|
||||
>>> u = [0, 1, 0, 0, 1]
|
||||
>>> v = [0, 1, 2, 3, 2]
|
||||
>>> g = dgl.bipartite((u, v))
|
||||
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
|
||||
>>> u_feat = mx.nd.random.randn(2, 5)
|
||||
>>> v_feat = mx.nd.random.randn(4, 10)
|
||||
>>> gatconv = GATConv((5,10), 2, 3)
|
||||
|
||||
@@ -21,7 +21,7 @@ from .densesageconv import DenseSAGEConv
|
||||
from .atomicconv import AtomicConv
|
||||
from .cfconv import CFConv
|
||||
from .dotgatconv import DotGatConv
|
||||
from .twirlsconv import TWIRLSConv, UnfoldingAndAttention as TWIRLSUnfoldingAndAttention
|
||||
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
|
||||
from .gcn2conv import GCN2Conv
|
||||
|
||||
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'TAGConv', 'RelGraphConv', 'SAGEConv',
|
||||
|
||||
@@ -115,7 +115,7 @@ class GATConv(nn.Module):
|
||||
>>> # Case 2: Unidirectional bipartite graph
|
||||
>>> u = [0, 1, 0, 0, 1]
|
||||
>>> v = [0, 1, 2, 3, 2]
|
||||
>>> g = dgl.bipartite((u, v))
|
||||
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
|
||||
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
|
||||
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
|
||||
>>> gatconv = GATConv((5,10), 2, 3)
|
||||
|
||||
@@ -20,7 +20,9 @@ class GCN2Conv(nn.Module):
|
||||
and Identity mapping (GCNII) was introduced in `"Simple and Deep Graph Convolutional
|
||||
Networks" <https://arxiv.org/abs/2007.02133>`_ paper.
|
||||
It is mathematically is defined as follows:
|
||||
|
||||
.. math::
|
||||
|
||||
\mathbf{h}^{(l+1)} =\left( (1 - \alpha)(\mathbf{D}^{-1/2} \mathbf{\hat{A}}
|
||||
\mathbf{D}^{-1/2})\mathbf{h}^{(l)} + \alpha {\mathbf{h}^{(0)}} \right)
|
||||
\left( (1 - \beta_l) \mathbf{I} + \beta_l \mathbf{W} \right)
|
||||
|
||||
@@ -444,7 +444,7 @@ def D_power_bias_X(graph, X, power, coeff, bias):
|
||||
return Y
|
||||
|
||||
|
||||
class UnfoldingAndAttention(nn.Module):
|
||||
class TWIRLSUnfoldingAndAttention(nn.Module):
|
||||
r"""
|
||||
|
||||
Description
|
||||
|
||||
@@ -117,7 +117,7 @@ class GATConv(layers.Layer):
|
||||
>>> # Case 2: Unidirectional bipartite graph
|
||||
>>> u = [0, 1, 0, 0, 1]
|
||||
>>> v = [0, 1, 2, 3, 2]
|
||||
>>> g = dgl.bipartite((u, v))
|
||||
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
|
||||
>>> with tf.device("CPU:0"):
|
||||
>>> u_feat = tf.convert_to_tensor(np.random.rand(2, 5))
|
||||
>>> v_feat = tf.convert_to_tensor(np.random.rand(4, 10))
|
||||
|
||||
@@ -302,8 +302,9 @@ def sample_neighbors_biased(g, nodes, fanout, bias, edge_dir='in',
|
||||
[0, 1, 2]])
|
||||
|
||||
Set the probability of each tag:
|
||||
|
||||
>>> bias = torch.tensor([1.0, 0.001])
|
||||
# node 2 is almost impossible to be sampled because it has tag 1.
|
||||
>>> # node 2 is almost impossible to be sampled because it has tag 1.
|
||||
|
||||
To sample one out bound edge for node 0 and node 2:
|
||||
|
||||
|
||||
@@ -63,13 +63,13 @@ def node2vec_random_walk(g, nodes, p, q, walk_length, prob=None, return_eids=Fal
|
||||
Examples
|
||||
--------
|
||||
>>> g1 = dgl.graph(([0, 1, 1, 2, 3], [1, 2, 3, 0, 0]))
|
||||
>>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, length=4)
|
||||
>>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, walk_length=4)
|
||||
tensor([[0, 1, 3, 0, 1],
|
||||
[1, 2, 0, 1, 3],
|
||||
[2, 0, 1, 3, 0],
|
||||
[0, 1, 2, 0, 1]])
|
||||
|
||||
>>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, length=4, return_eids=True)
|
||||
>>> dgl.sampling.node2vec_random_walk(g1, [0, 1, 2, 0], 1, 1, walk_length=4, return_eids=True)
|
||||
(tensor([[0, 1, 3, 0, 1],
|
||||
[1, 2, 0, 1, 2],
|
||||
[2, 0, 1, 2, 0],
|
||||
|
||||
@@ -119,9 +119,9 @@ def node_subgraph(graph, nodes, *, relabel_nodes=True, store_ids=True):
|
||||
>>> })
|
||||
>>> sub_g = dgl.node_subgraph(g, {'user': [1, 2]})
|
||||
>>> sub_g
|
||||
Graph(num_nodes={'user': 2, 'game': 0},
|
||||
num_edges={('user', 'plays', 'game'): 0, ('user', 'follows', 'user'): 2},
|
||||
metagraph=[('user', 'game'), ('user', 'user')])
|
||||
Graph(num_nodes={'game': 0, 'user': 2},
|
||||
num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 0},
|
||||
metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])
|
||||
|
||||
See Also
|
||||
--------
|
||||
@@ -266,9 +266,9 @@ def edge_subgraph(graph, edges, *, relabel_nodes=True, store_ids=True, **depreca
|
||||
>>> sub_g = dgl.edge_subgraph(g, {('user', 'follows', 'user'): [1, 2],
|
||||
... ('user', 'plays', 'game'): [2]})
|
||||
>>> print(sub_g)
|
||||
Graph(num_nodes={'user': 2, 'game': 1},
|
||||
num_edges={('user', 'plays', 'game'): 1, ('user', 'follows', 'user'): 2},
|
||||
metagraph=[('user', 'game'), ('user', 'user')])
|
||||
Graph(num_nodes={'game': 1, user': 2},
|
||||
num_edges={('user', 'follows', 'user'): 2, ('user', 'plays', 'game'): 1},
|
||||
metagraph=[('user', 'user', 'follows'), ('user', 'game', 'plays')])
|
||||
|
||||
See Also
|
||||
--------
|
||||
|
||||
@@ -2536,8 +2536,6 @@ def adj_product_graph(A, B, weight_name, etype='_E'):
|
||||
>>> B = dgl.heterograph({
|
||||
... ('B', 'BA', 'A'): ([0, 3, 2, 1, 3, 3], [1, 2, 0, 2, 1, 0])},
|
||||
... num_nodes_dict={'A': 3, 'B': 4})
|
||||
>>> A.edata['w'] = torch.randn(6).requires_grad_()
|
||||
>>> B.edata['w'] = torch.randn(6).requires_grad_()
|
||||
|
||||
If your graph is a multigraph, you will need to call :func:`dgl.to_simple`
|
||||
to convert it into a simple graph first.
|
||||
@@ -2545,6 +2543,13 @@ def adj_product_graph(A, B, weight_name, etype='_E'):
|
||||
>>> A = dgl.to_simple(A)
|
||||
>>> B = dgl.to_simple(B)
|
||||
|
||||
Initialize learnable edge weights.
|
||||
|
||||
>>> A.edata['w'] = torch.randn(6).requires_grad_()
|
||||
>>> B.edata['w'] = torch.randn(6).requires_grad_()
|
||||
|
||||
Take the product.
|
||||
|
||||
>>> C = dgl.adj_product_graph(A, B, 'w')
|
||||
>>> C.edges()
|
||||
(tensor([0, 0, 1, 2, 2, 2]), tensor([0, 1, 0, 0, 2, 1]))
|
||||
@@ -2660,12 +2665,19 @@ def adj_sum_graph(graphs, weight_name):
|
||||
>>> A.edata['w'] = torch.randn(6).requires_grad_()
|
||||
>>> B.edata['w'] = torch.randn(6).requires_grad_()
|
||||
|
||||
If your graph is a multigraph, you will need to call :func:`dgl.to_simple`
|
||||
If your graph is a multigraph, call :func:`dgl.to_simple`
|
||||
to convert it into a simple graph first.
|
||||
|
||||
>>> A = dgl.to_simple(A)
|
||||
>>> B = dgl.to_simple(B)
|
||||
|
||||
Initialize learnable edge weights.
|
||||
|
||||
>>> A.edata['w'] = torch.randn(6).requires_grad_()
|
||||
>>> B.edata['w'] = torch.randn(6).requires_grad_()
|
||||
|
||||
Take the sum.
|
||||
|
||||
>>> C = dgl.adj_sum_graph([A, B], 'w')
|
||||
>>> C.edges()
|
||||
(tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2]),
|
||||
@@ -2930,7 +2942,7 @@ def reorder_graph(g, node_permute_algo='rcmk', edge_permute_algo='src',
|
||||
generated/scipy.sparse.csgraph.reverse_cuthill_mckee.html#
|
||||
scipy-sparse-csgraph-reverse-cuthill-mckee>`__ from ``scipy`` to generate nodes
|
||||
permutation.
|
||||
* ``metis``: Use the :func:`~dgl.partition.metis_partition_assignment` function
|
||||
* ``metis``: Use the :func:`~dgl.metis_partition_assignment` function
|
||||
to partition the input graph, which gives a cluster assignment of each node.
|
||||
DGL then sorts the assignment array so the new node order will put nodes of
|
||||
the same cluster together.
|
||||
|
||||
@@ -11,66 +11,64 @@ knowledge in GNNs for graph classification and we recommend you to check
|
||||
|
||||
To use a single GPU in training a GNN, we need to put the model, graph(s), and other
|
||||
tensors (e.g. labels) on the same GPU:
|
||||
"""
|
||||
|
||||
"""
|
||||
import torch
|
||||
.. code:: python
|
||||
|
||||
# Use the first GPU
|
||||
device = torch.device("cuda:0")
|
||||
model = model.to(device)
|
||||
graph = graph.to(device)
|
||||
labels = labels.to(device)
|
||||
"""
|
||||
import torch
|
||||
|
||||
###############################################################################
|
||||
# The node and edge features in the graphs, if any, will also be on the GPU.
|
||||
# After that, the forward computation, backward computation and parameter
|
||||
# update will take place on the GPU. For graph classification, this repeats
|
||||
# for each minibatch gradient descent.
|
||||
#
|
||||
# Using multiple GPUs allows performing more computation per unit of time. It
|
||||
# is like having a team work together, where each GPU is a team member. We need
|
||||
# to distribute the computation workload across GPUs and let them synchronize
|
||||
# the efforts regularly. PyTorch provides convenient APIs for this task with
|
||||
# multiple processes, one per GPU, and we can use them in conjunction with DGL.
|
||||
#
|
||||
# Intuitively, we can distribute the workload along the dimension of data. This
|
||||
# allows multiple GPUs to perform the forward and backward computation of
|
||||
# multiple gradient descents in parallel. To distribute a dataset across
|
||||
# multiple GPUs, we need to partition it into multiple mutually exclusive
|
||||
# subsets of a similar size, one per GPU. We need to repeat the random
|
||||
# partition every epoch to guarantee randomness. We can use
|
||||
# :func:`~dgl.dataloading.pytorch.GraphDataLoader`, which wraps some PyTorch
|
||||
# APIs and does the job for graph classification in data loading.
|
||||
#
|
||||
# Once all GPUs have finished the backward computation for its minibatch,
|
||||
# we need to synchronize the model parameter update across them. Specifically,
|
||||
# this involves collecting gradients from all GPUs, averaging them and updating
|
||||
# the model parameters on each GPU. We can wrap a PyTorch model with
|
||||
# :func:`~torch.nn.parallel.DistributedDataParallel` so that the model
|
||||
# parameter update will invoke gradient synchronization first under the hood.
|
||||
#
|
||||
# .. image:: https://data.dgl.ai/tutorial/mgpu_gc.png
|
||||
# :width: 450px
|
||||
# :align: center
|
||||
#
|
||||
# That’s the core behind this tutorial. We will explore it more in detail with
|
||||
# a complete example below.
|
||||
#
|
||||
# .. note::
|
||||
#
|
||||
# See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
|
||||
# from PyTorch for general multi-GPU training with ``DistributedDataParallel``.
|
||||
#
|
||||
# Distributed Process Group Initialization
|
||||
# ----------------------------------------
|
||||
#
|
||||
# For communication between multiple processes in multi-gpu training, we need
|
||||
# to start the distributed backend at the beginning of each process. We use
|
||||
# `world_size` to refer to the number of processes and `rank` to refer to the
|
||||
# process ID, which should be an integer from `0` to `world_size - 1`.
|
||||
#
|
||||
# Use the first GPU
|
||||
device = torch.device("cuda:0")
|
||||
model = model.to(device)
|
||||
graph = graph.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
The node and edge features in the graphs, if any, will also be on the GPU.
|
||||
After that, the forward computation, backward computation and parameter
|
||||
update will take place on the GPU. For graph classification, this repeats
|
||||
for each minibatch gradient descent.
|
||||
|
||||
Using multiple GPUs allows performing more computation per unit of time. It
|
||||
is like having a team work together, where each GPU is a team member. We need
|
||||
to distribute the computation workload across GPUs and let them synchronize
|
||||
the efforts regularly. PyTorch provides convenient APIs for this task with
|
||||
multiple processes, one per GPU, and we can use them in conjunction with DGL.
|
||||
|
||||
Intuitively, we can distribute the workload along the dimension of data. This
|
||||
allows multiple GPUs to perform the forward and backward computation of
|
||||
multiple gradient descents in parallel. To distribute a dataset across
|
||||
multiple GPUs, we need to partition it into multiple mutually exclusive
|
||||
subsets of a similar size, one per GPU. We need to repeat the random
|
||||
partition every epoch to guarantee randomness. We can use
|
||||
:func:`~dgl.dataloading.pytorch.GraphDataLoader`, which wraps some PyTorch
|
||||
APIs and does the job for graph classification in data loading.
|
||||
|
||||
Once all GPUs have finished the backward computation for its minibatch,
|
||||
we need to synchronize the model parameter update across them. Specifically,
|
||||
this involves collecting gradients from all GPUs, averaging them and updating
|
||||
the model parameters on each GPU. We can wrap a PyTorch model with
|
||||
:func:`~torch.nn.parallel.DistributedDataParallel` so that the model
|
||||
parameter update will invoke gradient synchronization first under the hood.
|
||||
|
||||
.. image:: https://data.dgl.ai/tutorial/mgpu_gc.png
|
||||
:width: 450px
|
||||
:align: center
|
||||
|
||||
That’s the core behind this tutorial. We will explore it more in detail with
|
||||
a complete example below.
|
||||
|
||||
.. note::
|
||||
|
||||
See `this tutorial <https://pytorch.org/tutorials/intermediate/ddp_tutorial.html>`__
|
||||
from PyTorch for general multi-GPU training with ``DistributedDataParallel``.
|
||||
|
||||
Distributed Process Group Initialization
|
||||
----------------------------------------
|
||||
|
||||
For communication between multiple processes in multi-gpu training, we need
|
||||
to start the distributed backend at the beginning of each process. We use
|
||||
`world_size` to refer to the number of processes and `rank` to refer to the
|
||||
process ID, which should be an integer from `0` to `world_size - 1`.
|
||||
"""
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
@@ -193,9 +191,7 @@ def main(rank, world_size, dataset, seed=0):
|
||||
optimizer = Adam(model.parameters(), lr=0.01)
|
||||
|
||||
train_loader, val_loader, test_loader = get_dataloaders(dataset,
|
||||
seed,
|
||||
world_size,
|
||||
rank)
|
||||
seed)
|
||||
for epoch in range(5):
|
||||
model.train()
|
||||
# The line below ensures all processes use a different
|
||||
|
||||
Reference in New Issue
Block a user