[GraphBolt] Async feature fetch refactor (#7540)

This commit is contained in:
Muhammed Fatih BALIN
2024-07-22 06:59:08 -04:00
committed by GitHub
parent 2074cbf556
commit 5b4635a3e6
4 changed files with 45 additions and 20 deletions

View File

@@ -40,6 +40,7 @@ class Feature:
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
@@ -52,21 +53,25 @@ class Feature:
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
raise NotImplementedError
def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device

View File

@@ -83,6 +83,7 @@ class CPUCachedFeature(Feature):
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
@@ -95,14 +96,15 @@ class CPUCachedFeature(Feature):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
policy = self._feature._policy
@@ -309,7 +311,10 @@ class CPUCachedFeature(Feature):
def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device

View File

@@ -90,6 +90,7 @@ class GPUCachedFeature(Feature):
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
@@ -102,14 +103,15 @@ class GPUCachedFeature(Feature):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
values, missing_index, missing_keys = self._feature.query(ids)
@@ -136,7 +138,10 @@ class GPUCachedFeature(Feature):
def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device

View File

@@ -127,6 +127,7 @@ class TorchBasedFeature(Feature):
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
@@ -139,14 +140,15 @@ class TorchBasedFeature(Feature):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert self._tensor.device.type == "cpu"
@@ -206,7 +208,10 @@ class TorchBasedFeature(Feature):
def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device
@@ -408,6 +413,7 @@ class DiskBasedFeature(Feature):
def read_async(self, ids: torch.Tensor):
"""Read the feature by index asynchronously.
Parameters
----------
ids : torch.Tensor
@@ -420,14 +426,15 @@ class DiskBasedFeature(Feature):
`read_async_num_stages(ids.device)`th invocation. The return result
can be accessed by calling `.wait()`. on the returned future object.
It is undefined behavior to call `.wait()` more than once.
Example Usage
--------
>>> import dgl.graphbolt as gb
>>> feature = gb.Feature(...)
>>> ids = torch.tensor([0, 2])
>>> async_handle = feature.read_async(ids)
>>> for _ in range(feature.read_async_num_stages(ids.device)):
... future = next(async_handle)
>>> for stage, future in enumerate(feature.read_async(ids)):
... pass
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
>>> result = future.wait() # result contains the read values.
"""
assert torch.ops.graphbolt.detect_io_uring()
@@ -468,7 +475,10 @@ class DiskBasedFeature(Feature):
def read_async_num_stages(self, ids_device: torch.device):
"""The number of stages of the read_async operation. See read_async
function for directions on its use.
function for directions on its use. This function is required to return
the number of yield operations when read_async is used with a tensor
residing on ids_device.
Parameters
----------
ids_device : torch.device