[Bugfix][nn/PyTorch]: add checks to avoid view/reshape (0, -1, *) on empty tensors (#7894)

This commit is contained in:
ikun
2025-08-01 07:53:10 +08:00
committed by GitHub
parent 743e65fa3b
commit 3d16000b41
4 changed files with 88 additions and 12 deletions

View File

@@ -368,10 +368,11 @@ class EdgeGATConv(nn.Module):
# Residual.
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting.
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
if h_dst.numel() != 0:
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# Bias.
if self.bias is not None:
rst = rst + self.bias.view(

View File

@@ -348,10 +348,11 @@ class GATConv(nn.Module):
# residual
if self.res_fc is not None:
# Use -1 rather than self._num_heads to handle broadcasting
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
if h_dst.numel() != 0:
resval = self.res_fc(h_dst).view(
*dst_prefix_shape, -1, self._out_feats
)
rst = rst + resval
# bias
if self.has_explicit_bias:
rst = rst + self.bias.view(

View File

@@ -320,10 +320,11 @@ class GATv2Conv(nn.Module):
rst = graph.dstdata["ft"]
# residual
if self.res_fc is not None:
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
if h_dst.numel() != 0:
resval = self.res_fc(h_dst).view(
h_dst.shape[0], -1, self._out_feats
)
rst = rst + resval
# activation
if self.activation:
rst = self.activation(rst)

View File

@@ -2680,3 +2680,76 @@ def test_SpatialEncoder(max_dist, num_kernels, num_heads):
encoding3d_2 = model_3(coord, node_type)
assert encoding3d_1.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
assert encoding3d_2.shape == (bsz, max_num_nodes, max_num_nodes, num_heads)
@pytest.mark.parametrize("residual", [True, False])
def test_conv_with_zero_nodes_bugfix_7894(residual):
"""Test for PR #7894 in DGL where HeteroGraphConv with zero nodes in a
specific node type would cause an error due to empty tensors.
This test ensures that GATConv, GATv2Conv, and EdgeGATConv can handle
such cases without raising errors.
"""
# Create a heterogeneous graph with zero nodes in the "tag" type
user_item_src = torch.tensor([0, 1, 2])
user_item_dst = torch.tensor([4, 5, 6])
user_tag_src = torch.tensor([], dtype=torch.int64)
user_tag_dst = torch.tensor([], dtype=torch.int64)
num_nodes_dict = {
"user": 5,
"item": 10,
"tag": 0,
}
data_dict = {
("user", "buys", "item"): (user_item_src, user_item_dst),
("user", "likes", "tag"): (user_tag_src, user_tag_dst),
}
g = dgl.heterograph(data_dict, num_nodes_dict=num_nodes_dict)
feat_dim = 16
node_features = {
"user": torch.randn(num_nodes_dict["user"], feat_dim),
"item": torch.randn(num_nodes_dict["item"], feat_dim),
"tag": torch.randn(num_nodes_dict["tag"], feat_dim),
}
edge_features = {
("user", "buys", "item"): torch.randn(g.num_edges(("user", "buys", "item")), feat_dim),
("user", "likes", "tag"): torch.randn(g.num_edges(("user", "likes", "tag")), feat_dim),
}
# Test GATConv with zero nodes in "tag" type
conv = nn.HeteroGraphConv({
("user", "buys", "item"): nn.GATConv(16, 2, num_heads=2, residual=residual),
("user", "likes", "tag"): nn.GATConv(16, 2, num_heads=2, residual=residual),
}, aggregate="sum")
out = conv(g, node_features)
assert out["item"].shape == (10, 2, 2)
assert out["tag"].shape == (0, 2, 2)
assert "user" not in out
# Test GATv2Conv with zero nodes in "tag" type
conv_v2 = nn.HeteroGraphConv({
("user", "buys", "item"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
("user", "likes", "tag"): nn.GATv2Conv(16, 2, num_heads=2, residual=residual),
}, aggregate="sum")
out_v2 = conv_v2(g, node_features)
assert out_v2["item"].shape == (10, 2, 2)
assert out_v2["tag"].shape == (0, 2, 2)
assert "user" not in out_v2
# Test EdgeGATConv with zero nodes in "tag" type
edge_conv = nn.HeteroGraphConv({
("user", "buys", "item"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
("user", "likes", "tag"): nn.EdgeGATConv(16, 16, 2, num_heads=2, residual=residual),
}, aggregate="sum")
mod_kwargs = {
"buys": {"edge_feat": edge_features[("user", "buys", "item")]},
"likes": {"edge_feat": edge_features[("user", "likes", "tag")]},
}
out_edge = edge_conv(g, node_features, mod_kwargs=mod_kwargs)
assert out_edge["item"].shape == (10, 2, 2)
assert out_edge["tag"].shape == (0, 2, 2)
assert "user" not in out_edge