mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[Test] Fixing issues related to SampledHeteroCSC and StaticHeteroCSC deprecation. (#7799)
This commit is contained in:
@@ -9,7 +9,7 @@ from torch import nn
|
||||
from .cugraph_base import CuGraphBaseConv
|
||||
|
||||
try:
|
||||
from pylibcugraphops.pytorch import SampledHeteroCSC, StaticHeteroCSC
|
||||
from pylibcugraphops.pytorch import HeteroCSC
|
||||
from pylibcugraphops.pytorch.operators import (
|
||||
agg_hg_basis_n2n_post as RelGraphConvAgg,
|
||||
)
|
||||
@@ -188,27 +188,28 @@ class CuGraphRelGraphConv(CuGraphBaseConv):
|
||||
max_in_degree = g.in_degrees().max().item()
|
||||
|
||||
if max_in_degree < self.MAX_IN_DEGREE_MFG:
|
||||
_graph = SampledHeteroCSC(
|
||||
_graph = HeteroCSC(
|
||||
offsets,
|
||||
indices,
|
||||
edge_types_perm,
|
||||
max_in_degree,
|
||||
g.num_src_nodes(),
|
||||
self.num_rels,
|
||||
)
|
||||
else:
|
||||
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
|
||||
_graph = StaticHeteroCSC(
|
||||
_graph = HeteroCSC(
|
||||
offsets_fg,
|
||||
indices,
|
||||
edge_types_perm,
|
||||
g.num_src_nodes(),
|
||||
self.num_rels,
|
||||
)
|
||||
else:
|
||||
_graph = StaticHeteroCSC(
|
||||
_graph = HeteroCSC(
|
||||
offsets,
|
||||
indices,
|
||||
edge_types_perm,
|
||||
g.num_src_nodes(),
|
||||
self.num_rels,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user