mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-04 19:44:23 +08:00
[DistPart] Fix corner case in dist partition which always led to an assertion error being triggered. (#7395)
This commit is contained in:
committed by
GitHub
parent
64750575a2
commit
0851db72de
@@ -402,7 +402,6 @@ def exchange_feature(
|
||||
"""
|
||||
# type_ids for this feature subset on the current rank
|
||||
gids_feat = np.arange(gid_start, gid_end)
|
||||
tids_feat = np.arange(type_id_start, type_id_end)
|
||||
local_idx = np.arange(0, type_id_end - type_id_start)
|
||||
|
||||
feats_per_rank = []
|
||||
@@ -473,12 +472,22 @@ def exchange_feature(
|
||||
)
|
||||
|
||||
# exchange actual data here.
|
||||
if featdata_key != None:
|
||||
logging.debug(f"Rank: {rank} {featdata_key.shape=}")
|
||||
if featdata_key is not None:
|
||||
feat_dims_dtype = list(featdata_key.shape)
|
||||
assert (
|
||||
len(featdata_key.shape) == 2 or len(featdata_key.shape) == 1
|
||||
), f"We expect 1D or 2D tensors for features, got shape {featdata_key.shape}"
|
||||
# When a feature is 2-dim, the shape should match the feature dimension.
|
||||
if len(featdata_key.shape) == 2:
|
||||
feature_dimension = feat_dims_dtype[1]
|
||||
else:
|
||||
feature_dimension = 0
|
||||
feat_dims_dtype.append(DATA_TYPE_ID[featdata_key.dtype])
|
||||
else:
|
||||
feat_dims_dtype = list(np.zeros((rank0_shape_len), dtype=np.int64))
|
||||
feat_dims_dtype.append(DATA_TYPE_ID[torch.float32])
|
||||
feature_dimension = 0
|
||||
|
||||
logging.debug(f"Sending the feature shape information - {feat_dims_dtype}")
|
||||
all_dims_dtype = allgather_sizes(
|
||||
@@ -488,13 +497,18 @@ def exchange_feature(
|
||||
for idx in range(world_size):
|
||||
cond = partid_slice == (idx + local_part_id * world_size)
|
||||
gids_per_partid = gids_feat[cond]
|
||||
tids_per_partid = tids_feat[cond]
|
||||
local_idx_partid = local_idx[cond]
|
||||
|
||||
if gids_per_partid.shape[0] == 0:
|
||||
assert len(all_dims_dtype) % world_size == 0
|
||||
dim_len = int(len(all_dims_dtype) / world_size)
|
||||
rank0_shape = tuple(list(np.zeros((dim_len - 1), dtype=np.int32)))
|
||||
rank0_shape = list(np.zeros((dim_len - 1), dtype=np.int32))
|
||||
assert (
|
||||
len(rank0_shape) == 2 or len(rank0_shape) == 1
|
||||
), f"We expect 1D or 2D tensors for features, got shape {rank0_shape}"
|
||||
# When a feature is 2-dim, the shape[1] (number of columns) should match the feature dimension.
|
||||
if len(rank0_shape) == 2:
|
||||
rank0_shape[1] = feature_dimension
|
||||
rank0_dtype = REV_DATA_TYPE_ID[
|
||||
all_dims_dtype[(dim_len - 1) : (dim_len)][0]
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user