mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Rename sampled_edge_ids for clarity. (#7704)
This commit is contained in:
committed by
GitHub
parent
1eb0f9c116
commit
bca5924296
@@ -599,7 +599,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
indices = C_sampled_subgraph.indices
|
||||
type_per_edge = C_sampled_subgraph.type_per_edge
|
||||
column = C_sampled_subgraph.original_column_node_ids
|
||||
sampled_edge_ids = C_sampled_subgraph.original_edge_ids
|
||||
edge_ids_in_fused_csc_sampling_graph = (
|
||||
C_sampled_subgraph.original_edge_ids
|
||||
)
|
||||
etype_offsets = C_sampled_subgraph.etype_offsets
|
||||
if etype_offsets is not None:
|
||||
etype_offsets = etype_offsets.tolist()
|
||||
@@ -610,17 +612,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
)
|
||||
original_edge_ids = (
|
||||
torch.ops.graphbolt.index_select(
|
||||
self.edge_attributes[ORIGINAL_EDGE_ID], sampled_edge_ids
|
||||
self.edge_attributes[ORIGINAL_EDGE_ID],
|
||||
edge_ids_in_fused_csc_sampling_graph,
|
||||
)
|
||||
if has_original_eids
|
||||
else sampled_edge_ids
|
||||
else edge_ids_in_fused_csc_sampling_graph
|
||||
)
|
||||
if type_per_edge is None and etype_offsets is None:
|
||||
# The sampled graph is already a homogeneous graph.
|
||||
sampled_csc = CSCFormatBase(indptr=indptr, indices=indices)
|
||||
if indices is not None:
|
||||
# Only needed to fetch indices.
|
||||
sampled_edge_ids = None
|
||||
edge_ids_in_fused_csc_sampling_graph = None
|
||||
else:
|
||||
offset = self._node_type_offset_list
|
||||
|
||||
@@ -660,9 +663,9 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
original_hetero_edge_ids[etype] = original_edge_ids[
|
||||
eids
|
||||
]
|
||||
sampled_hetero_edge_ids = None
|
||||
sampled_hetero_edge_ids_in_fused_csc_sampling_graph = None
|
||||
else:
|
||||
sampled_hetero_edge_ids = {}
|
||||
sampled_hetero_edge_ids_in_fused_csc_sampling_graph = {}
|
||||
edge_offsets = [0]
|
||||
for etype, etype_id in self.edge_type_to_id.items():
|
||||
src_ntype, _, dst_ntype = etype_str_to_tuple(etype)
|
||||
@@ -693,14 +696,18 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
]
|
||||
if indices is None:
|
||||
# Only needed to fetch indices.
|
||||
sampled_hetero_edge_ids[etype] = sampled_edge_ids[
|
||||
sampled_hetero_edge_ids_in_fused_csc_sampling_graph[
|
||||
etype
|
||||
] = edge_ids_in_fused_csc_sampling_graph[
|
||||
etype_offsets[etype_id] : etype_offsets[
|
||||
etype_id + 1
|
||||
]
|
||||
]
|
||||
|
||||
original_edge_ids = original_hetero_edge_ids
|
||||
sampled_edge_ids = sampled_hetero_edge_ids
|
||||
edge_ids_in_fused_csc_sampling_graph = (
|
||||
sampled_hetero_edge_ids_in_fused_csc_sampling_graph
|
||||
)
|
||||
sampled_csc = {
|
||||
etype: CSCFormatBase(
|
||||
indptr=sub_indptr[etype],
|
||||
@@ -711,7 +718,7 @@ class FusedCSCSamplingGraph(SamplingGraph):
|
||||
return SampledSubgraphImpl(
|
||||
sampled_csc=sampled_csc,
|
||||
original_edge_ids=original_edge_ids,
|
||||
_sampled_edge_ids=sampled_edge_ids,
|
||||
_edge_ids_in_fused_csc_sampling_graph=edge_ids_in_fused_csc_sampling_graph,
|
||||
)
|
||||
|
||||
def sample_neighbors(
|
||||
|
||||
@@ -357,21 +357,27 @@ class SamplePerLayer(MiniBatchTransformer):
|
||||
if isinstance(subgraph.sampled_csc, dict):
|
||||
for etype, pair in subgraph.sampled_csc.items():
|
||||
if pair.indices is None:
|
||||
edge_ids = subgraph._sampled_edge_ids[etype]
|
||||
edge_ids = (
|
||||
subgraph._edge_ids_in_fused_csc_sampling_graph[
|
||||
etype
|
||||
]
|
||||
)
|
||||
edge_ids.record_stream(torch.cuda.current_stream())
|
||||
pair.indices = record_stream(
|
||||
index_select(indices, edge_ids)
|
||||
)
|
||||
minibatch._indices_needs_offset_subtraction = True
|
||||
elif subgraph.sampled_csc.indices is None:
|
||||
subgraph._sampled_edge_ids.record_stream(
|
||||
subgraph._edge_ids_in_fused_csc_sampling_graph.record_stream(
|
||||
torch.cuda.current_stream()
|
||||
)
|
||||
subgraph.sampled_csc.indices = record_stream(
|
||||
index_select(indices, subgraph._sampled_edge_ids)
|
||||
index_select(
|
||||
indices, subgraph._edge_ids_in_fused_csc_sampling_graph
|
||||
)
|
||||
)
|
||||
minibatch._indices_needs_offset_subtraction = True
|
||||
subgraph._sampled_edge_ids = None
|
||||
subgraph._edge_ids_in_fused_csc_sampling_graph = None
|
||||
minibatch.wait = torch.cuda.current_stream().record_event().wait
|
||||
|
||||
return minibatch
|
||||
|
||||
@@ -46,7 +46,9 @@ class SampledSubgraphImpl(SampledSubgraph):
|
||||
original_row_node_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
|
||||
original_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
|
||||
# Used to fetch sampled_csc.indices if it is missing.
|
||||
_sampled_edge_ids: Union[Dict[str, torch.Tensor], torch.Tensor] = None
|
||||
_edge_ids_in_fused_csc_sampling_graph: Union[
|
||||
Dict[str, torch.Tensor], torch.Tensor
|
||||
] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.sampled_csc, dict):
|
||||
@@ -65,7 +67,10 @@ class SampledSubgraphImpl(SampledSubgraph):
|
||||
), "Node pair should be have indices of type torch.Tensor."
|
||||
else:
|
||||
assert isinstance(
|
||||
self._sampled_edge_ids.get(etype, None), torch.Tensor
|
||||
self._edge_ids_in_fused_csc_sampling_graph.get(
|
||||
etype, None
|
||||
),
|
||||
torch.Tensor,
|
||||
), "When indices is missing, sampled edge ids needs to be provided."
|
||||
else:
|
||||
assert self.sampled_csc.indptr is not None and isinstance(
|
||||
@@ -81,7 +86,7 @@ class SampledSubgraphImpl(SampledSubgraph):
|
||||
), "Node pair should have a torch.Tensor indices."
|
||||
else:
|
||||
assert isinstance(
|
||||
self._sampled_edge_ids, torch.Tensor
|
||||
self._edge_ids_in_fused_csc_sampling_graph, torch.Tensor
|
||||
), "When indices is missing, sampled edge ids needs to be provided."
|
||||
|
||||
def __repr__(self) -> str:
|
||||
@@ -95,7 +100,7 @@ def _sampled_subgraph_str(sampled_subgraph: SampledSubgraph, classname) -> str:
|
||||
attributes.reverse()
|
||||
|
||||
for name in attributes:
|
||||
if name in "_sampled_edge_ids":
|
||||
if name in "_edge_ids_in_fused_csc_sampling_graph":
|
||||
continue
|
||||
val = getattr(sampled_subgraph, name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user