mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[Performance] Fused sampling with compaction (#5924)
Co-authored-by: Hesham Mostafa <hesham.mostafa@intel.com>
This commit is contained in:
39
benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
Normal file
39
benchmarks/benchmarks/api/bench_fused_sample_neighbors.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import time
|
||||
|
||||
import dgl
|
||||
import dgl.function as fn
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import utils
|
||||
|
||||
|
||||
@utils.benchmark("time")
|
||||
@utils.parametrize_cpu("graph_name", ["livejournal", "reddit"])
|
||||
@utils.parametrize_gpu("graph_name", ["ogbn-arxiv", "reddit"])
|
||||
@utils.parametrize("format", ["csr", "csc"])
|
||||
@utils.parametrize("seed_nodes_num", [200, 5000, 20000])
|
||||
@utils.parametrize("fanout", [5, 20, 40])
|
||||
def track_time(graph_name, format, seed_nodes_num, fanout):
|
||||
device = utils.get_bench_device()
|
||||
graph = utils.get_graph(graph_name, format).to(device)
|
||||
|
||||
edge_dir = "in" if format == "csc" else "out"
|
||||
seed_nodes = np.random.randint(0, graph.num_nodes(), seed_nodes_num)
|
||||
seed_nodes = torch.from_numpy(seed_nodes).to(device)
|
||||
|
||||
# dry run
|
||||
for i in range(3):
|
||||
dgl.sampling.sample_neighbors_fused(
|
||||
graph, seed_nodes, fanout, edge_dir=edge_dir
|
||||
)
|
||||
|
||||
# timing
|
||||
with utils.Timer() as t:
|
||||
for i in range(50):
|
||||
dgl.sampling.sample_neighbors_fused(
|
||||
graph, seed_nodes, fanout, edge_dir=edge_dir
|
||||
)
|
||||
|
||||
return t.elapsed_secs / 50
|
||||
@@ -572,6 +572,72 @@ COOMatrix CSRRowWiseSampling(
|
||||
CSRMatrix mat, IdArray rows, int64_t num_samples,
|
||||
NDArray prob_or_mask = NDArray(), bool replace = true);
|
||||
|
||||
/*!
|
||||
* @brief Randomly select a fixed number of non-zero entries along each given
|
||||
* row independently.
|
||||
*
|
||||
* The function performs random choices along each row independently.
|
||||
* The picked indices are returned in the form of a CSR matrix, with
|
||||
* additional IdArray that is an extended version of CSR's index pointers.
|
||||
*
|
||||
* With template parameter set to True rows are also saved as new seed nodes and
|
||||
* mapped
|
||||
*
|
||||
* If replace is false and a row has fewer non-zero values than num_samples,
|
||||
* all the values are picked.
|
||||
*
|
||||
* Examples:
|
||||
*
|
||||
* // csr.num_rows = 4;
|
||||
* // csr.num_cols = 4;
|
||||
* // csr.indptr = [0, 2, 3, 3, 5]
|
||||
* // csr.indices = [0, 1, 1, 2, 3]
|
||||
* // csr.data = [2, 3, 0, 1, 4]
|
||||
* CSRMatrix csr = ...;
|
||||
* IdArray rows = ... ; // [1, 3]
|
||||
* IdArray seed_mapping = [-1, -1, -1, -1];
|
||||
* std::vector<IdType> new_seed_nodes = {};
|
||||
*
|
||||
* std::pair<CSRMatrix, IdArray> sampled = CSRRowWiseSamplingFused<
|
||||
* typename IdType, True>(
|
||||
* csr, rows, seed_mapping,
|
||||
* new_seed_nodes, 2,
|
||||
* FloatArray(), false);
|
||||
* // possible sampled csr matrix:
|
||||
* // sampled.first.num_rows = 2
|
||||
* // sampled.first.num_cols = 3
|
||||
* // sampled.first.indptr = [0, 1, 3]
|
||||
* // sampled.first.indices = [1, 2, 3]
|
||||
* // sampled.first.data = [0, 1, 4]
|
||||
* // sampled.second = [0, 1, 1]
|
||||
* // seed_mapping = [-1, 0, -1, 1];
|
||||
* // new_seed_nodes = {1, 3};
|
||||
*
|
||||
* @tparam IdType Graph's index data type, can be int32_t or int64_t
|
||||
* @tparam map_seed_nodes If set for true we map and copy rows to new_seed_nodes
|
||||
* @param mat Input CSR matrix.
|
||||
* @param rows Rows to sample from.
|
||||
* @param seed_mapping Mapping array used if map_seed_nodes=true. If so each row
|
||||
* from rows will be set to its position e.g. mapping[rows[i]] = i.
|
||||
* @param new_seed_nodes Vector used if map_seed_nodes=true. If so it will
|
||||
* contain rows.
|
||||
* @param rows Rows to sample from.
|
||||
* @param num_samples Number of samples
|
||||
* @param prob_or_mask Unnormalized probability array or mask array.
|
||||
* Should be of the same length as the data array.
|
||||
* If an empty array is provided, assume uniform.
|
||||
* @param replace True if sample with replacement
|
||||
* @return A CSRMatrix storing the picked row, col and data indices,
|
||||
* COO version of picked rows
|
||||
* @note The edges of the entire graph must be ordered by their edge types,
|
||||
* rows must be unique
|
||||
*/
|
||||
template <typename IdType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdType>* new_seed_nodes, int64_t num_samples,
|
||||
NDArray prob_or_mask = NDArray(), bool replace = true);
|
||||
|
||||
/**
|
||||
* @brief Randomly select a fixed number of non-zero entries for each edge type
|
||||
* along each given row independently.
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include <dgl/array.h>
|
||||
#include <dgl/base_heterograph.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace dgl {
|
||||
@@ -47,6 +48,55 @@ HeteroSubgraph SampleNeighbors(
|
||||
const std::vector<FloatArray>& probability,
|
||||
const std::vector<IdArray>& exclude_edges, bool replace = true);
|
||||
|
||||
/**
|
||||
* @brief Sample from the neighbors of the given nodes and convert a graph into
|
||||
* a bipartite-structured graph for message passing.
|
||||
*
|
||||
* Specifically, we create one node type \c ntype_l on the "left" side and
|
||||
* another node type \c ntype_r on the "right" side for each node type \c ntype.
|
||||
* The nodes of type \c ntype_r would contain the nodes designated by the
|
||||
* caller, and node type \c ntype_l would contain the nodes that has an edge
|
||||
* connecting to one of the designated nodes.
|
||||
*
|
||||
* The nodes of \c ntype_l would also contain the nodes in node type \c ntype_r.
|
||||
* When sampling with replacement, the sampled subgraph could have parallel
|
||||
* edges.
|
||||
*
|
||||
* For sampling without replace, if fanout > the number of neighbors, all the
|
||||
* neighbors will be sampled.
|
||||
*
|
||||
* Non-deterministic algorithm, requires nodes parameter to store unique Node
|
||||
* IDs.
|
||||
*
|
||||
* @tparam IdType Graph's index data type, can be int32_t or int64_t
|
||||
* @param hg The input graph.
|
||||
* @param nodes Node IDs of each type. The vector length must be equal to the
|
||||
* number of node types. Empty array is allowed.
|
||||
* @param mapping External parameter that should be set to a vector of IdArrays
|
||||
* filled with -1, required for mapping of nodes in returned
|
||||
* graph
|
||||
* @param fanouts Number of sampled neighbors for each edge type. The vector
|
||||
* length should be equal to the number of edge types, or one if they all have
|
||||
* the same fanout.
|
||||
* @param dir Edge direction.
|
||||
* @param probability A vector of 1D float arrays, indicating the transition
|
||||
* probability of each edge by edge type. An empty float array assumes uniform
|
||||
* transition.
|
||||
* @param exclude_edges Edges IDs of each type which will be excluded during
|
||||
* sampling. The vector length must be equal to the number of edges types. Empty
|
||||
* array is allowed.
|
||||
* @param replace If true, sample with replacement.
|
||||
* @return Sampled neighborhoods as a graph. The return graph has the same
|
||||
* schema as the original one.
|
||||
*/
|
||||
template <typename IdType>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
SampleNeighborsFused(
|
||||
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
|
||||
const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
|
||||
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
|
||||
const std::vector<IdArray>& exclude_edges, bool replace = true);
|
||||
|
||||
/**
|
||||
* Select the neighbors with k-largest weights on the connecting edges for each
|
||||
* given node.
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Data loading components for neighbor sampling"""
|
||||
from .. import backend as F
|
||||
from ..base import EID, NID
|
||||
from ..heterograph import DGLGraph
|
||||
from ..transforms import to_block
|
||||
from .base import BlockSampler
|
||||
|
||||
@@ -54,6 +56,9 @@ class NeighborSampler(BlockSampler):
|
||||
output_device : device, optional
|
||||
The device of the output subgraphs or MFGs. Default is the same as the
|
||||
minibatch of seed nodes.
|
||||
fused : bool, default True
|
||||
If True and device is CPU fused sample neighbors is invoked. This version
|
||||
requires seed_nodes to be unique
|
||||
|
||||
Examples
|
||||
--------
|
||||
@@ -120,6 +125,7 @@ class NeighborSampler(BlockSampler):
|
||||
prefetch_labels=None,
|
||||
prefetch_edge_feats=None,
|
||||
output_device=None,
|
||||
fused=True,
|
||||
):
|
||||
super().__init__(
|
||||
prefetch_node_feats=prefetch_node_feats,
|
||||
@@ -137,10 +143,43 @@ class NeighborSampler(BlockSampler):
|
||||
)
|
||||
self.prob = prob or mask
|
||||
self.replace = replace
|
||||
self.fused = fused
|
||||
self.mapping = {}
|
||||
self.g = None
|
||||
|
||||
def sample_blocks(self, g, seed_nodes, exclude_eids=None):
|
||||
output_nodes = seed_nodes
|
||||
blocks = []
|
||||
|
||||
if self.fused:
|
||||
cpu = F.device_type(g.device) == "cpu"
|
||||
if isinstance(seed_nodes, dict):
|
||||
for ntype in list(seed_nodes.keys()):
|
||||
if not cpu:
|
||||
break
|
||||
cpu = (
|
||||
cpu and F.device_type(seed_nodes[ntype].device) == "cpu"
|
||||
)
|
||||
else:
|
||||
cpu = cpu and F.device_type(seed_nodes.device) == "cpu"
|
||||
if cpu and isinstance(g, DGLGraph) and F.backend_name == "pytorch":
|
||||
if self.g != g:
|
||||
self.mapping = {}
|
||||
self.g = g
|
||||
for fanout in reversed(self.fanouts):
|
||||
block = g.sample_neighbors_fused(
|
||||
seed_nodes,
|
||||
fanout,
|
||||
edge_dir=self.edge_dir,
|
||||
prob=self.prob,
|
||||
replace=self.replace,
|
||||
exclude_edges=exclude_eids,
|
||||
mapping=self.mapping,
|
||||
)
|
||||
seed_nodes = block.srcdata[NID]
|
||||
blocks.insert(0, block)
|
||||
return seed_nodes, output_nodes, blocks
|
||||
|
||||
for fanout in reversed(self.fanouts):
|
||||
frontier = g.sample_neighbors(
|
||||
seed_nodes,
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""Neighbor sampling APIs"""
|
||||
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from .. import backend as F, ndarray as nd, utils
|
||||
from .._ffi.function import _init_api
|
||||
from ..base import DGLError, EID
|
||||
from ..heterograph import DGLGraph
|
||||
from ..heterograph import DGLBlock, DGLGraph
|
||||
from .utils import EidExcluder
|
||||
|
||||
__all__ = [
|
||||
"sample_etype_neighbors",
|
||||
"sample_neighbors",
|
||||
"sample_neighbors_fused",
|
||||
"sample_neighbors_biased",
|
||||
"select_topk",
|
||||
]
|
||||
@@ -379,6 +384,126 @@ def sample_neighbors(
|
||||
return frontier if output_device is None else frontier.to(output_device)
|
||||
|
||||
|
||||
def sample_neighbors_fused(
|
||||
g,
|
||||
nodes,
|
||||
fanout,
|
||||
edge_dir="in",
|
||||
prob=None,
|
||||
replace=False,
|
||||
copy_ndata=True,
|
||||
copy_edata=True,
|
||||
exclude_edges=None,
|
||||
mapping=None,
|
||||
):
|
||||
"""Sample neighboring edges of the given nodes and return the induced subgraph.
|
||||
|
||||
For each node, a number of inbound (or outbound when ``edge_dir == 'out'``) edges
|
||||
will be randomly chosen. The graph returned will then contain all the nodes in the
|
||||
original graph, but only the sampled edges. Nodes will be renumbered starting from id 0,
|
||||
which would be new node id of first seed node.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
g : DGLGraph
|
||||
The graph. Can be either on CPU or GPU.
|
||||
nodes : tensor or dict
|
||||
Node IDs to sample neighbors from.
|
||||
|
||||
This argument can take a single ID tensor or a dictionary of node types and ID tensors.
|
||||
If a single tensor is given, the graph must only have one type of nodes.
|
||||
fanout : int or dict[etype, int]
|
||||
The number of edges to be sampled for each node on each edge type.
|
||||
|
||||
This argument can take a single int or a dictionary of edge types and ints.
|
||||
If a single int is given, DGL will sample this number of edges for each node for
|
||||
every edge type.
|
||||
|
||||
If -1 is given for a single edge type, all the neighboring edges with that edge
|
||||
type and non-zero probability will be selected.
|
||||
edge_dir : str, optional
|
||||
Determines whether to sample inbound or outbound edges.
|
||||
|
||||
Can take either ``in`` for inbound edges or ``out`` for outbound edges.
|
||||
prob : str, optional
|
||||
Feature name used as the (unnormalized) probabilities associated with each
|
||||
neighboring edge of a node. The feature must have only one element for each
|
||||
edge.
|
||||
|
||||
The features must be non-negative floats or boolean. Otherwise, the result
|
||||
will be undefined.
|
||||
exclude_edges: tensor or dict
|
||||
Edge IDs to exclude during sampling neighbors for the seed nodes.
|
||||
|
||||
This argument can take a single ID tensor or a dictionary of edge types and ID tensors.
|
||||
If a single tensor is given, the graph must only have one type of nodes.
|
||||
replace : bool, optional
|
||||
If True, sample with replacement.
|
||||
copy_ndata: bool, optional
|
||||
If True, the node features of the new graph are copied from
|
||||
the original graph. If False, the new graph will not have any
|
||||
node features.
|
||||
|
||||
(Default: True)
|
||||
copy_edata: bool, optional
|
||||
If True, the edge features of the new graph are copied from
|
||||
the original graph. If False, the new graph will not have any
|
||||
edge features.
|
||||
|
||||
(Default: False)
|
||||
|
||||
mapping : dictionary, optional
|
||||
Used by fused version of NeighborSampler. To avoid constant data allocation
|
||||
provide empty dictionary ({}) that will be allocated once with proper data and reused
|
||||
by each function call
|
||||
|
||||
(Default: None)
|
||||
Returns
|
||||
-------
|
||||
DGLGraph
|
||||
A sampled subgraph containing only the sampled neighboring edges.
|
||||
|
||||
Notes
|
||||
-----
|
||||
If :attr:`copy_ndata` or :attr:`copy_edata` is True, same tensors are used as
|
||||
the node or edge features of the original graph and the new graph.
|
||||
As a result, users should avoid performing in-place operations
|
||||
on the node features of the new graph to avoid feature corruption.
|
||||
|
||||
"""
|
||||
if not g.is_pinned():
|
||||
frontier = _sample_neighbors(
|
||||
g,
|
||||
nodes,
|
||||
fanout,
|
||||
edge_dir=edge_dir,
|
||||
prob=prob,
|
||||
replace=replace,
|
||||
copy_ndata=copy_ndata,
|
||||
copy_edata=copy_edata,
|
||||
exclude_edges=exclude_edges,
|
||||
fused=True,
|
||||
mapping=mapping,
|
||||
)
|
||||
else:
|
||||
frontier = _sample_neighbors(
|
||||
g,
|
||||
nodes,
|
||||
fanout,
|
||||
edge_dir=edge_dir,
|
||||
prob=prob,
|
||||
replace=replace,
|
||||
copy_ndata=copy_ndata,
|
||||
copy_edata=copy_edata,
|
||||
fused=True,
|
||||
mapping=mapping,
|
||||
)
|
||||
if exclude_edges is not None:
|
||||
eid_excluder = EidExcluder(exclude_edges)
|
||||
frontier = eid_excluder(frontier)
|
||||
return frontier
|
||||
|
||||
|
||||
def _sample_neighbors(
|
||||
g,
|
||||
nodes,
|
||||
@@ -390,6 +515,8 @@ def _sample_neighbors(
|
||||
copy_edata=True,
|
||||
_dist_training=False,
|
||||
exclude_edges=None,
|
||||
fused=False,
|
||||
mapping=None,
|
||||
):
|
||||
if not isinstance(nodes, dict):
|
||||
if len(g.ntypes) > 1:
|
||||
@@ -446,17 +573,64 @@ def _sample_neighbors(
|
||||
else:
|
||||
excluded_edges_all_t.append(nd.array([], ctx=ctx))
|
||||
|
||||
subgidx = _CAPI_DGLSampleNeighbors(
|
||||
g._graph,
|
||||
nodes_all_types,
|
||||
fanout_array,
|
||||
edge_dir,
|
||||
prob_arrays,
|
||||
excluded_edges_all_t,
|
||||
replace,
|
||||
)
|
||||
induced_edges = subgidx.induced_edges
|
||||
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
|
||||
if fused:
|
||||
if _dist_training:
|
||||
raise DGLError(
|
||||
"distributed training not supported in fused sampling"
|
||||
)
|
||||
cpu = F.device_type(g.device) == "cpu"
|
||||
if isinstance(nodes, dict):
|
||||
for ntype in list(nodes.keys()):
|
||||
if not cpu:
|
||||
break
|
||||
cpu = cpu and F.device_type(nodes[ntype].device) == "cpu"
|
||||
else:
|
||||
cpu = cpu and F.device_type(nodes.device) == "cpu"
|
||||
if not cpu or F.backend_name != "pytorch":
|
||||
raise DGLError(
|
||||
"Only PyTorch backend and cpu is supported in fused sampling"
|
||||
)
|
||||
|
||||
if mapping is None:
|
||||
mapping = {}
|
||||
mapping_name = "__mapping" + str(os.getpid())
|
||||
if mapping_name not in mapping.keys():
|
||||
mapping[mapping_name] = [
|
||||
torch.LongTensor(g.num_nodes(ntype)).fill_(-1)
|
||||
for ntype in g.ntypes
|
||||
]
|
||||
|
||||
subgidx, induced_nodes, induced_edges = _CAPI_DGLSampleNeighborsFused(
|
||||
g._graph,
|
||||
nodes_all_types,
|
||||
[F.to_dgl_nd(m) for m in mapping[mapping_name]],
|
||||
fanout_array,
|
||||
edge_dir,
|
||||
prob_arrays,
|
||||
excluded_edges_all_t,
|
||||
replace,
|
||||
)
|
||||
for mapping_vector, src_nodes in zip(
|
||||
mapping[mapping_name], induced_nodes
|
||||
):
|
||||
mapping_vector[F.from_dgl_nd(src_nodes).type(F.int64)] = -1
|
||||
|
||||
new_ntypes = (g.ntypes, g.ntypes)
|
||||
ret = DGLBlock(subgidx, new_ntypes, g.etypes)
|
||||
assert ret.is_unibipartite
|
||||
|
||||
else:
|
||||
subgidx = _CAPI_DGLSampleNeighbors(
|
||||
g._graph,
|
||||
nodes_all_types,
|
||||
fanout_array,
|
||||
edge_dir,
|
||||
prob_arrays,
|
||||
excluded_edges_all_t,
|
||||
replace,
|
||||
)
|
||||
ret = DGLGraph(subgidx.graph, g.ntypes, g.etypes)
|
||||
induced_edges = subgidx.induced_edges
|
||||
|
||||
# handle features
|
||||
# (TODO) (BarclayII) DGL distributed fails with bus error, freezes, or other
|
||||
@@ -465,12 +639,31 @@ def _sample_neighbors(
|
||||
# only set the edge IDs.
|
||||
if not _dist_training:
|
||||
if copy_ndata:
|
||||
node_frames = utils.extract_node_subframes(g, device)
|
||||
utils.set_new_frames(ret, node_frames=node_frames)
|
||||
if fused:
|
||||
src_node_ids = [F.from_dgl_nd(src) for src in induced_nodes]
|
||||
dst_node_ids = [
|
||||
utils.toindex(
|
||||
nodes.get(ntype, []), g._idtype_str
|
||||
).tousertensor(ctx=F.to_backend_ctx(g._graph.ctx))
|
||||
for ntype in g.ntypes
|
||||
]
|
||||
node_frames = utils.extract_node_subframes_for_block(
|
||||
g, src_node_ids, dst_node_ids
|
||||
)
|
||||
utils.set_new_frames(ret, node_frames=node_frames)
|
||||
else:
|
||||
node_frames = utils.extract_node_subframes(g, device)
|
||||
utils.set_new_frames(ret, node_frames=node_frames)
|
||||
|
||||
if copy_edata:
|
||||
edge_frames = utils.extract_edge_subframes(g, induced_edges)
|
||||
utils.set_new_frames(ret, edge_frames=edge_frames)
|
||||
if fused:
|
||||
edge_ids = [F.from_dgl_nd(eid) for eid in induced_edges]
|
||||
edge_frames = utils.extract_edge_subframes(g, edge_ids)
|
||||
utils.set_new_frames(ret, edge_frames=edge_frames)
|
||||
else:
|
||||
edge_frames = utils.extract_edge_subframes(g, induced_edges)
|
||||
utils.set_new_frames(ret, edge_frames=edge_frames)
|
||||
|
||||
else:
|
||||
for i, etype in enumerate(ret.canonical_etypes):
|
||||
ret.edges[etype].data[EID] = induced_edges[i]
|
||||
@@ -479,6 +672,7 @@ def _sample_neighbors(
|
||||
|
||||
|
||||
DGLGraph.sample_neighbors = utils.alias_func(sample_neighbors)
|
||||
DGLGraph.sample_neighbors_fused = utils.alias_func(sample_neighbors_fused)
|
||||
|
||||
|
||||
def sample_neighbors_biased(
|
||||
|
||||
@@ -597,6 +597,47 @@ COOMatrix CSRRowWiseSampling(
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename IdType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdType>* new_seed_nodes, int64_t num_samples,
|
||||
NDArray prob_or_mask, bool replace) {
|
||||
std::pair<CSRMatrix, IdArray> ret;
|
||||
if (IsNullArray(prob_or_mask)) {
|
||||
ATEN_XPU_SWITCH(
|
||||
rows->ctx.device_type, XPU, "CSRRowWiseSamplingUniformFused", {
|
||||
ret =
|
||||
impl::CSRRowWiseSamplingUniformFused<XPU, IdType, map_seed_nodes>(
|
||||
mat, rows, seed_mapping, new_seed_nodes, num_samples,
|
||||
replace);
|
||||
});
|
||||
} else {
|
||||
CHECK_VALID_CONTEXT(prob_or_mask, rows);
|
||||
ATEN_XPU_SWITCH(rows->ctx.device_type, XPU, "CSRRowWiseSamplingFused", {
|
||||
ATEN_FLOAT_INT8_UINT8_TYPE_SWITCH(
|
||||
prob_or_mask->dtype, FloatType, "probability or mask", {
|
||||
ret = impl::CSRRowWiseSamplingFused<
|
||||
XPU, IdType, FloatType, map_seed_nodes>(
|
||||
mat, rows, seed_mapping, new_seed_nodes, num_samples,
|
||||
prob_or_mask, replace);
|
||||
});
|
||||
});
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
|
||||
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int64_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
|
||||
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
|
||||
template std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused<int32_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
|
||||
COOMatrix CSRRowWisePerEtypeSampling(
|
||||
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
|
||||
const std::vector<int64_t>& num_samples,
|
||||
|
||||
@@ -178,6 +178,14 @@ COOMatrix CSRRowWiseSampling(
|
||||
CSRMatrix mat, IdArray rows, int64_t num_samples, NDArray prob_or_mask,
|
||||
bool replace);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <
|
||||
DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
|
||||
NDArray prob_or_mask, bool replace);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
COOMatrix CSRRowWisePerEtypeSampling(
|
||||
@@ -190,6 +198,11 @@ template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRRowWiseSamplingUniform(
|
||||
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
|
||||
|
||||
template <DGLDeviceType XPU, typename IdType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdType>* new_seed_nodes, int64_t num_samples, bool replace);
|
||||
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRRowWisePerEtypeSamplingUniform(
|
||||
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
|
||||
|
||||
@@ -223,5 +223,27 @@ ConcurrentIdHashMap<IdType>::AttemptInsertAt(int64_t pos, IdType key) {
|
||||
template class ConcurrentIdHashMap<int32_t>;
|
||||
template class ConcurrentIdHashMap<int64_t>;
|
||||
|
||||
template <typename IdType>
|
||||
bool BoolCompareAndSwap(IdType* ptr) {
|
||||
#ifdef _MSC_VER
|
||||
if (sizeof(IdType) == 4) {
|
||||
return _InterlockedCompareExchange(reinterpret_cast<LONG*>(ptr), 0, -1) ==
|
||||
-1;
|
||||
} else if (sizeof(IdType) == 8) {
|
||||
return _InterlockedCompareExchange64(
|
||||
reinterpret_cast<LONGLONG*>(ptr), 0, -1) == -1;
|
||||
} else {
|
||||
LOG(FATAL) << "ID can only be int32 or int64";
|
||||
}
|
||||
#elif __GNUC__ // _MSC_VER
|
||||
return __sync_bool_compare_and_swap(ptr, -1, 0);
|
||||
#else // _MSC_VER
|
||||
#error "CompareAndSwap is not supported on this platform."
|
||||
#endif // _MSC_VER
|
||||
}
|
||||
|
||||
template bool BoolCompareAndSwap<int32_t>(int32_t*);
|
||||
template bool BoolCompareAndSwap<int64_t>(int64_t*);
|
||||
|
||||
} // namespace aten
|
||||
} // namespace dgl
|
||||
|
||||
@@ -195,6 +195,9 @@ class ConcurrentIdHashMap {
|
||||
IdType mask_;
|
||||
};
|
||||
|
||||
template <typename IdType>
|
||||
bool BoolCompareAndSwap(IdType* ptr);
|
||||
|
||||
} // namespace aten
|
||||
} // namespace dgl
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace dgl {
|
||||
@@ -94,6 +95,115 @@ using EtypeRangePickFn = std::function<void(
|
||||
const std::vector<IdxType>& et_idx, const std::vector<IdxType>& et_eid,
|
||||
const IdxType* eid, IdxType* out_idx)>;
|
||||
|
||||
template <typename IdxType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWisePickFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdxType>* new_seed_nodes, int64_t num_picks, bool replace,
|
||||
PickFn<IdxType> pick_fn, NumPicksFn<IdxType> num_picks_fn) {
|
||||
using namespace aten;
|
||||
|
||||
const IdxType* indptr = static_cast<IdxType*>(mat.indptr->data);
|
||||
const IdxType* indices = static_cast<IdxType*>(mat.indices->data);
|
||||
const IdxType* data =
|
||||
CSRHasData(mat) ? static_cast<IdxType*>(mat.data->data) : nullptr;
|
||||
const IdxType* rows_data = static_cast<IdxType*>(rows->data);
|
||||
const int64_t num_rows = rows->shape[0];
|
||||
const auto& ctx = mat.indptr->ctx;
|
||||
const auto& idtype = mat.indptr->dtype;
|
||||
IdxType* seed_mapping_data = nullptr;
|
||||
if (map_seed_nodes) seed_mapping_data = seed_mapping.Ptr<IdxType>();
|
||||
|
||||
const int num_threads = runtime::compute_num_threads(0, num_rows, 1);
|
||||
std::vector<int64_t> global_prefix(num_threads + 1, 0);
|
||||
|
||||
IdArray picked_col, picked_idx, picked_coo_rows;
|
||||
|
||||
IdArray block_csr_indptr = IdArray::Empty({num_rows + 1}, idtype, ctx);
|
||||
IdxType* block_csr_indptr_data = block_csr_indptr.Ptr<IdxType>();
|
||||
|
||||
#pragma omp parallel num_threads(num_threads)
|
||||
{
|
||||
const int thread_id = omp_get_thread_num();
|
||||
|
||||
const int64_t start_i =
|
||||
thread_id * (num_rows / num_threads) +
|
||||
std::min(static_cast<int64_t>(thread_id), num_rows % num_threads);
|
||||
const int64_t end_i =
|
||||
(thread_id + 1) * (num_rows / num_threads) +
|
||||
std::min(static_cast<int64_t>(thread_id + 1), num_rows % num_threads);
|
||||
assert(thread_id + 1 < num_threads || end_i == num_rows);
|
||||
|
||||
const int64_t num_local = end_i - start_i;
|
||||
|
||||
std::unique_ptr<int64_t[]> local_prefix(new int64_t[num_local + 1]);
|
||||
local_prefix[0] = 0;
|
||||
for (int64_t i = start_i; i < end_i; ++i) {
|
||||
// build prefix-sum
|
||||
const int64_t local_i = i - start_i;
|
||||
const IdxType rid = rows_data[i];
|
||||
if (map_seed_nodes) seed_mapping_data[rid] = i;
|
||||
|
||||
IdxType len = num_picks_fn(
|
||||
rid, indptr[rid], indptr[rid + 1] - indptr[rid], indices, data);
|
||||
local_prefix[local_i + 1] = local_prefix[local_i] + len;
|
||||
}
|
||||
global_prefix[thread_id + 1] = local_prefix[num_local];
|
||||
|
||||
#pragma omp barrier
|
||||
#pragma omp master
|
||||
{
|
||||
for (int t = 0; t < num_threads; ++t) {
|
||||
global_prefix[t + 1] += global_prefix[t];
|
||||
}
|
||||
picked_col = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
|
||||
picked_idx = IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
|
||||
picked_coo_rows =
|
||||
IdArray::Empty({global_prefix[num_threads]}, idtype, ctx);
|
||||
}
|
||||
|
||||
#pragma omp barrier
|
||||
IdxType* picked_cdata = picked_col.Ptr<IdxType>();
|
||||
IdxType* picked_idata = picked_idx.Ptr<IdxType>();
|
||||
IdxType* picked_rows = picked_coo_rows.Ptr<IdxType>();
|
||||
|
||||
const IdxType thread_offset = global_prefix[thread_id];
|
||||
|
||||
for (int64_t i = start_i; i < end_i; ++i) {
|
||||
const IdxType rid = rows_data[i];
|
||||
const int64_t local_i = i - start_i;
|
||||
block_csr_indptr_data[i] = local_prefix[local_i] + thread_offset;
|
||||
|
||||
const IdxType off = indptr[rid];
|
||||
const IdxType len = indptr[rid + 1] - off;
|
||||
if (len == 0) continue;
|
||||
|
||||
const int64_t row_offset = local_prefix[local_i] + thread_offset;
|
||||
const int64_t num_picks =
|
||||
local_prefix[local_i + 1] + thread_offset - row_offset;
|
||||
|
||||
pick_fn(
|
||||
rid, off, len, num_picks, indices, data, picked_idata + row_offset);
|
||||
for (int64_t j = 0; j < num_picks; ++j) {
|
||||
const IdxType picked = picked_idata[row_offset + j];
|
||||
picked_cdata[row_offset + j] = indices[picked];
|
||||
picked_idata[row_offset + j] = data ? data[picked] : picked;
|
||||
picked_rows[row_offset + j] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
block_csr_indptr_data[num_rows] = global_prefix.back();
|
||||
|
||||
const IdxType num_cols = picked_col->shape[0];
|
||||
if (map_seed_nodes) {
|
||||
(*new_seed_nodes).resize(num_rows);
|
||||
memcpy((*new_seed_nodes).data(), rows_data, sizeof(IdxType) * num_rows);
|
||||
}
|
||||
|
||||
return std::make_pair(
|
||||
CSRMatrix(num_rows, num_cols, block_csr_indptr, picked_col, picked_idx),
|
||||
picked_coo_rows);
|
||||
}
|
||||
|
||||
// Template for picking non-zero values row-wise. The implementation utilizes
|
||||
// OpenMP parallelization on rows because each row performs computation
|
||||
// independently.
|
||||
|
||||
@@ -225,6 +225,74 @@ template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, uint8_t>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, uint8_t>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
|
||||
template <
|
||||
DGLDeviceType XPU, typename IdxType, typename DType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdxType>* new_seed_nodes, int64_t num_samples,
|
||||
NDArray prob_or_mask, bool replace) {
|
||||
// If num_samples is -1, select all neighbors without replacement.
|
||||
replace = (replace && num_samples != -1);
|
||||
CHECK(prob_or_mask.defined());
|
||||
auto num_picks_fn =
|
||||
GetSamplingNumPicksFn<IdxType, DType>(num_samples, prob_or_mask, replace);
|
||||
auto pick_fn =
|
||||
GetSamplingPickFn<IdxType, DType>(num_samples, prob_or_mask, replace);
|
||||
return CSRRowWisePickFused<IdxType, map_seed_nodes>(
|
||||
mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
|
||||
num_picks_fn);
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, float, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, float, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, double, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, double, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, int8_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, int8_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int32_t, uint8_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, NDArray, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingFused<kDGLCPU, int64_t, uint8_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, NDArray, bool);
|
||||
|
||||
template <DGLDeviceType XPU, typename IdxType, typename DType>
|
||||
COOMatrix CSRRowWisePerEtypeSampling(
|
||||
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
|
||||
@@ -283,6 +351,33 @@ template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
|
||||
CSRMatrix, IdArray, int64_t, bool);
|
||||
|
||||
template <DGLDeviceType XPU, typename IdxType, bool map_seed_nodes>
|
||||
std::pair<CSRMatrix, IdArray> CSRRowWiseSamplingUniformFused(
|
||||
CSRMatrix mat, IdArray rows, IdArray seed_mapping,
|
||||
std::vector<IdxType>* new_seed_nodes, int64_t num_samples, bool replace) {
|
||||
// If num_samples is -1, select all neighbors without replacement.
|
||||
replace = (replace && num_samples != -1);
|
||||
auto num_picks_fn =
|
||||
GetSamplingUniformNumPicksFn<IdxType>(num_samples, replace);
|
||||
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
|
||||
return CSRRowWisePickFused<IdxType, map_seed_nodes>(
|
||||
mat, rows, seed_mapping, new_seed_nodes, num_samples, replace, pick_fn,
|
||||
num_picks_fn);
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, true>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingUniformFused<kDGLCPU, int32_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int32_t>*, int64_t, bool);
|
||||
template std::pair<CSRMatrix, IdArray>
|
||||
CSRRowWiseSamplingUniformFused<kDGLCPU, int64_t, false>(
|
||||
CSRMatrix, IdArray, IdArray, std::vector<int64_t>*, int64_t, bool);
|
||||
|
||||
template <DGLDeviceType XPU, typename IdxType>
|
||||
COOMatrix CSRRowWisePerEtypeSamplingUniform(
|
||||
CSRMatrix mat, IdArray rows, const std::vector<int64_t>& eid2etype_offset,
|
||||
|
||||
@@ -6,13 +6,16 @@
|
||||
|
||||
#include <dgl/array.h>
|
||||
#include <dgl/aten/macro.h>
|
||||
#include <dgl/immutable_graph.h>
|
||||
#include <dgl/packed_func_ext.h>
|
||||
#include <dgl/runtime/container.h>
|
||||
#include <dgl/runtime/parallel_for.h>
|
||||
#include <dgl/sampling/neighbor.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
|
||||
#include "../../../array/cpu/concurrent_id_hash_map.h"
|
||||
#include "../../../c_api_common.h"
|
||||
#include "../../unit_graph.h"
|
||||
|
||||
@@ -22,6 +25,76 @@ using namespace dgl::aten;
|
||||
namespace dgl {
|
||||
namespace sampling {
|
||||
|
||||
template <typename IdType>
|
||||
void ExcludeCertainEdgesFused(
|
||||
std::vector<CSRMatrix>* sampled_graphs, std::vector<IdArray>* induced_edges,
|
||||
std::vector<IdArray>* sampled_coo_rows,
|
||||
const std::vector<IdArray>& exclude_edges,
|
||||
std::vector<FloatArray>* weights = nullptr) {
|
||||
int etypes = (*sampled_graphs).size();
|
||||
std::vector<IdArray> remain_induced_edges(etypes);
|
||||
std::vector<IdArray> remain_indptrs(etypes);
|
||||
std::vector<IdArray> remain_indices(etypes);
|
||||
std::vector<IdArray> remain_coo_rows(etypes);
|
||||
std::vector<FloatArray> remain_weights(etypes);
|
||||
for (int etype = 0; etype < etypes; ++etype) {
|
||||
if (exclude_edges[etype].GetSize() == 0 ||
|
||||
(*sampled_graphs)[etype].num_rows == 0) {
|
||||
remain_induced_edges[etype] = (*induced_edges)[etype];
|
||||
if (weights) remain_weights[etype] = (*weights)[etype];
|
||||
continue;
|
||||
}
|
||||
const auto dtype = weights && (*weights)[etype]->shape[0]
|
||||
? (*weights)[etype]->dtype
|
||||
: DGLDataType{kDGLFloat, 8 * sizeof(float), 1};
|
||||
ATEN_FLOAT_TYPE_SWITCH(dtype, FloatType, "weights", {
|
||||
IdType* indptr = (*sampled_graphs)[etype].indptr.Ptr<IdType>();
|
||||
IdType* indices = (*sampled_graphs)[etype].indices.Ptr<IdType>();
|
||||
IdType* coo_rows = (*sampled_coo_rows)[etype].Ptr<IdType>();
|
||||
IdType* induced_edges_data = (*induced_edges)[etype].Ptr<IdType>();
|
||||
FloatType* weights_data = weights && (*weights)[etype]->shape[0]
|
||||
? (*weights)[etype].Ptr<FloatType>()
|
||||
: nullptr;
|
||||
const IdType exclude_edges_len = exclude_edges[etype]->shape[0];
|
||||
std::sort(
|
||||
exclude_edges[etype].Ptr<IdType>(),
|
||||
exclude_edges[etype].Ptr<IdType>() + exclude_edges_len);
|
||||
const IdType* exclude_edges_data = exclude_edges[etype].Ptr<IdType>();
|
||||
IdType outIndices = 0;
|
||||
for (IdType row = 0; row < (*sampled_graphs)[etype].indptr->shape[0] - 1;
|
||||
++row) {
|
||||
auto tmp_row = indptr[row];
|
||||
if (outIndices != indptr[row]) indptr[row] = outIndices;
|
||||
for (IdType col = tmp_row; col < indptr[row + 1]; ++col) {
|
||||
if (!std::binary_search(
|
||||
exclude_edges_data, exclude_edges_data + exclude_edges_len,
|
||||
induced_edges_data[col])) {
|
||||
indices[outIndices] = indices[col];
|
||||
induced_edges_data[outIndices] = induced_edges_data[col];
|
||||
coo_rows[outIndices] = coo_rows[col];
|
||||
if (weights_data) weights_data[outIndices] = weights_data[col];
|
||||
++outIndices;
|
||||
}
|
||||
}
|
||||
}
|
||||
indptr[(*sampled_graphs)[etype].indptr->shape[0] - 1] = outIndices;
|
||||
remain_induced_edges[etype] =
|
||||
aten::IndexSelect((*induced_edges)[etype], 0, outIndices);
|
||||
remain_weights[etype] =
|
||||
weights_data ? aten::IndexSelect((*weights)[etype], 0, outIndices)
|
||||
: NullArray();
|
||||
remain_indices[etype] =
|
||||
aten::IndexSelect((*sampled_graphs)[etype].indices, 0, outIndices);
|
||||
(*sampled_coo_rows)[etype] =
|
||||
aten::IndexSelect((*sampled_coo_rows)[etype], 0, outIndices);
|
||||
(*sampled_graphs)[etype] = CSRMatrix(
|
||||
(*sampled_graphs)[etype].num_rows, outIndices,
|
||||
(*sampled_graphs)[etype].indptr, remain_indices[etype],
|
||||
remain_induced_edges[etype]);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
std::pair<HeteroSubgraph, std::vector<FloatArray>> ExcludeCertainEdges(
|
||||
const HeteroSubgraph& sg, const std::vector<IdArray>& exclude_edges,
|
||||
const std::vector<FloatArray>* weights = nullptr) {
|
||||
@@ -266,6 +339,242 @@ HeteroSubgraph SampleNeighbors(
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename IdType>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
SampleNeighborsFused(
|
||||
const HeteroGraphPtr hg, const std::vector<IdArray>& nodes,
|
||||
const std::vector<IdArray>& mapping, const std::vector<int64_t>& fanouts,
|
||||
EdgeDir dir, const std::vector<NDArray>& prob_or_mask,
|
||||
const std::vector<IdArray>& exclude_edges, bool replace) {
|
||||
CHECK_EQ(nodes.size(), hg->NumVertexTypes())
|
||||
<< "Number of node ID tensors must match the number of node types.";
|
||||
CHECK_EQ(fanouts.size(), hg->NumEdgeTypes())
|
||||
<< "Number of fanout values must match the number of edge types.";
|
||||
CHECK_EQ(prob_or_mask.size(), hg->NumEdgeTypes())
|
||||
<< "Number of probability tensors must match the number of edge types.";
|
||||
|
||||
DGLContext ctx = aten::GetContextOf(nodes);
|
||||
|
||||
std::vector<CSRMatrix> sampled_graphs;
|
||||
std::vector<IdArray> sampled_coo_rows;
|
||||
std::vector<IdArray> induced_edges;
|
||||
std::vector<IdArray> induced_vertices;
|
||||
std::vector<int64_t> num_nodes_per_type;
|
||||
std::vector<std::vector<IdType>> new_nodes_vec(hg->NumVertexTypes());
|
||||
std::vector<int> seed_nodes_mapped(hg->NumVertexTypes(), 0);
|
||||
|
||||
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
|
||||
auto pair = hg->meta_graph()->FindEdge(etype);
|
||||
const dgl_type_t src_vtype = pair.first;
|
||||
const dgl_type_t dst_vtype = pair.second;
|
||||
const dgl_type_t rhs_node_type =
|
||||
(dir == EdgeDir::kOut) ? src_vtype : dst_vtype;
|
||||
const IdArray nodes_ntype = nodes[rhs_node_type];
|
||||
const int64_t num_nodes = nodes_ntype->shape[0];
|
||||
|
||||
if (num_nodes == 0 || fanouts[etype] == 0) {
|
||||
// Nothing to sample for this etype, create a placeholder
|
||||
sampled_graphs.push_back(CSRMatrix());
|
||||
sampled_coo_rows.push_back(IdArray());
|
||||
induced_edges.push_back(aten::NullArray(hg->DataType(), ctx));
|
||||
} else {
|
||||
bool map_seed_nodes = !seed_nodes_mapped[rhs_node_type];
|
||||
// sample from one relation graph
|
||||
std::pair<CSRMatrix, IdArray> sampled_graph;
|
||||
auto sampling_fn = map_seed_nodes
|
||||
? aten::CSRRowWiseSamplingFused<IdType, true>
|
||||
: aten::CSRRowWiseSamplingFused<IdType, false>;
|
||||
auto req_fmt = (dir == EdgeDir::kOut) ? CSR_CODE : CSC_CODE;
|
||||
auto avail_fmt = hg->SelectFormat(etype, req_fmt);
|
||||
switch (avail_fmt) {
|
||||
case SparseFormat::kCSR:
|
||||
CHECK(dir == EdgeDir::kOut)
|
||||
<< "Cannot sample out edges on CSC matrix.";
|
||||
// In heterographs nodes of two diffrent types can be connected
|
||||
// therefore two diffrent mappings and node vectors are needed
|
||||
sampled_graph = sampling_fn(
|
||||
hg->GetCSRMatrix(etype), nodes_ntype, mapping[src_vtype],
|
||||
&new_nodes_vec[src_vtype], fanouts[etype], prob_or_mask[etype],
|
||||
replace);
|
||||
break;
|
||||
case SparseFormat::kCSC:
|
||||
CHECK(dir == EdgeDir::kIn) << "Cannot sample in edges on CSR matrix.";
|
||||
sampled_graph = sampling_fn(
|
||||
hg->GetCSCMatrix(etype), nodes_ntype, mapping[dst_vtype],
|
||||
&new_nodes_vec[dst_vtype], fanouts[etype], prob_or_mask[etype],
|
||||
replace);
|
||||
break;
|
||||
default:
|
||||
LOG(FATAL) << "Unsupported sparse format.";
|
||||
}
|
||||
seed_nodes_mapped[rhs_node_type]++;
|
||||
sampled_graphs.push_back(sampled_graph.first);
|
||||
if (sampled_graph.first.data.defined())
|
||||
induced_edges.push_back(sampled_graph.first.data);
|
||||
else
|
||||
induced_edges.push_back(
|
||||
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx));
|
||||
sampled_coo_rows.push_back(sampled_graph.second);
|
||||
}
|
||||
}
|
||||
|
||||
if (!exclude_edges.empty()) {
|
||||
ExcludeCertainEdgesFused<IdType>(
|
||||
&sampled_graphs, &induced_edges, &sampled_coo_rows, exclude_edges);
|
||||
for (size_t i = 0; i < hg->NumEdgeTypes(); i++) {
|
||||
if (sampled_graphs[i].data.defined())
|
||||
induced_edges[i] = std::move(sampled_graphs[i].data);
|
||||
else
|
||||
induced_edges[i] =
|
||||
aten::NullArray(DGLDataType{kDGLInt, sizeof(IdType) * 8, 1}, ctx);
|
||||
}
|
||||
}
|
||||
|
||||
// map indices
|
||||
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
|
||||
auto pair = hg->meta_graph()->FindEdge(etype);
|
||||
const dgl_type_t src_vtype = pair.first;
|
||||
const dgl_type_t dst_vtype = pair.second;
|
||||
const dgl_type_t lhs_node_type =
|
||||
(dir == EdgeDir::kIn) ? src_vtype : dst_vtype;
|
||||
if (sampled_graphs[etype].num_cols != 0) {
|
||||
auto num_cols = sampled_graphs[etype].num_cols;
|
||||
int num_threads_col = runtime::compute_num_threads(0, num_cols, 1);
|
||||
std::vector<IdType> global_prefix_col(num_threads_col + 1, 0);
|
||||
std::vector<std::vector<IdType>> src_nodes_local(num_threads_col);
|
||||
IdType* mapping_data_dst = mapping[lhs_node_type].Ptr<IdType>();
|
||||
IdType* cdata = sampled_graphs[etype].indices.Ptr<IdType>();
|
||||
#pragma omp parallel num_threads(num_threads_col)
|
||||
{
|
||||
const int thread_id = omp_get_thread_num();
|
||||
num_threads_col = omp_get_num_threads();
|
||||
|
||||
const int64_t start_i =
|
||||
thread_id * (num_cols / num_threads_col) +
|
||||
std::min(
|
||||
static_cast<int64_t>(thread_id), num_cols % num_threads_col);
|
||||
const int64_t end_i = (thread_id + 1) * (num_cols / num_threads_col) +
|
||||
std::min(
|
||||
static_cast<int64_t>(thread_id + 1),
|
||||
num_cols % num_threads_col);
|
||||
assert(thread_id + 1 < num_threads_col || end_i == num_cols);
|
||||
for (int64_t i = start_i; i < end_i; ++i) {
|
||||
int64_t picked_idx = cdata[i];
|
||||
bool spot_claimed =
|
||||
BoolCompareAndSwap<IdType>(&mapping_data_dst[picked_idx]);
|
||||
if (spot_claimed) src_nodes_local[thread_id].push_back(picked_idx);
|
||||
}
|
||||
global_prefix_col[thread_id + 1] = src_nodes_local[thread_id].size();
|
||||
|
||||
#pragma omp barrier
|
||||
#pragma omp master
|
||||
{
|
||||
global_prefix_col[0] = new_nodes_vec[lhs_node_type].size();
|
||||
for (int t = 0; t < num_threads_col; ++t) {
|
||||
global_prefix_col[t + 1] += global_prefix_col[t];
|
||||
}
|
||||
}
|
||||
|
||||
#pragma omp barrier
|
||||
int64_t mapping_shift = global_prefix_col[thread_id];
|
||||
for (size_t i = 0; i < src_nodes_local[thread_id].size(); ++i)
|
||||
mapping_data_dst[src_nodes_local[thread_id][i]] = mapping_shift + i;
|
||||
|
||||
#pragma omp barrier
|
||||
for (int64_t i = start_i; i < end_i; ++i) {
|
||||
IdType picked_idx = cdata[i];
|
||||
IdType mapped_idx = mapping_data_dst[picked_idx];
|
||||
cdata[i] = mapped_idx;
|
||||
}
|
||||
}
|
||||
IdType offset = new_nodes_vec[lhs_node_type].size();
|
||||
new_nodes_vec[lhs_node_type].resize(global_prefix_col.back());
|
||||
for (int thread_id = 0; thread_id < num_threads_col; ++thread_id) {
|
||||
memcpy(
|
||||
new_nodes_vec[lhs_node_type].data() + offset,
|
||||
&src_nodes_local[thread_id][0],
|
||||
src_nodes_local[thread_id].size() * sizeof(IdType));
|
||||
offset += src_nodes_local[thread_id].size();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// counting how many nodes of each ntype were sampled
|
||||
num_nodes_per_type.resize(2 * hg->NumVertexTypes());
|
||||
for (size_t i = 0; i < hg->NumVertexTypes(); i++) {
|
||||
num_nodes_per_type[i] = new_nodes_vec[i].size();
|
||||
num_nodes_per_type[hg->NumVertexTypes() + i] = nodes[i]->shape[0];
|
||||
induced_vertices.push_back(
|
||||
VecToIdArray(new_nodes_vec[i], sizeof(IdType) * 8));
|
||||
}
|
||||
|
||||
std::vector<HeteroGraphPtr> subrels(hg->NumEdgeTypes());
|
||||
for (dgl_type_t etype = 0; etype < hg->NumEdgeTypes(); ++etype) {
|
||||
auto pair = hg->meta_graph()->FindEdge(etype);
|
||||
const dgl_type_t src_vtype = pair.first;
|
||||
const dgl_type_t dst_vtype = pair.second;
|
||||
if (sampled_graphs[etype].num_rows == 0) {
|
||||
subrels[etype] = UnitGraph::Empty(
|
||||
2, new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],
|
||||
hg->DataType(), ctx);
|
||||
} else {
|
||||
CSRMatrix graph = sampled_graphs[etype];
|
||||
if (dir == EdgeDir::kOut) {
|
||||
subrels[etype] = UnitGraph::CreateFromCSRAndCOO(
|
||||
2,
|
||||
CSRMatrix(
|
||||
nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),
|
||||
graph.indptr, graph.indices,
|
||||
Range(
|
||||
0, graph.indices->shape[0], graph.indices->dtype.bits,
|
||||
ctx)),
|
||||
COOMatrix(
|
||||
nodes[src_vtype]->shape[0], new_nodes_vec[dst_vtype].size(),
|
||||
sampled_coo_rows[etype], graph.indices),
|
||||
ALL_CODE);
|
||||
} else {
|
||||
subrels[etype] = UnitGraph::CreateFromCSCAndCOO(
|
||||
2,
|
||||
CSRMatrix(
|
||||
nodes[dst_vtype]->shape[0], new_nodes_vec[src_vtype].size(),
|
||||
graph.indptr, graph.indices,
|
||||
Range(
|
||||
0, graph.indices->shape[0], graph.indices->dtype.bits,
|
||||
ctx)),
|
||||
COOMatrix(
|
||||
new_nodes_vec[src_vtype].size(), nodes[dst_vtype]->shape[0],
|
||||
graph.indices, sampled_coo_rows[etype]),
|
||||
ALL_CODE);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HeteroSubgraph ret;
|
||||
|
||||
const auto meta_graph = hg->meta_graph();
|
||||
const EdgeArray etypes = meta_graph->Edges("eid");
|
||||
const IdArray new_dst = Add(etypes.dst, hg->NumVertexTypes());
|
||||
|
||||
const auto new_meta_graph = ImmutableGraph::CreateFromCOO(
|
||||
hg->NumVertexTypes() * 2, etypes.src, new_dst);
|
||||
|
||||
HeteroGraphPtr new_graph =
|
||||
CreateHeteroGraph(new_meta_graph, subrels, num_nodes_per_type);
|
||||
return std::make_tuple(new_graph, induced_edges, induced_vertices);
|
||||
}
|
||||
|
||||
template std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
SampleNeighborsFused<int64_t>(
|
||||
const HeteroGraphPtr, const std::vector<IdArray>&,
|
||||
const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,
|
||||
const std::vector<NDArray>&, const std::vector<IdArray>&, bool);
|
||||
|
||||
template std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
SampleNeighborsFused<int32_t>(
|
||||
const HeteroGraphPtr, const std::vector<IdArray>&,
|
||||
const std::vector<IdArray>&, const std::vector<int64_t>&, EdgeDir,
|
||||
const std::vector<NDArray>&, const std::vector<IdArray>&, bool);
|
||||
|
||||
HeteroSubgraph SampleNeighborsEType(
|
||||
const HeteroGraphPtr hg, const IdArray nodes,
|
||||
const std::vector<int64_t>& eid2etype_offset,
|
||||
@@ -568,6 +877,47 @@ DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighbors")
|
||||
*rv = HeteroSubgraphRef(subg);
|
||||
});
|
||||
|
||||
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsFused")
|
||||
.set_body([](DGLArgs args, DGLRetValue* rv) {
|
||||
HeteroGraphRef hg = args[0];
|
||||
const auto& nodes = ListValueToVector<IdArray>(args[1]);
|
||||
auto mapping = ListValueToVector<IdArray>(args[2]);
|
||||
IdArray fanouts_array = args[3];
|
||||
const auto& fanouts = fanouts_array.ToVector<int64_t>();
|
||||
const std::string dir_str = args[4];
|
||||
const auto& prob_or_mask = ListValueToVector<NDArray>(args[5]);
|
||||
const auto& exclude_edges = ListValueToVector<IdArray>(args[6]);
|
||||
const bool replace = args[7];
|
||||
|
||||
CHECK(dir_str == "in" || dir_str == "out")
|
||||
<< "Invalid edge direction. Must be \"in\" or \"out\".";
|
||||
EdgeDir dir = (dir_str == "in") ? EdgeDir::kIn : EdgeDir::kOut;
|
||||
|
||||
HeteroGraphPtr new_graph;
|
||||
std::vector<IdArray> induced_edges;
|
||||
std::vector<IdArray> induced_vertices;
|
||||
|
||||
ATEN_ID_TYPE_SWITCH(hg->DataType(), IdType, {
|
||||
std::tie(new_graph, induced_edges, induced_vertices) =
|
||||
SampleNeighborsFused<IdType>(
|
||||
hg.sptr(), nodes, mapping, fanouts, dir, prob_or_mask,
|
||||
exclude_edges, replace);
|
||||
});
|
||||
|
||||
List<Value> lhs_nodes_ref;
|
||||
for (IdArray& array : induced_vertices)
|
||||
lhs_nodes_ref.push_back(Value(MakeValue(array)));
|
||||
List<Value> induced_edges_ref;
|
||||
for (IdArray& array : induced_edges)
|
||||
induced_edges_ref.push_back(Value(MakeValue(array)));
|
||||
List<ObjectRef> ret;
|
||||
ret.push_back(HeteroGraphRef(new_graph));
|
||||
ret.push_back(lhs_nodes_ref);
|
||||
ret.push_back(induced_edges_ref);
|
||||
|
||||
*rv = ret;
|
||||
});
|
||||
|
||||
DGL_REGISTER_GLOBAL("sampling.neighbor._CAPI_DGLSampleNeighborsTopk")
|
||||
.set_body([](DGLArgs args, DGLRetValue* rv) {
|
||||
HeteroGraphRef hg = args[0];
|
||||
|
||||
@@ -1218,6 +1218,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSR(
|
||||
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csr, nullptr, formats));
|
||||
}
|
||||
|
||||
HeteroGraphPtr UnitGraph::CreateFromCSRAndCOO(
|
||||
int64_t num_vtypes, const aten::CSRMatrix& csr, const aten::COOMatrix& coo,
|
||||
dgl_format_code_t formats) {
|
||||
CHECK(num_vtypes == 1 || num_vtypes == 2);
|
||||
CHECK_EQ(coo.num_rows, csr.num_rows);
|
||||
CHECK_EQ(coo.num_cols, csr.num_cols);
|
||||
if (num_vtypes == 1) {
|
||||
CHECK_EQ(csr.num_rows, csr.num_cols);
|
||||
}
|
||||
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
|
||||
CSRPtr csrPtr(new CSR(mg, csr));
|
||||
COOPtr cooPtr(new COO(mg, coo));
|
||||
return HeteroGraphPtr(new UnitGraph(mg, nullptr, csrPtr, cooPtr, formats));
|
||||
}
|
||||
|
||||
HeteroGraphPtr UnitGraph::CreateFromCSC(
|
||||
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
|
||||
IdArray indices, IdArray edge_ids, dgl_format_code_t formats) {
|
||||
@@ -1237,6 +1252,21 @@ HeteroGraphPtr UnitGraph::CreateFromCSC(
|
||||
return HeteroGraphPtr(new UnitGraph(mg, csc, nullptr, nullptr, formats));
|
||||
}
|
||||
|
||||
HeteroGraphPtr UnitGraph::CreateFromCSCAndCOO(
|
||||
int64_t num_vtypes, const aten::CSRMatrix& csc, const aten::COOMatrix& coo,
|
||||
dgl_format_code_t formats) {
|
||||
CHECK(num_vtypes == 1 || num_vtypes == 2);
|
||||
CHECK_EQ(coo.num_rows, csc.num_cols);
|
||||
CHECK_EQ(coo.num_cols, csc.num_rows);
|
||||
if (num_vtypes == 1) {
|
||||
CHECK_EQ(csc.num_rows, csc.num_cols);
|
||||
}
|
||||
auto mg = CreateUnitGraphMetaGraph(num_vtypes);
|
||||
CSRPtr cscPtr(new CSR(mg, csc));
|
||||
COOPtr cooPtr(new COO(mg, coo));
|
||||
return HeteroGraphPtr(new UnitGraph(mg, cscPtr, nullptr, cooPtr, formats));
|
||||
}
|
||||
|
||||
HeteroGraphPtr UnitGraph::AsNumBits(HeteroGraphPtr g, uint8_t bits) {
|
||||
if (g->NumBits() == bits) {
|
||||
return g;
|
||||
|
||||
@@ -190,6 +190,12 @@ class UnitGraph : public BaseHeteroGraph {
|
||||
int64_t num_vtypes, const aten::CSRMatrix& mat,
|
||||
dgl_format_code_t formats = ALL_CODE);
|
||||
|
||||
/** @brief Create a graph from (out) CSR and COO arrays, both representing the
|
||||
* same graph */
|
||||
static HeteroGraphPtr CreateFromCSRAndCOO(
|
||||
int64_t num_vtypes, const aten::CSRMatrix& csr,
|
||||
const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);
|
||||
|
||||
/** @brief Create a graph from (in) CSC arrays */
|
||||
static HeteroGraphPtr CreateFromCSC(
|
||||
int64_t num_vtypes, int64_t num_src, int64_t num_dst, IdArray indptr,
|
||||
@@ -199,6 +205,12 @@ class UnitGraph : public BaseHeteroGraph {
|
||||
int64_t num_vtypes, const aten::CSRMatrix& mat,
|
||||
dgl_format_code_t formats = ALL_CODE);
|
||||
|
||||
/** @brief Create a graph from (in) CSC and COO arrays, both representing the
|
||||
* same graph */
|
||||
static HeteroGraphPtr CreateFromCSCAndCOO(
|
||||
int64_t num_vtypes, const aten::CSRMatrix& csc,
|
||||
const aten::COOMatrix& coo, dgl_format_code_t formats = ALL_CODE);
|
||||
|
||||
/** @brief Convert the graph to use the given number of bits for storage */
|
||||
static HeteroGraphPtr AsNumBits(HeteroGraphPtr g, uint8_t bits);
|
||||
|
||||
|
||||
@@ -7,6 +7,11 @@ import dgl
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
sample_neighbors_fusing_mode = {
|
||||
True: dgl.sampling.sample_neighbors_fused,
|
||||
False: dgl.sampling.sample_neighbors,
|
||||
}
|
||||
|
||||
|
||||
def check_random_walk(g, metapath, traces, ntypes, prob=None, trace_eids=None):
|
||||
traces = F.asnumpy(traces)
|
||||
@@ -555,15 +560,18 @@ def _gen_neighbor_topk_test_graph(hypersparse, reverse):
|
||||
return g, hg
|
||||
|
||||
|
||||
def _test_sample_neighbors(hypersparse, prob):
|
||||
def _test_sample_neighbors(hypersparse, prob, fused):
|
||||
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, False)
|
||||
|
||||
def _test1(p, replace):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 1], -1, prob=p, replace=replace
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
|
||||
u_ans, v_ans, e_ans = g.in_edges([0, 1], form="all")
|
||||
if p is not None:
|
||||
emask = F.gather_row(g.edata[p], e_ans)
|
||||
@@ -576,12 +584,17 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
assert uv == uv_ans
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 1], 2, prob=p, replace=replace
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
|
||||
assert subg.num_edges() == 4
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
|
||||
|
||||
assert set(F.asnumpy(F.unique(v))) == {0, 1}
|
||||
assert F.array_equal(
|
||||
F.astype(g.has_edges_between(u, v), F.int64),
|
||||
@@ -600,11 +613,14 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
_test1(prob, False) # w/o replacement, uniform
|
||||
|
||||
def _test2(p, replace): # fanout > #neighbors
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 2], -1, prob=p, replace=replace
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
|
||||
u_ans, v_ans, e_ans = g.in_edges([0, 2], form="all")
|
||||
if p is not None:
|
||||
emask = F.gather_row(g.edata[p], e_ans)
|
||||
@@ -617,13 +633,16 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
assert uv == uv_ans
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 2], 2, prob=p, replace=replace
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
num_edges = 4 if replace else 3
|
||||
assert subg.num_edges() == num_edges
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.srcdata[dgl.NID][u], subg.dstdata[dgl.NID][v]
|
||||
assert set(F.asnumpy(F.unique(v))) == {0, 2}
|
||||
assert F.array_equal(
|
||||
F.astype(g.has_edges_between(u, v), F.int64),
|
||||
@@ -641,10 +660,13 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
_test2(prob, False) # w/o replacement, uniform
|
||||
|
||||
def _test3(p, replace):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
hg, {"user": [0, 1], "game": 0}, -1, prob=p, replace=replace
|
||||
)
|
||||
assert len(subg.ntypes) == 3
|
||||
if not fused:
|
||||
assert len(subg.ntypes) == 3
|
||||
assert len(subg.srctypes) == 3
|
||||
assert len(subg.dsttypes) == 3
|
||||
assert len(subg.etypes) == 4
|
||||
assert subg["follow"].num_edges() == 6 if p is None else 4
|
||||
assert subg["play"].num_edges() == 1
|
||||
@@ -652,10 +674,13 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
assert subg["flips"].num_edges() == 0
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
hg, {"user": [0, 1], "game": 0}, 2, prob=p, replace=replace
|
||||
)
|
||||
assert len(subg.ntypes) == 3
|
||||
if not fused:
|
||||
assert len(subg.ntypes) == 3
|
||||
assert len(subg.srctypes) == 3
|
||||
assert len(subg.dsttypes) == 3
|
||||
assert len(subg.etypes) == 4
|
||||
assert subg["follow"].num_edges() == 4
|
||||
assert subg["play"].num_edges() == 2 if replace else 1
|
||||
@@ -667,13 +692,16 @@ def _test_sample_neighbors(hypersparse, prob):
|
||||
|
||||
# test different fanouts for different relations
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
hg,
|
||||
{"user": [0, 1], "game": 0, "coin": 0},
|
||||
{"follow": 1, "play": 2, "liked-by": 0, "flips": -1},
|
||||
replace=True,
|
||||
)
|
||||
assert len(subg.ntypes) == 3
|
||||
if not fused:
|
||||
assert len(subg.ntypes) == 3
|
||||
assert len(subg.srctypes) == 3
|
||||
assert len(subg.dsttypes) == 3
|
||||
assert len(subg.etypes) == 4
|
||||
assert subg["follow"].num_edges() == 2
|
||||
assert subg["play"].num_edges() == 2
|
||||
@@ -795,15 +823,19 @@ def _test_sample_labors(hypersparse, prob):
|
||||
assert subg["flips"].num_edges() == 4
|
||||
|
||||
|
||||
def _test_sample_neighbors_outedge(hypersparse):
|
||||
def _test_sample_neighbors_outedge(hypersparse, fused):
|
||||
g, hg = _gen_neighbor_sampling_test_graph(hypersparse, True)
|
||||
|
||||
def _test1(p, replace):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 1], -1, prob=p, replace=replace, edge_dir="out"
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
|
||||
u_ans, v_ans, e_ans = g.out_edges([0, 1], form="all")
|
||||
if p is not None:
|
||||
emask = F.gather_row(g.edata[p], e_ans)
|
||||
@@ -816,12 +848,15 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
assert uv == uv_ans
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 1], 2, prob=p, replace=replace, edge_dir="out"
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
assert subg.num_edges() == 4
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
|
||||
assert set(F.asnumpy(F.unique(u))) == {0, 1}
|
||||
assert F.array_equal(
|
||||
F.astype(g.has_edges_between(u, v), F.int64),
|
||||
@@ -842,11 +877,14 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
_test1("prob", False) # w/o replacement
|
||||
|
||||
def _test2(p, replace): # fanout > #neighbors
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 2], -1, prob=p, replace=replace, edge_dir="out"
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
|
||||
u_ans, v_ans, e_ans = g.out_edges([0, 2], form="all")
|
||||
if p is not None:
|
||||
emask = F.gather_row(g.edata[p], e_ans)
|
||||
@@ -859,13 +897,17 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
assert uv == uv_ans
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
g, [0, 2], 2, prob=p, replace=replace, edge_dir="out"
|
||||
)
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
if not fused:
|
||||
assert subg.num_nodes() == g.num_nodes()
|
||||
num_edges = 4 if replace else 3
|
||||
assert subg.num_edges() == num_edges
|
||||
u, v = subg.edges()
|
||||
if fused:
|
||||
u, v = subg.dstdata[dgl.NID][u], subg.srcdata[dgl.NID][v]
|
||||
|
||||
assert set(F.asnumpy(F.unique(u))) == {0, 2}
|
||||
assert F.array_equal(
|
||||
F.astype(g.has_edges_between(u, v), F.int64),
|
||||
@@ -885,7 +927,7 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
_test2("prob", False) # w/o replacement
|
||||
|
||||
def _test3(p, replace):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
hg,
|
||||
{"user": [0, 1], "game": 0},
|
||||
-1,
|
||||
@@ -893,7 +935,11 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
replace=replace,
|
||||
edge_dir="out",
|
||||
)
|
||||
assert len(subg.ntypes) == 3
|
||||
|
||||
if not fused:
|
||||
assert len(subg.ntypes) == 3
|
||||
assert len(subg.srctypes) == 3
|
||||
assert len(subg.dsttypes) == 3
|
||||
assert len(subg.etypes) == 4
|
||||
assert subg["follow"].num_edges() == 6 if p is None else 4
|
||||
assert subg["play"].num_edges() == 1
|
||||
@@ -901,7 +947,7 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
assert subg["flips"].num_edges() == 0
|
||||
|
||||
for i in range(10):
|
||||
subg = dgl.sampling.sample_neighbors(
|
||||
subg = sample_neighbors_fusing_mode[fused](
|
||||
hg,
|
||||
{"user": [0, 1], "game": 0},
|
||||
2,
|
||||
@@ -909,7 +955,10 @@ def _test_sample_neighbors_outedge(hypersparse):
|
||||
replace=replace,
|
||||
edge_dir="out",
|
||||
)
|
||||
assert len(subg.ntypes) == 3
|
||||
if not fused:
|
||||
assert len(subg.ntypes) == 3
|
||||
assert len(subg.srctypes) == 3
|
||||
assert len(subg.dsttypes) == 3
|
||||
assert len(subg.etypes) == 4
|
||||
assert subg["follow"].num_edges() == 4
|
||||
assert subg["play"].num_edges() == 2 if replace else 1
|
||||
@@ -1077,7 +1126,9 @@ def _test_sample_neighbors_topk_outedge(hypersparse):
|
||||
|
||||
|
||||
def test_sample_neighbors_noprob():
|
||||
_test_sample_neighbors(False, None)
|
||||
_test_sample_neighbors(False, None, False)
|
||||
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
|
||||
_test_sample_neighbors(False, None, True)
|
||||
# _test_sample_neighbors(True)
|
||||
|
||||
|
||||
@@ -1086,7 +1137,9 @@ def test_sample_labors_noprob():
|
||||
|
||||
|
||||
def test_sample_neighbors_prob():
|
||||
_test_sample_neighbors(False, "prob")
|
||||
_test_sample_neighbors(False, "prob", False)
|
||||
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
|
||||
_test_sample_neighbors(False, "prob", True)
|
||||
# _test_sample_neighbors(True)
|
||||
|
||||
|
||||
@@ -1095,7 +1148,9 @@ def test_sample_labors_prob():
|
||||
|
||||
|
||||
def test_sample_neighbors_outedge():
|
||||
_test_sample_neighbors_outedge(False)
|
||||
_test_sample_neighbors_outedge(False, False)
|
||||
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
|
||||
_test_sample_neighbors_outedge(False, True)
|
||||
# _test_sample_neighbors_outedge(True)
|
||||
|
||||
|
||||
@@ -1107,7 +1162,9 @@ def test_sample_neighbors_outedge():
|
||||
reason="GPU sample neighbors with mask not implemented",
|
||||
)
|
||||
def test_sample_neighbors_mask():
|
||||
_test_sample_neighbors(False, "mask")
|
||||
_test_sample_neighbors(False, "mask", False)
|
||||
if F._default_context_str != "gpu" and F.backend_name == "pytorch":
|
||||
_test_sample_neighbors(False, "mask", True)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
@@ -1128,21 +1185,26 @@ def test_sample_neighbors_topk_outedge():
|
||||
# _test_sample_neighbors_topk_outedge(True)
|
||||
|
||||
|
||||
def test_sample_neighbors_with_0deg():
|
||||
@pytest.mark.parametrize("fused", [False, True])
|
||||
def test_sample_neighbors_with_0deg(fused):
|
||||
if fused and (
|
||||
F._default_context_str == "gpu" or F.backend_name != "pytorch"
|
||||
):
|
||||
pytest.skip("Fused sampling support CPU with backend PyTorch.")
|
||||
g = dgl.graph(([], []), num_nodes=5).to(F.ctx())
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=False
|
||||
)
|
||||
assert sg.num_edges() == 0
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="in", replace=True
|
||||
)
|
||||
assert sg.num_edges() == 0
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=False
|
||||
)
|
||||
assert sg.num_edges() == 0
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g, F.tensor([1, 2], dtype=F.int64), 2, edge_dir="out", replace=True
|
||||
)
|
||||
assert sg.num_edges() == 0
|
||||
@@ -1274,7 +1336,7 @@ def test_sample_neighbors_biased_homogeneous():
|
||||
)
|
||||
def test_sample_neighbors_biased_bipartite():
|
||||
g = create_test_graph(100, 30, True)
|
||||
num_dst = g.number_of_dst_nodes()
|
||||
num_dst = g.num_dst_nodes()
|
||||
bias = F.tensor([0, 0.01, 10, 10], dtype=F.float32)
|
||||
|
||||
def check_num(nodes, tag):
|
||||
@@ -1492,7 +1554,12 @@ def test_sample_neighbors_etype_sorted_homogeneous(format_, direction):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["int32", "int64"])
|
||||
def test_sample_neighbors_exclude_edges_heteroG(dtype):
|
||||
@pytest.mark.parametrize("fused", [False, True])
|
||||
def test_sample_neighbors_exclude_edges_heteroG(dtype, fused):
|
||||
if fused and (
|
||||
F._default_context_str == "gpu" or F.backend_name != "pytorch"
|
||||
):
|
||||
pytest.skip("Fused sampling support CPU with backend PyTorch.")
|
||||
d_i_d_u_nodes = F.zerocopy_from_numpy(
|
||||
np.unique(np.random.randint(300, size=100, dtype=dtype))
|
||||
)
|
||||
@@ -1565,7 +1632,7 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
|
||||
("drug", "treats", "disease"): excluded_d_t_d_edges,
|
||||
}
|
||||
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g,
|
||||
{
|
||||
"drug": sampled_drug_node,
|
||||
@@ -1576,37 +1643,84 @@ def test_sample_neighbors_exclude_edges_heteroG(dtype):
|
||||
exclude_edges=excluded_edges,
|
||||
)
|
||||
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
did_excluded_nodes_U,
|
||||
did_excluded_nodes_V,
|
||||
etype=("drug", "interacts", "drug"),
|
||||
if fused:
|
||||
|
||||
def contain_edge(g, sg, etype, u, v):
|
||||
# set of subgraph graph edges deduced from original graph
|
||||
org_edges = set(
|
||||
map(
|
||||
tuple,
|
||||
np.stack(
|
||||
g.find_edges(sg.edges[etype].data[dgl.EID], etype),
|
||||
axis=1,
|
||||
),
|
||||
)
|
||||
)
|
||||
# set of excluded edges
|
||||
excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))
|
||||
|
||||
diff_set = org_edges - excluded_edges
|
||||
|
||||
return len(diff_set) != len(org_edges)
|
||||
|
||||
assert not contain_edge(
|
||||
g,
|
||||
sg,
|
||||
("drug", "interacts", "drug"),
|
||||
did_excluded_nodes_U,
|
||||
did_excluded_nodes_V,
|
||||
)
|
||||
assert not contain_edge(
|
||||
g,
|
||||
sg,
|
||||
("drug", "interacts", "gene"),
|
||||
dig_excluded_nodes_U,
|
||||
dig_excluded_nodes_V,
|
||||
)
|
||||
assert not contain_edge(
|
||||
g,
|
||||
sg,
|
||||
("drug", "treats", "disease"),
|
||||
dtd_excluded_nodes_U,
|
||||
dtd_excluded_nodes_V,
|
||||
)
|
||||
else:
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
did_excluded_nodes_U,
|
||||
did_excluded_nodes_V,
|
||||
etype=("drug", "interacts", "drug"),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
dig_excluded_nodes_U,
|
||||
dig_excluded_nodes_V,
|
||||
etype=("drug", "interacts", "gene"),
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
dig_excluded_nodes_U,
|
||||
dig_excluded_nodes_V,
|
||||
etype=("drug", "interacts", "gene"),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
dtd_excluded_nodes_U,
|
||||
dtd_excluded_nodes_V,
|
||||
etype=("drug", "treats", "disease"),
|
||||
assert not np.any(
|
||||
F.asnumpy(
|
||||
sg.has_edges_between(
|
||||
dtd_excluded_nodes_U,
|
||||
dtd_excluded_nodes_V,
|
||||
etype=("drug", "treats", "disease"),
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["int32", "int64"])
|
||||
def test_sample_neighbors_exclude_edges_homoG(dtype):
|
||||
@pytest.mark.parametrize("fused", [False, True])
|
||||
def test_sample_neighbors_exclude_edges_homoG(dtype, fused):
|
||||
if fused and (
|
||||
F._default_context_str == "gpu" or F.backend_name != "pytorch"
|
||||
):
|
||||
pytest.skip("Fused sampling support CPU with backend PyTorch.")
|
||||
u_nodes = F.zerocopy_from_numpy(
|
||||
np.unique(np.random.randint(300, size=100, dtype=dtype))
|
||||
)
|
||||
@@ -1629,13 +1743,33 @@ def test_sample_neighbors_exclude_edges_homoG(dtype):
|
||||
excluded_nodes_U = g_edges[U][b_idx:e_idx]
|
||||
excluded_nodes_V = g_edges[V][b_idx:e_idx]
|
||||
|
||||
sg = dgl.sampling.sample_neighbors(
|
||||
sg = sample_neighbors_fusing_mode[fused](
|
||||
g, sampled_node, sampled_amount, exclude_edges=excluded_edges
|
||||
)
|
||||
if fused:
|
||||
|
||||
assert not np.any(
|
||||
F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V))
|
||||
)
|
||||
def contain_edge(g, sg, u, v):
|
||||
# set of subgraph graph edges deduced from original graph
|
||||
org_edges = set(
|
||||
map(
|
||||
tuple,
|
||||
np.stack(
|
||||
g.find_edges(sg.edges["_E"].data[dgl.EID]), axis=1
|
||||
),
|
||||
)
|
||||
)
|
||||
# set of excluded edges
|
||||
excluded_edges = set(map(tuple, np.stack((u, v), axis=1)))
|
||||
|
||||
diff_set = org_edges - excluded_edges
|
||||
|
||||
return len(diff_set) != len(org_edges)
|
||||
|
||||
assert not contain_edge(g, sg, excluded_nodes_U, excluded_nodes_V)
|
||||
else:
|
||||
assert not np.any(
|
||||
F.asnumpy(sg.has_edges_between(excluded_nodes_U, excluded_nodes_V))
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", ["int32", "int64"])
|
||||
|
||||
Reference in New Issue
Block a user