mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[NN][Model] GATv2 (#3473)
* [Model][Core] GATv2 * lint * gatv2conv.py * lint * lint * style and docs * lint * gatv2conv fix Co-authored-by: Shaked Brody shakedbr@campus.technion.ac.il <shakedbr@tangerine.cslcs.technion.ac.il> Co-authored-by: Mufei Li <mufeili1996@gmail.com>
This commit is contained in:
@@ -288,6 +288,8 @@ Take the survey [here](https://forms.gle/Ej3jHCocACmb49Gp8) and leave any feedba
|
||||
|
||||
1. [**GNNLens: A Visual Analytics Approach for Prediction Error Diagnosis of Graph Neural Networks**](https://arxiv.org/abs/2011.11048v5), *Zhihua Jin, Yong Wang, Qianwen Wang, Yao Ming, Tengfei Ma, Huamin Qu*
|
||||
|
||||
1. [**How Attentive are Graph Attention Networks?**](https://arxiv.org/pdf/2105.14491.pdf), *Shaked Brody, Uri Alon, Eran Yahav*, [code](https://github.com/tech-srl/how_attentive_are_gats)
|
||||
|
||||
</details>
|
||||
|
||||
## Contributing
|
||||
|
||||
@@ -45,7 +45,13 @@ GATConv
|
||||
:members: forward
|
||||
:show-inheritance:
|
||||
|
||||
|
||||
GATv2Conv
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
.. autoclass:: dgl.nn.pytorch.conv.GATv2Conv
|
||||
:members: forward
|
||||
:show-inheritance:
|
||||
|
||||
EGATConv
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ To quickly locate the examples of your interest, search for the tagged keywords
|
||||
- <a name="pct"></a> Guo et al. PCT: Point cloud transformer. [Paper link](http://arxiv.org/abs/2012.09688).
|
||||
- Example code: [PyTorch](../examples/pytorch/pointcloud/pct)
|
||||
- Tags: point cloud classification, point cloud part-segmentation
|
||||
- <a name='gatv2'></a> Brody et al. How Attentive are Graph Attention Networks? [Paper link](https://arxiv.org/abs/2105.14491).
|
||||
- Example code: [PyTorch](../examples/pytorch/gatv2)
|
||||
- Tags: graph attention, gat, gatv2, attention
|
||||
|
||||
## 2020
|
||||
- <a name="eeg-gcnn"></a> Wagh et al. EEG-GCNN: Augmenting Electroencephalogram-based Neurological Disease Diagnosis using a Domain-guided Graph Convolutional Neural Network. [Paper link](http://proceedings.mlr.press/v136/wagh20a.html).
|
||||
|
||||
40
examples/pytorch/gatv2/README.md
Normal file
40
examples/pytorch/gatv2/README.md
Normal file
@@ -0,0 +1,40 @@
|
||||
Graph Attention Networks v2 (GATv2)
|
||||
============
|
||||
|
||||
- Paper link: [How Attentive are Graph Attention Networks?](https://arxiv.org/pdf/2105.14491.pdf)
|
||||
- Author's code repo: [https://github.com/tech-srl/how_attentive_are_gats](https://github.com/tech-srl/how_attentive_are_gats).
|
||||
- Annotated implemetnation: [https://nn.labml.ai/graphs/gatv2/index.html]
|
||||
|
||||
Dependencies
|
||||
------------
|
||||
- torch
|
||||
- requests
|
||||
- sklearn
|
||||
|
||||
How to run
|
||||
----------
|
||||
|
||||
Run with following:
|
||||
|
||||
```bash
|
||||
python3 train.py --dataset=cora
|
||||
```
|
||||
|
||||
```bash
|
||||
python3 train.py --dataset=citeseer
|
||||
```
|
||||
|
||||
```bash
|
||||
python3 train.py --dataset=pubmed
|
||||
```
|
||||
|
||||
Results
|
||||
-------
|
||||
|
||||
| Dataset | Test Accuracy |
|
||||
| -------- | ------------- |
|
||||
| Cora | 82.10 |
|
||||
| Citeseer | 70.00 |
|
||||
| Pubmed | 77.2 |
|
||||
|
||||
* All the accuracy numbers are obtained after 200 epochs.
|
||||
51
examples/pytorch/gatv2/gatv2.py
Normal file
51
examples/pytorch/gatv2/gatv2.py
Normal file
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Graph Attention Networks in DGL using SPMV optimization.
|
||||
References
|
||||
----------
|
||||
Paper: https://arxiv.org/pdf/2105.14491.pdf
|
||||
Author's code: https://github.com/tech-srl/how_attentive_are_gats
|
||||
"""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from dgl.nn import GATv2Conv
|
||||
|
||||
|
||||
class GATv2(nn.Module):
|
||||
def __init__(self,
|
||||
num_layers,
|
||||
in_dim,
|
||||
num_hidden,
|
||||
num_classes,
|
||||
heads,
|
||||
activation,
|
||||
feat_drop,
|
||||
attn_drop,
|
||||
negative_slope,
|
||||
residual):
|
||||
super(GATv2, self).__init__()
|
||||
self.num_layers = num_layers
|
||||
self.gatv2_layers = nn.ModuleList()
|
||||
self.activation = activation
|
||||
# input projection (no residual)
|
||||
self.gatv2_layers.append(GATv2Conv(
|
||||
in_dim, num_hidden, heads[0],
|
||||
feat_drop, attn_drop, negative_slope, False, self.activation, bias=False, share_weights=True))
|
||||
# hidden layers
|
||||
for l in range(1, num_layers):
|
||||
# due to multi-head, the in_dim = num_hidden * num_heads
|
||||
self.gatv2_layers.append(GATv2Conv(
|
||||
num_hidden * heads[l-1], num_hidden, heads[l],
|
||||
feat_drop, attn_drop, negative_slope, residual, self.activation, bias=False, share_weights=True))
|
||||
# output projection
|
||||
self.gatv2_layers.append(GATv2Conv(
|
||||
num_hidden * heads[-2], num_classes, heads[-1],
|
||||
feat_drop, attn_drop, negative_slope, residual, None, bias=False, share_weights=True))
|
||||
|
||||
def forward(self, g, inputs):
|
||||
h = inputs
|
||||
for l in range(self.num_layers):
|
||||
h = self.gatv2_layers[l](h).flatten(1)
|
||||
# output projection
|
||||
logits = self.gatv2_layers[-1](h).mean(1)
|
||||
return logits
|
||||
198
examples/pytorch/gatv2/train.py
Normal file
198
examples/pytorch/gatv2/train.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Graph Attention Networks v2 (GATv2) in DGL using SPMV optimization.
|
||||
Multiple heads are also batched together for faster training.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import numpy as np
|
||||
import time
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import dgl
|
||||
from dgl.data import register_data_args
|
||||
from dgl.data import CoraGraphDataset, CiteseerGraphDataset, PubmedGraphDataset
|
||||
|
||||
from gatv2 import GATv2
|
||||
|
||||
|
||||
class EarlyStopping:
|
||||
def __init__(self, patience=10):
|
||||
self.patience = patience
|
||||
self.counter = 0
|
||||
self.best_score = None
|
||||
self.early_stop = False
|
||||
|
||||
def step(self, acc, model):
|
||||
score = acc
|
||||
if self.best_score is None:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(model)
|
||||
elif score < self.best_score:
|
||||
self.counter += 1
|
||||
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
|
||||
if self.counter >= self.patience:
|
||||
self.early_stop = True
|
||||
else:
|
||||
self.best_score = score
|
||||
self.save_checkpoint(model)
|
||||
self.counter = 0
|
||||
return self.early_stop
|
||||
|
||||
def save_checkpoint(self, model):
|
||||
'''Saves model when validation loss decrease.'''
|
||||
torch.save(model.state_dict(), 'es_checkpoint.pt')
|
||||
|
||||
def accuracy(logits, labels):
|
||||
_, indices = torch.max(logits, dim=1)
|
||||
correct = torch.sum(indices == labels)
|
||||
return correct.item() * 1.0 / len(labels)
|
||||
|
||||
|
||||
def evaluate(model, g, features, labels, mask):
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
logits = model(g, features)
|
||||
logits = logits[mask]
|
||||
labels = labels[mask]
|
||||
return accuracy(logits, labels)
|
||||
|
||||
|
||||
def main(args):
|
||||
# load and preprocess dataset
|
||||
if args.dataset == 'cora':
|
||||
data = CoraGraphDataset()
|
||||
elif args.dataset == 'citeseer':
|
||||
data = CiteseerGraphDataset()
|
||||
elif args.dataset == 'pubmed':
|
||||
data = PubmedGraphDataset()
|
||||
else:
|
||||
raise ValueError('Unknown dataset: {}'.format(args.dataset))
|
||||
|
||||
g = data[0]
|
||||
if args.gpu < 0:
|
||||
cuda = False
|
||||
else:
|
||||
cuda = True
|
||||
g = g.int().to(args.gpu)
|
||||
|
||||
features = g.ndata['feat']
|
||||
labels = g.ndata['label']
|
||||
train_mask = g.ndata['train_mask']
|
||||
val_mask = g.ndata['val_mask']
|
||||
test_mask = g.ndata['test_mask']
|
||||
num_feats = features.shape[1]
|
||||
n_classes = data.num_labels
|
||||
n_edges = data.graph.number_of_edges()
|
||||
print("""----Data statistics------'
|
||||
#Edges %d
|
||||
#Classes %d
|
||||
#Train samples %d
|
||||
#Val samples %d
|
||||
#Test samples %d""" %
|
||||
(n_edges, n_classes,
|
||||
train_mask.int().sum().item(),
|
||||
val_mask.int().sum().item(),
|
||||
test_mask.int().sum().item()))
|
||||
|
||||
# add self loop
|
||||
g = dgl.remove_self_loop(g)
|
||||
g = dgl.add_self_loop(g)
|
||||
n_edges = g.number_of_edges()
|
||||
# create model
|
||||
heads = ([args.num_heads] * args.num_layers) + [args.num_out_heads]
|
||||
model = GATv2(args.num_layers,
|
||||
num_feats,
|
||||
args.num_hidden,
|
||||
n_classes,
|
||||
heads,
|
||||
F.elu,
|
||||
args.in_drop,
|
||||
args.attn_drop,
|
||||
args.negative_slope,
|
||||
args.residual)
|
||||
print(model)
|
||||
if args.early_stop:
|
||||
stopper = EarlyStopping(patience=100)
|
||||
if cuda:
|
||||
model.cuda()
|
||||
loss_fcn = torch.nn.CrossEntropyLoss()
|
||||
|
||||
# use optimizer
|
||||
optimizer = torch.optim.Adam(
|
||||
model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
|
||||
|
||||
# initialize graph
|
||||
dur = []
|
||||
for epoch in range(args.epochs):
|
||||
model.train()
|
||||
if epoch >= 3:
|
||||
t0 = time.time()
|
||||
# forward
|
||||
logits = model(g, features)
|
||||
loss = loss_fcn(logits[train_mask], labels[train_mask])
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
if epoch >= 3:
|
||||
dur.append(time.time() - t0)
|
||||
|
||||
train_acc = accuracy(logits[train_mask], labels[train_mask])
|
||||
|
||||
if args.fastmode:
|
||||
val_acc = accuracy(logits[val_mask], labels[val_mask])
|
||||
else:
|
||||
val_acc = evaluate(g, model, features, labels, val_mask)
|
||||
if args.early_stop:
|
||||
if stopper.step(val_acc, model):
|
||||
break
|
||||
|
||||
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | TrainAcc {:.4f} |"
|
||||
" ValAcc {:.4f} | ETputs(KTEPS) {:.2f}".
|
||||
format(epoch, np.mean(dur), loss.item(), train_acc,
|
||||
val_acc, n_edges / np.mean(dur) / 1000))
|
||||
|
||||
print()
|
||||
if args.early_stop:
|
||||
model.load_state_dict(torch.load('es_checkpoint.pt'))
|
||||
acc = evaluate(model, features, labels, test_mask)
|
||||
print("Test Accuracy {:.4f}".format(acc))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
parser = argparse.ArgumentParser(description='GAT')
|
||||
register_data_args(parser)
|
||||
parser.add_argument("--gpu", type=int, default=-1,
|
||||
help="which GPU to use. Set -1 to use CPU.")
|
||||
parser.add_argument("--epochs", type=int, default=200,
|
||||
help="number of training epochs")
|
||||
parser.add_argument("--num-heads", type=int, default=8,
|
||||
help="number of hidden attention heads")
|
||||
parser.add_argument("--num-out-heads", type=int, default=1,
|
||||
help="number of output attention heads")
|
||||
parser.add_argument("--num-layers", type=int, default=1,
|
||||
help="number of hidden layers")
|
||||
parser.add_argument("--num-hidden", type=int, default=8,
|
||||
help="number of hidden units")
|
||||
parser.add_argument("--residual", action="store_true", default=False,
|
||||
help="use residual connection")
|
||||
parser.add_argument("--in-drop", type=float, default=.7,
|
||||
help="input feature dropout")
|
||||
parser.add_argument("--attn-drop", type=float, default=.7,
|
||||
help="attention dropout")
|
||||
parser.add_argument("--lr", type=float, default=0.005,
|
||||
help="learning rate")
|
||||
parser.add_argument('--weight-decay', type=float, default=5e-4,
|
||||
help="weight decay")
|
||||
parser.add_argument('--negative-slope', type=float, default=0.2,
|
||||
help="the negative slope of leaky relu")
|
||||
parser.add_argument('--early-stop', action='store_true', default=False,
|
||||
help="indicates whether to use early stop or not")
|
||||
parser.add_argument('--fastmode', action="store_true", default=False,
|
||||
help="skip re-evaluate the validation set")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
main(args)
|
||||
@@ -6,6 +6,7 @@ from .appnpconv import APPNPConv
|
||||
from .chebconv import ChebConv
|
||||
from .edgeconv import EdgeConv
|
||||
from .gatconv import GATConv
|
||||
from .gatv2conv import GATv2Conv
|
||||
from .egatconv import EGATConv
|
||||
from .ginconv import GINConv
|
||||
from .gmmconv import GMMConv
|
||||
@@ -25,8 +26,8 @@ from .dotgatconv import DotGatConv
|
||||
from .twirlsconv import TWIRLSConv, TWIRLSUnfoldingAndAttention
|
||||
from .gcn2conv import GCN2Conv
|
||||
|
||||
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'EGATConv', 'TAGConv', 'RelGraphConv',
|
||||
'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv', 'GMMConv',
|
||||
'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
|
||||
__all__ = ['GraphConv', 'EdgeWeightNorm', 'GATConv', 'GATv2Conv', 'EGATConv', 'TAGConv',
|
||||
'RelGraphConv', 'SAGEConv', 'SGConv', 'APPNPConv', 'GINConv', 'GatedGraphConv',
|
||||
'GMMConv', 'ChebConv', 'AGNNConv', 'NNConv', 'DenseGraphConv', 'DenseSAGEConv',
|
||||
'DenseChebConv', 'EdgeConv', 'AtomicConv', 'CFConv', 'DotGatConv', 'TWIRLSConv',
|
||||
'TWIRLSUnfoldingAndAttention', 'GCN2Conv']
|
||||
|
||||
312
python/dgl/nn/pytorch/conv/gatv2conv.py
Normal file
312
python/dgl/nn/pytorch/conv/gatv2conv.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Torch modules for graph attention networks v2 (GATv2)."""
|
||||
# pylint: disable= no-member, arguments-differ, invalid-name
|
||||
import torch as th
|
||||
from torch import nn
|
||||
|
||||
from .... import function as fn
|
||||
from ...functional import edge_softmax
|
||||
from ....base import DGLError
|
||||
from ..utils import Identity
|
||||
from ....utils import expand_as_pair
|
||||
|
||||
# pylint: enable=W0235
|
||||
class GATv2Conv(nn.Module):
|
||||
r"""
|
||||
|
||||
Description
|
||||
-----------
|
||||
Apply GATv2 from
|
||||
`How Attentive are Graph Attention Networks? <https://arxiv.org/pdf/2105.14491.pdf>`__
|
||||
over an input signal.
|
||||
|
||||
.. math::
|
||||
h_i^{(l+1)} = \sum_{j\in \mathcal{N}(i)} \alpha_{i,j} W^{(l)}_{right} h_j^{(l)}
|
||||
|
||||
where :math:`\alpha_{ij}` is the attention score bewteen node :math:`i` and
|
||||
node :math:`j`:
|
||||
|
||||
.. math::
|
||||
\alpha_{ij}^{l} &= \mathrm{softmax_i} (e_{ij}^{l})
|
||||
|
||||
e_{ij}^{l} &= \vec{a}^T\mathrm{LeakyReLU}\left(
|
||||
W^{(l)}_{left} h_{i} + W^{(l)}_{right} h_{j}]\right)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
in_feats : int, or pair of ints
|
||||
Input feature size; i.e, the number of dimensions of :math:`h_i^{(l)}`.
|
||||
If the layer is to be applied to a unidirectional bipartite graph, `in_feats`
|
||||
specifies the input feature size on both the source and destination nodes.
|
||||
If a scalar is given, the source and destination node feature size
|
||||
would take the same value.
|
||||
out_feats : int
|
||||
Output feature size; i.e, the number of dimensions of :math:`h_i^{(l+1)}`.
|
||||
num_heads : int
|
||||
Number of heads in Multi-Head Attention.
|
||||
feat_drop : float, optional
|
||||
Dropout rate on feature. Defaults: ``0``.
|
||||
attn_drop : float, optional
|
||||
Dropout rate on attention weight. Defaults: ``0``.
|
||||
negative_slope : float, optional
|
||||
LeakyReLU angle of negative slope. Defaults: ``0.2``.
|
||||
residual : bool, optional
|
||||
If True, use residual connection. Defaults: ``False``.
|
||||
activation : callable activation function/layer or None, optional.
|
||||
If not None, applies an activation function to the updated node features.
|
||||
Default: ``None``.
|
||||
allow_zero_in_degree : bool, optional
|
||||
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
|
||||
since no message will be passed to those nodes. This is harmful for some applications
|
||||
causing silent performance regression. This module will raise a DGLError if it detects
|
||||
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
|
||||
and let the users handle it by themselves. Defaults: ``False``.
|
||||
bias : bool, optional
|
||||
If set to :obj:`False`, the layer will not learn
|
||||
an additive bias. (default: :obj:`True`)
|
||||
share_weights : bool, optional
|
||||
If set to :obj:`True`, the same matrix for :math:`W_{left}` and :math:`W_{right}` in
|
||||
the above equations, will be applied to the source and the target node of every edge.
|
||||
(default: :obj:`False`)
|
||||
|
||||
Note
|
||||
----
|
||||
Zero in-degree nodes will lead to invalid output value. This is because no message
|
||||
will be passed to those nodes, the aggregation function will be applied on empty input.
|
||||
A common practice to avoid this is to add a self-loop for each node in the graph if
|
||||
it is homogeneous, which can be achieved by:
|
||||
|
||||
>>> g = ... # a DGLGraph
|
||||
>>> g = dgl.add_self_loop(g)
|
||||
|
||||
Calling ``add_self_loop`` will not work for some graphs, for example, heterogeneous graph
|
||||
since the edge type can not be decided for self_loop edges. Set ``allow_zero_in_degree``
|
||||
to ``True`` for those cases to unblock the code and handle zero-in-degree nodes manually.
|
||||
A common practise to handle this is to filter out the nodes with zero-in-degree when use
|
||||
after conv.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> import dgl
|
||||
>>> import numpy as np
|
||||
>>> import torch as th
|
||||
>>> from dgl.nn import GATv2Conv
|
||||
|
||||
>>> # Case 1: Homogeneous graph
|
||||
>>> g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
|
||||
>>> g = dgl.add_self_loop(g)
|
||||
>>> feat = th.ones(6, 10)
|
||||
>>> gatv2conv = GATv2Conv(10, 2, num_heads=3)
|
||||
>>> res = gatv2conv(g, feat)
|
||||
>>> res
|
||||
tensor([[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]],
|
||||
[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]],
|
||||
[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]],
|
||||
[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]],
|
||||
[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]],
|
||||
[[ 1.9599, 1.0239],
|
||||
[ 3.2015, -0.5512],
|
||||
[ 2.3700, -2.2182]]], grad_fn=<GSpMMBackward>)
|
||||
|
||||
>>> # Case 2: Unidirectional bipartite graph
|
||||
>>> u = [0, 1, 0, 0, 1]
|
||||
>>> v = [0, 1, 2, 3, 2]
|
||||
>>> g = dgl.heterograph({('A', 'r', 'B'): (u, v)})
|
||||
>>> u_feat = th.tensor(np.random.rand(2, 5).astype(np.float32))
|
||||
>>> v_feat = th.tensor(np.random.rand(4, 10).astype(np.float32))
|
||||
>>> gatv2conv = GATv2Conv((5,10), 2, 3)
|
||||
>>> res = gatv2conv(g, (u_feat, v_feat))
|
||||
>>> res
|
||||
tensor([[[-0.0935, -0.4273],
|
||||
[-1.1850, 0.1123],
|
||||
[-0.2002, 0.1155]],
|
||||
[[ 0.1908, -1.2095],
|
||||
[-0.0129, 0.6408],
|
||||
[-0.8135, 0.1157]],
|
||||
[[ 0.0596, -0.8487],
|
||||
[-0.5421, 0.4022],
|
||||
[-0.4805, 0.1156]],
|
||||
[[-0.0935, -0.4273],
|
||||
[-1.1850, 0.1123],
|
||||
[-0.2002, 0.1155]]], grad_fn=<GSpMMBackward>)
|
||||
"""
|
||||
def __init__(self,
|
||||
in_feats,
|
||||
out_feats,
|
||||
num_heads,
|
||||
feat_drop=0.,
|
||||
attn_drop=0.,
|
||||
negative_slope=0.2,
|
||||
residual=False,
|
||||
activation=None,
|
||||
allow_zero_in_degree=False,
|
||||
bias=True,
|
||||
share_weights=False):
|
||||
super(GATv2Conv, self).__init__()
|
||||
self._num_heads = num_heads
|
||||
self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
|
||||
self._out_feats = out_feats
|
||||
self._allow_zero_in_degree = allow_zero_in_degree
|
||||
if isinstance(in_feats, tuple):
|
||||
self.fc_src = nn.Linear(
|
||||
self._in_src_feats, out_feats * num_heads, bias=bias)
|
||||
self.fc_dst = nn.Linear(
|
||||
self._in_dst_feats, out_feats * num_heads, bias=bias)
|
||||
else:
|
||||
self.fc_src = nn.Linear(
|
||||
self._in_src_feats, out_feats * num_heads, bias=bias)
|
||||
if share_weights:
|
||||
self.fc_dst = self.fc_src
|
||||
else:
|
||||
self.fc_dst = nn.Linear(
|
||||
self._in_src_feats, out_feats * num_heads, bias=bias)
|
||||
self.attn = nn.Parameter(th.FloatTensor(size=(1, num_heads, out_feats)))
|
||||
self.feat_drop = nn.Dropout(feat_drop)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.leaky_relu = nn.LeakyReLU(negative_slope)
|
||||
if residual:
|
||||
if self._in_dst_feats != out_feats:
|
||||
self.res_fc = nn.Linear(
|
||||
self._in_dst_feats, num_heads * out_feats, bias=bias)
|
||||
else:
|
||||
self.res_fc = Identity()
|
||||
else:
|
||||
self.register_buffer('res_fc', None)
|
||||
self.activation = activation
|
||||
self.share_weights = share_weights
|
||||
self.bias = bias
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
"""
|
||||
Description
|
||||
-----------
|
||||
Reinitialize learnable parameters.
|
||||
|
||||
Note
|
||||
----
|
||||
The fc weights :math:`W^{(l)}` are initialized using Glorot uniform initialization.
|
||||
The attention weights are using xavier initialization method.
|
||||
"""
|
||||
gain = nn.init.calculate_gain('relu')
|
||||
nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
|
||||
if self.bias:
|
||||
nn.init.constant_(self.fc_src.bias, 0)
|
||||
if not self.share_weights:
|
||||
nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
|
||||
if self.bias:
|
||||
nn.init.constant_(self.fc_dst.bias, 0)
|
||||
nn.init.xavier_normal_(self.attn, gain=gain)
|
||||
if isinstance(self.res_fc, nn.Linear):
|
||||
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
|
||||
if self.bias:
|
||||
nn.init.constant_(self.res_fc.bias, 0)
|
||||
|
||||
def set_allow_zero_in_degree(self, set_value):
|
||||
r"""
|
||||
Description
|
||||
-----------
|
||||
Set allow_zero_in_degree flag.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
set_value : bool
|
||||
The value to be set to the flag.
|
||||
"""
|
||||
self._allow_zero_in_degree = set_value
|
||||
|
||||
def forward(self, graph, feat, get_attention=False):
|
||||
r"""
|
||||
Description
|
||||
-----------
|
||||
Compute graph attention network layer.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
graph : DGLGraph
|
||||
The graph.
|
||||
feat : torch.Tensor or pair of torch.Tensor
|
||||
If a torch.Tensor is given, the input feature of shape :math:`(N, D_{in})` where
|
||||
:math:`D_{in}` is size of input feature, :math:`N` is the number of nodes.
|
||||
If a pair of torch.Tensor is given, the pair must contain two tensors of shape
|
||||
:math:`(N_{in}, D_{in_{src}})` and :math:`(N_{out}, D_{in_{dst}})`.
|
||||
get_attention : bool, optional
|
||||
Whether to return the attention values. Default to False.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
The output feature of shape :math:`(N, H, D_{out})` where :math:`H`
|
||||
is the number of heads, and :math:`D_{out}` is size of output feature.
|
||||
torch.Tensor, optional
|
||||
The attention values of shape :math:`(E, H, 1)`, where :math:`E` is the number of
|
||||
edges. This is returned only when :attr:`get_attention` is ``True``.
|
||||
|
||||
Raises
|
||||
------
|
||||
DGLError
|
||||
If there are 0-in-degree nodes in the input graph, it will raise DGLError
|
||||
since no message will be passed to those nodes. This will cause invalid output.
|
||||
The error can be ignored by setting ``allow_zero_in_degree`` parameter to ``True``.
|
||||
"""
|
||||
with graph.local_scope():
|
||||
if not self._allow_zero_in_degree:
|
||||
if (graph.in_degrees() == 0).any():
|
||||
raise DGLError('There are 0-in-degree nodes in the graph, '
|
||||
'output for those nodes will be invalid. '
|
||||
'This is harmful for some applications, '
|
||||
'causing silent performance regression. '
|
||||
'Adding self-loop on the input graph by '
|
||||
'calling `g = dgl.add_self_loop(g)` will resolve '
|
||||
'the issue. Setting ``allow_zero_in_degree`` '
|
||||
'to be `True` when constructing this module will '
|
||||
'suppress the check and let the code run.')
|
||||
|
||||
if isinstance(feat, tuple):
|
||||
h_src = self.feat_drop(feat[0])
|
||||
h_dst = self.feat_drop(feat[1])
|
||||
feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
|
||||
feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
|
||||
else:
|
||||
h_src = h_dst = self.feat_drop(feat)
|
||||
feat_src = self.fc_src(h_src).view(
|
||||
-1, self._num_heads, self._out_feats)
|
||||
if self.share_weights:
|
||||
feat_dst = feat_src
|
||||
else:
|
||||
feat_dst = self.fc_dst(h_src).view(
|
||||
-1, self._num_heads, self._out_feats)
|
||||
if graph.is_block:
|
||||
feat_dst = feat_src[:graph.number_of_dst_nodes()]
|
||||
graph.srcdata.update({'el': feat_src})# (num_src_edge, num_heads, out_dim)
|
||||
graph.dstdata.update({'er': feat_dst})
|
||||
graph.apply_edges(fn.u_add_v('el', 'er', 'e'))
|
||||
e = self.leaky_relu(graph.edata.pop('e'))# (num_src_edge, num_heads, out_dim)
|
||||
e = (e * self.attn).sum(dim=-1).unsqueeze(dim=2)# (num_edge, num_heads, 1)
|
||||
# compute softmax
|
||||
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) # (num_edge, num_heads)
|
||||
# message passing
|
||||
graph.update_all(fn.u_mul_e('el', 'a', 'm'),
|
||||
fn.sum('m', 'ft'))
|
||||
rst = graph.dstdata['ft']
|
||||
# residual
|
||||
if self.res_fc is not None:
|
||||
resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
|
||||
rst = rst + resval
|
||||
# activation
|
||||
if self.activation:
|
||||
rst = self.activation(rst)
|
||||
|
||||
if get_attention:
|
||||
return rst, graph.edata['a']
|
||||
else:
|
||||
return rst
|
||||
@@ -564,6 +564,45 @@ def test_gat_conv_bi(g, idtype, out_dim, num_heads):
|
||||
_, a = gat(g, feat, get_attention=True)
|
||||
assert a.shape == (g.number_of_edges(), num_heads, 1)
|
||||
|
||||
@parametrize_dtype
|
||||
@pytest.mark.parametrize('g', get_cases(['homo', 'block-bipartite'], exclude=['zero-degree']))
|
||||
@pytest.mark.parametrize('out_dim', [1, 5])
|
||||
@pytest.mark.parametrize('num_heads', [1, 4])
|
||||
def test_gatv2_conv(g, idtype, out_dim, num_heads):
|
||||
g = g.astype(idtype).to(F.ctx())
|
||||
ctx = F.ctx()
|
||||
gat = nn.GATv2Conv(5, out_dim, num_heads)
|
||||
feat = F.randn((g.number_of_src_nodes(), 5))
|
||||
gat = gat.to(ctx)
|
||||
h = gat(g, feat)
|
||||
|
||||
# test pickle
|
||||
th.save(gat, tmp_buffer)
|
||||
|
||||
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
|
||||
_, a = gat(g, feat, get_attention=True)
|
||||
assert a.shape == (g.number_of_edges(), num_heads, 1)
|
||||
|
||||
# test residual connection
|
||||
gat = nn.GATConv(5, out_dim, num_heads, residual=True)
|
||||
gat = gat.to(ctx)
|
||||
h = gat(g, feat)
|
||||
|
||||
@parametrize_dtype
|
||||
@pytest.mark.parametrize('g', get_cases(['bipartite'], exclude=['zero-degree']))
|
||||
@pytest.mark.parametrize('out_dim', [1, 2])
|
||||
@pytest.mark.parametrize('num_heads', [1, 4])
|
||||
def test_gatv2_conv_bi(g, idtype, out_dim, num_heads):
|
||||
g = g.astype(idtype).to(F.ctx())
|
||||
ctx = F.ctx()
|
||||
gat = nn.GATv2Conv(5, out_dim, num_heads)
|
||||
feat = (F.randn((g.number_of_src_nodes(), 5)), F.randn((g.number_of_dst_nodes(), 5)))
|
||||
gat = gat.to(ctx)
|
||||
h = gat(g, feat)
|
||||
assert h.shape == (g.number_of_dst_nodes(), num_heads, out_dim)
|
||||
_, a = gat(g, feat, get_attention=True)
|
||||
assert a.shape == (g.number_of_edges(), num_heads, 1)
|
||||
|
||||
@parametrize_dtype
|
||||
@pytest.mark.parametrize('g', get_cases(['homo'], exclude=['zero-degree']))
|
||||
@pytest.mark.parametrize('out_node_feats', [1, 5])
|
||||
@@ -1159,6 +1198,7 @@ if __name__ == '__main__':
|
||||
test_rgcn_sorted()
|
||||
test_tagconv()
|
||||
test_gat_conv()
|
||||
test_gatv2_conv()
|
||||
test_egat_conv()
|
||||
test_sage_conv()
|
||||
test_sgc_conv()
|
||||
|
||||
Reference in New Issue
Block a user