mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
CuEq->DeepSpeed Fallback
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
This commit is contained in:
@@ -44,9 +44,13 @@ def enforce_config_constraints(config):
|
||||
(
|
||||
"globals.use_lma",
|
||||
"globals.use_flash",
|
||||
"globals.use_cuequivariance_attention",
|
||||
"globals.use_deepspeed_evo_attention"
|
||||
),
|
||||
(
|
||||
"globals.use_lma",
|
||||
"globals.use_flash",
|
||||
"globals.use_cuequivariance_attention",
|
||||
),
|
||||
]
|
||||
|
||||
for options in mutually_exclusive_bools:
|
||||
|
||||
@@ -17,6 +17,9 @@ import importlib
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
|
||||
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
|
||||
@@ -31,17 +34,25 @@ if fa_is_installed:
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cuequivariance_is_installed:
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
cuequivariance_is_installed = False
|
||||
cueq_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cueq_is_installed:
|
||||
from cuequivariance_ops_torch.triangle_attention import (
|
||||
CUEQ_TRIATTN_FALLBACK_THRESHOLD,
|
||||
)
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import truncnorm
|
||||
def cueq_would_fall_back(n_token: int, hidden_dim: int, dtype: torch.dtype):
|
||||
# for q_x, dimension -2 is the context length
|
||||
if n_token <= CUEQ_TRIATTN_FALLBACK_THRESHOLD:
|
||||
return True
|
||||
if dtype == torch.float32:
|
||||
if hidden_dim > 32 or hidden_dim % 4 != 0:
|
||||
return True
|
||||
else:
|
||||
# float16, bfloat16
|
||||
if hidden_dim > 128 or hidden_dim % 8 != 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
from openfold.utils.checkpointing import get_checkpoint_fn
|
||||
from openfold.utils.kernel.attention_core import attention_core
|
||||
@@ -519,7 +530,7 @@ class Attention(nn.Module):
|
||||
"cuEquivariance attention requires exactly two bias terms"
|
||||
)
|
||||
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, use_flash]
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention or use_cuequivariance_attention, use_lma, use_flash]
|
||||
if sum(attn_options) > 1:
|
||||
raise ValueError(
|
||||
"Choose at most one alternative attention algorithm"
|
||||
@@ -530,12 +541,20 @@ class Attention(nn.Module):
|
||||
|
||||
# DeepSpeed attention kernel applies scaling internally
|
||||
q, k, v = self._prep_qkv(q_x, kv_x,
|
||||
apply_scale=not use_deepspeed_evo_attention)
|
||||
apply_scale=not use_deepspeed_evo_attention or use_cuequivariance_attention)
|
||||
|
||||
if is_fp16_enabled():
|
||||
use_memory_efficient_kernel = False
|
||||
|
||||
if use_memory_efficient_kernel:
|
||||
# cuequivariance kernel takes precedence over use_deepspeed_evo_attention
|
||||
if use_cuequivariance_attention:
|
||||
if not cueq_is_installed:
|
||||
raise ValueError(
|
||||
"Running with `use_cuequivariance_attention` but package is not "
|
||||
"installed. See documentation for installation instructions."
|
||||
)
|
||||
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
|
||||
elif use_memory_efficient_kernel:
|
||||
if len(biases) > 2:
|
||||
raise ValueError(
|
||||
"If use_memory_efficient_kernel is True, you may only "
|
||||
@@ -550,8 +569,6 @@ class Attention(nn.Module):
|
||||
"provide up to two bias terms"
|
||||
)
|
||||
o = _deepspeed_evo_attn(q, k, v, biases)
|
||||
elif use_cuequivariance_attention:
|
||||
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
|
||||
elif use_lma:
|
||||
biases = [
|
||||
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
|
||||
@@ -874,10 +891,6 @@ def _cuequivariance_attn(
|
||||
Returns:
|
||||
[*, H, Q, C_hidden] attention output
|
||||
"""
|
||||
if not cuequivariance_is_installed:
|
||||
raise ValueError(
|
||||
"_cuequivariance_attn requires that cuequivariance_torch be installed"
|
||||
)
|
||||
|
||||
# Check input dimensionality
|
||||
qdim = len(q.shape)
|
||||
@@ -909,8 +922,7 @@ def _cuequivariance_attn(
|
||||
k=k,
|
||||
v=v,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
scale=1.0
|
||||
mask=mask
|
||||
)
|
||||
|
||||
# If we added a batch dimension for 4D inputs, remove it
|
||||
|
||||
Reference in New Issue
Block a user