[Test] Fixing issues related to SampledHeteroCSC and StaticHeteroCSC deprecation. (#7799)

This commit is contained in:
Andrei Ivanov
2025-01-05 18:34:28 -08:00
committed by GitHub
parent 88f109f173
commit 275183b16b

View File

@@ -9,7 +9,7 @@ from torch import nn
from .cugraph_base import CuGraphBaseConv
try:
from pylibcugraphops.pytorch import SampledHeteroCSC, StaticHeteroCSC
from pylibcugraphops.pytorch import HeteroCSC
from pylibcugraphops.pytorch.operators import (
agg_hg_basis_n2n_post as RelGraphConvAgg,
)
@@ -188,27 +188,28 @@ class CuGraphRelGraphConv(CuGraphBaseConv):
max_in_degree = g.in_degrees().max().item()
if max_in_degree < self.MAX_IN_DEGREE_MFG:
_graph = SampledHeteroCSC(
_graph = HeteroCSC(
offsets,
indices,
edge_types_perm,
max_in_degree,
g.num_src_nodes(),
self.num_rels,
)
else:
offsets_fg = self.pad_offsets(offsets, g.num_src_nodes() + 1)
_graph = StaticHeteroCSC(
_graph = HeteroCSC(
offsets_fg,
indices,
edge_types_perm,
g.num_src_nodes(),
self.num_rels,
)
else:
_graph = StaticHeteroCSC(
_graph = HeteroCSC(
offsets,
indices,
edge_types_perm,
g.num_src_nodes(),
self.num_rels,
)