mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt] Feature.count(). (#7730)
This commit is contained in:
committed by
GitHub
parent
0e649fc68c
commit
8bdcd7eeea
@@ -93,6 +93,16 @@ class Feature:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def count(self):
|
||||
"""Get the count of the feature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the feature.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
|
||||
"""Update the feature.
|
||||
|
||||
@@ -194,6 +204,29 @@ class FeatureStore:
|
||||
"""
|
||||
return self.__getitem__((domain, type_name, feature_name)).size()
|
||||
|
||||
def count(
|
||||
self,
|
||||
domain: str,
|
||||
type_name: str,
|
||||
feature_name: str,
|
||||
):
|
||||
"""Get the count the specified feature in the feature store.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
domain : str
|
||||
The domain of the feature such as "node", "edge" or "graph".
|
||||
type_name : str
|
||||
The node or edge type name.
|
||||
feature_name : str
|
||||
The feature name.
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the specified feature in the feature store.
|
||||
"""
|
||||
return self.__getitem__((domain, type_name, feature_name)).count()
|
||||
|
||||
def metadata(
|
||||
self,
|
||||
domain: str,
|
||||
|
||||
@@ -422,6 +422,16 @@ class CPUCachedFeature(Feature):
|
||||
"""
|
||||
return self._fallback_feature.size()
|
||||
|
||||
def count(self):
|
||||
"""Get the count of the feature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the feature.
|
||||
"""
|
||||
return self._fallback_feature.count()
|
||||
|
||||
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
|
||||
"""Update the feature.
|
||||
|
||||
|
||||
@@ -203,6 +203,16 @@ class GPUCachedFeature(Feature):
|
||||
"""
|
||||
return self._fallback_feature.size()
|
||||
|
||||
def count(self):
|
||||
"""Get the count of the feature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the feature.
|
||||
"""
|
||||
return self._fallback_feature.count()
|
||||
|
||||
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
|
||||
"""Update the feature.
|
||||
|
||||
|
||||
@@ -239,6 +239,16 @@ class TorchBasedFeature(Feature):
|
||||
"""
|
||||
return self._tensor.size()[1:]
|
||||
|
||||
def count(self):
|
||||
"""Get the count of the feature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the feature.
|
||||
"""
|
||||
return self._tensor.size()[0]
|
||||
|
||||
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
|
||||
"""Update the feature store.
|
||||
|
||||
@@ -493,6 +503,16 @@ class DiskBasedFeature(Feature):
|
||||
"""
|
||||
return self._tensor.size()[1:]
|
||||
|
||||
def count(self):
|
||||
"""Get the count of the feature.
|
||||
|
||||
Returns
|
||||
-------
|
||||
int
|
||||
The count of the feature.
|
||||
"""
|
||||
return self._tensor.size()[0]
|
||||
|
||||
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
|
||||
"""Disk based feature does not support update for now."""
|
||||
raise NotImplementedError
|
||||
|
||||
Reference in New Issue
Block a user