From 3bc8e228fc87bb371d42cf97e8e8eb0159c5f8ae Mon Sep 17 00:00:00 2001 From: Wenxuan Cao <90617523+CfromBU@users.noreply.github.com> Date: Thu, 19 Sep 2024 17:05:11 +0800 Subject: [PATCH] [DistGB] enable dist partition pipeline to save FusedCSCSamplingGraph partition directly (#7728) Co-authored-by: Ubuntu Co-authored-by: Ubuntu Co-authored-by: Rhett Ying <85214957+Rhett-Ying@users.noreply.github.com> --- python/dgl/distributed/partition.py | 52 +- tests/tools/test_dist_partition_graphbolt.py | 1023 ++++++++++++++++++ tools/dispatch_data.py | 28 + tools/distpartitioning/convert_partition.py | 409 ++++++- tools/distpartitioning/data_proc_pipeline.py | 25 +- tools/distpartitioning/data_shuffle.py | 27 +- tools/distpartitioning/utils.py | 44 +- 7 files changed, 1520 insertions(+), 88 deletions(-) create mode 100644 tests/tools/test_dist_partition_graphbolt.py diff --git a/python/dgl/distributed/partition.py b/python/dgl/distributed/partition.py index 079ed8806a..48005ffb4d 100644 --- a/python/dgl/distributed/partition.py +++ b/python/dgl/distributed/partition.py @@ -1600,8 +1600,6 @@ def _save_graph_gb(part_config, part_id, csc_graph): def cast_various_to_minimum_dtype_gb( - graph, - part_meta, num_parts, indptr, indices, @@ -1610,25 +1608,43 @@ def cast_various_to_minimum_dtype_gb( ntypes, node_attributes, edge_attributes, + part_meta=None, + graph=None, + edge_count=None, + node_count=None, + tot_edge_count=None, + tot_node_count=None, ): """Cast various data to minimum dtype.""" + if graph is not None: + assert part_meta is not None + tot_edge_count = graph.num_edges() + tot_node_count = graph.num_nodes() + node_count = part_meta["num_nodes"] + edge_count = part_meta["num_edges"] + else: + assert tot_edge_count is not None + assert tot_node_count is not None + assert edge_count is not None + assert node_count is not None + # Cast 1: indptr. - indptr = _cast_to_minimum_dtype(graph.num_edges(), indptr) + indptr = _cast_to_minimum_dtype(tot_edge_count, indptr) # Cast 2: indices. - indices = _cast_to_minimum_dtype(graph.num_nodes(), indices) + indices = _cast_to_minimum_dtype(tot_node_count, indices) # Cast 3: type_per_edge. type_per_edge = _cast_to_minimum_dtype( len(etypes), type_per_edge, field=ETYPE ) # Cast 4: node/edge_attributes. predicates = { - NID: part_meta["num_nodes"], + NID: node_count, "part_id": num_parts, NTYPE: len(ntypes), - EID: part_meta["num_edges"], + EID: edge_count, ETYPE: len(etypes), - DGL2GB_EID: part_meta["num_edges"], - GB_DST_ID: part_meta["num_nodes"], + DGL2GB_EID: edge_count, + GB_DST_ID: node_count, } for attributes in [node_attributes, edge_attributes]: for key in attributes: @@ -1779,16 +1795,16 @@ def gb_convert_single_dgl_partition( ) indptr, indices, type_per_edge = cast_various_to_minimum_dtype_gb( - graph, - part_meta, - num_parts, - indptr, - indices, - type_per_edge, - etypes, - ntypes, - node_attributes, - edge_attributes, + graph=graph, + part_meta=part_meta, + num_parts=num_parts, + indptr=indptr, + indices=indices, + type_per_edge=type_per_edge, + etypes=etypes, + ntypes=ntypes, + node_attributes=node_attributes, + edge_attributes=edge_attributes, ) csc_graph = gb.fused_csc_sampling_graph( diff --git a/tests/tools/test_dist_partition_graphbolt.py b/tests/tools/test_dist_partition_graphbolt.py new file mode 100644 index 0000000000..81c16f8809 --- /dev/null +++ b/tests/tools/test_dist_partition_graphbolt.py @@ -0,0 +1,1023 @@ +import json +import os +import tempfile + +import dgl +import dgl.backend as F +import dgl.graphbolt as gb + +import numpy as np +import pyarrow.parquet as pq +import pytest +import torch +from dgl.data.utils import load_graphs, load_tensors +from dgl.distributed.partition import ( + _etype_str_to_tuple, + _etype_tuple_to_str, + _get_inner_edge_mask, + _get_inner_node_mask, + load_partition, + RESERVED_FIELD_DTYPE, +) + +from distpartitioning import array_readwriter +from distpartitioning.utils import generate_read_list +from pytest_utils import create_chunked_dataset + + +def _verify_metadata_gb(gpb, g, num_parts, part_id, part_sizes): + """ + check list: + make sure the number of nodes and edges is correct. + make sure the number of parts is correct. + make sure the number of nodes and edges in each part is corrcet. + """ + assert gpb._num_nodes() == g.num_nodes() + assert gpb._num_edges() == g.num_edges() + + assert gpb.num_partitions() == num_parts + gpb_meta = gpb.metadata() + assert len(gpb_meta) == num_parts + assert len(gpb.partid2nids(part_id)) == gpb_meta[part_id]["num_nodes"] + assert len(gpb.partid2eids(part_id)) == gpb_meta[part_id]["num_edges"] + part_sizes.append( + (gpb_meta[part_id]["num_nodes"], gpb_meta[part_id]["num_edges"]) + ) + + +def _verify_local_id_gb(part_g, part_id, gpb): + """ + check list: + make sure the type of local id is correct. + make sure local id have a right order. + """ + nid = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + local_nid = gpb.nid2localnid(nid, part_id) + assert F.dtype(local_nid) in (F.int64, F.int32) + assert np.all(F.asnumpy(local_nid) == np.arange(0, len(local_nid))) + eid = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + local_eid = gpb.eid2localeid(eid, part_id) + assert F.dtype(local_eid) in (F.int64, F.int32) + assert np.all(np.sort(F.asnumpy(local_eid)) == np.arange(0, len(local_eid))) + return local_nid, local_eid + + +def _verify_map_gb( + part_g, + part_id, + gpb, +): + """ + check list: + make sure the map node and its data type is correct. + """ + # Check the node map. + local_nodes = F.boolean_mask( + part_g.node_attributes[dgl.NID], + part_g.node_attributes["inner_node"], + ) + inner_node_index = F.nonzero_1d(part_g.node_attributes["inner_node"]) + mapping_nodes = gpb.partid2nids(part_id) + assert F.dtype(mapping_nodes) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_nodes)) == np.sort(F.asnumpy(mapping_nodes)) + ) + assert np.all( + F.asnumpy(inner_node_index) == np.arange(len(inner_node_index)) + ) + + # Check the edge map. + + local_edges = F.boolean_mask( + part_g.edge_attributes[dgl.EID], + part_g.edge_attributes["inner_edge"], + ) + inner_edge_index = F.nonzero_1d(part_g.edge_attributes["inner_edge"]) + mapping_edges = gpb.partid2eids(part_id) + assert F.dtype(mapping_edges) in (F.int32, F.int64) + assert np.all( + np.sort(F.asnumpy(local_edges)) == np.sort(F.asnumpy(mapping_edges)) + ) + assert np.all( + F.asnumpy(inner_edge_index) == np.arange(len(inner_edge_index)) + ) + return local_nodes, local_edges + + +def _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, +): + """ + check list: + make sure local id are correct. + make sure mapping id are correct. + """ + if store_inner_node and store_inner_edge and store_eids: + _verify_local_id_gb(part_g, part_id, gpb) + _verify_map_gb(part_g, part_id, gpb) + + +def _get_part_IDs(part_g): + # These are partition-local IDs. + num_columns = part_g.csc_indptr.diff() + part_src_ids = part_g.indices + part_dst_ids = torch.arange(part_g.total_num_nodes).repeat_interleave( + num_columns + ) + # These are reshuffled global homogeneous IDs. + part_src_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_src_ids) + part_dst_ids = F.gather_row(part_g.node_attributes[dgl.NID], part_dst_ids) + return part_src_ids, part_dst_ids + + +def _verify_node_type_ID_gb(part_g, gpb): + """ + check list: + make sure ntype id have correct data type + """ + part_src_ids, part_dst_ids = _get_part_IDs(part_g) + # These are reshuffled per-type IDs. + src_ntype_ids, part_src_ids = gpb.map_to_per_ntype(part_src_ids) + dst_ntype_ids, part_dst_ids = gpb.map_to_per_ntype(part_dst_ids) + # `IdMap` is in int64 by default. + assert src_ntype_ids.dtype == F.int64 + assert dst_ntype_ids.dtype == F.int64 + + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_ntype(F.tensor([0], F.int32)) + with pytest.raises(dgl.utils.internal.InconsistentDtypeException): + gpb.map_to_per_etype(F.tensor([0], F.int32)) + return ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) + + +def _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids, + part_src_ids, + part_dst_ids, + src_ntype=None, + dst_ntype=None, + etype=None, +): + """ + check list: + make sure orig edge id are correct after + """ + if src_ntype is not None and dst_ntype is not None: + orig_src_nid = orig_nids[src_ntype] + orig_dst_nid = orig_nids[dst_ntype] + else: + orig_src_nid = orig_nids + orig_dst_nid = orig_nids + orig_src_ids = F.gather_row(orig_src_nid, part_src_ids) + orig_dst_ids = F.gather_row(orig_dst_nid, part_dst_ids) + if etype is not None: + orig_eids = orig_eids[etype] + orig_eids1 = F.gather_row(orig_eids, part_eids) + orig_eids2 = g.edge_ids(orig_src_ids, orig_dst_ids, etype=etype) + assert len(orig_eids1) == len(orig_eids2) + assert np.all(F.asnumpy(orig_eids1) == F.asnumpy(orig_eids2)) + + +def _verify_orig_IDs_gb( + part_g, + gpb, + g, + is_homo=False, + part_src_ids=None, + part_dst_ids=None, + src_ntype_ids=None, + dst_ntype_ids=None, + orig_nids=None, + orig_eids=None, +): + """ + check list: + make sure orig edge id are correct. + make sure hetero ntype id are correct. + """ + part_eids = part_g.edge_attributes[dgl.EID] + if is_homo: + _verify_orig_edge_IDs_gb( + g, orig_nids, orig_eids, part_eids, part_src_ids, part_dst_ids + ) + local_orig_nids = orig_nids[part_g.node_attributes[dgl.NID]] + local_orig_eids = orig_eids[part_g.edge_attributes[dgl.EID]] + part_g.node_attributes["feats"] = F.gather_row( + g.ndata["feats"], local_orig_nids + ) + part_g.edge_attributes["feats"] = F.gather_row( + g.edata["feats"], local_orig_eids + ) + else: + etype_ids, part_eids = gpb.map_to_per_etype(part_eids) + # `IdMap` is in int64 by default. + assert etype_ids.dtype == F.int64 + + # These are original per-type IDs. + for etype_id, etype in enumerate(g.canonical_etypes): + part_src_ids1 = F.boolean_mask(part_src_ids, etype_ids == etype_id) + src_ntype_ids1 = F.boolean_mask( + src_ntype_ids, etype_ids == etype_id + ) + part_dst_ids1 = F.boolean_mask(part_dst_ids, etype_ids == etype_id) + dst_ntype_ids1 = F.boolean_mask( + dst_ntype_ids, etype_ids == etype_id + ) + part_eids1 = F.boolean_mask(part_eids, etype_ids == etype_id) + assert np.all(F.asnumpy(src_ntype_ids1 == src_ntype_ids1[0])) + assert np.all(F.asnumpy(dst_ntype_ids1 == dst_ntype_ids1[0])) + src_ntype = g.ntypes[F.as_scalar(src_ntype_ids1[0])] + dst_ntype = g.ntypes[F.as_scalar(dst_ntype_ids1[0])] + + _verify_orig_edge_IDs_gb( + g, + orig_nids, + orig_eids, + part_eids1, + part_src_ids1, + part_dst_ids1, + src_ntype, + dst_ntype, + etype, + ) + + +def _verify_constructed_id_gb(part_sizes, gpb): + """ + verify the part id of each node by constructed nids. + check list: + make sure each node' part id and its type are corect + """ + node_map = [] + edge_map = [] + for part_i, (num_nodes, num_edges) in enumerate(part_sizes): + node_map.append(np.ones(num_nodes) * part_i) + edge_map.append(np.ones(num_edges) * part_i) + node_map = np.concatenate(node_map) + edge_map = np.concatenate(edge_map) + nid2pid = gpb.nid2partid(F.arange(0, len(node_map))) + assert F.dtype(nid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(nid2pid) == node_map) + eid2pid = gpb.eid2partid(F.arange(0, len(edge_map))) + assert F.dtype(eid2pid) in (F.int32, F.int64) + assert np.all(F.asnumpy(eid2pid) == edge_map) + + +def _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, +): + # verify local id and mapping id + _verify_local_and_map_id_gb( + part_g, + part_id, + gpb, + store_inner_node, + store_inner_edge, + store_eids, + ) + + # Verify the mapping between the reshuffled IDs and the original IDs. + ( + part_src_ids, + part_dst_ids, + src_ntype_ids, + part_src_ids, + dst_ntype_ids, + ) = _verify_node_type_ID_gb(part_g, gpb) + + if store_eids: + _verify_orig_IDs_gb( + part_g, + gpb, + g, + part_src_ids=part_src_ids, + part_dst_ids=part_dst_ids, + src_ntype_ids=src_ntype_ids, + dst_ntype_ids=dst_ntype_ids, + orig_nids=orig_nids, + orig_eids=orig_eids, + is_homo=is_homo, + ) + _verify_constructed_id_gb(part_sizes, gpb) + + +def _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, +): + if test_ntype != None: + shuffled_labels.append(node_feats[test_ntype + "/label"]) + shuffled_edata.append( + edge_feats[_etype_tuple_to_str(test_etype) + "/count"] + ) + else: + shuffled_labels.append(node_feats["_N/labels"]) + shuffled_edata.append(edge_feats["_N:_E:_N/feats"]) + parts.append(part_g) + gpbs.append(gpb) + tot_node_feats.append(node_feats) + tot_edge_feats.append(edge_feats) + + +def _verify_node_feats(g, part, gpb, orig_nids, node_feats, is_homo=False): + for ntype in g.ntypes: + ndata = ( + part.node_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.ndata + ) + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask( + part, + ntype_id, + (gpb if isinstance(part, gb.FusedCSCSamplingGraph) else None), + ) + inner_nids = F.boolean_mask(ndata[dgl.NID], inner_node_mask) + ntype_ids, inner_type_nids = gpb.map_to_per_ntype(inner_nids) + partid = gpb.nid2partid(inner_type_nids, ntype) + if is_homo: + assert np.all(F.asnumpy(ntype_ids) == ntype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + if is_homo: + orig_id = orig_nids[inner_type_nids] + else: + orig_id = orig_nids[ntype][inner_type_nids] + local_nids = gpb.nid2localnid(inner_type_nids, gpb.partid, ntype) + + for name in g.nodes[ntype].data: + if name in [dgl.NID, "inner_node"]: + continue + true_feats = F.gather_row(g.nodes[ntype].data[name], orig_id) + ndata = F.gather_row(node_feats[ntype + "/" + name], local_nids) + assert np.all(F.asnumpy(ndata == true_feats)) + + +def _verify_edge_feats(g, part, gpb, orig_eids, edge_feats, is_homo=False): + for etype in g.canonical_etypes: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = F.boolean_mask(edata[dgl.EID], inner_edge_mask) + etype_ids, inner_type_eids = gpb.map_to_per_etype(inner_eids) + partid = gpb.eid2partid(inner_type_eids, etype) + assert np.all(F.asnumpy(etype_ids) == etype_id) + assert np.all(F.asnumpy(partid) == gpb.partid) + + if is_homo: + orig_id = orig_eids[inner_type_eids] + else: + orig_id = orig_eids[etype][inner_type_eids] + local_eids = gpb.eid2localeid(inner_type_eids, gpb.partid, etype) + + for name in g.edges[etype].data: + if name in [dgl.EID, "inner_edge"]: + continue + true_feats = F.gather_row(g.edges[etype].data[name], orig_id) + edata = F.gather_row( + edge_feats[_etype_tuple_to_str(etype) + "/" + name], + local_eids, + ) + assert np.all(F.asnumpy(edata == true_feats)) + + +def _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype=None, + test_etype=None, +): + """ + check list: + make sure node data are correct. + make sure edge data are correct. + """ + shuffled_labels = F.asnumpy(F.cat(shuffled_labels, 0)) + shuffled_edata = F.asnumpy(F.cat(shuffled_edata, 0)) + orig_labels = np.zeros(shuffled_labels.shape, dtype=shuffled_labels.dtype) + orig_edata = np.zeros(shuffled_edata.shape, dtype=shuffled_edata.dtype) + + orig_nid = orig_nids if test_ntype is None else orig_nids[test_ntype] + orig_eid = orig_eids if test_etype is None else orig_eids[test_etype] + nlabel = ( + g.ndata["labels"] + if test_ntype is None + else g.nodes[test_ntype].data["label"] + ) + edata = ( + g.edata["feats"] + if test_etype is None + else g.edges[test_etype].data["count"] + ) + + orig_labels[F.asnumpy(orig_nid)] = shuffled_labels + orig_edata[F.asnumpy(orig_eid)] = shuffled_edata + assert np.all(orig_labels == F.asnumpy(nlabel)) + assert np.all(orig_edata == F.asnumpy(edata)) + + +def verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + store_inner_node=False, + store_inner_edge=False, + store_eids=False, + is_homo=False, +): + """ + check list: + make sure the feats of nodes and edges are correct + """ + for part_id in range(len(parts)): + part = parts[part_id] + gpb = gpbs[part_id] + node_feats = tot_node_feats[part_id] + edge_feats = tot_edge_feats[part_id] + if store_inner_node: + _verify_node_feats( + g, + part, + gpb, + orig_nids, + node_feats, + is_homo=is_homo, + ) + if store_inner_edge and store_eids: + _verify_edge_feats( + g, + part, + gpb, + orig_eids, + edge_feats, + is_homo=is_homo, + ) + + _verify_shuffled_labels_gb( + g, + shuffled_labels, + shuffled_edata, + orig_nids, + orig_eids, + test_ntype, + test_etype, + ) + + +def _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids +): + """ + check list: + make sure arguments work. + """ + for part in parts: + assert store_inner_edge == ("inner_edge" in part.edge_attributes) + assert store_inner_node == ("inner_node" in part.node_attributes) + assert store_eids == (dgl.EID in part.edge_attributes) + + +def _verify_graphbolt_part( + g, + test_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + part_config=None, + test_ntype=None, + test_etype=None, + is_homo=False, +): + """ + check list: + _verify_metadata_gb: + data type, ID's order and ID's number of edges and nodes + _verify_IDs_gb: + local id, mapping id,node type id, orig edge, hetero ntype id + verify_graph_feats_gb: + nodes and edges' feats + _verify_graphbolt_attributes: + arguments + """ + parts = [] + tot_node_feats = [] + tot_edge_feats = [] + shuffled_labels = [] + shuffled_edata = [] + part_sizes = [] + gpbs = [] + if part_config is None: + part_config = os.path.join(test_dir, f"{graph_name}.json") + # test each part + for part_id in range(num_parts): + part_g, node_feats, edge_feats, gpb, _, _, _ = load_partition( + part_config, part_id, load_feats=True, use_graphbolt=True + ) + # verify metadata + _verify_metadata_gb( + gpb, + g, + num_parts, + part_id, + part_sizes, + ) + + # verify eid and nid + _verify_IDs_gb( + g, + part_g, + part_id, + gpb, + part_sizes, + orig_nids, + orig_eids, + store_inner_node, + store_inner_edge, + store_eids, + is_homo, + ) + + # collect shuffled data and parts + _collect_data_gb( + parts, + part_g, + gpbs, + gpb, + tot_node_feats, + node_feats, + tot_edge_feats, + edge_feats, + shuffled_labels, + shuffled_edata, + test_ntype, + test_etype, + ) + + # verify graph feats + verify_graph_feats_gb( + g, + gpbs, + parts, + tot_node_feats, + tot_edge_feats, + orig_nids, + orig_eids, + shuffled_labels=shuffled_labels, + shuffled_edata=shuffled_edata, + test_ntype=test_ntype, + test_etype=test_etype, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + store_eids=store_eids, + is_homo=is_homo, + ) + + _verify_graphbolt_attributes( + parts, store_inner_node, store_inner_edge, store_eids + ) + + return parts + + +def _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge, + debug_mode, +): + """ + check list: + make sure edge type are correct. + make sure the number of nodes in each node type are correct. + make sure the number of nodes in each node type are correct. + """ + num_nodes = {ntype: 0 for ntype in g.ntypes} + num_edges = {etype: 0 for etype in g.canonical_etypes} + for part in parts: + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + if dgl.ETYPE in edata: + assert len(g.canonical_etypes) == len(F.unique(edata[dgl.ETYPE])) + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + inner_node_mask = _get_inner_node_mask(part, ntype_id) + num_inner_nodes = F.sum(F.astype(inner_node_mask, F.int64), 0) + num_nodes[ntype] += num_inner_nodes + if store_inner_edge or isinstance(part, dgl.DGLGraph): + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + num_inner_edges = F.sum(F.astype(inner_edge_mask, F.int64), 0) + num_edges[etype] += num_inner_edges + + # Verify the number of nodes are correct. + if debug_mode or isinstance(part, dgl.DGLGraph): + for ntype in g.ntypes: + print( + "node {}: {}, {}".format( + ntype, g.num_nodes(ntype), num_nodes[ntype] + ) + ) + assert g.num_nodes(ntype) == num_nodes[ntype] + # Verify the number of edges are correct. + if store_inner_edge or isinstance(part, dgl.DGLGraph): + for etype in g.canonical_etypes: + print( + "edge {}: {}, {}".format( + etype, g.num_edges(etype), num_edges[etype] + ) + ) + assert g.num_edges(etype) == num_edges[etype] + + +def _verify_edge_id_range_hetero( + g, + part, + eids, +): + """ + check list: + make sure inner_eids fall into a range. + make sure all edges are included. + """ + edata = ( + part.edge_attributes + if isinstance(part, gb.FusedCSCSamplingGraph) + else part.edata + ) + etype = ( + part.type_per_edge + if isinstance(part, gb.FusedCSCSamplingGraph) + else edata[dgl.ETYPE] + ) + eid = torch.arange(len(edata[dgl.EID])) + etype_arr = F.gather_row(etype, eid) + eid_arr = F.gather_row(edata[dgl.EID], eid) + for etype in g.canonical_etypes: + etype_id = g.get_etype_id(etype) + eids[etype].append(F.boolean_mask(eid_arr, etype_arr == etype_id)) + # Make sure edge Ids fall into a range. + inner_edge_mask = _get_inner_edge_mask(part, etype_id) + inner_eids = np.sort( + F.asnumpy(F.boolean_mask(edata[dgl.EID], inner_edge_mask)) + ) + assert np.all( + inner_eids == np.arange(inner_eids[0], inner_eids[-1] + 1) + ) + return eids + + +def _verify_node_id_range_hetero(g, part, nids): + """ + check list: + make sure inner nodes have Ids fall into a range. + """ + for ntype in g.ntypes: + ntype_id = g.get_ntype_id(ntype) + # Make sure inner nodes have Ids fall into a range. + inner_node_mask = _get_inner_node_mask(part, ntype_id) + inner_nids = F.boolean_mask( + part.node_attributes[dgl.NID], inner_node_mask + ) + assert np.all( + F.asnumpy( + inner_nids + == F.arange( + F.as_scalar(inner_nids[0]), + F.as_scalar(inner_nids[-1]) + 1, + ) + ) + ) + nids[ntype].append(inner_nids) + return nids + + +def _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge, + store_inner_node, +): + """ + check list: + make sure edge ids fall into a range. + make sure inner nodes have Ids fall into a range. + make sure all nodes is included. + make sure all edges is included. + """ + nids = {ntype: [] for ntype in g.ntypes} + eids = {etype: [] for etype in g.canonical_etypes} + # check edge id. + if store_inner_edge or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + # collect eids + eids = _verify_edge_id_range_hetero(g, part, eids) + for etype in eids: + eids_type = F.cat(eids[etype], 0) + uniq_ids = F.unique(eids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_edges(etype) + + # check node id. + if store_inner_node or isinstance(parts[0], dgl.DGLGraph): + for part in parts: + nids = _verify_node_id_range_hetero(g, part, nids) + for ntype in nids: + nids_type = F.cat(nids[ntype], 0) + uniq_ids = F.unique(nids_type) + # We should get all nodes. + assert len(uniq_ids) == g.num_nodes(ntype) + + +def _verify_hetero_graph( + g, + parts, + store_eids=False, + store_inner_edge=False, + store_inner_node=False, + debug_mode=False, +): + _verify_hetero_graph_node_edge_num( + g, + parts, + store_inner_edge=store_inner_edge, + debug_mode=debug_mode, + ) + if store_eids: + _verify_graph_attributes_hetero( + g, + parts, + store_inner_edge=store_inner_edge, + store_inner_node=store_inner_node, + ) + + +def _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + graph_formats=None, + data_fmt="numpy", + num_chunks_nodes=None, + num_chunks_edges=None, + num_chunks_node_data=None, + num_chunks_edge_data=None, + use_verify_partitions=False, + store_eids=True, + store_inner_edge=True, + store_inner_node=True, +): + if num_parts % world_size != 0: + # num_parts should be a multiple of world_size + return + + with tempfile.TemporaryDirectory() as root_dir: + g = create_chunked_dataset( + root_dir, + num_chunks, + data_fmt=data_fmt, + num_chunks_nodes=num_chunks_nodes, + num_chunks_edges=num_chunks_edges, + num_chunks_node_data=num_chunks_node_data, + num_chunks_edge_data=num_chunks_edge_data, + ) + graph_name = "test" + test_ntype = "paper" + test_etype = ("paper", "cites", "paper") + + # Step1: graph partition + in_dir = os.path.join(root_dir, "chunked-data") + output_dir = os.path.join(root_dir, "parted_data") + os.system( + "python3 tools/partition_algo/random_partition.py " + "--in_dir {} --out_dir {} --num_partitions {}".format( + in_dir, output_dir, num_parts + ) + ) + for ntype in ["author", "institution", "paper"]: + fname = os.path.join(output_dir, "{}.txt".format(ntype)) + with open(fname, "r") as f: + header = f.readline().rstrip() + assert isinstance(int(header), int) + + # Step2: data dispatch + partition_dir = os.path.join(root_dir, "parted_data") + out_dir = os.path.join(root_dir, "partitioned") + ip_config = os.path.join(root_dir, "ip_config.txt") + with open(ip_config, "w") as f: + for i in range(world_size): + f.write(f"127.0.0.{i + 1}\n") + + cmd = "python3 tools/dispatch_data.py " + cmd += f" --in-dir {in_dir} " + cmd += f" --partitions-dir {partition_dir} " + cmd += f" --out-dir {out_dir} " + cmd += f" --ip-config {ip_config} " + cmd += " --ssh-port 22 " + cmd += " --process-group-timeout 60 " + cmd += " --save-orig-nids " + cmd += " --save-orig-eids " + cmd += " --use-graphbolt " + cmd += f" --graph-formats {graph_formats} " if graph_formats else "" + + if store_eids: + cmd += " --store-eids " + if store_inner_edge: + cmd += " --store-inner-edge " + if store_inner_node: + cmd += " --store-inner-node " + os.system(cmd) + + # check if verify_partitions.py is used for validation. + if use_verify_partitions: + cmd = "python3 tools/verify_partitions.py " + cmd += f" --orig-dataset-dir {in_dir}" + cmd += f" --part-graph {out_dir}" + cmd += f" --partitions-dir {output_dir}" + os.system(cmd) + return + + # read original node/edge IDs + def read_orig_ids(fname): + orig_ids = {} + for i in range(num_parts): + ids_path = os.path.join(out_dir, f"part{i}", fname) + part_ids = load_tensors(ids_path) + for type, data in part_ids.items(): + if type not in orig_ids: + orig_ids[type] = data + else: + orig_ids[type] = torch.cat((orig_ids[type], data)) + return orig_ids + + orig_nids, orig_eids = None, None + orig_nids = read_orig_ids("orig_nids.dgl") + + orig_eids_str = read_orig_ids("orig_eids.dgl") + + orig_eids = {} + # transmit etype from string to tuple. + for etype, eids in orig_eids_str.items(): + orig_eids[_etype_str_to_tuple(etype)] = eids + + # load partitions and verify + part_config = os.path.join(out_dir, "metadata.json") + parts = _verify_graphbolt_part( + g, + root_dir, + orig_nids, + orig_eids, + graph_name, + num_parts, + store_inner_node, + store_inner_edge, + store_eids, + test_ntype=test_ntype, + test_etype=test_etype, + part_config=part_config, + is_homo=False, + ) + _verify_hetero_graph( + g, + parts, + store_eids=store_eids, + store_inner_edge=store_inner_edge, + ) + + +@pytest.mark.parametrize( + "num_chunks, num_parts, world_size", + [[4, 4, 4], [8, 4, 2], [8, 4, 4], [9, 6, 3], [11, 11, 1], [11, 4, 1]], +) +def test_pipeline_basics(num_chunks, num_parts, world_size): + _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + ) + _test_pipeline_graphbolt( + num_chunks, num_parts, world_size, use_verify_partitions=False + ) + + +@pytest.mark.parametrize("store_inner_node", [True, False]) +@pytest.mark.parametrize("store_inner_edge", [True, False]) +@pytest.mark.parametrize("store_eids", [True, False]) +def test_pipeline_attributes(store_inner_node, store_inner_edge, store_eids): + _test_pipeline_graphbolt( + 4, + 4, + 4, + store_inner_node=store_inner_node, + store_inner_edge=store_inner_edge, + store_eids=store_eids, + ) + + +@pytest.mark.parametrize( + "num_chunks, " + "num_parts, " + "world_size, " + "num_chunks_node_data, " + "num_chunks_edge_data", + [ + # Test cases where no. of chunks more than + # no. of partitions + [8, 4, 4, 8, 8], + [8, 4, 2, 8, 8], + [9, 7, 5, 9, 9], + [8, 8, 4, 8, 8], + # Test cases where no. of chunks smaller + # than no. of partitions + [7, 8, 4, 7, 7], + [1, 8, 4, 1, 1], + [1, 4, 4, 1, 1], + [3, 4, 4, 3, 3], + [1, 4, 2, 1, 1], + [3, 4, 2, 3, 3], + [1, 5, 3, 1, 1], + ], +) +def test_pipeline_arbitrary_chunks( + num_chunks, + num_parts, + world_size, + num_chunks_node_data, + num_chunks_edge_data, +): + + _test_pipeline_graphbolt( + num_chunks, + num_parts, + world_size, + num_chunks_node_data=num_chunks_node_data, + num_chunks_edge_data=num_chunks_edge_data, + ) + + +@pytest.mark.parametrize("data_fmt", ["numpy", "parquet"]) +def test_pipeline_feature_format(data_fmt): + _test_pipeline_graphbolt(4, 4, 4, data_fmt=data_fmt) diff --git a/tools/dispatch_data.py b/tools/dispatch_data.py index 3cf1d0fbf2..b2b54e51a6 100644 --- a/tools/dispatch_data.py +++ b/tools/dispatch_data.py @@ -75,6 +75,10 @@ def submit_jobs(args) -> str: argslist += "--log-level {} ".format(args.log_level) argslist += "--save-orig-nids " if args.save_orig_nids else "" argslist += "--save-orig-eids " if args.save_orig_eids else "" + argslist += "--use-graphbolt " if args.use_graphbolt else "" + argslist += "--store-eids " if args.store_eids else "" + argslist += "--store-inner-node " if args.store_inner_node else "" + argslist += "--store-inner-edge " if args.store_inner_edge else "" argslist += ( f"--graph-formats {args.graph_formats} " if args.graph_formats else "" ) @@ -159,6 +163,30 @@ def main(): action="store_true", help="Save original edge IDs into files", ) + parser.add_argument( + "--use-graphbolt", + action="store_true", + help="Use GraphBolt for distributed partition.", + ) + parser.add_argument( + "--store-inner-node", + action="store_true", + default=False, + help="Store inner nodes.", + ) + + parser.add_argument( + "--store-inner-edge", + action="store_true", + default=False, + help="Store inner edges.", + ) + parser.add_argument( + "--store-eids", + action="store_true", + default=False, + help="Store edge IDs.", + ) parser.add_argument( "--graph-formats", type=str, diff --git a/tools/distpartitioning/convert_partition.py b/tools/distpartitioning/convert_partition.py index a169589a3f..5013b6d40f 100644 --- a/tools/distpartitioning/convert_partition.py +++ b/tools/distpartitioning/convert_partition.py @@ -1,24 +1,25 @@ -import argparse +import copy import gc -import json import logging import os -import time import constants - import dgl +import dgl.backend as F +import dgl.graphbolt as gb import numpy as np -import pandas as pd -import pyarrow import torch as th +from dgl import EID, ETYPE, NID, NTYPE + +from dgl.distributed.constants import DGL2GB_EID, GB_DST_ID from dgl.distributed.partition import ( + _cast_to_minimum_dtype, _etype_str_to_tuple, _etype_tuple_to_str, + cast_various_to_minimum_dtype_gb, RESERVED_FIELD_DTYPE, ) -from pyarrow import csv -from utils import get_idranges, memory_snapshot, read_json +from utils import get_idranges, memory_snapshot def _get_unique_invidx(srcids, dstids, nids, low_mem=True): @@ -164,7 +165,202 @@ def _get_unique_invidx(srcids, dstids, nids, low_mem=True): return uniques, idxes, srcids, dstids -def create_dgl_object( +# Utility functions. +def _is_homogeneous(ntypes, etypes): + """Checks if the provided ntypes and etypes form a homogeneous graph.""" + return len(ntypes) == 1 and len(etypes) == 1 + + +def _coo2csc(src_ids, dst_ids): + src_ids, dst_ids = th.tensor(src_ids, dtype=th.int64), th.tensor( + dst_ids, dtype=th.int64 + ) + num_nodes = th.max(th.stack([src_ids, dst_ids], dim=0)).item() + 1 + dst, idx = dst_ids.sort() + indptr = th.searchsorted(dst, th.arange(num_nodes + 1)) + indices = src_ids[idx] + return indptr, indices, idx + + +def _create_edge_data(edgeid_offset, etype_ids, num_edges): + eid = th.arange( + edgeid_offset, + edgeid_offset + num_edges, + dtype=RESERVED_FIELD_DTYPE[dgl.EID], + ) + etype = th.as_tensor(etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE]) + inner_edge = th.ones(num_edges, dtype=RESERVED_FIELD_DTYPE["inner_edge"]) + return eid, etype, inner_edge + + +def _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes): + node_type = th.as_tensor(ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE]) + node_id = th.as_tensor(uniq_ids[reshuffle_nodes]) + inner_node = th.as_tensor( + inner_nodes[reshuffle_nodes], + dtype=RESERVED_FIELD_DTYPE["inner_node"], + ) + return node_type, node_id, inner_node + + +def _compute_node_ntype( + global_src_id, global_dst_id, global_homo_nid, idx, reshuffle_nodes, id_map +): + global_ids = np.concatenate([global_src_id, global_dst_id, global_homo_nid]) + part_global_ids = global_ids[idx] + part_global_ids = part_global_ids[reshuffle_nodes] + ntype, per_type_ids = id_map(part_global_ids) + return ntype, per_type_ids + + +def _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + node_attr, + edge_attr, + per_type_ids, + type_per_edge, + global_edge_id, +): + orig_nids = None + orig_eids = None + if return_orig_nids: + orig_nids = {} + for ntype, ntype_id in ntypes_map.items(): + mask = th.logical_and( + node_attr[dgl.NTYPE] == ntype_id, + node_attr["inner_node"], + ) + orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) + if return_orig_eids: + orig_eids = {} + for etype, etype_id in etypes_map.items(): + mask = th.logical_and( + type_per_edge == etype_id, + edge_attr["inner_edge"], + ) + orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( + global_edge_id[mask] + ) + return orig_nids, orig_eids + + +def _create_edge_attr_gb( + part_local_dst_id, edgeid_offset, etype_ids, ntypes, etypes, etypes_map +): + edge_attr = {} + # create edge data in graph. + num_edges = len(part_local_dst_id) + ( + edge_attr[dgl.EID], + type_per_edge, + edge_attr["inner_edge"], + ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + assert "inner_edge" in edge_attr + + is_homo = _is_homogeneous(ntypes, etypes) + + edge_type_to_id = ( + {gb.etype_tuple_to_str(("_N", "_E", "_N")): 0} + if is_homo + else { + gb.etype_tuple_to_str(etype): etid + for etype, etid in etypes_map.items() + } + ) + return edge_attr, type_per_edge, edge_type_to_id + + +def _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, +): + # compute per_type_ids and ntype for all the nodes in the graph. + ntype, per_type_ids = _compute_node_ntype( + global_src_id, + global_dst_id, + global_homo_nid, + idx, + reshuffle_nodes, + id_map, + ) + + # create node data in graph. + node_attr = {} + ( + node_attr[dgl.NTYPE], + node_attr[dgl.NID], + node_attr["inner_node"], + ) = _create_node_data(ntype, uniq_ids, reshuffle_nodes, inner_nodes) + return node_attr, per_type_ids + + +def remove_attr_gb( + edge_attr, node_attr, store_inner_node, store_inner_edge, store_eids +): + edata, ndata = copy.deepcopy(edge_attr), copy.deepcopy(node_attr) + if not store_inner_edge: + assert "inner_edge" in edata + edata.pop("inner_edge") + + if not store_eids: + assert dgl.EID in edata + edata.pop(dgl.EID) + + if not store_inner_node: + assert "inner_node" in ndata + ndata.pop("inner_node") + return edata, ndata + + +def _process_partition_gb( + node_attr, + edge_attr, + type_per_edge, + src_ids, + dst_ids, + sort_etypes, +): + """Preprocess partitions before saving: + 1. format data types. + 2. sort csc/csr by tag. + """ + for k, dtype in RESERVED_FIELD_DTYPE.items(): + if k in node_attr: + node_attr[k] = F.astype(node_attr[k], dtype) + if k in edge_attr: + edge_attr[k] = F.astype(edge_attr[k], dtype) + + indptr, indices, edge_ids = _coo2csc(src_ids, dst_ids) + if sort_etypes: + split_size = th.diff(indptr) + split_indices = th.split(type_per_edge, tuple(split_size), dim=0) + sorted_idxs = [] + for split_indice in split_indices: + sorted_idxs.append(split_indice.sort()[1]) + + sorted_idx = th.cat(sorted_idxs, dim=0) + sorted_idx = ( + th.repeat_interleave(indptr[:-1], split_size, dim=0) + sorted_idx + ) + + return indptr, indices[sorted_idx], edge_ids[sorted_idx] + + +def create_graph_object( + tot_node_count, + tot_edge_count, + node_count, + edge_count, + num_parts, schema, part_id, node_data, @@ -174,6 +370,8 @@ def create_dgl_object( edge_typecounts, return_orig_nids=False, return_orig_eids=False, + use_graphbolt=False, + **kwargs, ): """ This function creates dgl objects for a given graph partition, as in function @@ -223,6 +421,18 @@ def create_dgl_object( Parameters: ----------- + tot_node_count : int + the number of all nodes + tot_edge_count : int + the number of all edges + node_count : int + the number of nodes in partition + edge_count : int + the number of edges in partition + graph_formats : str + the format of graph + num_parts : int + the number of parts schame : json object json object created by reading the graph metadata json file part_id : int @@ -449,58 +659,134 @@ def create_dgl_object( nid_map[part_local_dst_id], ) + """ + Creating attributes for graphbolt and DGLGraph is as follows. + + node attributes: + this part is implemented in _create_node_attr. + compute the ntype and per type ids for each node with global node type id. + create ntype, nid and inner node with orig ntype and inner nodes + this part is shared by graphbolt and DGLGraph. + + the attributes created for graphbolt are as follows: + + edge attributes: + this part is implemented in _create_edge_attr_gb. + create eid, type per edge and inner edge with edgeid_offset. + create edge_type_to_id with etypes_map. + + The process to remove extra attribute is implemented in remove_attr_gb. + the unused attributes like inner_node, inner_edge, eids will be removed following the arguments in kwargs. + edge_attr, node_attr are the variable that have removed extra attributes to construct csc_graph. + edata, ndata are the variable that reserve extra attributes to be used to generate orig_nid and orig_eid. + + the src_ids and dst_ids will be transformed into indptr and indices in _coo2csc. + + all variable mentioned above will be casted to minimum data type in cast_various_to_minimum_dtype_gb. + + orig_nids and orig_eids will be generated in _graph_orig_ids with ndata and edata. + """ # create the graph here now. - part_graph = dgl.graph( - data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) - ) - part_graph.edata[dgl.EID] = th.arange( - edgeid_offset, - edgeid_offset + part_graph.num_edges(), - dtype=th.int64, - ) - part_graph.edata[dgl.ETYPE] = th.as_tensor( - etype_ids, dtype=RESERVED_FIELD_DTYPE[dgl.ETYPE] - ) - part_graph.edata["inner_edge"] = th.ones( - part_graph.num_edges(), dtype=RESERVED_FIELD_DTYPE["inner_edge"] + ndata, per_type_ids = _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, ) + if use_graphbolt: + edata, type_per_edge, edge_type_to_id = _create_edge_attr_gb( + part_local_dst_id, + edgeid_offset, + etype_ids, + ntypes, + etypes, + etypes_map, + ) - # compute per_type_ids and ntype for all the nodes in the graph. - global_ids = np.concatenate([global_src_id, global_dst_id, global_homo_nid]) - part_global_ids = global_ids[idx] - part_global_ids = part_global_ids[reshuffle_nodes] - ntype, per_type_ids = id_map(part_global_ids) + assert edata is not None + assert ndata is not None - # continue with the graph creation - part_graph.ndata[dgl.NTYPE] = th.as_tensor( - ntype, dtype=RESERVED_FIELD_DTYPE[dgl.NTYPE] + sort_etypes = len(etypes_map) > 1 + indptr, indices, csc_edge_ids = _process_partition_gb( + ndata, + edata, + type_per_edge, + part_local_src_id, + part_local_dst_id, + sort_etypes, + ) + edge_attr, node_attr = remove_attr_gb( + edge_attr=edata, node_attr=ndata, **kwargs + ) + edge_attr = { + attr: edge_attr[attr][csc_edge_ids] for attr in edge_attr.keys() + } + cast_various_to_minimum_dtype_gb( + node_count=node_count, + edge_count=edge_count, + tot_node_count=tot_node_count, + tot_edge_count=tot_edge_count, + num_parts=num_parts, + indptr=indptr, + indices=indices, + type_per_edge=type_per_edge, + etypes=etypes, + ntypes=ntypes, + node_attributes=node_attr, + edge_attributes=edge_attr, + ) + part_graph = gb.fused_csc_sampling_graph( + csc_indptr=indptr, + indices=indices, + node_type_offset=None, + type_per_edge=type_per_edge[csc_edge_ids], + node_attributes=node_attr, + edge_attributes=edge_attr, + node_type_to_id=ntypes_map, + edge_type_to_id=edge_type_to_id, + ) + else: + num_edges = len(part_local_dst_id) + part_graph = dgl.graph( + data=(part_local_src_id, part_local_dst_id), num_nodes=len(uniq_ids) + ) + # create edge data in graph. + ( + part_graph.edata[dgl.EID], + part_graph.edata[dgl.ETYPE], + part_graph.edata["inner_edge"], + ) = _create_edge_data(edgeid_offset, etype_ids, num_edges) + + ndata, per_type_ids = _create_node_attr( + idx, + global_src_id, + global_dst_id, + global_homo_nid, + uniq_ids, + reshuffle_nodes, + id_map, + inner_nodes, + ) + for attr_name, node_attributes in ndata.items(): + part_graph.ndata[attr_name] = node_attributes + type_per_edge = part_graph.edata[dgl.ETYPE] + ndata, edata = part_graph.ndata, part_graph.edata + # get the original node ids and edge ids from original graph. + orig_nids, orig_eids = _graph_orig_ids( + return_orig_nids, + return_orig_eids, + ntypes_map, + etypes_map, + ndata, + edata, + per_type_ids, + type_per_edge, + global_edge_id, ) - part_graph.ndata[dgl.NID] = th.as_tensor(uniq_ids[reshuffle_nodes]) - part_graph.ndata["inner_node"] = th.as_tensor( - inner_nodes[reshuffle_nodes], dtype=RESERVED_FIELD_DTYPE["inner_node"] - ) - - orig_nids = None - orig_eids = None - if return_orig_nids: - orig_nids = {} - for ntype, ntype_id in ntypes_map.items(): - mask = th.logical_and( - part_graph.ndata[dgl.NTYPE] == ntype_id, - part_graph.ndata["inner_node"], - ) - orig_nids[ntype] = th.as_tensor(per_type_ids[mask]) - if return_orig_eids: - orig_eids = {} - for etype, etype_id in etypes_map.items(): - mask = th.logical_and( - part_graph.edata[dgl.ETYPE] == etype_id, - part_graph.edata["inner_edge"], - ) - orig_eids[_etype_tuple_to_str(etype)] = th.as_tensor( - global_edge_id[mask] - ) - return ( part_graph, node_map_val, @@ -523,6 +809,7 @@ def create_metadata_json( ntypes_map, etypes_map, output_dir, + use_graphbolt, ): """ Auxiliary function to create json file for the graph partition metadata @@ -549,6 +836,8 @@ def create_metadata_json( map between edge type(string) and edge_type_id(int) output_dir : string directory where the output files are to be stored + use_graphbolt : bool + whether to use graphbolt or not Returns: -------- @@ -572,10 +861,14 @@ def create_metadata_json( part_dir = "part" + str(part_id) node_feat_file = os.path.join(part_dir, "node_feat.dgl") edge_feat_file = os.path.join(part_dir, "edge_feat.dgl") - part_graph_file = os.path.join(part_dir, "graph.dgl") + if use_graphbolt: + part_graph_file = os.path.join(part_dir, "fused_csc_sampling_graph.pt") + else: + part_graph_file = os.path.join(part_dir, "graph.dgl") + part_graph_type = "part_graph_graphbolt" if use_graphbolt else "part_graph" part_metadata["part-{}".format(part_id)] = { "node_feats": node_feat_file, "edge_feats": edge_feat_file, - "part_graph": part_graph_file, + part_graph_type: part_graph_file, } return part_metadata diff --git a/tools/distpartitioning/data_proc_pipeline.py b/tools/distpartitioning/data_proc_pipeline.py index 4c249a34b6..e0159f55b9 100644 --- a/tools/distpartitioning/data_proc_pipeline.py +++ b/tools/distpartitioning/data_proc_pipeline.py @@ -94,6 +94,30 @@ if __name__ == "__main__": action="store_true", help="Save original edge IDs into files", ) + parser.add_argument( + "--use-graphbolt", + action="store_true", + help="Use GraphBolt for distributed partition.", + ) + parser.add_argument( + "--store-inner-node", + action="store_true", + default=False, + help="Store inner nodes.", + ) + + parser.add_argument( + "--store-inner-edge", + action="store_true", + default=False, + help="Store inner edges.", + ) + parser.add_argument( + "--store-eids", + action="store_true", + default=False, + help="Store edge IDs.", + ) parser.add_argument( "--graph-formats", default=None, @@ -101,7 +125,6 @@ if __name__ == "__main__": help="Save partitions in specified formats.", ) params = parser.parse_args() - # invoke the pipeline function numeric_level = getattr(logging, params.log_level.upper(), None) logging.basicConfig( diff --git a/tools/distpartitioning/data_shuffle.py b/tools/distpartitioning/data_shuffle.py index 7cba2cbeec..6800064a2b 100644 --- a/tools/distpartitioning/data_shuffle.py +++ b/tools/distpartitioning/data_shuffle.py @@ -13,7 +13,7 @@ import numpy as np import torch import torch.distributed as dist import torch.multiprocessing as mp -from convert_partition import create_dgl_object, create_metadata_json +from convert_partition import create_graph_object, create_metadata_json from dataset_utils import get_dataset from dist_lookup import DistLookupService from globalids import ( @@ -1121,7 +1121,6 @@ def gen_dist_partitions(rank, world_size, params): ) id_map = dgl.distributed.id_map.IdMap(global_nid_ranges) id_lookup.set_idMap(id_map) - # read input graph files and augment these datastructures with # appropriate information (global_nid and owner process) for node and edge data ( @@ -1315,6 +1314,8 @@ def gen_dist_partitions(rank, world_size, params): ) local_node_data = prepare_local_data(node_data, local_part_id) local_edge_data = prepare_local_data(edge_data, local_part_id) + tot_node_count = sum(schema_map["num_nodes_per_type"]) + tot_edge_count = sum(schema_map["num_edges_per_type"]) ( graph_obj, ntypes_map_val, @@ -1323,7 +1324,12 @@ def gen_dist_partitions(rank, world_size, params): etypes_map, orig_nids, orig_eids, - ) = create_dgl_object( + ) = create_graph_object( + tot_node_count, + tot_edge_count, + node_count, + edge_count, + params.num_parts, schema_map, rank + local_part_id * world_size, local_node_data, @@ -1334,8 +1340,12 @@ def gen_dist_partitions(rank, world_size, params): schema_map[constants.STR_NUM_NODES_PER_TYPE], ), edge_typecounts, - params.save_orig_nids, - params.save_orig_eids, + return_orig_nids=params.save_orig_nids, + return_orig_eids=params.save_orig_eids, + use_graphbolt=params.use_graphbolt, + store_inner_node=params.store_inner_node, + store_inner_edge=params.store_inner_edge, + store_eids=params.store_eids, ) sort_etypes = len(etypes_map) > 1 local_node_features = prepare_local_data( @@ -1354,8 +1364,12 @@ def gen_dist_partitions(rank, world_size, params): orig_eids, graph_formats, sort_etypes, + params.use_graphbolt, ) - memory_snapshot("DiskWriteDGLObjectsComplete: ", rank) + if params.use_graphbolt: + memory_snapshot("DiskWriteGrapgboltObjectsComplete: ", rank) + else: + memory_snapshot("DiskWriteDGLObjectsComplete: ", rank) # get the meta-data json_metadata = create_metadata_json( @@ -1369,6 +1383,7 @@ def gen_dist_partitions(rank, world_size, params): ntypes_map, etypes_map, params.output, + params.use_graphbolt, ) output_meta_json[ "local-part-id-" + str(local_part_id * world_size + rank) diff --git a/tools/distpartitioning/utils.py b/tools/distpartitioning/utils.py index cdb984be37..32292a843b 100644 --- a/tools/distpartitioning/utils.py +++ b/tools/distpartitioning/utils.py @@ -504,6 +504,20 @@ def write_edge_features(edge_features, edge_file): dgl.data.utils.save_tensors(edge_file, edge_features) +def write_graph_graghbolt(graph_file, graph_obj): + """ + Utility function to serialize FusedCSCSamplingGraph + + Parameters: + ----------- + graph_obj : FusedCSCSamplingGraph + FusedCSCSamplingGraph, as created in convert_partition.py, which is to be serialized + graph_file : string + File name in which graph object is serialized + """ + torch.save(graph_obj, graph_file) + + def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): """ Utility function to serialize graph dgl objects @@ -519,9 +533,23 @@ def write_graph_dgl(graph_file, graph_obj, formats, sort_etypes): sort_etypes : bool Whether to sort etypes in csc/csr. """ - dgl.distributed.partition._save_graphs( - graph_file, [graph_obj], formats, sort_etypes + dgl.distributed.partition.process_partitions( + graph_obj, formats, sort_etypes ) + dgl.save_graphs(graph_file, [graph_obj], formats=formats) + + +def _write_graph( + part_dir, graph_obj, formats=None, sort_etypes=None, use_graphbolt=False +): + if use_graphbolt: + write_graph_graghbolt( + os.path.join(part_dir, "fused_csc_sampling_graph.pt"), graph_obj + ) + else: + write_graph_dgl( + os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes + ) def write_dgl_objects( @@ -534,6 +562,7 @@ def write_dgl_objects( orig_eids, formats, sort_etypes, + use_graphbolt, ): """ Wrapper function to write graph, node/edge feature, original node/edge IDs. @@ -558,13 +587,18 @@ def write_dgl_objects( Save graph in formats. sort_etypes : bool Whether to sort etypes in csc/csr. + use_graphbolt : bool + Whether to use graphbolt or not. """ part_dir = output_dir + "/part" + str(part_id) os.makedirs(part_dir, exist_ok=True) - write_graph_dgl( - os.path.join(part_dir, "graph.dgl"), graph_obj, formats, sort_etypes + _write_graph( + part_dir, + graph_obj, + formats=formats, + sort_etypes=sort_etypes, + use_graphbolt=use_graphbolt, ) - if node_features != None: write_node_features( node_features, os.path.join(part_dir, "node_feat.dgl")