[distGB] graphbolt graph edge's mask will be filled with 0 if these edges have no mask initial (#7846)

This commit is contained in:
Wenxuan Cao
2025-01-09 19:52:10 +08:00
committed by GitHub
parent 540dd2ba4d
commit 17017c2899
3 changed files with 110 additions and 14 deletions

View File

@@ -311,7 +311,7 @@ class Collator(ABC):
raise NotImplementedError
@staticmethod
def add_edge_attribute_to_graph(g, data_name):
def add_edge_attribute_to_graph(g, data_name, gb_padding):
"""Add data into the graph as an edge attribute.
For some cases such as prob/mask-based sampling on GraphBolt partitions,
@@ -327,9 +327,11 @@ class Collator(ABC):
The graph.
data_name : str
The name of data that's stored in DistGraph.ndata/edata.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes.
"""
if g._use_graphbolt and data_name:
g.add_edge_attribute(data_name)
g.add_edge_attribute(data_name, gb_padding)
class NodeCollator(Collator):
@@ -344,6 +346,11 @@ class NodeCollator(Collator):
The node set to compute outputs.
graph_sampler : dgl.dataloading.BlockSampler
The neighborhood sampler.
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
Examples
--------
@@ -366,7 +373,7 @@ class NodeCollator(Collator):
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
"""
def __init__(self, g, nids, graph_sampler):
def __init__(self, g, nids, graph_sampler, gb_padding=1):
self.g = g
if not isinstance(nids, Mapping):
assert (
@@ -380,7 +387,7 @@ class NodeCollator(Collator):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)
@property
@@ -508,8 +515,11 @@ class EdgeCollator(Collator):
A set of builtin negative samplers are provided in
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
Examples
gb_padding : int, optional
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
e.g. prob/mask-based sampling.
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
--------
The following example shows how to train a 3-layer GNN for edge classification on a
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
@@ -612,6 +622,7 @@ class EdgeCollator(Collator):
reverse_eids=None,
reverse_etypes=None,
negative_sampler=None,
gb_padding=1,
):
self.g = g
if not isinstance(eids, Mapping):
@@ -642,7 +653,7 @@ class EdgeCollator(Collator):
# Add prob/mask into graphbolt partition's edge attributes if needed.
if hasattr(self.graph_sampler, "prob"):
Collator.add_edge_attribute_to_graph(
self.g, self.graph_sampler.prob
self.g, self.graph_sampler.prob, gb_padding
)
@property

View File

@@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
class AddEdgeAttributeFromKVRequest(rpc.Request):
"""Add edge attribute from kvstore to local GraphBolt partition."""
def __init__(self, name, kv_names):
def __init__(self, name, kv_names, padding):
self._name = name
self._kv_names = kv_names
self._padding = padding
def __getstate__(self):
return self._name, self._kv_names
return self._name, self._kv_names, self._padding
def __setstate__(self, state):
self._name, self._kv_names = state
self._name, self._kv_names, self._padding = state
def process_request(self, server_state):
# For now, this is only used to add prob/mask data to the graph.
@@ -169,7 +170,13 @@ class AddEdgeAttributeFromKVRequest(rpc.Request):
gpb = server_state.partition_book
# Initialize the edge attribute.
num_edges = g.total_num_edges
attr_data = torch.zeros(num_edges, dtype=data_type)
# Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.
# In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges.
# In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes.
# To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1),
# allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph.
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
# Map data from kvstore to the local partition for inner edges only.
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
homo_eids = g.edge_attributes[EID][:num_inner_edges]
@@ -1620,13 +1627,15 @@ class DistGraph:
edata_names.append(name)
return edata_names
def add_edge_attribute(self, name):
def add_edge_attribute(self, name, padding):
"""Add an edge attribute into GraphBolt partition from edge data.
Parameters
----------
name : str
The name of the edge attribute.
padding : int, optional
The padding value for the new edge attribute.
"""
# Sanity checks.
if not self._use_graphbolt:
@@ -1643,7 +1652,7 @@ class DistGraph:
]
rpc.send_request(
self._client._main_server_id,
AddEdgeAttributeFromKVRequest(name, kv_names),
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
)
# Wait for the response.
assert rpc.recv_response()._name == name

View File

@@ -7,8 +7,9 @@ import traceback
import unittest
from pathlib import Path
import backend as F
import dgl
import dgl.backend as F
import numpy as np
import pytest
import torch
@@ -1858,6 +1859,81 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
)
def check_hetero_dist_edge_dataloader_gb(
tmpdir, num_server, use_graphbolt=True
):
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
g = create_random_hetero()
eids = torch.randperm(g.num_edges("r23"))[:10]
mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool)
mask[eids] = True
num_parts = num_server
orig_nid_map, orig_eid_map = partition_graph(
g,
"test_sampling",
num_parts,
tmpdir,
num_hops=1,
part_method="metis",
return_mapping=True,
use_graphbolt=use_graphbolt,
store_eids=True,
)
part_config = tmpdir / "test_sampling.json"
pserver_list = []
ctx = mp.get_context("spawn")
for i in range(num_server):
p = ctx.Process(
target=start_server,
args=(
i,
tmpdir,
num_server > 1,
"test_sampling",
["csc", "coo"],
True,
),
)
p.start()
time.sleep(1)
pserver_list.append(p)
dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True)
dist_graph = DistGraph("test_sampling", part_config=part_config)
os.environ["DGL_DIST_DEBUG"] = "1"
edges = {("n2", "r23", "n3"): eids}
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
loader = dgl.dataloading.DistEdgeDataLoader(
dist_graph, edges, sampler, batch_size=64
)
dgl.distributed.exit_client()
for p in pserver_list:
p.join()
assert p.exitcode == 0
block = next(iter(loader))[2][0]
assert block.num_src_nodes("n1") > 0
assert block.num_edges("r12") > 0
assert block.num_edges("r13") > 0
assert block.num_edges("r23") > 0
def test_hetero_dist_edge_dataloader_gb(
num_server=1,
):
reset_envs()
os.environ["DGL_DIST_MODE"] = "distributed"
with tempfile.TemporaryDirectory() as tmpdirname:
check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server)
if __name__ == "__main__":
import tempfile