mirror of
https://github.com/dmlc/dgl.git
synced 2026-06-03 19:34:33 +08:00
[GraphBolt] Minor improvement to item sampler and cache policy. (#7734)
This commit is contained in:
committed by
GitHub
parent
f37f24c77c
commit
9514e7b9cd
@@ -238,7 +238,7 @@ class BaseCachePolicy {
|
||||
// Move the element to the beginning of the queue.
|
||||
to.splice(to.begin(), temp);
|
||||
// The iterators and references are not invalidated.
|
||||
// TORCH_CHECK(it == to.begin());
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(it == to.begin());
|
||||
}
|
||||
|
||||
int64_t capacity_;
|
||||
|
||||
@@ -330,12 +330,11 @@ class ItemSampler(IterDataPipe):
|
||||
self._drop_uneven_inputs,
|
||||
)
|
||||
if self._shuffle:
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self._seed + self._epoch)
|
||||
g = torch.Generator().manual_seed(self._seed + self._epoch)
|
||||
permutation = torch.randperm(total, generator=g)
|
||||
indices = permutation[start_offset : start_offset + assigned_count]
|
||||
else:
|
||||
permutation = torch.arange(total)
|
||||
indices = permutation[start_offset : start_offset + assigned_count]
|
||||
indices = torch.arange(start_offset, start_offset + assigned_count)
|
||||
for i in range(0, assigned_count, self._batch_size):
|
||||
if output_count <= 0:
|
||||
break
|
||||
|
||||
Reference in New Issue
Block a user