mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
* update * graph Co-authored-by: Steve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
145 lines
4.4 KiB
Python
145 lines
4.4 KiB
Python
"""
|
|
Hypergraph Convolution and Hypergraph Attention
|
|
(https://arxiv.org/pdf/1901.08150.pdf).
|
|
"""
|
|
import argparse
|
|
|
|
import dgl.sparse as dglsp
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import tqdm
|
|
from dgl.data import CoraGraphDataset
|
|
from torchmetrics.functional import accuracy
|
|
|
|
|
|
def hypergraph_laplacian(H):
|
|
###########################################################
|
|
# (HIGHLIGHT) Compute the Laplacian with Sparse Matrix API
|
|
###########################################################
|
|
d_V = H.sum(1) # node degree
|
|
d_E = H.sum(0) # edge degree
|
|
n_edges = d_E.shape[0]
|
|
D_V_invsqrt = dglsp.diag(d_V**-0.5) # D_V ** (-1/2)
|
|
D_E_inv = dglsp.diag(d_E**-1) # D_E ** (-1)
|
|
W = dglsp.identity((n_edges, n_edges))
|
|
return D_V_invsqrt @ H @ W @ D_E_inv @ H.T @ D_V_invsqrt
|
|
|
|
|
|
class HypergraphAttention(nn.Module):
|
|
"""Hypergraph Attention module as in the paper
|
|
`Hypergraph Convolution and Hypergraph Attention
|
|
<https://arxiv.org/pdf/1901.08150.pdf>`_.
|
|
"""
|
|
|
|
def __init__(self, in_size, out_size):
|
|
super().__init__()
|
|
|
|
self.P = nn.Linear(in_size, out_size)
|
|
self.a = nn.Linear(2 * out_size, 1)
|
|
|
|
def forward(self, H, X, X_edges):
|
|
Z = self.P(X)
|
|
Z_edges = self.P(X_edges)
|
|
sim = self.a(torch.cat([Z[H.row], Z_edges[H.col]], 1))
|
|
sim = F.leaky_relu(sim, 0.2).squeeze(1)
|
|
# Reassign the hypergraph new weights.
|
|
H_att = dglsp.val_like(H, sim)
|
|
H_att = H_att.softmax()
|
|
return hypergraph_laplacian(H_att) @ Z
|
|
|
|
|
|
class Net(nn.Module):
|
|
def __init__(self, in_size, out_size, hidden_size=16):
|
|
super().__init__()
|
|
|
|
self.layer1 = HypergraphAttention(in_size, hidden_size)
|
|
self.layer2 = HypergraphAttention(hidden_size, out_size)
|
|
|
|
def forward(self, H, X):
|
|
Z = self.layer1(H, X, X)
|
|
Z = F.elu(Z)
|
|
Z = self.layer2(H, Z, Z)
|
|
return Z
|
|
|
|
|
|
def train(model, optimizer, H, X, Y, train_mask):
|
|
model.train()
|
|
Y_hat = model(H, X)
|
|
loss = F.cross_entropy(Y_hat[train_mask], Y[train_mask])
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
return loss.item()
|
|
|
|
|
|
def evaluate(model, H, X, Y, val_mask, test_mask, num_classes):
|
|
model.eval()
|
|
Y_hat = model(H, X)
|
|
val_acc = accuracy(
|
|
Y_hat[val_mask], Y[val_mask], task="multiclass", num_classes=num_classes
|
|
)
|
|
test_acc = accuracy(
|
|
Y_hat[test_mask],
|
|
Y[test_mask],
|
|
task="multiclass",
|
|
num_classes=num_classes,
|
|
)
|
|
return val_acc, test_acc
|
|
|
|
|
|
def load_data():
|
|
dataset = CoraGraphDataset()
|
|
|
|
graph = dataset[0]
|
|
# The paper created a hypergraph from the original graph. For each node in
|
|
# the original graph, a hyperedge in the hypergraph is created to connect
|
|
# its neighbors and itself. In this case, the incidence matrix of the
|
|
# hypergraph is the same as the adjacency matrix of the original graph (with
|
|
# self-loops).
|
|
# We follow the paper and assume that the rows of the incidence matrix
|
|
# are for nodes and the columns are for edges.
|
|
indices = torch.stack(graph.edges())
|
|
H = dglsp.spmatrix(indices)
|
|
H = H + dglsp.identity(H.shape)
|
|
|
|
X = graph.ndata["feat"]
|
|
Y = graph.ndata["label"]
|
|
train_mask = graph.ndata["train_mask"]
|
|
val_mask = graph.ndata["val_mask"]
|
|
test_mask = graph.ndata["test_mask"]
|
|
return H, X, Y, dataset.num_classes, train_mask, val_mask, test_mask
|
|
|
|
|
|
def main(args):
|
|
H, X, Y, num_classes, train_mask, val_mask, test_mask = load_data()
|
|
model = Net(X.shape[1], num_classes)
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
with tqdm.trange(args.epochs) as tq:
|
|
for epoch in tq:
|
|
loss = train(model, optimizer, H, X, Y, train_mask)
|
|
val_acc, test_acc = evaluate(
|
|
model, H, X, Y, val_mask, test_mask, num_classes
|
|
)
|
|
tq.set_postfix(
|
|
{
|
|
"Loss": f"{loss:.5f}",
|
|
"Val acc": f"{val_acc:.5f}",
|
|
"Test acc": f"{test_acc:.5f}",
|
|
},
|
|
refresh=False,
|
|
)
|
|
|
|
print(f"Test acc: {test_acc:.3f}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="Hypergraph Attention Example")
|
|
parser.add_argument(
|
|
"--epochs", type=int, default=500, help="Number of training epochs."
|
|
)
|
|
args = parser.parse_args()
|
|
main(args)
|