[GraphBolt] Use query_and_then_replace_async in read_async. (#7626)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-31 11:02:24 -04:00
committed by GitHub
parent f37741c9ad
commit 65f85b50bb

View File

@@ -121,32 +121,35 @@ class CPUCachedFeature(Feature):
yield # first stage is done.
ids_copy_event.synchronize()
policy_future = policy.query_async(ids)
policy_future = policy.query_and_then_replace_async(ids)
yield
(
positions,
index,
pointers,
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
found_cnt = ids.size(0) - missing_keys.size(0)
found_positions = positions[:found_cnt]
missing_positions = positions[found_cnt:]
found_pointers = pointers[:found_cnt]
missing_pointers = pointers[found_cnt:]
host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
positions_cuda = positions.to(ids_device, non_blocking=True)
values_from_cpu = cache.index_select(positions_cuda)
found_positions = found_positions.to(
ids_device, non_blocking=True
)
values_from_cpu = cache.index_select(found_positions)
values_from_cpu.record_stream(current_stream)
values_from_cpu_copy_event = torch.cuda.Event()
values_from_cpu_copy_event.record()
positions_future = policy.replace_async(
missing_keys, missing_offsets
)
fallback_reader = self._fallback_feature.read_async(missing_keys)
for _ in range(
self._fallback_feature.read_async_num_stages(
@@ -162,17 +165,18 @@ class CPUCachedFeature(Feature):
)
missing_values = missing_values_future.wait()
positions, pointers, offsets = positions_future.wait()
replace_future = cache.replace_async(positions, missing_values)
replace_future = cache.replace_async(
missing_positions, missing_values
)
host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
index = index.to(ids_device, non_blocking=True)
missing_values_cuda = missing_values.to(
missing_values = missing_values.to(
ids_device, non_blocking=True
)
index.record_stream(current_stream)
missing_values_cuda.record_stream(current_stream)
missing_values.record_stream(current_stream)
missing_values_copy_event = torch.cuda.Event()
missing_values_copy_event.record()
@@ -180,8 +184,9 @@ class CPUCachedFeature(Feature):
reading_completed.wait()
replace_future.wait()
missing_values_copy_event.wait()
writing_completed = policy.writing_completed_async(
pointers, offsets
missing_pointers, missing_offsets
)
class _Waiter:
@@ -211,9 +216,9 @@ class CPUCachedFeature(Feature):
return values
yield _Waiter(
[missing_values_copy_event, writing_completed],
[writing_completed],
values_from_cpu,
missing_values_cuda,
missing_values,
index,
)
elif ids.is_cuda:
@@ -230,24 +235,27 @@ class CPUCachedFeature(Feature):
yield # first stage is done.
ids_copy_event.synchronize()
policy_future = policy.query_async(ids)
policy_future = policy.query_and_then_replace_async(ids)
yield
(
positions,
index,
pointers,
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])
positions_future = policy.replace_async(
missing_keys, missing_offsets
found_cnt = ids.size(0) - missing_keys.size(0)
found_positions = positions[:found_cnt]
missing_positions = positions[found_cnt:]
found_pointers = pointers[:found_cnt]
missing_pointers = pointers[found_cnt:]
values_future = cache.query_async(
found_positions, index, ids.shape[0]
)
fallback_reader = self._fallback_feature.read_async(missing_keys)
@@ -264,11 +272,12 @@ class CPUCachedFeature(Feature):
found_pointers, found_offsets
)
missing_index = index[positions.size(0) :]
missing_index = index[found_cnt:]
missing_values = missing_values_future.wait()
positions, pointers, offsets = positions_future.wait()
replace_future = cache.replace_async(positions, missing_values)
replace_future = cache.replace_async(
missing_positions, missing_values
)
values = torch.ops.graphbolt.scatter_async(
values, missing_index, missing_values
)
@@ -277,15 +286,15 @@ class CPUCachedFeature(Feature):
host_to_device_stream = get_host_to_device_uva_stream()
with torch.cuda.stream(host_to_device_stream):
values_cuda = values.wait().to(ids_device, non_blocking=True)
values_cuda.record_stream(current_stream)
values = values.wait().to(ids_device, non_blocking=True)
values.record_stream(current_stream)
values_copy_event = torch.cuda.Event()
values_copy_event.record()
reading_completed.wait()
replace_future.wait()
writing_completed = policy.writing_completed_async(
pointers, offsets
missing_pointers, missing_offsets
)
class _Waiter:
@@ -302,26 +311,29 @@ class CPUCachedFeature(Feature):
self.events = self.values = None
return values
yield _Waiter([values_copy_event, writing_completed], values_cuda)
yield _Waiter([values_copy_event, writing_completed], values)
else:
policy_future = policy.query_async(ids)
policy_future = policy.query_and_then_replace_async(ids)
yield
(
positions,
index,
pointers,
missing_keys,
found_pointers,
found_offsets,
missing_offsets,
) = policy_future.wait()
self._feature.total_queries += ids.shape[0]
self._feature.total_miss += missing_keys.shape[0]
values_future = cache.query_async(positions, index, ids.shape[0])
positions_future = policy.replace_async(
missing_keys, missing_offsets
found_cnt = ids.size(0) - missing_keys.size(0)
found_positions = positions[:found_cnt]
missing_positions = positions[found_cnt:]
found_pointers = pointers[:found_cnt]
missing_pointers = pointers[found_cnt:]
values_future = cache.query_async(
found_positions, index, ids.shape[0]
)
fallback_reader = self._fallback_feature.read_async(missing_keys)
@@ -338,11 +350,12 @@ class CPUCachedFeature(Feature):
found_pointers, found_offsets
)
missing_index = index[positions.size(0) :]
missing_index = index[found_cnt:]
missing_values = missing_values_future.wait()
positions, pointers, offsets = positions_future.wait()
replace_future = cache.replace_async(positions, missing_values)
replace_future = cache.replace_async(
missing_positions, missing_values
)
values = torch.ops.graphbolt.scatter_async(
values, missing_index, missing_values
)
@@ -352,7 +365,7 @@ class CPUCachedFeature(Feature):
reading_completed.wait()
replace_future.wait()
writing_completed = policy.writing_completed_async(
pointers, offsets
missing_pointers, missing_offsets
)
class _Waiter: