dll bug fix

This commit is contained in:
Raktim Mitra
2026-02-16 12:48:48 -08:00
parent 22d24edc5b
commit 4478253fa9

View File

@@ -210,7 +210,7 @@ def create_attention_indices(
chain_ids is not None and len(torch.unique(chain_ids)) > 3
): # Multi-chain structure
# Reserve 25% of attention keys for inter-chain interactions
k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys
k_inter_chain = min(max(32, k_actual // 4), k_actual) # At least 32 inter-chain keys
k_intra_chain = k_actual - k_inter_chain
attn_indices = get_sparse_attention_indices_with_inter_chain(
@@ -413,7 +413,10 @@ def extend_index_mask_with_neighbours(
# 2. Find k-nn excluding forced indices
D_LL = torch.where(mask, inf, D_LL)
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
try:
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
except:
raise ValueError(f"DLL has nan?: {torch.isnan(D_LL).any()}, D_LL shape: {D_LL.shape}, k: {k}")
# ... Reverse last axis s.t. best matched indices are last
filler_idx = filler_idx.flip(dims=[-1])