mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)
* rename `DLContext` to `DGLContext` * rename `kDLGPU` to `kDLCUDA` * replace DLTensor with DGLArray * fix linting * Unify DGLType and DLDataType to DGLDataType * Fix FFI * rename DLDeviceType to DGLDeviceType * decouple dlpack from the core library * fix bug * fix lint * fix merge * fix build * address comments * rename dl_converter to dlpack_convert * remove redundant comments
This commit is contained in:
@@ -24,8 +24,8 @@ namespace aten {
|
||||
//////////////////////////////////////////////////////////////////////
|
||||
|
||||
/*! \return A special array to represent null. */
|
||||
inline NDArray NullArray(const DLDataType& dtype = DLDataType{kDLInt, 64, 1},
|
||||
const DLContext& ctx = DLContext{kDLCPU, 0}) {
|
||||
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
|
||||
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
|
||||
return NDArray::Empty({0}, dtype, ctx);
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ inline bool IsNullArray(NDArray array) {
|
||||
* \return id array
|
||||
*/
|
||||
IdArray NewIdArray(int64_t length,
|
||||
DLContext ctx = DLContext{kDLCPU, 0},
|
||||
DGLContext ctx = DGLContext{kDGLCPU, 0},
|
||||
uint8_t nbits = 64);
|
||||
|
||||
/*!
|
||||
@@ -57,7 +57,7 @@ IdArray NewIdArray(int64_t length,
|
||||
template <typename T>
|
||||
IdArray VecToIdArray(const std::vector<T>& vec,
|
||||
uint8_t nbits = 64,
|
||||
DLContext ctx = DLContext{kDLCPU, 0});
|
||||
DGLContext ctx = DGLContext{kDGLCPU, 0});
|
||||
|
||||
/*!
|
||||
* \brief Return an array representing a 1D range.
|
||||
@@ -67,7 +67,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
|
||||
* \param ctx Device context
|
||||
* \return range array
|
||||
*/
|
||||
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
|
||||
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
|
||||
|
||||
/*!
|
||||
* \brief Return an array full of the given value
|
||||
@@ -77,7 +77,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
|
||||
* \param ctx Device context
|
||||
* \return the result array
|
||||
*/
|
||||
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
|
||||
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
|
||||
|
||||
/*!
|
||||
* \brief Return an array full of the given value with the given type.
|
||||
@@ -87,7 +87,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
|
||||
* \return the result array
|
||||
*/
|
||||
template <typename DType>
|
||||
NDArray Full(DType val, int64_t length, DLContext ctx);
|
||||
NDArray Full(DType val, int64_t length, DGLContext ctx);
|
||||
|
||||
/*! \brief Create a deep copy of the given array */
|
||||
IdArray Clone(IdArray arr);
|
||||
@@ -226,7 +226,7 @@ NDArray Concat(const std::vector<IdArray>& arrays);
|
||||
|
||||
/*!\brief Return whether the array is a valid 1D int array*/
|
||||
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
|
||||
return arr->ndim == 1 && arr->dtype.code == kDLInt;
|
||||
return arr->ndim == 1 && arr->dtype.code == kDGLInt;
|
||||
}
|
||||
|
||||
/*!
|
||||
@@ -343,8 +343,8 @@ std::string ToDebugString(NDArray array);
|
||||
template <typename T>
|
||||
IdArray VecToIdArray(const std::vector<T>& vec,
|
||||
uint8_t nbits,
|
||||
DLContext ctx) {
|
||||
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
|
||||
DGLContext ctx) {
|
||||
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
|
||||
if (nbits == 32) {
|
||||
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
|
||||
} else if (nbits == 64) {
|
||||
@@ -359,9 +359,9 @@ IdArray VecToIdArray(const std::vector<T>& vec,
|
||||
* \brief Get the context of the first array, and check if the non-null arrays'
|
||||
* contexts are the same.
|
||||
*/
|
||||
inline DLContext GetContextOf(const std::vector<IdArray>& arrays) {
|
||||
inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
|
||||
bool first = true;
|
||||
DLContext result;
|
||||
DGLContext result;
|
||||
for (auto& array : arrays) {
|
||||
if (first) {
|
||||
first = false;
|
||||
|
||||
@@ -122,11 +122,10 @@ struct COOMatrix {
|
||||
}
|
||||
|
||||
/*! \brief Return a copy of this matrix on the give device context. */
|
||||
inline COOMatrix CopyTo(const DLContext &ctx) const {
|
||||
inline COOMatrix CopyTo(const DGLContext &ctx) const {
|
||||
if (ctx == row->ctx)
|
||||
return *this;
|
||||
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx),
|
||||
col.CopyTo(ctx),
|
||||
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
|
||||
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
|
||||
row_sorted, col_sorted);
|
||||
}
|
||||
@@ -134,9 +133,9 @@ struct COOMatrix {
|
||||
/*!
|
||||
* \brief Pin the row, col and data (if not Null) of the matrix.
|
||||
* \note This is an in-place method. Behavior depends on the current context,
|
||||
* kDLCPU: will be pinned;
|
||||
* kDGLCPU: will be pinned;
|
||||
* IsPinned: directly return;
|
||||
* kDLGPU: invalid, will throw an error.
|
||||
* kDGLCUDA: invalid, will throw an error.
|
||||
* The context check is deferred to pinning the NDArray.
|
||||
*/
|
||||
inline void PinMemory_() {
|
||||
|
||||
@@ -115,21 +115,19 @@ struct CSRMatrix {
|
||||
}
|
||||
|
||||
/*! \brief Return a copy of this matrix on the give device context. */
|
||||
inline CSRMatrix CopyTo(const DLContext &ctx) const {
|
||||
inline CSRMatrix CopyTo(const DGLContext &ctx) const {
|
||||
if (ctx == indptr->ctx)
|
||||
return *this;
|
||||
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx),
|
||||
indices.CopyTo(ctx),
|
||||
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
|
||||
sorted);
|
||||
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx),
|
||||
aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted);
|
||||
}
|
||||
|
||||
/*!
|
||||
* \brief Pin the indptr, indices and data (if not Null) of the matrix.
|
||||
* \note This is an in-place method. Behavior depends on the current context,
|
||||
* kDLCPU: will be pinned;
|
||||
* kDGLCPU: will be pinned;
|
||||
* IsPinned: directly return;
|
||||
* kDLGPU: invalid, will throw an error.
|
||||
* kDGLCUDA: invalid, will throw an error.
|
||||
* The context check is deferred to pinning the NDArray.
|
||||
*/
|
||||
inline void PinMemory_() {
|
||||
|
||||
@@ -18,8 +18,8 @@
|
||||
* });
|
||||
*/
|
||||
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
|
||||
if ((val) == kDLCPU) { \
|
||||
constexpr auto XPU = kDLCPU; \
|
||||
if ((val) == kDGLCPU) { \
|
||||
constexpr auto XPU = kDGLCPU; \
|
||||
{__VA_ARGS__} \
|
||||
} else { \
|
||||
LOG(FATAL) << "Operator " << (op) << " does not support " \
|
||||
@@ -43,11 +43,11 @@
|
||||
*/
|
||||
#ifdef DGL_USE_CUDA
|
||||
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
|
||||
if ((val) == kDLCPU) { \
|
||||
constexpr auto XPU = kDLCPU; \
|
||||
if ((val) == kDGLCPU) { \
|
||||
constexpr auto XPU = kDGLCPU; \
|
||||
{__VA_ARGS__} \
|
||||
} else if ((val) == kDLGPU) { \
|
||||
constexpr auto XPU = kDLGPU; \
|
||||
} else if ((val) == kDGLCUDA) { \
|
||||
constexpr auto XPU = kDGLCUDA; \
|
||||
{__VA_ARGS__} \
|
||||
} else { \
|
||||
LOG(FATAL) << "Operator " << (op) << " does not support " \
|
||||
@@ -69,7 +69,7 @@
|
||||
* });
|
||||
*/
|
||||
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
|
||||
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
|
||||
CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
|
||||
if ((val).bits == 32) { \
|
||||
typedef int32_t IdType; \
|
||||
{__VA_ARGS__} \
|
||||
@@ -114,7 +114,7 @@
|
||||
* });
|
||||
*/
|
||||
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
|
||||
CHECK_EQ((val).code, kDLFloat) \
|
||||
CHECK_EQ((val).code, kDGLFloat) \
|
||||
<< (val_name) << " must be float type"; \
|
||||
if ((val).bits == 32) { \
|
||||
typedef float FloatType; \
|
||||
@@ -128,7 +128,7 @@
|
||||
} while (0)
|
||||
|
||||
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \
|
||||
CHECK_EQ((val).code, kDLFloat) \
|
||||
CHECK_EQ((val).code, kDGLFloat) \
|
||||
<< (val_name) << " must be float type"; \
|
||||
if ((val).bits == 16) { \
|
||||
constexpr int bits = 16; \
|
||||
@@ -154,16 +154,16 @@
|
||||
* });
|
||||
*/
|
||||
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
|
||||
if ((val).code == kDLInt && (val).bits == 32) { \
|
||||
if ((val).code == kDGLInt && (val).bits == 32) { \
|
||||
typedef int32_t DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else if ((val).code == kDLInt && (val).bits == 64) { \
|
||||
} else if ((val).code == kDGLInt && (val).bits == 64) { \
|
||||
typedef int64_t DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else if ((val).code == kDLFloat && (val).bits == 32) { \
|
||||
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
|
||||
typedef float DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else if ((val).code == kDLFloat && (val).bits == 64) { \
|
||||
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
|
||||
typedef double DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else { \
|
||||
@@ -205,10 +205,10 @@
|
||||
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
|
||||
*/
|
||||
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
|
||||
if ((val).code == kDLInt && (val).bits == 32) { \
|
||||
if ((val).code == kDGLInt && (val).bits == 32) { \
|
||||
typedef int32_t DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else if ((val).code == kDLInt && (val).bits == 64) { \
|
||||
} else if ((val).code == kDGLInt && (val).bits == 64) { \
|
||||
typedef int64_t DType; \
|
||||
{__VA_ARGS__} \
|
||||
} else { \
|
||||
@@ -278,13 +278,13 @@
|
||||
///////////////////////// Array checks //////////////////////////
|
||||
|
||||
#define IS_INT32(a) \
|
||||
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
|
||||
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
|
||||
#define IS_INT64(a) \
|
||||
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
|
||||
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
|
||||
#define IS_FLOAT32(a) \
|
||||
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
|
||||
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
|
||||
#define IS_FLOAT64(a) \
|
||||
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)
|
||||
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
|
||||
|
||||
#define CHECK_IF(cond, prop, value_name, dtype_name) \
|
||||
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)
|
||||
|
||||
@@ -34,7 +34,7 @@ typedef NDArray TypeArray;
|
||||
|
||||
namespace aten {
|
||||
|
||||
static const DLContext CPU{kDLCPU, 0};
|
||||
static const DGLContext CPU{kDGLCPU, 0};
|
||||
|
||||
} // namespace aten
|
||||
} // namespace dgl
|
||||
|
||||
@@ -104,12 +104,12 @@ class BaseHeteroGraph : public runtime::Object {
|
||||
/*!
|
||||
* \brief Get the data type of node and edge IDs of this graph.
|
||||
*/
|
||||
virtual DLDataType DataType() const = 0;
|
||||
virtual DGLDataType DataType() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Get the device context of this graph.
|
||||
*/
|
||||
virtual DLContext Context() const = 0;
|
||||
virtual DGLContext Context() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Pin graph.
|
||||
|
||||
@@ -89,8 +89,8 @@ class Graph: public GraphInterface {
|
||||
num_edges_ = 0;
|
||||
}
|
||||
|
||||
DLContext Context() const override {
|
||||
return DLContext{kDLCPU, 0};
|
||||
DGLContext Context() const override {
|
||||
return DGLContext{kDGLCPU, 0};
|
||||
}
|
||||
|
||||
uint8_t NumBits() const override {
|
||||
|
||||
@@ -137,7 +137,7 @@ class GraphInterface : public runtime::Object {
|
||||
/*!
|
||||
* \brief Get the device context of this graph.
|
||||
*/
|
||||
virtual DLContext Context() const = 0;
|
||||
virtual DGLContext Context() const = 0;
|
||||
|
||||
/*!
|
||||
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).
|
||||
|
||||
@@ -69,7 +69,7 @@ class CSR : public GraphInterface {
|
||||
LOG(FATAL) << "CSR graph does not allow mutation.";
|
||||
}
|
||||
|
||||
DLContext Context() const override {
|
||||
DGLContext Context() const override {
|
||||
return adj_.indptr->ctx;
|
||||
}
|
||||
|
||||
@@ -214,7 +214,7 @@ class CSR : public GraphInterface {
|
||||
* \param ctx The target context.
|
||||
* \return The graph under another context.
|
||||
*/
|
||||
CSR CopyTo(const DLContext& ctx) const;
|
||||
CSR CopyTo(const DGLContext& ctx) const;
|
||||
|
||||
/*!
|
||||
* \brief Copy data to shared memory.
|
||||
@@ -288,7 +288,7 @@ class COO : public GraphInterface {
|
||||
LOG(FATAL) << "COO graph does not allow mutation.";
|
||||
}
|
||||
|
||||
DLContext Context() const override {
|
||||
DGLContext Context() const override {
|
||||
return adj_.row->ctx;
|
||||
}
|
||||
|
||||
@@ -472,7 +472,7 @@ class COO : public GraphInterface {
|
||||
* \param ctx The target context.
|
||||
* \return The graph under another context.
|
||||
*/
|
||||
COO CopyTo(const DLContext& ctx) const;
|
||||
COO CopyTo(const DGLContext& ctx) const;
|
||||
|
||||
/*!
|
||||
* \brief Copy data to shared memory.
|
||||
@@ -578,7 +578,7 @@ class ImmutableGraph: public GraphInterface {
|
||||
LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
|
||||
}
|
||||
|
||||
DLContext Context() const override {
|
||||
DGLContext Context() const override {
|
||||
return AnyGraph()->Context();
|
||||
}
|
||||
|
||||
@@ -911,7 +911,7 @@ class ImmutableGraph: public GraphInterface {
|
||||
* \param ctx The target context.
|
||||
* \return The graph under another context.
|
||||
*/
|
||||
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);
|
||||
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx);
|
||||
|
||||
/*!
|
||||
* \brief Copy data to shared memory.
|
||||
|
||||
@@ -145,7 +145,7 @@ class RandomEngine {
|
||||
*/
|
||||
template <typename IdxType, typename FloatType>
|
||||
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
|
||||
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
|
||||
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
|
||||
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
|
||||
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace);
|
||||
return ret;
|
||||
@@ -178,9 +178,9 @@ class RandomEngine {
|
||||
*/
|
||||
template <typename IdxType>
|
||||
IdArray UniformChoice(IdxType num, IdxType population, bool replace = true) {
|
||||
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
|
||||
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
|
||||
// TODO(minjie): only CPU implementation right now
|
||||
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
|
||||
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
|
||||
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace);
|
||||
return ret;
|
||||
}
|
||||
@@ -230,8 +230,8 @@ class RandomEngine {
|
||||
template <typename IdxType, typename FloatType>
|
||||
IdArray BiasedChoice(
|
||||
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
|
||||
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
|
||||
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
|
||||
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
|
||||
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
|
||||
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
|
||||
return ret;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2016 by Contributors
|
||||
* Copyright (c) 2016-2022 by Contributors
|
||||
* \file dgl/runtime/c_runtime_api.h
|
||||
* \brief DGL runtime library.
|
||||
*
|
||||
@@ -35,10 +35,6 @@
|
||||
// DGL version
|
||||
#define DGL_VERSION "0.9"
|
||||
|
||||
|
||||
// DGL Runtime is DLPack compatible.
|
||||
#include <dlpack/dlpack.h>
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
@@ -48,28 +44,31 @@ extern "C" {
|
||||
/*! \brief type of array index. */
|
||||
typedef int64_t dgl_index_t;
|
||||
|
||||
/*! \brief Extension device types in DGL */
|
||||
/*!
|
||||
* \brief The device type in DGLContext.
|
||||
*/
|
||||
#ifdef __cplusplus
|
||||
typedef enum : int32_t {
|
||||
#else
|
||||
typedef enum {
|
||||
kDLAOCL = 5,
|
||||
kDLSDAccel = 6,
|
||||
kOpenGL = 11,
|
||||
// Extension DRAM type, used for quickly test extension device
|
||||
// The device api can differ depending on the xpu driver registered.
|
||||
kExtDev = 12,
|
||||
// AddExtraDGLType which is not in DLPack here
|
||||
} DGLDeviceExtType;
|
||||
#endif
|
||||
/*! \brief CPU device */
|
||||
kDGLCPU = 1,
|
||||
/*! \brief CUDA GPU device */
|
||||
kDGLCUDA = 2,
|
||||
// add more devices once supported
|
||||
} DGLDeviceType;
|
||||
|
||||
/*!
|
||||
* \brief The type code in DGLType
|
||||
* \note DGLType is used in two places.
|
||||
* \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python.
|
||||
*/
|
||||
typedef enum {
|
||||
// The type code of other types are compatible with DLPack.
|
||||
// The next few fields are extension types
|
||||
// that is used by DGL API calls.
|
||||
kInt = 0U,
|
||||
kUInt = 1U,
|
||||
kFloat = 2U,
|
||||
kHandle = 3U,
|
||||
kNull = 4U,
|
||||
kDGLType = 5U,
|
||||
kDGLDataType = 5U,
|
||||
kDGLContext = 6U,
|
||||
kArrayHandle = 7U,
|
||||
kObjectHandle = 8U,
|
||||
@@ -88,29 +87,112 @@ typedef enum {
|
||||
// The following section of code is used for non-reserved types.
|
||||
kExtReserveEnd = 64U,
|
||||
kExtEnd = 128U
|
||||
} DGLTypeCode;
|
||||
} DGLObjectTypeCode;
|
||||
|
||||
/*!
|
||||
* \brief The data type used in DGL Runtime.
|
||||
* \brief The type code options DGLDataType.
|
||||
*/
|
||||
typedef enum {
|
||||
/*! \brief signed integer */
|
||||
kDGLInt = 0U,
|
||||
/*! \brief unsigned integer */
|
||||
kDGLUInt = 1U,
|
||||
/*! \brief IEEE floating point */
|
||||
kDGLFloat = 2U,
|
||||
/*! \brief bfloat16 */
|
||||
kDGLBfloat = 4U,
|
||||
// add more data types if we are going to support them
|
||||
} DGLDataTypeCode;
|
||||
|
||||
/*!
|
||||
* \brief The data type the tensor can hold. The data type is assumed to follow the
|
||||
* native endian-ness. An explicit error message should be raised when attempting to
|
||||
* export an array with non-native endianness
|
||||
*
|
||||
* Examples
|
||||
* - float: type_code = 2, bits = 32, lanes=1
|
||||
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
|
||||
* - int8: type_code = 0, bits = 8, lanes=1
|
||||
*
|
||||
* \note Arguments DGL API function always takes bits=64 and lanes=1
|
||||
*/
|
||||
typedef DLDataType DGLType;
|
||||
typedef struct {
|
||||
/*!
|
||||
* \brief Type code of base types.
|
||||
* We keep it uint8_t instead of DGLDataTypeCode for minimal memory
|
||||
* footprint, but the value should be one of DGLDataTypeCode enum values.
|
||||
* */
|
||||
uint8_t code;
|
||||
/*!
|
||||
* \brief Number of bits, common choices are 8, 16, 32.
|
||||
*/
|
||||
uint8_t bits;
|
||||
/*! \brief Number of lanes in the type, used for vector types. */
|
||||
uint16_t lanes;
|
||||
} DGLDataType;
|
||||
|
||||
/*!
|
||||
* \brief The Device information, abstract away common device types.
|
||||
*/
|
||||
typedef DLContext DGLContext;
|
||||
typedef struct {
|
||||
/*! \brief The device type used in the device. */
|
||||
DGLDeviceType device_type;
|
||||
/*!
|
||||
* \brief The device index.
|
||||
* For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
|
||||
*/
|
||||
int32_t device_id;
|
||||
} DGLContext;
|
||||
|
||||
/*!
|
||||
* \brief The tensor array stucture to DGL API.
|
||||
* The structure is heavily inspired by DLTensor from DLPack.
|
||||
*/
|
||||
typedef DLTensor DGLArray;
|
||||
typedef struct {
|
||||
/*!
|
||||
* \brief The data pointer points to the allocated data.
|
||||
*
|
||||
* Depending on the device context, it can be a CPU pointer, or a CUDA
|
||||
* device pointer or acl_mem handle in OpenCL.
|
||||
* This pointer is always aligned to 256 bytes as in CUDA. Use the
|
||||
* `byte_offset` field to mark the beginning of the actual data (if the address
|
||||
* is not 256 byte aligned).
|
||||
*
|
||||
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
|
||||
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
|
||||
* on CPU/CUDA/ROCm, and always use `byte_offset=0`. This is likely to be
|
||||
* fixed in the future; at the moment it is recommended
|
||||
* to not rely on the data pointer being correctly aligned.
|
||||
*
|
||||
* For a DGLArray, the size of memory required to store the contents of
|
||||
* data can be calculated as follows:
|
||||
*
|
||||
* \code{.c}
|
||||
* static inline size_t GetDataSize(const DGLArray* t) {
|
||||
* size_t size = 1;
|
||||
* for (int32_t i = 0; i < t->ndim; ++i) {
|
||||
* size *= t->shape[i];
|
||||
* }
|
||||
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
|
||||
* return size;
|
||||
* }
|
||||
* \endcode
|
||||
*/
|
||||
void* data;
|
||||
/*! \brief The device of the tensor */
|
||||
DGLContext ctx;
|
||||
/*! \brief Number of dimensions */
|
||||
int32_t ndim;
|
||||
/*! \brief The data type of the pointer*/
|
||||
DGLDataType dtype;
|
||||
/*! \brief The shape of the tensor */
|
||||
int64_t* shape;
|
||||
/*!
|
||||
* \brief strides of the tensor (in number of elements, not bytes)
|
||||
* can be NULL, indicating tensor is compact and row-majored.
|
||||
*/
|
||||
int64_t* strides;
|
||||
/*! \brief The offset in bytes to the beginning pointer to data */
|
||||
uint64_t byte_offset;
|
||||
} DGLArray;
|
||||
|
||||
/*! \brief the array handle */
|
||||
typedef DGLArray* DGLArrayHandle;
|
||||
@@ -124,7 +206,7 @@ typedef union {
|
||||
double v_float64;
|
||||
void* v_handle;
|
||||
const char* v_str;
|
||||
DGLType v_type;
|
||||
DGLDataType v_type;
|
||||
DGLContext v_ctx;
|
||||
} DGLValue;
|
||||
|
||||
@@ -455,32 +537,6 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
|
||||
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
|
||||
DGLArrayHandle to);
|
||||
|
||||
/*!
|
||||
* \brief Produce an array from the DLManagedTensor that shares data memory
|
||||
* with the DLManagedTensor.
|
||||
* \param from The source DLManagedTensor.
|
||||
* \param out The output array handle.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
|
||||
DGLArrayHandle* out);
|
||||
|
||||
/*!
|
||||
* \brief Produce a DLMangedTensor from the array that shares data memory with
|
||||
* the array.
|
||||
* \param from The source array.
|
||||
* \param out The DLManagedTensor handle.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
|
||||
int alignment = 0);
|
||||
|
||||
/*!
|
||||
* \brief Delete (free) a DLManagedTensor's data.
|
||||
* \param dltensor Pointer to the DLManagedTensor.
|
||||
*/
|
||||
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
|
||||
|
||||
/*!
|
||||
* \brief Create a new runtime stream.
|
||||
*
|
||||
@@ -557,12 +613,12 @@ DGL_DLL int DGLLoadTensorAdapter(const char *path);
|
||||
/*!
|
||||
* \brief Pin host memory.
|
||||
*/
|
||||
int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);
|
||||
int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx);
|
||||
|
||||
/*!
|
||||
* \brief Unpin host memory.
|
||||
*/
|
||||
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);
|
||||
int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx);
|
||||
|
||||
/*!
|
||||
* \brief Record the stream that's using this tensor.
|
||||
|
||||
@@ -75,7 +75,7 @@ class DeviceAPI {
|
||||
virtual void* AllocDataSpace(DGLContext ctx,
|
||||
size_t nbytes,
|
||||
size_t alignment,
|
||||
DGLType type_hint) = 0;
|
||||
DGLDataType type_hint) = 0;
|
||||
/*!
|
||||
* \brief Free a data space on device.
|
||||
* \param ctx The device context to perform operation.
|
||||
@@ -101,7 +101,7 @@ class DeviceAPI {
|
||||
size_t num_bytes,
|
||||
DGLContext ctx_from,
|
||||
DGLContext ctx_to,
|
||||
DGLType type_hint) = 0;
|
||||
DGLDataType type_hint) = 0;
|
||||
/*!
|
||||
* \brief Create a new stream of execution.
|
||||
*
|
||||
@@ -189,7 +189,7 @@ class DeviceAPI {
|
||||
*/
|
||||
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
|
||||
size_t nbytes,
|
||||
DGLType type_hint = {});
|
||||
DGLDataType type_hint = {});
|
||||
/*!
|
||||
* \brief Free temporal workspace in backend execution.
|
||||
*
|
||||
@@ -213,7 +213,7 @@ class DeviceAPI {
|
||||
* \param allow_missing Whether allow missing
|
||||
* \return The corresponding device API.
|
||||
*/
|
||||
DGL_DLL static DeviceAPI* Get(DLDeviceType dev_type, bool allow_missing = false);
|
||||
DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false);
|
||||
};
|
||||
|
||||
/*! \brief The device type bigger than this is RPC device */
|
||||
|
||||
85
include/dgl/runtime/dlpack_convert.h
Normal file
85
include/dgl/runtime/dlpack_convert.h
Normal file
@@ -0,0 +1,85 @@
|
||||
/*!
|
||||
* Copyright (c) 2022 by Contributors
|
||||
* \file include/dgl/runtime/dlpack_convert.h
|
||||
* \brief Conversion between NDArray and DLPack.
|
||||
*/
|
||||
#ifndef DGL_RUNTIME_DLPACK_CONVERT_H_
|
||||
#define DGL_RUNTIME_DLPACK_CONVERT_H_
|
||||
|
||||
#include "c_runtime_api.h"
|
||||
#include "ndarray.h"
|
||||
|
||||
struct DLManagedTensor;
|
||||
|
||||
namespace dgl {
|
||||
namespace runtime {
|
||||
|
||||
struct DLPackConvert {
|
||||
/*!
|
||||
* \brief Create a DGL NDArray from a DLPack tensor.
|
||||
*
|
||||
* This allows us to create a NDArray using the memory
|
||||
* allocated by an external deep learning framework
|
||||
* that is DLPack compatible.
|
||||
*
|
||||
* The memory is retained until the NDArray went out of scope.
|
||||
* \param tensor The DLPack tensor to copy from.
|
||||
* \return The created NDArray view.
|
||||
*/
|
||||
static NDArray FromDLPack(DLManagedTensor* tensor);
|
||||
|
||||
/*!
|
||||
* \brief Deleter for NDArray converted from DLPack.
|
||||
*
|
||||
* This is used from data which is passed from external DLPack(DLManagedTensor)
|
||||
* that are not allocated inside of DGL.
|
||||
* This enables us to create NDArray from memory allocated by other
|
||||
* frameworks that are DLPack compatible
|
||||
*/
|
||||
static void DLPackDeleter(NDArray::Container* ptr);
|
||||
|
||||
/*! \brief Convert a DGL NDArray to a DLPack tensor.
|
||||
*
|
||||
* \param from The DGL NDArray.
|
||||
* \return A DLPack tensor.
|
||||
*/
|
||||
static DLManagedTensor* ToDLPack(const NDArray &from);
|
||||
};
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace dgl
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*!
|
||||
* \brief Delete (free) a DLManagedTensor's data.
|
||||
* \param dltensor Pointer to the DLManagedTensor.
|
||||
*/
|
||||
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
|
||||
|
||||
/*!
|
||||
* \brief Produce an array from the DLManagedTensor that shares data memory
|
||||
* with the DLManagedTensor.
|
||||
* \param from The source DLManagedTensor.
|
||||
* \param out The output array handle.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
|
||||
DGLArrayHandle* out);
|
||||
|
||||
/*!
|
||||
* \brief Produce a DLMangedTensor from the array that shares data memory with
|
||||
* the array.
|
||||
* \param from The source array.
|
||||
* \param out The DLManagedTensor handle.
|
||||
* \return 0 when success, -1 when failure happens
|
||||
*/
|
||||
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
|
||||
int alignment = 0);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // DGL_EXTERN_C
|
||||
#endif
|
||||
#endif // DGL_RUNTIME_DLPACK_CONVERT_H_
|
||||
@@ -1,5 +1,5 @@
|
||||
/*!
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* Copyright (c) 2017-2022 by Contributors
|
||||
* \file dgl/runtime/ndarray.h
|
||||
* \brief Abstract device memory management API
|
||||
*/
|
||||
@@ -14,7 +14,6 @@
|
||||
#include <memory>
|
||||
|
||||
#include "c_runtime_api.h"
|
||||
#include "dlpack/dlpack.h"
|
||||
#include "serializer.h"
|
||||
#include "shared_mem.h"
|
||||
|
||||
@@ -23,44 +22,49 @@
|
||||
#endif
|
||||
|
||||
// forward declaration
|
||||
inline std::ostream& operator << (std::ostream& os, DGLType t);
|
||||
inline std::ostream& operator << (std::ostream& os, DGLDataType t);
|
||||
|
||||
namespace dgl {
|
||||
|
||||
/*!
|
||||
* \brief Type traits that converts a C type to a DLDataType.
|
||||
* \brief Type traits that converts a C type to a DGLDataType.
|
||||
*
|
||||
* Usage:
|
||||
* DLDataTypeTraits<int>::dtype == dtype
|
||||
* DGLDataTypeTraits<int>::dtype == dtype
|
||||
*/
|
||||
template<typename T>
|
||||
struct DLDataTypeTraits {
|
||||
static constexpr DLDataType dtype{0, 0, 0}; // dummy
|
||||
struct DGLDataTypeTraits {
|
||||
static constexpr DGLDataType dtype{0, 0, 0}; // dummy
|
||||
};
|
||||
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \
|
||||
#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
|
||||
template<> \
|
||||
struct DLDataTypeTraits<T> { \
|
||||
static constexpr DLDataType dtype{code, bits, 1}; \
|
||||
struct DGLDataTypeTraits<T> { \
|
||||
static constexpr DGLDataType dtype{code, bits, 1}; \
|
||||
}
|
||||
GEN_DLDATATYPETRAITS_FOR(int8_t, kDLInt, 8);
|
||||
GEN_DLDATATYPETRAITS_FOR(int16_t, kDLInt, 16);
|
||||
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32);
|
||||
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
|
||||
GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
|
||||
GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
|
||||
GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
|
||||
GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
|
||||
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
|
||||
// converting uints to signed DTypes.
|
||||
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
|
||||
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
|
||||
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
|
||||
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
|
||||
#ifdef DGL_USE_CUDA
|
||||
#ifdef USE_FP16
|
||||
GEN_DLDATATYPETRAITS_FOR(__half, kDLFloat, 16);
|
||||
GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
|
||||
#endif
|
||||
#endif
|
||||
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
|
||||
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
|
||||
#undef GEN_DLDATATYPETRAITS_FOR
|
||||
GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
|
||||
GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
|
||||
#undef GEN_DGLDATATYPETRAITS_FOR
|
||||
|
||||
namespace runtime {
|
||||
|
||||
/*!
|
||||
* \brief DLPack converter.
|
||||
*/
|
||||
struct DLPackConvert;
|
||||
|
||||
/*!
|
||||
* \brief Managed NDArray.
|
||||
* The array is backed by reference counted blocks.
|
||||
@@ -135,8 +139,8 @@ class NDArray {
|
||||
* \note this number is approximate in multi-threaded setting.
|
||||
*/
|
||||
inline int use_count() const;
|
||||
/*! \return Pointer to content of DLTensor */
|
||||
inline const DLTensor* operator->() const;
|
||||
/*! \return Pointer to content of DGLArray */
|
||||
inline const DGLArray* operator->() const;
|
||||
/*! \return True if the ndarray is contiguous. */
|
||||
bool IsContiguous() const;
|
||||
/*! \return the data pointer with type. */
|
||||
@@ -152,9 +156,9 @@ class NDArray {
|
||||
* \param other The source array to be copied from.
|
||||
* \note The copy runs on the dgl internal stream if it involves a GPU context.
|
||||
*/
|
||||
inline void CopyFrom(DLTensor* other);
|
||||
inline void CopyFrom(DGLArray* other);
|
||||
inline void CopyFrom(const NDArray& other);
|
||||
inline void CopyTo(DLTensor *other) const;
|
||||
inline void CopyTo(DGLArray *other) const;
|
||||
inline void CopyTo(const NDArray &other) const;
|
||||
|
||||
/*!
|
||||
@@ -162,7 +166,7 @@ class NDArray {
|
||||
* \param ctx The target context.
|
||||
* \return The array under another context.
|
||||
*/
|
||||
inline NDArray CopyTo(const DLContext &ctx) const;
|
||||
inline NDArray CopyTo(const DGLContext &ctx) const;
|
||||
/*!
|
||||
* \brief Return a new array with a copy of the content.
|
||||
*/
|
||||
@@ -171,9 +175,9 @@ class NDArray {
|
||||
* \brief In-place method to pin the current array by calling PinContainer
|
||||
* on the underlying NDArray:Container.
|
||||
* \note This is an in-place method. Behavior depends on the current context,
|
||||
* kDLCPU: will be pinned;
|
||||
* kDGLCPU: will be pinned;
|
||||
* IsPinned: directly return;
|
||||
* kDLGPU: invalid, will throw an error.
|
||||
* kDGLCUDA: invalid, will throw an error.
|
||||
*/
|
||||
inline void PinMemory_();
|
||||
/*!
|
||||
@@ -212,13 +216,7 @@ class NDArray {
|
||||
* \note The memory size of new array must be smaller than the current one.
|
||||
*/
|
||||
DGL_DLL NDArray CreateView(
|
||||
std::vector<int64_t> shape, DLDataType dtype, int64_t offset = 0);
|
||||
/*!
|
||||
* \brief Create a reference view of NDArray that
|
||||
* represents as DLManagedTensor.
|
||||
* \return A DLManagedTensor
|
||||
*/
|
||||
DGL_DLL DLManagedTensor* ToDLPack() const;
|
||||
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
|
||||
/*!
|
||||
* \brief Create an empty NDArray.
|
||||
* \param shape The shape of the new array.
|
||||
@@ -227,8 +225,8 @@ class NDArray {
|
||||
* \return The created Array
|
||||
*/
|
||||
DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
|
||||
DLDataType dtype,
|
||||
DLContext ctx);
|
||||
DGLDataType dtype,
|
||||
DGLContext ctx);
|
||||
/*!
|
||||
* \brief Create an empty NDArray with shared memory.
|
||||
* \param name The name of shared memory.
|
||||
@@ -240,8 +238,8 @@ class NDArray {
|
||||
*/
|
||||
DGL_DLL static NDArray EmptyShared(const std::string &name,
|
||||
std::vector<int64_t> shape,
|
||||
DLDataType dtype,
|
||||
DLContext ctx,
|
||||
DGLDataType dtype,
|
||||
DGLContext ctx,
|
||||
bool is_create);
|
||||
/*!
|
||||
* \brief Get the size of the array in the number of bytes.
|
||||
@@ -253,26 +251,19 @@ class NDArray {
|
||||
*/
|
||||
int64_t NumElements() const;
|
||||
|
||||
/*!
|
||||
* \brief Create a NDArray backed by a dlpack tensor.
|
||||
*
|
||||
* This allows us to create a NDArray using the memory
|
||||
* allocated by an external deep learning framework
|
||||
* that is DLPack compatible.
|
||||
*
|
||||
* The memory is retained until the NDArray went out of scope.
|
||||
* \param tensor The DLPack tensor to copy from.
|
||||
* \return The created NDArray view.
|
||||
*/
|
||||
DGL_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
|
||||
|
||||
/*!
|
||||
* \brief Create a NDArray by copying from std::vector.
|
||||
* \tparam T Type of vector data. Determines the dtype of returned array.
|
||||
*/
|
||||
template<typename T>
|
||||
DGL_DLL static NDArray FromVector(
|
||||
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});
|
||||
const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
|
||||
|
||||
/*!
|
||||
* \brief Create a NDArray from a raw pointer.
|
||||
*/
|
||||
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape,
|
||||
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free);
|
||||
|
||||
/*!
|
||||
* \brief Create a std::vector from a 1D NDArray.
|
||||
@@ -292,23 +283,23 @@ class NDArray {
|
||||
* \param (optional) stream The stream used in copy.
|
||||
*/
|
||||
DGL_DLL static void CopyFromTo(
|
||||
DLTensor* from, DLTensor* to);
|
||||
DGLArray* from, DGLArray* to);
|
||||
DGL_DLL static void CopyFromTo(
|
||||
DLTensor* from, DLTensor* to, DGLStreamHandle stream);
|
||||
DGLArray* from, DGLArray* to, DGLStreamHandle stream);
|
||||
|
||||
/*!
|
||||
* \brief Function to pin the DLTensor of a Container.
|
||||
* \brief Function to pin the DGLArray of a Container.
|
||||
* \param ptr The container to be pinned.
|
||||
* \note Data of the given array will be pinned inplace.
|
||||
* Behavior depends on the current context,
|
||||
* kDLCPU: will be pinned;
|
||||
* kDGLCPU: will be pinned;
|
||||
* IsPinned: directly return;
|
||||
* kDLGPU: invalid, will throw an error.
|
||||
* kDGLCUDA: invalid, will throw an error.
|
||||
*/
|
||||
DGL_DLL static void PinContainer(Container* ptr);
|
||||
|
||||
/*!
|
||||
* \brief Function to unpin the DLTensor of a Container.
|
||||
* \brief Function to unpin the DGLArray of a Container.
|
||||
* \param ptr The container to be unpinned.
|
||||
* \note Data of the given array will be unpinned inplace.
|
||||
* Behavior depends on the current context,
|
||||
@@ -318,7 +309,7 @@ class NDArray {
|
||||
DGL_DLL static void UnpinContainer(Container* ptr);
|
||||
|
||||
/*!
|
||||
* \brief Function check if the DLTensor of a Container is pinned.
|
||||
* \brief Function check if the DGLArray of a Container is pinned.
|
||||
* \param ptr The container to be checked.
|
||||
* \return true if pinned.
|
||||
*/
|
||||
@@ -332,45 +323,57 @@ class NDArray {
|
||||
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
|
||||
|
||||
// internal namespace
|
||||
struct Internal;
|
||||
struct Internal {
|
||||
// Default deleter for the container
|
||||
static void DefaultDeleter(NDArray::Container* ptr);
|
||||
// Local create function which allocates tensor metadata
|
||||
// but does not allocate space for the data.
|
||||
static NDArray Create(std::vector<int64_t> shape,
|
||||
DGLDataType dtype, DGLContext ctx);
|
||||
// Implementation of API function
|
||||
static DGLArray* MoveAsDGLArray(NDArray arr);
|
||||
};
|
||||
|
||||
private:
|
||||
/*! \brief Internal Data content */
|
||||
Container* data_{nullptr};
|
||||
// enable internal functions
|
||||
friend struct Internal;
|
||||
friend struct DLPackConvert;
|
||||
friend class DGLRetValue;
|
||||
friend class DGLArgsSetter;
|
||||
};
|
||||
|
||||
/*!
|
||||
* \brief Save a DLTensor to stream
|
||||
* \brief Save a DGLArray to stream
|
||||
* \param strm The outpu stream
|
||||
* \param tensor The tensor to be saved.
|
||||
*/
|
||||
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
|
||||
inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
|
||||
|
||||
|
||||
/*!
|
||||
* \brief Reference counted Container object used to back NDArray.
|
||||
*
|
||||
* This object is DLTensor compatible:
|
||||
* This object is DGLArray compatible:
|
||||
* the pointer to the NDArrayContainer can be directly
|
||||
* interpreted as a DLTensor*
|
||||
* interpreted as a DGLArray*
|
||||
*
|
||||
* \note: do not use this function directly, use NDArray.
|
||||
*/
|
||||
struct NDArray::Container {
|
||||
public:
|
||||
// NOTE: the first part of this structure is the same as
|
||||
// DLManagedTensor, note that, however, the deleter
|
||||
// is only called when the reference counter goes to 0
|
||||
/*!
|
||||
* \brief The corresponding dl_tensor field.
|
||||
* \note it is important that the first field is DLTensor
|
||||
* So that this data structure is DLTensor compatible.
|
||||
* The head ptr of this struct can be viewed as DLTensor*.
|
||||
/*! NOTE: the first part of this structure is the same as
|
||||
* DLManagedTensor, note that, however, the deleter
|
||||
* is only called when the reference counter goes to 0
|
||||
*/
|
||||
DLTensor dl_tensor;
|
||||
/*!
|
||||
* \brief Tensor structure.
|
||||
* \note it is important that the first field is DGLArray
|
||||
* So that this data structure is DGLArray compatible.
|
||||
* The head ptr of this struct can be viewed as DGLArray*.
|
||||
*/
|
||||
DGLArray dl_tensor;
|
||||
/*!
|
||||
* \brief addtional context, reserved for recycling
|
||||
* \note We can attach additional content here
|
||||
@@ -411,6 +414,7 @@ struct NDArray::Container {
|
||||
}
|
||||
|
||||
private:
|
||||
friend struct DLPackConvert;
|
||||
friend class NDArray;
|
||||
friend class RPCWrappedFunc;
|
||||
/*!
|
||||
@@ -450,7 +454,7 @@ inline void NDArray::reset() {
|
||||
}
|
||||
}
|
||||
|
||||
inline void NDArray::CopyFrom(DLTensor* other) {
|
||||
inline void NDArray::CopyFrom(DGLArray* other) {
|
||||
CHECK(data_ != nullptr);
|
||||
CopyFromTo(other, &(data_->dl_tensor));
|
||||
}
|
||||
@@ -460,7 +464,7 @@ inline void NDArray::CopyFrom(const NDArray& other) {
|
||||
CopyFrom(&(other.data_->dl_tensor));
|
||||
}
|
||||
|
||||
inline void NDArray::CopyTo(DLTensor *other) const {
|
||||
inline void NDArray::CopyTo(DGLArray *other) const {
|
||||
CHECK(data_ != nullptr);
|
||||
CopyFromTo(&(data_->dl_tensor), other);
|
||||
}
|
||||
@@ -470,9 +474,9 @@ inline void NDArray::CopyTo(const NDArray &other) const {
|
||||
CopyTo(&(other.data_->dl_tensor));
|
||||
}
|
||||
|
||||
inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
|
||||
inline NDArray NDArray::CopyTo(const DGLContext &ctx) const {
|
||||
CHECK(data_ != nullptr);
|
||||
const DLTensor* dptr = operator->();
|
||||
const DGLArray* dptr = operator->();
|
||||
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
|
||||
dptr->dtype, ctx);
|
||||
this->CopyTo(ret);
|
||||
@@ -481,7 +485,7 @@ inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
|
||||
|
||||
inline NDArray NDArray::Clone() const {
|
||||
CHECK(data_ != nullptr);
|
||||
const DLTensor* dptr = operator->();
|
||||
const DGLArray* dptr = operator->();
|
||||
return this->CopyTo(dptr->ctx);
|
||||
}
|
||||
|
||||
@@ -510,15 +514,15 @@ inline int NDArray::use_count() const {
|
||||
return data_->ref_counter_.load(std::memory_order_relaxed);
|
||||
}
|
||||
|
||||
inline const DLTensor* NDArray::operator->() const {
|
||||
inline const DGLArray* NDArray::operator->() const {
|
||||
return &(data_->dl_tensor);
|
||||
}
|
||||
|
||||
/*! \brief Magic number for NDArray file */
|
||||
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
|
||||
|
||||
inline bool SaveDLTensor(dmlc::Stream* strm,
|
||||
DLTensor* tensor) {
|
||||
inline bool SaveDGLArray(dmlc::Stream* strm,
|
||||
DGLArray* tensor) {
|
||||
uint64_t header = kDGLNDArrayMagic, reserved = 0;
|
||||
strm->Write(header);
|
||||
strm->Write(reserved);
|
||||
@@ -531,8 +535,8 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
|
||||
//
|
||||
// We can always do array.CopyTo(target_ctx) to get a corresponding
|
||||
// array in the target context.
|
||||
DLContext cpu_ctx;
|
||||
cpu_ctx.device_type = kDLCPU;
|
||||
DGLContext cpu_ctx;
|
||||
cpu_ctx.device_type = kDGLCPU;
|
||||
cpu_ctx.device_id = 0;
|
||||
strm->Write(cpu_ctx);
|
||||
strm->Write(tensor->ndim);
|
||||
@@ -548,7 +552,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
|
||||
strm->Write(data_byte_size);
|
||||
|
||||
if (DMLC_IO_NO_ENDIAN_SWAP &&
|
||||
tensor->ctx.device_type == kDLCPU &&
|
||||
tensor->ctx.device_type == kDGLCPU &&
|
||||
tensor->strides == nullptr &&
|
||||
tensor->byte_offset == 0) {
|
||||
// quick path
|
||||
@@ -573,16 +577,16 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
|
||||
*/
|
||||
inline const char* TypeCode2Str(int type_code) {
|
||||
switch (type_code) {
|
||||
case kDLInt: return "int";
|
||||
case kDLUInt: return "uint";
|
||||
case kDLFloat: return "float";
|
||||
case kDGLInt: return "int";
|
||||
case kDGLUInt: return "uint";
|
||||
case kDGLFloat: return "float";
|
||||
case kStr: return "str";
|
||||
case kBytes: return "bytes";
|
||||
case kHandle: return "handle";
|
||||
case kNull: return "NULL";
|
||||
case kObjectHandle: return "ObjectHandle";
|
||||
case kArrayHandle: return "ArrayHandle";
|
||||
case kDGLType: return "DGLType";
|
||||
case kDGLDataType: return "DGLDataType";
|
||||
case kDGLContext: return "DGLContext";
|
||||
case kFuncHandle: return "FunctionHandle";
|
||||
case kModuleHandle: return "ModuleHandle";
|
||||
@@ -597,17 +601,11 @@ inline const char* TypeCode2Str(int type_code) {
|
||||
* \param device_type The device type code.
|
||||
* \return The name of the device.
|
||||
*/
|
||||
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
|
||||
inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
|
||||
switch (device_type) {
|
||||
case kDLCPU: return "cpu";
|
||||
case kDLGPU: return "cuda";
|
||||
case kDLCPUPinned: return "cpu_pinned";
|
||||
case kDLOpenCL: return "opencl";
|
||||
case kDLVulkan: return "vulkan";
|
||||
case kDLMetal: return "metal";
|
||||
case kDLVPI: return "vpi";
|
||||
case kDLROCM: return "rocm";
|
||||
default: LOG(FATAL) << "Unknown device type code="
|
||||
case kDGLCPU: return "cpu";
|
||||
case kDGLCUDA: return "cuda";
|
||||
default: LOG(FATAL) << "Unsupported device type code="
|
||||
<< static_cast<int>(device_type); return "";
|
||||
}
|
||||
}
|
||||
@@ -617,16 +615,16 @@ inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
|
||||
* \param s The string to be converted.
|
||||
* \return The corresponding dgl type.
|
||||
*/
|
||||
inline DGLType String2DGLType(std::string s) {
|
||||
DGLType t;
|
||||
inline DGLDataType String2DGLDataType(std::string s) {
|
||||
DGLDataType t;
|
||||
t.bits = 32; t.lanes = 1;
|
||||
const char* scan;
|
||||
if (s.substr(0, 3) == "int") {
|
||||
t.code = kDLInt; scan = s.c_str() + 3;
|
||||
t.code = kDGLInt; scan = s.c_str() + 3;
|
||||
} else if (s.substr(0, 4) == "uint") {
|
||||
t.code = kDLUInt; scan = s.c_str() + 4;
|
||||
t.code = kDGLUInt; scan = s.c_str() + 4;
|
||||
} else if (s.substr(0, 5) == "float") {
|
||||
t.code = kDLFloat; scan = s.c_str() + 5;
|
||||
t.code = kDGLFloat; scan = s.c_str() + 5;
|
||||
} else if (s.substr(0, 6) == "handle") {
|
||||
t.code = kHandle;
|
||||
t.bits = 64; // handle uses 64 bit by default.
|
||||
@@ -649,7 +647,7 @@ inline DGLType String2DGLType(std::string s) {
|
||||
* \param t The type to be converted.
|
||||
* \return The corresponding dgl type in string.
|
||||
*/
|
||||
inline std::string DGLType2String(DGLType t) {
|
||||
inline std::string DGLDataType2String(DGLDataType t) {
|
||||
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
|
||||
std::ostringstream os;
|
||||
os << t;
|
||||
@@ -728,20 +726,20 @@ dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2)
|
||||
|
||||
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array);
|
||||
|
||||
///////////////// Operator overloading for DLDataType /////////////////
|
||||
///////////////// Operator overloading for DGLDataType /////////////////
|
||||
|
||||
/*! \brief Check whether two data types are the same.*/
|
||||
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
|
||||
inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) {
|
||||
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
|
||||
}
|
||||
|
||||
/*! \brief Check whether two data types are different.*/
|
||||
inline bool operator != (const DLDataType& ty1, const DLDataType& ty2) {
|
||||
inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) {
|
||||
return !(ty1 == ty2);
|
||||
}
|
||||
|
||||
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
|
||||
inline std::ostream& operator << (std::ostream& os, DGLType t) {
|
||||
inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
|
||||
os << dgl::runtime::TypeCode2Str(t.code);
|
||||
if (t.code == kHandle) return os;
|
||||
os << static_cast<int>(t.bits);
|
||||
@@ -752,20 +750,20 @@ inline std::ostream& operator << (std::ostream& os, DGLType t) {
|
||||
}
|
||||
#endif
|
||||
|
||||
///////////////// Operator overloading for DLContext /////////////////
|
||||
///////////////// Operator overloading for DGLContext /////////////////
|
||||
|
||||
/*! \brief Check whether two device contexts are the same.*/
|
||||
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
|
||||
inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) {
|
||||
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
|
||||
}
|
||||
|
||||
/*! \brief Check whether two device contexts are different.*/
|
||||
inline bool operator != (const DLContext& ctx1, const DLContext& ctx2) {
|
||||
inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) {
|
||||
return !(ctx1 == ctx2);
|
||||
}
|
||||
|
||||
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
|
||||
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
|
||||
inline std::ostream& operator << (std::ostream& os, const DGLContext& ctx) {
|
||||
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id;
|
||||
}
|
||||
#endif
|
||||
|
||||
@@ -350,28 +350,28 @@ class DGLPODValue_ {
|
||||
// Allow automatic conversion from int to float
|
||||
// This avoids errors when user pass in int from
|
||||
// the frontend while the API expects a float.
|
||||
if (type_code_ == kDLInt) {
|
||||
if (type_code_ == kDGLInt) {
|
||||
return static_cast<double>(value_.v_int64);
|
||||
}
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDLFloat);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLFloat);
|
||||
return value_.v_float64;
|
||||
}
|
||||
operator int64_t() const {
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
|
||||
return value_.v_int64;
|
||||
}
|
||||
operator uint64_t() const {
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
|
||||
return value_.v_int64;
|
||||
}
|
||||
operator int() const {
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
|
||||
CHECK_LE(value_.v_int64,
|
||||
std::numeric_limits<int>::max());
|
||||
return static_cast<int>(value_.v_int64);
|
||||
}
|
||||
operator bool() const {
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
|
||||
return value_.v_int64 != 0;
|
||||
}
|
||||
operator void*() const {
|
||||
@@ -380,14 +380,14 @@ class DGLPODValue_ {
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kHandle);
|
||||
return value_.v_handle;
|
||||
}
|
||||
operator DLTensor*() const {
|
||||
operator DGLArray*() const {
|
||||
if (type_code_ == kArrayHandle ||
|
||||
type_code_ == kNDArrayContainer) {
|
||||
return static_cast<DLTensor*>(value_.v_handle);
|
||||
return static_cast<DGLArray*>(value_.v_handle);
|
||||
} else {
|
||||
if (type_code_ == kNull) return nullptr;
|
||||
LOG(FATAL) << "Expected "
|
||||
<< "DLTensor* or NDArray but get "
|
||||
<< "DGLArray* or NDArray but get "
|
||||
<< TypeCode2Str(type_code_);
|
||||
return nullptr;
|
||||
}
|
||||
@@ -457,14 +457,14 @@ class DGLArgValue : public DGLPODValue_ {
|
||||
using DGLPODValue_::operator int;
|
||||
using DGLPODValue_::operator bool;
|
||||
using DGLPODValue_::operator void*;
|
||||
using DGLPODValue_::operator DLTensor*;
|
||||
using DGLPODValue_::operator DGLArray*;
|
||||
using DGLPODValue_::operator NDArray;
|
||||
using DGLPODValue_::operator DGLContext;
|
||||
|
||||
// conversion operator.
|
||||
operator std::string() const {
|
||||
if (type_code_ == kDGLType) {
|
||||
return DGLType2String(operator DGLType());
|
||||
if (type_code_ == kDGLDataType) {
|
||||
return DGLDataType2String(operator DGLDataType());
|
||||
} else if (type_code_ == kBytes) {
|
||||
DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle);
|
||||
return std::string(arr->data, arr->size);
|
||||
@@ -473,11 +473,11 @@ class DGLArgValue : public DGLPODValue_ {
|
||||
return std::string(value_.v_str);
|
||||
}
|
||||
}
|
||||
operator DGLType() const {
|
||||
operator DGLDataType() const {
|
||||
if (type_code_ == kStr) {
|
||||
return String2DGLType(operator std::string());
|
||||
return String2DGLDataType(operator std::string());
|
||||
}
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
|
||||
return value_.v_type;
|
||||
}
|
||||
operator PackedFunc() const {
|
||||
@@ -549,7 +549,7 @@ class DGLRetValue : public DGLPODValue_ {
|
||||
using DGLPODValue_::operator int;
|
||||
using DGLPODValue_::operator bool;
|
||||
using DGLPODValue_::operator void*;
|
||||
using DGLPODValue_::operator DLTensor*;
|
||||
using DGLPODValue_::operator DGLArray*;
|
||||
using DGLPODValue_::operator DGLContext;
|
||||
using DGLPODValue_::operator NDArray;
|
||||
// Disable copy and assign from another value, but allow move.
|
||||
@@ -558,19 +558,19 @@ class DGLRetValue : public DGLPODValue_ {
|
||||
}
|
||||
// conversion operators
|
||||
operator std::string() const {
|
||||
if (type_code_ == kDGLType) {
|
||||
return DGLType2String(operator DGLType());
|
||||
if (type_code_ == kDGLDataType) {
|
||||
return DGLDataType2String(operator DGLDataType());
|
||||
} else if (type_code_ == kBytes) {
|
||||
return *ptr<std::string>();
|
||||
}
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kStr);
|
||||
return *ptr<std::string>();
|
||||
}
|
||||
operator DGLType() const {
|
||||
operator DGLDataType() const {
|
||||
if (type_code_ == kStr) {
|
||||
return String2DGLType(operator std::string());
|
||||
return String2DGLDataType(operator std::string());
|
||||
}
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
|
||||
DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
|
||||
return value_.v_type;
|
||||
}
|
||||
operator PackedFunc() const {
|
||||
@@ -595,7 +595,7 @@ class DGLRetValue : public DGLPODValue_ {
|
||||
return *this;
|
||||
}
|
||||
DGLRetValue& operator=(double value) {
|
||||
this->SwitchToPOD(kDLFloat);
|
||||
this->SwitchToPOD(kDGLFloat);
|
||||
value_.v_float64 = value;
|
||||
return *this;
|
||||
}
|
||||
@@ -610,17 +610,17 @@ class DGLRetValue : public DGLPODValue_ {
|
||||
return *this;
|
||||
}
|
||||
DGLRetValue& operator=(int64_t value) {
|
||||
this->SwitchToPOD(kDLInt);
|
||||
this->SwitchToPOD(kDGLInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
DGLRetValue& operator=(int value) {
|
||||
this->SwitchToPOD(kDLInt);
|
||||
this->SwitchToPOD(kDGLInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
DGLRetValue& operator=(DGLType t) {
|
||||
this->SwitchToPOD(kDGLType);
|
||||
DGLRetValue& operator=(DGLDataType t) {
|
||||
this->SwitchToPOD(kDGLDataType);
|
||||
value_.v_type = t;
|
||||
return *this;
|
||||
}
|
||||
@@ -630,7 +630,7 @@ class DGLRetValue : public DGLPODValue_ {
|
||||
return *this;
|
||||
}
|
||||
DGLRetValue& operator=(bool value) {
|
||||
this->SwitchToPOD(kDLInt);
|
||||
this->SwitchToPOD(kDGLInt);
|
||||
value_.v_int64 = value;
|
||||
return *this;
|
||||
}
|
||||
@@ -859,17 +859,17 @@ class DGLArgsSetter {
|
||||
std::is_integral<T>::value>::type>
|
||||
void operator()(size_t i, T value) const {
|
||||
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||
type_codes_[i] = kDLInt;
|
||||
type_codes_[i] = kDGLInt;
|
||||
}
|
||||
void operator()(size_t i, uint64_t value) const {
|
||||
values_[i].v_int64 = static_cast<int64_t>(value);
|
||||
CHECK_LE(value,
|
||||
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
|
||||
type_codes_[i] = kDLInt;
|
||||
type_codes_[i] = kDGLInt;
|
||||
}
|
||||
void operator()(size_t i, double value) const {
|
||||
values_[i].v_float64 = value;
|
||||
type_codes_[i] = kDLFloat;
|
||||
type_codes_[i] = kDGLFloat;
|
||||
}
|
||||
void operator()(size_t i, std::nullptr_t value) const {
|
||||
values_[i].v_handle = value;
|
||||
@@ -883,7 +883,7 @@ class DGLArgsSetter {
|
||||
values_[i].v_handle = value;
|
||||
type_codes_[i] = kHandle;
|
||||
}
|
||||
void operator()(size_t i, DLTensor* value) const {
|
||||
void operator()(size_t i, DGLArray* value) const {
|
||||
values_[i].v_handle = value;
|
||||
type_codes_[i] = kArrayHandle;
|
||||
}
|
||||
@@ -891,9 +891,9 @@ class DGLArgsSetter {
|
||||
values_[i].v_ctx = value;
|
||||
type_codes_[i] = kDGLContext;
|
||||
}
|
||||
void operator()(size_t i, DGLType value) const {
|
||||
void operator()(size_t i, DGLDataType value) const {
|
||||
values_[i].v_type = value;
|
||||
type_codes_[i] = kDGLType;
|
||||
type_codes_[i] = kDGLDataType;
|
||||
}
|
||||
void operator()(size_t i, const char* value) const {
|
||||
values_[i].v_str = value;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file dgl/runtime/serializer.h
|
||||
* \brief Serializer extension to support DGL data types
|
||||
* Include this file to enable serialization of DLDataType, DLContext
|
||||
* Include this file to enable serialization of DGLDataType, DGLContext
|
||||
*/
|
||||
#ifndef DGL_RUNTIME_SERIALIZER_H_
|
||||
#define DGL_RUNTIME_SERIALIZER_H_
|
||||
@@ -16,13 +16,13 @@ namespace dmlc {
|
||||
namespace serializer {
|
||||
|
||||
template <>
|
||||
struct Handler<DLDataType> {
|
||||
inline static void Write(Stream *strm, const DLDataType &dtype) {
|
||||
struct Handler<DGLDataType> {
|
||||
inline static void Write(Stream *strm, const DGLDataType &dtype) {
|
||||
Handler<uint8_t>::Write(strm, dtype.code);
|
||||
Handler<uint8_t>::Write(strm, dtype.bits);
|
||||
Handler<uint16_t>::Write(strm, dtype.lanes);
|
||||
}
|
||||
inline static bool Read(Stream *strm, DLDataType *dtype) {
|
||||
inline static bool Read(Stream *strm, DGLDataType *dtype) {
|
||||
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
|
||||
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
|
||||
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
|
||||
@@ -31,16 +31,16 @@ struct Handler<DLDataType> {
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Handler<DLContext> {
|
||||
inline static void Write(Stream *strm, const DLContext &ctx) {
|
||||
struct Handler<DGLContext> {
|
||||
inline static void Write(Stream *strm, const DGLContext &ctx) {
|
||||
int32_t device_type = static_cast<int32_t>(ctx.device_type);
|
||||
Handler<int32_t>::Write(strm, device_type);
|
||||
Handler<int32_t>::Write(strm, ctx.device_id);
|
||||
}
|
||||
inline static bool Read(Stream *strm, DLContext *ctx) {
|
||||
inline static bool Read(Stream *strm, DGLContext *ctx) {
|
||||
int32_t device_type = 0;
|
||||
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
|
||||
ctx->device_type = static_cast<DLDeviceType>(device_type);
|
||||
ctx->device_type = static_cast<DGLDeviceType>(device_type);
|
||||
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* Copyright (c) 2017 by Contributors
|
||||
* \file dgl/runtime/serializer.h
|
||||
* \brief Serializer extension to support DGL data types
|
||||
* Include this file to enable serialization of DLDataType, DLContext
|
||||
* Include this file to enable serialization of DGLDataType, DGLContext
|
||||
*/
|
||||
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
|
||||
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
|
||||
|
||||
@@ -18,7 +18,7 @@ namespace runtime {
|
||||
* \param bits The number of bits to be matched.
|
||||
* \param lanes The number of lanes sin the type.
|
||||
*/
|
||||
inline bool TypeMatch(DGLType t, int code, int bits, int lanes = 1) {
|
||||
inline bool TypeMatch(DGLDataType t, int code, int bits, int lanes = 1) {
|
||||
return t.code == code && t.bits == bits && t.lanes == lanes;
|
||||
}
|
||||
} // namespace runtime
|
||||
|
||||
@@ -10,7 +10,7 @@ from numbers import Number, Integral
|
||||
from ..base import _LIB, check_call
|
||||
from ..base import c_str, string_types
|
||||
from ..object_generic import convert_to_object, ObjectGeneric
|
||||
from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext
|
||||
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
|
||||
from . import ndarray as _nd
|
||||
from .ndarray import NDArrayBase, _make_array
|
||||
from .types import DGLValue, TypeCode
|
||||
@@ -115,7 +115,7 @@ def _make_dgl_args(args, temp_args):
|
||||
elif isinstance(arg, Number):
|
||||
values[i].v_float64 = arg
|
||||
type_codes[i] = TypeCode.FLOAT
|
||||
elif isinstance(arg, DGLType):
|
||||
elif isinstance(arg, DGLDataType):
|
||||
values[i].v_str = c_str(str(arg))
|
||||
type_codes[i] = TypeCode.STR
|
||||
elif isinstance(arg, DGLContext):
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
|
||||
|
||||
import ctypes
|
||||
from ..base import py_str, check_call, _LIB
|
||||
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLType, DGLContext
|
||||
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext
|
||||
|
||||
class DGLValue(ctypes.Union):
|
||||
"""DGLValue in C API"""
|
||||
@@ -12,7 +12,7 @@ class DGLValue(ctypes.Union):
|
||||
("v_float64", ctypes.c_double),
|
||||
("v_handle", ctypes.c_void_p),
|
||||
("v_str", ctypes.c_char_p),
|
||||
("v_type", DGLType),
|
||||
("v_type", DGLDataType),
|
||||
("v_ctx", DGLContext)]
|
||||
|
||||
|
||||
|
||||
@@ -3,16 +3,16 @@ from libcpp.vector cimport vector
|
||||
from libcpp cimport bool
|
||||
from cpython.version cimport PY_MAJOR_VERSION
|
||||
from cpython cimport pycapsule
|
||||
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
|
||||
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
|
||||
import ctypes
|
||||
|
||||
cdef enum DGLTypeCode:
|
||||
cdef enum DGLObjectTypeCode:
|
||||
kInt = 0
|
||||
kUInt = 1
|
||||
kFloat = 2
|
||||
kHandle = 3
|
||||
kNull = 4
|
||||
kDGLType = 5
|
||||
kDGLDataType = 5
|
||||
kDGLContext = 6
|
||||
kArrayHandle = 7
|
||||
kObjectHandle = 8
|
||||
@@ -24,26 +24,26 @@ cdef enum DGLTypeCode:
|
||||
kExtBegin = 15
|
||||
|
||||
cdef extern from "dgl/runtime/c_runtime_api.h":
|
||||
ctypedef struct DLDataType:
|
||||
ctypedef struct DGLDataType:
|
||||
uint8_t code
|
||||
uint8_t bits
|
||||
uint16_t lanes
|
||||
|
||||
ctypedef struct DLContext:
|
||||
int device_type
|
||||
int device_id
|
||||
ctypedef struct DGLContext:
|
||||
int32_t device_type
|
||||
int32_t device_id
|
||||
|
||||
ctypedef struct DLTensor:
|
||||
ctypedef struct DGLArray:
|
||||
void* data
|
||||
DLContext ctx
|
||||
int ndim
|
||||
DLDataType dtype
|
||||
DGLContext ctx
|
||||
int32_t ndim
|
||||
DGLDataType dtype
|
||||
int64_t* shape
|
||||
int64_t* strides
|
||||
uint64_t byte_offset
|
||||
|
||||
ctypedef struct DLManagedTensor:
|
||||
DLTensor dl_tensor
|
||||
DGLArray dl_tensor
|
||||
void* manager_ctx
|
||||
void (*deleter)(DLManagedTensor* self)
|
||||
|
||||
@@ -52,13 +52,11 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
|
||||
double v_float64
|
||||
void* v_handle
|
||||
const char* v_str
|
||||
DLDataType v_type
|
||||
DLContext v_ctx
|
||||
DGLDataType v_type
|
||||
DGLContext v_ctx
|
||||
|
||||
ctypedef int64_t dgl_index_t
|
||||
ctypedef DLTensor* DLTensorHandle
|
||||
ctypedef DLTensor DGLArray
|
||||
ctypedef DGLArray* CDGLArrayHandle
|
||||
ctypedef DGLArray* DGLArrayHandle
|
||||
ctypedef void* DGLStreamHandle
|
||||
ctypedef void* DGLRetValueHandle
|
||||
ctypedef void* DGLFunctionHandle
|
||||
@@ -94,9 +92,9 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
|
||||
int DGLCbArgToReturn(DGLValue* value, int code)
|
||||
int DGLArrayAlloc(dgl_index_t* shape,
|
||||
dgl_index_t ndim,
|
||||
DLDataType dtype,
|
||||
DLContext ctx,
|
||||
DLTensorHandle* out)
|
||||
DGLDataType dtype,
|
||||
DGLContext ctx,
|
||||
DGLArrayHandle* out)
|
||||
int DGLArrayAllocSharedMem(const char *mem_name,
|
||||
const dgl_index_t *shape,
|
||||
int ndim,
|
||||
@@ -104,16 +102,10 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
|
||||
int dtype_bits,
|
||||
int dtype_lanes,
|
||||
bool is_create,
|
||||
CDGLArrayHandle* out)
|
||||
int DGLArrayFree(DLTensorHandle handle)
|
||||
int DGLArrayCopyFromTo(DLTensorHandle src,
|
||||
DLTensorHandle to)
|
||||
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
|
||||
DLTensorHandle* out)
|
||||
int DGLArrayToDLPack(DLTensorHandle arr_from,
|
||||
DLManagedTensor** out,
|
||||
int alignment)
|
||||
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
|
||||
DGLArrayHandle* out)
|
||||
int DGLArrayFree(DGLArrayHandle handle)
|
||||
int DGLArrayCopyFromTo(DGLArrayHandle src,
|
||||
DGLArrayHandle to)
|
||||
|
||||
cdef extern from "dgl/runtime/c_object_api.h":
|
||||
int DGLObjectFree(ObjectHandle handle)
|
||||
@@ -127,6 +119,14 @@ cdef extern from "dgl/runtime/c_object_api.h":
|
||||
int* out_type_code,
|
||||
int* out_success)
|
||||
|
||||
cdef extern from "dgl/runtime/dlpack_convert.h":
|
||||
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
|
||||
DGLArrayHandle* out)
|
||||
int DGLArrayToDLPack(DGLArrayHandle arr_from,
|
||||
DLManagedTensor** out,
|
||||
int alignment)
|
||||
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
|
||||
|
||||
cdef inline py_str(const char* x):
|
||||
if PY_MAJOR_VERSION < 3:
|
||||
return x
|
||||
|
||||
@@ -4,7 +4,9 @@ from cpython cimport Py_INCREF, Py_DECREF
|
||||
from numbers import Number, Integral
|
||||
from ..base import string_types
|
||||
from ..object_generic import convert_to_object, ObjectGeneric
|
||||
from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray
|
||||
from ..runtime_ctypes import DGLDataType as CTypesDGLDataType, \
|
||||
DGLContext as CTypesDGLContext, \
|
||||
DGLByteArray
|
||||
|
||||
|
||||
cdef void dgl_callback_finalize(void* fhandle):
|
||||
@@ -107,13 +109,13 @@ cdef inline int make_arg(object arg,
|
||||
elif isinstance(arg, Number):
|
||||
value[0].v_float64 = arg
|
||||
tcode[0] = kFloat
|
||||
elif isinstance(arg, DGLType):
|
||||
elif isinstance(arg, CTypesDGLDataType):
|
||||
tstr = c_str(str(arg))
|
||||
value[0].v_str = tstr
|
||||
tcode[0] = kStr
|
||||
temp_args.append(tstr)
|
||||
elif isinstance(arg, DGLContext):
|
||||
value[0].v_ctx = (<DLContext*>(
|
||||
elif isinstance(arg, CTypesDGLContext):
|
||||
value[0].v_ctx = (<DGLContext*>(
|
||||
<unsigned long long>ctypes.addressof(arg)))[0]
|
||||
tcode[0] = kDGLContext
|
||||
elif isinstance(arg, bytearray):
|
||||
@@ -183,7 +185,7 @@ cdef inline object make_ret(DGLValue value, int tcode):
|
||||
elif tcode == kHandle:
|
||||
return ctypes_handle(value.v_handle)
|
||||
elif tcode == kDGLContext:
|
||||
return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
|
||||
return CTypesDGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
|
||||
# (minjie): class module are not used in DGL.
|
||||
#elif tcode == kModuleHandle:
|
||||
# return _CLASS_MODULE(ctypes_handle(value.v_handle))
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from ..runtime_ctypes import DGLArrayHandle
|
||||
from ..runtime_ctypes import DGLArrayHandle as PyDGLArrayHandle
|
||||
|
||||
cdef const char* _c_str_dltensor = "dltensor"
|
||||
cdef const char* _c_str_used_dltensor = "used_dltensor"
|
||||
@@ -13,7 +13,7 @@ cdef void _c_dlpack_deleter(object pycaps):
|
||||
|
||||
def _from_dlpack(object dltensor):
|
||||
cdef DLManagedTensor* ptr
|
||||
cdef DLTensorHandle chandle
|
||||
cdef DGLArrayHandle chandle
|
||||
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
|
||||
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
|
||||
CALL(DGLArrayFromDLPack(ptr, &chandle))
|
||||
@@ -25,7 +25,7 @@ def _from_dlpack(object dltensor):
|
||||
|
||||
|
||||
cdef class NDArrayBase:
|
||||
cdef DLTensor* chandle
|
||||
cdef DGLArray* chandle
|
||||
cdef int c_is_view
|
||||
|
||||
cdef inline _set_handle(self, handle):
|
||||
@@ -34,7 +34,7 @@ cdef class NDArrayBase:
|
||||
self.chandle = NULL
|
||||
else:
|
||||
ptr = ctypes.cast(handle, ctypes.c_void_p).value
|
||||
self.chandle = <DLTensor*>(ptr)
|
||||
self.chandle = <DGLArray*>(ptr)
|
||||
|
||||
property _dgl_handle:
|
||||
def __get__(self):
|
||||
@@ -46,7 +46,7 @@ cdef class NDArrayBase:
|
||||
return None
|
||||
else:
|
||||
return ctypes.cast(
|
||||
<unsigned long long>self.chandle, DGLArrayHandle)
|
||||
<unsigned long long>self.chandle, PyDGLArrayHandle)
|
||||
|
||||
def __set__(self, value):
|
||||
self._set_handle(value)
|
||||
@@ -82,7 +82,7 @@ cdef class NDArrayBase:
|
||||
|
||||
cdef c_make_array(void* chandle, is_view):
|
||||
ret = _CLASS_NDARRAY(None, is_view)
|
||||
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
|
||||
(<NDArrayBase>ret).chandle = <DGLArray*>chandle
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import sys
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
|
||||
from .runtime_ctypes import DGLType, DGLContext, DGLArray, DGLArrayHandle
|
||||
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
|
||||
from .runtime_ctypes import TypeCode, dgl_shape_index_t
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def numpyasarray(np_data):
|
||||
arr.data = data.ctypes.data_as(ctypes.c_void_p)
|
||||
arr.shape = shape
|
||||
arr.strides = None
|
||||
arr.dtype = DGLType(np.dtype(data.dtype).name)
|
||||
arr.dtype = DGLDataType(np.dtype(data.dtype).name)
|
||||
arr.ndim = data.ndim
|
||||
# CPU device
|
||||
arr.ctx = context(1, 0)
|
||||
@@ -101,7 +101,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
|
||||
shape = c_array(dgl_shape_index_t, shape)
|
||||
ndim = ctypes.c_int(len(shape))
|
||||
handle = DGLArrayHandle()
|
||||
dtype = DGLType(dtype)
|
||||
dtype = DGLDataType(dtype)
|
||||
check_call(_LIB.DGLArrayAlloc(
|
||||
shape, ndim,
|
||||
ctypes.c_int(dtype.type_code),
|
||||
@@ -139,7 +139,7 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
|
||||
shape = c_array(dgl_shape_index_t, shape)
|
||||
ndim = ctypes.c_int(len(shape))
|
||||
handle = DGLArrayHandle()
|
||||
dtype = DGLType(dtype)
|
||||
dtype = DGLDataType(dtype)
|
||||
check_call(_LIB.DGLArrayAllocSharedMem(
|
||||
name, shape, ndim,
|
||||
ctypes.c_int(dtype.type_code),
|
||||
@@ -254,7 +254,7 @@ class NDArrayBase(_NDArrayBase):
|
||||
except:
|
||||
raise TypeError('array must be an array_like data,' +
|
||||
'type %s is not supported' % str(type(source_array)))
|
||||
t = DGLType(self.dtype)
|
||||
t = DGLDataType(self.dtype)
|
||||
shape, dtype = self.shape, self.dtype
|
||||
if t.lanes > 1:
|
||||
shape = shape + (t.lanes,)
|
||||
@@ -286,7 +286,7 @@ class NDArrayBase(_NDArrayBase):
|
||||
np_arr : numpy.ndarray
|
||||
The corresponding numpy array.
|
||||
"""
|
||||
t = DGLType(self.dtype)
|
||||
t = DGLDataType(self.dtype)
|
||||
shape, dtype = self.shape, self.dtype
|
||||
if t.lanes > 1:
|
||||
shape = shape + (t.lanes,)
|
||||
|
||||
@@ -17,7 +17,7 @@ class TypeCode(object):
|
||||
FLOAT = 2
|
||||
HANDLE = 3
|
||||
NULL = 4
|
||||
DGL_TYPE = 5
|
||||
DGL_DATA_TYPE = 5
|
||||
DGL_CONTEXT = 6
|
||||
ARRAY_HANDLE = 7
|
||||
OBJECT_HANDLE = 8
|
||||
@@ -33,7 +33,7 @@ class DGLByteArray(ctypes.Structure):
|
||||
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
|
||||
("size", ctypes.c_size_t)]
|
||||
|
||||
class DGLType(ctypes.Structure):
|
||||
class DGLDataType(ctypes.Structure):
|
||||
"""DGL datatype structure"""
|
||||
_fields_ = [("type_code", ctypes.c_uint8),
|
||||
("bits", ctypes.c_uint8),
|
||||
@@ -50,7 +50,7 @@ class DGLType(ctypes.Structure):
|
||||
if type_str in cls._cache:
|
||||
return cls._cache[type_str]
|
||||
|
||||
inst = super(DGLType, cls).__new__(DGLType)
|
||||
inst = super(DGLDataType, cls).__new__(DGLDataType)
|
||||
|
||||
if isinstance(type_str, np.dtype):
|
||||
type_str = str(type_str)
|
||||
@@ -84,7 +84,7 @@ class DGLType(ctypes.Structure):
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
x = "%s%d" % (DGLType.CODE2STR[self.type_code], self.bits)
|
||||
x = "%s%d" % (DGLDataType.CODE2STR[self.type_code], self.bits)
|
||||
if self.lanes != 1:
|
||||
x += "x%d" % self.lanes
|
||||
return x
|
||||
@@ -250,7 +250,7 @@ class DGLArray(ctypes.Structure):
|
||||
_fields_ = [("data", ctypes.c_void_p),
|
||||
("ctx", DGLContext),
|
||||
("ndim", ctypes.c_int),
|
||||
("dtype", DGLType),
|
||||
("dtype", DGLDataType),
|
||||
("shape", ctypes.POINTER(dgl_shape_index_t)),
|
||||
("strides", ctypes.POINTER(dgl_shape_index_t)),
|
||||
("byte_offset", ctypes.c_uint64)]
|
||||
|
||||
@@ -13,7 +13,7 @@ import numpy as _np
|
||||
|
||||
from ._ffi.object import register_object, ObjectBase
|
||||
from ._ffi.function import _init_api
|
||||
from ._ffi.ndarray import DGLContext, DGLType, NDArrayBase
|
||||
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
|
||||
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
|
||||
from ._ffi.ndarray import _set_class_ndarray
|
||||
from . import backend as F
|
||||
|
||||
@@ -19,8 +19,8 @@ using namespace dgl::runtime;
|
||||
namespace dgl {
|
||||
namespace aten {
|
||||
|
||||
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) {
|
||||
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx);
|
||||
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
|
||||
return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
|
||||
}
|
||||
|
||||
IdArray Clone(IdArray arr) {
|
||||
@@ -29,7 +29,7 @@ IdArray Clone(IdArray arr) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
|
||||
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
|
||||
IdArray ret;
|
||||
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
|
||||
if (nbits == 32) {
|
||||
@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
|
||||
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
|
||||
IdArray ret;
|
||||
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
|
||||
if (nbits == 32) {
|
||||
@@ -58,7 +58,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
|
||||
}
|
||||
|
||||
template <typename DType>
|
||||
NDArray Full(DType val, int64_t length, DLContext ctx) {
|
||||
NDArray Full(DType val, int64_t length, DGLContext ctx) {
|
||||
NDArray ret;
|
||||
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
|
||||
ret = impl::Full<XPU, DType>(val, length, ctx);
|
||||
@@ -66,10 +66,10 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template NDArray Full<int32_t>(int32_t val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<int64_t>(int64_t val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<float>(float val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<double>(double val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
|
||||
|
||||
IdArray AsNumBits(IdArray arr, uint8_t bits) {
|
||||
CHECK(bits == 32 || bits == 64)
|
||||
@@ -315,7 +315,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
|
||||
|
||||
std::string ToDebugString(NDArray array) {
|
||||
std::ostringstream oss;
|
||||
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
|
||||
NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
|
||||
oss << "array([";
|
||||
ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
|
||||
for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
|
||||
@@ -1132,10 +1132,10 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
|
||||
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
|
||||
.set_body([] (DGLArgs args, DGLRetValue* rv) {
|
||||
NDArray array = args[0];
|
||||
CHECK_EQ(array->dtype.code, kDLUInt);
|
||||
CHECK_EQ(array->dtype.code, kDGLUInt);
|
||||
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
|
||||
DLDataType dtype = array->dtype;
|
||||
dtype.code = kDLInt;
|
||||
DGLDataType dtype = array->dtype;
|
||||
dtype.code = kDGLInt;
|
||||
*rv = array.CreateView(shape, dtype, 0);
|
||||
});
|
||||
|
||||
|
||||
@@ -16,176 +16,176 @@ namespace dgl {
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
IdArray Full(IdType val, int64_t length, DLContext ctx);
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Full(IdType val, int64_t length, DGLContext ctx);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DLContext ctx);
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DGLContext ctx);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray AsNumBits(IdArray arr, uint8_t bits);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdArray rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdType rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdType lhs, IdArray rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray UnaryElewise(IdArray array);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray IndexSelect(NDArray array, IdArray index);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
DType IndexSelect(NDArray array, int64_t index);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
IdArray NonZero(BoolArray bool_arr);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray Scatter(NDArray array, IdArray indices);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
void Scatter_(IdArray index, NDArray value, NDArray out);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray Repeat(NDArray array, IdArray repeats);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Relabel_(const std::vector<IdArray>& arrays);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray Concat(const std::vector<IdArray>& arrays);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray CumSum(IdArray array, bool prepend_zero);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray NonZero(NDArray array);
|
||||
|
||||
// sparse arrays
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRHasDuplicate(CSRMatrix csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsSorted(CSRMatrix csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
runtime::NDArray CSRGetData(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
|
||||
runtime::NDArray weights, DType filler);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
runtime::NDArray CSRGetData(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
|
||||
runtime::NDArray weights, DType filler) {
|
||||
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
|
||||
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::vector<runtime::NDArray> CSRGetDataAndIndices(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRTranspose(CSRMatrix csr);
|
||||
|
||||
// Convert CSR to COO
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOO(CSRMatrix csr);
|
||||
|
||||
// Convert CSR to COO using data array as order
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CSRSort_(CSRMatrix* csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename TagType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename TagType>
|
||||
std::pair<CSRMatrix, NDArray> CSRSortByTag(
|
||||
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix CSRRowWiseSampling(
|
||||
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix CSRRowWisePerEtypeSampling(
|
||||
CSRMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
|
||||
bool etype_sorted);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRRowWiseSamplingUniform(
|
||||
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRRowWisePerEtypeSamplingUniform(
|
||||
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
|
||||
bool replace, bool etype_sorted);
|
||||
|
||||
// FloatType is the type of weight data.
|
||||
template <DLDeviceType XPU, typename IdType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
COOMatrix CSRRowWiseTopk(
|
||||
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix CSRRowWiseSamplingBiased(
|
||||
CSRMatrix mat,
|
||||
IdArray rows,
|
||||
@@ -194,7 +194,7 @@ COOMatrix CSRRowWiseSamplingBiased(
|
||||
FloatArray bias,
|
||||
bool replace);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
const CSRMatrix& csr,
|
||||
int64_t num_samples,
|
||||
@@ -204,117 +204,117 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
double redundancy);
|
||||
|
||||
// Union CSRMatrixes
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool COOHasDuplicate(COOMatrix coo);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<runtime::NDArray, runtime::NDArray>
|
||||
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::vector<runtime::NDArray> COOGetDataAndIndices(
|
||||
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOTranspose(COOMatrix coo);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix COOToCSR(COOMatrix coo);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void COOSort_(COOMatrix* mat, bool sort_column);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COORemove(COOMatrix coo, IdArray entries);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix COORowWiseSampling(
|
||||
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
|
||||
|
||||
// FloatType is the type of probability data.
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix COORowWisePerEtypeSampling(
|
||||
COOMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COORowWiseSamplingUniform(
|
||||
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COORowWisePerEtypeSamplingUniform(
|
||||
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
|
||||
bool replace, bool etype_sorted);
|
||||
|
||||
// FloatType is the type of weight data.
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix COORowWiseTopk(
|
||||
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
|
||||
|
||||
///////////////////////// Graph Traverse routines //////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
|
||||
IdArray source,
|
||||
const bool has_reverse_edge,
|
||||
const bool has_nontree_edge,
|
||||
const bool return_labels);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace aten {
|
||||
|
||||
// Check whether the given arguments have the same context.
|
||||
inline void CheckCtx(
|
||||
const DLContext& ctx,
|
||||
const DGLContext& ctx,
|
||||
const std::vector<NDArray>& arrays,
|
||||
const std::vector<std::string>& names) {
|
||||
for (size_t i = 0; i < arrays.size(); ++i) {
|
||||
|
||||
@@ -10,7 +10,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray CumSum(IdArray array, bool prepend_zero) {
|
||||
const int64_t len = array.NumElements();
|
||||
if (len == 0)
|
||||
@@ -34,8 +34,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
|
||||
}
|
||||
}
|
||||
|
||||
template IdArray CumSum<kDLCPU, int32_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDLCPU, int64_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDGLCPU, int32_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDGLCPU, int64_t>(IdArray, bool);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -10,7 +10,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template<DLDeviceType XPU, typename DType, typename IdType>
|
||||
template<DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray IndexSelect(NDArray array, IdArray index) {
|
||||
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
|
||||
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
|
||||
@@ -28,25 +28,25 @@ NDArray IndexSelect(NDArray array, IdArray index) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template NDArray IndexSelect<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLCPU, double, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCPU, double, int64_t>(NDArray, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
DType IndexSelect(NDArray array, int64_t index) {
|
||||
const DType* data = static_cast<DType*>(array->data);
|
||||
return data[index];
|
||||
}
|
||||
|
||||
template int32_t IndexSelect<kDLCPU, int32_t>(NDArray array, int64_t index);
|
||||
template int64_t IndexSelect<kDLCPU, int64_t>(NDArray array, int64_t index);
|
||||
template float IndexSelect<kDLCPU, float>(NDArray array, int64_t index);
|
||||
template double IndexSelect<kDLCPU, double>(NDArray array, int64_t index);
|
||||
template int32_t IndexSelect<kDGLCPU, int32_t>(NDArray array, int64_t index);
|
||||
template int64_t IndexSelect<kDGLCPU, int64_t>(NDArray array, int64_t index);
|
||||
template float IndexSelect<kDGLCPU, float>(NDArray array, int64_t index);
|
||||
template double IndexSelect<kDGLCPU, double>(NDArray array, int64_t index);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -10,7 +10,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray NonZero(IdArray array) {
|
||||
std::vector<int64_t> ret;
|
||||
const IdType* data = array.Ptr<IdType>();
|
||||
@@ -20,8 +20,8 @@ IdArray NonZero(IdArray array) {
|
||||
return NDArray::FromVector(ret, array->ctx);
|
||||
}
|
||||
|
||||
template IdArray NonZero<kDLCPU, int32_t>(IdArray);
|
||||
template IdArray NonZero<kDLCPU, int64_t>(IdArray);
|
||||
template IdArray NonZero<kDGLCPU, int32_t>(IdArray);
|
||||
template IdArray NonZero<kDGLCPU, int64_t>(IdArray);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace impl {
|
||||
|
||||
///////////////////////////// AsNumBits /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray AsNumBits(IdArray arr, uint8_t bits) {
|
||||
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
|
||||
if (sizeof(IdType) * 8 == bits) {
|
||||
@@ -40,12 +40,12 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDGLCPU, int32_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDGLCPU, int64_t>(IdArray arr, uint8_t bits);
|
||||
|
||||
///////////////////////////// BinaryElewise /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
|
||||
@@ -59,30 +59,30 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
|
||||
@@ -95,30 +95,30 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
|
||||
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
|
||||
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
|
||||
@@ -131,30 +131,30 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray UnaryElewise(IdArray lhs) {
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
|
||||
@@ -167,28 +167,28 @@ IdArray UnaryElewise(IdArray lhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDGLCPU, int32_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDGLCPU, int64_t, arith::Neg>(IdArray lhs);
|
||||
|
||||
///////////////////////////// Full /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
NDArray Full(DType val, int64_t length, DLContext ctx) {
|
||||
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
NDArray Full(DType val, int64_t length, DGLContext ctx) {
|
||||
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
|
||||
DType* ret_data = static_cast<DType*>(ret->data);
|
||||
std::fill(ret_data, ret_data + length, val);
|
||||
return ret;
|
||||
}
|
||||
|
||||
template NDArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<kDLCPU, float>(float val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<kDLCPU, double>(double val, int64_t length, DLContext ctx);
|
||||
template NDArray Full<kDGLCPU, int32_t>(int32_t val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<kDGLCPU, int64_t>(int64_t val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<kDGLCPU, float>(float val, int64_t length, DGLContext ctx);
|
||||
template NDArray Full<kDGLCPU, double>(double val, int64_t length, DGLContext ctx);
|
||||
|
||||
///////////////////////////// Range /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DLContext ctx) {
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DGLContext ctx) {
|
||||
CHECK(high >= low) << "high must be bigger than low";
|
||||
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
|
||||
IdType* ret_data = static_cast<IdType*>(ret->data);
|
||||
@@ -196,12 +196,12 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
|
||||
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
|
||||
template IdArray Range<kDGLCPU, int32_t>(int32_t, int32_t, DGLContext);
|
||||
template IdArray Range<kDGLCPU, int64_t>(int64_t, int64_t, DGLContext);
|
||||
|
||||
///////////////////////////// Relabel_ /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
// build map & relabel
|
||||
IdType newid = 0;
|
||||
@@ -216,7 +216,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
}
|
||||
}
|
||||
// map array
|
||||
IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdArray maparr = NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdType* maparr_data = static_cast<IdType*>(maparr->data);
|
||||
for (const auto& kv : oldv2newv) {
|
||||
maparr_data[kv.second] = kv.first;
|
||||
@@ -224,8 +224,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
return maparr;
|
||||
}
|
||||
|
||||
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDGLCPU, int32_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDGLCPU, int64_t>(const std::vector<IdArray>& arrays);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -14,7 +14,7 @@ using runtime::parallel_for;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template<DLDeviceType XPU, typename DType, typename IdType>
|
||||
template<DGLDeviceType XPU, typename DType, typename IdType>
|
||||
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
|
||||
const int64_t rows = lengths->shape[0];
|
||||
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
|
||||
@@ -41,16 +41,16 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
|
||||
return std::make_pair(concat, offsets);
|
||||
}
|
||||
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(NDArray, IdArray);
|
||||
|
||||
template<DLDeviceType XPU, typename DType>
|
||||
template<DGLDeviceType XPU, typename DType>
|
||||
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
|
||||
CHECK_NDIM(array, 2, "array");
|
||||
const DType *array_data = static_cast<DType *>(array->data);
|
||||
@@ -75,10 +75,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
|
||||
return std::make_tuple(ret.first, length, ret.second);
|
||||
}
|
||||
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int32_t>(NDArray, int32_t);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int64_t>(NDArray, int64_t);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, float>(NDArray, float);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, double>(NDArray, double);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(NDArray, int32_t);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(NDArray, int64_t);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(NDArray, float);
|
||||
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(NDArray, double);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -11,7 +11,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray Repeat(NDArray array, IdArray repeats) {
|
||||
CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch";
|
||||
|
||||
@@ -34,14 +34,14 @@ NDArray Repeat(NDArray array, IdArray repeats) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template NDArray Repeat<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDLCPU, double, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -11,7 +11,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray Scatter(NDArray array, IdArray indices) {
|
||||
NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);
|
||||
|
||||
@@ -25,16 +25,16 @@ NDArray Scatter(NDArray array, IdArray indices) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template NDArray Scatter<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDLCPU, double, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
void Scatter_(IdArray index, NDArray value, NDArray out) {
|
||||
const int64_t len = index->shape[0];
|
||||
const IdType* idx = index.Ptr<IdType>();
|
||||
@@ -47,14 +47,14 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
|
||||
});
|
||||
}
|
||||
|
||||
template void Scatter_<kDLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, float, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, double, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, float, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLCPU, double, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, float, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, double, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, float, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCPU, double, int64_t>(IdArray, NDArray, NDArray);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -160,7 +160,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
|
||||
const int64_t nitem = array->shape[0];
|
||||
IdArray val = array.Clone();
|
||||
@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
|
||||
return std::make_pair(val, idx);
|
||||
}
|
||||
|
||||
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(IdArray, int num_bits);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -88,7 +88,7 @@ class IdHashMap {
|
||||
|
||||
// Return all the old ids collected so far, ordered by new id.
|
||||
IdArray Values() const {
|
||||
IdArray values = NewIdArray(oldv2newv_.size(), DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdArray values = NewIdArray(oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdType* values_data = static_cast<IdType*>(values->data);
|
||||
for (auto pair : oldv2newv_)
|
||||
values_data[pair.second] = pair.first;
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace aten {
|
||||
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
|
||||
const int64_t nnz = coo.row->shape[0];
|
||||
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
|
||||
@@ -44,8 +44,8 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
|
||||
return std::make_pair(coo_result, NDArray::FromVector(count));
|
||||
}
|
||||
|
||||
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int32_t>(COOMatrix);
|
||||
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int64_t>(COOMatrix);
|
||||
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);
|
||||
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace dgl {
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
|
||||
const int64_t nnz = coo.row->shape[0];
|
||||
IdType* coo_row = coo.row.Ptr<IdType>();
|
||||
@@ -50,8 +50,8 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
|
||||
}
|
||||
|
||||
|
||||
template COOMatrix COOLineGraph<kDLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
|
||||
template COOMatrix COOLineGraph<kDLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
|
||||
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
|
||||
template COOMatrix COOLineGraph<kDGLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace impl {
|
||||
namespace {
|
||||
|
||||
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void COORemoveConsecutive(
|
||||
COOMatrix coo,
|
||||
IdArray entries,
|
||||
@@ -47,7 +47,7 @@ void COORemoveConsecutive(
|
||||
}
|
||||
|
||||
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void COORemoveShuffled(
|
||||
COOMatrix coo,
|
||||
IdArray entries,
|
||||
@@ -73,7 +73,7 @@ void COORemoveShuffled(
|
||||
|
||||
}; // namespace
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
|
||||
const int64_t nnz = coo.row->shape[0];
|
||||
const int64_t n_entries = entries->shape[0];
|
||||
@@ -98,8 +98,8 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
|
||||
IdArray::FromVector(new_eids));
|
||||
}
|
||||
|
||||
template COOMatrix COORemove<kDLCPU, int32_t>(COOMatrix coo, IdArray entries);
|
||||
template COOMatrix COORemove<kDLCPU, int64_t>(COOMatrix coo, IdArray entries);
|
||||
template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
|
||||
template COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -167,7 +167,7 @@ namespace impl {
|
||||
|
||||
///////////////////////////// COOSort_ /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void COOSort_(COOMatrix* coo, bool sort_column) {
|
||||
const int64_t nnz = coo->row->shape[0];
|
||||
IdType* coo_row = coo->row.Ptr<IdType>();
|
||||
@@ -208,13 +208,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
|
||||
coo->col_sorted = sort_column;
|
||||
}
|
||||
|
||||
template void COOSort_<kDLCPU, int32_t>(COOMatrix*, bool);
|
||||
template void COOSort_<kDLCPU, int64_t>(COOMatrix*, bool);
|
||||
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
|
||||
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
|
||||
|
||||
|
||||
///////////////////////////// COOIsSorted /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
|
||||
const int64_t nnz = coo.row->shape[0];
|
||||
IdType* row = coo.row.Ptr<IdType>();
|
||||
@@ -230,8 +230,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
|
||||
return {row_sorted, col_sorted};
|
||||
}
|
||||
|
||||
template std::pair<bool, bool> COOIsSorted<kDLCPU, int32_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDLCPU, int64_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -17,7 +17,7 @@ using runtime::parallel_for;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
|
||||
const IdType start, const IdType end, const IdType col,
|
||||
std::vector<IdType> *ret_vec) {
|
||||
@@ -38,7 +38,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
NDArray CSRGetData(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
|
||||
const int64_t rowlen = rows->shape[0];
|
||||
@@ -59,7 +59,7 @@ NDArray CSRGetData(
|
||||
const int64_t retlen = std::max(rowlen, collen);
|
||||
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
|
||||
if (return_eids)
|
||||
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
|
||||
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
|
||||
"DType does not match row's dtype.";
|
||||
|
||||
NDArray ret = Full(filler, retlen, rows->ctx);
|
||||
@@ -106,19 +106,19 @@ NDArray CSRGetData(
|
||||
return ret;
|
||||
}
|
||||
|
||||
template NDArray CSRGetData<kDLCPU, int32_t, float>(
|
||||
template NDArray CSRGetData<kDGLCPU, int32_t, float>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
|
||||
template NDArray CSRGetData<kDLCPU, int64_t, float>(
|
||||
template NDArray CSRGetData<kDGLCPU, int64_t, float>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
|
||||
template NDArray CSRGetData<kDLCPU, int32_t, double>(
|
||||
template NDArray CSRGetData<kDGLCPU, int32_t, double>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
|
||||
template NDArray CSRGetData<kDLCPU, int64_t, double>(
|
||||
template NDArray CSRGetData<kDGLCPU, int64_t, double>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
|
||||
|
||||
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
|
||||
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(
|
||||
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
|
||||
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(
|
||||
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -134,13 +134,13 @@ std::pair<CSRMatrix, NDArray> CSRMM(
|
||||
C_weights};
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
|
||||
}; // namespace aten
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CSRRemoveConsecutive(
|
||||
CSRMatrix csr,
|
||||
IdArray entries,
|
||||
@@ -48,7 +48,7 @@ void CSRRemoveConsecutive(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CSRRemoveShuffled(
|
||||
CSRMatrix csr,
|
||||
IdArray entries,
|
||||
@@ -77,7 +77,7 @@ void CSRRemoveShuffled(
|
||||
|
||||
}; // namespace
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
|
||||
CHECK_SAME_DTYPE(csr.indices, entries);
|
||||
const int64_t nnz = csr.indices->shape[0];
|
||||
@@ -103,8 +103,8 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
|
||||
IdArray::FromVector(new_eids));
|
||||
}
|
||||
|
||||
template CSRMatrix CSRRemove<kDLCPU, int32_t>(CSRMatrix csr, IdArray entries);
|
||||
template CSRMatrix CSRRemove<kDLCPU, int64_t>(CSRMatrix csr, IdArray entries);
|
||||
template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
|
||||
template CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace aten {
|
||||
namespace impl {
|
||||
|
||||
///////////////////////////// CSRIsSorted /////////////////////////////
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsSorted(CSRMatrix csr) {
|
||||
const IdType* indptr = csr.indptr.Ptr<IdType>();
|
||||
const IdType* indices = csr.indices.Ptr<IdType>();
|
||||
@@ -31,12 +31,12 @@ bool CSRIsSorted(CSRMatrix csr) {
|
||||
[](bool a, bool b) { return a && b; });
|
||||
}
|
||||
|
||||
template bool CSRIsSorted<kDLCPU, int64_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDLCPU, int32_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
|
||||
|
||||
///////////////////////////// CSRSort /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CSRSort_(CSRMatrix* csr) {
|
||||
typedef std::pair<IdType, IdType> ShufflePair;
|
||||
const int64_t num_rows = csr->num_rows;
|
||||
@@ -79,10 +79,10 @@ void CSRSort_(CSRMatrix* csr) {
|
||||
csr->sorted = true;
|
||||
}
|
||||
|
||||
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename TagType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename TagType>
|
||||
std::pair<CSRMatrix, NDArray> CSRSortByTag(
|
||||
const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
|
||||
const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
|
||||
@@ -143,13 +143,13 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
|
||||
return std::make_pair(output, tag_pos);
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int64_t>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
|
||||
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int32_t>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
|
||||
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int64_t>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
|
||||
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int32_t>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
|
||||
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -130,13 +130,13 @@ std::pair<CSRMatrix, NDArray> CSRSum(
|
||||
C_weights};
|
||||
}
|
||||
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, float>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, float>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, double>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, double>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
|
||||
}; // namespace aten
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace dgl {
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
|
||||
if (!csr.sorted)
|
||||
csr = CSRSort(csr);
|
||||
@@ -67,8 +67,8 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
|
||||
return std::make_tuple(res_csr, edge_count, eids_remapped);
|
||||
}
|
||||
|
||||
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int32_t>(CSRMatrix);
|
||||
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int64_t>(CSRMatrix);
|
||||
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(CSRMatrix);
|
||||
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(CSRMatrix);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace dgl {
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
|
||||
std::vector<IdType> res_indptr;
|
||||
std::vector<IdType> res_indices;
|
||||
@@ -109,8 +109,8 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
|
||||
sorted);
|
||||
}
|
||||
|
||||
template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&);
|
||||
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&);
|
||||
template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix>&);
|
||||
template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix>&);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -26,7 +26,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
|
||||
IdArray prefix_src_arr = NewIdArray(
|
||||
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
|
||||
@@ -52,7 +52,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
|
||||
CumSum(prefix_elm_arr, true));
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
|
||||
bool has_data = false;
|
||||
bool row_sorted = true;
|
||||
@@ -118,8 +118,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
|
||||
col_sorted);
|
||||
}
|
||||
|
||||
template COOMatrix DisjointUnionCoo<kDLCPU, int32_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDLCPU, int64_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(const std::vector<COOMatrix>& coos);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -62,74 +62,74 @@ void GatherMMScatter(const NDArray A,
|
||||
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
|
||||
}
|
||||
|
||||
template void GatherMM<kDLCPU, int32_t, 16>(
|
||||
template void GatherMM<kDGLCPU, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLCPU, int64_t, 16>(
|
||||
template void GatherMM<kDGLCPU, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLCPU, int32_t, 32>(
|
||||
template void GatherMM<kDGLCPU, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLCPU, int64_t, 32>(
|
||||
template void GatherMM<kDGLCPU, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLCPU, int32_t, 64>(
|
||||
template void GatherMM<kDGLCPU, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLCPU, int64_t, 64>(
|
||||
template void GatherMM<kDGLCPU, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
|
||||
template void GatherMMScatter<kDLCPU, int32_t, 16>(
|
||||
template void GatherMMScatter<kDGLCPU, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLCPU, int64_t, 16>(
|
||||
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLCPU, int32_t, 32>(
|
||||
template void GatherMMScatter<kDGLCPU, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLCPU, int64_t, 32>(
|
||||
template void GatherMMScatter<kDGLCPU, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLCPU, int32_t, 64>(
|
||||
template void GatherMMScatter<kDGLCPU, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLCPU, int64_t, 64>(
|
||||
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
|
||||
template void SegmentMM<kDLCPU, int32_t, 16>(
|
||||
template void SegmentMM<kDGLCPU, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLCPU, int64_t, 16>(
|
||||
template void SegmentMM<kDGLCPU, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLCPU, int32_t, 32>(
|
||||
template void SegmentMM<kDGLCPU, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLCPU, int64_t, 32>(
|
||||
template void SegmentMM<kDGLCPU, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLCPU, int32_t, 64>(
|
||||
template void SegmentMM<kDGLCPU, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLCPU, int64_t, 64>(
|
||||
template void SegmentMM<kDGLCPU, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
|
||||
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
|
||||
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
|
||||
} // namespace aten
|
||||
|
||||
@@ -17,7 +17,7 @@ namespace dgl {
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
const CSRMatrix &csr,
|
||||
int64_t num_samples,
|
||||
@@ -61,9 +61,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)};
|
||||
}
|
||||
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int32_t>(
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>(
|
||||
const CSRMatrix&, int64_t, int, bool, bool, double);
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int64_t>(
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>(
|
||||
const CSRMatrix&, int64_t, int, bool, bool, double);
|
||||
|
||||
}; // namespace impl
|
||||
|
||||
@@ -97,13 +97,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
|
||||
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
|
||||
// significant. (minjie)
|
||||
IdArray picked_row = NDArray::Empty({num_rows * num_picks},
|
||||
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
|
||||
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
|
||||
ctx);
|
||||
IdArray picked_col = NDArray::Empty({num_rows * num_picks},
|
||||
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
|
||||
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
|
||||
ctx);
|
||||
IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
|
||||
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
|
||||
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
|
||||
ctx);
|
||||
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
|
||||
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);
|
||||
|
||||
@@ -117,7 +117,7 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
|
||||
|
||||
/////////////////////////////// CSR ///////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
|
||||
FloatArray prob, bool replace) {
|
||||
CHECK(prob.defined());
|
||||
@@ -125,16 +125,16 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
|
||||
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples,
|
||||
FloatArray prob, bool replace, bool etype_sorted) {
|
||||
@@ -143,28 +143,28 @@ COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes
|
||||
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
|
||||
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType>
|
||||
template <DGLDeviceType XPU, typename IdxType>
|
||||
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
|
||||
int64_t num_samples, bool replace) {
|
||||
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
|
||||
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
|
||||
CSRMatrix, IdArray, int64_t, bool);
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
|
||||
CSRMatrix, IdArray, int64_t, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType>
|
||||
template <DGLDeviceType XPU, typename IdxType>
|
||||
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples,
|
||||
bool replace, bool etype_sorted) {
|
||||
@@ -172,12 +172,12 @@ COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray
|
||||
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
|
||||
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
|
||||
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
|
||||
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
|
||||
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
COOMatrix CSRRowWiseSamplingBiased(
|
||||
CSRMatrix mat,
|
||||
IdArray rows,
|
||||
@@ -191,22 +191,22 @@ COOMatrix CSRRowWiseSamplingBiased(
|
||||
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>(
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>(
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>(
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>(
|
||||
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
|
||||
|
||||
|
||||
/////////////////////////////// COO ///////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
|
||||
FloatArray prob, bool replace) {
|
||||
CHECK(prob.defined());
|
||||
@@ -214,16 +214,16 @@ COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
|
||||
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>(
|
||||
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
|
||||
COOMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>(
|
||||
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
|
||||
COOMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>(
|
||||
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
|
||||
COOMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
|
||||
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
|
||||
COOMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
|
||||
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples,
|
||||
FloatArray prob, bool replace, bool etype_sorted) {
|
||||
@@ -232,28 +232,28 @@ COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes
|
||||
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
|
||||
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType>
|
||||
template <DGLDeviceType XPU, typename IdxType>
|
||||
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
|
||||
int64_t num_samples, bool replace) {
|
||||
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
|
||||
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>(
|
||||
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
|
||||
COOMatrix, IdArray, int64_t, bool);
|
||||
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
|
||||
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
|
||||
COOMatrix, IdArray, int64_t, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType>
|
||||
template <DGLDeviceType XPU, typename IdxType>
|
||||
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
|
||||
const std::vector<int64_t>& num_samples,
|
||||
bool replace, bool etype_sorted) {
|
||||
@@ -261,9 +261,9 @@ COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray
|
||||
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
|
||||
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
|
||||
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
|
||||
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
|
||||
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -55,52 +55,52 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
|
||||
|
||||
} // namespace
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename DType>
|
||||
COOMatrix CSRRowWiseTopk(
|
||||
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
|
||||
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
|
||||
return CSRRowWisePick(mat, rows, k, false, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int32_t>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int32_t>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int64_t>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int64_t>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, float>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, float>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, double>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, double>(
|
||||
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, NDArray, bool);
|
||||
|
||||
template <DLDeviceType XPU, typename IdxType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdxType, typename DType>
|
||||
COOMatrix COORowWiseTopk(
|
||||
COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
|
||||
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
|
||||
return COORowWisePick(mat, rows, k, false, pick_fn);
|
||||
}
|
||||
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int32_t>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int32_t>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int64_t>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int64_t>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, float>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, float>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, double>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, double>(
|
||||
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
|
||||
COOMatrix, IdArray, int64_t, NDArray, bool);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -102,67 +102,67 @@ void SDDMMCsrHetero(const std::string& op,
|
||||
});
|
||||
}
|
||||
|
||||
template void SDDMMCsr<kDLCPU, int32_t, 16>(
|
||||
template void SDDMMCsr<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLCPU, int64_t, 16>(
|
||||
template void SDDMMCsr<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLCPU, int32_t, 32>(
|
||||
template void SDDMMCsr<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLCPU, int64_t, 32>(
|
||||
template void SDDMMCsr<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLCPU, int32_t, 64>(
|
||||
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLCPU, int64_t, 64>(
|
||||
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
|
||||
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
|
||||
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
@@ -217,67 +217,67 @@ void SDDMMCooHetero(const std::string& op,
|
||||
});
|
||||
}
|
||||
|
||||
template void SDDMMCoo<kDLCPU, int32_t, 16>(
|
||||
template void SDDMMCoo<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLCPU, int64_t, 16>(
|
||||
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLCPU, int32_t, 32>(
|
||||
template void SDDMMCoo<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLCPU, int64_t, 32>(
|
||||
template void SDDMMCoo<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLCPU, int32_t, 64>(
|
||||
template void SDDMMCoo<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLCPU, int64_t, 64>(
|
||||
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
|
||||
template void SDDMMCooHetero<kDLCPU, int32_t, 16>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLCPU, int64_t, 16>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLCPU, int32_t, 32>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLCPU, int64_t, 32>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLCPU, int32_t, 64>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLCPU, int64_t, 64>(
|
||||
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
|
||||
@@ -74,113 +74,113 @@ void BackwardSegmentCmp(
|
||||
});
|
||||
}
|
||||
|
||||
template void SegmentReduce<kDLCPU, int32_t, 16>(
|
||||
template void SegmentReduce<kDGLCPU, int32_t, 16>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLCPU, int64_t, 16>(
|
||||
template void SegmentReduce<kDGLCPU, int64_t, 16>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLCPU, int32_t, 32>(
|
||||
template void SegmentReduce<kDGLCPU, int32_t, 32>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLCPU, int64_t, 32>(
|
||||
template void SegmentReduce<kDGLCPU, int64_t, 32>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLCPU, int32_t, 64>(
|
||||
template void SegmentReduce<kDGLCPU, int32_t, 64>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLCPU, int64_t, 64>(
|
||||
template void SegmentReduce<kDGLCPU, int64_t, 64>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void ScatterAdd<kDLCPU, int32_t, 16>(
|
||||
template void ScatterAdd<kDGLCPU, int32_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLCPU, int64_t, 16>(
|
||||
template void ScatterAdd<kDGLCPU, int64_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLCPU, int32_t, 32>(
|
||||
template void ScatterAdd<kDGLCPU, int32_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLCPU, int64_t, 32>(
|
||||
template void ScatterAdd<kDGLCPU, int64_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLCPU, int32_t, 64>(
|
||||
template void ScatterAdd<kDGLCPU, int32_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLCPU, int64_t, 64>(
|
||||
template void ScatterAdd<kDGLCPU, int64_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
|
||||
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLCPU, int64_t, 16>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLCPU, int32_t, 32>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLCPU, int64_t, 32>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLCPU, int32_t, 64>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLCPU, int64_t, 64>(
|
||||
template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
|
||||
@@ -29,7 +29,7 @@ namespace impl {
|
||||
|
||||
///////////////////////////// COOIsNonZero /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
|
||||
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
|
||||
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
|
||||
@@ -42,10 +42,10 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template bool COOIsNonZero<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
|
||||
template bool COOIsNonZero<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
|
||||
template bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
|
||||
template bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
|
||||
const auto rowlen = row->shape[0];
|
||||
const auto collen = col->shape[0];
|
||||
@@ -67,12 +67,12 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
|
||||
return rst;
|
||||
}
|
||||
|
||||
template NDArray COOIsNonZero<kDLCPU, int32_t>(COOMatrix, NDArray, NDArray);
|
||||
template NDArray COOIsNonZero<kDLCPU, int64_t>(COOMatrix, NDArray, NDArray);
|
||||
template NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);
|
||||
template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
|
||||
|
||||
///////////////////////////// COOHasDuplicate /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool COOHasDuplicate(COOMatrix coo) {
|
||||
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
|
||||
const IdType* src_data = static_cast<IdType*>(coo.row->data);
|
||||
@@ -89,12 +89,12 @@ bool COOHasDuplicate(COOMatrix coo) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template bool COOHasDuplicate<kDLCPU, int32_t>(COOMatrix coo);
|
||||
template bool COOHasDuplicate<kDLCPU, int64_t>(COOMatrix coo);
|
||||
template bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);
|
||||
template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
|
||||
|
||||
///////////////////////////// COOGetRowNNZ /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
|
||||
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
|
||||
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
|
||||
@@ -106,10 +106,10 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
|
||||
return result;
|
||||
}
|
||||
|
||||
template int64_t COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
|
||||
CHECK_SAME_DTYPE(coo.col, rows);
|
||||
const auto len = rows->shape[0];
|
||||
@@ -123,12 +123,12 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
|
||||
return rst;
|
||||
}
|
||||
|
||||
template NDArray COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
|
||||
|
||||
///////////////////////////// COOGetRowDataAndIndices /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
|
||||
COOMatrix coo, int64_t row) {
|
||||
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
|
||||
@@ -151,13 +151,13 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
|
||||
}
|
||||
|
||||
template std::pair<NDArray, NDArray>
|
||||
COOGetRowDataAndIndices<kDLCPU, int32_t>(COOMatrix, int64_t);
|
||||
COOGetRowDataAndIndices<kDGLCPU, int32_t>(COOMatrix, int64_t);
|
||||
template std::pair<NDArray, NDArray>
|
||||
COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t);
|
||||
COOGetRowDataAndIndices<kDGLCPU, int64_t>(COOMatrix, int64_t);
|
||||
|
||||
///////////////////////////// COOGetData /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
|
||||
const int64_t rowlen = rows->shape[0];
|
||||
const int64_t collen = cols->shape[0];
|
||||
@@ -211,12 +211,12 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray COOGetData<kDLCPU, int32_t>(COOMatrix, IdArray, IdArray);
|
||||
template IdArray COOGetData<kDLCPU, int64_t>(COOMatrix, IdArray, IdArray);
|
||||
template IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);
|
||||
template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
|
||||
|
||||
///////////////////////////// COOGetDataAndIndices /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
|
||||
NDArray cols) {
|
||||
CHECK_SAME_DTYPE(coo.col, rows);
|
||||
@@ -286,20 +286,20 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
|
||||
NDArray::FromVector(ret_data)};
|
||||
}
|
||||
|
||||
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t>(
|
||||
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(
|
||||
COOMatrix coo, NDArray rows, NDArray cols);
|
||||
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t>(
|
||||
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(
|
||||
COOMatrix coo, NDArray rows, NDArray cols);
|
||||
|
||||
///////////////////////////// COOTranspose /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOTranspose(COOMatrix coo) {
|
||||
return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
|
||||
}
|
||||
|
||||
template COOMatrix COOTranspose<kDLCPU, int32_t>(COOMatrix coo);
|
||||
template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo);
|
||||
template COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);
|
||||
template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
|
||||
|
||||
///////////////////////////// COOToCSR /////////////////////////////
|
||||
namespace {
|
||||
@@ -615,7 +615,7 @@ P^2).
|
||||
degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ +
|
||||
N*P).
|
||||
*/
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix COOToCSR(COOMatrix coo) {
|
||||
if (!coo.row_sorted) {
|
||||
const int64_t num_threads = omp_get_num_threads();
|
||||
@@ -632,12 +632,12 @@ CSRMatrix COOToCSR(COOMatrix coo) {
|
||||
return SortedCOOToCSR<IdType>(coo);
|
||||
}
|
||||
|
||||
template CSRMatrix COOToCSR<kDLCPU, int32_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDLCPU, int64_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);
|
||||
|
||||
///////////////////////////// COOSliceRows /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
|
||||
// TODO(minjie): use binary search when coo.row_sorted is true
|
||||
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
|
||||
@@ -669,10 +669,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
|
||||
coo.col_sorted);
|
||||
}
|
||||
|
||||
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
|
||||
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
|
||||
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
|
||||
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
|
||||
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
|
||||
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
|
||||
@@ -703,12 +703,12 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
|
||||
coo.row_sorted, coo.col_sorted};
|
||||
}
|
||||
|
||||
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix , NDArray);
|
||||
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix , NDArray);
|
||||
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix , NDArray);
|
||||
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix , NDArray);
|
||||
|
||||
///////////////////////////// COOSliceMatrix /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
|
||||
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
|
||||
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
|
||||
@@ -740,15 +740,15 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
|
||||
coo.row_sorted, coo.col_sorted);
|
||||
}
|
||||
|
||||
template COOMatrix COOSliceMatrix<kDLCPU, int32_t>(
|
||||
template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
|
||||
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
|
||||
template COOMatrix COOSliceMatrix<kDLCPU, int64_t>(
|
||||
template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
|
||||
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
|
||||
///////////////////////////// COOReorder /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
|
||||
runtime::NDArray new_col_id_arr) {
|
||||
CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
|
||||
@@ -785,9 +785,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
|
||||
return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
|
||||
}
|
||||
|
||||
template COOMatrix COOReorder<kDLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
|
||||
template COOMatrix COOReorder<kDGLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
|
||||
runtime::NDArray new_col_ids);
|
||||
template COOMatrix COOReorder<kDLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
|
||||
template COOMatrix COOReorder<kDGLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
|
||||
runtime::NDArray new_col_ids);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace impl {
|
||||
|
||||
///////////////////////////// CSRIsNonZero /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
|
||||
@@ -39,10 +39,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
|
||||
const auto rowlen = row->shape[0];
|
||||
const auto collen = col->shape[0];
|
||||
@@ -62,12 +62,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
|
||||
return rst;
|
||||
}
|
||||
|
||||
template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
|
||||
|
||||
///////////////////////////// CSRHasDuplicate /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRHasDuplicate(CSRMatrix csr) {
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
|
||||
@@ -85,21 +85,21 @@ bool CSRHasDuplicate(CSRMatrix csr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);
|
||||
|
||||
///////////////////////////// CSRGetRowNNZ /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
return indptr_data[row + 1] - indptr_data[row];
|
||||
}
|
||||
|
||||
template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
|
||||
CHECK_SAME_DTYPE(csr.indices, rows);
|
||||
const auto len = rows->shape[0];
|
||||
@@ -114,12 +114,12 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
|
||||
return rst;
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
|
||||
|
||||
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
|
||||
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -127,12 +127,12 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
|
||||
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
///////////////////////////// CSRGetRowData /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
|
||||
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -143,13 +143,13 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
|
||||
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
///////////////////////////// CSRGetData /////////////////////////////
|
||||
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
|
||||
const IdType start, const IdType end, const IdType col,
|
||||
std::vector<IdType> *col_vec,
|
||||
@@ -172,7 +172,7 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
|
||||
// TODO(minjie): more efficient implementation for matrix without duplicate entries
|
||||
const int64_t rowlen = rows->shape[0];
|
||||
@@ -224,16 +224,16 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
|
||||
NDArray::FromVector(ret_data, csr.data->ctx)};
|
||||
}
|
||||
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols);
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>(
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols);
|
||||
|
||||
///////////////////////////// CSRTranspose /////////////////////////////
|
||||
|
||||
// for a matrix of shape (N, M) and NNZ
|
||||
// complexity: time O(NNZ + max(N, M)), space O(1)
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRTranspose(CSRMatrix csr) {
|
||||
const int64_t N = csr.num_rows;
|
||||
const int64_t M = csr.num_cols;
|
||||
@@ -281,11 +281,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
|
||||
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
|
||||
}
|
||||
|
||||
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);
|
||||
|
||||
///////////////////////////// CSRToCOO /////////////////////////////
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOO(CSRMatrix csr) {
|
||||
const int64_t nnz = csr.indices->shape[0];
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -303,11 +303,11 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
|
||||
true, csr.sorted);
|
||||
}
|
||||
|
||||
template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);
|
||||
|
||||
// complexity: time O(NNZ), space O(1)
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
|
||||
const int64_t N = csr.num_rows;
|
||||
const int64_t M = csr.num_cols;
|
||||
@@ -333,12 +333,12 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
|
||||
return COOMatrix(N, M, ret_row, ret_col);
|
||||
}
|
||||
|
||||
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);
|
||||
|
||||
///////////////////////////// CSRSliceRows /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
|
||||
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
|
||||
const int64_t num_rows = end - start;
|
||||
@@ -362,10 +362,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
|
||||
csr.sorted);
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
|
||||
CHECK_SAME_DTYPE(csr.indices, rows);
|
||||
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -467,12 +467,12 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix , NDArray);
|
||||
|
||||
///////////////////////////// CSRSliceMatrix /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
|
||||
IdHashMap<IdType> hashmap(cols);
|
||||
const int64_t new_nrows = rows->shape[0];
|
||||
@@ -521,14 +521,14 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
|
||||
sub_data_arr};
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>(
|
||||
template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>(
|
||||
template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
///////////////////////////// CSRReorder /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
|
||||
runtime::NDArray new_col_id_arr) {
|
||||
CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
|
||||
@@ -599,9 +599,9 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
|
||||
out_indptr_arr, out_indices_arr, out_data_arr);
|
||||
}
|
||||
|
||||
template CSRMatrix CSRReorder<kDLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
|
||||
template CSRMatrix CSRReorder<kDGLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
|
||||
runtime::NDArray new_col_ids);
|
||||
template CSRMatrix CSRReorder<kDLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
|
||||
template CSRMatrix CSRReorder<kDGLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
|
||||
runtime::NDArray new_col_ids);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -124,67 +124,67 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
|
||||
}
|
||||
}
|
||||
|
||||
template void SpMMCsr<kDLCPU, int32_t, 16>(
|
||||
template void SpMMCsr<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLCPU, int64_t, 16>(
|
||||
template void SpMMCsr<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLCPU, int32_t, 32>(
|
||||
template void SpMMCsr<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLCPU, int64_t, 32>(
|
||||
template void SpMMCsr<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLCPU, int32_t, 64>(
|
||||
template void SpMMCsr<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLCPU, int64_t, 64>(
|
||||
template void SpMMCsr<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_node_tids,
|
||||
const std::vector<dgl_type_t>& out_node_tids);
|
||||
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_node_tids,
|
||||
const std::vector<dgl_type_t>& out_node_tids);
|
||||
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_node_tids,
|
||||
const std::vector<dgl_type_t>& out_node_tids);
|
||||
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_node_tids,
|
||||
const std::vector<dgl_type_t>& out_node_tids);
|
||||
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_node_tids,
|
||||
const std::vector<dgl_type_t>& out_node_tids);
|
||||
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
|
||||
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
@@ -222,52 +222,52 @@ void Edge_softmax_csr_backward(const std::string& op,
|
||||
});
|
||||
}
|
||||
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
|
||||
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
|
||||
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out);
|
||||
@@ -303,27 +303,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
|
||||
}
|
||||
}
|
||||
|
||||
template void SpMMCoo<kDLCPU, int32_t, 16>(
|
||||
template void SpMMCoo<kDGLCPU, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLCPU, int64_t, 16>(
|
||||
template void SpMMCoo<kDGLCPU, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLCPU, int32_t, 32>(
|
||||
template void SpMMCoo<kDGLCPU, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLCPU, int64_t, 32>(
|
||||
template void SpMMCoo<kDGLCPU, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLCPU, int32_t, 64>(
|
||||
template void SpMMCoo<kDGLCPU, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLCPU, int64_t, 64>(
|
||||
template void SpMMCoo<kDGLCPU, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals(
|
||||
total_len += traces[i].size();
|
||||
}
|
||||
IdArray ret = IdArray::Empty({total_len},
|
||||
DLDataType{kDLInt, sizeof(DType) * 8, 1},
|
||||
DLContext{kDLCPU, 0});
|
||||
DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
|
||||
DGLContext{kDGLCPU, 0});
|
||||
DType* ret_data = static_cast<DType*>(ret->data);
|
||||
for (int64_t i = 0; i < max_len; ++i) {
|
||||
for (size_t j = 0; j < traces.size(); ++j) {
|
||||
@@ -79,7 +79,7 @@ IdArray ComputeMergedSections(
|
||||
const int64_t tracelen = traces[i].size();
|
||||
max_len = std::max(max_len, tracelen);
|
||||
}
|
||||
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
|
||||
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
|
||||
int64_t* ret_data = static_cast<int64_t*>(ret->data);
|
||||
for (int64_t i = 0; i < max_len; ++i) {
|
||||
int64_t sec_len = 0;
|
||||
@@ -96,7 +96,7 @@ IdArray ComputeMergedSections(
|
||||
|
||||
} // namespace
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
|
||||
std::vector<IdType> ids;
|
||||
std::vector<int64_t> sections;
|
||||
@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
|
||||
return front;
|
||||
}
|
||||
|
||||
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
|
||||
std::vector<IdType> ids;
|
||||
std::vector<int64_t> sections;
|
||||
@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
|
||||
return front;
|
||||
}
|
||||
|
||||
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
|
||||
std::vector<IdType> ids;
|
||||
std::vector<int64_t> sections;
|
||||
@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
|
||||
return front;
|
||||
}
|
||||
|
||||
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&);
|
||||
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&);
|
||||
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&);
|
||||
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
|
||||
const int64_t len = source->shape[0];
|
||||
const IdType* src_data = static_cast<IdType*>(source->data);
|
||||
@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
|
||||
return front;
|
||||
}
|
||||
|
||||
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
|
||||
template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
|
||||
IdArray source,
|
||||
const bool has_reverse_edge,
|
||||
@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
|
||||
return front;
|
||||
}
|
||||
|
||||
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&,
|
||||
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&,
|
||||
IdArray,
|
||||
const bool,
|
||||
const bool,
|
||||
const bool);
|
||||
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&,
|
||||
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
|
||||
IdArray,
|
||||
const bool,
|
||||
const bool,
|
||||
|
||||
@@ -13,7 +13,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray CumSum(IdArray array, bool prepend_zero) {
|
||||
const int64_t len = array.NumElements();
|
||||
if (len == 0)
|
||||
@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
|
||||
template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -13,7 +13,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template<DLDeviceType XPU, typename DType, typename IdType>
|
||||
template<DGLDeviceType XPU, typename DType, typename IdType>
|
||||
NDArray IndexSelect(NDArray array, IdArray index) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const DType* array_data = static_cast<DType*>(array->data);
|
||||
@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
|
||||
#ifdef USE_FP16
|
||||
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
|
||||
#endif
|
||||
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
|
||||
template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
DType IndexSelect(NDArray array, int64_t index) {
|
||||
auto device = runtime::DeviceAPI::Get(array->ctx);
|
||||
#ifdef USE_FP16
|
||||
@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) {
|
||||
#endif
|
||||
device->CopyDataFromTo(
|
||||
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
|
||||
sizeof(DType), array->ctx, DLContext{kDLCPU, 0},
|
||||
array->dtype);
|
||||
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
|
||||
return reinterpret_cast<DType&>(ret);
|
||||
}
|
||||
|
||||
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index);
|
||||
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index);
|
||||
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index);
|
||||
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index);
|
||||
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
|
||||
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
|
||||
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
|
||||
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
|
||||
#ifdef USE_FP16
|
||||
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index);
|
||||
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
|
||||
#endif
|
||||
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index);
|
||||
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index);
|
||||
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
|
||||
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -26,7 +26,7 @@ struct IsNonZeroIndex {
|
||||
const IdType * array_;
|
||||
};
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray NonZero(IdArray array) {
|
||||
const auto& ctx = array->ctx;
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
|
||||
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
|
||||
}
|
||||
|
||||
template IdArray NonZero<kDLGPU, int32_t>(IdArray);
|
||||
template IdArray NonZero<kDLGPU, int64_t>(IdArray);
|
||||
template IdArray NonZero<kDGLCUDA, int32_t>(IdArray);
|
||||
template IdArray NonZero<kDGLCUDA, int64_t>(IdArray);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
|
||||
const int64_t len = lhs->shape[0];
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
|
||||
|
||||
|
||||
template <typename IdType, typename Op>
|
||||
@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
|
||||
const int64_t len = lhs->shape[0];
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
|
||||
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
|
||||
const int64_t len = rhs->shape[0];
|
||||
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
|
||||
@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
|
||||
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
|
||||
|
||||
template <typename IdType, typename Op>
|
||||
__global__ void _UnaryElewiseKernel(
|
||||
@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename Op>
|
||||
template <DGLDeviceType XPU, typename IdType, typename Op>
|
||||
IdArray UnaryElewise(IdArray lhs) {
|
||||
const int64_t len = lhs->shape[0];
|
||||
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
|
||||
@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);
|
||||
template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
|
||||
|
||||
///////////////////////////// Full /////////////////////////////
|
||||
|
||||
@@ -210,9 +210,9 @@ __global__ void _FullKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename DType>
|
||||
NDArray Full(DType val, int64_t length, DLContext ctx) {
|
||||
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
|
||||
template <DGLDeviceType XPU, typename DType>
|
||||
NDArray Full(DType val, int64_t length, DGLContext ctx) {
|
||||
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
|
||||
DType* ret_data = static_cast<DType*>(ret->data);
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
int nt = cuda::FindNumThreads(length);
|
||||
@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
|
||||
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
|
||||
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
|
||||
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
|
||||
#ifdef USE_FP16
|
||||
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx);
|
||||
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
|
||||
#endif
|
||||
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
|
||||
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
|
||||
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
|
||||
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
|
||||
|
||||
|
||||
///////////////////////////// Range /////////////////////////////
|
||||
@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DLContext ctx) {
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Range(IdType low, IdType high, DGLContext ctx) {
|
||||
CHECK(high >= low) << "high must be bigger than low";
|
||||
const IdType length = high - low;
|
||||
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
|
||||
@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
|
||||
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
|
||||
template IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);
|
||||
template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
|
||||
|
||||
///////////////////////////// Relabel_ //////////////////////////////
|
||||
|
||||
@@ -278,7 +278,7 @@ __global__ void _RelabelKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
IdArray all_nodes = Concat(arrays);
|
||||
const int64_t total_length = all_nodes->shape[0];
|
||||
@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
&num_induced, 0,
|
||||
sizeof(num_induced),
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLType{kDLInt, 64, 1});
|
||||
DGLContext{kDGLCPU, 0},
|
||||
DGLDataType{kDGLInt, 64, 1});
|
||||
|
||||
device->StreamSync(ctx, stream);
|
||||
device->FreeWorkspace(ctx, num_induced_device);
|
||||
@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
|
||||
return induced_nodes;
|
||||
}
|
||||
|
||||
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
|
||||
template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
|
||||
|
||||
///////////////////////////// AsNumBits /////////////////////////////
|
||||
|
||||
@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
IdArray AsNumBits(IdArray arr, uint8_t bits) {
|
||||
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
|
||||
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
|
||||
IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);
|
||||
const int64_t length = ret.NumElements();
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
int nt = cuda::FindNumThreads(length);
|
||||
@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
|
||||
}
|
||||
|
||||
|
||||
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
|
||||
template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value,
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename DType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename DType, typename IdType>
|
||||
void Scatter_(IdArray index, NDArray value, NDArray out) {
|
||||
const int64_t len = index->shape[0];
|
||||
const IdType* idx = index.Ptr<IdType>();
|
||||
@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
|
||||
idx, val, len, outd);
|
||||
}
|
||||
|
||||
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
|
||||
#ifdef USE_FP16
|
||||
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
|
||||
#endif
|
||||
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
|
||||
#ifdef USE_FP16
|
||||
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
|
||||
#endif
|
||||
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
|
||||
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
|
||||
|
||||
}; // namespace impl
|
||||
}; // namespace aten
|
||||
|
||||
@@ -13,7 +13,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
|
||||
const auto& ctx = array->ctx;
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
|
||||
return std::make_pair(sorted_array, sorted_idx);
|
||||
}
|
||||
|
||||
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits);
|
||||
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -14,14 +14,14 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix COOToCSR(COOMatrix coo) {
|
||||
LOG(FATAL) << "Unreachable code.";
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
|
||||
CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
|
||||
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
// allocate cusparse handle if needed
|
||||
@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound(
|
||||
}
|
||||
|
||||
template <>
|
||||
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
|
||||
CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
|
||||
const auto& ctx = coo.row->ctx;
|
||||
const auto nbits = coo.row->dtype.bits;
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
|
||||
indptr, coo.col, coo.data, col_sorted);
|
||||
}
|
||||
|
||||
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
|
||||
template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) {
|
||||
return bits;
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void COOSort_(COOMatrix* coo, bool sort_column) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const int row_bits = _NumberOfBits(coo->num_rows);
|
||||
@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
|
||||
}
|
||||
}
|
||||
|
||||
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
|
||||
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);
|
||||
template void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);
|
||||
template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
|
||||
|
||||
///////////////////////////// COOIsSorted /////////////////////////////
|
||||
|
||||
@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
|
||||
const int64_t nnz = coo.row->shape[0];
|
||||
const auto& ctx = coo.row->ctx;
|
||||
@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
|
||||
return {row_sorted, col_sorted};
|
||||
}
|
||||
|
||||
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);
|
||||
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -14,14 +14,14 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOO(CSRMatrix csr) {
|
||||
LOG(FATAL) << "Unreachable codes";
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
|
||||
COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
|
||||
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
// allocate cusparse handle if needed
|
||||
@@ -77,7 +77,7 @@ __global__ void _RepeatKernel(
|
||||
}
|
||||
|
||||
template <>
|
||||
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
|
||||
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
|
||||
const auto& ctx = csr.indptr->ctx;
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
|
||||
@@ -99,18 +99,18 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
|
||||
true, csr.sorted);
|
||||
}
|
||||
|
||||
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
|
||||
LOG(FATAL) << "Unreachable codes";
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
|
||||
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr);
|
||||
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
|
||||
COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
|
||||
if (aten::IsNullArray(coo.data))
|
||||
return coo;
|
||||
|
||||
@@ -156,8 +156,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
|
||||
}
|
||||
|
||||
template <>
|
||||
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
|
||||
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr);
|
||||
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
|
||||
COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
|
||||
if (aten::IsNullArray(coo.data))
|
||||
return coo;
|
||||
const auto& sorted = Sort(coo.data);
|
||||
@@ -173,8 +173,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
|
||||
return coo;
|
||||
}
|
||||
|
||||
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
|
||||
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -17,7 +17,7 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType, typename DType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename DType>
|
||||
NDArray CSRGetData(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
|
||||
const int64_t rowlen = rows->shape[0];
|
||||
@@ -38,7 +38,7 @@ NDArray CSRGetData(
|
||||
const int nt = cuda::FindNumThreads(rstlen);
|
||||
const int nb = (rstlen + nt - 1) / nt;
|
||||
if (return_eids)
|
||||
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
|
||||
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
|
||||
"DType does not match row's dtype.";
|
||||
|
||||
// TODO(minjie): use binary search for sorted csr
|
||||
@@ -53,24 +53,24 @@ NDArray CSRGetData(
|
||||
}
|
||||
|
||||
#ifdef USE_FP16
|
||||
template NDArray CSRGetData<kDLGPU, int32_t, __half>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
|
||||
template NDArray CSRGetData<kDLGPU, int64_t, __half>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
|
||||
#endif
|
||||
template NDArray CSRGetData<kDLGPU, int32_t, float>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
|
||||
template NDArray CSRGetData<kDLGPU, int64_t, float>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
|
||||
template NDArray CSRGetData<kDLGPU, int32_t, double>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
|
||||
template NDArray CSRGetData<kDLGPU, int64_t, double>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
|
||||
|
||||
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
|
||||
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
|
||||
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
|
||||
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
|
||||
}
|
||||
|
||||
#ifdef USE_FP16
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
#endif
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
|
||||
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
|
||||
|
||||
} // namespace aten
|
||||
|
||||
@@ -34,7 +34,7 @@ __global__ void _SegmentIsSorted(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsSorted(CSRMatrix csr) {
|
||||
const auto& ctx = csr.indptr->ctx;
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
@@ -53,16 +53,16 @@ bool CSRIsSorted(CSRMatrix csr) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
|
||||
template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void CSRSort_(CSRMatrix* csr) {
|
||||
LOG(FATAL) << "Unreachable codes";
|
||||
}
|
||||
|
||||
template <>
|
||||
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
|
||||
void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
|
||||
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
|
||||
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
@@ -108,7 +108,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
|
||||
}
|
||||
|
||||
template <>
|
||||
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
|
||||
void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
|
||||
|
||||
@@ -147,8 +147,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
|
||||
device->FreeWorkspace(ctx, workspace);
|
||||
}
|
||||
|
||||
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
|
||||
template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
|
||||
}
|
||||
|
||||
#ifdef USE_FP16
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
#endif
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>(
|
||||
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
|
||||
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
|
||||
|
||||
} // namespace aten
|
||||
|
||||
@@ -13,14 +13,14 @@ using runtime::NDArray;
|
||||
namespace aten {
|
||||
namespace impl {
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRTranspose(CSRMatrix csr) {
|
||||
LOG(FATAL) << "Unreachable codes";
|
||||
return {};
|
||||
}
|
||||
|
||||
template <>
|
||||
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
|
||||
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
|
||||
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
// allocate cusparse handle if needed
|
||||
@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
|
||||
}
|
||||
|
||||
template <>
|
||||
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) {
|
||||
CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {
|
||||
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
|
||||
}
|
||||
|
||||
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);
|
||||
template CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -105,7 +105,7 @@ IdArray _PerformFilter(
|
||||
&num_unique, 0,
|
||||
sizeof(num_unique),
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLContext{kDGLCPU, 0},
|
||||
test->dtype);
|
||||
|
||||
// insert items into set
|
||||
@@ -150,13 +150,13 @@ class CudaFilterSet : public Filter {
|
||||
|
||||
} // namespace
|
||||
|
||||
template<DLDeviceType XPU, typename IdType>
|
||||
template<DGLDeviceType XPU, typename IdType>
|
||||
FilterRef CreateSetFilter(IdArray set) {
|
||||
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
|
||||
}
|
||||
|
||||
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set);
|
||||
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set);
|
||||
template FilterRef CreateSetFilter<kDGLCUDA, int32_t>(IdArray set);
|
||||
template FilterRef CreateSetFilter<kDGLCUDA, int64_t>(IdArray set);
|
||||
|
||||
} // namespace array
|
||||
} // namespace dgl
|
||||
|
||||
@@ -47,7 +47,7 @@ __global__ void _DisjointUnionKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
|
||||
IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
|
||||
IdArray n_rows = NewIdArray(n, CPU, nbits);
|
||||
@@ -71,10 +71,10 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
|
||||
CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
|
||||
int64_t n_arrs, int n_elms,
|
||||
DGLContext ctx, DGLType dtype, cudaStream_t stream) {
|
||||
DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
int nt = 256;
|
||||
int nb = (n_elms + nt - 1) / nt;
|
||||
@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
|
||||
|
||||
device->CopyDataFromTo(
|
||||
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs,
|
||||
DGLContext{kDLCPU, 0}, ctx, dtype);
|
||||
DGLContext{kDGLCPU, 0}, ctx, dtype);
|
||||
|
||||
CUDA_KERNEL_CALL(_DisjointUnionKernel,
|
||||
nb, nt, 0, stream,
|
||||
@@ -94,7 +94,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
|
||||
device->FreeWorkspace(ctx, arrs_dev);
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);
|
||||
@@ -133,17 +133,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
|
||||
IdType n_elements = 0;
|
||||
device->CopyDataFromTo(
|
||||
&prefix_elm[coos.size()], 0, &n_elements, 0,
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
|
||||
coos[0].row->dtype);
|
||||
|
||||
device->CopyDataFromTo(
|
||||
&prefix_src[coos.size()], 0, &src_offset, 0,
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
|
||||
coos[0].row->dtype);
|
||||
|
||||
device->CopyDataFromTo(
|
||||
&prefix_dst[coos.size()], 0, &dst_offset, 0,
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
|
||||
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
|
||||
coos[0].row->dtype);
|
||||
|
||||
// Union src array
|
||||
@@ -176,8 +176,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
|
||||
col_sorted);
|
||||
}
|
||||
|
||||
template COOMatrix DisjointUnionCoo<kDLGPU, int32_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDLGPU, int64_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos);
|
||||
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -394,74 +394,74 @@ void GatherMMScatter(const NDArray A,
|
||||
}
|
||||
|
||||
|
||||
template void GatherMM<kDLGPU, int32_t, 16>(
|
||||
template void GatherMM<kDGLCUDA, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLGPU, int64_t, 16>(
|
||||
template void GatherMM<kDGLCUDA, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLGPU, int32_t, 32>(
|
||||
template void GatherMM<kDGLCUDA, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLGPU, int64_t, 32>(
|
||||
template void GatherMM<kDGLCUDA, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLGPU, int32_t, 64>(
|
||||
template void GatherMM<kDGLCUDA, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
template void GatherMM<kDLGPU, int64_t, 64>(
|
||||
template void GatherMM<kDGLCUDA, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b);
|
||||
|
||||
template void GatherMMScatter<kDLGPU, int32_t, 16>(
|
||||
template void GatherMMScatter<kDGLCUDA, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLGPU, int64_t, 16>(
|
||||
template void GatherMMScatter<kDGLCUDA, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLGPU, int32_t, 32>(
|
||||
template void GatherMMScatter<kDGLCUDA, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLGPU, int64_t, 32>(
|
||||
template void GatherMMScatter<kDGLCUDA, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLGPU, int32_t, 64>(
|
||||
template void GatherMMScatter<kDGLCUDA, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
template void GatherMMScatter<kDLGPU, int64_t, 64>(
|
||||
template void GatherMMScatter<kDGLCUDA, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
|
||||
|
||||
template void SegmentMM<kDLGPU, int32_t, 16>(
|
||||
template void SegmentMM<kDGLCUDA, int32_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLGPU, int64_t, 16>(
|
||||
template void SegmentMM<kDGLCUDA, int64_t, 16>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLGPU, int32_t, 32>(
|
||||
template void SegmentMM<kDGLCUDA, int32_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLGPU, int64_t, 32>(
|
||||
template void SegmentMM<kDGLCUDA, int64_t, 32>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLGPU, int32_t, 64>(
|
||||
template void SegmentMM<kDGLCUDA, int32_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
template void SegmentMM<kDLGPU, int64_t, 64>(
|
||||
template void SegmentMM<kDGLCUDA, int64_t, 64>(
|
||||
const NDArray A, const NDArray B, NDArray C,
|
||||
const NDArray seglen_A, bool a_trans, bool b_trans);
|
||||
|
||||
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 16>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 16>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 32>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 32>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 64>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>(
|
||||
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 64>(
|
||||
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
|
||||
|
||||
} // namespace aten
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
} \
|
||||
} else { \
|
||||
constexpr bool UseBcast = true; \
|
||||
const DLContext ctx = (CTX); \
|
||||
const DGLContext ctx = (CTX); \
|
||||
const auto device = runtime::DeviceAPI::Get(ctx); \
|
||||
(LHS_OFF) = static_cast<int64_t*>( \
|
||||
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \
|
||||
|
||||
@@ -93,7 +93,7 @@ struct IsNotMinusOne {
|
||||
template <typename IdType>
|
||||
void SortOrderedPairs(
|
||||
runtime::DeviceAPI* device,
|
||||
DLContext ctx,
|
||||
DGLContext ctx,
|
||||
IdType* major,
|
||||
IdType* minor,
|
||||
IdType* tmp_major,
|
||||
@@ -128,7 +128,7 @@ void SortOrderedPairs(
|
||||
|
||||
}; // namespace
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
const CSRMatrix& csr,
|
||||
int64_t num_samples,
|
||||
@@ -211,9 +211,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
|
||||
return result;
|
||||
}
|
||||
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int32_t>(
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int32_t>(
|
||||
const CSRMatrix&, int64_t, int, bool, bool, double);
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int64_t>(
|
||||
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int64_t>(
|
||||
const CSRMatrix&, int64_t, int, bool, bool, double);
|
||||
|
||||
}; // namespace impl
|
||||
|
||||
@@ -240,7 +240,7 @@ __global__ void _CSRRowWiseSampleUniformReplaceKernel(
|
||||
|
||||
///////////////////////////// CSR sampling //////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
|
||||
IdArray rows,
|
||||
const int64_t num_picks,
|
||||
@@ -311,7 +311,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
|
||||
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
|
||||
sizeof(new_len),
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLContext{kDGLCPU, 0},
|
||||
mat.indptr->dtype);
|
||||
CUDA_CALL(cudaEventRecord(copyEvent, stream));
|
||||
|
||||
@@ -369,9 +369,9 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
|
||||
picked_col, picked_idx);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int32_t>(
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
|
||||
CSRMatrix, IdArray, int64_t, bool);
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int64_t>(
|
||||
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
|
||||
CSRMatrix, IdArray, int64_t, bool);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -416,7 +416,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
|
||||
* @param replace Is replacement sampling?
|
||||
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
|
||||
*/
|
||||
template <DLDeviceType XPU, typename IdType, typename FloatType>
|
||||
template <DGLDeviceType XPU, typename IdType, typename FloatType>
|
||||
COOMatrix CSRRowWiseSampling(CSRMatrix mat,
|
||||
IdArray rows,
|
||||
int64_t num_picks,
|
||||
@@ -492,7 +492,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
|
||||
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
|
||||
sizeof(temp_len),
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLContext{kDGLCPU, 0},
|
||||
mat.indptr->dtype);
|
||||
device->StreamSync(ctx, stream);
|
||||
|
||||
@@ -523,7 +523,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
|
||||
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
|
||||
sizeof(new_len),
|
||||
ctx,
|
||||
DGLContext{kDLCPU, 0},
|
||||
DGLContext{kDGLCPU, 0},
|
||||
mat.indptr->dtype);
|
||||
CUDA_CALL(cudaEventRecord(copyEvent, stream));
|
||||
|
||||
@@ -651,13 +651,13 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
|
||||
picked_col, picked_idx);
|
||||
}
|
||||
|
||||
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, float>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, float>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, double>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, double>(
|
||||
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(
|
||||
CSRMatrix, IdArray, int64_t, FloatArray, bool);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -54,52 +54,52 @@ void SDDMMCoo(const std::string& op,
|
||||
}
|
||||
|
||||
|
||||
template void SDDMMCsr<kDLGPU, int32_t, 16>(
|
||||
template void SDDMMCsr<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLGPU, int64_t, 16>(
|
||||
template void SDDMMCsr<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLGPU, int32_t, 32>(
|
||||
template void SDDMMCsr<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLGPU, int64_t, 32>(
|
||||
template void SDDMMCsr<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLGPU, int32_t, 64>(
|
||||
template void SDDMMCsr<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCsr<kDLGPU, int64_t, 64>(
|
||||
template void SDDMMCsr<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
|
||||
template void SDDMMCoo<kDLGPU, int32_t, 16>(
|
||||
template void SDDMMCoo<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLGPU, int64_t, 16>(
|
||||
template void SDDMMCoo<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLGPU, int32_t, 32>(
|
||||
template void SDDMMCoo<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLGPU, int64_t, 32>(
|
||||
template void SDDMMCoo<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLGPU, int32_t, 64>(
|
||||
template void SDDMMCoo<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
template void SDDMMCoo<kDLGPU, int64_t, 64>(
|
||||
template void SDDMMCoo<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray lhs, NDArray rhs, NDArray out,
|
||||
int lhs_target, int rhs_target);
|
||||
|
||||
@@ -42,42 +42,42 @@ void SDDMMCooHetero(const std::string& op,
|
||||
}
|
||||
|
||||
|
||||
template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
|
||||
template void SDDMMCooHetero<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<COOMatrix>& vec_coo,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
|
||||
@@ -41,42 +41,42 @@ void SDDMMCsrHetero(const std::string& op,
|
||||
});
|
||||
}
|
||||
|
||||
template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
std::vector<NDArray> out, int lhs_target, int rhs_target,
|
||||
const std::vector<dgl_type_t>& in_eid,
|
||||
const std::vector<dgl_type_t>& out_eid);
|
||||
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
|
||||
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const BcastOff& bcast,
|
||||
const std::vector<CSRMatrix>& vec_csr,
|
||||
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
|
||||
|
||||
@@ -73,113 +73,113 @@ void BackwardSegmentCmp(NDArray feat,
|
||||
}
|
||||
|
||||
|
||||
template void SegmentReduce<kDLGPU, int32_t, 16>(
|
||||
template void SegmentReduce<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLGPU, int64_t, 16>(
|
||||
template void SegmentReduce<kDGLCUDA, int64_t, 16>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLGPU, int32_t, 32>(
|
||||
template void SegmentReduce<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLGPU, int64_t, 32>(
|
||||
template void SegmentReduce<kDGLCUDA, int64_t, 32>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLGPU, int32_t, 64>(
|
||||
template void SegmentReduce<kDGLCUDA, int32_t, 64>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void SegmentReduce<kDLGPU, int64_t, 64>(
|
||||
template void SegmentReduce<kDGLCUDA, int64_t, 64>(
|
||||
const std::string &op,
|
||||
NDArray feat,
|
||||
NDArray offsets,
|
||||
NDArray out,
|
||||
NDArray arg);
|
||||
template void ScatterAdd<kDLGPU, int32_t, 16>(
|
||||
template void ScatterAdd<kDGLCUDA, int32_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLGPU, int64_t, 16>(
|
||||
template void ScatterAdd<kDGLCUDA, int64_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLGPU, int32_t, 32>(
|
||||
template void ScatterAdd<kDGLCUDA, int32_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLGPU, int64_t, 32>(
|
||||
template void ScatterAdd<kDGLCUDA, int64_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLGPU, int32_t, 64>(
|
||||
template void ScatterAdd<kDGLCUDA, int32_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
template void ScatterAdd<kDLGPU, int64_t, 64>(
|
||||
template void ScatterAdd<kDGLCUDA, int64_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray idx,
|
||||
NDArray out);
|
||||
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 16>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 16>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 16>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 16>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 32>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 32>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 32>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 32>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 64>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 64>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 64>(
|
||||
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 64>(
|
||||
const HeteroGraphPtr& g, const std::string& op,
|
||||
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
|
||||
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
|
||||
|
||||
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLGPU, int64_t, 16>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 16>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLGPU, int32_t, 32>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLGPU, int64_t, 32>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 32>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLGPU, int32_t, 64>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
template void BackwardSegmentCmp<kDLGPU, int64_t, 64>(
|
||||
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 64>(
|
||||
NDArray feat,
|
||||
NDArray arg,
|
||||
NDArray out);
|
||||
|
||||
@@ -71,7 +71,7 @@ __global__ void _COOGetRowNNZKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto& ctx = coo.row->ctx;
|
||||
@@ -84,12 +84,12 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
|
||||
nb, nt, 0, stream,
|
||||
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(),
|
||||
row, nnz);
|
||||
rst = rst.CopyTo(DLContext{kDLCPU, 0});
|
||||
rst = rst.CopyTo(DGLContext{kDGLCPU, 0});
|
||||
return *rst.Ptr<IdType>();
|
||||
}
|
||||
|
||||
template int64_t COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, int64_t);
|
||||
template int64_t COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, int64_t);
|
||||
|
||||
template <typename IdType>
|
||||
__global__ void _COOGetAllRowNNZKernel(
|
||||
@@ -104,7 +104,7 @@ __global__ void _COOGetAllRowNNZKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto& ctx = coo.row->ctx;
|
||||
@@ -112,7 +112,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
|
||||
IdType num_rows = coo.num_rows;
|
||||
IdType num_queries = rows->shape[0];
|
||||
if (num_queries == 1) {
|
||||
auto rows_cpu = rows.CopyTo(DLContext{kDLCPU, 0});
|
||||
auto rows_cpu = rows.CopyTo(DGLContext{kDGLCPU, 0});
|
||||
int64_t row = *rows_cpu.Ptr<IdType>();
|
||||
IdType nt = 1024;
|
||||
IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);
|
||||
@@ -136,8 +136,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
|
||||
}
|
||||
}
|
||||
|
||||
template NDArray COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, NDArray);
|
||||
template NDArray COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, NDArray);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace aten
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace impl {
|
||||
|
||||
///////////////////////////// CSRIsNonZero /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto& ctx = csr.indptr->ctx;
|
||||
@@ -38,14 +38,14 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
|
||||
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
|
||||
1, 1, 1,
|
||||
static_cast<IdType*>(nullptr), static_cast<IdType>(-1), out.Ptr<IdType>());
|
||||
out = out.CopyTo(DLContext{kDLCPU, 0});
|
||||
out = out.CopyTo(DGLContext{kDGLCPU, 0});
|
||||
return *out.Ptr<IdType>() != -1;
|
||||
}
|
||||
|
||||
template bool CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
|
||||
const auto rowlen = row->shape[0];
|
||||
const auto collen = col->shape[0];
|
||||
@@ -69,8 +69,8 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
|
||||
return rst != -1;
|
||||
}
|
||||
|
||||
template NDArray CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);
|
||||
template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
|
||||
|
||||
///////////////////////////// CSRHasDuplicate /////////////////////////////
|
||||
|
||||
@@ -95,7 +95,7 @@ __global__ void _SegmentHasNoDuplicate(
|
||||
}
|
||||
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
bool CSRHasDuplicate(CSRMatrix csr) {
|
||||
if (!csr.sorted)
|
||||
csr = CSRSort(csr);
|
||||
@@ -116,20 +116,20 @@ bool CSRHasDuplicate(CSRMatrix csr) {
|
||||
return !ret;
|
||||
}
|
||||
|
||||
template bool CSRHasDuplicate<kDLGPU, int32_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDLGPU, int64_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);
|
||||
template bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);
|
||||
|
||||
///////////////////////////// CSRGetRowNNZ /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
|
||||
const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);
|
||||
const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);
|
||||
return next - cur;
|
||||
}
|
||||
|
||||
template int64_t CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
|
||||
template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
template <typename IdType>
|
||||
__global__ void _CSRGetRowNNZKernel(
|
||||
@@ -146,7 +146,7 @@ __global__ void _CSRGetRowNNZKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto len = rows->shape[0];
|
||||
@@ -162,24 +162,24 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
|
||||
return rst;
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
|
||||
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
|
||||
|
||||
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
|
||||
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
|
||||
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
|
||||
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowColumnIndices<kDLGPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDLGPU, int64_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
///////////////////////////// CSRGetRowData /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
|
||||
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
|
||||
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
|
||||
@@ -189,12 +189,12 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
|
||||
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
|
||||
}
|
||||
|
||||
template NDArray CSRGetRowData<kDLGPU, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDLGPU, int64_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
|
||||
template NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
|
||||
|
||||
///////////////////////////// CSRSliceRows /////////////////////////////
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
|
||||
const int64_t num_rows = end - start;
|
||||
const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);
|
||||
@@ -215,8 +215,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
|
||||
csr.sorted);
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
|
||||
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
|
||||
|
||||
/*!
|
||||
* \brief Copy data segment to output buffers
|
||||
@@ -243,7 +243,7 @@ __global__ void _SegmentCopyKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const int64_t len = rows->shape[0];
|
||||
@@ -272,8 +272,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
|
||||
csr.sorted);
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix , NDArray);
|
||||
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray);
|
||||
|
||||
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
|
||||
|
||||
@@ -345,7 +345,7 @@ __global__ void _SortedSearchKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) {
|
||||
const auto rowlen = row->shape[0];
|
||||
const auto collen = col->shape[0];
|
||||
@@ -392,9 +392,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
|
||||
return {ret_row, ret_col, ret_data};
|
||||
}
|
||||
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int32_t>(
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols);
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>(
|
||||
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
|
||||
CSRMatrix csr, NDArray rows, NDArray cols);
|
||||
|
||||
///////////////////////////// CSRSliceMatrix /////////////////////////////
|
||||
@@ -422,7 +422,7 @@ __global__ void _SegmentMaskColKernel(
|
||||
}
|
||||
}
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto& ctx = rows->ctx;
|
||||
@@ -501,9 +501,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
|
||||
ret_col, ret_data);
|
||||
}
|
||||
|
||||
template CSRMatrix CSRSliceMatrix<kDLGPU, int32_t>(
|
||||
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
template CSRMatrix CSRSliceMatrix<kDLGPU, int64_t>(
|
||||
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(
|
||||
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
|
||||
|
||||
} // namespace impl
|
||||
|
||||
@@ -147,53 +147,53 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
|
||||
}
|
||||
}
|
||||
|
||||
template void SpMMCsr<kDLGPU, int32_t, 16>(
|
||||
template void SpMMCsr<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLGPU, int64_t, 16>(
|
||||
template void SpMMCsr<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLGPU, int32_t, 32>(
|
||||
template void SpMMCsr<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLGPU, int64_t, 32>(
|
||||
template void SpMMCsr<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLGPU, int32_t, 64>(
|
||||
template void SpMMCsr<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCsr<kDLGPU, int64_t, 64>(
|
||||
template void SpMMCsr<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const CSRMatrix& csr,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
|
||||
template void SpMMCoo<kDLGPU, int32_t, 16>(
|
||||
template void SpMMCoo<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLGPU, int64_t, 16>(
|
||||
template void SpMMCoo<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLGPU, int32_t, 32>(
|
||||
template void SpMMCoo<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLGPU, int64_t, 32>(
|
||||
template void SpMMCoo<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLGPU, int32_t, 64>(
|
||||
template void SpMMCoo<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
template void SpMMCoo<kDLGPU, int64_t, 64>(
|
||||
template void SpMMCoo<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const COOMatrix& coo,
|
||||
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
|
||||
|
||||
@@ -203,7 +203,7 @@ cusparseStatus_t Xcsrmm2<double>(cusparseHandle_t handle, cusparseOperation_t tr
|
||||
/*! Cusparse implementation of SpMM on Csr format. */
|
||||
template <typename DType, typename IdType>
|
||||
void CusparseCsrmm2(
|
||||
const DLContext& ctx,
|
||||
const DGLContext& ctx,
|
||||
const CSRMatrix& csr,
|
||||
const DType* B_data, const DType* A_data,
|
||||
DType* C_data,
|
||||
@@ -303,7 +303,7 @@ void CusparseCsrmm2(
|
||||
/*! Cusparse implementation of SpMM on Csr format. */
|
||||
template <typename DType, typename IdType>
|
||||
void CusparseCsrmm2Hetero(
|
||||
const DLContext& ctx,
|
||||
const DGLContext& ctx,
|
||||
const CSRMatrix& csr,
|
||||
const DType* B_data, const DType* A_data,
|
||||
DType* C_data,
|
||||
|
||||
@@ -199,37 +199,37 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
|
||||
});
|
||||
}
|
||||
|
||||
template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int32_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
|
||||
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int64_t, 16>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
|
||||
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int32_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
|
||||
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int64_t, 32>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
|
||||
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int32_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
|
||||
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
|
||||
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
|
||||
template void SpMMCsrHetero<kDGLCUDA, int64_t, 64>(
|
||||
const std::string& op, const std::string& reduce,
|
||||
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
|
||||
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
namespace dgl {
|
||||
namespace cuda {
|
||||
|
||||
bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) {
|
||||
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {
|
||||
auto device = runtime::DeviceAPI::Get(ctx);
|
||||
int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
|
||||
// Call CUB's reduction
|
||||
|
||||
@@ -7,9 +7,9 @@
|
||||
#define DGL_ARRAY_CUDA_UTILS_H_
|
||||
|
||||
#include <dmlc/logging.h>
|
||||
#include <dgl/runtime/c_runtime_api.h>
|
||||
#include <dgl/runtime/device_api.h>
|
||||
#include <dgl/runtime/ndarray.h>
|
||||
#include <dlpack/dlpack.h>
|
||||
#include "../../runtime/cuda/cuda_common.h"
|
||||
|
||||
namespace dgl {
|
||||
@@ -115,7 +115,7 @@ __device__ __forceinline__ T _ldg(T* addr) {
|
||||
* \param ctx Device context.
|
||||
* \return True if all the flags are true.
|
||||
*/
|
||||
bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx);
|
||||
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx);
|
||||
|
||||
/*!
|
||||
* \brief CUDA Kernel of filling the vector started from ptr of size length
|
||||
@@ -187,7 +187,7 @@ __global__ void _LinearSearchKernel(
|
||||
template <typename DType>
|
||||
inline DType GetCUDAScalar(
|
||||
runtime::DeviceAPI* device_api,
|
||||
DLContext ctx,
|
||||
DGLContext ctx,
|
||||
const DType* cuda_ptr) {
|
||||
DType result;
|
||||
device_api->CopyDataFromTo(
|
||||
@@ -195,8 +195,8 @@ inline DType GetCUDAScalar(
|
||||
&result, 0,
|
||||
sizeof(result),
|
||||
ctx,
|
||||
DLContext{kDLCPU, 0},
|
||||
DLDataTypeTraits<DType>::dtype);
|
||||
DGLContext{kDGLCPU, 0},
|
||||
DGLDataTypeTraits<DType>::dtype);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
|
||||
std::vector<int64_t> shape{len};
|
||||
|
||||
CHECK(array.IsPinned());
|
||||
CHECK_EQ(index->ctx.device_type, kDLGPU);
|
||||
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
|
||||
|
||||
for (int d = 1; d < array->ndim; ++d) {
|
||||
num_feat *= array->shape[d];
|
||||
@@ -85,8 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
|
||||
std::vector<int64_t> shape{len};
|
||||
|
||||
CHECK(dest.IsPinned());
|
||||
CHECK_EQ(index->ctx.device_type, kDLGPU);
|
||||
CHECK_EQ(source->ctx.device_type, kDLGPU);
|
||||
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
|
||||
CHECK_EQ(source->ctx.device_type, kDGLCUDA);
|
||||
|
||||
for (int d = 1; d < source->ndim; ++d) {
|
||||
num_feat *= source->shape[d];
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace array {
|
||||
|
||||
using namespace dgl::runtime;
|
||||
|
||||
template<DLDeviceType XPU, typename IdType>
|
||||
template<DGLDeviceType XPU, typename IdType>
|
||||
FilterRef CreateSetFilter(IdArray set);
|
||||
|
||||
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
|
||||
@@ -23,10 +23,10 @@ DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
|
||||
IdArray array = args[0];
|
||||
auto ctx = array->ctx;
|
||||
// TODO(nv-dlasalle): Implement CPU version.
|
||||
if (ctx.device_type == kDLGPU) {
|
||||
if (ctx.device_type == kDGLCUDA) {
|
||||
#ifdef DGL_USE_CUDA
|
||||
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
|
||||
*rv = CreateSetFilter<kDLGPU, IdType>(array);
|
||||
*rv = CreateSetFilter<kDGLCUDA, IdType>(array);
|
||||
});
|
||||
#else
|
||||
LOG(FATAL) << "GPU support not compiled.";
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#ifdef USE_TVM
|
||||
#include <featgraph.h>
|
||||
#include <dgl/runtime/dlpack_convert.h>
|
||||
#endif // USE_TVM
|
||||
|
||||
#include "kernel_decl.h"
|
||||
@@ -70,7 +71,7 @@ void SegmentMM(const NDArray A,
|
||||
}
|
||||
CHECK_EQ(B->shape[0], seglen_A.NumElements())
|
||||
<< "segment_mm expects len(seglen_A) == B.shape[0]";
|
||||
CHECK_EQ(seglen_A->ctx.device_type, kDLCPU)
|
||||
CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU)
|
||||
<< "segment_mm expects seglen_A to be on CPU.";
|
||||
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device";
|
||||
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
|
||||
@@ -89,7 +90,7 @@ void SegmentMMBackwardB(const NDArray A,
|
||||
CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor for the first input.";
|
||||
CHECK_EQ(dC->ndim, 2)
|
||||
<< "segment_mm_backward operator expects a 2D tensor for the second input.";
|
||||
CHECK_EQ(seglen->ctx.device_type, kDLCPU)
|
||||
CHECK_EQ(seglen->ctx.device_type, kDGLCPU)
|
||||
<< "segment_mm expects seglen to be on CPU.";
|
||||
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
|
||||
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
|
||||
@@ -829,8 +830,12 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction")
|
||||
// {lhs, rhs, out},
|
||||
// {"U_data", "E_data", "V_data"});
|
||||
COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
|
||||
dgl::featgraph::SDDMMTreeReduction(coo.row.ToDLPack(), coo.col.ToDLPack(),
|
||||
lhs.ToDLPack(), rhs.ToDLPack(), out.ToDLPack());
|
||||
dgl::featgraph::SDDMMTreeReduction(
|
||||
DLPackConvert::ToDLPack(coo.row),
|
||||
DLPackConvert::ToDLPack(coo.col),
|
||||
DLPackConvert::ToDLPack(lhs),
|
||||
DLPackConvert::ToDLPack(rhs),
|
||||
DLPackConvert::ToDLPack(out));
|
||||
});
|
||||
#endif // USE_TVM
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ namespace aten {
|
||||
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
|
||||
#ifdef DGL_USE_CUDA
|
||||
CHECK(array.IsPinned()) << "Input array must be in pinned memory.";
|
||||
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
|
||||
CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
|
||||
CHECK_GE(array->ndim, 1) << "Input array must have at least 1 dimension.";
|
||||
CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";
|
||||
|
||||
@@ -34,8 +34,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
|
||||
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
|
||||
#ifdef DGL_USE_CUDA
|
||||
CHECK(dest.IsPinned()) << "Destination array must be in pinned memory.";
|
||||
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
|
||||
CHECK_EQ(source->ctx.device_type, kDLGPU) << "Source array must be on the GPU.";
|
||||
CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
|
||||
CHECK_EQ(source->ctx.device_type, kDGLCUDA) << "Source array must be on the GPU.";
|
||||
CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source "
|
||||
"array must have the same dtype.";
|
||||
CHECK_GE(dest->ndim, 1) << "Destination array must have at least 1 dimension.";
|
||||
|
||||
@@ -41,8 +41,8 @@ dgl::runtime::NDArray CopyVectorToNDArray(
|
||||
const std::vector<DType>& vec) {
|
||||
using dgl::runtime::NDArray;
|
||||
const int64_t len = vec.size();
|
||||
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType) * 8, 1},
|
||||
DLContext{kDLCPU, 0});
|
||||
NDArray a = NDArray::Empty({len}, DGLDataType{kDGLInt, sizeof(IdType) * 8, 1},
|
||||
DGLContext{kDGLCPU, 0});
|
||||
std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
|
||||
return a;
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ template void GroupIndexShuffle<int64_t>(
|
||||
|
||||
template <typename IdType>
|
||||
IdArray RandomPerm(int64_t num_nodes) {
|
||||
IdArray perm = aten::NewIdArray(num_nodes, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdType* perm_data = static_cast<IdType*>(perm->data);
|
||||
std::iota(perm_data, perm_data + num_nodes, 0);
|
||||
IndexShuffle(perm_data, num_nodes);
|
||||
@@ -59,7 +59,7 @@ IdArray RandomPerm(int64_t num_nodes) {
|
||||
|
||||
template <typename IdType>
|
||||
IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
|
||||
IdArray perm = aten::NewIdArray(num_nodes, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
|
||||
IdType* perm_data = static_cast<IdType*>(perm->data);
|
||||
std::iota(perm_data, perm_data + num_nodes, 0);
|
||||
GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);
|
||||
@@ -77,7 +77,7 @@ IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_
|
||||
* Finally, we pick the point with the maximum such distance.
|
||||
* This process will be repeated for ``sample_points`` - 1 times.
|
||||
*/
|
||||
template <DLDeviceType XPU, typename FloatType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename FloatType, typename IdType>
|
||||
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
|
||||
NDArray dist, IdArray start_idx, IdArray result) {
|
||||
const FloatType* array_data = static_cast<FloatType*>(array->data);
|
||||
@@ -135,20 +135,20 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
|
||||
ret_start += sample_points;
|
||||
}
|
||||
}
|
||||
template void FarthestPointSampler<kDLCPU, float, int32_t>(
|
||||
template void FarthestPointSampler<kDGLCPU, float, int32_t>(
|
||||
NDArray array, int64_t batch_size, int64_t sample_points,
|
||||
NDArray dist, IdArray start_idx, IdArray result);
|
||||
template void FarthestPointSampler<kDLCPU, float, int64_t>(
|
||||
template void FarthestPointSampler<kDGLCPU, float, int64_t>(
|
||||
NDArray array, int64_t batch_size, int64_t sample_points,
|
||||
NDArray dist, IdArray start_idx, IdArray result);
|
||||
template void FarthestPointSampler<kDLCPU, double, int32_t>(
|
||||
template void FarthestPointSampler<kDGLCPU, double, int32_t>(
|
||||
NDArray array, int64_t batch_size, int64_t sample_points,
|
||||
NDArray dist, IdArray start_idx, IdArray result);
|
||||
template void FarthestPointSampler<kDLCPU, double, int64_t>(
|
||||
template void FarthestPointSampler<kDGLCPU, double, int64_t>(
|
||||
NDArray array, int64_t batch_size, int64_t sample_points,
|
||||
NDArray dist, IdArray start_idx, IdArray result);
|
||||
|
||||
template <DLDeviceType XPU, typename FloatType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename FloatType, typename IdType>
|
||||
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
|
||||
const int64_t num_nodes = result->shape[0];
|
||||
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -181,16 +181,16 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
|
||||
result_data[v_max] = result_data[u];
|
||||
}
|
||||
}
|
||||
template void WeightedNeighborMatching<kDLCPU, float, int32_t>(
|
||||
template void WeightedNeighborMatching<kDGLCPU, float, int32_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLCPU, float, int64_t>(
|
||||
template void WeightedNeighborMatching<kDGLCPU, float, int64_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLCPU, double, int32_t>(
|
||||
template void WeightedNeighborMatching<kDGLCPU, double, int32_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLCPU, double, int64_t>(
|
||||
template void WeightedNeighborMatching<kDGLCPU, double, int64_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
|
||||
const int64_t num_nodes = result->shape[0];
|
||||
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
|
||||
@@ -221,8 +221,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
|
||||
}
|
||||
}
|
||||
}
|
||||
template void NeighborMatching<kDLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDGLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDGLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace geometry
|
||||
|
||||
@@ -150,7 +150,7 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
|
||||
* are marked, mark this node with its id. Else match this (BLUE, RED) node
|
||||
* pair and mark them with the smaller id between them.
|
||||
*/
|
||||
template <DLDeviceType XPU, typename FloatType, typename IdType>
|
||||
template <DGLDeviceType XPU, typename FloatType, typename IdType>
|
||||
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
const auto& ctx = result->ctx;
|
||||
@@ -182,13 +182,13 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
|
||||
}
|
||||
device->FreeWorkspace(ctx, prop);
|
||||
}
|
||||
template void WeightedNeighborMatching<kDLGPU, float, int32_t>(
|
||||
template void WeightedNeighborMatching<kDGLCUDA, float, int32_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLGPU, float, int64_t>(
|
||||
template void WeightedNeighborMatching<kDGLCUDA, float, int64_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLGPU, double, int32_t>(
|
||||
template void WeightedNeighborMatching<kDGLCUDA, double, int32_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
|
||||
template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
|
||||
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
|
||||
|
||||
/*! \brief Unweighted neighbor matching procedure (GPU version).
|
||||
@@ -201,7 +201,7 @@ template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
|
||||
* 2. Graph is sparse, thus neighborhood of each node is small,
|
||||
* which is suitable for GPU implementation.
|
||||
*/
|
||||
template <DLDeviceType XPU, typename IdType>
|
||||
template <DGLDeviceType XPU, typename IdType>
|
||||
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
|
||||
const int64_t num_edges = csr.indices->shape[0];
|
||||
const auto& ctx = result->ctx;
|
||||
@@ -211,7 +211,7 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
|
||||
// generate random weights
|
||||
cudaStream_t stream = runtime::getCurrentCUDAStream();
|
||||
NDArray weight = NDArray::Empty(
|
||||
{num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
|
||||
{num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);
|
||||
float *weight_data = static_cast<float*>(weight->data);
|
||||
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
|
||||
auto num_threads = cuda::FindNumThreads(num_edges);
|
||||
@@ -221,8 +221,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
|
||||
|
||||
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
|
||||
}
|
||||
template void NeighborMatching<kDLGPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDLGPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDGLCUDA, int32_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
template void NeighborMatching<kDGLCUDA, int64_t>(const aten::CSRMatrix &csr, IdArray result);
|
||||
|
||||
} // namespace impl
|
||||
} // namespace geometry
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user