mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[Misc.] Avoid calling IsPinned in the coo/csr constructor from every sampling process (#6568)
This commit is contained in:
@@ -64,11 +64,6 @@ struct COOMatrix {
|
||||
data(darr),
|
||||
row_sorted(rsorted),
|
||||
col_sorted(csorted) {
|
||||
if (!IsEmpty()) {
|
||||
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
|
||||
(aten::IsNullArray(col) || col.IsPinned()) &&
|
||||
(aten::IsNullArray(data) || data.IsPinned());
|
||||
}
|
||||
CheckValidity();
|
||||
}
|
||||
|
||||
@@ -134,6 +129,15 @@ struct COOMatrix {
|
||||
aten::IsNullArray(data);
|
||||
}
|
||||
|
||||
// Check and update the internal flag is_pinned.
|
||||
// This function will initialize a cuda context.
|
||||
inline bool CheckIfPinnedInCUDA() {
|
||||
is_pinned = (aten::IsNullArray(row) || row.IsPinned()) &&
|
||||
(aten::IsNullArray(col) || col.IsPinned()) &&
|
||||
(aten::IsNullArray(data) || data.IsPinned());
|
||||
return is_pinned;
|
||||
}
|
||||
|
||||
/** @brief Return a copy of this matrix on the give device context. */
|
||||
inline COOMatrix CopyTo(const DGLContext& ctx) const {
|
||||
if (ctx == row->ctx) return *this;
|
||||
@@ -151,7 +155,7 @@ struct COOMatrix {
|
||||
num_rows, num_cols, row.PinMemory(), col.PinMemory(),
|
||||
aten::IsNullArray(data) ? data : data.PinMemory(), row_sorted,
|
||||
col_sorted);
|
||||
CHECK(new_coo.is_pinned)
|
||||
CHECK(new_coo.CheckIfPinnedInCUDA())
|
||||
<< "An internal DGL error has occured while trying to pin a COO "
|
||||
"matrix. Please file a bug at "
|
||||
"'https://github.com/dmlc/dgl/issues' "
|
||||
|
||||
@@ -60,11 +60,6 @@ struct CSRMatrix {
|
||||
indices(iarr),
|
||||
data(darr),
|
||||
sorted(sorted_flag) {
|
||||
if (!IsEmpty()) {
|
||||
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
|
||||
(aten::IsNullArray(indices) || indices.IsPinned()) &&
|
||||
(aten::IsNullArray(data) || data.IsPinned());
|
||||
}
|
||||
CheckValidity();
|
||||
}
|
||||
|
||||
@@ -128,6 +123,15 @@ struct CSRMatrix {
|
||||
aten::IsNullArray(data);
|
||||
}
|
||||
|
||||
// Check and update the internal flag is_pinned.
|
||||
// This function will initialize a cuda context.
|
||||
inline bool CheckIfPinnedInCUDA() {
|
||||
is_pinned = (aten::IsNullArray(indptr) || indptr.IsPinned()) &&
|
||||
(aten::IsNullArray(indices) || indices.IsPinned()) &&
|
||||
(aten::IsNullArray(data) || data.IsPinned());
|
||||
return is_pinned;
|
||||
}
|
||||
|
||||
/** @brief Return a copy of this matrix on the give device context. */
|
||||
inline CSRMatrix CopyTo(const DGLContext& ctx) const {
|
||||
if (ctx == indptr->ctx) return *this;
|
||||
@@ -143,7 +147,7 @@ struct CSRMatrix {
|
||||
auto new_csr = CSRMatrix(
|
||||
num_rows, num_cols, indptr.PinMemory(), indices.PinMemory(),
|
||||
aten::IsNullArray(data) ? data : data.PinMemory(), sorted);
|
||||
CHECK(new_csr.is_pinned)
|
||||
CHECK(new_csr.CheckIfPinnedInCUDA())
|
||||
<< "An internal DGL error has occured while trying to pin a CSR "
|
||||
"matrix. Please file a bug at "
|
||||
"'https://github.com/dmlc/dgl/issues' "
|
||||
|
||||
Reference in New Issue
Block a user