mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt] Async feature fetch refactor (#7540)
This commit is contained in:
committed by
GitHub
parent
2074cbf556
commit
5b4635a3e6
@@ -40,6 +40,7 @@ class Feature:
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids : torch.Tensor
|
||||
@@ -52,21 +53,25 @@ class Feature:
|
||||
`read_async_num_stages(ids.device)`th invocation. The return result
|
||||
can be accessed by calling `.wait()`. on the returned future object.
|
||||
It is undefined behavior to call `.wait()` more than once.
|
||||
|
||||
Example Usage
|
||||
--------
|
||||
>>> import dgl.graphbolt as gb
|
||||
>>> feature = gb.Feature(...)
|
||||
>>> ids = torch.tensor([0, 2])
|
||||
>>> async_handle = feature.read_async(ids)
|
||||
>>> for _ in range(feature.read_async_num_stages(ids.device)):
|
||||
... future = next(async_handle)
|
||||
>>> for stage, future in enumerate(feature.read_async(ids)):
|
||||
... pass
|
||||
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
|
||||
>>> result = future.wait() # result contains the read values.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def read_async_num_stages(self, ids_device: torch.device):
|
||||
"""The number of stages of the read_async operation. See read_async
|
||||
function for directions on its use.
|
||||
function for directions on its use. This function is required to return
|
||||
the number of yield operations when read_async is used with a tensor
|
||||
residing on ids_device.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids_device : torch.device
|
||||
|
||||
@@ -83,6 +83,7 @@ class CPUCachedFeature(Feature):
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids : torch.Tensor
|
||||
@@ -95,14 +96,15 @@ class CPUCachedFeature(Feature):
|
||||
`read_async_num_stages(ids.device)`th invocation. The return result
|
||||
can be accessed by calling `.wait()`. on the returned future object.
|
||||
It is undefined behavior to call `.wait()` more than once.
|
||||
|
||||
Example Usage
|
||||
--------
|
||||
>>> import dgl.graphbolt as gb
|
||||
>>> feature = gb.Feature(...)
|
||||
>>> ids = torch.tensor([0, 2])
|
||||
>>> async_handle = feature.read_async(ids)
|
||||
>>> for _ in range(feature.read_async_num_stages(ids.device)):
|
||||
... future = next(async_handle)
|
||||
>>> for stage, future in enumerate(feature.read_async(ids)):
|
||||
... pass
|
||||
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
|
||||
>>> result = future.wait() # result contains the read values.
|
||||
"""
|
||||
policy = self._feature._policy
|
||||
@@ -309,7 +311,10 @@ class CPUCachedFeature(Feature):
|
||||
|
||||
def read_async_num_stages(self, ids_device: torch.device):
|
||||
"""The number of stages of the read_async operation. See read_async
|
||||
function for directions on its use.
|
||||
function for directions on its use. This function is required to return
|
||||
the number of yield operations when read_async is used with a tensor
|
||||
residing on ids_device.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids_device : torch.device
|
||||
|
||||
@@ -90,6 +90,7 @@ class GPUCachedFeature(Feature):
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids : torch.Tensor
|
||||
@@ -102,14 +103,15 @@ class GPUCachedFeature(Feature):
|
||||
`read_async_num_stages(ids.device)`th invocation. The return result
|
||||
can be accessed by calling `.wait()`. on the returned future object.
|
||||
It is undefined behavior to call `.wait()` more than once.
|
||||
|
||||
Example Usage
|
||||
--------
|
||||
>>> import dgl.graphbolt as gb
|
||||
>>> feature = gb.Feature(...)
|
||||
>>> ids = torch.tensor([0, 2])
|
||||
>>> async_handle = feature.read_async(ids)
|
||||
>>> for _ in range(feature.read_async_num_stages(ids.device)):
|
||||
... future = next(async_handle)
|
||||
>>> for stage, future in enumerate(feature.read_async(ids)):
|
||||
... pass
|
||||
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
|
||||
>>> result = future.wait() # result contains the read values.
|
||||
"""
|
||||
values, missing_index, missing_keys = self._feature.query(ids)
|
||||
@@ -136,7 +138,10 @@ class GPUCachedFeature(Feature):
|
||||
|
||||
def read_async_num_stages(self, ids_device: torch.device):
|
||||
"""The number of stages of the read_async operation. See read_async
|
||||
function for directions on its use.
|
||||
function for directions on its use. This function is required to return
|
||||
the number of yield operations when read_async is used with a tensor
|
||||
residing on ids_device.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids_device : torch.device
|
||||
|
||||
@@ -127,6 +127,7 @@ class TorchBasedFeature(Feature):
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids : torch.Tensor
|
||||
@@ -139,14 +140,15 @@ class TorchBasedFeature(Feature):
|
||||
`read_async_num_stages(ids.device)`th invocation. The return result
|
||||
can be accessed by calling `.wait()`. on the returned future object.
|
||||
It is undefined behavior to call `.wait()` more than once.
|
||||
|
||||
Example Usage
|
||||
--------
|
||||
>>> import dgl.graphbolt as gb
|
||||
>>> feature = gb.Feature(...)
|
||||
>>> ids = torch.tensor([0, 2])
|
||||
>>> async_handle = feature.read_async(ids)
|
||||
>>> for _ in range(feature.read_async_num_stages(ids.device)):
|
||||
... future = next(async_handle)
|
||||
>>> for stage, future in enumerate(feature.read_async(ids)):
|
||||
... pass
|
||||
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
|
||||
>>> result = future.wait() # result contains the read values.
|
||||
"""
|
||||
assert self._tensor.device.type == "cpu"
|
||||
@@ -206,7 +208,10 @@ class TorchBasedFeature(Feature):
|
||||
|
||||
def read_async_num_stages(self, ids_device: torch.device):
|
||||
"""The number of stages of the read_async operation. See read_async
|
||||
function for directions on its use.
|
||||
function for directions on its use. This function is required to return
|
||||
the number of yield operations when read_async is used with a tensor
|
||||
residing on ids_device.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids_device : torch.device
|
||||
@@ -408,6 +413,7 @@ class DiskBasedFeature(Feature):
|
||||
|
||||
def read_async(self, ids: torch.Tensor):
|
||||
"""Read the feature by index asynchronously.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids : torch.Tensor
|
||||
@@ -420,14 +426,15 @@ class DiskBasedFeature(Feature):
|
||||
`read_async_num_stages(ids.device)`th invocation. The return result
|
||||
can be accessed by calling `.wait()`. on the returned future object.
|
||||
It is undefined behavior to call `.wait()` more than once.
|
||||
|
||||
Example Usage
|
||||
--------
|
||||
>>> import dgl.graphbolt as gb
|
||||
>>> feature = gb.Feature(...)
|
||||
>>> ids = torch.tensor([0, 2])
|
||||
>>> async_handle = feature.read_async(ids)
|
||||
>>> for _ in range(feature.read_async_num_stages(ids.device)):
|
||||
... future = next(async_handle)
|
||||
>>> for stage, future in enumerate(feature.read_async(ids)):
|
||||
... pass
|
||||
>>> assert stage + 1 == feature.read_async_num_stages(ids.device)
|
||||
>>> result = future.wait() # result contains the read values.
|
||||
"""
|
||||
assert torch.ops.graphbolt.detect_io_uring()
|
||||
@@ -468,7 +475,10 @@ class DiskBasedFeature(Feature):
|
||||
|
||||
def read_async_num_stages(self, ids_device: torch.device):
|
||||
"""The number of stages of the read_async operation. See read_async
|
||||
function for directions on its use.
|
||||
function for directions on its use. This function is required to return
|
||||
the number of yield operations when read_async is used with a tensor
|
||||
residing on ids_device.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ids_device : torch.device
|
||||
|
||||
Reference in New Issue
Block a user