[Feature] Bump DLPack to v0.7 and decouple DLPack from the core library (#4454)

* rename `DLContext` to `DGLContext`

* rename `kDLGPU` to `kDLCUDA`

* replace DLTensor with DGLArray

* fix linting

* Unify DGLType and DLDataType to DGLDataType

* Fix FFI

* rename DLDeviceType to DGLDeviceType

* decouple dlpack from the core library

* fix bug

* fix lint

* fix merge

* fix build

* address comments

* rename dl_converter to dlpack_convert

* remove redundant comments
This commit is contained in:
Xin Yao
2022-09-19 16:02:43 +08:00
committed by GitHub
parent f1689ad0e1
commit cded5b80fe
171 changed files with 2213 additions and 2073 deletions

View File

@@ -24,8 +24,8 @@ namespace aten {
//////////////////////////////////////////////////////////////////////
/*! \return A special array to represent null. */
inline NDArray NullArray(const DLDataType& dtype = DLDataType{kDLInt, 64, 1},
const DLContext& ctx = DLContext{kDLCPU, 0}) {
inline NDArray NullArray(const DGLDataType& dtype = DGLDataType{kDGLInt, 64, 1},
const DGLContext& ctx = DGLContext{kDGLCPU, 0}) {
return NDArray::Empty({0}, dtype, ctx);
}
@@ -44,7 +44,7 @@ inline bool IsNullArray(NDArray array) {
* \return id array
*/
IdArray NewIdArray(int64_t length,
DLContext ctx = DLContext{kDLCPU, 0},
DGLContext ctx = DGLContext{kDGLCPU, 0},
uint8_t nbits = 64);
/*!
@@ -57,7 +57,7 @@ IdArray NewIdArray(int64_t length,
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits = 64,
DLContext ctx = DLContext{kDLCPU, 0});
DGLContext ctx = DGLContext{kDGLCPU, 0});
/*!
* \brief Return an array representing a 1D range.
@@ -67,7 +67,7 @@ IdArray VecToIdArray(const std::vector<T>& vec,
* \param ctx Device context
* \return range array
*/
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx);
/*!
* \brief Return an array full of the given value
@@ -77,7 +77,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx);
* \param ctx Device context
* \return the result array
*/
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx);
/*!
* \brief Return an array full of the given value with the given type.
@@ -87,7 +87,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx);
* \return the result array
*/
template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx);
NDArray Full(DType val, int64_t length, DGLContext ctx);
/*! \brief Create a deep copy of the given array */
IdArray Clone(IdArray arr);
@@ -226,7 +226,7 @@ NDArray Concat(const std::vector<IdArray>& arrays);
/*!\brief Return whether the array is a valid 1D int array*/
inline bool IsValidIdArray(const dgl::runtime::NDArray& arr) {
return arr->ndim == 1 && arr->dtype.code == kDLInt;
return arr->ndim == 1 && arr->dtype.code == kDGLInt;
}
/*!
@@ -343,8 +343,8 @@ std::string ToDebugString(NDArray array);
template <typename T>
IdArray VecToIdArray(const std::vector<T>& vec,
uint8_t nbits,
DLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DLContext{kDLCPU, 0}, nbits);
DGLContext ctx) {
IdArray ret = NewIdArray(vec.size(), DGLContext{kDGLCPU, 0}, nbits);
if (nbits == 32) {
std::copy(vec.begin(), vec.end(), static_cast<int32_t*>(ret->data));
} else if (nbits == 64) {
@@ -359,9 +359,9 @@ IdArray VecToIdArray(const std::vector<T>& vec,
* \brief Get the context of the first array, and check if the non-null arrays'
* contexts are the same.
*/
inline DLContext GetContextOf(const std::vector<IdArray>& arrays) {
inline DGLContext GetContextOf(const std::vector<IdArray>& arrays) {
bool first = true;
DLContext result;
DGLContext result;
for (auto& array : arrays) {
if (first) {
first = false;

View File

@@ -122,11 +122,10 @@ struct COOMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline COOMatrix CopyTo(const DLContext &ctx) const {
inline COOMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == row->ctx)
return *this;
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx),
col.CopyTo(ctx),
return COOMatrix(num_rows, num_cols, row.CopyTo(ctx), col.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
row_sorted, col_sorted);
}
@@ -134,9 +133,9 @@ struct COOMatrix {
/*!
* \brief Pin the row, col and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {

View File

@@ -115,21 +115,19 @@ struct CSRMatrix {
}
/*! \brief Return a copy of this matrix on the give device context. */
inline CSRMatrix CopyTo(const DLContext &ctx) const {
inline CSRMatrix CopyTo(const DGLContext &ctx) const {
if (ctx == indptr->ctx)
return *this;
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx),
indices.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx),
sorted);
return CSRMatrix(num_rows, num_cols, indptr.CopyTo(ctx), indices.CopyTo(ctx),
aten::IsNullArray(data) ? data : data.CopyTo(ctx), sorted);
}
/*!
* \brief Pin the indptr, indices and data (if not Null) of the matrix.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDGLCUDA: invalid, will throw an error.
* The context check is deferred to pinning the NDArray.
*/
inline void PinMemory_() {

View File

@@ -18,8 +18,8 @@
* });
*/
#define ATEN_XPU_SWITCH(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
@@ -43,11 +43,11 @@
*/
#ifdef DGL_USE_CUDA
#define ATEN_XPU_SWITCH_CUDA(val, XPU, op, ...) do { \
if ((val) == kDLCPU) { \
constexpr auto XPU = kDLCPU; \
if ((val) == kDGLCPU) { \
constexpr auto XPU = kDGLCPU; \
{__VA_ARGS__} \
} else if ((val) == kDLGPU) { \
constexpr auto XPU = kDLGPU; \
} else if ((val) == kDGLCUDA) { \
constexpr auto XPU = kDGLCUDA; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << "Operator " << (op) << " does not support " \
@@ -69,7 +69,7 @@
* });
*/
#define ATEN_ID_TYPE_SWITCH(val, IdType, ...) do { \
CHECK_EQ((val).code, kDLInt) << "ID must be integer type"; \
CHECK_EQ((val).code, kDGLInt) << "ID must be integer type"; \
if ((val).bits == 32) { \
typedef int32_t IdType; \
{__VA_ARGS__} \
@@ -114,7 +114,7 @@
* });
*/
#define ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
CHECK_EQ((val).code, kDGLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
@@ -128,7 +128,7 @@
} while (0)
#define ATEN_FLOAT_BITS_SWITCH(val, bits, val_name, ...) do { \
CHECK_EQ((val).code, kDLFloat) \
CHECK_EQ((val).code, kDGLFloat) \
<< (val_name) << " must be float type"; \
if ((val).bits == 16) { \
constexpr int bits = 16; \
@@ -154,16 +154,16 @@
* });
*/
#define ATEN_DTYPE_SWITCH(val, DType, val_name, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 32) { \
} else if ((val).code == kDGLFloat && (val).bits == 32) { \
typedef float DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLFloat && (val).bits == 64) { \
} else if ((val).code == kDGLFloat && (val).bits == 64) { \
typedef double DType; \
{__VA_ARGS__} \
} else { \
@@ -205,10 +205,10 @@
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
*/
#define ATEN_CSR_DTYPE_SWITCH(val, DType, ...) do { \
if ((val).code == kDLInt && (val).bits == 32) { \
if ((val).code == kDGLInt && (val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).code == kDLInt && (val).bits == 64) { \
} else if ((val).code == kDGLInt && (val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
@@ -278,13 +278,13 @@
///////////////////////// Array checks //////////////////////////
#define IS_INT32(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 32)
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 32)
#define IS_INT64(a) \
((a)->dtype.code == kDLInt && (a)->dtype.bits == 64)
((a)->dtype.code == kDGLInt && (a)->dtype.bits == 64)
#define IS_FLOAT32(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 32)
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 32)
#define IS_FLOAT64(a) \
((a)->dtype.code == kDLFloat && (a)->dtype.bits == 64)
((a)->dtype.code == kDGLFloat && (a)->dtype.bits == 64)
#define CHECK_IF(cond, prop, value_name, dtype_name) \
CHECK(cond) << "Expecting " << (prop) << " of " << (value_name) << " to be " << (dtype_name)

View File

@@ -34,7 +34,7 @@ typedef NDArray TypeArray;
namespace aten {
static const DLContext CPU{kDLCPU, 0};
static const DGLContext CPU{kDGLCPU, 0};
} // namespace aten
} // namespace dgl

View File

@@ -104,12 +104,12 @@ class BaseHeteroGraph : public runtime::Object {
/*!
* \brief Get the data type of node and edge IDs of this graph.
*/
virtual DLDataType DataType() const = 0;
virtual DGLDataType DataType() const = 0;
/*!
* \brief Get the device context of this graph.
*/
virtual DLContext Context() const = 0;
virtual DGLContext Context() const = 0;
/*!
* \brief Pin graph.

View File

@@ -89,8 +89,8 @@ class Graph: public GraphInterface {
num_edges_ = 0;
}
DLContext Context() const override {
return DLContext{kDLCPU, 0};
DGLContext Context() const override {
return DGLContext{kDGLCPU, 0};
}
uint8_t NumBits() const override {

View File

@@ -137,7 +137,7 @@ class GraphInterface : public runtime::Object {
/*!
* \brief Get the device context of this graph.
*/
virtual DLContext Context() const = 0;
virtual DGLContext Context() const = 0;
/*!
* \brief Get the number of integer bits used to store node/edge ids (32 or 64).

View File

@@ -69,7 +69,7 @@ class CSR : public GraphInterface {
LOG(FATAL) << "CSR graph does not allow mutation.";
}
DLContext Context() const override {
DGLContext Context() const override {
return adj_.indptr->ctx;
}
@@ -214,7 +214,7 @@ class CSR : public GraphInterface {
* \param ctx The target context.
* \return The graph under another context.
*/
CSR CopyTo(const DLContext& ctx) const;
CSR CopyTo(const DGLContext& ctx) const;
/*!
* \brief Copy data to shared memory.
@@ -288,7 +288,7 @@ class COO : public GraphInterface {
LOG(FATAL) << "COO graph does not allow mutation.";
}
DLContext Context() const override {
DGLContext Context() const override {
return adj_.row->ctx;
}
@@ -472,7 +472,7 @@ class COO : public GraphInterface {
* \param ctx The target context.
* \return The graph under another context.
*/
COO CopyTo(const DLContext& ctx) const;
COO CopyTo(const DGLContext& ctx) const;
/*!
* \brief Copy data to shared memory.
@@ -578,7 +578,7 @@ class ImmutableGraph: public GraphInterface {
LOG(FATAL) << "Clear isn't supported in ImmutableGraph";
}
DLContext Context() const override {
DGLContext Context() const override {
return AnyGraph()->Context();
}
@@ -911,7 +911,7 @@ class ImmutableGraph: public GraphInterface {
* \param ctx The target context.
* \return The graph under another context.
*/
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DLContext& ctx);
static ImmutableGraphPtr CopyTo(ImmutableGraphPtr g, const DGLContext& ctx);
/*!
* \brief Copy data to shared memory.

View File

@@ -145,7 +145,7 @@ class RandomEngine {
*/
template <typename IdxType, typename FloatType>
IdArray Choice(IdxType num, FloatArray prob, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, prob->ctx);
Choice<IdxType, FloatType>(num, prob, static_cast<IdxType*>(ret->data), replace);
return ret;
@@ -178,9 +178,9 @@ class RandomEngine {
*/
template <typename IdxType>
IdArray UniformChoice(IdxType num, IdxType population, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
// TODO(minjie): only CPU implementation right now
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
UniformChoice<IdxType>(num, population, static_cast<IdxType*>(ret->data), replace);
return ret;
}
@@ -230,8 +230,8 @@ class RandomEngine {
template <typename IdxType, typename FloatType>
IdArray BiasedChoice(
IdxType num, const IdxType *split, FloatArray bias, bool replace = true) {
const DLDataType dtype{kDLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DLContext{kDLCPU, 0});
const DGLDataType dtype{kDGLInt, sizeof(IdxType) * 8, 1};
IdArray ret = IdArray::Empty({num}, dtype, DGLContext{kDGLCPU, 0});
BiasedChoice<IdxType, FloatType>(num, split, bias, static_cast<IdxType*>(ret->data), replace);
return ret;
}

View File

@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2016 by Contributors
* Copyright (c) 2016-2022 by Contributors
* \file dgl/runtime/c_runtime_api.h
* \brief DGL runtime library.
*
@@ -35,10 +35,6 @@
// DGL version
#define DGL_VERSION "0.9"
// DGL Runtime is DLPack compatible.
#include <dlpack/dlpack.h>
#ifdef __cplusplus
extern "C" {
#endif
@@ -48,28 +44,31 @@ extern "C" {
/*! \brief type of array index. */
typedef int64_t dgl_index_t;
/*! \brief Extension device types in DGL */
/*!
* \brief The device type in DGLContext.
*/
#ifdef __cplusplus
typedef enum : int32_t {
#else
typedef enum {
kDLAOCL = 5,
kDLSDAccel = 6,
kOpenGL = 11,
// Extension DRAM type, used for quickly test extension device
// The device api can differ depending on the xpu driver registered.
kExtDev = 12,
// AddExtraDGLType which is not in DLPack here
} DGLDeviceExtType;
#endif
/*! \brief CPU device */
kDGLCPU = 1,
/*! \brief CUDA GPU device */
kDGLCUDA = 2,
// add more devices once supported
} DGLDeviceType;
/*!
* \brief The type code in DGLType
* \note DGLType is used in two places.
* \brief The object type code is used in DGL FFI to indicate the types of objects passed between C and Python.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by DGL API calls.
kInt = 0U,
kUInt = 1U,
kFloat = 2U,
kHandle = 3U,
kNull = 4U,
kDGLType = 5U,
kDGLDataType = 5U,
kDGLContext = 6U,
kArrayHandle = 7U,
kObjectHandle = 8U,
@@ -88,29 +87,112 @@ typedef enum {
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U
} DGLTypeCode;
} DGLObjectTypeCode;
/*!
* \brief The data type used in DGL Runtime.
* \brief The type code options DGLDataType.
*/
typedef enum {
/*! \brief signed integer */
kDGLInt = 0U,
/*! \brief unsigned integer */
kDGLUInt = 1U,
/*! \brief IEEE floating point */
kDGLFloat = 2U,
/*! \brief bfloat16 */
kDGLBfloat = 4U,
// add more data types if we are going to support them
} DGLDataTypeCode;
/*!
* \brief The data type the tensor can hold. The data type is assumed to follow the
* native endian-ness. An explicit error message should be raised when attempting to
* export an array with non-native endianness
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments DGL API function always takes bits=64 and lanes=1
*/
typedef DLDataType DGLType;
typedef struct {
/*!
* \brief Type code of base types.
* We keep it uint8_t instead of DGLDataTypeCode for minimal memory
* footprint, but the value should be one of DGLDataTypeCode enum values.
* */
uint8_t code;
/*!
* \brief Number of bits, common choices are 8, 16, 32.
*/
uint8_t bits;
/*! \brief Number of lanes in the type, used for vector types. */
uint16_t lanes;
} DGLDataType;
/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext DGLContext;
typedef struct {
/*! \brief The device type used in the device. */
DGLDeviceType device_type;
/*!
* \brief The device index.
* For vanilla CPU memory, pinned memory, or managed memory, this is set to 0.
*/
int32_t device_id;
} DGLContext;
/*!
* \brief The tensor array stucture to DGL API.
* The structure is heavily inspired by DLTensor from DLPack.
*/
typedef DLTensor DGLArray;
typedef struct {
/*!
* \brief The data pointer points to the allocated data.
*
* Depending on the device context, it can be a CPU pointer, or a CUDA
* device pointer or acl_mem handle in OpenCL.
* This pointer is always aligned to 256 bytes as in CUDA. Use the
* `byte_offset` field to mark the beginning of the actual data (if the address
* is not 256 byte aligned).
*
* Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow,
* TVM, perhaps others) do not adhere to this 256 byte alignment requirement
* on CPU/CUDA/ROCm, and always use `byte_offset=0`. This is likely to be
* fixed in the future; at the moment it is recommended
* to not rely on the data pointer being correctly aligned.
*
* For a DGLArray, the size of memory required to store the contents of
* data can be calculated as follows:
*
* \code{.c}
* static inline size_t GetDataSize(const DGLArray* t) {
* size_t size = 1;
* for (int32_t i = 0; i < t->ndim; ++i) {
* size *= t->shape[i];
* }
* size *= (t->dtype.bits * t->dtype.lanes + 7) / 8;
* return size;
* }
* \endcode
*/
void* data;
/*! \brief The device of the tensor */
DGLContext ctx;
/*! \brief Number of dimensions */
int32_t ndim;
/*! \brief The data type of the pointer*/
DGLDataType dtype;
/*! \brief The shape of the tensor */
int64_t* shape;
/*!
* \brief strides of the tensor (in number of elements, not bytes)
* can be NULL, indicating tensor is compact and row-majored.
*/
int64_t* strides;
/*! \brief The offset in bytes to the beginning pointer to data */
uint64_t byte_offset;
} DGLArray;
/*! \brief the array handle */
typedef DGLArray* DGLArrayHandle;
@@ -124,7 +206,7 @@ typedef union {
double v_float64;
void* v_handle;
const char* v_str;
DGLType v_type;
DGLDataType v_type;
DGLContext v_ctx;
} DGLValue;
@@ -455,32 +537,6 @@ DGL_DLL int DGLArrayCopyToBytes(DGLArrayHandle handle,
DGL_DLL int DGLArrayCopyFromTo(DGLArrayHandle from,
DGLArrayHandle to);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Create a new runtime stream.
*
@@ -557,12 +613,12 @@ DGL_DLL int DGLLoadTensorAdapter(const char *path);
/*!
* \brief Pin host memory.
*/
int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);
int DGLArrayPinData(DGLArrayHandle handle, DGLContext ctx);
/*!
* \brief Unpin host memory.
*/
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);
int DGLArrayUnpinData(DGLArrayHandle handle, DGLContext ctx);
/*!
* \brief Record the stream that's using this tensor.

View File

@@ -75,7 +75,7 @@ class DeviceAPI {
virtual void* AllocDataSpace(DGLContext ctx,
size_t nbytes,
size_t alignment,
DGLType type_hint) = 0;
DGLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
@@ -101,7 +101,7 @@ class DeviceAPI {
size_t num_bytes,
DGLContext ctx_from,
DGLContext ctx_to,
DGLType type_hint) = 0;
DGLDataType type_hint) = 0;
/*!
* \brief Create a new stream of execution.
*
@@ -189,7 +189,7 @@ class DeviceAPI {
*/
DGL_DLL virtual void* AllocWorkspace(DGLContext ctx,
size_t nbytes,
DGLType type_hint = {});
DGLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
@@ -213,7 +213,7 @@ class DeviceAPI {
* \param allow_missing Whether allow missing
* \return The corresponding device API.
*/
DGL_DLL static DeviceAPI* Get(DLDeviceType dev_type, bool allow_missing = false);
DGL_DLL static DeviceAPI* Get(DGLDeviceType dev_type, bool allow_missing = false);
};
/*! \brief The device type bigger than this is RPC device */

View File

@@ -0,0 +1,85 @@
/*!
* Copyright (c) 2022 by Contributors
* \file include/dgl/runtime/dlpack_convert.h
* \brief Conversion between NDArray and DLPack.
*/
#ifndef DGL_RUNTIME_DLPACK_CONVERT_H_
#define DGL_RUNTIME_DLPACK_CONVERT_H_
#include "c_runtime_api.h"
#include "ndarray.h"
struct DLManagedTensor;
namespace dgl {
namespace runtime {
struct DLPackConvert {
/*!
* \brief Create a DGL NDArray from a DLPack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Deleter for NDArray converted from DLPack.
*
* This is used from data which is passed from external DLPack(DLManagedTensor)
* that are not allocated inside of DGL.
* This enables us to create NDArray from memory allocated by other
* frameworks that are DLPack compatible
*/
static void DLPackDeleter(NDArray::Container* ptr);
/*! \brief Convert a DGL NDArray to a DLPack tensor.
*
* \param from The DGL NDArray.
* \return A DLPack tensor.
*/
static DLManagedTensor* ToDLPack(const NDArray &from);
};
} // namespace runtime
} // namespace dgl
#ifdef __cplusplus
extern "C" {
#endif
/*!
* \brief Delete (free) a DLManagedTensor's data.
* \param dltensor Pointer to the DLManagedTensor.
*/
DGL_DLL void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor);
/*!
* \brief Produce an array from the DLManagedTensor that shares data memory
* with the DLManagedTensor.
* \param from The source DLManagedTensor.
* \param out The output array handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayFromDLPack(DLManagedTensor* from,
DGLArrayHandle* out);
/*!
* \brief Produce a DLMangedTensor from the array that shares data memory with
* the array.
* \param from The source array.
* \param out The DLManagedTensor handle.
* \return 0 when success, -1 when failure happens
*/
DGL_DLL int DGLArrayToDLPack(DGLArrayHandle from, DLManagedTensor** out,
int alignment = 0);
#ifdef __cplusplus
} // DGL_EXTERN_C
#endif
#endif // DGL_RUNTIME_DLPACK_CONVERT_H_

View File

@@ -1,5 +1,5 @@
/*!
* Copyright (c) 2017 by Contributors
* Copyright (c) 2017-2022 by Contributors
* \file dgl/runtime/ndarray.h
* \brief Abstract device memory management API
*/
@@ -14,7 +14,6 @@
#include <memory>
#include "c_runtime_api.h"
#include "dlpack/dlpack.h"
#include "serializer.h"
#include "shared_mem.h"
@@ -23,44 +22,49 @@
#endif
// forward declaration
inline std::ostream& operator << (std::ostream& os, DGLType t);
inline std::ostream& operator << (std::ostream& os, DGLDataType t);
namespace dgl {
/*!
* \brief Type traits that converts a C type to a DLDataType.
* \brief Type traits that converts a C type to a DGLDataType.
*
* Usage:
* DLDataTypeTraits<int>::dtype == dtype
* DGLDataTypeTraits<int>::dtype == dtype
*/
template<typename T>
struct DLDataTypeTraits {
static constexpr DLDataType dtype{0, 0, 0}; // dummy
struct DGLDataTypeTraits {
static constexpr DGLDataType dtype{0, 0, 0}; // dummy
};
#define GEN_DLDATATYPETRAITS_FOR(T, code, bits) \
#define GEN_DGLDATATYPETRAITS_FOR(T, code, bits) \
template<> \
struct DLDataTypeTraits<T> { \
static constexpr DLDataType dtype{code, bits, 1}; \
struct DGLDataTypeTraits<T> { \
static constexpr DGLDataType dtype{code, bits, 1}; \
}
GEN_DLDATATYPETRAITS_FOR(int8_t, kDLInt, 8);
GEN_DLDATATYPETRAITS_FOR(int16_t, kDLInt, 16);
GEN_DLDATATYPETRAITS_FOR(int32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(int64_t, kDLInt, 64);
GEN_DGLDATATYPETRAITS_FOR(int8_t, kDGLInt, 8);
GEN_DGLDATATYPETRAITS_FOR(int16_t, kDGLInt, 16);
GEN_DGLDATATYPETRAITS_FOR(int32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(int64_t, kDGLInt, 64);
// XXX(BarclayII) most DL frameworks do not support unsigned int and long arrays, so I'm just
// converting uints to signed DTypes.
GEN_DLDATATYPETRAITS_FOR(uint32_t, kDLInt, 32);
GEN_DLDATATYPETRAITS_FOR(uint64_t, kDLInt, 64);
GEN_DGLDATATYPETRAITS_FOR(uint32_t, kDGLInt, 32);
GEN_DGLDATATYPETRAITS_FOR(uint64_t, kDGLInt, 64);
#ifdef DGL_USE_CUDA
#ifdef USE_FP16
GEN_DLDATATYPETRAITS_FOR(__half, kDLFloat, 16);
GEN_DGLDATATYPETRAITS_FOR(__half, kDGLFloat, 16);
#endif
#endif
GEN_DLDATATYPETRAITS_FOR(float, kDLFloat, 32);
GEN_DLDATATYPETRAITS_FOR(double, kDLFloat, 64);
#undef GEN_DLDATATYPETRAITS_FOR
GEN_DGLDATATYPETRAITS_FOR(float, kDGLFloat, 32);
GEN_DGLDATATYPETRAITS_FOR(double, kDGLFloat, 64);
#undef GEN_DGLDATATYPETRAITS_FOR
namespace runtime {
/*!
* \brief DLPack converter.
*/
struct DLPackConvert;
/*!
* \brief Managed NDArray.
* The array is backed by reference counted blocks.
@@ -135,8 +139,8 @@ class NDArray {
* \note this number is approximate in multi-threaded setting.
*/
inline int use_count() const;
/*! \return Pointer to content of DLTensor */
inline const DLTensor* operator->() const;
/*! \return Pointer to content of DGLArray */
inline const DGLArray* operator->() const;
/*! \return True if the ndarray is contiguous. */
bool IsContiguous() const;
/*! \return the data pointer with type. */
@@ -152,9 +156,9 @@ class NDArray {
* \param other The source array to be copied from.
* \note The copy runs on the dgl internal stream if it involves a GPU context.
*/
inline void CopyFrom(DLTensor* other);
inline void CopyFrom(DGLArray* other);
inline void CopyFrom(const NDArray& other);
inline void CopyTo(DLTensor *other) const;
inline void CopyTo(DGLArray *other) const;
inline void CopyTo(const NDArray &other) const;
/*!
@@ -162,7 +166,7 @@ class NDArray {
* \param ctx The target context.
* \return The array under another context.
*/
inline NDArray CopyTo(const DLContext &ctx) const;
inline NDArray CopyTo(const DGLContext &ctx) const;
/*!
* \brief Return a new array with a copy of the content.
*/
@@ -171,9 +175,9 @@ class NDArray {
* \brief In-place method to pin the current array by calling PinContainer
* on the underlying NDArray:Container.
* \note This is an in-place method. Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDGLCUDA: invalid, will throw an error.
*/
inline void PinMemory_();
/*!
@@ -212,13 +216,7 @@ class NDArray {
* \note The memory size of new array must be smaller than the current one.
*/
DGL_DLL NDArray CreateView(
std::vector<int64_t> shape, DLDataType dtype, int64_t offset = 0);
/*!
* \brief Create a reference view of NDArray that
* represents as DLManagedTensor.
* \return A DLManagedTensor
*/
DGL_DLL DLManagedTensor* ToDLPack() const;
std::vector<int64_t> shape, DGLDataType dtype, int64_t offset = 0);
/*!
* \brief Create an empty NDArray.
* \param shape The shape of the new array.
@@ -227,8 +225,8 @@ class NDArray {
* \return The created Array
*/
DGL_DLL static NDArray Empty(std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx);
DGLDataType dtype,
DGLContext ctx);
/*!
* \brief Create an empty NDArray with shared memory.
* \param name The name of shared memory.
@@ -240,8 +238,8 @@ class NDArray {
*/
DGL_DLL static NDArray EmptyShared(const std::string &name,
std::vector<int64_t> shape,
DLDataType dtype,
DLContext ctx,
DGLDataType dtype,
DGLContext ctx,
bool is_create);
/*!
* \brief Get the size of the array in the number of bytes.
@@ -253,26 +251,19 @@ class NDArray {
*/
int64_t NumElements() const;
/*!
* \brief Create a NDArray backed by a dlpack tensor.
*
* This allows us to create a NDArray using the memory
* allocated by an external deep learning framework
* that is DLPack compatible.
*
* The memory is retained until the NDArray went out of scope.
* \param tensor The DLPack tensor to copy from.
* \return The created NDArray view.
*/
DGL_DLL static NDArray FromDLPack(DLManagedTensor* tensor);
/*!
* \brief Create a NDArray by copying from std::vector.
* \tparam T Type of vector data. Determines the dtype of returned array.
*/
template<typename T>
DGL_DLL static NDArray FromVector(
const std::vector<T>& vec, DLContext ctx = DLContext{kDLCPU, 0});
const std::vector<T>& vec, DGLContext ctx = DGLContext{kDGLCPU, 0});
/*!
* \brief Create a NDArray from a raw pointer.
*/
DGL_DLL static NDArray CreateFromRaw(const std::vector<int64_t>& shape,
DGLDataType dtype, DGLContext ctx, void* raw, bool auto_free);
/*!
* \brief Create a std::vector from a 1D NDArray.
@@ -292,23 +283,23 @@ class NDArray {
* \param (optional) stream The stream used in copy.
*/
DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to);
DGLArray* from, DGLArray* to);
DGL_DLL static void CopyFromTo(
DLTensor* from, DLTensor* to, DGLStreamHandle stream);
DGLArray* from, DGLArray* to, DGLStreamHandle stream);
/*!
* \brief Function to pin the DLTensor of a Container.
* \brief Function to pin the DGLArray of a Container.
* \param ptr The container to be pinned.
* \note Data of the given array will be pinned inplace.
* Behavior depends on the current context,
* kDLCPU: will be pinned;
* kDGLCPU: will be pinned;
* IsPinned: directly return;
* kDLGPU: invalid, will throw an error.
* kDGLCUDA: invalid, will throw an error.
*/
DGL_DLL static void PinContainer(Container* ptr);
/*!
* \brief Function to unpin the DLTensor of a Container.
* \brief Function to unpin the DGLArray of a Container.
* \param ptr The container to be unpinned.
* \note Data of the given array will be unpinned inplace.
* Behavior depends on the current context,
@@ -318,7 +309,7 @@ class NDArray {
DGL_DLL static void UnpinContainer(Container* ptr);
/*!
* \brief Function check if the DLTensor of a Container is pinned.
* \brief Function check if the DGLArray of a Container is pinned.
* \param ptr The container to be checked.
* \return true if pinned.
*/
@@ -332,45 +323,57 @@ class NDArray {
DGL_DLL static void RecordStream(DGLArray* tensor, DGLStreamHandle stream);
// internal namespace
struct Internal;
struct Internal {
// Default deleter for the container
static void DefaultDeleter(NDArray::Container* ptr);
// Local create function which allocates tensor metadata
// but does not allocate space for the data.
static NDArray Create(std::vector<int64_t> shape,
DGLDataType dtype, DGLContext ctx);
// Implementation of API function
static DGLArray* MoveAsDGLArray(NDArray arr);
};
private:
/*! \brief Internal Data content */
Container* data_{nullptr};
// enable internal functions
friend struct Internal;
friend struct DLPackConvert;
friend class DGLRetValue;
friend class DGLArgsSetter;
};
/*!
* \brief Save a DLTensor to stream
* \brief Save a DGLArray to stream
* \param strm The outpu stream
* \param tensor The tensor to be saved.
*/
inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor);
inline bool SaveDGLArray(dmlc::Stream* strm, const DGLArray* tensor);
/*!
* \brief Reference counted Container object used to back NDArray.
*
* This object is DLTensor compatible:
* This object is DGLArray compatible:
* the pointer to the NDArrayContainer can be directly
* interpreted as a DLTensor*
* interpreted as a DGLArray*
*
* \note: do not use this function directly, use NDArray.
*/
struct NDArray::Container {
public:
// NOTE: the first part of this structure is the same as
// DLManagedTensor, note that, however, the deleter
// is only called when the reference counter goes to 0
/*!
* \brief The corresponding dl_tensor field.
* \note it is important that the first field is DLTensor
* So that this data structure is DLTensor compatible.
* The head ptr of this struct can be viewed as DLTensor*.
/*! NOTE: the first part of this structure is the same as
* DLManagedTensor, note that, however, the deleter
* is only called when the reference counter goes to 0
*/
DLTensor dl_tensor;
/*!
* \brief Tensor structure.
* \note it is important that the first field is DGLArray
* So that this data structure is DGLArray compatible.
* The head ptr of this struct can be viewed as DGLArray*.
*/
DGLArray dl_tensor;
/*!
* \brief addtional context, reserved for recycling
* \note We can attach additional content here
@@ -411,6 +414,7 @@ struct NDArray::Container {
}
private:
friend struct DLPackConvert;
friend class NDArray;
friend class RPCWrappedFunc;
/*!
@@ -450,7 +454,7 @@ inline void NDArray::reset() {
}
}
inline void NDArray::CopyFrom(DLTensor* other) {
inline void NDArray::CopyFrom(DGLArray* other) {
CHECK(data_ != nullptr);
CopyFromTo(other, &(data_->dl_tensor));
}
@@ -460,7 +464,7 @@ inline void NDArray::CopyFrom(const NDArray& other) {
CopyFrom(&(other.data_->dl_tensor));
}
inline void NDArray::CopyTo(DLTensor *other) const {
inline void NDArray::CopyTo(DGLArray *other) const {
CHECK(data_ != nullptr);
CopyFromTo(&(data_->dl_tensor), other);
}
@@ -470,9 +474,9 @@ inline void NDArray::CopyTo(const NDArray &other) const {
CopyTo(&(other.data_->dl_tensor));
}
inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
inline NDArray NDArray::CopyTo(const DGLContext &ctx) const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
const DGLArray* dptr = operator->();
NDArray ret = Empty(std::vector<int64_t>(dptr->shape, dptr->shape + dptr->ndim),
dptr->dtype, ctx);
this->CopyTo(ret);
@@ -481,7 +485,7 @@ inline NDArray NDArray::CopyTo(const DLContext &ctx) const {
inline NDArray NDArray::Clone() const {
CHECK(data_ != nullptr);
const DLTensor* dptr = operator->();
const DGLArray* dptr = operator->();
return this->CopyTo(dptr->ctx);
}
@@ -510,15 +514,15 @@ inline int NDArray::use_count() const {
return data_->ref_counter_.load(std::memory_order_relaxed);
}
inline const DLTensor* NDArray::operator->() const {
inline const DGLArray* NDArray::operator->() const {
return &(data_->dl_tensor);
}
/*! \brief Magic number for NDArray file */
constexpr uint64_t kDGLNDArrayMagic = 0xDD5E40F096B4A13F;
inline bool SaveDLTensor(dmlc::Stream* strm,
DLTensor* tensor) {
inline bool SaveDGLArray(dmlc::Stream* strm,
DGLArray* tensor) {
uint64_t header = kDGLNDArrayMagic, reserved = 0;
strm->Write(header);
strm->Write(reserved);
@@ -531,8 +535,8 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
//
// We can always do array.CopyTo(target_ctx) to get a corresponding
// array in the target context.
DLContext cpu_ctx;
cpu_ctx.device_type = kDLCPU;
DGLContext cpu_ctx;
cpu_ctx.device_type = kDGLCPU;
cpu_ctx.device_id = 0;
strm->Write(cpu_ctx);
strm->Write(tensor->ndim);
@@ -548,7 +552,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
strm->Write(data_byte_size);
if (DMLC_IO_NO_ENDIAN_SWAP &&
tensor->ctx.device_type == kDLCPU &&
tensor->ctx.device_type == kDGLCPU &&
tensor->strides == nullptr &&
tensor->byte_offset == 0) {
// quick path
@@ -573,16 +577,16 @@ inline bool SaveDLTensor(dmlc::Stream* strm,
*/
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt: return "int";
case kDLUInt: return "uint";
case kDLFloat: return "float";
case kDGLInt: return "int";
case kDGLUInt: return "uint";
case kDGLFloat: return "float";
case kStr: return "str";
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kObjectHandle: return "ObjectHandle";
case kArrayHandle: return "ArrayHandle";
case kDGLType: return "DGLType";
case kDGLDataType: return "DGLDataType";
case kDGLContext: return "DGLContext";
case kFuncHandle: return "FunctionHandle";
case kModuleHandle: return "ModuleHandle";
@@ -597,17 +601,11 @@ inline const char* TypeCode2Str(int type_code) {
* \param device_type The device type code.
* \return The name of the device.
*/
inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
inline const char* DeviceTypeCode2Str(DGLDeviceType device_type) {
switch (device_type) {
case kDLCPU: return "cpu";
case kDLGPU: return "cuda";
case kDLCPUPinned: return "cpu_pinned";
case kDLOpenCL: return "opencl";
case kDLVulkan: return "vulkan";
case kDLMetal: return "metal";
case kDLVPI: return "vpi";
case kDLROCM: return "rocm";
default: LOG(FATAL) << "Unknown device type code="
case kDGLCPU: return "cpu";
case kDGLCUDA: return "cuda";
default: LOG(FATAL) << "Unsupported device type code="
<< static_cast<int>(device_type); return "";
}
}
@@ -617,16 +615,16 @@ inline const char* DeviceTypeCode2Str(DLDeviceType device_type) {
* \param s The string to be converted.
* \return The corresponding dgl type.
*/
inline DGLType String2DGLType(std::string s) {
DGLType t;
inline DGLDataType String2DGLDataType(std::string s) {
DGLDataType t;
t.bits = 32; t.lanes = 1;
const char* scan;
if (s.substr(0, 3) == "int") {
t.code = kDLInt; scan = s.c_str() + 3;
t.code = kDGLInt; scan = s.c_str() + 3;
} else if (s.substr(0, 4) == "uint") {
t.code = kDLUInt; scan = s.c_str() + 4;
t.code = kDGLUInt; scan = s.c_str() + 4;
} else if (s.substr(0, 5) == "float") {
t.code = kDLFloat; scan = s.c_str() + 5;
t.code = kDGLFloat; scan = s.c_str() + 5;
} else if (s.substr(0, 6) == "handle") {
t.code = kHandle;
t.bits = 64; // handle uses 64 bit by default.
@@ -649,7 +647,7 @@ inline DGLType String2DGLType(std::string s) {
* \param t The type to be converted.
* \return The corresponding dgl type in string.
*/
inline std::string DGLType2String(DGLType t) {
inline std::string DGLDataType2String(DGLDataType t) {
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
std::ostringstream os;
os << t;
@@ -728,20 +726,20 @@ dgl::runtime::NDArray operator != (int64_t lhs, const dgl::runtime::NDArray& a2)
std::ostream& operator << (std::ostream& os, dgl::runtime::NDArray array);
///////////////// Operator overloading for DLDataType /////////////////
///////////////// Operator overloading for DGLDataType /////////////////
/*! \brief Check whether two data types are the same.*/
inline bool operator == (const DLDataType& ty1, const DLDataType& ty2) {
inline bool operator == (const DGLDataType& ty1, const DGLDataType& ty2) {
return ty1.code == ty2.code && ty1.bits == ty2.bits && ty1.lanes == ty2.lanes;
}
/*! \brief Check whether two data types are different.*/
inline bool operator != (const DLDataType& ty1, const DLDataType& ty2) {
inline bool operator != (const DGLDataType& ty1, const DGLDataType& ty2) {
return !(ty1 == ty2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, DGLType t) {
inline std::ostream& operator << (std::ostream& os, DGLDataType t) {
os << dgl::runtime::TypeCode2Str(t.code);
if (t.code == kHandle) return os;
os << static_cast<int>(t.bits);
@@ -752,20 +750,20 @@ inline std::ostream& operator << (std::ostream& os, DGLType t) {
}
#endif
///////////////// Operator overloading for DLContext /////////////////
///////////////// Operator overloading for DGLContext /////////////////
/*! \brief Check whether two device contexts are the same.*/
inline bool operator == (const DLContext& ctx1, const DLContext& ctx2) {
inline bool operator == (const DGLContext& ctx1, const DGLContext& ctx2) {
return ctx1.device_type == ctx2.device_type && ctx1.device_id == ctx2.device_id;
}
/*! \brief Check whether two device contexts are different.*/
inline bool operator != (const DLContext& ctx1, const DLContext& ctx2) {
inline bool operator != (const DGLContext& ctx1, const DGLContext& ctx2) {
return !(ctx1 == ctx2);
}
#ifndef _LIBCPP_SGX_NO_IOSTREAMS
inline std::ostream& operator << (std::ostream& os, const DLContext& ctx) {
inline std::ostream& operator << (std::ostream& os, const DGLContext& ctx) {
return os << dgl::runtime::DeviceTypeCode2Str(ctx.device_type) << ":" << ctx.device_id;
}
#endif

View File

@@ -350,28 +350,28 @@ class DGLPODValue_ {
// Allow automatic conversion from int to float
// This avoids errors when user pass in int from
// the frontend while the API expects a float.
if (type_code_ == kDLInt) {
if (type_code_ == kDGLInt) {
return static_cast<double>(value_.v_int64);
}
DGL_CHECK_TYPE_CODE(type_code_, kDLFloat);
DGL_CHECK_TYPE_CODE(type_code_, kDGLFloat);
return value_.v_float64;
}
operator int64_t() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64;
}
operator uint64_t() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64;
}
operator int() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
CHECK_LE(value_.v_int64,
std::numeric_limits<int>::max());
return static_cast<int>(value_.v_int64);
}
operator bool() const {
DGL_CHECK_TYPE_CODE(type_code_, kDLInt);
DGL_CHECK_TYPE_CODE(type_code_, kDGLInt);
return value_.v_int64 != 0;
}
operator void*() const {
@@ -380,14 +380,14 @@ class DGLPODValue_ {
DGL_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle;
}
operator DLTensor*() const {
operator DGLArray*() const {
if (type_code_ == kArrayHandle ||
type_code_ == kNDArrayContainer) {
return static_cast<DLTensor*>(value_.v_handle);
return static_cast<DGLArray*>(value_.v_handle);
} else {
if (type_code_ == kNull) return nullptr;
LOG(FATAL) << "Expected "
<< "DLTensor* or NDArray but get "
<< "DGLArray* or NDArray but get "
<< TypeCode2Str(type_code_);
return nullptr;
}
@@ -457,14 +457,14 @@ class DGLArgValue : public DGLPODValue_ {
using DGLPODValue_::operator int;
using DGLPODValue_::operator bool;
using DGLPODValue_::operator void*;
using DGLPODValue_::operator DLTensor*;
using DGLPODValue_::operator DGLArray*;
using DGLPODValue_::operator NDArray;
using DGLPODValue_::operator DGLContext;
// conversion operator.
operator std::string() const {
if (type_code_ == kDGLType) {
return DGLType2String(operator DGLType());
if (type_code_ == kDGLDataType) {
return DGLDataType2String(operator DGLDataType());
} else if (type_code_ == kBytes) {
DGLByteArray* arr = static_cast<DGLByteArray*>(value_.v_handle);
return std::string(arr->data, arr->size);
@@ -473,11 +473,11 @@ class DGLArgValue : public DGLPODValue_ {
return std::string(value_.v_str);
}
}
operator DGLType() const {
operator DGLDataType() const {
if (type_code_ == kStr) {
return String2DGLType(operator std::string());
return String2DGLDataType(operator std::string());
}
DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
return value_.v_type;
}
operator PackedFunc() const {
@@ -549,7 +549,7 @@ class DGLRetValue : public DGLPODValue_ {
using DGLPODValue_::operator int;
using DGLPODValue_::operator bool;
using DGLPODValue_::operator void*;
using DGLPODValue_::operator DLTensor*;
using DGLPODValue_::operator DGLArray*;
using DGLPODValue_::operator DGLContext;
using DGLPODValue_::operator NDArray;
// Disable copy and assign from another value, but allow move.
@@ -558,19 +558,19 @@ class DGLRetValue : public DGLPODValue_ {
}
// conversion operators
operator std::string() const {
if (type_code_ == kDGLType) {
return DGLType2String(operator DGLType());
if (type_code_ == kDGLDataType) {
return DGLDataType2String(operator DGLDataType());
} else if (type_code_ == kBytes) {
return *ptr<std::string>();
}
DGL_CHECK_TYPE_CODE(type_code_, kStr);
return *ptr<std::string>();
}
operator DGLType() const {
operator DGLDataType() const {
if (type_code_ == kStr) {
return String2DGLType(operator std::string());
return String2DGLDataType(operator std::string());
}
DGL_CHECK_TYPE_CODE(type_code_, kDGLType);
DGL_CHECK_TYPE_CODE(type_code_, kDGLDataType);
return value_.v_type;
}
operator PackedFunc() const {
@@ -595,7 +595,7 @@ class DGLRetValue : public DGLPODValue_ {
return *this;
}
DGLRetValue& operator=(double value) {
this->SwitchToPOD(kDLFloat);
this->SwitchToPOD(kDGLFloat);
value_.v_float64 = value;
return *this;
}
@@ -610,17 +610,17 @@ class DGLRetValue : public DGLPODValue_ {
return *this;
}
DGLRetValue& operator=(int64_t value) {
this->SwitchToPOD(kDLInt);
this->SwitchToPOD(kDGLInt);
value_.v_int64 = value;
return *this;
}
DGLRetValue& operator=(int value) {
this->SwitchToPOD(kDLInt);
this->SwitchToPOD(kDGLInt);
value_.v_int64 = value;
return *this;
}
DGLRetValue& operator=(DGLType t) {
this->SwitchToPOD(kDGLType);
DGLRetValue& operator=(DGLDataType t) {
this->SwitchToPOD(kDGLDataType);
value_.v_type = t;
return *this;
}
@@ -630,7 +630,7 @@ class DGLRetValue : public DGLPODValue_ {
return *this;
}
DGLRetValue& operator=(bool value) {
this->SwitchToPOD(kDLInt);
this->SwitchToPOD(kDGLInt);
value_.v_int64 = value;
return *this;
}
@@ -859,17 +859,17 @@ class DGLArgsSetter {
std::is_integral<T>::value>::type>
void operator()(size_t i, T value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
type_codes_[i] = kDLInt;
type_codes_[i] = kDGLInt;
}
void operator()(size_t i, uint64_t value) const {
values_[i].v_int64 = static_cast<int64_t>(value);
CHECK_LE(value,
static_cast<uint64_t>(std::numeric_limits<int64_t>::max()));
type_codes_[i] = kDLInt;
type_codes_[i] = kDGLInt;
}
void operator()(size_t i, double value) const {
values_[i].v_float64 = value;
type_codes_[i] = kDLFloat;
type_codes_[i] = kDGLFloat;
}
void operator()(size_t i, std::nullptr_t value) const {
values_[i].v_handle = value;
@@ -883,7 +883,7 @@ class DGLArgsSetter {
values_[i].v_handle = value;
type_codes_[i] = kHandle;
}
void operator()(size_t i, DLTensor* value) const {
void operator()(size_t i, DGLArray* value) const {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
@@ -891,9 +891,9 @@ class DGLArgsSetter {
values_[i].v_ctx = value;
type_codes_[i] = kDGLContext;
}
void operator()(size_t i, DGLType value) const {
void operator()(size_t i, DGLDataType value) const {
values_[i].v_type = value;
type_codes_[i] = kDGLType;
type_codes_[i] = kDGLDataType;
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext
* Include this file to enable serialization of DGLDataType, DGLContext
*/
#ifndef DGL_RUNTIME_SERIALIZER_H_
#define DGL_RUNTIME_SERIALIZER_H_
@@ -16,13 +16,13 @@ namespace dmlc {
namespace serializer {
template <>
struct Handler<DLDataType> {
inline static void Write(Stream *strm, const DLDataType &dtype) {
struct Handler<DGLDataType> {
inline static void Write(Stream *strm, const DGLDataType &dtype) {
Handler<uint8_t>::Write(strm, dtype.code);
Handler<uint8_t>::Write(strm, dtype.bits);
Handler<uint16_t>::Write(strm, dtype.lanes);
}
inline static bool Read(Stream *strm, DLDataType *dtype) {
inline static bool Read(Stream *strm, DGLDataType *dtype) {
if (!Handler<uint8_t>::Read(strm, &(dtype->code))) return false;
if (!Handler<uint8_t>::Read(strm, &(dtype->bits))) return false;
if (!Handler<uint16_t>::Read(strm, &(dtype->lanes))) return false;
@@ -31,16 +31,16 @@ struct Handler<DLDataType> {
};
template <>
struct Handler<DLContext> {
inline static void Write(Stream *strm, const DLContext &ctx) {
struct Handler<DGLContext> {
inline static void Write(Stream *strm, const DGLContext &ctx) {
int32_t device_type = static_cast<int32_t>(ctx.device_type);
Handler<int32_t>::Write(strm, device_type);
Handler<int32_t>::Write(strm, ctx.device_id);
}
inline static bool Read(Stream *strm, DLContext *ctx) {
inline static bool Read(Stream *strm, DGLContext *ctx) {
int32_t device_type = 0;
if (!Handler<int32_t>::Read(strm, &(device_type))) return false;
ctx->device_type = static_cast<DLDeviceType>(device_type);
ctx->device_type = static_cast<DGLDeviceType>(device_type);
if (!Handler<int32_t>::Read(strm, &(ctx->device_id))) return false;
return true;
}

View File

@@ -2,7 +2,7 @@
* Copyright (c) 2017 by Contributors
* \file dgl/runtime/serializer.h
* \brief Serializer extension to support DGL data types
* Include this file to enable serialization of DLDataType, DLContext
* Include this file to enable serialization of DGLDataType, DGLContext
*/
#ifndef DGL_RUNTIME_SMART_PTR_SERIALIZER_H_
#define DGL_RUNTIME_SMART_PTR_SERIALIZER_H_

View File

@@ -18,7 +18,7 @@ namespace runtime {
* \param bits The number of bits to be matched.
* \param lanes The number of lanes sin the type.
*/
inline bool TypeMatch(DGLType t, int code, int bits, int lanes = 1) {
inline bool TypeMatch(DGLDataType t, int code, int bits, int lanes = 1) {
return t.code == code && t.bits == bits && t.lanes == lanes;
}
} // namespace runtime

View File

@@ -10,7 +10,7 @@ from numbers import Number, Integral
from ..base import _LIB, check_call
from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLByteArray, DGLContext
from ..runtime_ctypes import DGLDataType, DGLByteArray, DGLContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
from .types import DGLValue, TypeCode
@@ -115,7 +115,7 @@ def _make_dgl_args(args, temp_args):
elif isinstance(arg, Number):
values[i].v_float64 = arg
type_codes[i] = TypeCode.FLOAT
elif isinstance(arg, DGLType):
elif isinstance(arg, DGLDataType):
values[i].v_str = c_str(str(arg))
type_codes[i] = TypeCode.STR
elif isinstance(arg, DGLContext):

View File

@@ -4,7 +4,7 @@ from __future__ import absolute_import as _abs
import ctypes
from ..base import py_str, check_call, _LIB
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLType, DGLContext
from ..runtime_ctypes import DGLByteArray, TypeCode, DGLDataType, DGLContext
class DGLValue(ctypes.Union):
"""DGLValue in C API"""
@@ -12,7 +12,7 @@ class DGLValue(ctypes.Union):
("v_float64", ctypes.c_double),
("v_handle", ctypes.c_void_p),
("v_str", ctypes.c_char_p),
("v_type", DGLType),
("v_type", DGLDataType),
("v_ctx", DGLContext)]

View File

@@ -3,16 +3,16 @@ from libcpp.vector cimport vector
from libcpp cimport bool
from cpython.version cimport PY_MAJOR_VERSION
from cpython cimport pycapsule
from libc.stdint cimport int64_t, uint64_t, uint8_t, uint16_t
from libc.stdint cimport int32_t, int64_t, uint64_t, uint8_t, uint16_t
import ctypes
cdef enum DGLTypeCode:
cdef enum DGLObjectTypeCode:
kInt = 0
kUInt = 1
kFloat = 2
kHandle = 3
kNull = 4
kDGLType = 5
kDGLDataType = 5
kDGLContext = 6
kArrayHandle = 7
kObjectHandle = 8
@@ -24,26 +24,26 @@ cdef enum DGLTypeCode:
kExtBegin = 15
cdef extern from "dgl/runtime/c_runtime_api.h":
ctypedef struct DLDataType:
ctypedef struct DGLDataType:
uint8_t code
uint8_t bits
uint16_t lanes
ctypedef struct DLContext:
int device_type
int device_id
ctypedef struct DGLContext:
int32_t device_type
int32_t device_id
ctypedef struct DLTensor:
ctypedef struct DGLArray:
void* data
DLContext ctx
int ndim
DLDataType dtype
DGLContext ctx
int32_t ndim
DGLDataType dtype
int64_t* shape
int64_t* strides
uint64_t byte_offset
ctypedef struct DLManagedTensor:
DLTensor dl_tensor
DGLArray dl_tensor
void* manager_ctx
void (*deleter)(DLManagedTensor* self)
@@ -52,13 +52,11 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
double v_float64
void* v_handle
const char* v_str
DLDataType v_type
DLContext v_ctx
DGLDataType v_type
DGLContext v_ctx
ctypedef int64_t dgl_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef DLTensor DGLArray
ctypedef DGLArray* CDGLArrayHandle
ctypedef DGLArray* DGLArrayHandle
ctypedef void* DGLStreamHandle
ctypedef void* DGLRetValueHandle
ctypedef void* DGLFunctionHandle
@@ -94,9 +92,9 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int DGLCbArgToReturn(DGLValue* value, int code)
int DGLArrayAlloc(dgl_index_t* shape,
dgl_index_t ndim,
DLDataType dtype,
DLContext ctx,
DLTensorHandle* out)
DGLDataType dtype,
DGLContext ctx,
DGLArrayHandle* out)
int DGLArrayAllocSharedMem(const char *mem_name,
const dgl_index_t *shape,
int ndim,
@@ -104,16 +102,10 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int dtype_bits,
int dtype_lanes,
bool is_create,
CDGLArrayHandle* out)
int DGLArrayFree(DLTensorHandle handle)
int DGLArrayCopyFromTo(DLTensorHandle src,
DLTensorHandle to)
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DLTensorHandle* out)
int DGLArrayToDLPack(DLTensorHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
DGLArrayHandle* out)
int DGLArrayFree(DGLArrayHandle handle)
int DGLArrayCopyFromTo(DGLArrayHandle src,
DGLArrayHandle to)
cdef extern from "dgl/runtime/c_object_api.h":
int DGLObjectFree(ObjectHandle handle)
@@ -127,6 +119,14 @@ cdef extern from "dgl/runtime/c_object_api.h":
int* out_type_code,
int* out_success)
cdef extern from "dgl/runtime/dlpack_convert.h":
int DGLArrayFromDLPack(DLManagedTensor* arr_from,
DGLArrayHandle* out)
int DGLArrayToDLPack(DGLArrayHandle arr_from,
DLManagedTensor** out,
int alignment)
void DGLDLManagedTensorCallDeleter(DLManagedTensor* dltensor)
cdef inline py_str(const char* x):
if PY_MAJOR_VERSION < 3:
return x

View File

@@ -4,7 +4,9 @@ from cpython cimport Py_INCREF, Py_DECREF
from numbers import Number, Integral
from ..base import string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import DGLType, DGLContext, DGLByteArray
from ..runtime_ctypes import DGLDataType as CTypesDGLDataType, \
DGLContext as CTypesDGLContext, \
DGLByteArray
cdef void dgl_callback_finalize(void* fhandle):
@@ -107,13 +109,13 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, Number):
value[0].v_float64 = arg
tcode[0] = kFloat
elif isinstance(arg, DGLType):
elif isinstance(arg, CTypesDGLDataType):
tstr = c_str(str(arg))
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
elif isinstance(arg, DGLContext):
value[0].v_ctx = (<DLContext*>(
elif isinstance(arg, CTypesDGLContext):
value[0].v_ctx = (<DGLContext*>(
<unsigned long long>ctypes.addressof(arg)))[0]
tcode[0] = kDGLContext
elif isinstance(arg, bytearray):
@@ -183,7 +185,7 @@ cdef inline object make_ret(DGLValue value, int tcode):
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kDGLContext:
return DGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
return CTypesDGLContext(value.v_ctx.device_type, value.v_ctx.device_id)
# (minjie): class module are not used in DGL.
#elif tcode == kModuleHandle:
# return _CLASS_MODULE(ctypes_handle(value.v_handle))

View File

@@ -1,4 +1,4 @@
from ..runtime_ctypes import DGLArrayHandle
from ..runtime_ctypes import DGLArrayHandle as PyDGLArrayHandle
cdef const char* _c_str_dltensor = "dltensor"
cdef const char* _c_str_used_dltensor = "used_dltensor"
@@ -13,7 +13,7 @@ cdef void _c_dlpack_deleter(object pycaps):
def _from_dlpack(object dltensor):
cdef DLManagedTensor* ptr
cdef DLTensorHandle chandle
cdef DGLArrayHandle chandle
if pycapsule.PyCapsule_IsValid(dltensor, _c_str_dltensor):
ptr = <DLManagedTensor*>pycapsule.PyCapsule_GetPointer(dltensor, _c_str_dltensor)
CALL(DGLArrayFromDLPack(ptr, &chandle))
@@ -25,7 +25,7 @@ def _from_dlpack(object dltensor):
cdef class NDArrayBase:
cdef DLTensor* chandle
cdef DGLArray* chandle
cdef int c_is_view
cdef inline _set_handle(self, handle):
@@ -34,7 +34,7 @@ cdef class NDArrayBase:
self.chandle = NULL
else:
ptr = ctypes.cast(handle, ctypes.c_void_p).value
self.chandle = <DLTensor*>(ptr)
self.chandle = <DGLArray*>(ptr)
property _dgl_handle:
def __get__(self):
@@ -46,7 +46,7 @@ cdef class NDArrayBase:
return None
else:
return ctypes.cast(
<unsigned long long>self.chandle, DGLArrayHandle)
<unsigned long long>self.chandle, PyDGLArrayHandle)
def __set__(self, value):
self._set_handle(value)
@@ -82,7 +82,7 @@ cdef class NDArrayBase:
cdef c_make_array(void* chandle, is_view):
ret = _CLASS_NDARRAY(None, is_view)
(<NDArrayBase>ret).chandle = <DLTensor*>chandle
(<NDArrayBase>ret).chandle = <DGLArray*>chandle
return ret

View File

@@ -6,7 +6,7 @@ import sys
import ctypes
import numpy as np
from .base import _LIB, check_call, c_array, string_types, _FFI_MODE, c_str
from .runtime_ctypes import DGLType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import DGLDataType, DGLContext, DGLArray, DGLArrayHandle
from .runtime_ctypes import TypeCode, dgl_shape_index_t
@@ -72,7 +72,7 @@ def numpyasarray(np_data):
arr.data = data.ctypes.data_as(ctypes.c_void_p)
arr.shape = shape
arr.strides = None
arr.dtype = DGLType(np.dtype(data.dtype).name)
arr.dtype = DGLDataType(np.dtype(data.dtype).name)
arr.ndim = data.ndim
# CPU device
arr.ctx = context(1, 0)
@@ -101,7 +101,7 @@ def empty(shape, dtype="float32", ctx=context(1, 0)):
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLType(dtype)
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAlloc(
shape, ndim,
ctypes.c_int(dtype.type_code),
@@ -139,7 +139,7 @@ def empty_shared_mem(name, is_create, shape, dtype="float32"):
shape = c_array(dgl_shape_index_t, shape)
ndim = ctypes.c_int(len(shape))
handle = DGLArrayHandle()
dtype = DGLType(dtype)
dtype = DGLDataType(dtype)
check_call(_LIB.DGLArrayAllocSharedMem(
name, shape, ndim,
ctypes.c_int(dtype.type_code),
@@ -254,7 +254,7 @@ class NDArrayBase(_NDArrayBase):
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
t = DGLType(self.dtype)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
@@ -286,7 +286,7 @@ class NDArrayBase(_NDArrayBase):
np_arr : numpy.ndarray
The corresponding numpy array.
"""
t = DGLType(self.dtype)
t = DGLDataType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)

View File

@@ -17,7 +17,7 @@ class TypeCode(object):
FLOAT = 2
HANDLE = 3
NULL = 4
DGL_TYPE = 5
DGL_DATA_TYPE = 5
DGL_CONTEXT = 6
ARRAY_HANDLE = 7
OBJECT_HANDLE = 8
@@ -33,7 +33,7 @@ class DGLByteArray(ctypes.Structure):
_fields_ = [("data", ctypes.POINTER(ctypes.c_byte)),
("size", ctypes.c_size_t)]
class DGLType(ctypes.Structure):
class DGLDataType(ctypes.Structure):
"""DGL datatype structure"""
_fields_ = [("type_code", ctypes.c_uint8),
("bits", ctypes.c_uint8),
@@ -50,7 +50,7 @@ class DGLType(ctypes.Structure):
if type_str in cls._cache:
return cls._cache[type_str]
inst = super(DGLType, cls).__new__(DGLType)
inst = super(DGLDataType, cls).__new__(DGLDataType)
if isinstance(type_str, np.dtype):
type_str = str(type_str)
@@ -84,7 +84,7 @@ class DGLType(ctypes.Structure):
pass
def __repr__(self):
x = "%s%d" % (DGLType.CODE2STR[self.type_code], self.bits)
x = "%s%d" % (DGLDataType.CODE2STR[self.type_code], self.bits)
if self.lanes != 1:
x += "x%d" % self.lanes
return x
@@ -250,7 +250,7 @@ class DGLArray(ctypes.Structure):
_fields_ = [("data", ctypes.c_void_p),
("ctx", DGLContext),
("ndim", ctypes.c_int),
("dtype", DGLType),
("dtype", DGLDataType),
("shape", ctypes.POINTER(dgl_shape_index_t)),
("strides", ctypes.POINTER(dgl_shape_index_t)),
("byte_offset", ctypes.c_uint64)]

View File

@@ -13,7 +13,7 @@ import numpy as _np
from ._ffi.object import register_object, ObjectBase
from ._ffi.function import _init_api
from ._ffi.ndarray import DGLContext, DGLType, NDArrayBase
from ._ffi.ndarray import DGLContext, DGLDataType, NDArrayBase
from ._ffi.ndarray import context, empty, empty_shared_mem, from_dlpack, numpyasarray
from ._ffi.ndarray import _set_class_ndarray
from . import backend as F

View File

@@ -19,8 +19,8 @@ using namespace dgl::runtime;
namespace dgl {
namespace aten {
IdArray NewIdArray(int64_t length, DLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DLDataType{kDLInt, nbits, 1}, ctx);
IdArray NewIdArray(int64_t length, DGLContext ctx, uint8_t nbits) {
return IdArray::Empty({length}, DGLDataType{kDGLInt, nbits, 1}, ctx);
}
IdArray Clone(IdArray arr) {
@@ -29,7 +29,7 @@ IdArray Clone(IdArray arr) {
return ret;
}
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
IdArray Range(int64_t low, int64_t high, uint8_t nbits, DGLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Range", {
if (nbits == 32) {
@@ -43,7 +43,7 @@ IdArray Range(int64_t low, int64_t high, uint8_t nbits, DLContext ctx) {
return ret;
}
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
IdArray Full(int64_t val, int64_t length, uint8_t nbits, DGLContext ctx) {
IdArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
if (nbits == 32) {
@@ -58,7 +58,7 @@ IdArray Full(int64_t val, int64_t length, uint8_t nbits, DLContext ctx) {
}
template <typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret;
ATEN_XPU_SWITCH_CUDA(ctx.device_type, XPU, "Full", {
ret = impl::Full<XPU, DType>(val, length, ctx);
@@ -66,10 +66,10 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret;
}
template NDArray Full<int32_t>(int32_t val, int64_t length, DLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<float>(float val, int64_t length, DLContext ctx);
template NDArray Full<double>(double val, int64_t length, DLContext ctx);
template NDArray Full<int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<double>(double val, int64_t length, DGLContext ctx);
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64)
@@ -315,7 +315,7 @@ std::pair<IdArray, IdArray> Sort(IdArray array, const int num_bits) {
std::string ToDebugString(NDArray array) {
std::ostringstream oss;
NDArray a = array.CopyTo(DLContext{kDLCPU, 0});
NDArray a = array.CopyTo(DGLContext{kDGLCPU, 0});
oss << "array([";
ATEN_DTYPE_SWITCH(a->dtype, DType, "array", {
for (int64_t i = 0; i < std::min<int64_t>(a.NumElements(), 10L); ++i) {
@@ -1132,10 +1132,10 @@ DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLExistSharedMemArray")
DGL_REGISTER_GLOBAL("ndarray._CAPI_DGLArrayCastToSigned")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
CHECK_EQ(array->dtype.code, kDLUInt);
CHECK_EQ(array->dtype.code, kDGLUInt);
std::vector<int64_t> shape(array->shape, array->shape + array->ndim);
DLDataType dtype = array->dtype;
dtype.code = kDLInt;
DGLDataType dtype = array->dtype;
dtype.code = kDGLInt;
*rv = array.CreateView(shape, dtype, 0);
});

View File

@@ -16,176 +16,176 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DLContext ctx);
template <DGLDeviceType XPU, typename IdType>
IdArray Full(IdType val, int64_t length, DGLContext ctx);
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx);
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray array);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
IdArray NonZero(BoolArray bool_arr);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray Concat(const std::vector<IdArray>& arrays);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(NDArray array);
// sparse arrays
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRIsNonZero(CSRMatrix csr, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowNNZ(CSRMatrix csr, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray CSRGetRowData(CSRMatrix csr, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols, bool return_eids,
runtime::NDArray weights, DType filler);
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
runtime::NDArray CSRGetData(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols,
runtime::NDArray weights, DType filler) {
return CSRGetData<XPU, IdType, DType>(csr, rows, cols, false, weights, filler);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetData(CSRMatrix csr, NDArray rows, NDArray cols) {
return CSRGetData<XPU, IdType, IdType>(csr, rows, cols, true, NullArray(rows->dtype), -1);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> CSRGetDataAndIndices(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr);
// Convert CSR to COO
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr);
// Convert CSR to COO using data array as order
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType>
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, IdArray tag_array, int64_t num_tags);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_ids, runtime::NDArray new_col_ids);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(
CSRMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(
CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace,
bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(
CSRMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(
CSRMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending);
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
@@ -194,7 +194,7 @@ COOMatrix CSRRowWiseSamplingBiased(
FloatArray bias,
bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
@@ -204,117 +204,117 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
double redundancy);
// Union CSRMatrixes
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr);
///////////////////////////////////////////////////////////////////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOIsNonZero(COOMatrix coo, runtime::NDArray row, runtime::NDArray col);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetRowNNZ(COOMatrix coo, runtime::NDArray row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<runtime::NDArray, runtime::NDArray>
COOGetRowDataAndIndices(COOMatrix coo, int64_t row);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<runtime::NDArray> COOGetDataAndIndices(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
runtime::NDArray COOGetData(COOMatrix mat, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, runtime::NDArray rows);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* mat, bool sort_column);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseSampling(
COOMatrix mat, IdArray rows, int64_t num_samples, FloatArray prob, bool replace);
// FloatType is the type of probability data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(
COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples, FloatArray prob, bool replace, bool etype_sorted);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWiseSamplingUniform(
COOMatrix mat, IdArray rows, int64_t num_samples, bool replace);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORowWisePerEtypeSamplingUniform(
COOMatrix mat, IdArray rows, IdArray etypes, const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted);
// FloatType is the type of weight data.
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, FloatArray weight, bool ascending);
///////////////////////// Graph Traverse routines //////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
const bool has_nontree_edge,
const bool return_labels);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking);
} // namespace impl

View File

@@ -16,7 +16,7 @@ namespace aten {
// Check whether the given arguments have the same context.
inline void CheckCtx(
const DLContext& ctx,
const DGLContext& ctx,
const std::vector<NDArray>& arrays,
const std::vector<std::string>& names) {
for (size_t i = 0; i < arrays.size(); ++i) {

View File

@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
@@ -34,8 +34,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
}
}
template IdArray CumSum<kDLCPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLCPU, int64_t>(IdArray, bool);
template IdArray CumSum<kDGLCPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDGLCPU, int64_t>(IdArray, bool);
} // namespace impl
} // namespace aten

View File

@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
CHECK_EQ(array->shape[0], array.NumElements()) << "Only support tensor"
<< " whose first dimension equals number of elements, e.g. (5,), (5, 1)";
@@ -28,25 +28,25 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template NDArray IndexSelect<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
const DType* data = static_cast<DType*>(array->data);
return data[index];
}
template int32_t IndexSelect<kDLCPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLCPU, int64_t>(NDArray array, int64_t index);
template float IndexSelect<kDLCPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLCPU, double>(NDArray array, int64_t index);
template int32_t IndexSelect<kDGLCPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCPU, int64_t>(NDArray array, int64_t index);
template float IndexSelect<kDGLCPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCPU, double>(NDArray array, int64_t index);
} // namespace impl
} // namespace aten

View File

@@ -10,7 +10,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
std::vector<int64_t> ret;
const IdType* data = array.Ptr<IdType>();
@@ -20,8 +20,8 @@ IdArray NonZero(IdArray array) {
return NDArray::FromVector(ret, array->ctx);
}
template IdArray NonZero<kDLCPU, int32_t>(IdArray);
template IdArray NonZero<kDLCPU, int64_t>(IdArray);
template IdArray NonZero<kDGLCPU, int32_t>(IdArray);
template IdArray NonZero<kDGLCPU, int64_t>(IdArray);
} // namespace impl
} // namespace aten

View File

@@ -17,7 +17,7 @@ namespace impl {
///////////////////////////// AsNumBits /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
CHECK(bits == 32 || bits == 64) << "invalid number of integer bits";
if (sizeof(IdType) * 8 == bits) {
@@ -40,12 +40,12 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
return ret;
}
template IdArray AsNumBits<kDLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLCPU, int64_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCPU, int64_t>(IdArray arr, uint8_t bits);
///////////////////////////// BinaryElewise /////////////////////////////
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
@@ -59,30 +59,30 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
@@ -95,30 +95,30 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
const IdType* rhs_data = static_cast<IdType*>(rhs->data);
@@ -131,30 +131,30 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) {
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
const IdType* lhs_data = static_cast<IdType*>(lhs->data);
@@ -167,28 +167,28 @@ IdArray UnaryElewise(IdArray lhs) {
return ret;
}
template IdArray UnaryElewise<kDLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLCPU, int64_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCPU, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data);
std::fill(ret_data, ret_data + length, val);
return ret;
}
template NDArray Full<kDLCPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, float>(float val, int64_t length, DLContext ctx);
template NDArray Full<kDLCPU, double>(double val, int64_t length, DLContext ctx);
template NDArray Full<kDGLCPU, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, int64_t>(int64_t val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, float>(float val, int64_t length, DGLContext ctx);
template NDArray Full<kDGLCPU, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range /////////////////////////////
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
IdArray ret = NewIdArray(high - low, ctx, sizeof(IdType) * 8);
IdType* ret_data = static_cast<IdType*>(ret->data);
@@ -196,12 +196,12 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret;
}
template IdArray Range<kDLCPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLCPU, int64_t>(int64_t, int64_t, DLContext);
template IdArray Range<kDGLCPU, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDGLCPU, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
// build map & relabel
IdType newid = 0;
@@ -216,7 +216,7 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
}
}
// map array
IdArray maparr = NewIdArray(newid, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray maparr = NewIdArray(newid, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* maparr_data = static_cast<IdType*>(maparr->data);
for (const auto& kv : oldv2newv) {
maparr_data[kv.second] = kv.first;
@@ -224,8 +224,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return maparr;
}
template IdArray Relabel_<kDLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLCPU, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCPU, int64_t>(const std::vector<IdArray>& arrays);
} // namespace impl
} // namespace aten

View File

@@ -14,7 +14,7 @@ using runtime::parallel_for;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
const int64_t rows = lengths->shape[0];
const int64_t cols = (array->ndim == 1 ? array->shape[0] : array->shape[1]);
@@ -41,16 +41,16 @@ std::pair<NDArray, IdArray> ConcatSlices(NDArray array, IdArray lengths) {
return std::make_pair(concat, offsets);
}
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDLCPU, double, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int32_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, float, int64_t>(NDArray, IdArray);
template std::pair<NDArray, IdArray> ConcatSlices<kDGLCPU, double, int64_t>(NDArray, IdArray);
template<DLDeviceType XPU, typename DType>
template<DGLDeviceType XPU, typename DType>
std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
CHECK_NDIM(array, 2, "array");
const DType *array_data = static_cast<DType *>(array->data);
@@ -75,10 +75,10 @@ std::tuple<NDArray, IdArray, IdArray> Pack(NDArray array, DType pad_value) {
return std::make_tuple(ret.first, length, ret.second);
}
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDLCPU, double>(NDArray, double);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int32_t>(NDArray, int32_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, int64_t>(NDArray, int64_t);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, float>(NDArray, float);
template std::tuple<NDArray, IdArray, IdArray> Pack<kDGLCPU, double>(NDArray, double);
} // namespace impl
} // namespace aten

View File

@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Repeat(NDArray array, IdArray repeats) {
CHECK(array->shape[0] == repeats->shape[0]) << "shape of array and repeats mismatch";
@@ -34,14 +34,14 @@ NDArray Repeat(NDArray array, IdArray repeats) {
return result;
}
template NDArray Repeat<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Repeat<kDGLCPU, double, int64_t>(NDArray, IdArray);
}; // namespace impl
}; // namespace aten

View File

@@ -11,7 +11,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
NDArray Scatter(NDArray array, IdArray indices) {
NDArray result = NDArray::Empty({indices->shape[0]}, array->dtype, array->ctx);
@@ -25,16 +25,16 @@ NDArray Scatter(NDArray array, IdArray indices) {
return result;
}
template NDArray Scatter<kDLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDLCPU, double, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, float, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, double, int32_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, float, int64_t>(NDArray, IdArray);
template NDArray Scatter<kDGLCPU, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
@@ -47,14 +47,14 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
});
}
template void Scatter_<kDLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLCPU, double, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCPU, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten

View File

@@ -160,7 +160,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
const int64_t nitem = array->shape[0];
IdArray val = array.Clone();
@@ -181,8 +181,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int /* num_bits */) {
return std::make_pair(val, idx);
}
template std::pair<IdArray, IdArray> Sort<kDLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLCPU, int64_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCPU, int64_t>(IdArray, int num_bits);
} // namespace impl
} // namespace aten

View File

@@ -88,7 +88,7 @@ class IdHashMap {
// Return all the old ids collected so far, ordered by new id.
IdArray Values() const {
IdArray values = NewIdArray(oldv2newv_.size(), DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray values = NewIdArray(oldv2newv_.size(), DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* values_data = static_cast<IdType*>(values->data);
for (auto pair : oldv2newv_)
values_data[pair.second] = pair.first;

View File

@@ -13,7 +13,7 @@ namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
@@ -44,8 +44,8 @@ std::pair<COOMatrix, IdArray> COOCoalesce(COOMatrix coo) {
return std::make_pair(coo_result, NDArray::FromVector(count));
}
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDLCPU, int64_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int32_t>(COOMatrix);
template std::pair<COOMatrix, IdArray> COOCoalesce<kDGLCPU, int64_t>(COOMatrix);
}; // namespace impl
}; // namespace aten

View File

@@ -14,7 +14,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
const int64_t nnz = coo.row->shape[0];
IdType* coo_row = coo.row.Ptr<IdType>();
@@ -50,8 +50,8 @@ COOMatrix COOLineGraph(const COOMatrix &coo, bool backtracking) {
}
template COOMatrix COOLineGraph<kDLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int32_t>(const COOMatrix &coo, bool backtracking);
template COOMatrix COOLineGraph<kDGLCPU, int64_t>(const COOMatrix &coo, bool backtracking);
} // namespace impl
} // namespace aten

View File

@@ -16,7 +16,7 @@ namespace impl {
namespace {
/*! \brief COORemove implementation for COOMatrix with default consecutive edge IDs */
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COORemoveConsecutive(
COOMatrix coo,
IdArray entries,
@@ -47,7 +47,7 @@ void COORemoveConsecutive(
}
/*! \brief COORemove implementation for COOMatrix with shuffled edge IDs */
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COORemoveShuffled(
COOMatrix coo,
IdArray entries,
@@ -73,7 +73,7 @@ void COORemoveShuffled(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COORemove(COOMatrix coo, IdArray entries) {
const int64_t nnz = coo.row->shape[0];
const int64_t n_entries = entries->shape[0];
@@ -98,8 +98,8 @@ COOMatrix COORemove(COOMatrix coo, IdArray entries) {
IdArray::FromVector(new_eids));
}
template COOMatrix COORemove<kDLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDLCPU, int64_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDGLCPU, int32_t>(COOMatrix coo, IdArray entries);
template COOMatrix COORemove<kDGLCPU, int64_t>(COOMatrix coo, IdArray entries);
}; // namespace impl
}; // namespace aten

View File

@@ -167,7 +167,7 @@ namespace impl {
///////////////////////////// COOSort_ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
const int64_t nnz = coo->row->shape[0];
IdType* coo_row = coo->row.Ptr<IdType>();
@@ -208,13 +208,13 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
coo->col_sorted = sort_column;
}
template void COOSort_<kDLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDLCPU, int64_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int32_t>(COOMatrix*, bool);
template void COOSort_<kDGLCPU, int64_t>(COOMatrix*, bool);
///////////////////////////// COOIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
IdType* row = coo.row.Ptr<IdType>();
@@ -230,8 +230,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted};
}
template std::pair<bool, bool> COOIsSorted<kDLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLCPU, int64_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCPU, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten

View File

@@ -17,7 +17,7 @@ using runtime::parallel_for;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *ret_vec) {
@@ -38,7 +38,7 @@ void CollectDataFromSorted(const IdType *indices_data, const IdType *data,
}
}
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
@@ -59,7 +59,7 @@ NDArray CSRGetData(
const int64_t retlen = std::max(rowlen, collen);
const DType* weight_data = return_eids ? nullptr : weights.Ptr<DType>();
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
NDArray ret = Full(filler, retlen, rows->ctx);
@@ -106,19 +106,19 @@ NDArray CSRGetData(
return ret;
}
template NDArray CSRGetData<kDLCPU, int32_t, float>(
template NDArray CSRGetData<kDGLCPU, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int64_t, float>(
template NDArray CSRGetData<kDGLCPU, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLCPU, int32_t, double>(
template NDArray CSRGetData<kDGLCPU, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLCPU, int64_t, double>(
template NDArray CSRGetData<kDGLCPU, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLCPU, int32_t, int32_t>(
template NDArray CSRGetData<kDGLCPU, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLCPU, int64_t, int64_t>(
template NDArray CSRGetData<kDGLCPU, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl

View File

@@ -134,13 +134,13 @@ std::pair<CSRMatrix, NDArray> CSRMM(
C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLCPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCPU, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
}; // namespace aten

View File

@@ -15,7 +15,7 @@ namespace impl {
namespace {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRRemoveConsecutive(
CSRMatrix csr,
IdArray entries,
@@ -48,7 +48,7 @@ void CSRRemoveConsecutive(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRRemoveShuffled(
CSRMatrix csr,
IdArray entries,
@@ -77,7 +77,7 @@ void CSRRemoveShuffled(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
CHECK_SAME_DTYPE(csr.indices, entries);
const int64_t nnz = csr.indices->shape[0];
@@ -103,8 +103,8 @@ CSRMatrix CSRRemove(CSRMatrix csr, IdArray entries) {
IdArray::FromVector(new_eids));
}
template CSRMatrix CSRRemove<kDLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDLCPU, int64_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDGLCPU, int32_t>(CSRMatrix csr, IdArray entries);
template CSRMatrix CSRRemove<kDGLCPU, int64_t>(CSRMatrix csr, IdArray entries);
}; // namespace impl
}; // namespace aten

View File

@@ -14,7 +14,7 @@ namespace aten {
namespace impl {
///////////////////////////// CSRIsSorted /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
const IdType* indptr = csr.indptr.Ptr<IdType>();
const IdType* indices = csr.indices.Ptr<IdType>();
@@ -31,12 +31,12 @@ bool CSRIsSorted(CSRMatrix csr) {
[](bool a, bool b) { return a && b; });
}
template bool CSRIsSorted<kDLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCPU, int32_t>(CSRMatrix csr);
///////////////////////////// CSRSort /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
typedef std::pair<IdType, IdType> ShufflePair;
const int64_t num_rows = csr->num_rows;
@@ -79,10 +79,10 @@ void CSRSort_(CSRMatrix* csr) {
csr->sorted = true;
}
template void CSRSort_<kDLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDLCPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCPU, int32_t>(CSRMatrix* csr);
template <DLDeviceType XPU, typename IdType, typename TagType>
template <DGLDeviceType XPU, typename IdType, typename TagType>
std::pair<CSRMatrix, NDArray> CSRSortByTag(
const CSRMatrix &csr, const IdArray tag_array, int64_t num_tags) {
const auto indptr_data = static_cast<const IdType *>(csr.indptr->data);
@@ -143,13 +143,13 @@ std::pair<CSRMatrix, NDArray> CSRSortByTag(
return std::make_pair(output, tag_pos);
}
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int64_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int64_t, int32_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int64_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int64_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int64_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDLCPU, int32_t, int32_t>(
template std::pair<CSRMatrix, NDArray> CSRSortByTag<kDGLCPU, int32_t, int32_t>(
const CSRMatrix &csr, const IdArray tag, int64_t num_tags);
} // namespace impl

View File

@@ -130,13 +130,13 @@ std::pair<CSRMatrix, NDArray> CSRSum(
C_weights};
}
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLCPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCPU, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
}; // namespace aten

View File

@@ -12,7 +12,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
if (!csr.sorted)
csr = CSRSort(csr);
@@ -67,8 +67,8 @@ std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple(CSRMatrix csr) {
return std::make_tuple(res_csr, edge_count, eids_remapped);
}
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int32_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDLCPU, int64_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int32_t>(CSRMatrix);
template std::tuple<CSRMatrix, IdArray, IdArray> CSRToSimple<kDGLCPU, int64_t>(CSRMatrix);
} // namespace impl
} // namespace aten

View File

@@ -14,7 +14,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
std::vector<IdType> res_indptr;
std::vector<IdType> res_indices;
@@ -109,8 +109,8 @@ CSRMatrix UnionCsr(const std::vector<CSRMatrix>& csrs) {
sorted);
}
template CSRMatrix UnionCsr<kDLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDLCPU, int32_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDGLCPU, int64_t>(const std::vector<CSRMatrix>&);
template CSRMatrix UnionCsr<kDGLCPU, int32_t>(const std::vector<CSRMatrix>&);
} // namespace impl
} // namespace aten

View File

@@ -26,7 +26,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdArray prefix_src_arr = NewIdArray(
coos.size(), coos[0].row->ctx, coos[0].row->dtype.bits);
@@ -52,7 +52,7 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(prefix_elm_arr, true));
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
bool has_data = false;
bool row_sorted = true;
@@ -118,8 +118,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted);
}
template COOMatrix DisjointUnionCoo<kDLCPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLCPU, int64_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCPU, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl
} // namespace aten

View File

@@ -62,74 +62,74 @@ void GatherMMScatter(const NDArray A,
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}
template void GatherMM<kDLCPU, int32_t, 16>(
template void GatherMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 16>(
template void GatherMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 32>(
template void GatherMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 32>(
template void GatherMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int32_t, 64>(
template void GatherMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLCPU, int64_t, 64>(
template void GatherMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLCPU, int32_t, 16>(
template void GatherMMScatter<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 16>(
template void GatherMMScatter<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 32>(
template void GatherMMScatter<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 32>(
template void GatherMMScatter<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int32_t, 64>(
template void GatherMMScatter<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLCPU, int64_t, 64>(
template void GatherMMScatter<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLCPU, int32_t, 16>(
template void SegmentMM<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 16>(
template void SegmentMM<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 32>(
template void SegmentMM<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 32>(
template void SegmentMM<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int32_t, 64>(
template void SegmentMM<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLCPU, int64_t, 64>(
template void SegmentMM<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLCPU, int32_t, 16>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 16>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 32>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int32_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLCPU, int64_t, 64>(
template void SegmentMMBackwardB<kDGLCPU, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten

View File

@@ -17,7 +17,7 @@ namespace dgl {
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix &csr,
int64_t num_samples,
@@ -61,9 +61,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return {row.CreateView({num_sampled}, row->dtype), col.CreateView({num_sampled}, col->dtype)};
}
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int32_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLCPU, int64_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCPU, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl

View File

@@ -97,13 +97,13 @@ COOMatrix CSRRowWisePick(CSRMatrix mat, IdArray rows,
// [02/29/2020 update]: OMP is disabled for now since batch-wise parallelism is more
// significant. (minjie)
IdArray picked_row = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_col = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdArray picked_idx = NDArray::Empty({num_rows * num_picks},
DLDataType{kDLInt, 8*sizeof(IdxType), 1},
DGLDataType{kDGLInt, 8*sizeof(IdxType), 1},
ctx);
IdxType* picked_rdata = static_cast<IdxType*>(picked_row->data);
IdxType* picked_cdata = static_cast<IdxType*>(picked_col->data);

View File

@@ -117,7 +117,7 @@ inline PickFn<IdxType> GetSamplingBiasedPickFn(
/////////////////////////////// CSR ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
@@ -125,16 +125,16 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat, IdArray rows, int64_t num_samples,
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
@@ -143,28 +143,28 @@ COOMatrix CSRRowWisePerEtypeSampling(CSRMatrix mat, IdArray rows, IdArray etypes
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix CSRRowWisePerEtypeSampling<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLCPU, int64_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
@@ -172,12 +172,12 @@ COOMatrix CSRRowWisePerEtypeSamplingUniform(CSRMatrix mat, IdArray rows, IdArray
return CSRRowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
template COOMatrix CSRRowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
CSRMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix CSRRowWiseSamplingBiased(
CSRMatrix mat,
IdArray rows,
@@ -191,22 +191,22 @@ COOMatrix CSRRowWiseSamplingBiased(
return CSRRowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
template COOMatrix CSRRowWiseSamplingBiased<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseSamplingBiased<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, FloatArray, bool);
/////////////////////////////// COO ///////////////////////////////
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
FloatArray prob, bool replace) {
CHECK(prob.defined());
@@ -214,16 +214,16 @@ COOMatrix COORowWiseSampling(COOMatrix mat, IdArray rows, int64_t num_samples,
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, float>(
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, float>(
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int32_t, double>(
template COOMatrix COORowWiseSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix COORowWiseSampling<kDLCPU, int64_t, double>(
template COOMatrix COORowWiseSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, FloatArray, bool);
template <DLDeviceType XPU, typename IdxType, typename FloatType>
template <DGLDeviceType XPU, typename IdxType, typename FloatType>
COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
FloatArray prob, bool replace, bool etype_sorted) {
@@ -232,28 +232,28 @@ COOMatrix COORowWisePerEtypeSampling(COOMatrix mat, IdArray rows, IdArray etypes
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, float>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, float>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int32_t, double>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template COOMatrix COORowWisePerEtypeSampling<kDLCPU, int64_t, double>(
template COOMatrix COORowWisePerEtypeSampling<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, FloatArray, bool, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWiseSamplingUniform(COOMatrix mat, IdArray rows,
int64_t num_samples, bool replace) {
auto pick_fn = GetSamplingUniformPickFn<IdxType>(num_samples, replace);
return COORowWisePick(mat, rows, num_samples, replace, pick_fn);
}
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int32_t>(
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, int64_t, bool);
template COOMatrix COORowWiseSamplingUniform<kDLCPU, int64_t>(
template COOMatrix COORowWiseSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, int64_t, bool);
template <DLDeviceType XPU, typename IdxType>
template <DGLDeviceType XPU, typename IdxType>
COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray etypes,
const std::vector<int64_t>& num_samples,
bool replace, bool etype_sorted) {
@@ -261,9 +261,9 @@ COOMatrix COORowWisePerEtypeSamplingUniform(COOMatrix mat, IdArray rows, IdArray
return COORowWisePerEtypePick(mat, rows, etypes, num_samples, replace, etype_sorted, pick_fn);
}
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int32_t>(
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int32_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
template COOMatrix COORowWisePerEtypeSamplingUniform<kDLCPU, int64_t>(
template COOMatrix COORowWisePerEtypeSamplingUniform<kDGLCPU, int64_t>(
COOMatrix, IdArray, IdArray, const std::vector<int64_t>&, bool, bool);
} // namespace impl

View File

@@ -55,52 +55,52 @@ inline PickFn<IdxType> GetTopkPickFn(int64_t k, NDArray weight, bool ascending)
} // namespace
template <DLDeviceType XPU, typename IdxType, typename DType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix CSRRowWiseTopk(
CSRMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return CSRRowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int32_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int32_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int32_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, int64_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, int64_t>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, int64_t>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, float>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, float>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, float>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int32_t, double>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int32_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix CSRRowWiseTopk<kDLCPU, int64_t, double>(
template COOMatrix CSRRowWiseTopk<kDGLCPU, int64_t, double>(
CSRMatrix, IdArray, int64_t, NDArray, bool);
template <DLDeviceType XPU, typename IdxType, typename DType>
template <DGLDeviceType XPU, typename IdxType, typename DType>
COOMatrix COORowWiseTopk(
COOMatrix mat, IdArray rows, int64_t k, NDArray weight, bool ascending) {
auto pick_fn = GetTopkPickFn<IdxType, DType>(k, weight, ascending);
return COORowWisePick(mat, rows, k, false, pick_fn);
}
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int32_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int32_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int32_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, int64_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, int64_t>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, int64_t>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, float>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, float>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, float>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int32_t, double>(
template COOMatrix COORowWiseTopk<kDGLCPU, int32_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool);
template COOMatrix COORowWiseTopk<kDLCPU, int64_t, double>(
template COOMatrix COORowWiseTopk<kDGLCPU, int64_t, double>(
COOMatrix, IdArray, int64_t, NDArray, bool);
} // namespace impl

View File

@@ -102,67 +102,67 @@ void SDDMMCsrHetero(const std::string& op,
});
}
template void SDDMMCsr<kDLCPU, int32_t, 16>(
template void SDDMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 16>(
template void SDDMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 32>(
template void SDDMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 32>(
template void SDDMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int32_t, 64>(
template void SDDMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLCPU, int64_t, 64>(
template void SDDMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsrHetero<kDLCPU, int32_t, 16>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 16>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 32>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int32_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLCPU, int64_t, 64>(
template void SDDMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
@@ -217,67 +217,67 @@ void SDDMMCooHetero(const std::string& op,
});
}
template void SDDMMCoo<kDLCPU, int32_t, 16>(
template void SDDMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 16>(
template void SDDMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 32>(
template void SDDMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 32>(
template void SDDMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int32_t, 64>(
template void SDDMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLCPU, int64_t, 64>(
template void SDDMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCooHetero<kDLCPU, int32_t, 16>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 16>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 32>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int32_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLCPU, int64_t, 64>(
template void SDDMMCooHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,

View File

@@ -74,113 +74,113 @@ void BackwardSegmentCmp(
});
}
template void SegmentReduce<kDLCPU, int32_t, 16>(
template void SegmentReduce<kDGLCPU, int32_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 16>(
template void SegmentReduce<kDGLCPU, int64_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 32>(
template void SegmentReduce<kDGLCPU, int32_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 32>(
template void SegmentReduce<kDGLCPU, int64_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int32_t, 64>(
template void SegmentReduce<kDGLCPU, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLCPU, int64_t, 64>(
template void SegmentReduce<kDGLCPU, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDLCPU, int32_t, 16>(
template void ScatterAdd<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 16>(
template void ScatterAdd<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 32>(
template void ScatterAdd<kDGLCPU, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 32>(
template void ScatterAdd<kDGLCPU, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int32_t, 64>(
template void ScatterAdd<kDGLCPU, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLCPU, int64_t, 64>(
template void ScatterAdd<kDGLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLCPU, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 16>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 16>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 32>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int32_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int32_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLCPU, int64_t, 64>(
template void BackwardSegmentCmp<kDGLCPU, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);

View File

@@ -29,7 +29,7 @@ namespace impl {
///////////////////////////// COOIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
CHECK(col >= 0 && col < coo.num_cols) << "Invalid col index: " << col;
@@ -42,10 +42,10 @@ bool COOIsNonZero(COOMatrix coo, int64_t row, int64_t col) {
return false;
}
template bool COOIsNonZero<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template bool COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
@@ -67,12 +67,12 @@ NDArray COOIsNonZero(COOMatrix coo, NDArray row, NDArray col) {
return rst;
}
template NDArray COOIsNonZero<kDLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDLCPU, int64_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDGLCPU, int32_t>(COOMatrix, NDArray, NDArray);
template NDArray COOIsNonZero<kDGLCPU, int64_t>(COOMatrix, NDArray, NDArray);
///////////////////////////// COOHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool COOHasDuplicate(COOMatrix coo) {
std::unordered_set<std::pair<IdType, IdType>, PairHash> hashmap;
const IdType* src_data = static_cast<IdType*>(coo.row->data);
@@ -89,12 +89,12 @@ bool COOHasDuplicate(COOMatrix coo) {
return false;
}
template bool COOHasDuplicate<kDLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDLCPU, int64_t>(COOMatrix coo);
template bool COOHasDuplicate<kDGLCPU, int32_t>(COOMatrix coo);
template bool COOHasDuplicate<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
@@ -106,10 +106,10 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
return result;
}
template int64_t COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
CHECK_SAME_DTYPE(coo.col, rows);
const auto len = rows->shape[0];
@@ -123,12 +123,12 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
return rst;
}
template NDArray COOGetRowNNZ<kDLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLCPU, int64_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCPU, int64_t>(COOMatrix, NDArray);
///////////////////////////// COOGetRowDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
COOMatrix coo, int64_t row) {
CHECK(row >= 0 && row < coo.num_rows) << "Invalid row index: " << row;
@@ -151,13 +151,13 @@ std::pair<NDArray, NDArray> COOGetRowDataAndIndices(
}
template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int32_t>(COOMatrix, int64_t);
COOGetRowDataAndIndices<kDGLCPU, int32_t>(COOMatrix, int64_t);
template std::pair<NDArray, NDArray>
COOGetRowDataAndIndices<kDLCPU, int64_t>(COOMatrix, int64_t);
COOGetRowDataAndIndices<kDGLCPU, int64_t>(COOMatrix, int64_t);
///////////////////////////// COOGetData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
const int64_t rowlen = rows->shape[0];
const int64_t collen = cols->shape[0];
@@ -211,12 +211,12 @@ IdArray COOGetData(COOMatrix coo, IdArray rows, IdArray cols) {
return ret;
}
template IdArray COOGetData<kDLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDLCPU, int64_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDGLCPU, int32_t>(COOMatrix, IdArray, IdArray);
template IdArray COOGetData<kDGLCPU, int64_t>(COOMatrix, IdArray, IdArray);
///////////////////////////// COOGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray cols) {
CHECK_SAME_DTYPE(coo.col, rows);
@@ -286,20 +286,20 @@ std::vector<NDArray> COOGetDataAndIndices(COOMatrix coo, NDArray rows,
NDArray::FromVector(ret_data)};
}
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int32_t>(
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int32_t>(
COOMatrix coo, NDArray rows, NDArray cols);
template std::vector<NDArray> COOGetDataAndIndices<kDLCPU, int64_t>(
template std::vector<NDArray> COOGetDataAndIndices<kDGLCPU, int64_t>(
COOMatrix coo, NDArray rows, NDArray cols);
///////////////////////////// COOTranspose /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOTranspose(COOMatrix coo) {
return COOMatrix{coo.num_cols, coo.num_rows, coo.col, coo.row, coo.data};
}
template COOMatrix COOTranspose<kDLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDLCPU, int64_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDGLCPU, int32_t>(COOMatrix coo);
template COOMatrix COOTranspose<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOToCSR /////////////////////////////
namespace {
@@ -615,7 +615,7 @@ P^2).
degree), UnSortedDenseCOOToCSR<> is applied. Time: O(NNZ/P + N/P), space O(NNZ +
N*P).
*/
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
if (!coo.row_sorted) {
const int64_t num_threads = omp_get_num_threads();
@@ -632,12 +632,12 @@ CSRMatrix COOToCSR(COOMatrix coo) {
return SortedCOOToCSR<IdType>(coo);
}
template CSRMatrix COOToCSR<kDLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLCPU, int64_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCPU, int64_t>(COOMatrix coo);
///////////////////////////// COOSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
// TODO(minjie): use binary search when coo.row_sorted is true
CHECK(start >= 0 && start < coo.num_rows) << "Invalid start row " << start;
@@ -669,10 +669,10 @@ COOMatrix COOSliceRows(COOMatrix coo, int64_t start, int64_t end) {
coo.col_sorted);
}
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix, int64_t, int64_t);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
@@ -703,12 +703,12 @@ COOMatrix COOSliceRows(COOMatrix coo, NDArray rows) {
coo.row_sorted, coo.col_sorted};
}
template COOMatrix COOSliceRows<kDLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDLCPU, int64_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int32_t>(COOMatrix , NDArray);
template COOMatrix COOSliceRows<kDGLCPU, int64_t>(COOMatrix , NDArray);
///////////////////////////// COOSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols) {
const IdType* coo_row_data = static_cast<IdType*>(coo.row->data);
const IdType* coo_col_data = static_cast<IdType*>(coo.col->data);
@@ -740,15 +740,15 @@ COOMatrix COOSliceMatrix(COOMatrix coo, runtime::NDArray rows, runtime::NDArray
coo.row_sorted, coo.col_sorted);
}
template COOMatrix COOSliceMatrix<kDLCPU, int32_t>(
template COOMatrix COOSliceMatrix<kDGLCPU, int32_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
template COOMatrix COOSliceMatrix<kDLCPU, int64_t>(
template COOMatrix COOSliceMatrix<kDGLCPU, int64_t>(
COOMatrix coo, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// COOReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(coo.row, new_row_id_arr);
@@ -785,9 +785,9 @@ COOMatrix COOReorder(COOMatrix coo, runtime::NDArray new_row_id_arr,
return COOMatrix(num_rows, num_cols, out_row_arr, out_col_arr, out_data_arr);
}
template COOMatrix COOReorder<kDLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
template COOMatrix COOReorder<kDGLCPU, int64_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
template COOMatrix COOReorder<kDLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
template COOMatrix COOReorder<kDGLCPU, int32_t>(COOMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
} // namespace impl

View File

@@ -21,7 +21,7 @@ namespace impl {
///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
@@ -39,10 +39,10 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
return false;
}
template bool CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
@@ -62,12 +62,12 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst;
}
template NDArray CSRIsNonZero<kDLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCPU, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
const IdType* indices_data = static_cast<IdType*>(csr.indices->data);
@@ -85,21 +85,21 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return false;
}
template bool CSRHasDuplicate<kDLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLCPU, int64_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
return indptr_data[row + 1] - indptr_data[row];
}
template int64_t CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const auto len = rows->shape[0];
@@ -114,12 +114,12 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return rst;
}
template NDArray CSRGetRowNNZ<kDLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLCPU, int64_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCPU, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -127,12 +127,12 @@ NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}
template NDArray CSRGetRowColumnIndices<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLCPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -143,13 +143,13 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
}
template NDArray CSRGetRowData<kDLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLCPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCPU, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetData /////////////////////////////
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data,
const IdType start, const IdType end, const IdType col,
std::vector<IdType> *col_vec,
@@ -172,7 +172,7 @@ void CollectDataIndicesFromSorted(const IdType *indices_data, const IdType *data
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray cols) {
// TODO(minjie): more efficient implementation for matrix without duplicate entries
const int64_t rowlen = rows->shape[0];
@@ -224,16 +224,16 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray rows, NDArray c
NDArray::FromVector(ret_data, csr.data->ctx)};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int32_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLCPU, int64_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCPU, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRTranspose /////////////////////////////
// for a matrix of shape (N, M) and NNZ
// complexity: time O(NNZ + max(N, M)), space O(1)
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
@@ -281,11 +281,11 @@ CSRMatrix CSRTranspose(CSRMatrix csr) {
return CSRMatrix{csr.num_cols, csr.num_rows, ret_indptr, ret_indices, ret_data};
}
template CSRMatrix CSRTranspose<kDLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLCPU, int64_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRToCOO /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
const int64_t nnz = csr.indices->shape[0];
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -303,11 +303,11 @@ COOMatrix CSRToCOO(CSRMatrix csr) {
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLCPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCPU, int64_t>(CSRMatrix csr);
// complexity: time O(NNZ), space O(1)
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
const int64_t N = csr.num_rows;
const int64_t M = csr.num_cols;
@@ -333,12 +333,12 @@ COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
return COOMatrix(N, M, ret_row, ret_col);
}
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLCPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCPU, int64_t>(CSRMatrix csr);
///////////////////////////// CSRSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
const IdType* indptr = static_cast<IdType*>(csr.indptr->data);
const int64_t num_rows = end - start;
@@ -362,10 +362,10 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
CHECK_SAME_DTYPE(csr.indices, rows);
const IdType* indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -467,12 +467,12 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
return ret;
}
template CSRMatrix CSRSliceRows<kDLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLCPU, int64_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCPU, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRSliceMatrix /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
IdHashMap<IdType> hashmap(cols);
const int64_t new_nrows = rows->shape[0];
@@ -521,14 +521,14 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
sub_data_arr};
}
template CSRMatrix CSRSliceMatrix<kDLCPU, int32_t>(
template CSRMatrix CSRSliceMatrix<kDGLCPU, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLCPU, int64_t>(
template CSRMatrix CSRSliceMatrix<kDGLCPU, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
///////////////////////////// CSRReorder /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
runtime::NDArray new_col_id_arr) {
CHECK_SAME_DTYPE(csr.indices, new_row_id_arr);
@@ -599,9 +599,9 @@ CSRMatrix CSRReorder(CSRMatrix csr, runtime::NDArray new_row_id_arr,
out_indptr_arr, out_indices_arr, out_data_arr);
}
template CSRMatrix CSRReorder<kDLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
template CSRMatrix CSRReorder<kDGLCPU, int64_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
template CSRMatrix CSRReorder<kDLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
template CSRMatrix CSRReorder<kDGLCPU, int32_t>(CSRMatrix csr, runtime::NDArray new_row_ids,
runtime::NDArray new_col_ids);
} // namespace impl

View File

@@ -124,67 +124,67 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
}
}
template void SpMMCsr<kDLCPU, int32_t, 16>(
template void SpMMCsr<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 16>(
template void SpMMCsr<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 32>(
template void SpMMCsr<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 32>(
template void SpMMCsr<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int32_t, 64>(
template void SpMMCsr<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLCPU, int64_t, 64>(
template void SpMMCsr<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsrHetero<kDLCPU, int32_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 16>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 32>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 32>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int32_t, 64>(
template void SpMMCsrHetero<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_node_tids,
const std::vector<dgl_type_t>& out_node_tids);
template void SpMMCsrHetero<kDLCPU, int64_t, 64>(
template void SpMMCsrHetero<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
@@ -222,52 +222,52 @@ void Edge_softmax_csr_backward(const std::string& op,
});
}
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 16>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 16>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 32>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int32_t, 64>(
template void Edge_softmax_csr_forward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_forward<kDLCPU, int64_t, 64>(
template void Edge_softmax_csr_forward<kDGLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 16>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 16>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 16>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 32>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 32>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int32_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int32_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
template void Edge_softmax_csr_backward<kDLCPU, int64_t, 64>(
template void Edge_softmax_csr_backward<kDGLCPU, int64_t, 64>(
const std::string& op,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out);
@@ -303,27 +303,27 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
}
template void SpMMCoo<kDLCPU, int32_t, 16>(
template void SpMMCoo<kDGLCPU, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 16>(
template void SpMMCoo<kDGLCPU, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 32>(
template void SpMMCoo<kDGLCPU, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 32>(
template void SpMMCoo<kDGLCPU, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int32_t, 64>(
template void SpMMCoo<kDGLCPU, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLCPU, int64_t, 64>(
template void SpMMCoo<kDGLCPU, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

View File

@@ -54,8 +54,8 @@ IdArray MergeMultipleTraversals(
total_len += traces[i].size();
}
IdArray ret = IdArray::Empty({total_len},
DLDataType{kDLInt, sizeof(DType) * 8, 1},
DLContext{kDLCPU, 0});
DGLDataType{kDGLInt, sizeof(DType) * 8, 1},
DGLContext{kDGLCPU, 0});
DType* ret_data = static_cast<DType*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
for (size_t j = 0; j < traces.size(); ++j) {
@@ -79,7 +79,7 @@ IdArray ComputeMergedSections(
const int64_t tracelen = traces[i].size();
max_len = std::max(max_len, tracelen);
}
IdArray ret = IdArray::Empty({max_len}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
IdArray ret = IdArray::Empty({max_len}, DGLDataType{kDGLInt, 64, 1}, DGLContext{kDGLCPU, 0});
int64_t* ret_data = static_cast<int64_t*>(ret->data);
for (int64_t i = 0; i < max_len; ++i) {
int64_t sec_len = 0;
@@ -96,7 +96,7 @@ IdArray ComputeMergedSections(
} // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
@@ -116,10 +116,10 @@ Frontiers BFSNodesFrontiers(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers BFSNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
@@ -144,10 +144,10 @@ Frontiers BFSEdgesFrontiers(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers BFSEdgesFrontiers<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers BFSEdgesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
std::vector<IdType> ids;
std::vector<int64_t> sections;
@@ -167,10 +167,10 @@ Frontiers TopologicalNodesFrontiers(const CSRMatrix& csr) {
return front;
}
template Frontiers TopologicalNodesFrontiers<kDLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDLCPU, int64_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int32_t>(const CSRMatrix&);
template Frontiers TopologicalNodesFrontiers<kDGLCPU, int64_t>(const CSRMatrix&);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
const int64_t len = source->shape[0];
const IdType* src_data = static_cast<IdType*>(source->data);
@@ -187,10 +187,10 @@ Frontiers DGLDFSEdges(const CSRMatrix& csr, IdArray source) {
return front;
}
template Frontiers DGLDFSEdges<kDLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDLCPU, int64_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int32_t>(const CSRMatrix&, IdArray);
template Frontiers DGLDFSEdges<kDGLCPU, int64_t>(const CSRMatrix&, IdArray);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
IdArray source,
const bool has_reverse_edge,
@@ -226,12 +226,12 @@ Frontiers DGLDFSLabeledEdges(const CSRMatrix& csr,
return front;
}
template Frontiers DGLDFSLabeledEdges<kDLCPU, int32_t>(const CSRMatrix&,
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int32_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,
const bool);
template Frontiers DGLDFSLabeledEdges<kDLCPU, int64_t>(const CSRMatrix&,
template Frontiers DGLDFSLabeledEdges<kDGLCPU, int64_t>(const CSRMatrix&,
IdArray,
const bool,
const bool,

View File

@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray CumSum(IdArray array, bool prepend_zero) {
const int64_t len = array.NumElements();
if (len == 0)
@@ -46,8 +46,8 @@ IdArray CumSum(IdArray array, bool prepend_zero) {
return ret;
}
template IdArray CumSum<kDLGPU, int32_t>(IdArray, bool);
template IdArray CumSum<kDLGPU, int64_t>(IdArray, bool);
template IdArray CumSum<kDGLCUDA, int32_t>(IdArray, bool);
template IdArray CumSum<kDGLCUDA, int64_t>(IdArray, bool);
} // namespace impl
} // namespace aten

View File

@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template<DLDeviceType XPU, typename DType, typename IdType>
template<DGLDeviceType XPU, typename DType, typename IdType>
NDArray IndexSelect(NDArray array, IdArray index) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const DType* array_data = static_cast<DType*>(array->data);
@@ -51,20 +51,20 @@ NDArray IndexSelect(NDArray array, IdArray index) {
return ret;
}
template NDArray IndexSelect<kDLGPU, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, int64_t, int64_t>(NDArray, IdArray);
#ifdef USE_FP16
template NDArray IndexSelect<kDLGPU, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, __half, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, __half, int64_t>(NDArray, IdArray);
#endif
template NDArray IndexSelect<kDLGPU, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDLGPU, double, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, float, int64_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int32_t>(NDArray, IdArray);
template NDArray IndexSelect<kDGLCUDA, double, int64_t>(NDArray, IdArray);
template <DLDeviceType XPU, typename DType>
template <DGLDeviceType XPU, typename DType>
DType IndexSelect(NDArray array, int64_t index) {
auto device = runtime::DeviceAPI::Get(array->ctx);
#ifdef USE_FP16
@@ -79,20 +79,19 @@ DType IndexSelect(NDArray array, int64_t index) {
#endif
device->CopyDataFromTo(
static_cast<DType*>(array->data) + index, 0, reinterpret_cast<DType*>(&ret), 0,
sizeof(DType), array->ctx, DLContext{kDLCPU, 0},
array->dtype);
sizeof(DType), array->ctx, DGLContext{kDGLCPU, 0}, array->dtype);
return reinterpret_cast<DType&>(ret);
}
template int32_t IndexSelect<kDLGPU, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDLGPU, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDLGPU, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDLGPU, uint64_t>(NDArray array, int64_t index);
template int32_t IndexSelect<kDGLCUDA, int32_t>(NDArray array, int64_t index);
template int64_t IndexSelect<kDGLCUDA, int64_t>(NDArray array, int64_t index);
template uint32_t IndexSelect<kDGLCUDA, uint32_t>(NDArray array, int64_t index);
template uint64_t IndexSelect<kDGLCUDA, uint64_t>(NDArray array, int64_t index);
#ifdef USE_FP16
template __half IndexSelect<kDLGPU, __half>(NDArray array, int64_t index);
template __half IndexSelect<kDGLCUDA, __half>(NDArray array, int64_t index);
#endif
template float IndexSelect<kDLGPU, float>(NDArray array, int64_t index);
template double IndexSelect<kDLGPU, double>(NDArray array, int64_t index);
template float IndexSelect<kDGLCUDA, float>(NDArray array, int64_t index);
template double IndexSelect<kDGLCUDA, double>(NDArray array, int64_t index);
} // namespace impl
} // namespace aten

View File

@@ -26,7 +26,7 @@ struct IsNonZeroIndex {
const IdType * array_;
};
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray NonZero(IdArray array) {
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
@@ -63,8 +63,8 @@ IdArray NonZero(IdArray array) {
return ret.CreateView({num_nonzeros}, ret->dtype, 0);
}
template IdArray NonZero<kDLGPU, int32_t>(IdArray);
template IdArray NonZero<kDLGPU, int64_t>(IdArray);
template IdArray NonZero<kDGLCUDA, int32_t>(IdArray);
template IdArray NonZero<kDGLCUDA, int64_t>(IdArray);
} // namespace impl
} // namespace aten

View File

@@ -28,7 +28,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
@@ -44,28 +44,28 @@ IdArray BinaryElewise(IdArray lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, IdArray rhs);
template <typename IdType, typename Op>
@@ -79,7 +79,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdArray lhs, IdType rhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
@@ -94,28 +94,28 @@ IdArray BinaryElewise(IdArray lhs, IdType rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(IdArray lhs, int32_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(IdArray lhs, int64_t rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(IdArray lhs, int64_t rhs);
@@ -130,7 +130,7 @@ __global__ void _BinaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray BinaryElewise(IdType lhs, IdArray rhs) {
const int64_t len = rhs->shape[0];
IdArray ret = NewIdArray(rhs->shape[0], rhs->ctx, rhs->dtype.bits);
@@ -145,28 +145,28 @@ IdArray BinaryElewise(IdType lhs, IdArray rhs) {
return ret;
}
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDLGPU, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Add>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Sub>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mul>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Div>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::Mod>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LT>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::GE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::LE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::EQ>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int32_t, arith::NE>(int32_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Add>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Sub>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mul>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Div>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::Mod>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LT>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::GE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::LE>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::EQ>(int64_t lhs, IdArray rhs);
template IdArray BinaryElewise<kDGLCUDA, int64_t, arith::NE>(int64_t lhs, IdArray rhs);
template <typename IdType, typename Op>
__global__ void _UnaryElewiseKernel(
@@ -179,7 +179,7 @@ __global__ void _UnaryElewiseKernel(
}
}
template <DLDeviceType XPU, typename IdType, typename Op>
template <DGLDeviceType XPU, typename IdType, typename Op>
IdArray UnaryElewise(IdArray lhs) {
const int64_t len = lhs->shape[0];
IdArray ret = NewIdArray(lhs->shape[0], lhs->ctx, lhs->dtype.bits);
@@ -194,8 +194,8 @@ IdArray UnaryElewise(IdArray lhs) {
return ret;
}
template IdArray UnaryElewise<kDLGPU, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDLGPU, int64_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCUDA, int32_t, arith::Neg>(IdArray lhs);
template IdArray UnaryElewise<kDGLCUDA, int64_t, arith::Neg>(IdArray lhs);
///////////////////////////// Full /////////////////////////////
@@ -210,9 +210,9 @@ __global__ void _FullKernel(
}
}
template <DLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DLContext ctx) {
NDArray ret = NDArray::Empty({length}, DLDataTypeTraits<DType>::dtype, ctx);
template <DGLDeviceType XPU, typename DType>
NDArray Full(DType val, int64_t length, DGLContext ctx) {
NDArray ret = NDArray::Empty({length}, DGLDataTypeTraits<DType>::dtype, ctx);
DType* ret_data = static_cast<DType*>(ret->data);
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
@@ -222,13 +222,13 @@ NDArray Full(DType val, int64_t length, DLContext ctx) {
return ret;
}
template IdArray Full<kDLGPU, int32_t>(int32_t val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, int64_t>(int64_t val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, int32_t>(int32_t val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, int64_t>(int64_t val, int64_t length, DGLContext ctx);
#ifdef USE_FP16
template IdArray Full<kDLGPU, __half>(__half val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, __half>(__half val, int64_t length, DGLContext ctx);
#endif
template IdArray Full<kDLGPU, float>(float val, int64_t length, DLContext ctx);
template IdArray Full<kDLGPU, double>(double val, int64_t length, DLContext ctx);
template IdArray Full<kDGLCUDA, float>(float val, int64_t length, DGLContext ctx);
template IdArray Full<kDGLCUDA, double>(double val, int64_t length, DGLContext ctx);
///////////////////////////// Range /////////////////////////////
@@ -243,8 +243,8 @@ __global__ void _RangeKernel(IdType* out, IdType low, IdType length) {
}
}
template <DLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DLContext ctx) {
template <DGLDeviceType XPU, typename IdType>
IdArray Range(IdType low, IdType high, DGLContext ctx) {
CHECK(high >= low) << "high must be bigger than low";
const IdType length = high - low;
IdArray ret = NewIdArray(length, ctx, sizeof(IdType) * 8);
@@ -260,8 +260,8 @@ IdArray Range(IdType low, IdType high, DLContext ctx) {
return ret;
}
template IdArray Range<kDLGPU, int32_t>(int32_t, int32_t, DLContext);
template IdArray Range<kDLGPU, int64_t>(int64_t, int64_t, DLContext);
template IdArray Range<kDGLCUDA, int32_t>(int32_t, int32_t, DGLContext);
template IdArray Range<kDGLCUDA, int64_t>(int64_t, int64_t, DGLContext);
///////////////////////////// Relabel_ //////////////////////////////
@@ -278,7 +278,7 @@ __global__ void _RelabelKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray Relabel_(const std::vector<IdArray>& arrays) {
IdArray all_nodes = Concat(arrays);
const int64_t total_length = all_nodes->shape[0];
@@ -316,8 +316,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
&num_induced, 0,
sizeof(num_induced),
ctx,
DGLContext{kDLCPU, 0},
DGLType{kDLInt, 64, 1});
DGLContext{kDGLCPU, 0},
DGLDataType{kDGLInt, 64, 1});
device->StreamSync(ctx, stream);
device->FreeWorkspace(ctx, num_induced_device);
@@ -338,8 +338,8 @@ IdArray Relabel_(const std::vector<IdArray>& arrays) {
return induced_nodes;
}
template IdArray Relabel_<kDLGPU, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDLGPU, int64_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int32_t>(const std::vector<IdArray>& arrays);
template IdArray Relabel_<kDGLCUDA, int64_t>(const std::vector<IdArray>& arrays);
///////////////////////////// AsNumBits /////////////////////////////
@@ -353,10 +353,10 @@ __global__ void _CastKernel(const InType* in, OutType* out, size_t length) {
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
IdArray AsNumBits(IdArray arr, uint8_t bits) {
const std::vector<int64_t> shape(arr->shape, arr->shape + arr->ndim);
IdArray ret = IdArray::Empty(shape, DLDataType{kDLInt, bits, 1}, arr->ctx);
IdArray ret = IdArray::Empty(shape, DGLDataType{kDGLInt, bits, 1}, arr->ctx);
const int64_t length = ret.NumElements();
cudaStream_t stream = runtime::getCurrentCUDAStream();
int nt = cuda::FindNumThreads(length);
@@ -374,8 +374,8 @@ IdArray AsNumBits(IdArray arr, uint8_t bits) {
}
template IdArray AsNumBits<kDLGPU, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDLGPU, int64_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int32_t>(IdArray arr, uint8_t bits);
template IdArray AsNumBits<kDGLCUDA, int64_t>(IdArray arr, uint8_t bits);
} // namespace impl
} // namespace aten

View File

@@ -23,7 +23,7 @@ __global__ void _ScatterKernel(const IdType* index, const DType* value,
}
}
template <DLDeviceType XPU, typename DType, typename IdType>
template <DGLDeviceType XPU, typename DType, typename IdType>
void Scatter_(IdArray index, NDArray value, NDArray out) {
const int64_t len = index->shape[0];
const IdType* idx = index.Ptr<IdType>();
@@ -37,20 +37,20 @@ void Scatter_(IdArray index, NDArray value, NDArray out) {
idx, val, len, outd);
}
template void Scatter_<kDLGPU, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int32_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int32_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, int64_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, float, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int32_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int32_t, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, int64_t, int64_t>(IdArray, NDArray, NDArray);
#ifdef USE_FP16
template void Scatter_<kDLGPU, __half, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, __half, int64_t>(IdArray, NDArray, NDArray);
#endif
template void Scatter_<kDLGPU, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDLGPU, double, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, float, int64_t>(IdArray, NDArray, NDArray);
template void Scatter_<kDGLCUDA, double, int64_t>(IdArray, NDArray, NDArray);
}; // namespace impl
}; // namespace aten

View File

@@ -13,7 +13,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
const auto& ctx = array->ctx;
auto device = runtime::DeviceAPI::Get(ctx);
@@ -47,8 +47,8 @@ std::pair<IdArray, IdArray> Sort(IdArray array, int num_bits) {
return std::make_pair(sorted_array, sorted_idx);
}
template std::pair<IdArray, IdArray> Sort<kDLGPU, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDLGPU, int64_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int32_t>(IdArray, int num_bits);
template std::pair<IdArray, IdArray> Sort<kDGLCUDA, int64_t>(IdArray, int num_bits);
} // namespace impl
} // namespace aten

View File

@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix COOToCSR(COOMatrix coo) {
LOG(FATAL) << "Unreachable code.";
return {};
}
template <>
CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
@@ -100,7 +100,7 @@ __global__ void _SortedSearchKernelUpperBound(
}
template <>
CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo) {
const auto& ctx = coo.row->ctx;
const auto nbits = coo.row->dtype.bits;
cudaStream_t stream = runtime::getCurrentCUDAStream();
@@ -133,8 +133,8 @@ CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo) {
indptr, coo.col, coo.data, col_sorted);
}
template CSRMatrix COOToCSR<kDLGPU, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDLGPU, int64_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCUDA, int32_t>(COOMatrix coo);
template CSRMatrix COOToCSR<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten

View File

@@ -84,7 +84,7 @@ int _NumberOfBits(const T& range) {
return bits;
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void COOSort_(COOMatrix* coo, bool sort_column) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int row_bits = _NumberOfBits(coo->num_rows);
@@ -131,8 +131,8 @@ void COOSort_(COOMatrix* coo, bool sort_column) {
}
}
template void COOSort_<kDLGPU, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDLGPU, int64_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDGLCUDA, int32_t>(COOMatrix* coo, bool sort_column);
template void COOSort_<kDGLCUDA, int64_t>(COOMatrix* coo, bool sort_column);
///////////////////////////// COOIsSorted /////////////////////////////
@@ -155,7 +155,7 @@ __global__ void _COOIsSortedKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
const int64_t nnz = coo.row->shape[0];
const auto& ctx = coo.row->ctx;
@@ -180,8 +180,8 @@ std::pair<bool, bool> COOIsSorted(COOMatrix coo) {
return {row_sorted, col_sorted};
}
template std::pair<bool, bool> COOIsSorted<kDLGPU, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDLGPU, int64_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int32_t>(COOMatrix coo);
template std::pair<bool, bool> COOIsSorted<kDGLCUDA, int64_t>(COOMatrix coo);
} // namespace impl
} // namespace aten

View File

@@ -14,14 +14,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOO(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
@@ -77,7 +77,7 @@ __global__ void _RepeatKernel(
}
template <>
COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
@@ -99,18 +99,18 @@ COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr) {
true, csr.sorted);
}
template COOMatrix CSRToCOO<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOO<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRToCOODataAsOrder(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int32_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int32_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
@@ -156,8 +156,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDLGPU, int64_t>(csr);
COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr) {
COOMatrix coo = CSRToCOO<kDGLCUDA, int64_t>(csr);
if (aten::IsNullArray(coo.data))
return coo;
const auto& sorted = Sort(coo.data);
@@ -173,8 +173,8 @@ COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr) {
return coo;
}
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDLGPU, int64_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int32_t>(CSRMatrix csr);
template COOMatrix CSRToCOODataAsOrder<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten

View File

@@ -17,7 +17,7 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType, typename DType>
template <DGLDeviceType XPU, typename IdType, typename DType>
NDArray CSRGetData(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, DType filler) {
const int64_t rowlen = rows->shape[0];
@@ -38,7 +38,7 @@ NDArray CSRGetData(
const int nt = cuda::FindNumThreads(rstlen);
const int nb = (rstlen + nt - 1) / nt;
if (return_eids)
BUG_IF_FAIL(DLDataTypeTraits<DType>::dtype == rows->dtype) <<
BUG_IF_FAIL(DGLDataTypeTraits<DType>::dtype == rows->dtype) <<
"DType does not match row's dtype.";
// TODO(minjie): use binary search for sorted csr
@@ -53,24 +53,24 @@ NDArray CSRGetData(
}
#ifdef USE_FP16
template NDArray CSRGetData<kDLGPU, int32_t, __half>(
template NDArray CSRGetData<kDGLCUDA, int32_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
template NDArray CSRGetData<kDLGPU, int64_t, __half>(
template NDArray CSRGetData<kDGLCUDA, int64_t, __half>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, __half filler);
#endif
template NDArray CSRGetData<kDLGPU, int32_t, float>(
template NDArray CSRGetData<kDGLCUDA, int32_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int64_t, float>(
template NDArray CSRGetData<kDGLCUDA, int64_t, float>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, float filler);
template NDArray CSRGetData<kDLGPU, int32_t, double>(
template NDArray CSRGetData<kDGLCUDA, int32_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
template NDArray CSRGetData<kDLGPU, int64_t, double>(
template NDArray CSRGetData<kDGLCUDA, int64_t, double>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, double filler);
// For CSRGetData<XPU, IdType>(CSRMatrix, NDArray, NDArray)
template NDArray CSRGetData<kDLGPU, int32_t, int32_t>(
template NDArray CSRGetData<kDGLCUDA, int32_t, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int32_t filler);
template NDArray CSRGetData<kDLGPU, int64_t, int64_t>(
template NDArray CSRGetData<kDGLCUDA, int64_t, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols, bool return_eids, NDArray weights, int64_t filler);
} // namespace impl

View File

@@ -256,18 +256,18 @@ std::pair<CSRMatrix, NDArray> CSRMM(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, __half>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
#endif
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, float>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int32_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
template std::pair<CSRMatrix, NDArray> CSRMM<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRMM<kDGLCUDA, int64_t, double>(
const CSRMatrix&, NDArray, const CSRMatrix&, NDArray);
} // namespace aten

View File

@@ -34,7 +34,7 @@ __global__ void _SegmentIsSorted(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsSorted(CSRMatrix csr) {
const auto& ctx = csr.indptr->ctx;
cudaStream_t stream = runtime::getCurrentCUDAStream();
@@ -53,16 +53,16 @@ bool CSRIsSorted(CSRMatrix csr) {
return ret;
}
template bool CSRIsSorted<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDLGPU, int64_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRIsSorted<kDGLCUDA, int64_t>(CSRMatrix csr);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void CSRSort_(CSRMatrix* csr) {
LOG(FATAL) << "Unreachable codes";
}
template <>
void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
cudaStream_t stream = runtime::getCurrentCUDAStream();
@@ -108,7 +108,7 @@ void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr) {
}
template <>
void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(csr->indptr->ctx);
@@ -147,8 +147,8 @@ void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr) {
device->FreeWorkspace(ctx, workspace);
}
template void CSRSort_<kDLGPU, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDLGPU, int64_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int32_t>(CSRMatrix* csr);
template void CSRSort_<kDGLCUDA, int64_t>(CSRMatrix* csr);
} // namespace impl
} // namespace aten

View File

@@ -168,18 +168,18 @@ std::pair<CSRMatrix, NDArray> CSRSum(
}
#ifdef USE_FP16
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, __half>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, __half>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
#endif
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, float>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, float>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int32_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int32_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
template std::pair<CSRMatrix, NDArray> CSRSum<kDLGPU, int64_t, double>(
template std::pair<CSRMatrix, NDArray> CSRSum<kDGLCUDA, int64_t, double>(
const std::vector<CSRMatrix>&, const std::vector<NDArray>&);
} // namespace aten

View File

@@ -13,14 +13,14 @@ using runtime::NDArray;
namespace aten {
namespace impl {
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRTranspose(CSRMatrix csr) {
LOG(FATAL) << "Unreachable codes";
return {};
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
cudaStream_t stream = runtime::getCurrentCUDAStream();
// allocate cusparse handle if needed
@@ -90,12 +90,12 @@ CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr) {
}
template <>
CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr) {
CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr) {
return COOToCSR(COOTranspose(CSRToCOO(csr, false)));
}
template CSRMatrix CSRTranspose<kDLGPU, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDLGPU, int64_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCUDA, int32_t>(CSRMatrix csr);
template CSRMatrix CSRTranspose<kDGLCUDA, int64_t>(CSRMatrix csr);
} // namespace impl
} // namespace aten

View File

@@ -105,7 +105,7 @@ IdArray _PerformFilter(
&num_unique, 0,
sizeof(num_unique),
ctx,
DGLContext{kDLCPU, 0},
DGLContext{kDGLCPU, 0},
test->dtype);
// insert items into set
@@ -150,13 +150,13 @@ class CudaFilterSet : public Filter {
} // namespace
template<DLDeviceType XPU, typename IdType>
template<DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set) {
return FilterRef(std::make_shared<CudaFilterSet<IdType>>(set));
}
template FilterRef CreateSetFilter<kDLGPU, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDLGPU, int64_t>(IdArray set);
template FilterRef CreateSetFilter<kDGLCUDA, int32_t>(IdArray set);
template FilterRef CreateSetFilter<kDGLCUDA, int64_t>(IdArray set);
} // namespace array
} // namespace dgl

View File

@@ -47,7 +47,7 @@ __global__ void _DisjointUnionKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMatrix>& coos) {
IdType n = coos.size(), nbits = coos[0].row->dtype.bits;
IdArray n_rows = NewIdArray(n, CPU, nbits);
@@ -71,10 +71,10 @@ std::tuple<IdArray, IdArray, IdArray> _ComputePrefixSums(const std::vector<COOMa
CumSum(n_elms.CopyTo(coos[0].row->ctx), true));
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
int64_t n_arrs, int n_elms,
DGLContext ctx, DGLType dtype, cudaStream_t stream) {
DGLContext ctx, DGLDataType dtype, cudaStream_t stream) {
auto device = runtime::DeviceAPI::Get(ctx);
int nt = 256;
int nb = (n_elms + nt - 1) / nt;
@@ -84,7 +84,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->CopyDataFromTo(
arrs, 0, arrs_dev, 0, sizeof(IdType*)*n_arrs,
DGLContext{kDLCPU, 0}, ctx, dtype);
DGLContext{kDGLCPU, 0}, ctx, dtype);
CUDA_KERNEL_CALL(_DisjointUnionKernel,
nb, nt, 0, stream,
@@ -94,7 +94,7 @@ void _Merge(IdType** arrs, IdType* prefix, IdType* offset, IdType* out,
device->FreeWorkspace(ctx, arrs_dev);
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
auto device = runtime::DeviceAPI::Get(coos[0].row->ctx);
@@ -133,17 +133,17 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
IdType n_elements = 0;
device->CopyDataFromTo(
&prefix_elm[coos.size()], 0, &n_elements, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_src[coos.size()], 0, &src_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
device->CopyDataFromTo(
&prefix_dst[coos.size()], 0, &dst_offset, 0,
sizeof(IdType), coos[0].row->ctx, DGLContext{kDLCPU, 0},
sizeof(IdType), coos[0].row->ctx, DGLContext{kDGLCPU, 0},
coos[0].row->dtype);
// Union src array
@@ -176,8 +176,8 @@ COOMatrix DisjointUnionCoo(const std::vector<COOMatrix>& coos) {
col_sorted);
}
template COOMatrix DisjointUnionCoo<kDLGPU, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDLGPU, int64_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int32_t>(const std::vector<COOMatrix>& coos);
template COOMatrix DisjointUnionCoo<kDGLCUDA, int64_t>(const std::vector<COOMatrix>& coos);
} // namespace impl
} // namespace aten

View File

@@ -394,74 +394,74 @@ void GatherMMScatter(const NDArray A,
}
template void GatherMM<kDLGPU, int32_t, 16>(
template void GatherMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 16>(
template void GatherMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 32>(
template void GatherMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 32>(
template void GatherMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int32_t, 64>(
template void GatherMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMM<kDLGPU, int64_t, 64>(
template void GatherMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b);
template void GatherMMScatter<kDLGPU, int32_t, 16>(
template void GatherMMScatter<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 16>(
template void GatherMMScatter<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 32>(
template void GatherMMScatter<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 32>(
template void GatherMMScatter<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int32_t, 64>(
template void GatherMMScatter<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDLGPU, int64_t, 64>(
template void GatherMMScatter<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray idx_a, const NDArray idx_b, const NDArray idx_c);
template void SegmentMM<kDLGPU, int32_t, 16>(
template void SegmentMM<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 16>(
template void SegmentMM<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 32>(
template void SegmentMM<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 32>(
template void SegmentMM<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int32_t, 64>(
template void SegmentMM<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMM<kDLGPU, int64_t, 64>(
template void SegmentMM<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray B, NDArray C,
const NDArray seglen_A, bool a_trans, bool b_trans);
template void SegmentMMBackwardB<kDLGPU, int32_t, 16>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 16>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 32>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 32>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 32>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int32_t, 64>(
template void SegmentMMBackwardB<kDGLCUDA, int32_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDLGPU, int64_t, 64>(
template void SegmentMMBackwardB<kDGLCUDA, int64_t, 64>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
} // namespace aten

View File

@@ -26,7 +26,7 @@
} \
} else { \
constexpr bool UseBcast = true; \
const DLContext ctx = (CTX); \
const DGLContext ctx = (CTX); \
const auto device = runtime::DeviceAPI::Get(ctx); \
(LHS_OFF) = static_cast<int64_t*>( \
device->AllocWorkspace(ctx, sizeof(int64_t) * info.lhs_offset.size())); \

View File

@@ -93,7 +93,7 @@ struct IsNotMinusOne {
template <typename IdType>
void SortOrderedPairs(
runtime::DeviceAPI* device,
DLContext ctx,
DGLContext ctx,
IdType* major,
IdType* minor,
IdType* tmp_major,
@@ -128,7 +128,7 @@ void SortOrderedPairs(
}; // namespace
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
const CSRMatrix& csr,
int64_t num_samples,
@@ -211,9 +211,9 @@ std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling(
return result;
}
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int32_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int32_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDLGPU, int64_t>(
template std::pair<IdArray, IdArray> CSRGlobalUniformNegativeSampling<kDGLCUDA, int64_t>(
const CSRMatrix&, int64_t, int, bool, bool, double);
}; // namespace impl

View File

@@ -240,7 +240,7 @@ __global__ void _CSRRowWiseSampleUniformReplaceKernel(
///////////////////////////// CSR sampling //////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
IdArray rows,
const int64_t num_picks,
@@ -311,7 +311,7 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
@@ -369,9 +369,9 @@ COOMatrix CSRRowWiseSamplingUniform(CSRMatrix mat,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int32_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int32_t>(
CSRMatrix, IdArray, int64_t, bool);
template COOMatrix CSRRowWiseSamplingUniform<kDLGPU, int64_t>(
template COOMatrix CSRRowWiseSamplingUniform<kDGLCUDA, int64_t>(
CSRMatrix, IdArray, int64_t, bool);
} // namespace impl

View File

@@ -416,7 +416,7 @@ __global__ void _CSRRowWiseSampleReplaceKernel(
* @param replace Is replacement sampling?
* @author pengqirong (OPPO), dlasalle and Xin from Nvidia.
*/
template <DLDeviceType XPU, typename IdType, typename FloatType>
template <DGLDeviceType XPU, typename IdType, typename FloatType>
COOMatrix CSRRowWiseSampling(CSRMatrix mat,
IdArray rows,
int64_t num_picks,
@@ -492,7 +492,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
device->CopyDataFromTo(temp_ptr, num_rows * sizeof(temp_len), &temp_len, 0,
sizeof(temp_len),
ctx,
DGLContext{kDLCPU, 0},
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
device->StreamSync(ctx, stream);
@@ -523,7 +523,7 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
device->CopyDataFromTo(out_ptr, num_rows * sizeof(new_len), &new_len, 0,
sizeof(new_len),
ctx,
DGLContext{kDLCPU, 0},
DGLContext{kDGLCPU, 0},
mat.indptr->dtype);
CUDA_CALL(cudaEventRecord(copyEvent, stream));
@@ -651,13 +651,13 @@ COOMatrix CSRRowWiseSampling(CSRMatrix mat,
picked_col, picked_idx);
}
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, float>(
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, float>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int32_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int32_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
template COOMatrix CSRRowWiseSampling<kDLGPU, int64_t, double>(
template COOMatrix CSRRowWiseSampling<kDGLCUDA, int64_t, double>(
CSRMatrix, IdArray, int64_t, FloatArray, bool);
} // namespace impl

View File

@@ -54,52 +54,52 @@ void SDDMMCoo(const std::string& op,
}
template void SDDMMCsr<kDLGPU, int32_t, 16>(
template void SDDMMCsr<kDGLCUDA, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 16>(
template void SDDMMCsr<kDGLCUDA, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 32>(
template void SDDMMCsr<kDGLCUDA, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 32>(
template void SDDMMCsr<kDGLCUDA, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int32_t, 64>(
template void SDDMMCsr<kDGLCUDA, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCsr<kDLGPU, int64_t, 64>(
template void SDDMMCsr<kDGLCUDA, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 16>(
template void SDDMMCoo<kDGLCUDA, int32_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 16>(
template void SDDMMCoo<kDGLCUDA, int64_t, 16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 32>(
template void SDDMMCoo<kDGLCUDA, int32_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 32>(
template void SDDMMCoo<kDGLCUDA, int64_t, 32>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int32_t, 64>(
template void SDDMMCoo<kDGLCUDA, int32_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);
template void SDDMMCoo<kDLGPU, int64_t, 64>(
template void SDDMMCoo<kDGLCUDA, int64_t, 64>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out,
int lhs_target, int rhs_target);

View File

@@ -42,42 +42,42 @@ void SDDMMCooHetero(const std::string& op,
}
template void SDDMMCooHetero<kDLGPU, int32_t, 16>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 16>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 32>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 32>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int32_t, 64>(
template void SDDMMCooHetero<kDGLCUDA, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDLGPU, int64_t, 64>(
template void SDDMMCooHetero<kDGLCUDA, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,

View File

@@ -41,42 +41,42 @@ void SDDMMCsrHetero(const std::string& op,
});
}
template void SDDMMCsrHetero<kDLGPU, int32_t, 16>(
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 16>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 32>(
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 32>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 32>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int32_t, 64>(
template void SDDMMCsrHetero<kDGLCUDA, int32_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,
std::vector<NDArray> out, int lhs_target, int rhs_target,
const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDLGPU, int64_t, 64>(
template void SDDMMCsrHetero<kDGLCUDA, int64_t, 64>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr,
const std::vector<NDArray>& lhs, const std::vector<NDArray>& rhs,

View File

@@ -73,113 +73,113 @@ void BackwardSegmentCmp(NDArray feat,
}
template void SegmentReduce<kDLGPU, int32_t, 16>(
template void SegmentReduce<kDGLCUDA, int32_t, 16>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 16>(
template void SegmentReduce<kDGLCUDA, int64_t, 16>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int32_t, 32>(
template void SegmentReduce<kDGLCUDA, int32_t, 32>(
const std::string& op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 32>(
template void SegmentReduce<kDGLCUDA, int64_t, 32>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int32_t, 64>(
template void SegmentReduce<kDGLCUDA, int32_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void SegmentReduce<kDLGPU, int64_t, 64>(
template void SegmentReduce<kDGLCUDA, int64_t, 64>(
const std::string &op,
NDArray feat,
NDArray offsets,
NDArray out,
NDArray arg);
template void ScatterAdd<kDLGPU, int32_t, 16>(
template void ScatterAdd<kDGLCUDA, int32_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 16>(
template void ScatterAdd<kDGLCUDA, int64_t, 16>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 32>(
template void ScatterAdd<kDGLCUDA, int32_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 32>(
template void ScatterAdd<kDGLCUDA, int64_t, 32>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int32_t, 64>(
template void ScatterAdd<kDGLCUDA, int32_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void ScatterAdd<kDLGPU, int64_t, 64>(
template void ScatterAdd<kDGLCUDA, int64_t, 64>(
NDArray feat,
NDArray idx,
NDArray out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 16>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 32>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 32>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int32_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int32_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void UpdateGradMinMax_hetero<kDLGPU, int64_t, 64>(
template void UpdateGradMinMax_hetero<kDGLCUDA, int64_t, 64>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 16>(
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 16>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 16>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 32>(
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 32>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 32>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int32_t, 64>(
template void BackwardSegmentCmp<kDGLCUDA, int32_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);
template void BackwardSegmentCmp<kDLGPU, int64_t, 64>(
template void BackwardSegmentCmp<kDGLCUDA, int64_t, 64>(
NDArray feat,
NDArray arg,
NDArray out);

View File

@@ -71,7 +71,7 @@ __global__ void _COOGetRowNNZKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = coo.row->ctx;
@@ -84,12 +84,12 @@ int64_t COOGetRowNNZ(COOMatrix coo, int64_t row) {
nb, nt, 0, stream,
coo.row.Ptr<IdType>(), rst.Ptr<IdType>(),
row, nnz);
rst = rst.CopyTo(DLContext{kDLCPU, 0});
rst = rst.CopyTo(DGLContext{kDGLCPU, 0});
return *rst.Ptr<IdType>();
}
template int64_t COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, int64_t);
template int64_t COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, int64_t);
template <typename IdType>
__global__ void _COOGetAllRowNNZKernel(
@@ -104,7 +104,7 @@ __global__ void _COOGetAllRowNNZKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = coo.row->ctx;
@@ -112,7 +112,7 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
IdType num_rows = coo.num_rows;
IdType num_queries = rows->shape[0];
if (num_queries == 1) {
auto rows_cpu = rows.CopyTo(DLContext{kDLCPU, 0});
auto rows_cpu = rows.CopyTo(DGLContext{kDGLCPU, 0});
int64_t row = *rows_cpu.Ptr<IdType>();
IdType nt = 1024;
IdType nb = dgl::cuda::FindNumBlocks<'x'>((nnz + nt - 1) / nt);
@@ -136,8 +136,8 @@ NDArray COOGetRowNNZ(COOMatrix coo, NDArray rows) {
}
}
template NDArray COOGetRowNNZ<kDLGPU, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDLGPU, int64_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCUDA, int32_t>(COOMatrix, NDArray);
template NDArray COOGetRowNNZ<kDGLCUDA, int64_t>(COOMatrix, NDArray);
} // namespace impl
} // namespace aten

View File

@@ -21,7 +21,7 @@ namespace impl {
///////////////////////////// CSRIsNonZero /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = csr.indptr->ctx;
@@ -38,14 +38,14 @@ bool CSRIsNonZero(CSRMatrix csr, int64_t row, int64_t col) {
rows.Ptr<IdType>(), cols.Ptr<IdType>(),
1, 1, 1,
static_cast<IdType*>(nullptr), static_cast<IdType>(-1), out.Ptr<IdType>());
out = out.CopyTo(DLContext{kDLCPU, 0});
out = out.CopyTo(DGLContext{kDGLCPU, 0});
return *out.Ptr<IdType>() != -1;
}
template bool CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template bool CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
@@ -69,8 +69,8 @@ NDArray CSRIsNonZero(CSRMatrix csr, NDArray row, NDArray col) {
return rst != -1;
}
template NDArray CSRIsNonZero<kDLGPU, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDLGPU, int64_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCUDA, int32_t>(CSRMatrix, NDArray, NDArray);
template NDArray CSRIsNonZero<kDGLCUDA, int64_t>(CSRMatrix, NDArray, NDArray);
///////////////////////////// CSRHasDuplicate /////////////////////////////
@@ -95,7 +95,7 @@ __global__ void _SegmentHasNoDuplicate(
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
bool CSRHasDuplicate(CSRMatrix csr) {
if (!csr.sorted)
csr = CSRSort(csr);
@@ -116,20 +116,20 @@ bool CSRHasDuplicate(CSRMatrix csr) {
return !ret;
}
template bool CSRHasDuplicate<kDLGPU, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDLGPU, int64_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCUDA, int32_t>(CSRMatrix csr);
template bool CSRHasDuplicate<kDGLCUDA, int64_t>(CSRMatrix csr);
///////////////////////////// CSRGetRowNNZ /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
int64_t CSRGetRowNNZ(CSRMatrix csr, int64_t row) {
const IdType cur = aten::IndexSelect<IdType>(csr.indptr, row);
const IdType next = aten::IndexSelect<IdType>(csr.indptr, row + 1);
return next - cur;
}
template int64_t CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template int64_t CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
template <typename IdType>
__global__ void _CSRGetRowNNZKernel(
@@ -146,7 +146,7 @@ __global__ void _CSRGetRowNNZKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto len = rows->shape[0];
@@ -162,24 +162,24 @@ NDArray CSRGetRowNNZ(CSRMatrix csr, NDArray rows) {
return rst;
}
template NDArray CSRGetRowNNZ<kDLGPU, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDLGPU, int64_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int32_t>(CSRMatrix, NDArray);
template NDArray CSRGetRowNNZ<kDGLCUDA, int64_t>(CSRMatrix, NDArray);
///////////////////////////// CSRGetRowColumnIndices /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowColumnIndices(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
return csr.indices.CreateView({len}, csr.indices->dtype, offset);
}
template NDArray CSRGetRowColumnIndices<kDLGPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDLGPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowColumnIndices<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRGetRowData /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
const int64_t len = impl::CSRGetRowNNZ<XPU, IdType>(csr, row);
const int64_t offset = aten::IndexSelect<IdType>(csr.indptr, row) * sizeof(IdType);
@@ -189,12 +189,12 @@ NDArray CSRGetRowData(CSRMatrix csr, int64_t row) {
return aten::Range(offset, offset + len, csr.indptr->dtype.bits, csr.indptr->ctx);
}
template NDArray CSRGetRowData<kDLGPU, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDLGPU, int64_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCUDA, int32_t>(CSRMatrix, int64_t);
template NDArray CSRGetRowData<kDGLCUDA, int64_t>(CSRMatrix, int64_t);
///////////////////////////// CSRSliceRows /////////////////////////////
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
const int64_t num_rows = end - start;
const IdType st_pos = aten::IndexSelect<IdType>(csr.indptr, start);
@@ -215,8 +215,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, int64_t start, int64_t end) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix, int64_t, int64_t);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix, int64_t, int64_t);
/*!
* \brief Copy data segment to output buffers
@@ -243,7 +243,7 @@ __global__ void _SegmentCopyKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const int64_t len = rows->shape[0];
@@ -272,8 +272,8 @@ CSRMatrix CSRSliceRows(CSRMatrix csr, NDArray rows) {
csr.sorted);
}
template CSRMatrix CSRSliceRows<kDLGPU, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDLGPU, int64_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int32_t>(CSRMatrix , NDArray);
template CSRMatrix CSRSliceRows<kDGLCUDA, int64_t>(CSRMatrix , NDArray);
///////////////////////////// CSRGetDataAndIndices /////////////////////////////
@@ -345,7 +345,7 @@ __global__ void _SortedSearchKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray col) {
const auto rowlen = row->shape[0];
const auto collen = col->shape[0];
@@ -392,9 +392,9 @@ std::vector<NDArray> CSRGetDataAndIndices(CSRMatrix csr, NDArray row, NDArray co
return {ret_row, ret_col, ret_data};
}
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int32_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int32_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
template std::vector<NDArray> CSRGetDataAndIndices<kDLGPU, int64_t>(
template std::vector<NDArray> CSRGetDataAndIndices<kDGLCUDA, int64_t>(
CSRMatrix csr, NDArray rows, NDArray cols);
///////////////////////////// CSRSliceMatrix /////////////////////////////
@@ -422,7 +422,7 @@ __global__ void _SegmentMaskColKernel(
}
}
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = rows->ctx;
@@ -501,9 +501,9 @@ CSRMatrix CSRSliceMatrix(CSRMatrix csr, runtime::NDArray rows, runtime::NDArray
ret_col, ret_data);
}
template CSRMatrix CSRSliceMatrix<kDLGPU, int32_t>(
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int32_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
template CSRMatrix CSRSliceMatrix<kDLGPU, int64_t>(
template CSRMatrix CSRSliceMatrix<kDGLCUDA, int64_t>(
CSRMatrix csr, runtime::NDArray rows, runtime::NDArray cols);
} // namespace impl

View File

@@ -147,53 +147,53 @@ void SpMMCoo(const std::string& op, const std::string& reduce,
}
}
template void SpMMCsr<kDLGPU, int32_t, 16>(
template void SpMMCsr<kDGLCUDA, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 16>(
template void SpMMCsr<kDGLCUDA, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int32_t, 32>(
template void SpMMCsr<kDGLCUDA, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 32>(
template void SpMMCsr<kDGLCUDA, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int32_t, 64>(
template void SpMMCsr<kDGLCUDA, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCsr<kDLGPU, int64_t, 64>(
template void SpMMCsr<kDGLCUDA, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const CSRMatrix& csr,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 16>(
template void SpMMCoo<kDGLCUDA, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 16>(
template void SpMMCoo<kDGLCUDA, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 32>(
template void SpMMCoo<kDGLCUDA, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 32>(
template void SpMMCoo<kDGLCUDA, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int32_t, 64>(
template void SpMMCoo<kDGLCUDA, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);
template void SpMMCoo<kDLGPU, int64_t, 64>(
template void SpMMCoo<kDGLCUDA, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const COOMatrix& coo,
NDArray ufeat, NDArray efeat, NDArray out, std::vector<NDArray> out_aux);

View File

@@ -203,7 +203,7 @@ cusparseStatus_t Xcsrmm2<double>(cusparseHandle_t handle, cusparseOperation_t tr
/*! Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType>
void CusparseCsrmm2(
const DLContext& ctx,
const DGLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,
@@ -303,7 +303,7 @@ void CusparseCsrmm2(
/*! Cusparse implementation of SpMM on Csr format. */
template <typename DType, typename IdType>
void CusparseCsrmm2Hetero(
const DLContext& ctx,
const DGLContext& ctx,
const CSRMatrix& csr,
const DType* B_data, const DType* A_data,
DType* C_data,

View File

@@ -199,37 +199,37 @@ void SpMMCsrHetero(const std::string& op, const std::string& reduce,
});
}
template void SpMMCsrHetero<kDLGPU, int32_t, 16>(
template void SpMMCsrHetero<kDGLCUDA, int32_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 16>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, 16>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 32>(
template void SpMMCsrHetero<kDGLCUDA, int32_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 32>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, 32>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int32_t, 64>(
template void SpMMCsrHetero<kDGLCUDA, int32_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,
std::vector<NDArray>* out, std::vector<std::vector<NDArray>>* out_aux,
const std::vector<dgl_type_t>& ufeat_ntids, const std::vector<dgl_type_t>& out_ntids);
template void SpMMCsrHetero<kDLGPU, int64_t, 64>(
template void SpMMCsrHetero<kDGLCUDA, int64_t, 64>(
const std::string& op, const std::string& reduce,
const BcastOff& bcast, const std::vector<CSRMatrix>& csr,
const std::vector<NDArray>& ufeat, const std::vector<NDArray>& efeat,

View File

@@ -11,7 +11,7 @@
namespace dgl {
namespace cuda {
bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx) {
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx) {
auto device = runtime::DeviceAPI::Get(ctx);
int8_t* rst = static_cast<int8_t*>(device->AllocWorkspace(ctx, 1));
// Call CUB's reduction

View File

@@ -7,9 +7,9 @@
#define DGL_ARRAY_CUDA_UTILS_H_
#include <dmlc/logging.h>
#include <dgl/runtime/c_runtime_api.h>
#include <dgl/runtime/device_api.h>
#include <dgl/runtime/ndarray.h>
#include <dlpack/dlpack.h>
#include "../../runtime/cuda/cuda_common.h"
namespace dgl {
@@ -115,7 +115,7 @@ __device__ __forceinline__ T _ldg(T* addr) {
* \param ctx Device context.
* \return True if all the flags are true.
*/
bool AllTrue(int8_t* flags, int64_t length, const DLContext& ctx);
bool AllTrue(int8_t* flags, int64_t length, const DGLContext& ctx);
/*!
* \brief CUDA Kernel of filling the vector started from ptr of size length
@@ -187,7 +187,7 @@ __global__ void _LinearSearchKernel(
template <typename DType>
inline DType GetCUDAScalar(
runtime::DeviceAPI* device_api,
DLContext ctx,
DGLContext ctx,
const DType* cuda_ptr) {
DType result;
device_api->CopyDataFromTo(
@@ -195,8 +195,8 @@ inline DType GetCUDAScalar(
&result, 0,
sizeof(result),
ctx,
DLContext{kDLCPU, 0},
DLDataTypeTraits<DType>::dtype);
DGLContext{kDGLCPU, 0},
DGLDataTypeTraits<DType>::dtype);
return result;
}

View File

@@ -25,7 +25,7 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
std::vector<int64_t> shape{len};
CHECK(array.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU);
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
for (int d = 1; d < array->ndim; ++d) {
num_feat *= array->shape[d];
@@ -85,8 +85,8 @@ void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
std::vector<int64_t> shape{len};
CHECK(dest.IsPinned());
CHECK_EQ(index->ctx.device_type, kDLGPU);
CHECK_EQ(source->ctx.device_type, kDLGPU);
CHECK_EQ(index->ctx.device_type, kDGLCUDA);
CHECK_EQ(source->ctx.device_type, kDGLCUDA);
for (int d = 1; d < source->ndim; ++d) {
num_feat *= source->shape[d];

View File

@@ -15,7 +15,7 @@ namespace array {
using namespace dgl::runtime;
template<DLDeviceType XPU, typename IdType>
template<DGLDeviceType XPU, typename IdType>
FilterRef CreateSetFilter(IdArray set);
DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
@@ -23,10 +23,10 @@ DGL_REGISTER_GLOBAL("utils.filter._CAPI_DGLFilterCreateFromSet")
IdArray array = args[0];
auto ctx = array->ctx;
// TODO(nv-dlasalle): Implement CPU version.
if (ctx.device_type == kDLGPU) {
if (ctx.device_type == kDGLCUDA) {
#ifdef DGL_USE_CUDA
ATEN_ID_TYPE_SWITCH(array->dtype, IdType, {
*rv = CreateSetFilter<kDLGPU, IdType>(array);
*rv = CreateSetFilter<kDGLCUDA, IdType>(array);
});
#else
LOG(FATAL) << "GPU support not compiled.";

View File

@@ -8,6 +8,7 @@
#ifdef USE_TVM
#include <featgraph.h>
#include <dgl/runtime/dlpack_convert.h>
#endif // USE_TVM
#include "kernel_decl.h"
@@ -70,7 +71,7 @@ void SegmentMM(const NDArray A,
}
CHECK_EQ(B->shape[0], seglen_A.NumElements())
<< "segment_mm expects len(seglen_A) == B.shape[0]";
CHECK_EQ(seglen_A->ctx.device_type, kDLCPU)
CHECK_EQ(seglen_A->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen_A to be on CPU.";
CHECK(A->ctx == B->ctx) << "segment_mm expects A and B to be of the same device";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMM", {
@@ -89,7 +90,7 @@ void SegmentMMBackwardB(const NDArray A,
CHECK_EQ(A->ndim, 2) << "segment_mm_backward operator expects a 2D tensor for the first input.";
CHECK_EQ(dC->ndim, 2)
<< "segment_mm_backward operator expects a 2D tensor for the second input.";
CHECK_EQ(seglen->ctx.device_type, kDLCPU)
CHECK_EQ(seglen->ctx.device_type, kDGLCPU)
<< "segment_mm expects seglen to be on CPU.";
ATEN_XPU_SWITCH_CUDA(A->ctx.device_type, XPU, "SegmentMMBackwardB", {
ATEN_ID_TYPE_SWITCH(seglen->dtype, IdType, {
@@ -829,8 +830,12 @@ DGL_REGISTER_GLOBAL("sparse._CAPI_FG_SDDMMTreeReduction")
// {lhs, rhs, out},
// {"U_data", "E_data", "V_data"});
COOMatrix coo = graph.sptr()->GetCOOMatrix(0);
dgl::featgraph::SDDMMTreeReduction(coo.row.ToDLPack(), coo.col.ToDLPack(),
lhs.ToDLPack(), rhs.ToDLPack(), out.ToDLPack());
dgl::featgraph::SDDMMTreeReduction(
DLPackConvert::ToDLPack(coo.row),
DLPackConvert::ToDLPack(coo.col),
DLPackConvert::ToDLPack(lhs),
DLPackConvert::ToDLPack(rhs),
DLPackConvert::ToDLPack(out));
});
#endif // USE_TVM

View File

@@ -16,7 +16,7 @@ namespace aten {
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA
CHECK(array.IsPinned()) << "Input array must be in pinned memory.";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
CHECK_GE(array->ndim, 1) << "Input array must have at least 1 dimension.";
CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";
@@ -34,8 +34,8 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
#ifdef DGL_USE_CUDA
CHECK(dest.IsPinned()) << "Destination array must be in pinned memory.";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Index must be on the GPU.";
CHECK_EQ(source->ctx.device_type, kDLGPU) << "Source array must be on the GPU.";
CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
CHECK_EQ(source->ctx.device_type, kDGLCUDA) << "Source array must be on the GPU.";
CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source "
"array must have the same dtype.";
CHECK_GE(dest->ndim, 1) << "Destination array must have at least 1 dimension.";

View File

@@ -41,8 +41,8 @@ dgl::runtime::NDArray CopyVectorToNDArray(
const std::vector<DType>& vec) {
using dgl::runtime::NDArray;
const int64_t len = vec.size();
NDArray a = NDArray::Empty({len}, DLDataType{kDLInt, sizeof(IdType) * 8, 1},
DLContext{kDLCPU, 0});
NDArray a = NDArray::Empty({len}, DGLDataType{kDGLInt, sizeof(IdType) * 8, 1},
DGLContext{kDGLCPU, 0});
std::copy(vec.begin(), vec.end(), static_cast<IdType*>(a->data));
return a;
}

View File

@@ -50,7 +50,7 @@ template void GroupIndexShuffle<int64_t>(
template <typename IdType>
IdArray RandomPerm(int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* perm_data = static_cast<IdType*>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0);
IndexShuffle(perm_data, num_nodes);
@@ -59,7 +59,7 @@ IdArray RandomPerm(int64_t num_nodes) {
template <typename IdType>
IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_t num_nodes) {
IdArray perm = aten::NewIdArray(num_nodes, DLContext{kDLCPU, 0}, sizeof(IdType) * 8);
IdArray perm = aten::NewIdArray(num_nodes, DGLContext{kDGLCPU, 0}, sizeof(IdType) * 8);
IdType* perm_data = static_cast<IdType*>(perm->data);
std::iota(perm_data, perm_data + num_nodes, 0);
GroupIndexShuffle(group_idxs, perm_data, num_group_idxs, num_nodes);
@@ -77,7 +77,7 @@ IdArray GroupRandomPerm(const IdType *group_idxs, int64_t num_group_idxs, int64_
* Finally, we pick the point with the maximum such distance.
* This process will be repeated for ``sample_points`` - 1 times.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result) {
const FloatType* array_data = static_cast<FloatType*>(array->data);
@@ -135,20 +135,20 @@ void FarthestPointSampler(NDArray array, int64_t batch_size, int64_t sample_poin
ret_start += sample_points;
}
}
template void FarthestPointSampler<kDLCPU, float, int32_t>(
template void FarthestPointSampler<kDGLCPU, float, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, float, int64_t>(
template void FarthestPointSampler<kDGLCPU, float, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int32_t>(
template void FarthestPointSampler<kDGLCPU, double, int32_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template void FarthestPointSampler<kDLCPU, double, int64_t>(
template void FarthestPointSampler<kDGLCPU, double, int64_t>(
NDArray array, int64_t batch_size, int64_t sample_points,
NDArray dist, IdArray start_idx, IdArray result);
template <DLDeviceType XPU, typename FloatType, typename IdType>
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -181,16 +181,16 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
result_data[v_max] = result_data[u];
}
}
template void WeightedNeighborMatching<kDLCPU, float, int32_t>(
template void WeightedNeighborMatching<kDGLCPU, float, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLCPU, float, int64_t>(
template void WeightedNeighborMatching<kDGLCPU, float, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLCPU, double, int32_t>(
template void WeightedNeighborMatching<kDGLCPU, double, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLCPU, double, int64_t>(
template void WeightedNeighborMatching<kDGLCPU, double, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_nodes = result->shape[0];
const IdType *indptr_data = static_cast<IdType*>(csr.indptr->data);
@@ -221,8 +221,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
}
}
}
template void NeighborMatching<kDLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
} // namespace impl
} // namespace geometry

View File

@@ -150,7 +150,7 @@ bool Colorize(IdType * result_data, int64_t num_nodes, float * const prop) {
* are marked, mark this node with its id. Else match this (BLUE, RED) node
* pair and mark them with the smaller id between them.
*/
template <DLDeviceType XPU, typename FloatType, typename IdType>
template <DGLDeviceType XPU, typename FloatType, typename IdType>
void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight, IdArray result) {
cudaStream_t stream = runtime::getCurrentCUDAStream();
const auto& ctx = result->ctx;
@@ -182,13 +182,13 @@ void WeightedNeighborMatching(const aten::CSRMatrix &csr, const NDArray weight,
}
device->FreeWorkspace(ctx, prop);
}
template void WeightedNeighborMatching<kDLGPU, float, int32_t>(
template void WeightedNeighborMatching<kDGLCUDA, float, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, float, int64_t>(
template void WeightedNeighborMatching<kDGLCUDA, float, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int32_t>(
template void WeightedNeighborMatching<kDGLCUDA, double, int32_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
template void WeightedNeighborMatching<kDGLCUDA, double, int64_t>(
const aten::CSRMatrix &csr, const NDArray weight, IdArray result);
/*! \brief Unweighted neighbor matching procedure (GPU version).
@@ -201,7 +201,7 @@ template void WeightedNeighborMatching<kDLGPU, double, int64_t>(
* 2. Graph is sparse, thus neighborhood of each node is small,
* which is suitable for GPU implementation.
*/
template <DLDeviceType XPU, typename IdType>
template <DGLDeviceType XPU, typename IdType>
void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
const int64_t num_edges = csr.indices->shape[0];
const auto& ctx = result->ctx;
@@ -211,7 +211,7 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
// generate random weights
cudaStream_t stream = runtime::getCurrentCUDAStream();
NDArray weight = NDArray::Empty(
{num_edges}, DLDataType{kDLFloat, sizeof(float) * 8, 1}, ctx);
{num_edges}, DGLDataType{kDGLFloat, sizeof(float) * 8, 1}, ctx);
float *weight_data = static_cast<float*>(weight->data);
uint64_t seed = dgl::RandomEngine::ThreadLocal()->RandInt(UINT64_MAX);
auto num_threads = cuda::FindNumThreads(num_edges);
@@ -221,8 +221,8 @@ void NeighborMatching(const aten::CSRMatrix &csr, IdArray result) {
WeightedNeighborMatching<XPU, float, IdType>(csr, weight, result);
}
template void NeighborMatching<kDLGPU, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDLGPU, int64_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int32_t>(const aten::CSRMatrix &csr, IdArray result);
template void NeighborMatching<kDGLCUDA, int64_t>(const aten::CSRMatrix &csr, IdArray result);
} // namespace impl
} // namespace geometry

Some files were not shown because too many files have changed in this diff Show More