diff --git a/python/dgl/graphbolt/feature_store.py b/python/dgl/graphbolt/feature_store.py index 33efbd7089..25c5f2fe93 100644 --- a/python/dgl/graphbolt/feature_store.py +++ b/python/dgl/graphbolt/feature_store.py @@ -93,6 +93,16 @@ class Feature: """ raise NotImplementedError + def count(self): + """Get the count of the feature. + + Returns + ------- + int + The count of the feature. + """ + raise NotImplementedError + def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Update the feature. @@ -194,6 +204,29 @@ class FeatureStore: """ return self.__getitem__((domain, type_name, feature_name)).size() + def count( + self, + domain: str, + type_name: str, + feature_name: str, + ): + """Get the count the specified feature in the feature store. + + Parameters + ---------- + domain : str + The domain of the feature such as "node", "edge" or "graph". + type_name : str + The node or edge type name. + feature_name : str + The feature name. + Returns + ------- + int + The count of the specified feature in the feature store. + """ + return self.__getitem__((domain, type_name, feature_name)).count() + def metadata( self, domain: str, diff --git a/python/dgl/graphbolt/impl/cpu_cached_feature.py b/python/dgl/graphbolt/impl/cpu_cached_feature.py index 8fa626c2cf..96bb31fc6b 100644 --- a/python/dgl/graphbolt/impl/cpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/cpu_cached_feature.py @@ -422,6 +422,16 @@ class CPUCachedFeature(Feature): """ return self._fallback_feature.size() + def count(self): + """Get the count of the feature. + + Returns + ------- + int + The count of the feature. + """ + return self._fallback_feature.count() + def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Update the feature. diff --git a/python/dgl/graphbolt/impl/gpu_cached_feature.py b/python/dgl/graphbolt/impl/gpu_cached_feature.py index 621349f4d4..c6903e2086 100644 --- a/python/dgl/graphbolt/impl/gpu_cached_feature.py +++ b/python/dgl/graphbolt/impl/gpu_cached_feature.py @@ -203,6 +203,16 @@ class GPUCachedFeature(Feature): """ return self._fallback_feature.size() + def count(self): + """Get the count of the feature. + + Returns + ------- + int + The count of the feature. + """ + return self._fallback_feature.count() + def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Update the feature. diff --git a/python/dgl/graphbolt/impl/torch_based_feature_store.py b/python/dgl/graphbolt/impl/torch_based_feature_store.py index 9337c8cb4f..42d5e7a859 100644 --- a/python/dgl/graphbolt/impl/torch_based_feature_store.py +++ b/python/dgl/graphbolt/impl/torch_based_feature_store.py @@ -239,6 +239,16 @@ class TorchBasedFeature(Feature): """ return self._tensor.size()[1:] + def count(self): + """Get the count of the feature. + + Returns + ------- + int + The count of the feature. + """ + return self._tensor.size()[0] + def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Update the feature store. @@ -493,6 +503,16 @@ class DiskBasedFeature(Feature): """ return self._tensor.size()[1:] + def count(self): + """Get the count of the feature. + + Returns + ------- + int + The count of the feature. + """ + return self._tensor.size()[0] + def update(self, value: torch.Tensor, ids: torch.Tensor = None): """Disk based feature does not support update for now.""" raise NotImplementedError diff --git a/tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py index d82e5a8113..261ac9d36b 100644 --- a/tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_basic_feature_store.py @@ -43,9 +43,11 @@ def test_basic_feature_store_homo(): torch.tensor([[[1, 2], [3, 4]]]), ) - # Test get the size of the entire feature. + # Test get the size and count of the entire feature. assert feature_store.size("node", None, "a") == torch.Size([3]) assert feature_store.size("node", None, "b") == torch.Size([2, 2]) + assert feature_store.count("node", None, "a") == a.size(0) + assert feature_store.count("node", None, "b") == b.size(0) # Test get metadata of the feature. assert feature_store.metadata("node", None, "a") == metadata diff --git a/tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py index 582d9fe939..ea8c5e9122 100644 --- a/tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_cpu_cached_feature.py @@ -81,9 +81,11 @@ def test_cpu_cached_feature(dtype, policy): assert total_miss == feat_store_b._feature.total_miss assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate - # Test get the size of the entire feature with ids. + # Test get the size and count of the entire feature. assert feat_store_a.size() == torch.Size([3]) assert feat_store_b.size() == torch.Size([2, 2]) + assert feat_store_a.count() == a.size(0) + assert feat_store_b.count() == b.size(0) # Test update the entire feature. feat_store_a.update(torch.tensor([[0, 1, 2], [3, 5, 2]], dtype=dtype)) diff --git a/tests/python/pytorch/graphbolt/impl/test_disk_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_disk_based_feature_store.py index 1fd5d3f9d3..300d98a4ef 100644 --- a/tests/python/pytorch/graphbolt/impl/test_disk_based_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_disk_based_feature_store.py @@ -82,9 +82,11 @@ def test_disk_based_feature(): ind_c = torch.randint(low=0, high=c.size(0), size=(4111,)) assert_equal(feature_c.read(ind_c), c[ind_c]) - # Test get the size of the entire feature. + # Test get the size and count of the entire feature. assert feature_a.size() == torch.Size([3]) assert feature_b.size() == torch.Size([2, 2]) + assert feature_a.count() == a.size(0) + assert feature_b.count() == b.size(0) # Test get metadata of the feature. assert feature_a.metadata() == metadata diff --git a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py index 9a9019ccab..4e2e2fabcd 100644 --- a/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py +++ b/tests/python/pytorch/graphbolt/impl/test_gpu_cached_feature.py @@ -85,9 +85,11 @@ def test_gpu_cached_feature(dtype, cache_size_a, cache_size_b): assert total_miss == feat_store_b._feature.total_miss assert feat_store_a._feature.miss_rate == feat_store_a.miss_rate - # Test get the size of the entire feature with ids. + # Test get the size and count of the entire feature. assert feat_store_a.size() == torch.Size([3]) assert feat_store_b.size() == torch.Size([2, 2]) + assert feat_store_a.count() == a.size(0) + assert feat_store_b.count() == b.size(0) # Test update the entire feature. feat_store_a.update( diff --git a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py index 445e5dbe8d..ff821b8092 100644 --- a/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py +++ b/tests/python/pytorch/graphbolt/impl/test_torch_based_feature_store.py @@ -79,9 +79,11 @@ def test_torch_based_feature(in_memory): ), ) - # Test get the size of the entire feature. + # Test get the size and count of the entire feature. assert feature_a.size() == torch.Size([3]) assert feature_b.size() == torch.Size([2, 2]) + assert feature_a.count() == 1 + assert feature_b.count() == 3 # Test get metadata of the feature. assert feature_a.metadata() == metadata