mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
cuEquivariance integration
Signed-off-by: Boris Fomitchev <bfomitchev@nvidia.com>
This commit is contained in:
@@ -29,6 +29,7 @@ def enforce_config_constraints(config):
|
||||
(
|
||||
"globals.use_lma",
|
||||
"globals.use_flash",
|
||||
"globals.use_cuequivariance_attention",
|
||||
"globals.use_deepspeed_evo_attention"
|
||||
),
|
||||
]
|
||||
@@ -51,6 +52,10 @@ def enforce_config_constraints(config):
|
||||
"and that the deepspeed.ops.deepspeed4science package exists"
|
||||
)
|
||||
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if (config.globals.use_cuequivariance_attention or config.globals.use_cuequivariance_multiplicative_update) and not cuequivariance_is_installed:
|
||||
raise ValueError("use_cuequivariance_xxx requires that cuequivariance_torch is installed")
|
||||
|
||||
if(
|
||||
config.globals.offload_inference and
|
||||
not config.model.template.average_templates
|
||||
@@ -64,6 +69,8 @@ def model_config(
|
||||
low_prec=False,
|
||||
long_sequence_inference=False,
|
||||
use_deepspeed_evoformer_attention=False,
|
||||
use_cuequivariance_attention=False,
|
||||
use_cuequivariance_multiplicative_update=False,
|
||||
):
|
||||
c = copy.deepcopy(config)
|
||||
# TRAINING PRESETS
|
||||
@@ -240,7 +247,13 @@ def model_config(
|
||||
|
||||
if use_deepspeed_evoformer_attention:
|
||||
c.globals.use_deepspeed_evo_attention = True
|
||||
|
||||
|
||||
if use_cuequivariance_attention:
|
||||
c.globals.use_cuequivariance_attention = True
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
c.globals.use_cuequivariance_multiplicative_update = True
|
||||
|
||||
if train:
|
||||
c.globals.blocks_per_ckpt = 1
|
||||
c.globals.chunk_size = None
|
||||
@@ -475,6 +488,11 @@ config = mlc.ConfigDict(
|
||||
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
|
||||
# on long sequences (>1000 residues).
|
||||
"use_flash": False,
|
||||
# Use cuEquivariance kernels for accelerated triangle attention and
|
||||
# triangle multiplicative update operations. Requires CUDA and
|
||||
# cuequivariance_torch package.
|
||||
"use_cuequivariance_attention": False,
|
||||
"use_cuequivariance_multiplicative_update": False,
|
||||
"offload_inference": False,
|
||||
"c_z": c_z,
|
||||
"c_m": c_m,
|
||||
|
||||
@@ -50,6 +50,8 @@ class Dropout(nn.Module):
|
||||
Tensor to which dropout is applied. Can have any shape
|
||||
compatible with self.batch_dim
|
||||
"""
|
||||
if not self.training:
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
if self.batch_dim is not None:
|
||||
for bd in self.batch_dim:
|
||||
|
||||
@@ -658,6 +658,8 @@ class TemplateEmbedder(nn.Module):
|
||||
chunk_size,
|
||||
_mask_trans=True,
|
||||
use_deepspeed_evo_attention=False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma=False,
|
||||
inplace_safe=False
|
||||
):
|
||||
@@ -709,6 +711,8 @@ class TemplateEmbedder(nn.Module):
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -896,6 +900,8 @@ class TemplateEmbedderMultimer(nn.Module):
|
||||
multichain_mask_2d,
|
||||
_mask_trans=True,
|
||||
use_deepspeed_evo_attention=False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma=False,
|
||||
inplace_safe=False
|
||||
):
|
||||
@@ -971,6 +977,8 @@ class TemplateEmbedderMultimer(nn.Module):
|
||||
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import sys
|
||||
import torch
|
||||
@@ -19,6 +21,7 @@ import torch.nn as nn
|
||||
from typing import Tuple, Sequence, Optional
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import Linear, LayerNorm
|
||||
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
|
||||
@@ -179,6 +182,8 @@ class PairStack(nn.Module):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -197,6 +202,7 @@ class PairStack(nn.Module):
|
||||
mask=pair_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if (not inplace_safe):
|
||||
z = z + self.ps_dropout_row_layer(tmu_update)
|
||||
@@ -210,6 +216,7 @@ class PairStack(nn.Module):
|
||||
mask=pair_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if (not inplace_safe):
|
||||
z = z + self.ps_dropout_row_layer(tmu_update)
|
||||
@@ -226,6 +233,7 @@ class PairStack(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -245,6 +253,7 @@ class PairStack(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -363,6 +372,7 @@ class MSABlock(nn.Module, ABC):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -427,6 +437,8 @@ class EvoformerBlock(MSABlock):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -467,6 +479,7 @@ class EvoformerBlock(MSABlock):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
)
|
||||
),
|
||||
@@ -489,6 +502,7 @@ class EvoformerBlock(MSABlock):
|
||||
mask=msa_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
),
|
||||
@@ -534,6 +548,8 @@ class EvoformerBlock(MSABlock):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -610,6 +626,8 @@ class ExtraMSABlock(MSABlock):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -618,8 +636,8 @@ class ExtraMSABlock(MSABlock):
|
||||
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if(_attn_chunk_size is None):
|
||||
_attn_chunk_size = chunk_size
|
||||
|
||||
_attn_chunk_size = chunk_size
|
||||
|
||||
if(_offload_inference and inplace_safe):
|
||||
input_tensors = _offloadable_inputs
|
||||
del _offloadable_inputs
|
||||
@@ -646,7 +664,8 @@ class ExtraMSABlock(MSABlock):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_lma=use_lma,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention or use_cuequivariance_attention),
|
||||
_checkpoint_chunks=
|
||||
self.ckpt if torch.is_grad_enabled() else False,
|
||||
)
|
||||
@@ -719,6 +738,8 @@ class ExtraMSABlock(MSABlock):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -857,13 +878,15 @@ class EvoformerStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if(tune_chunk_size):
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def _prep_blocks(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
use_lma: bool,
|
||||
use_flash: bool,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
@@ -878,6 +901,8 @@ class EvoformerStack(nn.Module):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
inplace_safe=inplace_safe,
|
||||
@@ -901,12 +926,13 @@ class EvoformerStack(nn.Module):
|
||||
args=(m.clone(), z.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -918,6 +944,8 @@ class EvoformerStack(nn.Module):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -930,6 +958,8 @@ class EvoformerStack(nn.Module):
|
||||
z=input_tensors[1],
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
msa_mask=msa_mask,
|
||||
@@ -960,8 +990,10 @@ class EvoformerStack(nn.Module):
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
chunk_size: int = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -996,12 +1028,19 @@ class EvoformerStack(nn.Module):
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
s:
|
||||
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||
"""
|
||||
"""
|
||||
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
blocks = self._prep_blocks(
|
||||
m=m,
|
||||
z=z,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
msa_mask=msa_mask,
|
||||
@@ -1080,13 +1119,15 @@ class ExtraMSAStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if(tune_chunk_size):
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def _prep_blocks(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
use_lma: bool,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
pair_mask: Optional[torch.Tensor],
|
||||
@@ -1100,6 +1141,8 @@ class ExtraMSAStack(nn.Module):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -1122,12 +1165,15 @@ class ExtraMSAStack(nn.Module):
|
||||
args=(m.clone(), z.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -1137,6 +1183,8 @@ class ExtraMSAStack(nn.Module):
|
||||
input_tensors: Sequence[torch.Tensor],
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
msa_mask: Optional[torch.Tensor] = None,
|
||||
pair_mask: Optional[torch.Tensor] = None,
|
||||
@@ -1150,6 +1198,8 @@ class ExtraMSAStack(nn.Module):
|
||||
z=input_tensors[1],
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
@@ -1175,8 +1225,10 @@ class ExtraMSAStack(nn.Module):
|
||||
z: torch.Tensor,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
pair_mask: Optional[torch.Tensor],
|
||||
chunk_size: int,
|
||||
chunk_size: int = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -1197,12 +1249,19 @@ class ExtraMSAStack(nn.Module):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair update
|
||||
"""
|
||||
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
blocks = self._prep_blocks(
|
||||
m=m,
|
||||
z=z,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
|
||||
@@ -147,6 +147,8 @@ class AlphaFold(nn.Module):
|
||||
chunk_size=self.globals.chunk_size,
|
||||
multichain_mask_2d=multichain_mask_2d,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=self.config._mask_trans
|
||||
@@ -171,6 +173,8 @@ class AlphaFold(nn.Module):
|
||||
templ_dim,
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=self.config._mask_trans
|
||||
@@ -382,6 +386,8 @@ class AlphaFold(nn.Module):
|
||||
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
pair_mask=pair_mask.to(dtype=m.dtype),
|
||||
_mask_trans=self.config._mask_trans,
|
||||
@@ -395,6 +401,8 @@ class AlphaFold(nn.Module):
|
||||
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
pair_mask=pair_mask.to(dtype=m.dtype),
|
||||
inplace_safe=inplace_safe,
|
||||
@@ -414,6 +422,8 @@ class AlphaFold(nn.Module):
|
||||
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
_mask_trans=self.config._mask_trans,
|
||||
)
|
||||
@@ -427,6 +437,8 @@ class AlphaFold(nn.Module):
|
||||
pair_mask=pair_mask.to(dtype=z.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
use_flash=self.globals.use_flash,
|
||||
inplace_safe=inplace_safe,
|
||||
|
||||
@@ -17,6 +17,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import (
|
||||
Linear,
|
||||
@@ -93,6 +94,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size: int,
|
||||
use_memory_efficient_kernel: bool,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_lma: bool,
|
||||
use_flash: bool,
|
||||
flash_mask: Optional[torch.Tensor],
|
||||
@@ -105,6 +107,7 @@ class MSAAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=flash_mask,
|
||||
@@ -132,37 +135,50 @@ class MSAAttention(nn.Module):
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor],
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
chunk_size: int = 256
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_seq, n_res = m.shape[-3:-1]
|
||||
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = m.new_ones(
|
||||
m.shape[:-3] + (n_seq, n_res),
|
||||
m.shape[:-1],
|
||||
)
|
||||
|
||||
# [*, N_seq, 1, 1, N_res]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
if use_cuequivariance_attention:
|
||||
mask_bias = mask[..., :, None, None, :]
|
||||
else:
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
if (self.pair_bias and
|
||||
z is not None and # For the
|
||||
self.layer_norm_z is not None and # benefit of
|
||||
self.linear_z is not None # TorchScript
|
||||
):
|
||||
chunks = []
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
for i in range(0, z.shape[-3], 256):
|
||||
z_chunk = z[..., i: i + 256, :, :]
|
||||
if chunk_size is None:
|
||||
z = self.layer_norm_z(z)
|
||||
z = self.linear_z(z)
|
||||
else:
|
||||
chunks = []
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z_chunk = self.layer_norm_z(z_chunk)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z_chunk = self.linear_z(z_chunk)
|
||||
|
||||
chunks.append(z_chunk)
|
||||
|
||||
z = torch.cat(chunks, dim=-3)
|
||||
|
||||
for i in range(0, z.shape[-3], chunk_size):
|
||||
z_chunk = z[..., i: i + chunk_size, :, :]
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z_chunk = self.layer_norm_z(z_chunk)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z_chunk = self.linear_z(z_chunk)
|
||||
|
||||
chunks.append(z_chunk)
|
||||
z = torch.cat(chunks, dim=-3)
|
||||
|
||||
# [*, 1, no_heads, N_res, N_res]
|
||||
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||
|
||||
@@ -224,6 +240,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size: Optional[int] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -252,16 +269,19 @@ class MSAAttention(nn.Module):
|
||||
checkpoint=_checkpoint_chunks,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
|
||||
|
||||
if(use_flash):
|
||||
assert z is None
|
||||
biases = None
|
||||
else:
|
||||
m, mask_bias, z = self._prep_inputs(
|
||||
m, z, mask, inplace_safe=inplace_safe
|
||||
m, z, mask, inplace_safe=inplace_safe,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
)
|
||||
|
||||
biases = [mask_bias]
|
||||
if z is None and use_cuequivariance_attention:
|
||||
z = torch.zeros_like(m)
|
||||
if(z is not None):
|
||||
biases.append(z)
|
||||
|
||||
@@ -272,6 +292,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=mask,
|
||||
@@ -284,6 +305,7 @@ class MSAAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=mask,
|
||||
@@ -362,6 +384,7 @@ class MSAColumnAttention(nn.Module):
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -386,6 +409,7 @@ class MSAColumnAttention(nn.Module):
|
||||
mask=mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
)
|
||||
|
||||
@@ -30,6 +30,14 @@ if fa_is_installed:
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cuequivariance_is_installed:
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
except ImportError as e:
|
||||
print(e)
|
||||
cuequivariance_is_installed = False
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import truncnorm
|
||||
@@ -199,7 +207,7 @@ class Linear(nn.Linear):
|
||||
bias).to(dtype=d)
|
||||
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
bias = self.bias.to(dtype=d) if self.bias is not None else None
|
||||
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
|
||||
|
||||
@@ -223,7 +231,7 @@ class LayerNorm(nn.Module):
|
||||
deepspeed.comm.comm.is_initialized()
|
||||
)
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
out = nn.functional.layer_norm(
|
||||
x,
|
||||
self.c_in,
|
||||
@@ -255,7 +263,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
deepspeed.comm.comm.is_initialized()
|
||||
)
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
else:
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
@@ -452,6 +460,7 @@ class Attention(nn.Module):
|
||||
biases: Optional[List[torch.Tensor]] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
|
||||
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
|
||||
@@ -483,6 +492,11 @@ class Attention(nn.Module):
|
||||
Query chunk size (for LMA)
|
||||
lma_kv_chunk_size:
|
||||
Key/Value chunk size (for LMA)
|
||||
use_cuequivariance_attention:
|
||||
Whether to use cuEquivariance attention kernel.
|
||||
When on, biases[0] contains 0/1 mask tensor for cuEquivariance attention (0 for invalid positions)
|
||||
|
||||
|
||||
Returns
|
||||
[*, Q, C_q] attention update
|
||||
"""
|
||||
@@ -498,7 +512,13 @@ class Attention(nn.Module):
|
||||
"use flash_mask instead"
|
||||
)
|
||||
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
|
||||
if use_cuequivariance_attention:
|
||||
if biases is None or len(biases) != 2:
|
||||
raise ValueError(
|
||||
"cuEquivariance attention requires exactly two bias terms"
|
||||
)
|
||||
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_cuequivariance_attention, use_lma, use_flash]
|
||||
if sum(attn_options) > 1:
|
||||
raise ValueError(
|
||||
"Choose at most one alternative attention algorithm"
|
||||
@@ -529,6 +549,8 @@ class Attention(nn.Module):
|
||||
"provide up to two bias terms"
|
||||
)
|
||||
o = _deepspeed_evo_attn(q, k, v, biases)
|
||||
elif use_cuequivariance_attention:
|
||||
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
|
||||
elif use_lma:
|
||||
biases = [
|
||||
b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],))
|
||||
@@ -828,3 +850,75 @@ def _flash_attn(q, k, v, kv_mask):
|
||||
out = out.to(dtype=dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _cuequivariance_attn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Compute attention using the cuEquivariance triangle attention kernel.
|
||||
|
||||
Args:
|
||||
q: [*, H, Q, C_hidden] query data
|
||||
k: [*, H, K, C_hidden] key data
|
||||
v: [*, H, V, C_hidden] value data
|
||||
bias: [*, H, Q, K] triangular bias
|
||||
mask: [*, Q, K] mask for masking invalid positions
|
||||
|
||||
Returns:
|
||||
[*, H, Q, C_hidden] attention output
|
||||
"""
|
||||
if not cuequivariance_is_installed:
|
||||
raise ValueError(
|
||||
"_cuequivariance_attn requires that cuequivariance_torch be installed"
|
||||
)
|
||||
|
||||
# Check input dimensionality
|
||||
qdim = len(q.shape)
|
||||
# If we have 4D tensors ([*, H, Q, D]), add batch dimension
|
||||
if qdim == 4:
|
||||
q = q.unsqueeze(0) # [1, H, Q, D]
|
||||
k = k.unsqueeze(0) # [1, H, K, D]
|
||||
v = v.unsqueeze(0) # [1, H, V, D]
|
||||
bias = bias.unsqueeze(0) # [1, H, Q, K]
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(0) # [1, Q, K]
|
||||
elif len(q.shape[:-3]) > 2:
|
||||
# If there are more than 2 leading dimensions, flatten them into B*N
|
||||
batch_shape = q.shape[:-3]
|
||||
flat_batch_size = 1
|
||||
for dim in batch_shape:
|
||||
flat_batch_size *= dim
|
||||
|
||||
q = q.reshape(flat_batch_size, *q.shape[-3:])
|
||||
k = k.reshape(flat_batch_size, *k.shape[-3:])
|
||||
v = v.reshape(flat_batch_size, *v.shape[-3:])
|
||||
bias = bias.reshape(flat_batch_size, *bias.shape[-3:])
|
||||
if mask is not None:
|
||||
mask = mask.reshape(flat_batch_size, *mask.shape[-2:])
|
||||
# Convert bias to float32
|
||||
bias = bias.to(dtype=torch.float32)
|
||||
|
||||
# Apply cuEquivariance triangle attention
|
||||
o = triangle_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
scale=1.0
|
||||
)
|
||||
|
||||
# If we added a batch dimension for 4D inputs, remove it
|
||||
if qdim == 4:
|
||||
o = o.squeeze(0)
|
||||
|
||||
# Final transpose to match expected output format
|
||||
o = o.transpose(-2, -3)
|
||||
|
||||
return o
|
||||
|
||||
@@ -216,6 +216,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
_attn_chunk_size: Optional[int],
|
||||
single_mask: torch.Tensor,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_lma: bool,
|
||||
inplace_safe: bool):
|
||||
single = add(single,
|
||||
@@ -225,6 +226,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -239,6 +241,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -251,12 +254,14 @@ class TemplatePairStackBlock(nn.Module):
|
||||
def tri_mul_out_in(self,
|
||||
single: torch.Tensor,
|
||||
single_mask: torch.Tensor,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
inplace_safe: bool):
|
||||
tmu_update = self.tri_mul_out(
|
||||
single,
|
||||
mask=single_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if not inplace_safe:
|
||||
single = single + self.dropout_row(tmu_update)
|
||||
@@ -270,6 +275,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
mask=single_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if not inplace_safe:
|
||||
single = single + self.dropout_row(tmu_update)
|
||||
@@ -285,6 +291,8 @@ class TemplatePairStackBlock(nn.Module):
|
||||
mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -307,10 +315,12 @@ class TemplatePairStackBlock(nn.Module):
|
||||
if self.tri_mul_first:
|
||||
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
|
||||
single_mask=single_mask,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
inplace_safe=inplace_safe),
|
||||
_attn_chunk_size=_attn_chunk_size,
|
||||
single_mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe)
|
||||
else:
|
||||
@@ -319,9 +329,11 @@ class TemplatePairStackBlock(nn.Module):
|
||||
_attn_chunk_size=_attn_chunk_size,
|
||||
single_mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe),
|
||||
single_mask=single_mask,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
inplace_safe=inplace_safe)
|
||||
|
||||
single = add(single,
|
||||
@@ -405,7 +417,7 @@ class TemplatePairStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if tune_chunk_size:
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -413,6 +425,8 @@ class TemplatePairStack(nn.Module):
|
||||
mask: torch.tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -437,6 +451,8 @@ class TemplatePairStack(nn.Module):
|
||||
mask=mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -451,10 +467,11 @@ class TemplatePairStack(nn.Module):
|
||||
args=(t.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -528,6 +545,8 @@ def embed_templates_offload(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
@@ -647,6 +666,8 @@ def embed_templates_average(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
|
||||
@@ -64,6 +64,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size: int,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -79,6 +80,7 @@ class TriangleAttention(nn.Module):
|
||||
self.mha,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma
|
||||
),
|
||||
mha_inputs,
|
||||
@@ -93,6 +95,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size: Optional[int] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -117,7 +120,10 @@ class TriangleAttention(nn.Module):
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
if use_cuequivariance_attention:
|
||||
mask_bias = mask[..., :, None, None, :]
|
||||
else:
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# [*, H, I, J]
|
||||
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||
@@ -134,6 +140,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -144,6 +151,7 @@ class TriangleAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,15 +17,86 @@
|
||||
from functools import partialmethod
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import Linear, LayerNorm
|
||||
from openfold.utils.chunk_utils import chunk_layer
|
||||
from openfold.utils.precision_utils import is_fp16_enabled
|
||||
from openfold.utils.tensor_utils import add, permute_final_dims
|
||||
|
||||
# cuEquivariance import handling
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cuequivariance_is_installed:
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import triangle_multiplicative_update
|
||||
except ImportError:
|
||||
cuequivariance_is_installed = False
|
||||
|
||||
|
||||
def _cuequivariance_triangular_mult(
|
||||
x: torch.Tensor,
|
||||
direction: str,
|
||||
mask: Optional[torch.Tensor],
|
||||
norm_in_weight: torch.Tensor,
|
||||
norm_in_bias: torch.Tensor,
|
||||
p_in_weight: torch.Tensor,
|
||||
p_in_bias: torch.Tensor,
|
||||
g_in_weight: torch.Tensor,
|
||||
g_in_bias: torch.Tensor,
|
||||
norm_out_weight: torch.Tensor,
|
||||
norm_out_bias: torch.Tensor,
|
||||
p_out_weight: torch.Tensor,
|
||||
p_out_bias: torch.Tensor,
|
||||
g_out_weight: torch.Tensor,
|
||||
g_out_bias: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Wrapper function for cuEquivariance triangle multiplicative update.
|
||||
|
||||
Args:
|
||||
x: [*, N, N, C] input tensor
|
||||
direction: "outgoing" or "incoming"
|
||||
mask: [*, N, N] mask tensor
|
||||
norm_in_weight: [C] input normalization weight
|
||||
norm_in_bias: [C] input normalization bias
|
||||
p_in_weight: [2*C, C] input projection weight
|
||||
g_in_weight: [2*C, C] input gating weight
|
||||
norm_out_weight: [C] output normalization weight
|
||||
norm_out_bias: [C] output normalization bias
|
||||
p_out_weight: [C, C] output projection weight
|
||||
g_out_weight: [C, C] output gating weight
|
||||
eps: epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
[*, N, N, C] output tensor
|
||||
"""
|
||||
if not cuequivariance_is_installed:
|
||||
raise ValueError(
|
||||
"_cuequivariance_triangular_mult requires that cuequivariance_torch be installed"
|
||||
)
|
||||
return triangle_multiplicative_update(
|
||||
x=x,
|
||||
direction=direction,
|
||||
mask=mask,
|
||||
norm_in_weight=norm_in_weight,
|
||||
norm_in_bias=norm_in_bias,
|
||||
p_in_weight=p_in_weight,
|
||||
p_in_bias=p_in_bias,
|
||||
g_in_weight=g_in_weight,
|
||||
g_in_bias=g_in_bias,
|
||||
norm_out_weight=norm_out_weight,
|
||||
norm_out_bias=norm_out_bias,
|
||||
p_out_weight=p_out_weight,
|
||||
p_out_bias=p_out_bias,
|
||||
g_out_weight=g_out_weight,
|
||||
g_out_bias=g_out_bias,
|
||||
eps=eps,
|
||||
).view(x.shape)
|
||||
|
||||
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
|
||||
"""
|
||||
@@ -87,6 +159,7 @@ class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -397,6 +470,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False,
|
||||
_inplace_chunk_size: Optional[int] = 256,
|
||||
) -> torch.Tensor:
|
||||
@@ -409,7 +483,38 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
if(inplace_safe):
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
p_in_weight = torch.cat([self.linear_a_p.weight, self.linear_b_p.weight], dim=0)
|
||||
g_in_weight = torch.cat([self.linear_a_g.weight, self.linear_b_g.weight], dim=0)
|
||||
|
||||
p_in_bias = torch.cat([self.linear_a_p.bias, self.linear_b_p.bias], dim=0)
|
||||
g_in_bias = torch.cat([self.linear_a_g.bias, self.linear_b_g.bias], dim=0)
|
||||
|
||||
result = _cuequivariance_triangular_mult(
|
||||
z,
|
||||
direction="outgoing" if self._outgoing else "incoming",
|
||||
mask=mask,
|
||||
norm_in_weight=self.layer_norm_in.weight,
|
||||
norm_in_bias=self.layer_norm_in.bias,
|
||||
p_in_weight=p_in_weight,
|
||||
p_in_bias=p_in_bias,
|
||||
g_in_weight=g_in_weight,
|
||||
g_in_bias=g_in_bias,
|
||||
norm_out_weight=self.layer_norm_out.weight,
|
||||
norm_out_bias=self.layer_norm_out.bias,
|
||||
p_out_weight=self.linear_z.weight,
|
||||
p_out_bias=self.linear_z.bias,
|
||||
g_out_weight=self.linear_g.weight,
|
||||
g_out_bias=self.linear_g.bias,
|
||||
eps=1e-5,
|
||||
)
|
||||
# When not inplace_safe (training), caller should have set _add_with_inplace to False
|
||||
if inplace_safe and _add_with_inplace:
|
||||
result += z
|
||||
return result
|
||||
|
||||
if inplace_safe:
|
||||
x = self._inference_forward(
|
||||
z,
|
||||
mask,
|
||||
@@ -422,7 +527,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
a = mask
|
||||
a = a * self.sigmoid(self.linear_a_g(z))
|
||||
@@ -433,13 +538,12 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
|
||||
# Prevents overflow of torch.matmul in combine projections in
|
||||
# reduced-precision modes
|
||||
a_std = a.std()
|
||||
b_std = b.std()
|
||||
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
|
||||
a = a / a.std()
|
||||
b = b / b.std()
|
||||
|
||||
if(is_fp16_enabled()):
|
||||
if is_fp16_enabled():
|
||||
a_std = a.std()
|
||||
b_std = b.std()
|
||||
if a_std != 0. and b_std != 0.:
|
||||
a = a / a.std()
|
||||
b = b / b.std()
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = self._combine_projections(a.float(), b.float())
|
||||
else:
|
||||
@@ -545,6 +649,7 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False,
|
||||
_inplace_chunk_size: Optional[int] = 256
|
||||
) -> torch.Tensor:
|
||||
@@ -557,6 +662,28 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
direction = "outgoing" if self._outgoing else "incoming"
|
||||
result = _cuequivariance_triangular_mult(
|
||||
x=z,
|
||||
direction=direction,
|
||||
mask=mask,
|
||||
norm_in_weight=self.layer_norm_in.weight,
|
||||
norm_in_bias=self.layer_norm_in.bias,
|
||||
p_in_weight=self.linear_ab_p.weight,
|
||||
g_in_weight=self.linear_ab_g.weight,
|
||||
norm_out_weight=self.layer_norm_out.weight,
|
||||
norm_out_bias=self.layer_norm_out.bias,
|
||||
p_out_weight=self.linear_z.weight,
|
||||
g_out_weight=self.linear_g.weight,
|
||||
eps=1e-5,
|
||||
)
|
||||
# When not inplace_safe (training), caller should have set _add_with_inplace to False
|
||||
if inplace_safe and _add_with_inplace:
|
||||
result += z
|
||||
return result
|
||||
|
||||
if (inplace_safe):
|
||||
x = self._inference_forward(
|
||||
z,
|
||||
|
||||
@@ -252,6 +252,16 @@ def chunk_layer(
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
if no_chunks == 1:
|
||||
return layer(**inputs)
|
||||
|
||||
def _prep_inputs(t):
|
||||
if(not low_mem):
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
@@ -267,14 +277,6 @@ def chunk_layer(
|
||||
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
|
||||
prepped_outputs = tensor_tree_map(reshape_fn, _out)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
|
||||
i = 0
|
||||
out = prepped_outputs
|
||||
for _ in range(no_chunks):
|
||||
|
||||
@@ -182,6 +182,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
]
|
||||
verify_arg_order(
|
||||
@@ -203,6 +204,7 @@ def trace_model_(model, sample_input):
|
||||
("mask", msa_mask),
|
||||
("chunk_size", torch.tensor(evoformer_chunk_size)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("use_flash", torch.tensor(model.globals.use_flash)),
|
||||
]
|
||||
@@ -286,6 +288,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("inplace_safe", torch.tensor(True)),
|
||||
]
|
||||
@@ -309,6 +312,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("inplace_safe", torch.tensor(True)),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@@ -20,13 +22,13 @@ import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
import json
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
import torch
|
||||
torch_versions = torch.__version__.split(".")
|
||||
torch_major_version = int(torch_versions[0])
|
||||
torch_minor_version = int(torch_versions[1])
|
||||
@@ -183,13 +185,15 @@ def main(args):
|
||||
args.config_preset,
|
||||
long_sequence_inference=args.long_sequence_inference,
|
||||
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
|
||||
)
|
||||
use_cuequivariance_attention=args.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=args.use_cuequivariance_multiplicative_update,
|
||||
)
|
||||
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
|
||||
if args.trace_model:
|
||||
if not config.data.predict.fixed_size:
|
||||
raise ValueError(
|
||||
@@ -482,6 +486,14 @@ if __name__ == "__main__":
|
||||
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
|
||||
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuequivariance_attention", action="store_true", default=False,
|
||||
help="""Use cuEquivariance kernels for attention computation."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuequivariance_multiplicative_update", action="store_true", default=False,
|
||||
help="""Use cuEquivariance kernels for triangular multiplicative update computation."""
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
24
setup.py
24
setup.py
@@ -52,15 +52,23 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
compute_capabilities = set([
|
||||
(5, 2), # Titan X
|
||||
(6, 1), # GeForce 1000-series
|
||||
(9, 0), # Hopper
|
||||
])
|
||||
|
||||
compute_capabilities.add((7, 0))
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
compute_capabilities.add((8, 0))
|
||||
compute_capabilities.add((8, 6))
|
||||
compute_capabilities.add((8, 9))
|
||||
|
||||
if int(bare_metal_major) >= 12:
|
||||
compute_capabilities.add((9, 0))
|
||||
|
||||
if int(bare_metal_major) >= 13:
|
||||
compute_capabilities.add((10, 0))
|
||||
compute_capabilities.add((10, 3))
|
||||
compute_capabilities.add((12, 0))
|
||||
else:
|
||||
compute_capabilities.add((7, 0))
|
||||
|
||||
compute_capability, _ = get_nvidia_cc()
|
||||
if compute_capability is not None:
|
||||
@@ -75,8 +83,6 @@ for major, minor in list(compute_capabilities):
|
||||
|
||||
extra_cuda_flags += cc_flag
|
||||
|
||||
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
|
||||
|
||||
if bare_metal_major != -1:
|
||||
modules = [CUDAExtension(
|
||||
name="attn_core_inplace_cuda",
|
||||
@@ -127,6 +133,12 @@ setup(
|
||||
},
|
||||
ext_modules=modules,
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
extras_require={
|
||||
'cuequivariance': [
|
||||
'cuequivariance-torch; sys_platform != "darwin"', # Not available on macOS
|
||||
'triton>=3.3.0; sys_platform != "darwin"', # Required for triangle multiplicative update
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
|
||||
160
tests/test_cuequivariance.py
Normal file
160
tests/test_cuequivariance.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests to compare components of OpenFold run with the cuEquivariance memory-efficient
|
||||
attention kernel vs. a stock PyTorch attention implementation.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from openfold.data import data_transforms
|
||||
from openfold.model.primitives import (
|
||||
lecun_normal_init_,
|
||||
Attention
|
||||
)
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
|
||||
from tests.config import consts
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.data_utils import random_template_feats, random_attention_inputs
|
||||
|
||||
|
||||
|
||||
class TestCuEquivarianceKernel(unittest.TestCase):
|
||||
|
||||
def test_compare_template_stack(self):
|
||||
"""
|
||||
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
|
||||
Kernel can be used for Triangle Attention in the Template Pair Stack.
|
||||
"""
|
||||
n_templ = consts.n_templ
|
||||
n_res = 20
|
||||
eps = 2e-2
|
||||
|
||||
batch = random_template_feats(n_templ, n_res)
|
||||
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
|
||||
if consts.is_multimer:
|
||||
batch["asym_id"] = batch['asym_id'][0]
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
|
||||
|
||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
|
||||
template_feats = {
|
||||
k: v for k, v in batch.items() if k.startswith("template_")
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_deepspeed_evo_attention = False
|
||||
out_repro = model.embed_templates(
|
||||
template_feats,
|
||||
batch,
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
templ_dim=0,
|
||||
inplace_safe=False
|
||||
)
|
||||
out_repro = out_repro["template_pair_embedding"].cpu()
|
||||
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
model.globals.use_cuequivariance_multiplicative_update = True
|
||||
|
||||
out_repro_ds = model.embed_templates(
|
||||
template_feats,
|
||||
batch,
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
templ_dim=0,
|
||||
inplace_safe=False
|
||||
)
|
||||
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
|
||||
|
||||
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
|
||||
|
||||
def test_compare_model(self):
|
||||
"""
|
||||
Run full model with and without using CuEquivariance Evoformer attention kernel
|
||||
and compare output coordinates.
|
||||
"""
|
||||
eps = 0.2
|
||||
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
|
||||
batch = pickle.load(fp)
|
||||
|
||||
# atom37_to_atom14 doesn't like batches
|
||||
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
|
||||
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
|
||||
|
||||
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
|
||||
|
||||
if consts.is_multimer:
|
||||
n_res = batch['aatype'].shape[1]
|
||||
n_extra_seq = batch['extra_msa'].shape[1]
|
||||
batch["asym_id"] = np.ones((4, n_res))
|
||||
batch["entity_id"] = np.ones((4, n_res))
|
||||
batch["sym_id"] = np.ones((4, n_res))
|
||||
batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
|
||||
|
||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
|
||||
|
||||
batch["aatype"] = batch["aatype"].long()
|
||||
batch["template_aatype"] = batch["template_aatype"].long()
|
||||
batch["extra_msa"] = batch["extra_msa"].long()
|
||||
batch["residx_atom37_to_atom14"] = batch[
|
||||
"residx_atom37_to_atom14"
|
||||
].long()
|
||||
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
|
||||
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
|
||||
batch.update(
|
||||
data_transforms.atom37_to_torsion_angles("template_")(batch)
|
||||
)
|
||||
|
||||
# Move the recycling dimension to the end
|
||||
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
|
||||
batch = tensor_tree_map(move_dim, batch)
|
||||
# Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
|
||||
# https://github.com/aqlaboratory/openfold/issues/532
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_cuequivariance_attention = False
|
||||
model.globals.use_cuequivariance_multiplicative_update = False
|
||||
out_repro = model(batch)
|
||||
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
|
||||
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
# Enable attention
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
out_repro_attn = model(batch)
|
||||
out_repro_attn = tensor_tree_map(lambda t: t.cpu(), out_repro_attn)
|
||||
out_repro_attn = out_repro_attn["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_attn, eps)
|
||||
|
||||
# Enable multiplication
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
model.globals.use_cuequivariance_multiplicative_update = True
|
||||
out_repro_mul = model(batch)
|
||||
out_repro_mul = tensor_tree_map(lambda t: t.cpu(), out_repro_mul)
|
||||
out_repro_mul = out_repro_mul["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
compare_utils.assert_mean_abs_diff_small(out_repro_attn, out_repro_mul, eps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -113,7 +113,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
def test_tri_mul_in_compare(self):
|
||||
self._tri_mul_compare(incoming=True)
|
||||
|
||||
def _tri_mul_inplace(self, incoming=False):
|
||||
def _tri_mul_inplace(self, incoming=False, dtype = torch.float32):
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
@@ -126,26 +126,38 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
if incoming
|
||||
else model.evoformer.blocks[0].pair_stack.tri_mul_out
|
||||
)
|
||||
|
||||
act = torch.as_tensor(pair_act, dtype=dtype).cuda()
|
||||
mask = torch.as_tensor(pair_mask, dtype=dtype).cuda()
|
||||
module = module.to(dtype=dtype)
|
||||
|
||||
out_stock = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
act,
|
||||
mask=mask,
|
||||
inplace_safe=False,
|
||||
).cpu()
|
||||
)
|
||||
|
||||
# This has to come second because inference mode is in-place
|
||||
out_inplace = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
act,
|
||||
mask=mask,
|
||||
inplace_safe=True, _inplace_chunk_size=2,
|
||||
).cpu()
|
||||
)
|
||||
|
||||
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
|
||||
torch.testing.assert_close(out_stock, out_inplace, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def test_tri_mul_out_inference(self):
|
||||
self._tri_mul_inplace()
|
||||
|
||||
def test_tri_mul_out_inference_bf16(self):
|
||||
self._tri_mul_inplace(dtype=torch.bfloat16)
|
||||
|
||||
def test_tri_mul_in_inference(self):
|
||||
self._tri_mul_inplace(incoming=True)
|
||||
|
||||
def test_tri_mul_in_inference_bf16(self):
|
||||
self._tri_mul_inplace(incoming=True, dtype=torch.bfloat16)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user