mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt][CUDA] Enable non_blocking copy_to in gb.DataLoader. (#7603)
This commit is contained in:
committed by
GitHub
parent
a004a2535b
commit
ee20edb83d
@@ -319,8 +319,6 @@ def apply_to(x, device, non_blocking=False):
|
||||
return x
|
||||
if not non_blocking:
|
||||
return x.to(device)
|
||||
# The copy is non blocking only if the objects are pinned.
|
||||
assert x.is_pinned(), f"{x} should be pinned."
|
||||
return x.to(device, non_blocking=True)
|
||||
|
||||
|
||||
@@ -373,6 +371,9 @@ class CopyTo(IterDataPipe):
|
||||
|
||||
def __iter__(self):
|
||||
for data in self.datapipe:
|
||||
if self.non_blocking:
|
||||
# The copy is non blocking only if contents of data are pinned.
|
||||
assert data.is_pinned(), f"{data} should be pinned."
|
||||
yield recursive_apply(
|
||||
data, apply_to, self.device, self.non_blocking
|
||||
)
|
||||
|
||||
@@ -224,14 +224,23 @@ class DataLoader(torch.utils.data.DataLoader):
|
||||
),
|
||||
)
|
||||
|
||||
# (4) Cut datapipe at CopyTo and wrap with prefetcher. This enables the
|
||||
# data pipeline up to the CopyTo operation to run in a separate thread.
|
||||
datapipe_graph = _find_and_wrap_parent(
|
||||
datapipe_graph,
|
||||
CopyTo,
|
||||
dp.iter.Prefetcher,
|
||||
buffer_size=2,
|
||||
)
|
||||
# (4) Cut datapipe at CopyTo and wrap with pinning and prefetching
|
||||
# before it. This enables enables non_blocking copies to the device.
|
||||
# Prefetching enables the data pipeline up to the CopyTo to run in a
|
||||
# separate thread.
|
||||
if torch.cuda.is_available():
|
||||
copiers = dp_utils.find_dps(datapipe_graph, CopyTo)
|
||||
for copier in copiers:
|
||||
if copier.device.type == "cuda":
|
||||
datapipe_graph = dp_utils.replace_dp(
|
||||
datapipe_graph,
|
||||
copier,
|
||||
copier.datapipe.transform(
|
||||
lambda x: x.pin_memory()
|
||||
).prefetch(2)
|
||||
# After the data gets pinned, we can copy non_blocking.
|
||||
.copy_to(copier.device, non_blocking=True),
|
||||
)
|
||||
|
||||
# The stages after feature fetching is still done in the main process.
|
||||
# So we set num_workers to 0 here.
|
||||
|
||||
Reference in New Issue
Block a user