mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[Sampling] Implement dgl.to_block() for the GPU (#2339)
* Add start of to_block gpu implementation * Pull in more changes from 0.4.2 cuda_to_block * Move more code to IdArray * Refactor DeviceNodeMapMaker * Updates * get compiling * Integrate to_block * Fix ID allocation * Minor fixes * Cleanup cuda calls to use cuda_common * Reduce kernel calls * Lint cleanup * Expand documentation * Remove unused function * Rename variables for consistency * Add doxygen comments * Fix file extension * Remove raw asynccopy for deviceapi * Remove unused function * Fix block/tile configuration * Add cuda_device_common.cuh * Add basic hashtable * Migrate part of hashtable * Refactor to use external hashtable * Make functions members * Format hash table functions * Migrate duplicate filling * Move last function over * Refactor with cu file * lint c++ code * Move context check to C++ code * Use macro switch * Add missing files * Update docstring * update docs * Move atomic functions * Refactor hashtable * Fix linting * Expand docs * Fix mismatched argument names * Switch doxygen comments from using @param to \param Co-authored-by: Jinjing Zhou <VoVAllen@users.noreply.github.com> Co-authored-by: Minjie Wang <wmjlyjemaine@gmail.com>
This commit is contained in:
@@ -236,7 +236,9 @@ macro(dgl_config_cuda out_variable)
|
||||
src/kernel/cuda/*.cc
|
||||
src/kernel/cuda/*.cu
|
||||
src/runtime/cuda/*.cc
|
||||
src/runtime/cuda/*.cu
|
||||
src/geometry/cuda/*.cu
|
||||
src/graph/transform/cuda/*.cu
|
||||
)
|
||||
|
||||
# NVCC flags
|
||||
|
||||
@@ -1604,7 +1604,7 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
|
||||
Parameters
|
||||
----------
|
||||
graph : DGLGraph
|
||||
The graph. Must be on CPU.
|
||||
The graph.
|
||||
dst_nodes : Tensor or dict[str, Tensor], optional
|
||||
The list of output nodes.
|
||||
|
||||
@@ -1633,6 +1633,9 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
|
||||
If :attr:`dst_nodes` is specified but it is not a superset of all the nodes that
|
||||
have at least one inbound edge.
|
||||
|
||||
If :attr:`dst_nodes` is not None, and :attr:`g` and :attr:`dst_nodes`
|
||||
are not in the same context.
|
||||
|
||||
Notes
|
||||
-----
|
||||
:func:`to_block` is most commonly used in customizing neighborhood sampling
|
||||
@@ -1715,8 +1718,6 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
|
||||
--------
|
||||
create_block
|
||||
"""
|
||||
assert g.device == F.cpu(), 'the graph must be on CPU'
|
||||
|
||||
if dst_nodes is None:
|
||||
# Find all nodes that appeared as destinations
|
||||
dst_nodes = defaultdict(list)
|
||||
@@ -1727,15 +1728,20 @@ def to_block(g, dst_nodes=None, include_dst_in_src=True):
|
||||
elif not isinstance(dst_nodes, Mapping):
|
||||
# dst_nodes is a Tensor, check if the g has only one type.
|
||||
if len(g.ntypes) > 1:
|
||||
raise ValueError(
|
||||
raise DGLError(
|
||||
'Graph has more than one node type; please specify a dict for dst_nodes.')
|
||||
dst_nodes = {g.ntypes[0]: dst_nodes}
|
||||
|
||||
dst_node_ids = [
|
||||
utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor()
|
||||
utils.toindex(dst_nodes.get(ntype, []), g._idtype_str).tousertensor(
|
||||
ctx=F.to_backend_ctx(g._graph.ctx))
|
||||
for ntype in g.ntypes]
|
||||
dst_node_ids_nd = [F.to_dgl_nd(nodes) for nodes in dst_node_ids]
|
||||
|
||||
for d in dst_node_ids_nd:
|
||||
if g._graph.ctx != d.ctx:
|
||||
raise ValueError('g and dst_nodes need to have the same context.')
|
||||
|
||||
new_graph_index, src_nodes_nd, induced_edges_nd = _CAPI_DGLToBlock(
|
||||
g._graph, dst_node_ids_nd, include_dst_in_src)
|
||||
|
||||
|
||||
539
src/graph/transform/cuda/cuda_to_block.cu
Normal file
539
src/graph/transform/cuda/cuda_to_block.cu
Normal file
@@ -0,0 +1,539 @@
|
||||
/*!
|
||||
* Copyright (c) 2020 by Contributors
|
||||
* \file graph/transform/cuda_to_block.cu
|
||||
* \brief Functions to convert a set of edges into a graph block with local
|
||||
* ids.
|
||||
*/
|
||||
|
||||
|
||||
#include <dgl/runtime/device_api.h>
|
||||
#include <dgl/immutable_graph.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "../../../runtime/cuda/cuda_common.h"
|
||||
#include "../../../runtime/cuda/cuda_hashtable.cuh"
|
||||
#include "../../heterograph.h"
|
||||
#include "../to_bipartite.h"
|
||||
|
||||
using namespace dgl::aten;
|
||||
using namespace dgl::runtime::cuda;
|
||||
|
||||
namespace dgl {
|
||||
namespace transform {
|
||||
|
||||
namespace {
|
||||
|
||||
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
|
||||
__device__ void map_vertex_ids(
|
||||
const IdType * const global,
|
||||
IdType * const new_global,
|
||||
const IdType num_vertices,
|
||||
const DeviceOrderedHashTable<IdType>& table) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
|
||||
using Mapping = typename OrderedHashTable<IdType>::Mapping;
|
||||
|
||||
const IdType tile_start = TILE_SIZE*blockIdx.x;
|
||||
const IdType tile_end = min(TILE_SIZE*(blockIdx.x+1), num_vertices);
|
||||
|
||||
for (IdType idx = threadIdx.x+tile_start; idx < tile_end; idx+=BLOCK_SIZE) {
|
||||
const Mapping& mapping = *table.Search(global[idx]);
|
||||
new_global[idx] = mapping.local;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Generate mapped edge endpoint ids.
|
||||
*
|
||||
* \tparam IdType The type of id.
|
||||
* \tparam BLOCK_SIZE The size of each thread block.
|
||||
* \tparam TILE_SIZE The number of edges to process per thread block.
|
||||
* \param global_srcs_device The source ids to map.
|
||||
* \param new_global_srcs_device The mapped source ids (output).
|
||||
* \param global_dsts_device The destination ids to map.
|
||||
* \param new_global_dsts_device The mapped destination ids (output).
|
||||
* \param num_edges The number of edges to map.
|
||||
* \param src_mapping The mapping of sources ids.
|
||||
* \param src_hash_size The the size of source id hash table/mapping.
|
||||
* \param dst_mapping The mapping of destination ids.
|
||||
* \param dst_hash_size The the size of destination id hash table/mapping.
|
||||
*/
|
||||
template<typename IdType, int BLOCK_SIZE, IdType TILE_SIZE>
|
||||
__global__ void map_edge_ids(
|
||||
const IdType * const global_srcs_device,
|
||||
IdType * const new_global_srcs_device,
|
||||
const IdType * const global_dsts_device,
|
||||
IdType * const new_global_dsts_device,
|
||||
const IdType num_edges,
|
||||
DeviceOrderedHashTable<IdType> src_mapping,
|
||||
DeviceOrderedHashTable<IdType> dst_mapping) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
assert(2 == gridDim.y);
|
||||
|
||||
if (blockIdx.y == 0) {
|
||||
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
|
||||
global_srcs_device,
|
||||
new_global_srcs_device,
|
||||
num_edges,
|
||||
src_mapping);
|
||||
} else {
|
||||
map_vertex_ids<IdType, BLOCK_SIZE, TILE_SIZE>(
|
||||
global_dsts_device,
|
||||
new_global_dsts_device,
|
||||
num_edges,
|
||||
dst_mapping);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename IdType>
|
||||
inline size_t RoundUpDiv(
|
||||
const IdType num,
|
||||
const size_t divisor) {
|
||||
return static_cast<IdType>(num/divisor) + (num % divisor == 0 ? 0 : 1);
|
||||
}
|
||||
|
||||
|
||||
template<typename IdType>
|
||||
inline IdType RoundUp(
|
||||
const IdType num,
|
||||
const size_t unit) {
|
||||
return RoundUpDiv(num, unit)*unit;
|
||||
}
|
||||
|
||||
|
||||
template<typename IdType>
|
||||
class DeviceNodeMap {
|
||||
public:
|
||||
using Mapping = typename OrderedHashTable<IdType>::Mapping;
|
||||
|
||||
DeviceNodeMap(
|
||||
const std::vector<int64_t>& num_nodes,
|
||||
DGLContext ctx,
|
||||
cudaStream_t stream) :
|
||||
num_types_(num_nodes.size()),
|
||||
rhs_offset_(num_types_/2),
|
||||
workspaces_(),
|
||||
hash_tables_(),
|
||||
ctx_(ctx) {
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
|
||||
hash_tables_.reserve(num_types_);
|
||||
workspaces_.reserve(num_types_);
|
||||
for (int64_t i = 0; i < num_types_; ++i) {
|
||||
hash_tables_.emplace_back(
|
||||
new OrderedHashTable<IdType>(
|
||||
num_nodes[i],
|
||||
ctx_,
|
||||
stream));
|
||||
}
|
||||
}
|
||||
|
||||
OrderedHashTable<IdType>& LhsHashTable(
|
||||
const size_t index) {
|
||||
return HashData(index);
|
||||
}
|
||||
|
||||
OrderedHashTable<IdType>& RhsHashTable(
|
||||
const size_t index) {
|
||||
return HashData(index+rhs_offset_);
|
||||
}
|
||||
|
||||
const OrderedHashTable<IdType>& LhsHashTable(
|
||||
const size_t index) const {
|
||||
return HashData(index);
|
||||
}
|
||||
|
||||
const OrderedHashTable<IdType>& RhsHashTable(
|
||||
const size_t index) const {
|
||||
return HashData(index+rhs_offset_);
|
||||
}
|
||||
|
||||
IdType LhsHashSize(
|
||||
const size_t index) const {
|
||||
return HashSize(index);
|
||||
}
|
||||
|
||||
IdType RhsHashSize(
|
||||
const size_t index) const {
|
||||
return HashSize(rhs_offset_+index);
|
||||
}
|
||||
|
||||
size_t Size() const {
|
||||
return hash_tables_.size();
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t num_types_;
|
||||
size_t rhs_offset_;
|
||||
std::vector<void*> workspaces_;
|
||||
std::vector<std::unique_ptr<OrderedHashTable<IdType>>> hash_tables_;
|
||||
DGLContext ctx_;
|
||||
|
||||
inline OrderedHashTable<IdType>& HashData(
|
||||
const size_t index) {
|
||||
CHECK_LT(index, hash_tables_.size());
|
||||
return *hash_tables_[index];
|
||||
}
|
||||
|
||||
inline const OrderedHashTable<IdType>& HashData(
|
||||
const size_t index) const {
|
||||
CHECK_LT(index, hash_tables_.size());
|
||||
return *hash_tables_[index];
|
||||
}
|
||||
|
||||
inline IdType HashSize(
|
||||
const size_t index) const {
|
||||
return HashData(index).size();
|
||||
}
|
||||
};
|
||||
|
||||
template<typename IdType>
|
||||
class DeviceNodeMapMaker {
|
||||
public:
|
||||
DeviceNodeMapMaker(
|
||||
const std::vector<int64_t>& maxNodesPerType) :
|
||||
max_num_nodes_(0) {
|
||||
max_num_nodes_ = *std::max_element(maxNodesPerType.begin(),
|
||||
maxNodesPerType.end());
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief This function builds node maps for each node type, preserving the
|
||||
* order of the input nodes.
|
||||
*
|
||||
* \param lhs_nodes The set of source input nodes.
|
||||
* \param rhs_nodes The set of destination input nodes.
|
||||
* \param node_maps The node maps to be constructed.
|
||||
* \param count_lhs_device The number of unique source nodes (on the GPU).
|
||||
* \param lhs_device The unique source nodes (on the GPU).
|
||||
* \param stream The stream to operate on.
|
||||
*/
|
||||
void Make(
|
||||
const std::vector<IdArray>& lhs_nodes,
|
||||
const std::vector<IdArray>& rhs_nodes,
|
||||
DeviceNodeMap<IdType> * const node_maps,
|
||||
int64_t * const count_lhs_device,
|
||||
std::vector<IdArray>* const lhs_device,
|
||||
cudaStream_t stream) {
|
||||
const int64_t num_ntypes = lhs_nodes.size() + rhs_nodes.size();
|
||||
|
||||
CUDA_CALL(cudaMemsetAsync(
|
||||
count_lhs_device,
|
||||
0,
|
||||
num_ntypes*sizeof(*count_lhs_device),
|
||||
stream));
|
||||
|
||||
// possibly dublicate lhs nodes
|
||||
const int64_t lhs_num_ntypes = static_cast<int64_t>(lhs_nodes.size());
|
||||
for (int64_t ntype = 0; ntype < lhs_num_ntypes; ++ntype) {
|
||||
const IdArray& nodes = lhs_nodes[ntype];
|
||||
if (nodes->shape[0] > 0) {
|
||||
CHECK_EQ(nodes->ctx.device_type, kDLGPU);
|
||||
node_maps->LhsHashTable(ntype).FillWithDuplicates(
|
||||
nodes.Ptr<IdType>(),
|
||||
nodes->shape[0],
|
||||
(*lhs_device)[ntype].Ptr<IdType>(),
|
||||
count_lhs_device+ntype,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
// unique rhs nodes
|
||||
const int64_t rhs_num_ntypes = static_cast<int64_t>(rhs_nodes.size());
|
||||
for (int64_t ntype = 0; ntype < rhs_num_ntypes; ++ntype) {
|
||||
const IdArray& nodes = rhs_nodes[ntype];
|
||||
if (nodes->shape[0] > 0) {
|
||||
node_maps->RhsHashTable(ntype).FillWithUnique(
|
||||
nodes.Ptr<IdType>(),
|
||||
nodes->shape[0],
|
||||
stream);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
IdType max_num_nodes_;
|
||||
};
|
||||
|
||||
template<typename IdType>
|
||||
std::tuple<std::vector<IdArray>, std::vector<IdArray>>
|
||||
MapEdges(
|
||||
HeteroGraphPtr graph,
|
||||
const std::vector<EdgeArray>& edge_sets,
|
||||
const DeviceNodeMap<IdType>& node_map,
|
||||
cudaStream_t stream) {
|
||||
constexpr const int BLOCK_SIZE = 128;
|
||||
constexpr const size_t TILE_SIZE = 1024;
|
||||
|
||||
const auto& ctx = graph->Context();
|
||||
|
||||
std::vector<IdArray> new_lhs;
|
||||
new_lhs.reserve(edge_sets.size());
|
||||
std::vector<IdArray> new_rhs;
|
||||
new_rhs.reserve(edge_sets.size());
|
||||
|
||||
// The next peformance optimization here, is to perform mapping of all edge
|
||||
// types in a single kernel launch.
|
||||
const int64_t num_edge_sets = static_cast<int64_t>(edge_sets.size());
|
||||
for (int64_t etype = 0; etype < num_edge_sets; ++etype) {
|
||||
const EdgeArray& edges = edge_sets[etype];
|
||||
if (edges.id.defined()) {
|
||||
const int64_t num_edges = edges.src->shape[0];
|
||||
|
||||
new_lhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8));
|
||||
new_rhs.emplace_back(NewIdArray(num_edges, ctx, sizeof(IdType)*8));
|
||||
|
||||
const auto src_dst_types = graph->GetEndpointTypes(etype);
|
||||
const int src_type = src_dst_types.first;
|
||||
const int dst_type = src_dst_types.second;
|
||||
|
||||
const dim3 grid(RoundUpDiv(num_edges, TILE_SIZE), 2);
|
||||
const dim3 block(BLOCK_SIZE);
|
||||
|
||||
// map the srcs
|
||||
map_edge_ids<IdType, BLOCK_SIZE, TILE_SIZE><<<
|
||||
grid,
|
||||
block,
|
||||
0,
|
||||
stream>>>(
|
||||
edges.src.Ptr<IdType>(),
|
||||
new_lhs.back().Ptr<IdType>(),
|
||||
edges.dst.Ptr<IdType>(),
|
||||
new_rhs.back().Ptr<IdType>(),
|
||||
num_edges,
|
||||
node_map.LhsHashTable(src_type).DeviceHandle(),
|
||||
node_map.RhsHashTable(dst_type).DeviceHandle());
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
} else {
|
||||
new_lhs.emplace_back(aten::NullArray());
|
||||
new_rhs.emplace_back(aten::NullArray());
|
||||
}
|
||||
}
|
||||
|
||||
return std::tuple<std::vector<IdArray>, std::vector<IdArray>>(
|
||||
std::move(new_lhs), std::move(new_rhs));
|
||||
}
|
||||
|
||||
|
||||
// Since partial specialization is not allowed for functions, use this as an
|
||||
// intermediate for ToBlock where XPU = kDLGPU.
|
||||
template<typename IdType>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlockGPU(
|
||||
HeteroGraphPtr graph,
|
||||
const std::vector<IdArray> &rhs_nodes,
|
||||
const bool include_rhs_in_lhs) {
|
||||
cudaStream_t stream = 0;
|
||||
const auto& ctx = graph->Context();
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
|
||||
CHECK_EQ(ctx.device_type, kDLGPU);
|
||||
for (const auto& nodes : rhs_nodes) {
|
||||
CHECK_EQ(ctx.device_type, nodes->ctx.device_type);
|
||||
}
|
||||
|
||||
// Since DST nodes are included in SRC nodes, a common requirement is to fetch
|
||||
// the DST node features from the SRC nodes features. To avoid expensive sparse lookup,
|
||||
// the function assures that the DST nodes in both SRC and DST sets have the same ids.
|
||||
// As a result, given the node feature tensor ``X`` of type ``utype``,
|
||||
// the following code finds the corresponding DST node features of type ``vtype``:
|
||||
|
||||
const int64_t num_etypes = graph->NumEdgeTypes();
|
||||
const int64_t num_ntypes = graph->NumVertexTypes();
|
||||
|
||||
CHECK(rhs_nodes.size() == static_cast<size_t>(num_ntypes))
|
||||
<< "rhs_nodes not given for every node type";
|
||||
|
||||
std::vector<EdgeArray> edge_arrays(num_etypes);
|
||||
for (int64_t etype = 0; etype < num_etypes; ++etype) {
|
||||
const auto src_dst_types = graph->GetEndpointTypes(etype);
|
||||
const dgl_type_t dsttype = src_dst_types.second;
|
||||
if (!aten::IsNullArray(rhs_nodes[dsttype])) {
|
||||
edge_arrays[etype] = graph->Edges(etype);
|
||||
}
|
||||
}
|
||||
|
||||
// count lhs and rhs nodes
|
||||
std::vector<int64_t> maxNodesPerType(num_ntypes*2, 0);
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
maxNodesPerType[ntype+num_ntypes] += rhs_nodes[ntype]->shape[0];
|
||||
|
||||
if (include_rhs_in_lhs) {
|
||||
maxNodesPerType[ntype] += rhs_nodes[ntype]->shape[0];
|
||||
}
|
||||
}
|
||||
for (int64_t etype = 0; etype < num_etypes; ++etype) {
|
||||
const auto src_dst_types = graph->GetEndpointTypes(etype);
|
||||
const dgl_type_t srctype = src_dst_types.first;
|
||||
if (edge_arrays[etype].src.defined()) {
|
||||
maxNodesPerType[srctype] += edge_arrays[etype].src->shape[0];
|
||||
}
|
||||
}
|
||||
|
||||
// gather lhs_nodes
|
||||
std::vector<int64_t> src_node_offsets(num_ntypes, 0);
|
||||
std::vector<IdArray> src_nodes(num_ntypes);
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
src_nodes[ntype] = NewIdArray(maxNodesPerType[ntype], ctx,
|
||||
sizeof(IdType)*8);
|
||||
if (include_rhs_in_lhs) {
|
||||
// place rhs nodes first
|
||||
device->CopyDataFromTo(rhs_nodes[ntype].Ptr<IdType>(), 0,
|
||||
src_nodes[ntype].Ptr<IdType>(), src_node_offsets[ntype],
|
||||
sizeof(IdType)*rhs_nodes[ntype]->shape[0],
|
||||
rhs_nodes[ntype]->ctx, src_nodes[ntype]->ctx,
|
||||
rhs_nodes[ntype]->dtype,
|
||||
stream);
|
||||
src_node_offsets[ntype] += sizeof(IdType)*rhs_nodes[ntype]->shape[0];
|
||||
}
|
||||
}
|
||||
for (int64_t etype = 0; etype < num_etypes; ++etype) {
|
||||
const auto src_dst_types = graph->GetEndpointTypes(etype);
|
||||
const dgl_type_t srctype = src_dst_types.first;
|
||||
if (edge_arrays[etype].src.defined()) {
|
||||
device->CopyDataFromTo(
|
||||
edge_arrays[etype].src.Ptr<IdType>(), 0,
|
||||
src_nodes[srctype].Ptr<IdType>(),
|
||||
src_node_offsets[srctype],
|
||||
sizeof(IdType)*edge_arrays[etype].src->shape[0],
|
||||
rhs_nodes[srctype]->ctx,
|
||||
src_nodes[srctype]->ctx,
|
||||
rhs_nodes[srctype]->dtype,
|
||||
stream);
|
||||
|
||||
src_node_offsets[srctype] += sizeof(IdType)*edge_arrays[etype].src->shape[0];
|
||||
}
|
||||
}
|
||||
|
||||
// allocate space for map creation process
|
||||
DeviceNodeMapMaker<IdType> maker(maxNodesPerType);
|
||||
|
||||
DeviceNodeMap<IdType> node_maps(maxNodesPerType, ctx, stream);
|
||||
|
||||
int64_t total_lhs = 0;
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
total_lhs += maxNodesPerType[ntype];
|
||||
}
|
||||
|
||||
std::vector<IdArray> lhs_nodes;
|
||||
lhs_nodes.reserve(num_ntypes);
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
lhs_nodes.emplace_back(NewIdArray(
|
||||
maxNodesPerType[ntype], ctx, sizeof(IdType)*8));
|
||||
}
|
||||
|
||||
// populate the mappings
|
||||
int64_t * count_lhs_device = static_cast<int64_t*>(
|
||||
device->AllocWorkspace(ctx, sizeof(int64_t)*num_ntypes*2));
|
||||
maker.Make(
|
||||
src_nodes,
|
||||
rhs_nodes,
|
||||
&node_maps,
|
||||
count_lhs_device,
|
||||
&lhs_nodes,
|
||||
stream);
|
||||
|
||||
std::vector<IdArray> induced_edges;
|
||||
induced_edges.reserve(num_etypes);
|
||||
for (int64_t etype = 0; etype < num_etypes; ++etype) {
|
||||
if (edge_arrays[etype].id.defined()) {
|
||||
induced_edges.push_back(edge_arrays[etype].id);
|
||||
} else {
|
||||
induced_edges.push_back(
|
||||
aten::NullArray(DLDataType{kDLInt, sizeof(IdType)*8, 1}, ctx));
|
||||
}
|
||||
}
|
||||
|
||||
// build metagraph -- small enough to be done on CPU
|
||||
const auto meta_graph = graph->meta_graph();
|
||||
const EdgeArray etypes = meta_graph->Edges("eid");
|
||||
const IdArray new_dst = Add(etypes.dst, num_ntypes);
|
||||
const auto new_meta_graph = ImmutableGraph::CreateFromCOO(
|
||||
num_ntypes * 2, etypes.src, new_dst);
|
||||
|
||||
// allocate vector for graph relations while GPU is busy
|
||||
std::vector<HeteroGraphPtr> rel_graphs;
|
||||
rel_graphs.reserve(num_etypes);
|
||||
|
||||
std::vector<int64_t> num_nodes_per_type(num_ntypes*2);
|
||||
// populate RHS nodes from what we already know
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
num_nodes_per_type[num_ntypes+ntype] = rhs_nodes[ntype]->shape[0];
|
||||
}
|
||||
device->CopyDataFromTo(
|
||||
count_lhs_device, 0,
|
||||
num_nodes_per_type.data(), 0,
|
||||
sizeof(*num_nodes_per_type.data())*num_ntypes,
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLType{kDLInt, 64, 1},
|
||||
stream);
|
||||
device->StreamSync(ctx, stream);
|
||||
|
||||
// wait for the node counts to finish transferring
|
||||
device->FreeWorkspace(ctx, count_lhs_device);
|
||||
|
||||
// map node numberings from global to local, and build pointer for CSR
|
||||
std::vector<IdArray> new_lhs;
|
||||
std::vector<IdArray> new_rhs;
|
||||
std::tie(new_lhs, new_rhs) = MapEdges(graph, edge_arrays, node_maps, stream);
|
||||
|
||||
// resize lhs nodes
|
||||
for (int64_t ntype = 0; ntype < num_ntypes; ++ntype) {
|
||||
lhs_nodes[ntype]->shape[0] = num_nodes_per_type[ntype];
|
||||
}
|
||||
|
||||
// build the heterograph
|
||||
for (int64_t etype = 0; etype < num_etypes; ++etype) {
|
||||
const auto src_dst_types = graph->GetEndpointTypes(etype);
|
||||
const dgl_type_t srctype = src_dst_types.first;
|
||||
const dgl_type_t dsttype = src_dst_types.second;
|
||||
|
||||
if (rhs_nodes[dsttype]->shape[0] == 0) {
|
||||
// No rhs nodes are given for this edge type. Create an empty graph.
|
||||
rel_graphs.push_back(CreateFromCOO(
|
||||
2, lhs_nodes[srctype]->shape[0], rhs_nodes[dsttype]->shape[0],
|
||||
aten::NullArray(DLDataType{kDLInt, sizeof(IdType)*8, 1}, ctx),
|
||||
aten::NullArray(DLDataType{kDLInt, sizeof(IdType)*8, 1}, ctx)));
|
||||
} else {
|
||||
rel_graphs.push_back(CreateFromCOO(
|
||||
2,
|
||||
lhs_nodes[srctype]->shape[0],
|
||||
rhs_nodes[dsttype]->shape[0],
|
||||
new_lhs[etype],
|
||||
new_rhs[etype]));
|
||||
}
|
||||
}
|
||||
|
||||
HeteroGraphPtr new_graph = CreateHeteroGraph(
|
||||
new_meta_graph, rel_graphs, num_nodes_per_type);
|
||||
|
||||
// return the new graph, the new src nodes, and new edges
|
||||
return std::make_tuple(new_graph, lhs_nodes, induced_edges);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template<>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock<kDLGPU, int32_t>(
|
||||
HeteroGraphPtr graph,
|
||||
const std::vector<IdArray> &rhs_nodes,
|
||||
bool include_rhs_in_lhs) {
|
||||
return ToBlockGPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs);
|
||||
}
|
||||
|
||||
template<>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock<kDLGPU, int64_t>(
|
||||
HeteroGraphPtr graph,
|
||||
const std::vector<IdArray> &rhs_nodes,
|
||||
bool include_rhs_in_lhs) {
|
||||
return ToBlockGPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs);
|
||||
}
|
||||
|
||||
} // namespace transform
|
||||
} // namespace dgl
|
||||
43
src/graph/transform/cuda/cuda_to_block.h
Normal file
43
src/graph/transform/cuda/cuda_to_block.h
Normal file
@@ -0,0 +1,43 @@
|
||||
/*!
|
||||
* Copyright (c) 2020 by Contributors
|
||||
* \file graph/transform/cuda_to_block.h
|
||||
* \brief Functions to convert a set of edges into a graph block with local
|
||||
* ids.
|
||||
*/
|
||||
|
||||
|
||||
#ifndef DGL_GRAPH_TRANSFORM_CUDA_CUDA_TO_BLOCK_H_
|
||||
#define DGL_GRAPH_TRANSFORM_CUDA_CUDA_TO_BLOCK_H_
|
||||
|
||||
#include <dgl/array.h>
|
||||
#include <dgl/base_heterograph.h>
|
||||
#include <vector>
|
||||
#include <tuple>
|
||||
|
||||
namespace dgl {
|
||||
namespace transform {
|
||||
namespace cuda {
|
||||
|
||||
/**
|
||||
* @brief Generate a subgraph with locally numbered vertices, from the given
|
||||
* edge set.
|
||||
*
|
||||
* @param graph The set of edges to construct the subgraph from.
|
||||
* @param rhs_nodes The unique set of destination vertices.
|
||||
* @param include_rhs_in_lhs Whether or not to include the `rhs_nodes` in the
|
||||
* set of source vertices for purposes of local numbering.
|
||||
*
|
||||
* @return The subgraph, the unique set of source nodes, and the mapping of
|
||||
* subgraph edges to global edges.
|
||||
*/
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
CudaToBlock(
|
||||
HeteroGraphPtr graph,
|
||||
const std::vector<IdArray>& rhs_nodes,
|
||||
const bool include_rhs_in_lhs);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace transform
|
||||
} // namespace dgl
|
||||
|
||||
#endif // DGL_GRAPH_TRANSFORM_CUDA_CUDA_TO_BLOCK_H_
|
||||
@@ -4,6 +4,8 @@
|
||||
* \brief Convert a graph to a bipartite-structured graph.
|
||||
*/
|
||||
|
||||
#include "to_bipartite.h"
|
||||
|
||||
#include <dgl/base_heterograph.h>
|
||||
#include <dgl/transform.h>
|
||||
#include <dgl/array.h>
|
||||
@@ -26,9 +28,11 @@ namespace transform {
|
||||
|
||||
namespace {
|
||||
|
||||
// Since partial specialization is not allowed for functions, use this as an
|
||||
// intermediate for ToBlock where XPU = kDLCPU.
|
||||
template<typename IdType>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
|
||||
ToBlockCPU(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
|
||||
const int64_t num_etypes = graph->NumEdgeTypes();
|
||||
const int64_t num_ntypes = graph->NumVertexTypes();
|
||||
std::vector<EdgeArray> edge_arrays(num_etypes);
|
||||
@@ -107,15 +111,22 @@ ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool includ
|
||||
return std::make_tuple(new_graph, lhs_nodes, induced_edges);
|
||||
}
|
||||
|
||||
}; // namespace
|
||||
} // namespace
|
||||
|
||||
template<>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes, bool include_rhs_in_lhs) {
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>> ret;
|
||||
ATEN_ID_TYPE_SWITCH(graph->DataType(), IdType, {
|
||||
ret = ToBlock<IdType>(graph, rhs_nodes, include_rhs_in_lhs);
|
||||
});
|
||||
return ret;
|
||||
ToBlock<kDLCPU, int32_t>(HeteroGraphPtr graph,
|
||||
const std::vector<IdArray> &rhs_nodes,
|
||||
bool include_rhs_in_lhs) {
|
||||
return ToBlockCPU<int32_t>(graph, rhs_nodes, include_rhs_in_lhs);
|
||||
}
|
||||
|
||||
template<>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock<kDLCPU, int64_t>(HeteroGraphPtr graph,
|
||||
const std::vector<IdArray> &rhs_nodes,
|
||||
bool include_rhs_in_lhs) {
|
||||
return ToBlockCPU<int64_t>(graph, rhs_nodes, include_rhs_in_lhs);
|
||||
}
|
||||
|
||||
DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
|
||||
@@ -127,8 +138,13 @@ DGL_REGISTER_GLOBAL("transform._CAPI_DGLToBlock")
|
||||
HeteroGraphPtr new_graph;
|
||||
std::vector<IdArray> lhs_nodes;
|
||||
std::vector<IdArray> induced_edges;
|
||||
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock(
|
||||
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs);
|
||||
|
||||
ATEN_XPU_SWITCH_CUDA(graph_ref->Context().device_type, XPU, "ToBlock", {
|
||||
ATEN_ID_TYPE_SWITCH(graph_ref->DataType(), IdType, {
|
||||
std::tie(new_graph, lhs_nodes, induced_edges) = ToBlock<XPU, IdType>(
|
||||
graph_ref.sptr(), rhs_nodes, include_rhs_in_lhs);
|
||||
});
|
||||
});
|
||||
|
||||
List<Value> lhs_nodes_ref;
|
||||
for (IdArray &array : lhs_nodes)
|
||||
|
||||
27
src/graph/transform/to_bipartite.h
Normal file
27
src/graph/transform/to_bipartite.h
Normal file
@@ -0,0 +1,27 @@
|
||||
/*!
|
||||
* Copyright (c) 2021 by Contributors
|
||||
* \file graph/transform/to_bipartite.h
|
||||
* \brief Array operator templates
|
||||
*/
|
||||
|
||||
#ifndef DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_
|
||||
#define DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_
|
||||
|
||||
#include <dgl/array.h>
|
||||
#include <dgl/base_heterograph.h>
|
||||
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
namespace dgl {
|
||||
namespace transform {
|
||||
|
||||
template<DLDeviceType XPU, typename IdType>
|
||||
std::tuple<HeteroGraphPtr, std::vector<IdArray>, std::vector<IdArray>>
|
||||
ToBlock(HeteroGraphPtr graph, const std::vector<IdArray> &rhs_nodes,
|
||||
bool include_rhs_in_lhs);
|
||||
|
||||
} // namespace transform
|
||||
} // namespace dgl
|
||||
|
||||
#endif // DGL_GRAPH_TRANSFORM_TO_BIPARTITE_H_
|
||||
@@ -118,6 +118,58 @@ __device__ __forceinline__ __half AtomicAdd<__half>(__half* addr, __half val) {
|
||||
DEFINE_ATOMIC(Mul)
|
||||
#undef OP
|
||||
|
||||
/**
|
||||
* \brief Performs an atomic compare-and-swap on 64 bit integers. That is,
|
||||
* it the word `old` at the memory location `address`, computes
|
||||
* `(old == compare ? val : old)` , and stores the result back to memory at
|
||||
* the same address.
|
||||
*
|
||||
* \param address The address to perform the atomic operation on.
|
||||
* \param compare The value to compare to.
|
||||
* \param val The new value to conditionally store.
|
||||
*
|
||||
* \return The old value at the address.
|
||||
*/
|
||||
inline __device__ int64_t AtomicCAS(
|
||||
int64_t * const address,
|
||||
const int64_t compare,
|
||||
const int64_t val) {
|
||||
// match the type of "::atomicCAS", so ignore lint warning
|
||||
using Type = unsigned long long int; // NOLINT
|
||||
|
||||
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
|
||||
|
||||
return atomicCAS(reinterpret_cast<Type*>(address),
|
||||
static_cast<Type>(compare),
|
||||
static_cast<Type>(val));
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Performs an atomic compare-and-swap on 32 bit integers. That is,
|
||||
* it the word `old` at the memory location `address`, computes
|
||||
* `(old == compare ? val : old)` , and stores the result back to memory at
|
||||
* the same address.
|
||||
*
|
||||
* \param address The address to perform the atomic operation on.
|
||||
* \param compare The value to compare to.
|
||||
* \param val The new value to conditionally store.
|
||||
*
|
||||
* \return The old value at the address.
|
||||
*/
|
||||
inline __device__ int32_t AtomicCAS(
|
||||
int32_t * const address,
|
||||
const int32_t compare,
|
||||
const int32_t val) {
|
||||
// match the type of "::atomicCAS", so ignore lint warning
|
||||
using Type = int; // NOLINT
|
||||
|
||||
static_assert(sizeof(Type) == sizeof(*address), "Type width must match");
|
||||
|
||||
return atomicCAS(reinterpret_cast<Type*>(address),
|
||||
static_cast<Type>(compare),
|
||||
static_cast<Type>(val));
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace kernel
|
||||
} // namespace dgl
|
||||
|
||||
488
src/runtime/cuda/cuda_hashtable.cu
Normal file
488
src/runtime/cuda/cuda_hashtable.cu
Normal file
@@ -0,0 +1,488 @@
|
||||
/*!
|
||||
* Copyright (c) 2021 by Contributors
|
||||
* \file runtime/cuda/cuda_device_common.cuh
|
||||
* \brief Device level functions for within cuda kernels.
|
||||
*/
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cassert>
|
||||
|
||||
#include "cuda_hashtable.cuh"
|
||||
#include "../../kernel/cuda/atomic.cuh"
|
||||
|
||||
using namespace dgl::kernel::cuda;
|
||||
|
||||
namespace dgl {
|
||||
namespace runtime {
|
||||
namespace cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
constexpr static const int BLOCK_SIZE = 256;
|
||||
constexpr static const size_t TILE_SIZE = 1024;
|
||||
|
||||
/**
|
||||
* @brief This is the mutable version of the DeviceOrderedHashTable, for use in
|
||||
* inserting elements into the hashtable.
|
||||
*
|
||||
* @tparam IdType The type of ID to store in the hashtable.
|
||||
*/
|
||||
template<typename IdType>
|
||||
class MutableDeviceOrderedHashTable : public DeviceOrderedHashTable<IdType> {
|
||||
public:
|
||||
typedef typename DeviceOrderedHashTable<IdType>::Mapping* Iterator;
|
||||
static constexpr IdType kEmptyKey = DeviceOrderedHashTable<IdType>::kEmptyKey;
|
||||
|
||||
/**
|
||||
* @brief Create a new mutable hashtable for use on the device.
|
||||
*
|
||||
* @param hostTable The original hash table on the host.
|
||||
*/
|
||||
explicit MutableDeviceOrderedHashTable(
|
||||
OrderedHashTable<IdType>* const hostTable) :
|
||||
DeviceOrderedHashTable<IdType>(hostTable->DeviceHandle()) {
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Find the mutable mapping of a given key within the hash table.
|
||||
*
|
||||
* WARNING: The key must exist within the hashtable. Searching for a key not
|
||||
* in the hashtable is undefined behavior.
|
||||
*
|
||||
* @param id The key to search for.
|
||||
*
|
||||
* @return The mapping.
|
||||
*/
|
||||
inline __device__ Iterator Search(
|
||||
const IdType id) {
|
||||
const IdType pos = SearchForPosition(id);
|
||||
|
||||
return GetMutable(pos);
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Attempt to insert into the hash table at a specific location.
|
||||
*
|
||||
* \param pos The position to insert at.
|
||||
* \param id The ID to insert into the hash table.
|
||||
* \param index The original index of the item being inserted.
|
||||
*
|
||||
* \return True, if the insertion was successful.
|
||||
*/
|
||||
inline __device__ bool AttemptInsertAt(
|
||||
const size_t pos,
|
||||
const IdType id,
|
||||
const size_t index) {
|
||||
const IdType key = AtomicCAS(&GetMutable(pos)->key, kEmptyKey, id);
|
||||
if (key == kEmptyKey || key == id) {
|
||||
// we either set a match key, or found a matching key, so then place the
|
||||
// minimum index in position. Match the type of atomicMin, so ignore
|
||||
// linting
|
||||
atomicMin(reinterpret_cast<unsigned long long*>(&GetMutable(pos)->index), // NOLINT
|
||||
static_cast<unsigned long long>(index)); // NOLINT
|
||||
return true;
|
||||
} else {
|
||||
// we need to search elsewhere
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Insert key-index pair into the hashtable.
|
||||
*
|
||||
* @param id The ID to insert.
|
||||
* @param index The index at which the ID occured.
|
||||
*
|
||||
* @return An iterator to inserted mapping.
|
||||
*/
|
||||
inline __device__ Iterator Insert(
|
||||
const IdType id,
|
||||
const size_t index) {
|
||||
size_t pos = Hash(id);
|
||||
|
||||
// linearly scan for an empty slot or matching entry
|
||||
IdType delta = 1;
|
||||
while (!AttemptInsertAt(pos, id, index)) {
|
||||
pos = Hash(pos+delta);
|
||||
delta +=1;
|
||||
}
|
||||
|
||||
return GetMutable(pos);
|
||||
}
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Get a mutable iterator to the given bucket in the hashtable.
|
||||
*
|
||||
* @param pos The given bucket.
|
||||
*
|
||||
* @return The iterator.
|
||||
*/
|
||||
inline __device__ Iterator GetMutable(const size_t pos) {
|
||||
assert(pos < this->size_);
|
||||
// The parent class Device is read-only, but we ensure this can only be
|
||||
// constructed from a mutable version of OrderedHashTable, making this
|
||||
// a safe cast to perform.
|
||||
return const_cast<Iterator>(this->table_+pos);
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Calculate the number of buckets in the hashtable. To guarantee we can
|
||||
* fill the hashtable in the worst case, we must use a number of buckets which
|
||||
* is a power of two.
|
||||
* https://en.wikipedia.org/wiki/Quadratic_probing#Limitations
|
||||
*
|
||||
* @param num The number of items to insert (should be an upper bound on the
|
||||
* number of unique keys).
|
||||
* @param scale The power of two larger the number of buckets should be than the
|
||||
* unique keys.
|
||||
*
|
||||
* @return The number of buckets the table should contain.
|
||||
*/
|
||||
size_t TableSize(
|
||||
const size_t num,
|
||||
const int scale) {
|
||||
const size_t next_pow2 = 1 << static_cast<size_t>(1 + std::log2(num >> 1));
|
||||
return next_pow2 << scale;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This structure is used with cub's block-level prefixscan in order to
|
||||
* keep a running sum as items are iteratively processed.
|
||||
*
|
||||
* @tparam IdType The type to perform the prefixsum on.
|
||||
*/
|
||||
template<typename IdType>
|
||||
struct BlockPrefixCallbackOp {
|
||||
IdType running_total_;
|
||||
|
||||
__device__ BlockPrefixCallbackOp(
|
||||
const IdType running_total) :
|
||||
running_total_(running_total) {
|
||||
}
|
||||
|
||||
__device__ IdType operator()(const IdType block_aggregate) {
|
||||
const IdType old_prefix = running_total_;
|
||||
running_total_ += block_aggregate;
|
||||
return old_prefix;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
/**
|
||||
* \brief This generates a hash map where the keys are the global item numbers,
|
||||
* and the values are indexes, and inputs may have duplciates.
|
||||
*
|
||||
* \tparam IdType The type of of id.
|
||||
* \tparam BLOCK_SIZE The size of the thread block.
|
||||
* \tparam TILE_SIZE The number of entries each thread block will process.
|
||||
* \param items The items to insert.
|
||||
* \param num_items The number of items to insert.
|
||||
* \param table The hash table.
|
||||
*/
|
||||
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
|
||||
__global__ void generate_hashmap_duplicates(
|
||||
const IdType * const items,
|
||||
const int64_t num_items,
|
||||
MutableDeviceOrderedHashTable<IdType> table) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
|
||||
const size_t block_start = TILE_SIZE*blockIdx.x;
|
||||
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
|
||||
|
||||
#pragma unroll
|
||||
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
|
||||
if (index < num_items) {
|
||||
table.Insert(items[index], index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief This generates a hash map where the keys are the global item numbers,
|
||||
* and the values are indexes, and all inputs are unique.
|
||||
*
|
||||
* \tparam IdType The type of of id.
|
||||
* \tparam BLOCK_SIZE The size of the thread block.
|
||||
* \tparam TILE_SIZE The number of entries each thread block will process.
|
||||
* \param items The unique items to insert.
|
||||
* \param num_items The number of items to insert.
|
||||
* \param table The hash table.
|
||||
*/
|
||||
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
|
||||
__global__ void generate_hashmap_unique(
|
||||
const IdType * const items,
|
||||
const int64_t num_items,
|
||||
MutableDeviceOrderedHashTable<IdType> table) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
|
||||
using Iterator = typename MutableDeviceOrderedHashTable<IdType>::Iterator;
|
||||
|
||||
const size_t block_start = TILE_SIZE*blockIdx.x;
|
||||
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
|
||||
|
||||
#pragma unroll
|
||||
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
|
||||
if (index < num_items) {
|
||||
const Iterator pos = table.Insert(items[index], index);
|
||||
|
||||
// since we are only inserting unique items, we know their local id
|
||||
// will be equal to their index
|
||||
pos->local = static_cast<IdType>(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief This counts the number of nodes inserted per thread block.
|
||||
*
|
||||
* \tparam IdType The type of of id.
|
||||
* \tparam BLOCK_SIZE The size of the thread block.
|
||||
* \tparam TILE_SIZE The number of entries each thread block will process.
|
||||
* \param input The nodes to insert.
|
||||
* \param num_input The number of nodes to insert.
|
||||
* \param table The hash table.
|
||||
* \param num_unique The number of nodes inserted into the hash table per thread
|
||||
* block.
|
||||
*/
|
||||
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
|
||||
__global__ void count_hashmap(
|
||||
const IdType * items,
|
||||
const size_t num_items,
|
||||
DeviceOrderedHashTable<IdType> table,
|
||||
IdType * const num_unique) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
|
||||
using BlockReduce = typename cub::BlockReduce<IdType, BLOCK_SIZE>;
|
||||
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
|
||||
|
||||
const size_t block_start = TILE_SIZE*blockIdx.x;
|
||||
const size_t block_end = TILE_SIZE*(blockIdx.x+1);
|
||||
|
||||
IdType count = 0;
|
||||
|
||||
#pragma unroll
|
||||
for (size_t index = threadIdx.x + block_start; index < block_end; index += BLOCK_SIZE) {
|
||||
if (index < num_items) {
|
||||
const Mapping& mapping = *table.Search(items[index]);
|
||||
if (mapping.index == index) {
|
||||
++count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__shared__ typename BlockReduce::TempStorage temp_space;
|
||||
|
||||
count = BlockReduce(temp_space).Sum(count);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
num_unique[blockIdx.x] = count;
|
||||
if (blockIdx.x == 0) {
|
||||
num_unique[gridDim.x] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* \brief Update the local numbering of elements in the hashmap.
|
||||
*
|
||||
* \tparam IdType The type of id.
|
||||
* \tparam BLOCK_SIZE The size of the thread blocks.
|
||||
* \tparam TILE_SIZE The number of elements each thread block works on.
|
||||
* \param items The set of non-unique items to update from.
|
||||
* \param num_items The number of non-unique items.
|
||||
* \param table The hash table.
|
||||
* \param num_items_prefix The number of unique items preceding each thread
|
||||
* block.
|
||||
* \param unique_items The set of unique items (output).
|
||||
* \param num_unique_items The number of unique items (output).
|
||||
*/
|
||||
template<typename IdType, int BLOCK_SIZE, size_t TILE_SIZE>
|
||||
__global__ void compact_hashmap(
|
||||
const IdType * const items,
|
||||
const size_t num_items,
|
||||
MutableDeviceOrderedHashTable<IdType> table,
|
||||
const IdType * const num_items_prefix,
|
||||
IdType * const unique_items,
|
||||
int64_t * const num_unique_items) {
|
||||
assert(BLOCK_SIZE == blockDim.x);
|
||||
|
||||
using FlagType = uint16_t;
|
||||
using BlockScan = typename cub::BlockScan<FlagType, BLOCK_SIZE>;
|
||||
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
|
||||
|
||||
constexpr const int32_t VALS_PER_THREAD = TILE_SIZE / BLOCK_SIZE;
|
||||
|
||||
__shared__ typename BlockScan::TempStorage temp_space;
|
||||
|
||||
const IdType offset = num_items_prefix[blockIdx.x];
|
||||
|
||||
BlockPrefixCallbackOp<FlagType> prefix_op(0);
|
||||
|
||||
// count successful placements
|
||||
for (int32_t i = 0; i < VALS_PER_THREAD; ++i) {
|
||||
const IdType index = threadIdx.x + i*BLOCK_SIZE + blockIdx.x*TILE_SIZE;
|
||||
|
||||
FlagType flag;
|
||||
Mapping * kv;
|
||||
if (index < num_items) {
|
||||
kv = table.Search(items[index]);
|
||||
flag = kv->index == index;
|
||||
} else {
|
||||
flag = 0;
|
||||
}
|
||||
|
||||
if (!flag) {
|
||||
kv = nullptr;
|
||||
}
|
||||
|
||||
BlockScan(temp_space).ExclusiveSum(flag, flag, prefix_op);
|
||||
__syncthreads();
|
||||
|
||||
if (kv) {
|
||||
const IdType pos = offset+flag;
|
||||
kv->local = pos;
|
||||
unique_items[pos] = items[index];
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0) {
|
||||
*num_unique_items = num_items_prefix[gridDim.x];
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceOrderedHashTable implementation
|
||||
|
||||
template<typename IdType>
|
||||
DeviceOrderedHashTable<IdType>::DeviceOrderedHashTable(
|
||||
const Mapping* const table,
|
||||
const size_t size) :
|
||||
table_(table),
|
||||
size_(size) {
|
||||
}
|
||||
|
||||
template<typename IdType>
|
||||
DeviceOrderedHashTable<IdType> OrderedHashTable<IdType>::DeviceHandle() const {
|
||||
return DeviceOrderedHashTable<IdType>(table_, size_);
|
||||
}
|
||||
|
||||
// OrderedHashTable implementation
|
||||
|
||||
template<typename IdType>
|
||||
OrderedHashTable<IdType>::OrderedHashTable(
|
||||
const size_t size,
|
||||
DGLContext ctx,
|
||||
cudaStream_t stream,
|
||||
const int scale) :
|
||||
table_(nullptr),
|
||||
size_(TableSize(size, scale)),
|
||||
ctx_(ctx) {
|
||||
// make sure we will at least as many buckets as items.
|
||||
CHECK_GT(scale, 0);
|
||||
|
||||
auto device = runtime::DeviceAPI::Get(ctx_);
|
||||
table_ = static_cast<Mapping*>(
|
||||
device->AllocWorkspace(ctx_, sizeof(Mapping)*size_));
|
||||
|
||||
CUDA_CALL(cudaMemsetAsync(
|
||||
table_,
|
||||
DeviceOrderedHashTable<IdType>::kEmptyKey,
|
||||
sizeof(Mapping)*size_,
|
||||
stream));
|
||||
}
|
||||
|
||||
template<typename IdType>
|
||||
OrderedHashTable<IdType>::~OrderedHashTable() {
|
||||
auto device = runtime::DeviceAPI::Get(ctx_);
|
||||
device->FreeWorkspace(ctx_, table_);
|
||||
}
|
||||
|
||||
template<typename IdType>
|
||||
void OrderedHashTable<IdType>::FillWithDuplicates(
|
||||
const IdType * const input,
|
||||
const size_t num_input,
|
||||
IdType * const unique,
|
||||
int64_t * const num_unique,
|
||||
cudaStream_t stream) {
|
||||
auto device = runtime::DeviceAPI::Get(ctx_);
|
||||
|
||||
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE;
|
||||
|
||||
const dim3 grid(num_tiles);
|
||||
const dim3 block(BLOCK_SIZE);
|
||||
|
||||
auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
|
||||
|
||||
generate_hashmap_duplicates<IdType, BLOCK_SIZE, TILE_SIZE><<<grid, block, 0, stream>>>(
|
||||
input,
|
||||
num_input,
|
||||
device_table);
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
|
||||
IdType * item_prefix = static_cast<IdType*>(
|
||||
device->AllocWorkspace(ctx_, sizeof(IdType)*(num_input+1)));
|
||||
|
||||
count_hashmap<IdType, BLOCK_SIZE, TILE_SIZE><<<grid, block, 0, stream>>>(
|
||||
input,
|
||||
num_input,
|
||||
device_table,
|
||||
item_prefix);
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
|
||||
size_t workspace_bytes;
|
||||
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
|
||||
nullptr,
|
||||
workspace_bytes,
|
||||
static_cast<IdType*>(nullptr),
|
||||
static_cast<IdType*>(nullptr),
|
||||
grid.x+1));
|
||||
void * workspace = device->AllocWorkspace(ctx_, workspace_bytes);
|
||||
|
||||
CUDA_CALL(cub::DeviceScan::ExclusiveSum(
|
||||
workspace,
|
||||
workspace_bytes,
|
||||
item_prefix,
|
||||
item_prefix,
|
||||
grid.x+1, stream));
|
||||
device->FreeWorkspace(ctx_, workspace);
|
||||
|
||||
compact_hashmap<IdType, BLOCK_SIZE, TILE_SIZE><<<grid, block, 0, stream>>>(
|
||||
input,
|
||||
num_input,
|
||||
device_table,
|
||||
item_prefix,
|
||||
unique,
|
||||
num_unique);
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
device->FreeWorkspace(ctx_, item_prefix);
|
||||
}
|
||||
|
||||
template<typename IdType>
|
||||
void OrderedHashTable<IdType>::FillWithUnique(
|
||||
const IdType * const input,
|
||||
const size_t num_input,
|
||||
cudaStream_t stream) {
|
||||
|
||||
const int64_t num_tiles = (num_input+TILE_SIZE-1)/TILE_SIZE;
|
||||
|
||||
const dim3 grid(num_tiles);
|
||||
const dim3 block(BLOCK_SIZE);
|
||||
|
||||
auto device_table = MutableDeviceOrderedHashTable<IdType>(this);
|
||||
|
||||
generate_hashmap_unique<IdType, BLOCK_SIZE, TILE_SIZE><<<grid, block, 0, stream>>>(
|
||||
input,
|
||||
num_input,
|
||||
device_table);
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
}
|
||||
|
||||
template class OrderedHashTable<int32_t>;
|
||||
template class OrderedHashTable<int64_t>;
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace runtime
|
||||
} // namespace dgl
|
||||
282
src/runtime/cuda/cuda_hashtable.cuh
Normal file
282
src/runtime/cuda/cuda_hashtable.cuh
Normal file
@@ -0,0 +1,282 @@
|
||||
/*!
|
||||
* Copyright (c) 2021 by Contributors
|
||||
* \file runtime/cuda/cuda_device_common.cuh
|
||||
* \brief Device level functions for within cuda kernels.
|
||||
*/
|
||||
|
||||
#ifndef DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
|
||||
#define DGL_RUNTIME_CUDA_CUDA_HASHTABLE_CUH_
|
||||
|
||||
#include <dgl/runtime/c_runtime_api.h>
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cuda_common.h"
|
||||
|
||||
namespace dgl {
|
||||
namespace runtime {
|
||||
namespace cuda {
|
||||
|
||||
template<typename>
|
||||
class OrderedHashTable;
|
||||
|
||||
/*!
|
||||
* \brief A device-side handle for a GPU hashtable for mapping items to the
|
||||
* first index at which they appear in the provided data array.
|
||||
*
|
||||
* For any ID array A, one can view it as a mapping from the index `i`
|
||||
* (continuous integer range from zero) to its element `A[i]`. This hashtable
|
||||
* serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
|
||||
* Quadratic probing is used for collision resolution. See
|
||||
* DeviceOrderedHashTable's documentation for how the Mapping structure is
|
||||
* used.
|
||||
*
|
||||
* The hash table should be used in two phases, with the first being populating
|
||||
* the hash table with the OrderedHashTable object, and then generating this
|
||||
* handle from it. This object can then be used to search the hash table,
|
||||
* to find mappings, from with CUDA code.
|
||||
*
|
||||
* If a device-side handle is created from a hash table with the following
|
||||
* entries:
|
||||
* [
|
||||
* {key: 0, local: 0, index: 0},
|
||||
* {key: 3, local: 1, index: 1},
|
||||
* {key: 2, local: 2, index: 2},
|
||||
* {key: 8, local: 3, index: 4},
|
||||
* {key: 4, local: 4, index: 5},
|
||||
* {key: 1, local: 5, index: 8}
|
||||
* ]
|
||||
* The array [0, 3, 2, 0, 8, 4, 3, 2, 1, 8] could have `Search()` called on
|
||||
* each id, to be mapped via:
|
||||
* ```
|
||||
* __global__ void map(int32_t * array,
|
||||
* size_t size,
|
||||
* DeviceOrderedHashTable<int32_t> table) {
|
||||
* int idx = threadIdx.x + blockIdx.x*blockDim.x;
|
||||
* if (idx < size) {
|
||||
* array[idx] = table.Search(array[idx])->local;
|
||||
* }
|
||||
* }
|
||||
* ```
|
||||
* to get the remaped array:
|
||||
* [0, 1, 2, 0, 3, 4, 1, 2, 5, 3]
|
||||
*
|
||||
* \tparam IdType The type of the IDs.
|
||||
*/
|
||||
template<typename IdType>
|
||||
class DeviceOrderedHashTable {
|
||||
public:
|
||||
/**
|
||||
* \brief An entry in the hashtable.
|
||||
*/
|
||||
struct Mapping {
|
||||
/**
|
||||
* \brief The ID of the item inserted.
|
||||
*/
|
||||
IdType key;
|
||||
/**
|
||||
* \brief The index of the item in the unique list.
|
||||
*/
|
||||
IdType local;
|
||||
/**
|
||||
* \brief The index of the item when inserted into the hashtable (e.g.,
|
||||
* the index within the array passed into FillWithDuplicates()).
|
||||
*/
|
||||
int64_t index;
|
||||
};
|
||||
|
||||
typedef const Mapping* ConstIterator;
|
||||
|
||||
DeviceOrderedHashTable(
|
||||
const DeviceOrderedHashTable& other) = default;
|
||||
DeviceOrderedHashTable& operator=(
|
||||
const DeviceOrderedHashTable& other) = default;
|
||||
|
||||
/**
|
||||
* \brief Find the non-mutable mapping of a given key within the hash table.
|
||||
*
|
||||
* WARNING: The key must exist within the hashtable. Searching for a key not
|
||||
* in the hashtable is undefined behavior.
|
||||
*
|
||||
* \param id The key to search for.
|
||||
*
|
||||
* \return An iterator to the mapping.
|
||||
*/
|
||||
inline __device__ ConstIterator Search(
|
||||
const IdType id) const {
|
||||
const IdType pos = SearchForPosition(id);
|
||||
|
||||
return &table_[pos];
|
||||
}
|
||||
|
||||
protected:
|
||||
// Must be uniform bytes for memset to work
|
||||
static constexpr IdType kEmptyKey = static_cast<IdType>(-1);
|
||||
|
||||
const Mapping * table_;
|
||||
size_t size_;
|
||||
|
||||
/**
|
||||
* \brief Create a new device-side handle to the hash table.
|
||||
*
|
||||
* \param table The table stored in GPU memory.
|
||||
* \param size The size of the table.
|
||||
*/
|
||||
explicit DeviceOrderedHashTable(
|
||||
const Mapping * table,
|
||||
size_t size);
|
||||
|
||||
/**
|
||||
* \brief Search for an item in the hash table which is known to exist.
|
||||
*
|
||||
* WARNING: If the ID searched for does not exist within the hashtable, this
|
||||
* function will never return.
|
||||
*
|
||||
* \param id The ID of the item to search for.
|
||||
*
|
||||
* \return The the position of the item in the hashtable.
|
||||
*/
|
||||
inline __device__ IdType SearchForPosition(
|
||||
const IdType id) const {
|
||||
IdType pos = Hash(id);
|
||||
|
||||
// linearly scan for matching entry
|
||||
IdType delta = 1;
|
||||
while (table_[pos].key != id) {
|
||||
assert(table_[pos].key != kEmptyKey);
|
||||
pos = Hash(pos+delta);
|
||||
delta +=1;
|
||||
}
|
||||
assert(pos < size_);
|
||||
|
||||
return pos;
|
||||
}
|
||||
|
||||
/**
|
||||
* \brief Hash an ID to a to a position in the hash table.
|
||||
*
|
||||
* \param id The ID to hash.
|
||||
*
|
||||
* \return The hash.
|
||||
*/
|
||||
inline __device__ size_t Hash(
|
||||
const IdType id) const {
|
||||
return id % size_;
|
||||
}
|
||||
|
||||
friend class OrderedHashTable<IdType>;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief A host-side handle for a GPU hashtable for mapping items to the
|
||||
* first index at which they appear in the provided data array. This host-side
|
||||
* handle is responsible for allocating and free the GPU memory of the
|
||||
* hashtable.
|
||||
*
|
||||
* For any ID array A, one can view it as a mapping from the index `i`
|
||||
* (continuous integer range from zero) to its element `A[i]`. This hashtable
|
||||
* serves as a reverse mapping, i.e., from element `A[i]` to its index `i`.
|
||||
* Quadratic probing is used for collision resolution.
|
||||
*
|
||||
* The hash table should be used in two phases, the first is filling the hash
|
||||
* table via 'FillWithDuplicates()' or 'FillWithUnique()'. Then, the
|
||||
* 'DeviceHandle()' method can be called, to get a version suitable for
|
||||
* searching from device and kernel functions.
|
||||
*
|
||||
* If 'FillWithDuplicates()' was called with an array of:
|
||||
* [0, 3, 2, 0, 8, 4, 3, 2, 1, 8]
|
||||
*
|
||||
* The resulting entries in the hash-table would be:
|
||||
* [
|
||||
* {key: 0, local: 0, index: 0},
|
||||
* {key: 3, local: 1, index: 1},
|
||||
* {key: 2, local: 2, index: 2},
|
||||
* {key: 8, local: 3, index: 4},
|
||||
* {key: 4, local: 4, index: 5},
|
||||
* {key: 1, local: 5, index: 8}
|
||||
* ]
|
||||
*
|
||||
* \tparam IdType The type of the IDs.
|
||||
*/
|
||||
template<typename IdType>
|
||||
class OrderedHashTable {
|
||||
public:
|
||||
static constexpr int kDefaultScale = 3;
|
||||
|
||||
using Mapping = typename DeviceOrderedHashTable<IdType>::Mapping;
|
||||
|
||||
/**
|
||||
* \brief Create a new ordered hash table. The amoutn of GPU memory
|
||||
* consumed by the resulting hashtable is O(`size` * 2^`scale`).
|
||||
*
|
||||
* \param size The number of items to insert into the hashtable.
|
||||
* \param ctx The device context to store the hashtable on.
|
||||
* \param scale The power of two times larger the number of buckets should
|
||||
* be than the number of items.
|
||||
* \param stream The stream to use for initializing the hashtable.
|
||||
*/
|
||||
OrderedHashTable(
|
||||
const size_t size,
|
||||
DGLContext ctx,
|
||||
cudaStream_t stream,
|
||||
const int scale = kDefaultScale);
|
||||
|
||||
/**
|
||||
* \brief Cleanup after the hashtable.
|
||||
*/
|
||||
~OrderedHashTable();
|
||||
|
||||
// Disable copying
|
||||
OrderedHashTable(
|
||||
const OrderedHashTable& other) = delete;
|
||||
OrderedHashTable& operator=(
|
||||
const OrderedHashTable& other) = delete;
|
||||
|
||||
/**
|
||||
* \brief Fill the hashtable with the array containing possibly duplicate
|
||||
* IDs.
|
||||
*
|
||||
* \param input The array of IDs to insert.
|
||||
* \param num_input The number of IDs to insert.
|
||||
* \param unique The list of unique IDs inserted.
|
||||
* \param num_unique The number of unique IDs inserted.
|
||||
* \param stream The stream to perform operations on.
|
||||
*/
|
||||
void FillWithDuplicates(
|
||||
const IdType * const input,
|
||||
const size_t num_input,
|
||||
IdType * const unique,
|
||||
int64_t * const num_unique,
|
||||
cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* \brief Fill the hashtable with an array of unique keys.
|
||||
*
|
||||
* \param input The array of unique IDs.
|
||||
* \param num_input The number of keys.
|
||||
* \param stream The stream to perform operations on.
|
||||
*/
|
||||
void FillWithUnique(
|
||||
const IdType * const input,
|
||||
const size_t num_input,
|
||||
cudaStream_t stream);
|
||||
|
||||
/**
|
||||
* \brief Get a verison of the hashtable usable from device functions.
|
||||
*
|
||||
* \return This hashtable.
|
||||
*/
|
||||
DeviceOrderedHashTable<IdType> DeviceHandle() const;
|
||||
|
||||
private:
|
||||
Mapping * table_;
|
||||
size_t size_;
|
||||
DGLContext ctx_;
|
||||
|
||||
};
|
||||
|
||||
|
||||
} // cuda
|
||||
} // runtime
|
||||
} // dgl
|
||||
|
||||
#endif
|
||||
@@ -821,7 +821,6 @@ def test_to_simple(idtype):
|
||||
assert 'h' not in sg.nodes['user'].data
|
||||
assert 'hh' not in sg.nodes['user'].data
|
||||
|
||||
@unittest.skipIf(F._default_context_str == 'gpu', reason="GPU compaction not implemented")
|
||||
@parametrize_dtype
|
||||
def test_to_block(idtype):
|
||||
def check(g, bg, ntype, etype, dst_nodes, include_dst_in_src=True):
|
||||
@@ -838,6 +837,7 @@ def test_to_block(idtype):
|
||||
induced_src = bg.srcdata[dgl.NID]
|
||||
induced_dst = bg.dstdata[dgl.NID]
|
||||
induced_eid = bg.edata[dgl.EID]
|
||||
|
||||
bg_src, bg_dst = bg.all_edges(order='eid')
|
||||
src_ans, dst_ans = g.all_edges(order='eid')
|
||||
|
||||
@@ -860,7 +860,7 @@ def test_to_block(idtype):
|
||||
g = dgl.heterograph({
|
||||
('A', 'AA', 'A'): ([0, 2, 1, 3], [1, 3, 2, 4]),
|
||||
('A', 'AB', 'B'): ([0, 1, 3, 1], [1, 3, 5, 6]),
|
||||
('B', 'BA', 'A'): ([2, 3], [3, 2])}, idtype=idtype)
|
||||
('B', 'BA', 'A'): ([2, 3], [3, 2])}, idtype=idtype, device=F.ctx())
|
||||
g.nodes['A'].data['x'] = F.randn((5, 10))
|
||||
g.nodes['B'].data['x'] = F.randn((7, 5))
|
||||
g.edges['AA'].data['x'] = F.randn((4, 3))
|
||||
|
||||
Reference in New Issue
Block a user