mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[GraphBolt] Move get_attributes utils. (#7601)
This commit is contained in:
committed by
GitHub
parent
0af92c2303
commit
781cc500c5
@@ -19,11 +19,15 @@ from ..internal import (
|
||||
calculate_dir_hash,
|
||||
check_dataset_change,
|
||||
copy_or_convert_data,
|
||||
get_attributes,
|
||||
read_data,
|
||||
read_edges,
|
||||
)
|
||||
from ..internal_utils import download, extract_archive, gb_warning
|
||||
from ..internal_utils import (
|
||||
download,
|
||||
extract_archive,
|
||||
gb_warning,
|
||||
get_attributes,
|
||||
)
|
||||
from ..itemset import HeteroItemSet, ItemSet
|
||||
from ..sampling_graph import SamplingGraph
|
||||
from .fused_csc_sampling_graph import (
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Dict, Union
|
||||
import torch
|
||||
|
||||
from ..base import CSCFormatBase, etype_str_to_tuple
|
||||
from ..internal import get_attributes
|
||||
from ..internal_utils import get_attributes
|
||||
from ..sampled_subgraph import SampledSubgraph
|
||||
|
||||
__all__ = ["SampledSubgraphImpl"]
|
||||
|
||||
@@ -144,32 +144,6 @@ def copy_or_convert_data(
|
||||
save_data(data, output_path, output_format)
|
||||
|
||||
|
||||
def get_nonproperty_attributes(_obj) -> list:
|
||||
"""Get attributes of the class except for the properties."""
|
||||
attributes = [
|
||||
attribute
|
||||
for attribute in dir(_obj)
|
||||
if not attribute.startswith("__")
|
||||
and (
|
||||
not hasattr(type(_obj), attribute)
|
||||
or not isinstance(getattr(type(_obj), attribute), property)
|
||||
)
|
||||
and not callable(getattr(_obj, attribute))
|
||||
]
|
||||
return attributes
|
||||
|
||||
|
||||
def get_attributes(_obj) -> list:
|
||||
"""Get attributes of the class."""
|
||||
attributes = [
|
||||
attribute
|
||||
for attribute in dir(_obj)
|
||||
if not attribute.startswith("__")
|
||||
and not callable(getattr(_obj, attribute))
|
||||
]
|
||||
return attributes
|
||||
|
||||
|
||||
def read_edges(dataset_dir, edge_fmt, edge_path):
|
||||
"""Read egde data from numpy or csv."""
|
||||
assert edge_fmt in [
|
||||
|
||||
@@ -151,6 +151,32 @@ def recursive_apply_reduce_all(data, fn, *args, **kwargs):
|
||||
return fn(data, *args, **kwargs)
|
||||
|
||||
|
||||
def get_nonproperty_attributes(_obj) -> list:
|
||||
"""Get attributes of the class except for the properties."""
|
||||
attributes = [
|
||||
attribute
|
||||
for attribute in dir(_obj)
|
||||
if not attribute.startswith("__")
|
||||
and (
|
||||
not hasattr(type(_obj), attribute)
|
||||
or not isinstance(getattr(type(_obj), attribute), property)
|
||||
)
|
||||
and not callable(getattr(_obj, attribute))
|
||||
]
|
||||
return attributes
|
||||
|
||||
|
||||
def get_attributes(_obj) -> list:
|
||||
"""Get attributes of the class."""
|
||||
attributes = [
|
||||
attribute
|
||||
for attribute in dir(_obj)
|
||||
if not attribute.startswith("__")
|
||||
and not callable(getattr(_obj, attribute))
|
||||
]
|
||||
return attributes
|
||||
|
||||
|
||||
def download(
|
||||
url,
|
||||
path=None,
|
||||
|
||||
@@ -6,8 +6,11 @@ from typing import Dict, List, Tuple, Union
|
||||
import torch
|
||||
|
||||
from .base import CSCFormatBase, etype_str_to_tuple, expand_indptr
|
||||
from .internal import get_attributes, get_nonproperty_attributes
|
||||
from .internal_utils import recursive_apply
|
||||
from .internal_utils import (
|
||||
get_attributes,
|
||||
get_nonproperty_attributes,
|
||||
recursive_apply,
|
||||
)
|
||||
from .sampled_subgraph import SampledSubgraph
|
||||
|
||||
__all__ = ["MiniBatch"]
|
||||
|
||||
Reference in New Issue
Block a user