[GraphBolt] Feature.count(). (#7730)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-21 11:27:33 -04:00
committed by GitHub
parent 0e649fc68c
commit 8bdcd7eeea
9 changed files with 88 additions and 5 deletions

View File

@@ -93,6 +93,16 @@ class Feature:
"""
raise NotImplementedError
def count(self):
"""Get the count of the feature.
Returns
-------
int
The count of the feature.
"""
raise NotImplementedError
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.
@@ -194,6 +204,29 @@ class FeatureStore:
"""
return self.__getitem__((domain, type_name, feature_name)).size()
def count(
self,
domain: str,
type_name: str,
feature_name: str,
):
"""Get the count the specified feature in the feature store.
Parameters
----------
domain : str
The domain of the feature such as "node", "edge" or "graph".
type_name : str
The node or edge type name.
feature_name : str
The feature name.
Returns
-------
int
The count of the specified feature in the feature store.
"""
return self.__getitem__((domain, type_name, feature_name)).count()
def metadata(
self,
domain: str,

View File

@@ -422,6 +422,16 @@ class CPUCachedFeature(Feature):
"""
return self._fallback_feature.size()
def count(self):
"""Get the count of the feature.
Returns
-------
int
The count of the feature.
"""
return self._fallback_feature.count()
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.

View File

@@ -203,6 +203,16 @@ class GPUCachedFeature(Feature):
"""
return self._fallback_feature.size()
def count(self):
"""Get the count of the feature.
Returns
-------
int
The count of the feature.
"""
return self._fallback_feature.count()
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature.

View File

@@ -239,6 +239,16 @@ class TorchBasedFeature(Feature):
"""
return self._tensor.size()[1:]
def count(self):
"""Get the count of the feature.
Returns
-------
int
The count of the feature.
"""
return self._tensor.size()[0]
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Update the feature store.
@@ -493,6 +503,16 @@ class DiskBasedFeature(Feature):
"""
return self._tensor.size()[1:]
def count(self):
"""Get the count of the feature.
Returns
-------
int
The count of the feature.
"""
return self._tensor.size()[0]
def update(self, value: torch.Tensor, ids: torch.Tensor = None):
"""Disk based feature does not support update for now."""
raise NotImplementedError