mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[Bugfix][nn/PyTorch]: add checks to avoid view/reshape (0, -1, *) on empty tensors (#7894)
This commit is contained in:
@@ -368,10 +368,11 @@ class EdgeGATConv(nn.Module):
|
||||
# Residual.
|
||||
if self.res_fc is not None:
|
||||
# Use -1 rather than self._num_heads to handle broadcasting.
|
||||
resval = self.res_fc(h_dst).view(
|
||||
*dst_prefix_shape, -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
if h_dst.numel() != 0:
|
||||
resval = self.res_fc(h_dst).view(
|
||||
*dst_prefix_shape, -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
# Bias.
|
||||
if self.bias is not None:
|
||||
rst = rst + self.bias.view(
|
||||
|
||||
@@ -348,10 +348,11 @@ class GATConv(nn.Module):
|
||||
# residual
|
||||
if self.res_fc is not None:
|
||||
# Use -1 rather than self._num_heads to handle broadcasting
|
||||
resval = self.res_fc(h_dst).view(
|
||||
*dst_prefix_shape, -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
if h_dst.numel() != 0:
|
||||
resval = self.res_fc(h_dst).view(
|
||||
*dst_prefix_shape, -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
# bias
|
||||
if self.has_explicit_bias:
|
||||
rst = rst + self.bias.view(
|
||||
|
||||
@@ -320,10 +320,11 @@ class GATv2Conv(nn.Module):
|
||||
rst = graph.dstdata["ft"]
|
||||
# residual
|
||||
if self.res_fc is not None:
|
||||
resval = self.res_fc(h_dst).view(
|
||||
h_dst.shape[0], -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
if h_dst.numel() != 0:
|
||||
resval = self.res_fc(h_dst).view(
|
||||
h_dst.shape[0], -1, self._out_feats
|
||||
)
|
||||
rst = rst + resval
|
||||
# activation
|
||||
if self.activation:
|
||||
rst = self.activation(rst)
|
||||
|
||||
@@ -2680,3 +2680,76 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
|
||||
encoding3d_2 = model_3(coord, node_type)
|
||||
assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
|
||||
assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("residual", [True, False])
|
||||
def test_conv_with_zero_nodes_bugfix_7894(residual):
|
||||
"""Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a
|
||||
specific node type would cause an error due to empty tensors.
|
||||
This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle
|
||||
such cases without raising errors.
|
||||
"""
|
||||
# Create a heterogeneous graph with zero nodes in the "tag" type
|
||||
user_item_src = torch.tensor([0, 1, 2])
|
||||
user_item_dst = torch.tensor([4, 5, 6])
|
||||
|
||||
user_tag_src = torch.tensor([], dtype=torch.int64)
|
||||
user_tag_dst = torch.tensor([], dtype=torch.int64)
|
||||
|
||||
num_nodes_dict = {
|
||||
"user": 5,
|
||||
"item": 10,
|
||||
"tag": 0,
|
||||
}
|
||||
|
||||
data_dict = {
|
||||
("user", "buys", "item"): (user_item_src, user_item_dst),
|
||||
("user", "likes", "tag"): (user_tag_src, user_tag_dst),
|
||||
}
|
||||
|
||||
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)
|
||||
|
||||
feat_dim = 16
|
||||
node_features = {
|
||||
"user": torch.randn(num_nodes_dict["user"], feat_dim),
|
||||
"item": torch.randn(num_nodes_dict["item"], feat_dim),
|
||||
"tag": torch.randn(num_nodes_dict["tag"], feat_dim),
|
||||
}
|
||||
edge_features = {
|
||||
("user", "buys", "item"): torch.randn(g.num_edges(("user", "buys", "item")), feat_dim),
|
||||
("user", "likes", "tag"): torch.randn(g.num_edges(("user", "likes", "tag")), feat_dim),
|
||||
}
|
||||
|
||||
# Test GATConv with zero nodes in "tag" type
|
||||
conv = nn.HeteroGraphConv({
|
||||
("user", "buys", "item"): nn.GATConv(16, 2, num_heads=2, residual=residual),
|
||||
("user", "likes", "tag"): nn.GATConv(16, 2, num_heads=2, residual=residual),
|
||||
}, aggregate="sum")
|
||||
out = conv(g, node_features)
|
||||
assert out["item"].shape == (10, 2, 2)
|
||||
assert out["tag"].shape == (0, 2, 2)
|
||||
assert "user" not in out
|
||||
|
||||
# Test GATv2Conv with zero nodes in "tag" type
|
||||
conv_v2 = nn.HeteroGraphConv({
|
||||
("user", "buys", "item"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
|
||||
("user", "likes", "tag"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
|
||||
}, aggregate="sum")
|
||||
out_v2 = conv_v2(g, node_features)
|
||||
assert out_v2["item"].shape == (10, 2, 2)
|
||||
assert out_v2["tag"].shape == (0, 2, 2)
|
||||
assert "user" not in out_v2
|
||||
|
||||
# Test EdgeGATConv with zero nodes in "tag" type
|
||||
edge_conv = nn.HeteroGraphConv({
|
||||
("user", "buys", "item"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
|
||||
("user", "likes", "tag"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
|
||||
}, aggregate="sum")
|
||||
mod_kwargs = {
|
||||
"buys": {"edge_feat": edge_features[("user", "buys", "item")]},
|
||||
"likes": {"edge_feat": edge_features[("user", "likes", "tag")]},
|
||||
}
|
||||
out_edge = edge_conv(g, node_features, mod_kwargs=mod_kwargs)
|
||||
assert out_edge["item"].shape == (10, 2, 2)
|
||||
assert out_edge["tag"].shape == (0, 2, 2)
|
||||
assert "user" not in out_edge
|
||||
|
||||
Reference in New Issue
Block a user