CuEq->DeepSpeed Fallback

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
This commit is contained in:
Boris Fomitchev
2025-10-24 16:37:05 -07:00
parent cdf9d6927c
commit a1cf318e2a
2 changed files with 38 additions and 22 deletions

View File

@@ -44,9 +44,13 @@ def enforce_config_constraints(config):
(
"globals.use_lma",
"globals.use_flash",
"globals.use_cuequivariance_attention",
"globals.use_deepspeed_evo_attention"
),
(
"globals.use_lma",
"globals.use_flash",
"globals.use_cuequivariance_attention",
),
]
for options in mutually_exclusive_bools:

View File

@@ -17,6 +17,9 @@ import importlib
import math
from typing import Optional, Callable, List, Tuple
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import truncnorm
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
@@ -31,17 +34,25 @@ if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
if cuequivariance_is_installed:
try:
from cuequivariance_torch.primitives.triangle import triangle_attention
except ImportError as e:
print(e)
cuequivariance_is_installed = False
cueq_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
if cueq_is_installed:
from cuequivariance_ops_torch.triangle_attention import (
CUEQ_TRIATTN_FALLBACK_THRESHOLD,
)
from cuequivariance_torch.primitives.triangle import triangle_attention
import torch
import torch.nn as nn
from scipy.stats import truncnorm
def cueq_would_fall_back(n_token: int, hidden_dim: int, dtype: torch.dtype):
# for q_x, dimension -2 is the context length
if n_token <= CUEQ_TRIATTN_FALLBACK_THRESHOLD:
return True
if dtype == torch.float32:
if hidden_dim > 32 or hidden_dim % 4 != 0:
return True
else:
# float16, bfloat16
if hidden_dim > 128 or hidden_dim % 8 != 0:
return True
return False
from openfold.utils.checkpointing import get_checkpoint_fn
from openfold.utils.kernel.attention_core import attention_core
@@ -519,7 +530,7 @@ class Attention(nn.Module):
"cuEquivariance attention requires exactly two bias terms"
)
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, use_flash]
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention or use_cuequivariance_attention, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError(
"Choose at most one alternative attention algorithm"
@@ -530,12 +541,20 @@ class Attention(nn.Module):
# DeepSpeed attention kernel applies scaling internally
q, k, v = self._prep_qkv(q_x, kv_x,
apply_scale=not use_deepspeed_evo_attention)
apply_scale=not use_deepspeed_evo_attention or use_cuequivariance_attention)
if is_fp16_enabled():
use_memory_efficient_kernel = False
if use_memory_efficient_kernel:
# cuequivariance kernel takes precedence over use_deepspeed_evo_attention
if use_cuequivariance_attention:
if not cueq_is_installed:
raise ValueError(
"Running with `use_cuequivariance_attention` but package is not "
"installed. See documentation for installation instructions."
)
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
elif use_memory_efficient_kernel:
if len(biases) > 2:
raise ValueError(
"If use_memory_efficient_kernel is True, you may only "
@@ -550,8 +569,6 @@ class Attention(nn.Module):
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_cuequivariance_attention:
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
@@ -874,10 +891,6 @@ def _cuequivariance_attn(
Returns:
[*, H, Q, C_hidden] attention output
"""
if not cuequivariance_is_installed:
raise ValueError(
"_cuequivariance_attn requires that cuequivariance_torch be installed"
)
# Check input dimensionality
qdim = len(q.shape)
@@ -909,8 +922,7 @@ def _cuequivariance_attn(
k=k,
v=v,
bias=bias,
mask=mask,
scale=1.0
mask=mask
)
# If we added a batch dimension for 4D inputs, remove it