mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Use query_and_then_replace_async in read_async. (#7626)
This commit is contained in:
committed by
GitHub
parent
f37741c9ad
commit
65f85b50bb
@@ -121,32 +121,35 @@ class CPUCachedFeature(Feature):
|
||||
yield # first stage is done.
|
||||
|
||||
ids_copy_event.synchronize()
|
||||
policy_future = policy.query_async(ids)
|
||||
policy_future = policy.query_and_then_replace_async(ids)
|
||||
|
||||
yield
|
||||
|
||||
(
|
||||
positions,
|
||||
index,
|
||||
pointers,
|
||||
missing_keys,
|
||||
found_pointers,
|
||||
found_offsets,
|
||||
missing_offsets,
|
||||
) = policy_future.wait()
|
||||
self._feature.total_queries += ids.shape[0]
|
||||
self._feature.total_miss += missing_keys.shape[0]
|
||||
found_cnt = ids.size(0) - missing_keys.size(0)
|
||||
found_positions = positions[:found_cnt]
|
||||
missing_positions = positions[found_cnt:]
|
||||
found_pointers = pointers[:found_cnt]
|
||||
missing_pointers = pointers[found_cnt:]
|
||||
host_to_device_stream = get_host_to_device_uva_stream()
|
||||
with torch.cuda.stream(host_to_device_stream):
|
||||
positions_cuda = positions.to(ids_device, non_blocking=True)
|
||||
values_from_cpu = cache.index_select(positions_cuda)
|
||||
found_positions = found_positions.to(
|
||||
ids_device, non_blocking=True
|
||||
)
|
||||
values_from_cpu = cache.index_select(found_positions)
|
||||
values_from_cpu.record_stream(current_stream)
|
||||
values_from_cpu_copy_event = torch.cuda.Event()
|
||||
values_from_cpu_copy_event.record()
|
||||
|
||||
positions_future = policy.replace_async(
|
||||
missing_keys, missing_offsets
|
||||
)
|
||||
|
||||
fallback_reader = self._fallback_feature.read_async(missing_keys)
|
||||
for _ in range(
|
||||
self._fallback_feature.read_async_num_stages(
|
||||
@@ -162,17 +165,18 @@ class CPUCachedFeature(Feature):
|
||||
)
|
||||
|
||||
missing_values = missing_values_future.wait()
|
||||
positions, pointers, offsets = positions_future.wait()
|
||||
replace_future = cache.replace_async(positions, missing_values)
|
||||
replace_future = cache.replace_async(
|
||||
missing_positions, missing_values
|
||||
)
|
||||
|
||||
host_to_device_stream = get_host_to_device_uva_stream()
|
||||
with torch.cuda.stream(host_to_device_stream):
|
||||
index = index.to(ids_device, non_blocking=True)
|
||||
missing_values_cuda = missing_values.to(
|
||||
missing_values = missing_values.to(
|
||||
ids_device, non_blocking=True
|
||||
)
|
||||
index.record_stream(current_stream)
|
||||
missing_values_cuda.record_stream(current_stream)
|
||||
missing_values.record_stream(current_stream)
|
||||
missing_values_copy_event = torch.cuda.Event()
|
||||
missing_values_copy_event.record()
|
||||
|
||||
@@ -180,8 +184,9 @@ class CPUCachedFeature(Feature):
|
||||
|
||||
reading_completed.wait()
|
||||
replace_future.wait()
|
||||
missing_values_copy_event.wait()
|
||||
writing_completed = policy.writing_completed_async(
|
||||
pointers, offsets
|
||||
missing_pointers, missing_offsets
|
||||
)
|
||||
|
||||
class _Waiter:
|
||||
@@ -211,9 +216,9 @@ class CPUCachedFeature(Feature):
|
||||
return values
|
||||
|
||||
yield _Waiter(
|
||||
[missing_values_copy_event, writing_completed],
|
||||
[writing_completed],
|
||||
values_from_cpu,
|
||||
missing_values_cuda,
|
||||
missing_values,
|
||||
index,
|
||||
)
|
||||
elif ids.is_cuda:
|
||||
@@ -230,24 +235,27 @@ class CPUCachedFeature(Feature):
|
||||
yield # first stage is done.
|
||||
|
||||
ids_copy_event.synchronize()
|
||||
policy_future = policy.query_async(ids)
|
||||
policy_future = policy.query_and_then_replace_async(ids)
|
||||
|
||||
yield
|
||||
|
||||
(
|
||||
positions,
|
||||
index,
|
||||
pointers,
|
||||
missing_keys,
|
||||
found_pointers,
|
||||
found_offsets,
|
||||
missing_offsets,
|
||||
) = policy_future.wait()
|
||||
self._feature.total_queries += ids.shape[0]
|
||||
self._feature.total_miss += missing_keys.shape[0]
|
||||
values_future = cache.query_async(positions, index, ids.shape[0])
|
||||
|
||||
positions_future = policy.replace_async(
|
||||
missing_keys, missing_offsets
|
||||
found_cnt = ids.size(0) - missing_keys.size(0)
|
||||
found_positions = positions[:found_cnt]
|
||||
missing_positions = positions[found_cnt:]
|
||||
found_pointers = pointers[:found_cnt]
|
||||
missing_pointers = pointers[found_cnt:]
|
||||
values_future = cache.query_async(
|
||||
found_positions, index, ids.shape[0]
|
||||
)
|
||||
|
||||
fallback_reader = self._fallback_feature.read_async(missing_keys)
|
||||
@@ -264,11 +272,12 @@ class CPUCachedFeature(Feature):
|
||||
found_pointers, found_offsets
|
||||
)
|
||||
|
||||
missing_index = index[positions.size(0) :]
|
||||
missing_index = index[found_cnt:]
|
||||
|
||||
missing_values = missing_values_future.wait()
|
||||
positions, pointers, offsets = positions_future.wait()
|
||||
replace_future = cache.replace_async(positions, missing_values)
|
||||
replace_future = cache.replace_async(
|
||||
missing_positions, missing_values
|
||||
)
|
||||
values = torch.ops.graphbolt.scatter_async(
|
||||
values, missing_index, missing_values
|
||||
)
|
||||
@@ -277,15 +286,15 @@ class CPUCachedFeature(Feature):
|
||||
|
||||
host_to_device_stream = get_host_to_device_uva_stream()
|
||||
with torch.cuda.stream(host_to_device_stream):
|
||||
values_cuda = values.wait().to(ids_device, non_blocking=True)
|
||||
values_cuda.record_stream(current_stream)
|
||||
values = values.wait().to(ids_device, non_blocking=True)
|
||||
values.record_stream(current_stream)
|
||||
values_copy_event = torch.cuda.Event()
|
||||
values_copy_event.record()
|
||||
|
||||
reading_completed.wait()
|
||||
replace_future.wait()
|
||||
writing_completed = policy.writing_completed_async(
|
||||
pointers, offsets
|
||||
missing_pointers, missing_offsets
|
||||
)
|
||||
|
||||
class _Waiter:
|
||||
@@ -302,26 +311,29 @@ class CPUCachedFeature(Feature):
|
||||
self.events = self.values = None
|
||||
return values
|
||||
|
||||
yield _Waiter([values_copy_event, writing_completed], values_cuda)
|
||||
yield _Waiter([values_copy_event, writing_completed], values)
|
||||
else:
|
||||
policy_future = policy.query_async(ids)
|
||||
policy_future = policy.query_and_then_replace_async(ids)
|
||||
|
||||
yield
|
||||
|
||||
(
|
||||
positions,
|
||||
index,
|
||||
pointers,
|
||||
missing_keys,
|
||||
found_pointers,
|
||||
found_offsets,
|
||||
missing_offsets,
|
||||
) = policy_future.wait()
|
||||
self._feature.total_queries += ids.shape[0]
|
||||
self._feature.total_miss += missing_keys.shape[0]
|
||||
values_future = cache.query_async(positions, index, ids.shape[0])
|
||||
|
||||
positions_future = policy.replace_async(
|
||||
missing_keys, missing_offsets
|
||||
found_cnt = ids.size(0) - missing_keys.size(0)
|
||||
found_positions = positions[:found_cnt]
|
||||
missing_positions = positions[found_cnt:]
|
||||
found_pointers = pointers[:found_cnt]
|
||||
missing_pointers = pointers[found_cnt:]
|
||||
values_future = cache.query_async(
|
||||
found_positions, index, ids.shape[0]
|
||||
)
|
||||
|
||||
fallback_reader = self._fallback_feature.read_async(missing_keys)
|
||||
@@ -338,11 +350,12 @@ class CPUCachedFeature(Feature):
|
||||
found_pointers, found_offsets
|
||||
)
|
||||
|
||||
missing_index = index[positions.size(0) :]
|
||||
missing_index = index[found_cnt:]
|
||||
|
||||
missing_values = missing_values_future.wait()
|
||||
positions, pointers, offsets = positions_future.wait()
|
||||
replace_future = cache.replace_async(positions, missing_values)
|
||||
replace_future = cache.replace_async(
|
||||
missing_positions, missing_values
|
||||
)
|
||||
values = torch.ops.graphbolt.scatter_async(
|
||||
values, missing_index, missing_values
|
||||
)
|
||||
@@ -352,7 +365,7 @@ class CPUCachedFeature(Feature):
|
||||
reading_completed.wait()
|
||||
replace_future.wait()
|
||||
writing_completed = policy.writing_completed_async(
|
||||
pointers, offsets
|
||||
missing_pointers, missing_offsets
|
||||
)
|
||||
|
||||
class _Waiter:
|
||||
|
||||
Reference in New Issue
Block a user