debug cleanup

This commit is contained in:
Raktim Mitra
2026-03-20 16:11:19 -07:00
parent e94ea147a6
commit f10771f98d

View File

@@ -413,10 +413,7 @@ def extend_index_mask_with_neighbours(
# 2. Find k-nn excluding forced indices
D_LL = torch.where(mask, inf, D_LL)
try:
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
except:
raise ValueError(f"DLL has nan?: {torch.isnan(D_LL).any()}, D_LL shape: {D_LL.shape}, k: {k}")
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
# ... Reverse last axis s.t. best matched indices are last
filler_idx = filler_idx.flip(dims=[-1])