mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[distGB] graphbolt graph edge's mask will be filled with 0 if these edges have no mask initial (#7846)
This commit is contained in:
@@ -311,7 +311,7 @@ class Collator(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def add_edge_attribute_to_graph(g, data_name):
|
||||
def add_edge_attribute_to_graph(g, data_name, gb_padding):
|
||||
"""Add data into the graph as an edge attribute.
|
||||
|
||||
For some cases such as prob/mask-based sampling on GraphBolt partitions,
|
||||
@@ -327,9 +327,11 @@ class Collator(ABC):
|
||||
The graph.
|
||||
data_name : str
|
||||
The name of data that's stored in DistGraph.ndata/edata.
|
||||
gb_padding : int, optional
|
||||
The padding value for GraphBolt partitions' new edge_attributes.
|
||||
"""
|
||||
if g._use_graphbolt and data_name:
|
||||
g.add_edge_attribute(data_name)
|
||||
g.add_edge_attribute(data_name, gb_padding)
|
||||
|
||||
|
||||
class NodeCollator(Collator):
|
||||
@@ -344,6 +346,11 @@ class NodeCollator(Collator):
|
||||
The node set to compute outputs.
|
||||
graph_sampler : dgl.dataloading.BlockSampler
|
||||
The neighborhood sampler.
|
||||
gb_padding : int, optional
|
||||
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
|
||||
e.g. prob/mask-based sampling.
|
||||
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
|
||||
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -366,7 +373,7 @@ class NodeCollator(Collator):
|
||||
:doc:`Minibatch Training Tutorials <tutorials/large/L0_neighbor_sampling_overview>`.
|
||||
"""
|
||||
|
||||
def __init__(self, g, nids, graph_sampler):
|
||||
def __init__(self, g, nids, graph_sampler, gb_padding=1):
|
||||
self.g = g
|
||||
if not isinstance(nids, Mapping):
|
||||
assert (
|
||||
@@ -380,7 +387,7 @@ class NodeCollator(Collator):
|
||||
# Add prob/mask into graphbolt partition's edge attributes if needed.
|
||||
if hasattr(self.graph_sampler, "prob"):
|
||||
Collator.add_edge_attribute_to_graph(
|
||||
self.g, self.graph_sampler.prob
|
||||
self.g, self.graph_sampler.prob, gb_padding
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -508,8 +515,11 @@ class EdgeCollator(Collator):
|
||||
|
||||
A set of builtin negative samplers are provided in
|
||||
:ref:`the negative sampling module <api-dataloading-negative-sampling>`.
|
||||
|
||||
Examples
|
||||
gb_padding : int, optional
|
||||
The padding value for GraphBolt partitions' new edge_attributes if the attributes in DistGraph are None.
|
||||
e.g. prob/mask-based sampling.
|
||||
Only when the mask of one edge is set as 1, an edge will be sampled in dgl.graphbolt.FusedCSCSamplingGraph.sample_neighbors.
|
||||
The argument will be used in add_edge_attribute_to_graph to add new edge_attributes in graphbolt.
|
||||
--------
|
||||
The following example shows how to train a 3-layer GNN for edge classification on a
|
||||
set of edges ``train_eid`` on a homogeneous undirected graph. Each node takes
|
||||
@@ -612,6 +622,7 @@ class EdgeCollator(Collator):
|
||||
reverse_eids=None,
|
||||
reverse_etypes=None,
|
||||
negative_sampler=None,
|
||||
gb_padding=1,
|
||||
):
|
||||
self.g = g
|
||||
if not isinstance(eids, Mapping):
|
||||
@@ -642,7 +653,7 @@ class EdgeCollator(Collator):
|
||||
# Add prob/mask into graphbolt partition's edge attributes if needed.
|
||||
if hasattr(self.graph_sampler, "prob"):
|
||||
Collator.add_edge_attribute_to_graph(
|
||||
self.g, self.graph_sampler.prob
|
||||
self.g, self.graph_sampler.prob, gb_padding
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -143,15 +143,16 @@ def _copy_data_from_shared_mem(name, shape):
|
||||
class AddEdgeAttributeFromKVRequest(rpc.Request):
|
||||
"""Add edge attribute from kvstore to local GraphBolt partition."""
|
||||
|
||||
def __init__(self, name, kv_names):
|
||||
def __init__(self, name, kv_names, padding):
|
||||
self._name = name
|
||||
self._kv_names = kv_names
|
||||
self._padding = padding
|
||||
|
||||
def __getstate__(self):
|
||||
return self._name, self._kv_names
|
||||
return self._name, self._kv_names, self._padding
|
||||
|
||||
def __setstate__(self, state):
|
||||
self._name, self._kv_names = state
|
||||
self._name, self._kv_names, self._padding = state
|
||||
|
||||
def process_request(self, server_state):
|
||||
# For now, this is only used to add prob/mask data to the graph.
|
||||
@@ -169,7 +170,13 @@ class AddEdgeAttributeFromKVRequest(rpc.Request):
|
||||
gpb = server_state.partition_book
|
||||
# Initialize the edge attribute.
|
||||
num_edges = g.total_num_edges
|
||||
attr_data = torch.zeros(num_edges, dtype=data_type)
|
||||
|
||||
# Padding is used to fill missing edge attributes (e.g., 'prob' or 'mask') for certain edge types.
|
||||
# In DGLGraph, some edges may lack these attributes or have them set to None, but DGL will still sample these edges.
|
||||
# In contrast, GraphBolt samples edges based on specific attributes (e.g., 'mask' == 1) and will skip edges with missing attributes.
|
||||
# To ensure consistent sampling behavior in GraphBolt, we pad missing attributes with default values (e.g., 'mask' = 1),
|
||||
# allowing all edges to be sampled, even if their attributes were missing or None in DGLGraph.
|
||||
attr_data = torch.full((num_edges,), self._padding, dtype=data_type)
|
||||
# Map data from kvstore to the local partition for inner edges only.
|
||||
num_inner_edges = gpb.metadata()[gpb.partid]["num_edges"]
|
||||
homo_eids = g.edge_attributes[EID][:num_inner_edges]
|
||||
@@ -1620,13 +1627,15 @@ class DistGraph:
|
||||
edata_names.append(name)
|
||||
return edata_names
|
||||
|
||||
def add_edge_attribute(self, name):
|
||||
def add_edge_attribute(self, name, padding):
|
||||
"""Add an edge attribute into GraphBolt partition from edge data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
The name of the edge attribute.
|
||||
padding : int, optional
|
||||
The padding value for the new edge attribute.
|
||||
"""
|
||||
# Sanity checks.
|
||||
if not self._use_graphbolt:
|
||||
@@ -1643,7 +1652,7 @@ class DistGraph:
|
||||
]
|
||||
rpc.send_request(
|
||||
self._client._main_server_id,
|
||||
AddEdgeAttributeFromKVRequest(name, kv_names),
|
||||
AddEdgeAttributeFromKVRequest(name, kv_names, padding),
|
||||
)
|
||||
# Wait for the response.
|
||||
assert rpc.recv_response()._name == name
|
||||
|
||||
@@ -7,8 +7,9 @@ import traceback
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import backend as F
|
||||
import dgl
|
||||
|
||||
import dgl.backend as F
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
@@ -1858,6 +1859,81 @@ def test_local_sampling_heterograph(num_parts, use_graphbolt, prob_or_mask):
|
||||
)
|
||||
|
||||
|
||||
def check_hetero_dist_edge_dataloader_gb(
|
||||
tmpdir, num_server, use_graphbolt=True
|
||||
):
|
||||
generate_ip_config("rpc_ip_config.txt", num_server, num_server)
|
||||
|
||||
g = create_random_hetero()
|
||||
eids = torch.randperm(g.num_edges("r23"))[:10]
|
||||
mask = torch.zeros(g.num_edges("r23"), dtype=torch.bool)
|
||||
mask[eids] = True
|
||||
|
||||
num_parts = num_server
|
||||
|
||||
orig_nid_map, orig_eid_map = partition_graph(
|
||||
g,
|
||||
"test_sampling",
|
||||
num_parts,
|
||||
tmpdir,
|
||||
num_hops=1,
|
||||
part_method="metis",
|
||||
return_mapping=True,
|
||||
use_graphbolt=use_graphbolt,
|
||||
store_eids=True,
|
||||
)
|
||||
|
||||
part_config = tmpdir / "test_sampling.json"
|
||||
|
||||
pserver_list = []
|
||||
ctx = mp.get_context("spawn")
|
||||
for i in range(num_server):
|
||||
p = ctx.Process(
|
||||
target=start_server,
|
||||
args=(
|
||||
i,
|
||||
tmpdir,
|
||||
num_server > 1,
|
||||
"test_sampling",
|
||||
["csc", "coo"],
|
||||
True,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
time.sleep(1)
|
||||
pserver_list.append(p)
|
||||
|
||||
dgl.distributed.initialize("rpc_ip_config.txt", use_graphbolt=True)
|
||||
dist_graph = DistGraph("test_sampling", part_config=part_config)
|
||||
|
||||
os.environ["DGL_DIST_DEBUG"] = "1"
|
||||
|
||||
edges = {("n2", "r23", "n3"): eids}
|
||||
sampler = dgl.dataloading.MultiLayerNeighborSampler([10, 10], mask="mask")
|
||||
loader = dgl.dataloading.DistEdgeDataLoader(
|
||||
dist_graph, edges, sampler, batch_size=64
|
||||
)
|
||||
dgl.distributed.exit_client()
|
||||
for p in pserver_list:
|
||||
p.join()
|
||||
assert p.exitcode == 0
|
||||
|
||||
block = next(iter(loader))[2][0]
|
||||
assert block.num_src_nodes("n1") > 0
|
||||
assert block.num_edges("r12") > 0
|
||||
assert block.num_edges("r13") > 0
|
||||
assert block.num_edges("r23") > 0
|
||||
|
||||
|
||||
def test_hetero_dist_edge_dataloader_gb(
|
||||
num_server=1,
|
||||
):
|
||||
reset_envs()
|
||||
os.environ["DGL_DIST_MODE"] = "distributed"
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
check_hetero_dist_edge_dataloader_gb(Path(tmpdirname), num_server)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import tempfile
|
||||
|
||||
|
||||
Reference in New Issue
Block a user