[GraphBolt] Minor improvement to item sampler and cache policy. (#7734)

This commit is contained in:
Muhammed Fatih BALIN
2024-08-23 14:55:14 -04:00
committed by GitHub
parent f37f24c77c
commit 9514e7b9cd
2 changed files with 4 additions and 5 deletions

View File

@@ -238,7 +238,7 @@ class BaseCachePolicy {
// Move the element to the beginning of the queue.
to.splice(to.begin(), temp);
// The iterators and references are not invalidated.
// TORCH_CHECK(it == to.begin());
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(it == to.begin());
}
int64_t capacity_;

View File

@@ -330,12 +330,11 @@ class ItemSampler(IterDataPipe):
self._drop_uneven_inputs,
)
if self._shuffle:
g = torch.Generator()
g.manual_seed(self._seed + self._epoch)
g = torch.Generator().manual_seed(self._seed + self._epoch)
permutation = torch.randperm(total, generator=g)
indices = permutation[start_offset : start_offset + assigned_count]
else:
permutation = torch.arange(total)
indices = permutation[start_offset : start_offset + assigned_count]
indices = torch.arange(start_offset, start_offset + assigned_count)
for i in range(0, assigned_count, self._batch_size):
if output_count <= 0:
break