mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
dll bug fix
This commit is contained in:
@@ -210,7 +210,7 @@ def create_attention_indices(
|
||||
chain_ids is not None and len(torch.unique(chain_ids)) > 3
|
||||
): # Multi-chain structure
|
||||
# Reserve 25% of attention keys for inter-chain interactions
|
||||
k_inter_chain = max(32, k_actual // 4) # At least 32 inter-chain keys
|
||||
k_inter_chain = min(max(32, k_actual // 4), k_actual) # At least 32 inter-chain keys
|
||||
k_intra_chain = k_actual - k_inter_chain
|
||||
|
||||
attn_indices = get_sparse_attention_indices_with_inter_chain(
|
||||
@@ -413,7 +413,10 @@ def extend_index_mask_with_neighbours(
|
||||
|
||||
# 2. Find k-nn excluding forced indices
|
||||
D_LL = torch.where(mask, inf, D_LL)
|
||||
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
|
||||
try:
|
||||
filler_idx = torch.topk(D_LL, k, dim=-1, largest=False).indices
|
||||
except:
|
||||
raise ValueError(f"DLL has nan?: {torch.isnan(D_LL).any()}, D_LL shape: {D_LL.shape}, k: {k}")
|
||||
|
||||
# ... Reverse last axis s.t. best matched indices are last
|
||||
filler_idx = filler_idx.flip(dims=[-1])
|
||||
|
||||
Reference in New Issue
Block a user