cuEquivariance integration

Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
This commit is contained in:
Boris Fomitchev
2025-09-04 13:09:38 -07:00
parent e938c184a2
commit 41d1c82165
16 changed files with 649 additions and 74 deletions

View File

@@ -29,6 +29,7 @@ def enforce_config_constraints(config):
(
"globals.use_lma",
"globals.use_flash",
"globals.use_cuequivariance_attention",
"globals.use_deepspeed_evo_attention"
),
]
@@ -51,6 +52,10 @@ def enforce_config_constraints(config):
"and that the deepspeed.ops.deepspeed4science package exists"
)
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
if (config.globals.use_cuequivariance_attention or config.globals.use_cuequivariance_multiplicative_update) and not cuequivariance_is_installed:
raise ValueError("use_cuequivariance_xxx requires that cuequivariance_torch is installed")
if(
config.globals.offload_inference and
not config.model.template.average_templates
@@ -64,6 +69,8 @@ def model_config(
low_prec=False,
long_sequence_inference=False,
use_deepspeed_evoformer_attention=False,
use_cuequivariance_attention=False,
use_cuequivariance_multiplicative_update=False,
):
c = copy.deepcopy(config)
# TRAINING PRESETS
@@ -240,7 +247,13 @@ def model_config(
if use_deepspeed_evoformer_attention:
c.globals.use_deepspeed_evo_attention = True
if use_cuequivariance_attention:
c.globals.use_cuequivariance_attention = True
if use_cuequivariance_multiplicative_update:
c.globals.use_cuequivariance_multiplicative_update = True
if train:
c.globals.blocks_per_ckpt = 1
c.globals.chunk_size = None
@@ -475,6 +488,11 @@ config = mlc.ConfigDict(
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
# on long sequences (>1000 residues).
"use_flash": False,
# Use cuEquivariance kernels for accelerated triangle attention and
# triangle multiplicative update operations. Requires CUDA and
# cuequivariance_torch package.
"use_cuequivariance_attention": False,
"use_cuequivariance_multiplicative_update": False,
"offload_inference": False,
"c_z": c_z,
"c_m": c_m,

View File

@@ -50,6 +50,8 @@ class Dropout(nn.Module):
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
if not self.training:
return x
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:

View File

@@ -658,6 +658,8 @@ class TemplateEmbedder(nn.Module):
chunk_size,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma=False,
inplace_safe=False
):
@@ -709,6 +711,8 @@ class TemplateEmbedder(nn.Module):
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
@@ -896,6 +900,8 @@ class TemplateEmbedderMultimer(nn.Module):
multichain_mask_2d,
_mask_trans=True,
use_deepspeed_evo_attention=False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma=False,
inplace_safe=False
):
@@ -971,6 +977,8 @@ class TemplateEmbedderMultimer(nn.Module):
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,

View File

@@ -1,5 +1,6 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import sys
import torch
@@ -19,6 +21,7 @@ import torch.nn as nn
from typing import Tuple, Sequence, Optional
from functools import partial
from abc import ABC, abstractmethod
from torch.fx._symbolic_trace import is_fx_tracing
from openfold.model.primitives import Linear, LayerNorm
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
@@ -179,6 +182,8 @@ class PairStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
@@ -197,6 +202,7 @@ class PairStack(nn.Module):
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
)
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
@@ -210,6 +216,7 @@ class PairStack(nn.Module):
mask=pair_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
)
if (not inplace_safe):
z = z + self.ps_dropout_row_layer(tmu_update)
@@ -226,6 +233,7 @@ class PairStack(nn.Module):
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
@@ -245,6 +253,7 @@ class PairStack(nn.Module):
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
@@ -363,6 +372,7 @@ class MSABlock(nn.Module, ABC):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
@@ -427,6 +437,8 @@ class EvoformerBlock(MSABlock):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
@@ -467,6 +479,7 @@ class EvoformerBlock(MSABlock):
chunk_size=_attn_chunk_size,
use_memory_efficient_kernel=False,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
)
),
@@ -489,6 +502,7 @@ class EvoformerBlock(MSABlock):
mask=msa_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
use_flash=use_flash,
),
@@ -534,6 +548,8 @@ class EvoformerBlock(MSABlock):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
@@ -610,6 +626,8 @@ class ExtraMSABlock(MSABlock):
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
@@ -618,8 +636,8 @@ class ExtraMSABlock(MSABlock):
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if(_attn_chunk_size is None):
_attn_chunk_size = chunk_size
_attn_chunk_size = chunk_size
if(_offload_inference and inplace_safe):
input_tensors = _offloadable_inputs
del _offloadable_inputs
@@ -646,7 +664,8 @@ class ExtraMSABlock(MSABlock):
chunk_size=_attn_chunk_size,
use_lma=use_lma,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
use_cuequivariance_attention=use_cuequivariance_attention,
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention or use_cuequivariance_attention),
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
@@ -719,6 +738,8 @@ class ExtraMSABlock(MSABlock):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
@@ -857,13 +878,15 @@ class EvoformerStack(nn.Module):
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
self.chunk_size_tuner = ChunkSizeTuner(2048)
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_cuequivariance_attention: bool,
use_cuequivariance_multiplicative_update: bool,
use_lma: bool,
use_flash: bool,
msa_mask: Optional[torch.Tensor],
@@ -878,6 +901,8 @@ class EvoformerStack(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
use_flash=use_flash,
inplace_safe=inplace_safe,
@@ -901,12 +926,13 @@ class EvoformerStack(nn.Module):
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
# A temporary measure to address torch's occasional
# inability to allocate large tensors
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
_attn_chunk_size=max(chunk_size, attn_chunk),
) for b in blocks
]
@@ -918,6 +944,8 @@ class EvoformerStack(nn.Module):
pair_mask: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
use_flash: bool = False,
_mask_trans: bool = True,
@@ -930,6 +958,8 @@ class EvoformerStack(nn.Module):
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
@@ -960,8 +990,10 @@ class EvoformerStack(nn.Module):
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
chunk_size: int = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
@@ -996,12 +1028,19 @@ class EvoformerStack(nn.Module):
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
"""
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
inplace_safe = False
chunk_size = None
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
use_flash=use_flash,
msa_mask=msa_mask,
@@ -1080,13 +1119,15 @@ class ExtraMSAStack(nn.Module):
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if(tune_chunk_size):
self.chunk_size_tuner = ChunkSizeTuner()
self.chunk_size_tuner = ChunkSizeTuner(2048)
def _prep_blocks(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool,
use_cuequivariance_attention: bool,
use_cuequivariance_multiplicative_update: bool,
use_lma: bool,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
@@ -1100,6 +1141,8 @@ class ExtraMSAStack(nn.Module):
pair_mask=pair_mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
@@ -1122,12 +1165,15 @@ class ExtraMSAStack(nn.Module):
args=(m.clone(), z.clone(),),
min_chunk_size=chunk_size,
)
# A temporary measure to address torch's occasional
# inability to allocate large tensors
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
# A temporary measure to address torch's occasional
# inability to allocate large tensors
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
_attn_chunk_size=max(chunk_size, attn_chunk),
) for b in blocks
]
@@ -1137,6 +1183,8 @@ class ExtraMSAStack(nn.Module):
input_tensors: Sequence[torch.Tensor],
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
@@ -1150,6 +1198,8 @@ class ExtraMSAStack(nn.Module):
z=input_tensors[1],
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,
@@ -1175,8 +1225,10 @@ class ExtraMSAStack(nn.Module):
z: torch.Tensor,
msa_mask: Optional[torch.Tensor],
pair_mask: Optional[torch.Tensor],
chunk_size: int,
chunk_size: int = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
@@ -1197,12 +1249,19 @@ class ExtraMSAStack(nn.Module):
Returns:
[*, N_res, N_res, C_z] pair update
"""
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
inplace_safe = False
chunk_size = None
checkpoint_fn = get_checkpoint_fn()
blocks = self._prep_blocks(
m=m,
z=z,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
msa_mask=msa_mask,
pair_mask=pair_mask,

View File

@@ -147,6 +147,8 @@ class AlphaFold(nn.Module):
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
@@ -171,6 +173,8 @@ class AlphaFold(nn.Module):
templ_dim,
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans
@@ -382,6 +386,8 @@ class AlphaFold(nn.Module):
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
_mask_trans=self.config._mask_trans,
@@ -395,6 +401,8 @@ class AlphaFold(nn.Module):
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
pair_mask=pair_mask.to(dtype=m.dtype),
inplace_safe=inplace_safe,
@@ -414,6 +422,8 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
_mask_trans=self.config._mask_trans,
)
@@ -427,6 +437,8 @@ class AlphaFold(nn.Module):
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=self.globals.use_lma,
use_flash=self.globals.use_flash,
inplace_safe=inplace_safe,

View File

@@ -17,6 +17,7 @@ import math
import torch
import torch.nn as nn
from typing import Optional, List, Tuple
from torch.fx._symbolic_trace import is_fx_tracing
from openfold.model.primitives import (
Linear,
@@ -93,6 +94,7 @@ class MSAAttention(nn.Module):
chunk_size: int,
use_memory_efficient_kernel: bool,
use_deepspeed_evo_attention: bool,
use_cuequivariance_attention: bool,
use_lma: bool,
use_flash: bool,
flash_mask: Optional[torch.Tensor],
@@ -105,6 +107,7 @@ class MSAAttention(nn.Module):
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=flash_mask,
@@ -132,37 +135,50 @@ class MSAAttention(nn.Module):
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
inplace_safe: bool = False,
use_cuequivariance_attention: bool = False,
chunk_size: int = 256
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res),
m.shape[:-1],
)
# [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
if use_cuequivariance_attention:
mask_bias = mask[..., :, None, None, :]
else:
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
chunks = []
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
inplace_safe = False
chunk_size = None
for i in range(0, z.shape[-3], 256):
z_chunk = z[..., i: i + 256, :, :]
if chunk_size is None:
z = self.layer_norm_z(z)
z = self.linear_z(z)
else:
chunks = []
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
for i in range(0, z.shape[-3], chunk_size):
z_chunk = z[..., i: i + chunk_size, :, :]
# [*, N_res, N_res, C_z]
z_chunk = self.layer_norm_z(z_chunk)
# [*, N_res, N_res, no_heads]
z_chunk = self.linear_z(z_chunk)
chunks.append(z_chunk)
z = torch.cat(chunks, dim=-3)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
@@ -224,6 +240,7 @@ class MSAAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
inplace_safe: bool = False,
@@ -252,16 +269,19 @@ class MSAAttention(nn.Module):
checkpoint=_checkpoint_chunks,
inplace_safe=inplace_safe,
)
if(use_flash):
assert z is None
biases = None
else:
m, mask_bias, z = self._prep_inputs(
m, z, mask, inplace_safe=inplace_safe
m, z, mask, inplace_safe=inplace_safe,
use_cuequivariance_attention=use_cuequivariance_attention,
)
biases = [mask_bias]
if z is None and use_cuequivariance_attention:
z = torch.zeros_like(m)
if(z is not None):
biases.append(z)
@@ -272,6 +292,7 @@ class MSAAttention(nn.Module):
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
@@ -284,6 +305,7 @@ class MSAAttention(nn.Module):
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
use_flash=use_flash,
flash_mask=mask,
@@ -362,6 +384,7 @@ class MSAColumnAttention(nn.Module):
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
use_flash: bool = False,
) -> torch.Tensor:
@@ -386,6 +409,7 @@ class MSAColumnAttention(nn.Module):
mask=mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
use_flash=use_flash,
)

View File

@@ -30,6 +30,14 @@ if fa_is_installed:
from flash_attn.bert_padding import unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
if cuequivariance_is_installed:
try:
from cuequivariance_torch.primitives.triangle import triangle_attention
except ImportError as e:
print(e)
cuequivariance_is_installed = False
import torch
import torch.nn as nn
from scipy.stats import truncnorm
@@ -199,7 +207,7 @@ class Linear(nn.Linear):
bias).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast('cuda', enabled=False):
bias = self.bias.to(dtype=d) if self.bias is not None else None
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
@@ -223,7 +231,7 @@ class LayerNorm(nn.Module):
deepspeed.comm.comm.is_initialized()
)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast('cuda', enabled=False):
out = nn.functional.layer_norm(
x,
self.c_in,
@@ -255,7 +263,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
deepspeed.comm.comm.is_initialized()
)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast('cuda', enabled=False):
s = torch.nn.functional.softmax(t, dim=dim)
else:
s = torch.nn.functional.softmax(t, dim=dim)
@@ -452,6 +460,7 @@ class Attention(nn.Module):
biases: Optional[List[torch.Tensor]] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
@@ -483,6 +492,11 @@ class Attention(nn.Module):
Query chunk size (for LMA)
lma_kv_chunk_size:
Key/Value chunk size (for LMA)
use_cuequivariance_attention:
Whether to use cuEquivariance attention kernel.
When on, biases[0] contains 0/1 mask tensor for cuEquivariance attention (0 for invalid positions)
Returns
[*, Q, C_q] attention update
"""
@@ -498,7 +512,13 @@ class Attention(nn.Module):
"use flash_mask instead"
)
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
if use_cuequivariance_attention:
if biases is None or len(biases) != 2:
raise ValueError(
"cuEquivariance attention requires exactly two bias terms"
)
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, use_flash]
if sum(attn_options) > 1:
raise ValueError(
"Choose at most one alternative attention algorithm"
@@ -529,6 +549,8 @@ class Attention(nn.Module):
"provide up to two bias terms"
)
o = _deepspeed_evo_attn(q, k, v, biases)
elif use_cuequivariance_attention:
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
elif use_lma:
biases = [
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
@@ -828,3 +850,75 @@ def _flash_attn(q, k, v, kv_mask):
out = out.to(dtype=dtype)
return out
@torch.jit.ignore
def _cuequivariance_attn(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bias: torch.Tensor,
mask: Optional[torch.Tensor] = None,
):
"""
Compute attention using the cuEquivariance triangle attention kernel.
Args:
q: [*, H, Q, C_hidden] query data
k: [*, H, K, C_hidden] key data
v: [*, H, V, C_hidden] value data
bias: [*, H, Q, K] triangular bias
mask: [*, Q, K] mask for masking invalid positions
Returns:
[*, H, Q, C_hidden] attention output
"""
if not cuequivariance_is_installed:
raise ValueError(
"_cuequivariance_attn requires that cuequivariance_torch be installed"
)
# Check input dimensionality
qdim = len(q.shape)
# If we have 4D tensors ([*, H, Q, D]), add batch dimension
if qdim == 4:
q = q.unsqueeze(0) # [1, H, Q, D]
k = k.unsqueeze(0) # [1, H, K, D]
v = v.unsqueeze(0) # [1, H, V, D]
bias = bias.unsqueeze(0) # [1, H, Q, K]
if mask is not None:
mask = mask.unsqueeze(0) # [1, Q, K]
elif len(q.shape[:-3]) > 2:
# If there are more than 2 leading dimensions, flatten them into B*N
batch_shape = q.shape[:-3]
flat_batch_size = 1
for dim in batch_shape:
flat_batch_size *= dim
q = q.reshape(flat_batch_size, *q.shape[-3:])
k = k.reshape(flat_batch_size, *k.shape[-3:])
v = v.reshape(flat_batch_size, *v.shape[-3:])
bias = bias.reshape(flat_batch_size, *bias.shape[-3:])
if mask is not None:
mask = mask.reshape(flat_batch_size, *mask.shape[-2:])
# Convert bias to float32
bias = bias.to(dtype=torch.float32)
# Apply cuEquivariance triangle attention
o = triangle_attention(
q=q,
k=k,
v=v,
bias=bias,
mask=mask,
scale=1.0
)
# If we added a batch dimension for 4D inputs, remove it
if qdim == 4:
o = o.squeeze(0)
# Final transpose to match expected output format
o = o.transpose(-2, -3)
return o

View File

@@ -216,6 +216,7 @@ class TemplatePairStackBlock(nn.Module):
_attn_chunk_size: Optional[int],
single_mask: torch.Tensor,
use_deepspeed_evo_attention: bool,
use_cuequivariance_attention: bool,
use_lma: bool,
inplace_safe: bool):
single = add(single,
@@ -225,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
@@ -239,6 +241,7 @@ class TemplatePairStackBlock(nn.Module):
chunk_size=_attn_chunk_size,
mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
@@ -251,12 +254,14 @@ class TemplatePairStackBlock(nn.Module):
def tri_mul_out_in(self,
single: torch.Tensor,
single_mask: torch.Tensor,
use_cuequivariance_multiplicative_update: bool,
inplace_safe: bool):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
)
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
@@ -270,6 +275,7 @@ class TemplatePairStackBlock(nn.Module):
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
)
if not inplace_safe:
single = single + self.dropout_row(tmu_update)
@@ -285,6 +291,8 @@ class TemplatePairStackBlock(nn.Module):
mask: torch.Tensor,
chunk_size: Optional[int] = None,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
@@ -307,10 +315,12 @@ class TemplatePairStackBlock(nn.Module):
if self.tri_mul_first:
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single_mask=single_mask,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe)
else:
@@ -319,9 +329,11 @@ class TemplatePairStackBlock(nn.Module):
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
inplace_safe=inplace_safe)
single = add(single,
@@ -405,7 +417,7 @@ class TemplatePairStack(nn.Module):
self.tune_chunk_size = tune_chunk_size
self.chunk_size_tuner = None
if tune_chunk_size:
self.chunk_size_tuner = ChunkSizeTuner()
self.chunk_size_tuner = ChunkSizeTuner(2048)
def forward(
self,
@@ -413,6 +425,8 @@ class TemplatePairStack(nn.Module):
mask: torch.tensor,
chunk_size: int,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
_mask_trans: bool = True,
@@ -437,6 +451,8 @@ class TemplatePairStack(nn.Module):
mask=mask,
chunk_size=chunk_size,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
@@ -451,10 +467,11 @@ class TemplatePairStack(nn.Module):
args=(t.clone(),),
min_chunk_size=chunk_size,
)
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
blocks = [
partial(b,
chunk_size=tuned_chunk_size,
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
_attn_chunk_size=max(chunk_size, attn_chunk),
) for b in blocks
]
@@ -528,6 +545,8 @@ def embed_templates_offload(
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,
@@ -647,6 +666,8 @@ def embed_templates_average(
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=model.globals.chunk_size,
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
use_lma=model.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=model.config._mask_trans,

View File

@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
chunk_size: int,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
@@ -79,6 +80,7 @@ class TriangleAttention(nn.Module):
self.mha,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma
),
mha_inputs,
@@ -93,6 +95,7 @@ class TriangleAttention(nn.Module):
chunk_size: Optional[int] = None,
use_memory_efficient_kernel: bool = False,
use_deepspeed_evo_attention: bool = False,
use_cuequivariance_attention: bool = False,
use_lma: bool = False,
inplace_safe: bool = False,
) -> torch.Tensor:
@@ -117,7 +120,10 @@ class TriangleAttention(nn.Module):
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
if use_cuequivariance_attention:
mask_bias = mask[..., :, None, None, :]
else:
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
@@ -134,6 +140,7 @@ class TriangleAttention(nn.Module):
chunk_size,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
@@ -144,6 +151,7 @@ class TriangleAttention(nn.Module):
biases=biases,
use_memory_efficient_kernel=use_memory_efficient_kernel,
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
use_cuequivariance_attention=use_cuequivariance_attention,
use_lma=use_lma
)

View File

@@ -1,5 +1,6 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -16,15 +17,86 @@
from functools import partialmethod
from typing import Optional
from abc import ABC, abstractmethod
import importlib
import torch
import torch.nn as nn
from torch.fx._symbolic_trace import is_fx_tracing
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.chunk_utils import chunk_layer
from openfold.utils.precision_utils import is_fp16_enabled
from openfold.utils.tensor_utils import add, permute_final_dims
# cuEquivariance import handling
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
if cuequivariance_is_installed:
try:
from cuequivariance_torch.primitives.triangle import triangle_multiplicative_update
except ImportError:
cuequivariance_is_installed = False
def _cuequivariance_triangular_mult(
x: torch.Tensor,
direction: str,
mask: Optional[torch.Tensor],
norm_in_weight: torch.Tensor,
norm_in_bias: torch.Tensor,
p_in_weight: torch.Tensor,
p_in_bias: torch.Tensor,
g_in_weight: torch.Tensor,
g_in_bias: torch.Tensor,
norm_out_weight: torch.Tensor,
norm_out_bias: torch.Tensor,
p_out_weight: torch.Tensor,
p_out_bias: torch.Tensor,
g_out_weight: torch.Tensor,
g_out_bias: torch.Tensor,
eps: float = 1e-5,
):
"""
Wrapper function for cuEquivariance triangle multiplicative update.
Args:
x: [*, N, N, C] input tensor
direction: "outgoing" or "incoming"
mask: [*, N, N] mask tensor
norm_in_weight: [C] input normalization weight
norm_in_bias: [C] input normalization bias
p_in_weight: [2*C, C] input projection weight
g_in_weight: [2*C, C] input gating weight
norm_out_weight: [C] output normalization weight
norm_out_bias: [C] output normalization bias
p_out_weight: [C, C] output projection weight
g_out_weight: [C, C] output gating weight
eps: epsilon for numerical stability
Returns:
[*, N, N, C] output tensor
"""
if not cuequivariance_is_installed:
raise ValueError(
"_cuequivariance_triangular_mult requires that cuequivariance_torch be installed"
)
return triangle_multiplicative_update(
x=x,
direction=direction,
mask=mask,
norm_in_weight=norm_in_weight,
norm_in_bias=norm_in_bias,
p_in_weight=p_in_weight,
p_in_bias=p_in_bias,
g_in_weight=g_in_weight,
g_in_bias=g_in_bias,
norm_out_weight=norm_out_weight,
norm_out_bias=norm_out_bias,
p_out_weight=p_out_weight,
p_out_bias=p_out_bias,
g_out_weight=g_out_weight,
g_out_bias=g_out_bias,
eps=eps,
).view(x.shape)
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
"""
@@ -87,6 +159,7 @@ class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
_add_with_inplace: bool = False
) -> torch.Tensor:
"""
@@ -397,6 +470,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256,
) -> torch.Tensor:
@@ -409,7 +483,38 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if(inplace_safe):
if use_cuequivariance_multiplicative_update:
p_in_weight = torch.cat([self.linear_a_p.weight, self.linear_b_p.weight], dim=0)
g_in_weight = torch.cat([self.linear_a_g.weight, self.linear_b_g.weight], dim=0)
p_in_bias = torch.cat([self.linear_a_p.bias, self.linear_b_p.bias], dim=0)
g_in_bias = torch.cat([self.linear_a_g.bias, self.linear_b_g.bias], dim=0)
result = _cuequivariance_triangular_mult(
z,
direction="outgoing" if self._outgoing else "incoming",
mask=mask,
norm_in_weight=self.layer_norm_in.weight,
norm_in_bias=self.layer_norm_in.bias,
p_in_weight=p_in_weight,
p_in_bias=p_in_bias,
g_in_weight=g_in_weight,
g_in_bias=g_in_bias,
norm_out_weight=self.layer_norm_out.weight,
norm_out_bias=self.layer_norm_out.bias,
p_out_weight=self.linear_z.weight,
p_out_bias=self.linear_z.bias,
g_out_weight=self.linear_g.weight,
g_out_bias=self.linear_g.bias,
eps=1e-5,
)
# When not inplace_safe (training), caller should have set _add_with_inplace to False
if inplace_safe and _add_with_inplace:
result += z
return result
if inplace_safe:
x = self._inference_forward(
z,
mask,
@@ -422,7 +527,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = mask
a = a * self.sigmoid(self.linear_a_g(z))
@@ -433,13 +538,12 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
# Prevents overflow of torch.matmul in combine projections in
# reduced-precision modes
a_std = a.std()
b_std = b.std()
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
a = a / a.std()
b = b / b.std()
if(is_fp16_enabled()):
if is_fp16_enabled():
a_std = a.std()
b_std = b.std()
if a_std != 0. and b_std != 0.:
a = a / a.std()
b = b / b.std()
with torch.cuda.amp.autocast(enabled=False):
x = self._combine_projections(a.float(), b.float())
else:
@@ -545,6 +649,7 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
inplace_safe: bool = False,
use_cuequivariance_multiplicative_update: bool = False,
_add_with_inplace: bool = False,
_inplace_chunk_size: Optional[int] = 256
) -> torch.Tensor:
@@ -557,6 +662,28 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if use_cuequivariance_multiplicative_update:
direction = "outgoing" if self._outgoing else "incoming"
result = _cuequivariance_triangular_mult(
x=z,
direction=direction,
mask=mask,
norm_in_weight=self.layer_norm_in.weight,
norm_in_bias=self.layer_norm_in.bias,
p_in_weight=self.linear_ab_p.weight,
g_in_weight=self.linear_ab_g.weight,
norm_out_weight=self.layer_norm_out.weight,
norm_out_bias=self.layer_norm_out.bias,
p_out_weight=self.linear_z.weight,
g_out_weight=self.linear_g.weight,
eps=1e-5,
)
# When not inplace_safe (training), caller should have set _add_with_inplace to False
if inplace_safe and _add_with_inplace:
result += z
return result
if (inplace_safe):
x = self._inference_forward(
z,

View File

@@ -252,6 +252,16 @@ def chunk_layer(
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
if no_chunks == 1:
return layer(**inputs)
def _prep_inputs(t):
if(not low_mem):
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
@@ -267,14 +277,6 @@ def chunk_layer(
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
prepped_outputs = tensor_tree_map(reshape_fn, _out)
flat_batch_dim = 1
for d in orig_batch_dims:
flat_batch_dim *= d
no_chunks = flat_batch_dim // chunk_size + (
flat_batch_dim % chunk_size != 0
)
i = 0
out = prepped_outputs
for _ in range(no_chunks):

View File

@@ -182,6 +182,7 @@ def trace_model_(model, sample_input):
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
]
verify_arg_order(
@@ -203,6 +204,7 @@ def trace_model_(model, sample_input):
("mask", msa_mask),
("chunk_size", torch.tensor(evoformer_chunk_size)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("use_flash", torch.tensor(model.globals.use_flash)),
]
@@ -286,6 +288,7 @@ def trace_model_(model, sample_input):
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]
@@ -309,6 +312,7 @@ def trace_model_(model, sample_input):
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
("use_memory_efficient_kernel", torch.tensor(False)),
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
("use_lma", torch.tensor(model.globals.use_lma)),
("inplace_safe", torch.tensor(True)),
]

View File

@@ -1,5 +1,6 @@
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
@@ -20,13 +22,13 @@ import os
import pickle
import random
import time
import torch
import json
logging.basicConfig()
logger = logging.getLogger(__file__)
logger.setLevel(level=logging.INFO)
import torch
torch_versions = torch.__version__.split(".")
torch_major_version = int(torch_versions[0])
torch_minor_version = int(torch_versions[1])
@@ -183,13 +185,15 @@ def main(args):
args.config_preset,
long_sequence_inference=args.long_sequence_inference,
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
)
use_cuequivariance_attention=args.use_cuequivariance_attention,
use_cuequivariance_multiplicative_update=args.use_cuequivariance_multiplicative_update,
)
if args.experiment_config_json:
with open(args.experiment_config_json, 'r') as f:
custom_config_dict = json.load(f)
config.update_from_flattened_dict(custom_config_dict)
if args.trace_model:
if not config.data.predict.fixed_size:
raise ValueError(
@@ -482,6 +486,14 @@ if __name__ == "__main__":
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
)
parser.add_argument(
"--use_cuequivariance_attention", action="store_true", default=False,
help="""Use cuEquivariance kernels for attention computation."""
)
parser.add_argument(
"--use_cuequivariance_multiplicative_update", action="store_true", default=False,
help="""Use cuEquivariance kernels for triangular multiplicative update computation."""
)
add_data_args(parser)
args = parser.parse_args()

View File

@@ -52,15 +52,23 @@ def get_cuda_bare_metal_version(cuda_dir):
return raw_output, bare_metal_major, bare_metal_minor
compute_capabilities = set([
(5, 2), # Titan X
(6, 1), # GeForce 1000-series
(9, 0), # Hopper
])
compute_capabilities.add((7, 0))
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
if int(bare_metal_major) >= 11:
compute_capabilities.add((8, 0))
compute_capabilities.add((8, 6))
compute_capabilities.add((8, 9))
if int(bare_metal_major) >= 12:
compute_capabilities.add((9, 0))
if int(bare_metal_major) >= 13:
compute_capabilities.add((10, 0))
compute_capabilities.add((10, 3))
compute_capabilities.add((12, 0))
else:
compute_capabilities.add((7, 0))
compute_capability, _ = get_nvidia_cc()
if compute_capability is not None:
@@ -75,8 +83,6 @@ for major, minor in list(compute_capabilities):
extra_cuda_flags += cc_flag
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
if bare_metal_major != -1:
modules = [CUDAExtension(
name="attn_core_inplace_cuda",
@@ -127,6 +133,12 @@ setup(
},
ext_modules=modules,
cmdclass={'build_ext': BuildExtension},
extras_require={
'cuequivariance': [
'cuequivariance-torch; sys_platform != "darwin"', # Not available on macOS
'triton>=3.3.0; sys_platform != "darwin"', # Required for triangle multiplicative update
],
},
classifiers=[
'License :: OSI Approved :: Apache Software License',
'Operating System :: POSIX :: Linux',

View File

@@ -0,0 +1,160 @@
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Unit tests to compare components of OpenFold run with the cuEquivariance memory-efficient
attention kernel vs. a stock PyTorch attention implementation.
"""
import unittest
import numpy as np
import pickle
import torch
from torch.nn import functional as F
from openfold.data import data_transforms
from openfold.model.primitives import (
lecun_normal_init_,
Attention
)
from openfold.utils.tensor_utils import tensor_tree_map
from tests.config import consts
import tests.compare_utils as compare_utils
from tests.data_utils import random_template_feats, random_attention_inputs
class TestCuEquivarianceKernel(unittest.TestCase):
def test_compare_template_stack(self):
"""
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
Kernel can be used for Triangle Attention in the Template Pair Stack.
"""
n_templ = consts.n_templ
n_res = 20
eps = 2e-2
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
if consts.is_multimer:
batch["asym_id"] = batch['asym_id'][0]
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
}
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"].cpu()
model.globals.use_cuequivariance_attention = True
model.globals.use_cuequivariance_multiplicative_update = True
out_repro_ds = model.embed_templates(
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
def test_compare_model(self):
"""
Run full model with and without using CuEquivariance Evoformer attention kernel
and compare output coordinates.
"""
eps = 0.2
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
if consts.is_multimer:
n_res = batch['aatype'].shape[1]
n_extra_seq = batch['extra_msa'].shape[1]
batch["asym_id"] = np.ones((4, n_res))
batch["entity_id"] = np.ones((4, n_res))
batch["sym_id"] = np.ones((4, n_res))
batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
)
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
# Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
# https://github.com/aqlaboratory/openfold/issues/532
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_cuequivariance_attention = False
model.globals.use_cuequivariance_multiplicative_update = False
out_repro = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
# Enable attention
model.globals.use_cuequivariance_attention = True
out_repro_attn = model(batch)
out_repro_attn = tensor_tree_map(lambda t: t.cpu(), out_repro_attn)
out_repro_attn = out_repro_attn["sm"]["positions"][-1].squeeze(0)
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_attn, eps)
# Enable multiplication
model.globals.use_cuequivariance_attention = True
model.globals.use_cuequivariance_multiplicative_update = True
out_repro_mul = model(batch)
out_repro_mul = tensor_tree_map(lambda t: t.cpu(), out_repro_mul)
out_repro_mul = out_repro_mul["sm"]["positions"][-1].squeeze(0)
compare_utils.assert_mean_abs_diff_small(out_repro_attn, out_repro_mul, eps)
if __name__ == "__main__":
unittest.main()

View File

@@ -113,7 +113,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True)
def _tri_mul_inplace(self, incoming=False):
def _tri_mul_inplace(self, incoming=False, dtype = torch.float32):
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
@@ -126,26 +126,38 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
if incoming
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
act = torch.as_tensor(pair_act, dtype=dtype).cuda()
mask = torch.as_tensor(pair_mask, dtype=dtype).cuda()
module = module.to(dtype=dtype)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
act,
mask=mask,
inplace_safe=False,
).cpu()
)
# This has to come second because inference mode is in-place
out_inplace = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
act,
mask=mask,
inplace_safe=True, _inplace_chunk_size=2,
).cpu()
)
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
torch.testing.assert_close(out_stock, out_inplace, rtol=0.1, atol=0.1)
def test_tri_mul_out_inference(self):
self._tri_mul_inplace()
def test_tri_mul_out_inference_bf16(self):
self._tri_mul_inplace(dtype=torch.bfloat16)
def test_tri_mul_in_inference(self):
self._tri_mul_inplace(incoming=True)
def test_tri_mul_in_inference_bf16(self):
self._tri_mul_inplace(incoming=True, dtype=torch.bfloat16)
if __name__ == "__main__":
unittest.main()