[GraphBolt] Rename sampled_edge_ids for clarity. (#7704)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-16 00:10:25 -04:00
committed by GitHub
parent 1eb0f9c116
commit bca5924296
3 changed files with 35 additions and 17 deletions

View File

@@ -599,7 +599,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
indices = C_sampled_subgraph.indices
type_per_edge = C_sampled_subgraph.type_per_edge
column = C_sampled_subgraph.original_column_node_ids
sampled_edge_ids = C_sampled_subgraph.original_edge_ids
edge_ids_in_fused_csc_sampling_graph = (
C_sampled_subgraph.original_edge_ids
)
etype_offsets = C_sampled_subgraph.etype_offsets
if etype_offsets is not None:
etype_offsets = etype_offsets.tolist()
@@ -610,17 +612,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
)
original_edge_ids = (
torch.ops.graphbolt.index_select(
self.edge_attributes[ORIGINAL_EDGE_ID], sampled_edge_ids
self.edge_attributes[ORIGINAL_EDGE_ID],
edge_ids_in_fused_csc_sampling_graph,
)
if has_original_eids
else sampled_edge_ids
else edge_ids_in_fused_csc_sampling_graph
)
if type_per_edge is None and etype_offsets is None:
# The sampled graph is already a homogeneous graph.
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
if indices is not None:
# Only needed to fetch indices.
sampled_edge_ids = None
edge_ids_in_fused_csc_sampling_graph = None
else:
offset = self._node_type_offset_list
@@ -660,9 +663,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
original_hetero_edge_ids[etype] = original_edge_ids[
eids
]
sampled_hetero_edge_ids = None
sampled_hetero_edge_ids_in_fused_csc_sampling_graph = None
else:
sampled_hetero_edge_ids = {}
sampled_hetero_edge_ids_in_fused_csc_sampling_graph = {}
edge_offsets = [0]
for etype, etype_id in self.edge_type_to_id.items():
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
@@ -693,14 +696,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
]
if indices is None:
# Only needed to fetch indices.
sampled_hetero_edge_ids[etype] = sampled_edge_ids[
sampled_hetero_edge_ids_in_fused_csc_sampling_graph[
etype
] = edge_ids_in_fused_csc_sampling_graph[
etype_offsets[etype_id] : etype_offsets[
etype_id + 1
]
]
original_edge_ids = original_hetero_edge_ids
sampled_edge_ids = sampled_hetero_edge_ids
edge_ids_in_fused_csc_sampling_graph = (
sampled_hetero_edge_ids_in_fused_csc_sampling_graph
)
sampled_csc = {
etype: CSCFormatBase(
indptr=sub_indptr[etype],
@@ -711,7 +718,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
return SampledSubgraphImpl(
sampled_csc=sampled_csc,
original_edge_ids=original_edge_ids,
_sampled_edge_ids=sampled_edge_ids,
_edge_ids_in_fused_csc_sampling_graph=edge_ids_in_fused_csc_sampling_graph,
)
def sample_neighbors(

View File

@@ -357,21 +357,27 @@ class SamplePerLayer(MiniBatchTransformer):
if isinstance(subgraph.sampled_csc, dict):
for etype, pair in subgraph.sampled_csc.items():
if pair.indices is None:
edge_ids = subgraph._sampled_edge_ids[etype]
edge_ids = (
subgraph._edge_ids_in_fused_csc_sampling_graph[
etype
]
)
edge_ids.record_stream(torch.cuda.current_stream())
pair.indices = record_stream(
index_select(indices, edge_ids)
)
minibatch._indices_needs_offset_subtraction = True
elif subgraph.sampled_csc.indices is None:
subgraph._sampled_edge_ids.record_stream(
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
torch.cuda.current_stream()
)
subgraph.sampled_csc.indices = record_stream(
index_select(indices, subgraph._sampled_edge_ids)
index_select(
indices, subgraph._edge_ids_in_fused_csc_sampling_graph
)
)
minibatch._indices_needs_offset_subtraction = True
subgraph._sampled_edge_ids = None
subgraph._edge_ids_in_fused_csc_sampling_graph = None
minibatch.wait = torch.cuda.current_stream().record_event().wait
return minibatch

View File

@@ -46,7 +46,9 @@ class SampledSubgraphImpl(SampledSubgraph):
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
# Used to fetch sampled_csc.indices if it is missing.
_sampled_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
_edge_ids_in_fused_csc_sampling_graph: Union[
Dict[str, torch.Tensor], torch.Tensor
] = None
def __post_init__(self):
if isinstance(self.sampled_csc, dict):
@@ -65,7 +67,10 @@ class SampledSubgraphImpl(SampledSubgraph):
), "Node pair should be have indices of type torch.Tensor."
else:
assert isinstance(
self._sampled_edge_ids.get(etype, None), torch.Tensor
self._edge_ids_in_fused_csc_sampling_graph.get(
etype, None
),
torch.Tensor,
), "When indices is missing, sampled edge ids needs to be provided."
else:
assert self.sampled_csc.indptr is not None and isinstance(
@@ -81,7 +86,7 @@ class SampledSubgraphImpl(SampledSubgraph):
), "Node pair should have a torch.Tensor indices."
else:
assert isinstance(
self._sampled_edge_ids, torch.Tensor
self._edge_ids_in_fused_csc_sampling_graph, torch.Tensor
), "When indices is missing, sampled edge ids needs to be provided."
def __repr__(self) -> str:
@@ -95,7 +100,7 @@ def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:
attributes.reverse()
for name in attributes:
if name in "_sampled_edge_ids":
if name in "_edge_ids_in_fused_csc_sampling_graph":
continue
val = getattr(sampled_subgraph, name)