mirror of
https://github.com/aqlaboratory/openfold.git
synced 2026-06-04 12:44:26 +08:00
Merge pull request #553 from borisfom/nv_upstream_trt_cuequivariance
NVIDIA cuEquivariance and TensorRT integration
This commit is contained in:
@@ -143,14 +143,38 @@ Some commonly used command line flags are here. A full list of flags can be view
|
||||
|
||||
### Advanced Options for Increasing Efficiency
|
||||
|
||||
#### Speeding up inference
|
||||
#### Turning on TF32 (TensorFloat-32) precision on compatible hardware
|
||||
|
||||
When running on latest NVIDIA GPUs, starting from Ampere, you can enable TF32 precision to get about 1.3x performance boost.
|
||||
TF32 uses 1 sign bit, 8 exponent bits (like FP32), and 10 mantissa (significand) bits (like FP16), packed into a 32-bit word.
|
||||
It was found generally safe to use OF2 with TF32 instead of full FP32. To enable it globally in Torch:
|
||||
|
||||
```
|
||||
torch.backends.cuda.matmul.allow_tf32 = True # Enable TF32 for matrix multiplications
|
||||
torch.backends.cudnn.allow_tf32 = True # Enable TF32 for convolutions
|
||||
```
|
||||
Make sure NVIDIA_TF32_OVERRIDE environment variable is either not defined or set to 1.
|
||||
|
||||
#### Applying lower BF16 precision to EvoformerStack and ExtraMSAStack
|
||||
|
||||
BF16 occupies 16 bits: 1 sign bit, 8 exponent bits (same as FP32), and 7 mantissa (fraction) bits. Its dynamic range is equivalent to FP32, but BF16 can only represent numbers with about three decimal digits of precision.
|
||||
It was found generally safe to apply BF16 precision cast to EvoformerStack and ExtraMSAStack. This allows to achieve ~1.5x speedup compared to TF32 inferenceof the whole model.
|
||||
To apply BF16, use '--precision=bf16' argument. '--precision=fp16' is also supported, but not recommended due to numerical instability.
|
||||
|
||||
#### Speeding up inference with custom attention and multiplicative update kernels
|
||||
|
||||
The **DeepSpeed DS4Sci_EvoformerAttention kernel** is a memory-efficient attention kernel developed as part of a collaboration between OpenFold and the DeepSpeed4Science initiative.
|
||||
|
||||
If your system supports deepseed, using deepspeed generally leads an inference speedup of 2 - 3x without significant additional memory use. You may specify this option by selecting the `--use_deepspeed_inference` argument.
|
||||
|
||||
OF2 supports the cuEquivariance [triangle_multiplicative_update](https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.triangle_multiplicative_update.html) and [triangle_attention](https://docs.nvidia.com/cuda/cuequivariance/api/generated/cuequivariance_torch.triangle_attention.html) kernels which can speed up inference/training of the model 1.2 to 1.5 on top of DeepSpeed and even more for sequences with > 1000 residues. cuEquivariance attention actually uses much less memory than default or DeepSpeed attention. To enable, pass '--use_cuequivariance_attention' and '--use_cuequivariance_multiplicative_update' arguments to run_pretrained_openfold.py.
|
||||
CUEquivariance does fall back to DeepSpeed on shapes it does not efficiently support, so enable both for best effect.
|
||||
|
||||
If DeepSpeed is unavailable for your system, you may also try using [FlashAttention](https://github.com/HazyResearch/flash-attention) by adding `globals.use_flash = True` to the `--experiment_config_json`. Note that FlashAttention appears to work best for sequences with < 1000 residues.
|
||||
|
||||
#### Speeding up inference with TensorRT
|
||||
Alternatively (or together with cuEquivariance), you can try applying [TensorRT](https://developer.nvidia.com/tensorrt) to key modules. OF2 comes with built-in TensorRT lazy compilation support. It allows to build TensorRT engine for Evoformer on the first inference run and to reuse it on subsequent runs. To enable, pass '--trt_mode-run', '--trt_engine_dir', '--trt_max_sequence_len', '--trt_num_profiles' and '--trt_optimization_level' arguments to run_pretrained_openfold.py.
|
||||
|
||||
#### Large-scale batch inference
|
||||
For large-scale batch inference, we offer an optional tracing mode, which massively improves runtimes at the cost of a lengthy model compilation process. To enable it, add `--trace_model` to the inference command.
|
||||
|
||||
|
||||
@@ -56,6 +56,8 @@ Certain tests perform equivalence comparisons with the AlphaFold implementation.
|
||||
### MPI
|
||||
To use OpenFold with MPI support, you will need to add the package [`mpi4py`](https://pypi.org/project/mpi4py/). This can be done with pip in your OpenFold environment, e.g. `$ pip install mpi4py`.
|
||||
|
||||
### cuEquivariance
|
||||
cuEquivariance can be installed from pip: `$ pip install cuequivariance_ops_torch_cu13 cuequivariance_torch` (on CUDA13) or `$ pip install cuequivariance_ops_torch_cu12 cuequivariance_torch` (on CUDA12)
|
||||
|
||||
### Install OpenFold parameters without aws
|
||||
If you don't have access to `aws` on your system, you can use a different download source:
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import copy
|
||||
import importlib
|
||||
@@ -31,6 +46,11 @@ def enforce_config_constraints(config):
|
||||
"globals.use_flash",
|
||||
"globals.use_deepspeed_evo_attention"
|
||||
),
|
||||
(
|
||||
"globals.use_lma",
|
||||
"globals.use_flash",
|
||||
"globals.use_cuequivariance_attention",
|
||||
),
|
||||
]
|
||||
|
||||
for options in mutually_exclusive_bools:
|
||||
@@ -51,6 +71,10 @@ def enforce_config_constraints(config):
|
||||
"and that the deepspeed.ops.deepspeed4science package exists"
|
||||
)
|
||||
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if (config.globals.use_cuequivariance_attention or config.globals.use_cuequivariance_multiplicative_update) and not cuequivariance_is_installed:
|
||||
raise ValueError("use_cuequivariance_xxx requires that cuequivariance_torch is installed")
|
||||
|
||||
if(
|
||||
config.globals.offload_inference and
|
||||
not config.model.template.average_templates
|
||||
@@ -64,8 +88,22 @@ def model_config(
|
||||
low_prec=False,
|
||||
long_sequence_inference=False,
|
||||
use_deepspeed_evoformer_attention=False,
|
||||
use_cuequivariance_attention=False,
|
||||
use_cuequivariance_multiplicative_update=False,
|
||||
precision="tf32",
|
||||
trt_mode=None,
|
||||
trt_engine_dir=None,
|
||||
trt_num_profiles=1,
|
||||
trt_optimization_level=3,
|
||||
trt_max_sequence_len=640,
|
||||
):
|
||||
c = copy.deepcopy(config)
|
||||
c.precision = precision
|
||||
c.trt.mode = trt_mode
|
||||
c.trt.engine_dir = trt_engine_dir
|
||||
c.trt.num_profiles = trt_num_profiles
|
||||
c.trt.optimization_level = trt_optimization_level
|
||||
c.trt.max_sequence_len = trt_max_sequence_len
|
||||
# TRAINING PRESETS
|
||||
if name == "initial_training":
|
||||
# AF2 Suppl. Table 4, "initial training" setting
|
||||
@@ -240,7 +278,13 @@ def model_config(
|
||||
|
||||
if use_deepspeed_evoformer_attention:
|
||||
c.globals.use_deepspeed_evo_attention = True
|
||||
|
||||
|
||||
if use_cuequivariance_attention:
|
||||
c.globals.use_cuequivariance_attention = True
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
c.globals.use_cuequivariance_multiplicative_update = True
|
||||
|
||||
if train:
|
||||
c.globals.blocks_per_ckpt = 1
|
||||
c.globals.chunk_size = None
|
||||
@@ -286,6 +330,14 @@ NUM_TEMPLATES = "num templates placeholder"
|
||||
|
||||
config = mlc.ConfigDict(
|
||||
{
|
||||
"precision": "tf32",
|
||||
"trt": {
|
||||
"mode": None,
|
||||
"engine_dir": None,
|
||||
"num_profiles": 1,
|
||||
"optimization_level": 3,
|
||||
"max_sequence_len": 640
|
||||
},
|
||||
"data": {
|
||||
"common": {
|
||||
"feat": {
|
||||
@@ -475,6 +527,11 @@ config = mlc.ConfigDict(
|
||||
# use_deepspeed_evo_attention and use_lma. Doesn't work that well
|
||||
# on long sequences (>1000 residues).
|
||||
"use_flash": False,
|
||||
# Use cuEquivariance kernels for accelerated triangle attention and
|
||||
# triangle multiplicative update operations. Requires CUDA and
|
||||
# cuequivariance_torch package.
|
||||
"use_cuequivariance_attention": False,
|
||||
"use_cuequivariance_multiplicative_update": False,
|
||||
"offload_inference": False,
|
||||
"c_z": c_z,
|
||||
"c_m": c_m,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -50,6 +51,8 @@ class Dropout(nn.Module):
|
||||
Tensor to which dropout is applied. Can have any shape
|
||||
compatible with self.batch_dim
|
||||
"""
|
||||
if not self.training:
|
||||
return x
|
||||
shape = list(x.shape)
|
||||
if self.batch_dim is not None:
|
||||
for bd in self.batch_dim:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -658,6 +659,8 @@ class TemplateEmbedder(nn.Module):
|
||||
chunk_size,
|
||||
_mask_trans=True,
|
||||
use_deepspeed_evo_attention=False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma=False,
|
||||
inplace_safe=False
|
||||
):
|
||||
@@ -709,6 +712,8 @@ class TemplateEmbedder(nn.Module):
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -896,6 +901,8 @@ class TemplateEmbedderMultimer(nn.Module):
|
||||
multichain_mask_2d,
|
||||
_mask_trans=True,
|
||||
use_deepspeed_evo_attention=False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma=False,
|
||||
inplace_safe=False
|
||||
):
|
||||
@@ -971,6 +978,8 @@ class TemplateEmbedderMultimer(nn.Module):
|
||||
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
import sys
|
||||
import torch
|
||||
@@ -19,6 +21,7 @@ import torch.nn as nn
|
||||
from typing import Tuple, Sequence, Optional
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import Linear, LayerNorm
|
||||
from openfold.model.dropout import DropoutRowwise, DropoutColumnwise
|
||||
@@ -179,6 +182,8 @@ class PairStack(nn.Module):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -197,6 +202,7 @@ class PairStack(nn.Module):
|
||||
mask=pair_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if (not inplace_safe):
|
||||
z = z + self.ps_dropout_row_layer(tmu_update)
|
||||
@@ -210,6 +216,7 @@ class PairStack(nn.Module):
|
||||
mask=pair_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if (not inplace_safe):
|
||||
z = z + self.ps_dropout_row_layer(tmu_update)
|
||||
@@ -226,6 +233,7 @@ class PairStack(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -245,6 +253,7 @@ class PairStack(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -363,6 +372,7 @@ class MSABlock(nn.Module, ABC):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -427,6 +437,8 @@ class EvoformerBlock(MSABlock):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -467,6 +479,7 @@ class EvoformerBlock(MSABlock):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_memory_efficient_kernel=False,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
)
|
||||
),
|
||||
@@ -489,6 +502,7 @@ class EvoformerBlock(MSABlock):
|
||||
mask=msa_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
),
|
||||
@@ -534,6 +548,8 @@ class EvoformerBlock(MSABlock):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -610,6 +626,8 @@ class ExtraMSABlock(MSABlock):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -618,8 +636,8 @@ class ExtraMSABlock(MSABlock):
|
||||
_offloadable_inputs: Optional[Sequence[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if(_attn_chunk_size is None):
|
||||
_attn_chunk_size = chunk_size
|
||||
|
||||
_attn_chunk_size = chunk_size
|
||||
|
||||
if(_offload_inference and inplace_safe):
|
||||
input_tensors = _offloadable_inputs
|
||||
del _offloadable_inputs
|
||||
@@ -646,7 +664,8 @@ class ExtraMSABlock(MSABlock):
|
||||
chunk_size=_attn_chunk_size,
|
||||
use_lma=use_lma,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention),
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_memory_efficient_kernel=not (use_lma or use_deepspeed_evo_attention or use_cuequivariance_attention),
|
||||
_checkpoint_chunks=
|
||||
self.ckpt if torch.is_grad_enabled() else False,
|
||||
)
|
||||
@@ -719,6 +738,8 @@ class ExtraMSABlock(MSABlock):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -857,13 +878,15 @@ class EvoformerStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if(tune_chunk_size):
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def _prep_blocks(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
use_lma: bool,
|
||||
use_flash: bool,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
@@ -878,6 +901,8 @@ class EvoformerStack(nn.Module):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
inplace_safe=inplace_safe,
|
||||
@@ -901,12 +926,13 @@ class EvoformerStack(nn.Module):
|
||||
args=(m.clone(), z.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -918,6 +944,8 @@ class EvoformerStack(nn.Module):
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -930,6 +958,8 @@ class EvoformerStack(nn.Module):
|
||||
z=input_tensors[1],
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
msa_mask=msa_mask,
|
||||
@@ -960,8 +990,10 @@ class EvoformerStack(nn.Module):
|
||||
z: torch.Tensor,
|
||||
msa_mask: torch.Tensor,
|
||||
pair_mask: torch.Tensor,
|
||||
chunk_size: int,
|
||||
chunk_size: int = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -996,12 +1028,19 @@ class EvoformerStack(nn.Module):
|
||||
[*, N_res, N_res, C_z] pair embedding
|
||||
s:
|
||||
[*, N_res, C_s] single embedding (or None if extra MSA stack)
|
||||
"""
|
||||
"""
|
||||
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
blocks = self._prep_blocks(
|
||||
m=m,
|
||||
z=z,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
msa_mask=msa_mask,
|
||||
@@ -1080,13 +1119,15 @@ class ExtraMSAStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if(tune_chunk_size):
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def _prep_blocks(self,
|
||||
m: torch.Tensor,
|
||||
z: torch.Tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
use_lma: bool,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
pair_mask: Optional[torch.Tensor],
|
||||
@@ -1100,6 +1141,8 @@ class ExtraMSAStack(nn.Module):
|
||||
pair_mask=pair_mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -1122,12 +1165,15 @@ class ExtraMSAStack(nn.Module):
|
||||
args=(m.clone(), z.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
# A temporary measure to address torch's occasional
|
||||
# inability to allocate large tensors
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -1137,6 +1183,8 @@ class ExtraMSAStack(nn.Module):
|
||||
input_tensors: Sequence[torch.Tensor],
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
msa_mask: Optional[torch.Tensor] = None,
|
||||
pair_mask: Optional[torch.Tensor] = None,
|
||||
@@ -1150,6 +1198,8 @@ class ExtraMSAStack(nn.Module):
|
||||
z=input_tensors[1],
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
@@ -1175,8 +1225,10 @@ class ExtraMSAStack(nn.Module):
|
||||
z: torch.Tensor,
|
||||
msa_mask: Optional[torch.Tensor],
|
||||
pair_mask: Optional[torch.Tensor],
|
||||
chunk_size: int,
|
||||
chunk_size: int = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -1197,12 +1249,19 @@ class ExtraMSAStack(nn.Module):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] pair update
|
||||
"""
|
||||
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
checkpoint_fn = get_checkpoint_fn()
|
||||
blocks = self._prep_blocks(
|
||||
m=m,
|
||||
z=z,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
msa_mask=msa_mask,
|
||||
pair_mask=pair_mask,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -147,6 +148,8 @@ class AlphaFold(nn.Module):
|
||||
chunk_size=self.globals.chunk_size,
|
||||
multichain_mask_2d=multichain_mask_2d,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=self.config._mask_trans
|
||||
@@ -171,6 +174,8 @@ class AlphaFold(nn.Module):
|
||||
templ_dim,
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=self.config._mask_trans
|
||||
@@ -382,6 +387,8 @@ class AlphaFold(nn.Module):
|
||||
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
pair_mask=pair_mask.to(dtype=m.dtype),
|
||||
_mask_trans=self.config._mask_trans,
|
||||
@@ -395,6 +402,8 @@ class AlphaFold(nn.Module):
|
||||
msa_mask=feats["extra_msa_mask"].to(dtype=m.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
pair_mask=pair_mask.to(dtype=m.dtype),
|
||||
inplace_safe=inplace_safe,
|
||||
@@ -414,6 +423,8 @@ class AlphaFold(nn.Module):
|
||||
pair_mask=pair_mask.to(dtype=input_tensors[1].dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
_mask_trans=self.config._mask_trans,
|
||||
)
|
||||
@@ -427,6 +438,8 @@ class AlphaFold(nn.Module):
|
||||
pair_mask=pair_mask.to(dtype=z.dtype),
|
||||
chunk_size=self.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=self.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=self.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=self.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=self.globals.use_lma,
|
||||
use_flash=self.globals.use_flash,
|
||||
inplace_safe=inplace_safe,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -17,6 +18,7 @@ import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Optional, List, Tuple
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import (
|
||||
Linear,
|
||||
@@ -93,6 +95,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size: int,
|
||||
use_memory_efficient_kernel: bool,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_lma: bool,
|
||||
use_flash: bool,
|
||||
flash_mask: Optional[torch.Tensor],
|
||||
@@ -105,6 +108,7 @@ class MSAAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=flash_mask,
|
||||
@@ -132,37 +136,50 @@ class MSAAttention(nn.Module):
|
||||
z: Optional[torch.Tensor],
|
||||
mask: Optional[torch.Tensor],
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
chunk_size: int = 256
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
n_seq, n_res = m.shape[-3:-1]
|
||||
|
||||
if mask is None:
|
||||
# [*, N_seq, N_res]
|
||||
mask = m.new_ones(
|
||||
m.shape[:-3] + (n_seq, n_res),
|
||||
m.shape[:-1],
|
||||
)
|
||||
|
||||
# [*, N_seq, 1, 1, N_res]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
if use_cuequivariance_attention:
|
||||
mask_bias = mask[..., :, None, None, :]
|
||||
else:
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
if (self.pair_bias and
|
||||
z is not None and # For the
|
||||
self.layer_norm_z is not None and # benefit of
|
||||
self.linear_z is not None # TorchScript
|
||||
):
|
||||
chunks = []
|
||||
if torch.onnx.is_in_onnx_export() or is_fx_tracing():
|
||||
inplace_safe = False
|
||||
chunk_size = None
|
||||
|
||||
for i in range(0, z.shape[-3], 256):
|
||||
z_chunk = z[..., i: i + 256, :, :]
|
||||
if chunk_size is None:
|
||||
z = self.layer_norm_z(z)
|
||||
z = self.linear_z(z)
|
||||
else:
|
||||
chunks = []
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z_chunk = self.layer_norm_z(z_chunk)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z_chunk = self.linear_z(z_chunk)
|
||||
|
||||
chunks.append(z_chunk)
|
||||
|
||||
z = torch.cat(chunks, dim=-3)
|
||||
|
||||
for i in range(0, z.shape[-3], chunk_size):
|
||||
z_chunk = z[..., i: i + chunk_size, :, :]
|
||||
|
||||
# [*, N_res, N_res, C_z]
|
||||
z_chunk = self.layer_norm_z(z_chunk)
|
||||
|
||||
# [*, N_res, N_res, no_heads]
|
||||
z_chunk = self.linear_z(z_chunk)
|
||||
|
||||
chunks.append(z_chunk)
|
||||
z = torch.cat(chunks, dim=-3)
|
||||
|
||||
# [*, 1, no_heads, N_res, N_res]
|
||||
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
|
||||
|
||||
@@ -224,6 +241,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size: Optional[int] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
@@ -252,16 +270,20 @@ class MSAAttention(nn.Module):
|
||||
checkpoint=_checkpoint_chunks,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
|
||||
|
||||
if(use_flash):
|
||||
assert z is None
|
||||
biases = None
|
||||
else:
|
||||
m, mask_bias, z = self._prep_inputs(
|
||||
m, z, mask, inplace_safe=inplace_safe
|
||||
m, z, mask, inplace_safe=inplace_safe,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
)
|
||||
|
||||
biases = [mask_bias]
|
||||
if z is None and use_cuequivariance_attention:
|
||||
z = m.new_zeros(1, self.no_heads, m.shape[-2], m.shape[-2])
|
||||
|
||||
if(z is not None):
|
||||
biases.append(z)
|
||||
|
||||
@@ -272,6 +294,7 @@ class MSAAttention(nn.Module):
|
||||
chunk_size,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=mask,
|
||||
@@ -284,6 +307,7 @@ class MSAAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
flash_mask=mask,
|
||||
@@ -362,6 +386,7 @@ class MSAColumnAttention(nn.Module):
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
use_flash: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -386,6 +411,7 @@ class MSAColumnAttention(nn.Module):
|
||||
mask=mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
use_flash=use_flash,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,6 +17,9 @@ import importlib
|
||||
import math
|
||||
from typing import Optional, Callable, List, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import truncnorm
|
||||
|
||||
deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
|
||||
ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec("deepspeed.ops.deepspeed4science") is not None
|
||||
@@ -30,9 +34,25 @@ if fa_is_installed:
|
||||
from flash_attn.bert_padding import unpad_input
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_kvpacked_func
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from scipy.stats import truncnorm
|
||||
cueq_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cueq_is_installed:
|
||||
from cuequivariance_ops_torch.triangle_attention import (
|
||||
CUEQ_TRIATTN_FALLBACK_THRESHOLD,
|
||||
)
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
|
||||
def cueq_would_fall_back(n_token: int, hidden_dim: int, dtype: torch.dtype):
|
||||
# for q_x, dimension -2 is the context length
|
||||
if n_token <= CUEQ_TRIATTN_FALLBACK_THRESHOLD:
|
||||
return True
|
||||
if dtype == torch.float32:
|
||||
if hidden_dim > 32 or hidden_dim % 4 != 0:
|
||||
return True
|
||||
else:
|
||||
# float16, bfloat16
|
||||
if hidden_dim > 128 or hidden_dim % 8 != 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
from openfold.utils.checkpointing import get_checkpoint_fn
|
||||
from openfold.utils.kernel.attention_core import attention_core
|
||||
@@ -199,7 +219,7 @@ class Linear(nn.Linear):
|
||||
bias).to(dtype=d)
|
||||
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
bias = self.bias.to(dtype=d) if self.bias is not None else None
|
||||
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
|
||||
|
||||
@@ -223,7 +243,7 @@ class LayerNorm(nn.Module):
|
||||
deepspeed.comm.comm.is_initialized()
|
||||
)
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
out = nn.functional.layer_norm(
|
||||
x,
|
||||
self.c_in,
|
||||
@@ -255,7 +275,7 @@ def softmax_no_cast(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
||||
deepspeed.comm.comm.is_initialized()
|
||||
)
|
||||
if d is torch.bfloat16 and not deepspeed_is_initialized:
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
with torch.amp.autocast('cuda', enabled=False):
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
else:
|
||||
s = torch.nn.functional.softmax(t, dim=dim)
|
||||
@@ -452,6 +472,7 @@ class Attention(nn.Module):
|
||||
biases: Optional[List[torch.Tensor]] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
lma_q_chunk_size: int = DEFAULT_LMA_Q_CHUNK_SIZE,
|
||||
lma_kv_chunk_size: int = DEFAULT_LMA_KV_CHUNK_SIZE,
|
||||
@@ -483,6 +504,11 @@ class Attention(nn.Module):
|
||||
Query chunk size (for LMA)
|
||||
lma_kv_chunk_size:
|
||||
Key/Value chunk size (for LMA)
|
||||
use_cuequivariance_attention:
|
||||
Whether to use cuEquivariance attention kernel.
|
||||
When on, biases[0] contains 0/1 mask tensor for cuEquivariance attention (0 for invalid positions)
|
||||
|
||||
|
||||
Returns
|
||||
[*, Q, C_q] attention update
|
||||
"""
|
||||
@@ -498,7 +524,13 @@ class Attention(nn.Module):
|
||||
"use flash_mask instead"
|
||||
)
|
||||
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention, use_lma, use_flash]
|
||||
if use_cuequivariance_attention:
|
||||
if biases is None or len(biases) != 2:
|
||||
raise ValueError(
|
||||
"cuEquivariance attention requires exactly two bias terms"
|
||||
)
|
||||
|
||||
attn_options = [use_memory_efficient_kernel, use_deepspeed_evo_attention or use_cuequivariance_attention, use_lma, use_flash]
|
||||
if sum(attn_options) > 1:
|
||||
raise ValueError(
|
||||
"Choose at most one alternative attention algorithm"
|
||||
@@ -509,12 +541,20 @@ class Attention(nn.Module):
|
||||
|
||||
# DeepSpeed attention kernel applies scaling internally
|
||||
q, k, v = self._prep_qkv(q_x, kv_x,
|
||||
apply_scale=not use_deepspeed_evo_attention)
|
||||
apply_scale=not use_deepspeed_evo_attention or use_cuequivariance_attention)
|
||||
|
||||
if is_fp16_enabled():
|
||||
use_memory_efficient_kernel = False
|
||||
|
||||
if use_memory_efficient_kernel:
|
||||
# cuequivariance kernel takes precedence over use_deepspeed_evo_attention
|
||||
if use_cuequivariance_attention:
|
||||
if not cueq_is_installed:
|
||||
raise ValueError(
|
||||
"Running with `use_cuequivariance_attention` but package is not "
|
||||
"installed. See documentation for installation instructions."
|
||||
)
|
||||
o = _cuequivariance_attn(q, k, v, biases[1], biases[0])
|
||||
elif use_memory_efficient_kernel:
|
||||
if len(biases) > 2:
|
||||
raise ValueError(
|
||||
"If use_memory_efficient_kernel is True, you may only "
|
||||
@@ -828,3 +868,68 @@ def _flash_attn(q, k, v, kv_mask):
|
||||
out = out.to(dtype=dtype)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@torch.jit.ignore
|
||||
def _cuequivariance_attn(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
bias: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Compute attention using the cuEquivariance triangle attention kernel.
|
||||
|
||||
Args:
|
||||
q: [*, H, Q, C_hidden] query data
|
||||
k: [*, H, K, C_hidden] key data
|
||||
v: [*, H, V, C_hidden] value data
|
||||
bias: [*, H, Q, K] triangular bias
|
||||
mask: [*, Q, K] mask for masking invalid positions
|
||||
|
||||
Returns:
|
||||
[*, H, Q, C_hidden] attention output
|
||||
"""
|
||||
|
||||
# Check input dimensionality
|
||||
qdim = len(q.shape)
|
||||
# If we have 4D tensors ([*, H, Q, D]), add batch dimension
|
||||
if qdim == 4:
|
||||
q = q.unsqueeze(0) # [1, H, Q, D]
|
||||
k = k.unsqueeze(0) # [1, H, K, D]
|
||||
v = v.unsqueeze(0) # [1, H, V, D]
|
||||
bias = bias.unsqueeze(0) # [1, H, Q, K]
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(0) # [1, Q, K]
|
||||
elif len(q.shape[:-3]) > 2:
|
||||
# If there are more than 2 leading dimensions, flatten them into B*N
|
||||
batch_shape = q.shape[:-3]
|
||||
flat_batch_size = 1
|
||||
for dim in batch_shape:
|
||||
flat_batch_size *= dim
|
||||
|
||||
q = q.reshape(flat_batch_size, *q.shape[-3:])
|
||||
k = k.reshape(flat_batch_size, *k.shape[-3:])
|
||||
v = v.reshape(flat_batch_size, *v.shape[-3:])
|
||||
bias = bias.reshape(flat_batch_size, *bias.shape[-3:])
|
||||
if mask is not None:
|
||||
mask = mask.reshape(flat_batch_size, *mask.shape[-2:])
|
||||
|
||||
# Apply cuEquivariance triangle attention
|
||||
o = triangle_attention(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
bias=bias,
|
||||
mask=mask
|
||||
)
|
||||
|
||||
# If we added a batch dimension for 4D inputs, remove it
|
||||
if qdim == 4:
|
||||
o = o.squeeze(0)
|
||||
|
||||
# Final transpose to match expected output format
|
||||
o = o.transpose(-2, -3)
|
||||
|
||||
return o
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -216,6 +217,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
_attn_chunk_size: Optional[int],
|
||||
single_mask: torch.Tensor,
|
||||
use_deepspeed_evo_attention: bool,
|
||||
use_cuequivariance_attention: bool,
|
||||
use_lma: bool,
|
||||
inplace_safe: bool):
|
||||
single = add(single,
|
||||
@@ -225,6 +227,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -239,6 +242,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
chunk_size=_attn_chunk_size,
|
||||
mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -251,12 +255,14 @@ class TemplatePairStackBlock(nn.Module):
|
||||
def tri_mul_out_in(self,
|
||||
single: torch.Tensor,
|
||||
single_mask: torch.Tensor,
|
||||
use_cuequivariance_multiplicative_update: bool,
|
||||
inplace_safe: bool):
|
||||
tmu_update = self.tri_mul_out(
|
||||
single,
|
||||
mask=single_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if not inplace_safe:
|
||||
single = single + self.dropout_row(tmu_update)
|
||||
@@ -270,6 +276,7 @@ class TemplatePairStackBlock(nn.Module):
|
||||
mask=single_mask,
|
||||
inplace_safe=inplace_safe,
|
||||
_add_with_inplace=True,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update
|
||||
)
|
||||
if not inplace_safe:
|
||||
single = single + self.dropout_row(tmu_update)
|
||||
@@ -285,6 +292,8 @@ class TemplatePairStackBlock(nn.Module):
|
||||
mask: torch.Tensor,
|
||||
chunk_size: Optional[int] = None,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -307,10 +316,12 @@ class TemplatePairStackBlock(nn.Module):
|
||||
if self.tri_mul_first:
|
||||
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
|
||||
single_mask=single_mask,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
inplace_safe=inplace_safe),
|
||||
_attn_chunk_size=_attn_chunk_size,
|
||||
single_mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe)
|
||||
else:
|
||||
@@ -319,9 +330,11 @@ class TemplatePairStackBlock(nn.Module):
|
||||
_attn_chunk_size=_attn_chunk_size,
|
||||
single_mask=single_mask,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe),
|
||||
single_mask=single_mask,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
inplace_safe=inplace_safe)
|
||||
|
||||
single = add(single,
|
||||
@@ -405,7 +418,7 @@ class TemplatePairStack(nn.Module):
|
||||
self.tune_chunk_size = tune_chunk_size
|
||||
self.chunk_size_tuner = None
|
||||
if tune_chunk_size:
|
||||
self.chunk_size_tuner = ChunkSizeTuner()
|
||||
self.chunk_size_tuner = ChunkSizeTuner(2048)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -413,6 +426,8 @@ class TemplatePairStack(nn.Module):
|
||||
mask: torch.tensor,
|
||||
chunk_size: int,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
_mask_trans: bool = True,
|
||||
@@ -437,6 +452,8 @@ class TemplatePairStack(nn.Module):
|
||||
mask=mask,
|
||||
chunk_size=chunk_size,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=use_cuequivariance_multiplicative_update,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=_mask_trans,
|
||||
@@ -451,10 +468,11 @@ class TemplatePairStack(nn.Module):
|
||||
args=(t.clone(),),
|
||||
min_chunk_size=chunk_size,
|
||||
)
|
||||
attn_chunk = tuned_chunk_size if use_cuequivariance_attention else (tuned_chunk_size // 4)
|
||||
blocks = [
|
||||
partial(b,
|
||||
chunk_size=tuned_chunk_size,
|
||||
_attn_chunk_size=max(chunk_size, tuned_chunk_size // 4),
|
||||
_attn_chunk_size=max(chunk_size, attn_chunk),
|
||||
) for b in blocks
|
||||
]
|
||||
|
||||
@@ -528,6 +546,8 @@ def embed_templates_offload(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
@@ -647,6 +667,8 @@ def embed_templates_average(
|
||||
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
|
||||
chunk_size=model.globals.chunk_size,
|
||||
use_deepspeed_evo_attention=model.globals.use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=model.globals.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=model.globals.use_cuequivariance_multiplicative_update,
|
||||
use_lma=model.globals.use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
_mask_trans=model.config._mask_trans,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -64,6 +65,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size: int,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -79,6 +81,7 @@ class TriangleAttention(nn.Module):
|
||||
self.mha,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma
|
||||
),
|
||||
mha_inputs,
|
||||
@@ -93,6 +96,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size: Optional[int] = None,
|
||||
use_memory_efficient_kernel: bool = False,
|
||||
use_deepspeed_evo_attention: bool = False,
|
||||
use_cuequivariance_attention: bool = False,
|
||||
use_lma: bool = False,
|
||||
inplace_safe: bool = False,
|
||||
) -> torch.Tensor:
|
||||
@@ -117,7 +121,10 @@ class TriangleAttention(nn.Module):
|
||||
x = self.layer_norm(x)
|
||||
|
||||
# [*, I, 1, 1, J]
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
if use_cuequivariance_attention:
|
||||
mask_bias = mask[..., :, None, None, :]
|
||||
else:
|
||||
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
|
||||
|
||||
# [*, H, I, J]
|
||||
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
|
||||
@@ -134,6 +141,7 @@ class TriangleAttention(nn.Module):
|
||||
chunk_size,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma,
|
||||
inplace_safe=inplace_safe,
|
||||
)
|
||||
@@ -144,6 +152,7 @@ class TriangleAttention(nn.Module):
|
||||
biases=biases,
|
||||
use_memory_efficient_kernel=use_memory_efficient_kernel,
|
||||
use_deepspeed_evo_attention=use_deepspeed_evo_attention,
|
||||
use_cuequivariance_attention=use_cuequivariance_attention,
|
||||
use_lma=use_lma
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -16,15 +17,86 @@
|
||||
from functools import partialmethod
|
||||
from typing import Optional
|
||||
from abc import ABC, abstractmethod
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.fx._symbolic_trace import is_fx_tracing
|
||||
|
||||
from openfold.model.primitives import Linear, LayerNorm
|
||||
from openfold.utils.chunk_utils import chunk_layer
|
||||
from openfold.utils.precision_utils import is_fp16_enabled
|
||||
from openfold.utils.tensor_utils import add, permute_final_dims
|
||||
|
||||
# cuEquivariance import handling
|
||||
cuequivariance_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
if cuequivariance_is_installed:
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import triangle_multiplicative_update
|
||||
except ImportError:
|
||||
cuequivariance_is_installed = False
|
||||
|
||||
|
||||
def _cuequivariance_triangular_mult(
|
||||
x: torch.Tensor,
|
||||
direction: str,
|
||||
mask: Optional[torch.Tensor],
|
||||
norm_in_weight: torch.Tensor,
|
||||
norm_in_bias: torch.Tensor,
|
||||
p_in_weight: torch.Tensor,
|
||||
p_in_bias: torch.Tensor,
|
||||
g_in_weight: torch.Tensor,
|
||||
g_in_bias: torch.Tensor,
|
||||
norm_out_weight: torch.Tensor,
|
||||
norm_out_bias: torch.Tensor,
|
||||
p_out_weight: torch.Tensor,
|
||||
p_out_bias: torch.Tensor,
|
||||
g_out_weight: torch.Tensor,
|
||||
g_out_bias: torch.Tensor,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
"""
|
||||
Wrapper function for cuEquivariance triangle multiplicative update.
|
||||
|
||||
Args:
|
||||
x: [*, N, N, C] input tensor
|
||||
direction: "outgoing" or "incoming"
|
||||
mask: [*, N, N] mask tensor
|
||||
norm_in_weight: [C] input normalization weight
|
||||
norm_in_bias: [C] input normalization bias
|
||||
p_in_weight: [2*C, C] input projection weight
|
||||
g_in_weight: [2*C, C] input gating weight
|
||||
norm_out_weight: [C] output normalization weight
|
||||
norm_out_bias: [C] output normalization bias
|
||||
p_out_weight: [C, C] output projection weight
|
||||
g_out_weight: [C, C] output gating weight
|
||||
eps: epsilon for numerical stability
|
||||
|
||||
Returns:
|
||||
[*, N, N, C] output tensor
|
||||
"""
|
||||
if not cuequivariance_is_installed:
|
||||
raise ValueError(
|
||||
"_cuequivariance_triangular_mult requires that cuequivariance_torch be installed"
|
||||
)
|
||||
return triangle_multiplicative_update(
|
||||
x=x,
|
||||
direction=direction,
|
||||
mask=mask,
|
||||
norm_in_weight=norm_in_weight,
|
||||
norm_in_bias=norm_in_bias,
|
||||
p_in_weight=p_in_weight,
|
||||
p_in_bias=p_in_bias,
|
||||
g_in_weight=g_in_weight,
|
||||
g_in_bias=g_in_bias,
|
||||
norm_out_weight=norm_out_weight,
|
||||
norm_out_bias=norm_out_bias,
|
||||
p_out_weight=p_out_weight,
|
||||
p_out_bias=p_out_bias,
|
||||
g_out_weight=g_out_weight,
|
||||
g_out_bias=g_out_bias,
|
||||
eps=eps,
|
||||
).view(x.shape)
|
||||
|
||||
class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
|
||||
"""
|
||||
@@ -87,6 +159,7 @@ class BaseTriangleMultiplicativeUpdate(nn.Module, ABC):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
@@ -397,6 +470,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False,
|
||||
_inplace_chunk_size: Optional[int] = 256,
|
||||
) -> torch.Tensor:
|
||||
@@ -409,7 +483,38 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
if(inplace_safe):
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
p_in_weight = torch.cat([self.linear_a_p.weight, self.linear_b_p.weight], dim=0)
|
||||
g_in_weight = torch.cat([self.linear_a_g.weight, self.linear_b_g.weight], dim=0)
|
||||
|
||||
p_in_bias = torch.cat([self.linear_a_p.bias, self.linear_b_p.bias], dim=0)
|
||||
g_in_bias = torch.cat([self.linear_a_g.bias, self.linear_b_g.bias], dim=0)
|
||||
|
||||
result = _cuequivariance_triangular_mult(
|
||||
z,
|
||||
direction="outgoing" if self._outgoing else "incoming",
|
||||
mask=mask,
|
||||
norm_in_weight=self.layer_norm_in.weight,
|
||||
norm_in_bias=self.layer_norm_in.bias,
|
||||
p_in_weight=p_in_weight,
|
||||
p_in_bias=p_in_bias,
|
||||
g_in_weight=g_in_weight,
|
||||
g_in_bias=g_in_bias,
|
||||
norm_out_weight=self.layer_norm_out.weight,
|
||||
norm_out_bias=self.layer_norm_out.bias,
|
||||
p_out_weight=self.linear_z.weight,
|
||||
p_out_bias=self.linear_z.bias,
|
||||
g_out_weight=self.linear_g.weight,
|
||||
g_out_bias=self.linear_g.bias,
|
||||
eps=1e-5,
|
||||
)
|
||||
# When not inplace_safe (training), caller should have set _add_with_inplace to False
|
||||
if inplace_safe and _add_with_inplace:
|
||||
result += z
|
||||
return result
|
||||
|
||||
if inplace_safe:
|
||||
x = self._inference_forward(
|
||||
z,
|
||||
mask,
|
||||
@@ -422,7 +527,7 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
mask = z.new_ones(z.shape[:-1])
|
||||
|
||||
mask = mask.unsqueeze(-1)
|
||||
|
||||
|
||||
z = self.layer_norm_in(z)
|
||||
a = mask
|
||||
a = a * self.sigmoid(self.linear_a_g(z))
|
||||
@@ -433,13 +538,12 @@ class TriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
|
||||
# Prevents overflow of torch.matmul in combine projections in
|
||||
# reduced-precision modes
|
||||
a_std = a.std()
|
||||
b_std = b.std()
|
||||
if(is_fp16_enabled() and a_std != 0. and b_std != 0.):
|
||||
a = a / a.std()
|
||||
b = b / b.std()
|
||||
|
||||
if(is_fp16_enabled()):
|
||||
if is_fp16_enabled():
|
||||
a_std = a.std()
|
||||
b_std = b.std()
|
||||
if a_std != 0. and b_std != 0.:
|
||||
a = a / a.std()
|
||||
b = b / b.std()
|
||||
with torch.cuda.amp.autocast(enabled=False):
|
||||
x = self._combine_projections(a.float(), b.float())
|
||||
else:
|
||||
@@ -545,6 +649,7 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
z: torch.Tensor,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
inplace_safe: bool = False,
|
||||
use_cuequivariance_multiplicative_update: bool = False,
|
||||
_add_with_inplace: bool = False,
|
||||
_inplace_chunk_size: Optional[int] = 256
|
||||
) -> torch.Tensor:
|
||||
@@ -557,6 +662,32 @@ class FusedTriangleMultiplicativeUpdate(BaseTriangleMultiplicativeUpdate):
|
||||
Returns:
|
||||
[*, N_res, N_res, C_z] output tensor
|
||||
"""
|
||||
|
||||
if use_cuequivariance_multiplicative_update:
|
||||
direction = "outgoing" if self._outgoing else "incoming"
|
||||
result = _cuequivariance_triangular_mult(
|
||||
x=z,
|
||||
direction=direction,
|
||||
mask=mask,
|
||||
norm_in_weight=self.layer_norm_in.weight,
|
||||
norm_in_bias=self.layer_norm_in.bias,
|
||||
p_in_weight=self.linear_ab_p.weight,
|
||||
p_in_bias=self.linear_ab_p.bias,
|
||||
g_in_weight=self.linear_ab_g.weight,
|
||||
g_in_bias=self.linear_ab_g.bias,
|
||||
norm_out_weight=self.layer_norm_out.weight,
|
||||
norm_out_bias=self.layer_norm_out.bias,
|
||||
p_out_weight=self.linear_z.weight,
|
||||
p_out_bias=self.linear_z.bias,
|
||||
g_out_weight=self.linear_g.weight,
|
||||
g_out_bias=self.linear_g.bias,
|
||||
eps=1e-5,
|
||||
)
|
||||
# When not inplace_safe (training), caller should have set _add_with_inplace to False
|
||||
if inplace_safe and _add_with_inplace:
|
||||
result += z
|
||||
return result
|
||||
|
||||
if (inplace_safe):
|
||||
x = self._inference_forward(
|
||||
z,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -252,6 +253,16 @@ def chunk_layer(
|
||||
initial_dims = [shape[:no_batch_dims] for shape in _fetch_dims(inputs)]
|
||||
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
if no_chunks == 1:
|
||||
return layer(**inputs)
|
||||
|
||||
def _prep_inputs(t):
|
||||
if(not low_mem):
|
||||
if not sum(t.shape[:no_batch_dims]) == no_batch_dims:
|
||||
@@ -267,14 +278,6 @@ def chunk_layer(
|
||||
reshape_fn = lambda t: t.view([-1] + list(t.shape[no_batch_dims:]))
|
||||
prepped_outputs = tensor_tree_map(reshape_fn, _out)
|
||||
|
||||
flat_batch_dim = 1
|
||||
for d in orig_batch_dims:
|
||||
flat_batch_dim *= d
|
||||
|
||||
no_chunks = flat_batch_dim // chunk_size + (
|
||||
flat_batch_dim % chunk_size != 0
|
||||
)
|
||||
|
||||
i = 0
|
||||
out = prepped_outputs
|
||||
for _ in range(no_chunks):
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2022 AlQuraishi Laboratory
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -15,6 +16,57 @@ import importlib
|
||||
|
||||
import torch
|
||||
|
||||
def cast_tensor(x, from_dtype, to_dtype):
|
||||
return x.to(dtype=to_dtype) if torch.is_tensor(x) and x.dtype == from_dtype else x
|
||||
|
||||
|
||||
def cast_all(x, from_dtype, to_dtype):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return cast_tensor(x, from_dtype=from_dtype, to_dtype=to_dtype)
|
||||
else:
|
||||
if isinstance(x, dict):
|
||||
new_dict = {}
|
||||
for k in x.keys():
|
||||
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
|
||||
return new_dict
|
||||
elif isinstance(x, tuple):
|
||||
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
|
||||
elif isinstance(x, list):
|
||||
return list(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
|
||||
else:
|
||||
return x
|
||||
|
||||
class PrecisionWrapper(torch.nn.Module):
|
||||
def __init__(self, model, precision):
|
||||
super().__init__()
|
||||
self.precision = precision
|
||||
if self.precision == "bf16":
|
||||
print(f"Converting {model.__class__} to BF16 ...")
|
||||
model = model.bfloat16()
|
||||
elif self.precision == "fp16":
|
||||
print(f"Converting {model.__class__} to FP16 ...")
|
||||
model = model.half()
|
||||
self.model = model
|
||||
|
||||
# TODO: generalize!!
|
||||
def forward(self, *args, **kwargs):
|
||||
if self.precision == "bf16":
|
||||
args = cast_all(args, from_dtype=torch.float32, to_dtype=torch.bfloat16)
|
||||
kwargs = cast_all(kwargs, from_dtype=torch.float32, to_dtype=torch.bfloat16)
|
||||
elif self.precision == "fp16":
|
||||
args = cast_all(args, from_dtype=torch.float32, to_dtype=torch.float16)
|
||||
kwargs = cast_all(kwargs, from_dtype=torch.float32, to_dtype=torch.float16)
|
||||
out = self.model(*args, **kwargs)
|
||||
if self.precision == "bf16":
|
||||
out = cast_all(out, from_dtype=torch.bfloat16, to_dtype=torch.float32)
|
||||
elif self.precision == "fp16":
|
||||
out = cast_all(out, from_dtype=torch.float16, to_dtype=torch.float32)
|
||||
|
||||
return out
|
||||
|
||||
def wrap_for_precision(model, precision):
|
||||
return PrecisionWrapper(model, precision)
|
||||
|
||||
def is_fp16_enabled():
|
||||
# Autocast world
|
||||
fp16_enabled = torch.get_autocast_gpu_dtype() == torch.float16
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -19,11 +34,13 @@ from pytorch_lightning.utilities.deepspeed import (
|
||||
convert_zero_checkpoint_to_fp32_state_dict
|
||||
)
|
||||
|
||||
from .tensorrt_utils import instrument_with_trt_compile
|
||||
from .precision_utils import wrap_for_precision
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
|
||||
def count_models_to_evaluate(openfold_checkpoint_path, jax_param_path):
|
||||
model_count = 0
|
||||
if openfold_checkpoint_path:
|
||||
@@ -50,6 +67,14 @@ def make_output_directory(output_dir, model_name, multiple_model_mode):
|
||||
return prediction_dir
|
||||
|
||||
|
||||
def _accelerate(model, config):
|
||||
if config.trt.mode is not None:
|
||||
instrument_with_trt_compile(model, config)
|
||||
if config.precision is not None and config.precision in ['bf16', 'fp16']:
|
||||
model.evoformer = wrap_for_precision(model.evoformer, config.precision)
|
||||
model.extra_msa_stack = wrap_for_precision(model.extra_msa_stack, config.precision)
|
||||
|
||||
|
||||
def load_models_from_command_line(config, model_device, openfold_checkpoint_path, jax_param_path, output_dir):
|
||||
# Create the output directory
|
||||
|
||||
@@ -71,6 +96,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
|
||||
f"Successfully loaded JAX parameters at {path}..."
|
||||
)
|
||||
output_directory = make_output_directory(output_dir, model_basename, multiple_model_mode)
|
||||
_accelerate(model, config)
|
||||
yield model, output_directory
|
||||
|
||||
if openfold_checkpoint_path:
|
||||
@@ -106,6 +132,7 @@ def load_models_from_command_line(config, model_device, openfold_checkpoint_path
|
||||
f"Loaded OpenFold parameters at {path}..."
|
||||
)
|
||||
output_directory = make_output_directory(output_dir, checkpoint_basename, multiple_model_mode)
|
||||
_accelerate(model, config)
|
||||
yield model, output_directory
|
||||
|
||||
if not jax_param_path and not openfold_checkpoint_path:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -20,6 +21,8 @@ from typing import Tuple, List, Callable, Any, Dict, Sequence, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
def maybe_to(x, dtype):
|
||||
return x.to(dtype=dtype) if x is not None and x.dtype in [torch.float32, torch.float16, torch.bfloat16] else x
|
||||
|
||||
def add(m1, m2, inplace):
|
||||
# The first operation in a checkpoint can't be in-place, but it's
|
||||
@@ -33,9 +36,9 @@ def add(m1, m2, inplace):
|
||||
|
||||
|
||||
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
|
||||
zero_index = -1 * len(inds)
|
||||
first_inds = list(range(len(tensor.shape[:zero_index])))
|
||||
return tensor.permute(first_inds + [zero_index + i for i in inds])
|
||||
num_first_dims = len(tensor.shape)-len(inds)
|
||||
first_inds = list(range(num_first_dims))
|
||||
return tensor.permute(first_inds + [num_first_dims + i for i in inds])
|
||||
|
||||
|
||||
def flatten_final_dims(t: torch.Tensor, no_dims: int):
|
||||
|
||||
897
openfold/utils/tensorrt_lazy_compiler.py
Normal file
897
openfold/utils/tensorrt_lazy_compiler.py
Normal file
@@ -0,0 +1,897 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Any, Dict, List, Sequence, Tuple, Union
|
||||
|
||||
import cuda.cudart as cudart
|
||||
import tensorrt as trt
|
||||
import torch
|
||||
from polygraphy.backend.common import bytes_from_path
|
||||
from polygraphy.backend.onnx.loader import fold_constants, onnx_from_path, save_onnx
|
||||
from polygraphy.backend.trt import (
|
||||
CreateConfig,
|
||||
Profile,
|
||||
engine_bytes_from_network,
|
||||
engine_from_bytes,
|
||||
network_from_onnx_path,
|
||||
)
|
||||
from polygraphy.logger import G_LOGGER
|
||||
|
||||
lock_sm = threading.Lock()
|
||||
G_LOGGER.module_severity = G_LOGGER.VERBOSE
|
||||
G_LOGGER.use_python_logging_system = True
|
||||
|
||||
|
||||
def trt_to_torch_dtype_dict():
|
||||
"""
|
||||
Map of TRT dtype -> Torch dtype
|
||||
"""
|
||||
return {
|
||||
trt.int32: torch.int32,
|
||||
trt.float32: torch.float32,
|
||||
trt.float16: torch.float16,
|
||||
trt.bfloat16: torch.bfloat16,
|
||||
trt.int64: torch.int64,
|
||||
trt.int8: torch.int8,
|
||||
trt.bool: torch.bool,
|
||||
}
|
||||
|
||||
|
||||
def get_profile_shapes(
|
||||
input_shape: Sequence[int], dynamic_batchsize: Sequence[int] | None
|
||||
):
|
||||
"""
|
||||
Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
|
||||
"""
|
||||
|
||||
def scale_batch_size(input_shape: Sequence[int], scale_num: int):
|
||||
scale_shape = [*input_shape]
|
||||
scale_shape[0] = scale_num
|
||||
return scale_shape
|
||||
|
||||
# Use the dynamic batchsize range to generate the min, opt and max model input shape
|
||||
if dynamic_batchsize:
|
||||
min_input_shape = scale_batch_size(input_shape, dynamic_batchsize[0])
|
||||
opt_input_shape = scale_batch_size(input_shape, dynamic_batchsize[1])
|
||||
max_input_shape = scale_batch_size(input_shape, dynamic_batchsize[2])
|
||||
else:
|
||||
min_input_shape = opt_input_shape = max_input_shape = input_shape
|
||||
return min_input_shape, opt_input_shape, max_input_shape
|
||||
|
||||
|
||||
def get_dynamic_axes(profiles):
|
||||
"""
|
||||
This method calculates dynamic_axes to use in onnx.export().
|
||||
Args:
|
||||
profiles: [[min,opt,max],...] list of profile dimensions
|
||||
"""
|
||||
dynamic_axes: dict[str, list[int]] = {}
|
||||
if not profiles:
|
||||
return dynamic_axes
|
||||
for profile in profiles:
|
||||
for key in profile:
|
||||
axes = []
|
||||
vals = profile[key]
|
||||
for i in range(len(vals[0])):
|
||||
if vals[0][i] != vals[2][i]:
|
||||
axes.append(i)
|
||||
if len(axes) > 0:
|
||||
dynamic_axes[key] = axes
|
||||
return dynamic_axes
|
||||
|
||||
|
||||
def cuassert(cuda_ret):
|
||||
"""
|
||||
Error reporting method for CUDA calls.
|
||||
Args:
|
||||
cuda_ret: CUDA return code.
|
||||
"""
|
||||
err = cuda_ret[0]
|
||||
if err != 0:
|
||||
raise RuntimeError(f"CUDA ERROR: {err}")
|
||||
if len(cuda_ret) > 1:
|
||||
return cuda_ret[1]
|
||||
return None
|
||||
|
||||
|
||||
class ShapeError(Exception):
|
||||
"""
|
||||
Exception class to report errors from setting TRT plan input shapes
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TRTEngine:
|
||||
"""
|
||||
An auxiliary class to implement running of TRT optimized engines
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, plan_path, logger=None):
|
||||
"""
|
||||
Loads serialized engine, creates execution context and activates it
|
||||
Args:
|
||||
plan_path: path to serialized TRT engine.
|
||||
logger: optional logger object
|
||||
"""
|
||||
self.input_names = []
|
||||
self.output_names = []
|
||||
self.dtypes = []
|
||||
self.cur_profile = 0
|
||||
self.input_table = {}
|
||||
dtype_dict = trt_to_torch_dtype_dict()
|
||||
|
||||
self.plan_path = plan_path
|
||||
self.logger = logger or getLogger("trt_compile")
|
||||
self.logger.info(f"Loading TensorRT engine: {self.plan_path}")
|
||||
self.engine = engine_from_bytes(bytes_from_path(self.plan_path))
|
||||
self.tensors = OrderedDict()
|
||||
self.cuda_graph_instance = None # cuda graph
|
||||
for idx in range(self.engine.num_io_tensors):
|
||||
binding = self.engine[idx]
|
||||
if self.engine.get_tensor_mode(binding) == trt.TensorIOMode.INPUT:
|
||||
self.input_names.append(binding)
|
||||
elif self.engine.get_tensor_mode(binding) == trt.TensorIOMode.OUTPUT:
|
||||
self.output_names.append(binding)
|
||||
dtype = dtype_dict[self.engine.get_tensor_dtype(binding)]
|
||||
self.dtypes.append(dtype)
|
||||
self.context = self.engine.create_execution_context()
|
||||
required_size = self.engine.device_memory_size
|
||||
if self.context:
|
||||
self.logger.info(
|
||||
f"Loaded TensorRT engine: {self.plan_path}.\nInputs: {self.input_names}\nOutputs: {self.output_names}\nContext memory size: {required_size}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Failed to create execution context for TensorRT engine: {self.plan_path}"
|
||||
)
|
||||
self.disabled = True
|
||||
|
||||
def allocate_buffers(self, device):
|
||||
"""
|
||||
Allocates outputs to run TRT engine
|
||||
Args:
|
||||
device: GPU device to allocate memory on
|
||||
"""
|
||||
ctx = self.context
|
||||
|
||||
for i, binding in enumerate(self.output_names):
|
||||
shape = list(ctx.get_tensor_shape(binding))
|
||||
if (
|
||||
binding not in self.tensors
|
||||
or list(self.tensors[binding].shape) != shape
|
||||
):
|
||||
t = torch.empty(shape, dtype=self.dtypes[i], device=device).contiguous()
|
||||
self.tensors[binding] = t
|
||||
ctx.set_tensor_address(binding, t.data_ptr())
|
||||
|
||||
def _check_shape_in_range(self, dims: list[trt.Dims], shape: torch.Size) -> bool:
|
||||
"""
|
||||
Checks if shape is within the range of the optimization profile.
|
||||
"""
|
||||
min_opt = dims[0]
|
||||
max_opt = dims[-1]
|
||||
in_range = True
|
||||
|
||||
in_range = in_range and all(shape[i] >= d for i, d in enumerate(min_opt))
|
||||
in_range = in_range and all(shape[i] <= d for i, d in enumerate(max_opt))
|
||||
return in_range
|
||||
|
||||
|
||||
def set_inputs(self, feed_dict, stream):
|
||||
"""
|
||||
Sets input bindings for TRT engine according to feed_dict
|
||||
|
||||
Args:
|
||||
feed_dict: a dictionary [str->Tensor]
|
||||
stream: CUDA stream to use
|
||||
"""
|
||||
|
||||
def set_profile():
|
||||
next_profile = self.cur_profile
|
||||
found = False
|
||||
for _ in range(e.num_optimization_profiles):
|
||||
tmp_profile = next_profile
|
||||
for binding in self.input_names:
|
||||
dims = e.get_tensor_profile_shape(binding, next_profile)
|
||||
t = feed_dict.get(self.input_table[binding], None)
|
||||
if t is None:
|
||||
raise ValueError(f"Not found tensor {binding} in feed_dict")
|
||||
in_range = self._check_shape_in_range(dims, t.shape)
|
||||
if not in_range:
|
||||
next_profile = (next_profile + 1) % e.num_optimization_profiles
|
||||
break
|
||||
if tmp_profile == next_profile:
|
||||
found = True
|
||||
break
|
||||
if found:
|
||||
self.logger.debug(f"Using optimization profile {next_profile}")
|
||||
if next_profile != self.cur_profile:
|
||||
ctx.set_optimization_profile_async(next_profile, stream)
|
||||
self.cur_profile = next_profile
|
||||
else:
|
||||
raise ShapeError("Shape out of range")
|
||||
|
||||
def try_set_inputs():
|
||||
for binding in self.input_names:
|
||||
t = feed_dict.get(self.input_table[binding], None)
|
||||
if t is not None:
|
||||
t = t.contiguous()
|
||||
shape = t.shape
|
||||
ctx.set_input_shape(binding, shape)
|
||||
ctx.set_tensor_address(binding, t.data_ptr())
|
||||
|
||||
e = self.engine
|
||||
ctx = self.context
|
||||
if e.num_optimization_profiles > 1:
|
||||
set_profile()
|
||||
|
||||
try_set_inputs()
|
||||
left = ctx.infer_shapes()
|
||||
# required_size = ctx.update_device_memory_size_for_shapes()
|
||||
# self.logger.info(f"Need context memory: {required_size}")
|
||||
assert len(left) == 0
|
||||
|
||||
def infer(self, stream, use_cuda_graph=False):
|
||||
"""
|
||||
Runs TRT engine.
|
||||
Args:
|
||||
stream: CUDA stream to run on
|
||||
use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls.
|
||||
"""
|
||||
if use_cuda_graph:
|
||||
if self.cuda_graph_instance is not None:
|
||||
cuassert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
|
||||
cuassert(cudart.cudaStreamSynchronize(stream))
|
||||
else:
|
||||
# do inference before CUDA graph capture
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
if not noerror:
|
||||
raise ValueError("ERROR: inference failed.")
|
||||
# capture cuda graph
|
||||
cuassert(
|
||||
cudart.cudaStreamBeginCapture(
|
||||
stream,
|
||||
cudart.cudaStreamCaptureMode.cudaStreamCaptureModeThreadLocal,
|
||||
)
|
||||
)
|
||||
self.context.execute_async_v3(stream)
|
||||
graph = cuassert(cudart.cudaStreamEndCapture(stream))
|
||||
self.cuda_graph_instance = cuassert(
|
||||
cudart.cudaGraphInstantiate(graph, 0)
|
||||
)
|
||||
self.logger.info("CUDA Graph captured!")
|
||||
else:
|
||||
noerror = self.context.execute_async_v3(stream)
|
||||
cuassert(cudart.cudaStreamSynchronize(stream))
|
||||
if not noerror:
|
||||
raise ValueError(f"ERROR: inference failed: {noerror}.")
|
||||
return self.tensors
|
||||
|
||||
|
||||
def make_tensor(d):
|
||||
"""
|
||||
Creates a new tensor from d, returns d if d is already a tensor
|
||||
"""
|
||||
return d if isinstance(d, torch.Tensor) else torch.tensor(d).cuda()
|
||||
|
||||
|
||||
def unroll_input(input_names, input_example):
|
||||
"""
|
||||
Simulates list/tuple unrolling during ONNX export
|
||||
"""
|
||||
|
||||
def unroll_one(name, val):
|
||||
res = {}
|
||||
try:
|
||||
if val is not None:
|
||||
if isinstance(val, dict):
|
||||
for key, data in val.items():
|
||||
subname = f"{name}_{key}"
|
||||
vals = unroll_one(subname, data)
|
||||
res.update(vals)
|
||||
elif isinstance(val, list) or isinstance(val, tuple):
|
||||
for i in range(len(val)):
|
||||
res.update(unroll_one(f"{name}_{i}", val[i]))
|
||||
else:
|
||||
res[name] = make_tensor(val)
|
||||
except Exception:
|
||||
pass
|
||||
return res
|
||||
|
||||
unrolled_input = {}
|
||||
for name in input_names:
|
||||
val = input_example.get(name, None)
|
||||
unrolled_input.update(unroll_one(name, val))
|
||||
return unrolled_input
|
||||
|
||||
|
||||
def parse_groups(
|
||||
ret: List[torch.Tensor], output_lists: List[List[int]]
|
||||
) -> Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]:
|
||||
"""
|
||||
Implements parsing of 'output_lists' arg of trt_compile().
|
||||
|
||||
Args:
|
||||
ret: plain list of Tensors
|
||||
|
||||
output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list
|
||||
of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
|
||||
Format: [[group_n] | [], ...]
|
||||
[] or group_n == 0 : next output from ret is a scalar
|
||||
group_n > 0 : next output from ret is a list of group_n length
|
||||
group_n == -1: next output is a dynamic list. This entry can be at any
|
||||
position in output_lists, but can appear only once.
|
||||
Returns:
|
||||
Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists
|
||||
|
||||
"""
|
||||
groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
|
||||
cur = 0
|
||||
for i in range(len(output_lists)):
|
||||
gl = output_lists[i]
|
||||
assert len(gl) == 0 or len(gl) == 1
|
||||
if len(gl) == 0 or gl[0] == 0:
|
||||
groups = (*groups, ret[cur])
|
||||
cur = cur + 1
|
||||
elif gl[0] > 0:
|
||||
groups = (*groups, ret[cur : cur + gl[0]])
|
||||
cur = cur + gl[0]
|
||||
elif gl[0] == -1:
|
||||
rev_groups: Tuple[Union[torch.Tensor, List[torch.Tensor]], ...] = tuple()
|
||||
rcur = len(ret)
|
||||
for rl in range(len(output_lists) - 1, i, -1):
|
||||
rgl = output_lists[rl]
|
||||
assert len(rgl) == 0 or len(rgl) == 1
|
||||
if len(rgl) == 0 or rgl[0] == 0:
|
||||
rcur = rcur - 1
|
||||
rev_groups = (*rev_groups, ret[rcur])
|
||||
elif rgl[0] > 0:
|
||||
rcur = rcur - rgl[0]
|
||||
rev_groups = (*rev_groups, ret[rcur : rcur + rgl[0]])
|
||||
else:
|
||||
raise ValueError("Two -1 lists in output")
|
||||
groups = (*groups, ret[cur:rcur], *rev_groups[::-1])
|
||||
break
|
||||
return groups
|
||||
|
||||
|
||||
class TrtCompiler:
|
||||
"""
|
||||
This class implements:
|
||||
- TRT lazy persistent export
|
||||
- Running TRT with optional fallback to Torch
|
||||
(for TRT engines with limited profiles)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
plan_path,
|
||||
precision="fp16",
|
||||
method="onnx",
|
||||
input_names=None,
|
||||
output_names=None,
|
||||
output_lists=None,
|
||||
export_args=None,
|
||||
build_args=None,
|
||||
input_profiles=None,
|
||||
dynamic_batchsize=None,
|
||||
use_cuda_graph=False,
|
||||
timestamp=None,
|
||||
fallback=False,
|
||||
function="forward",
|
||||
skip_once_registry=None,
|
||||
logger=None,
|
||||
verify=False,
|
||||
):
|
||||
"""
|
||||
Initialization method:
|
||||
Tries to load persistent serialized TRT engine
|
||||
Saves its arguments for lazy TRT build on first forward() call
|
||||
Args:
|
||||
model: Model to "wrap".
|
||||
plan_path : Path where to save persistent serialized TRT engine.
|
||||
precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'.
|
||||
method: One of 'onnx'|'torch_trt'.
|
||||
Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option.
|
||||
'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
|
||||
input_names: Optional list of input names. If None, will be read from the function signature.
|
||||
output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
|
||||
output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
|
||||
of their dimensions, like [[], [5], [-1]] for Tensor, list of 5 items and dynamic list.
|
||||
export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
|
||||
build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
|
||||
input_profiles: Optional list of profiles for TRT builder and ONNX export.
|
||||
Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}.
|
||||
dynamic_batchsize: A sequence with three elements to define the input batch size range for the model to be
|
||||
converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].
|
||||
[note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used.
|
||||
use_cuda_graph: Use CUDA Graph for inference. Note: inputs have to be the same GPU memory between calls!
|
||||
timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).
|
||||
fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).
|
||||
"""
|
||||
|
||||
method_vals = ["onnx", "torch_trt"]
|
||||
if method not in method_vals:
|
||||
raise ValueError(
|
||||
f"trt_compile(): 'method' should be one of {method_vals}, got: {method}."
|
||||
)
|
||||
precision_vals = ["fp32", "tf32", "fp16", "bf16"]
|
||||
if precision not in precision_vals:
|
||||
raise ValueError(
|
||||
f"trt_compile(): 'precision' should be one of {precision_vals}, got: {precision}."
|
||||
)
|
||||
|
||||
if skip_once_registry:
|
||||
if not fallback:
|
||||
raise ValueError(
|
||||
"trt_compile(): skip_once functionality requires fallback"
|
||||
)
|
||||
skip_once_registry.register_skip_once(self)
|
||||
|
||||
self.plan_path = plan_path
|
||||
self.precision = precision
|
||||
self.method = method
|
||||
self.return_dict = output_names is not None
|
||||
self.output_names = output_names or []
|
||||
self.output_lists = output_lists or []
|
||||
self.profiles = input_profiles or []
|
||||
self.dynamic_batchsize = dynamic_batchsize
|
||||
self.export_args = export_args or {}
|
||||
self.build_args = build_args or {}
|
||||
self.engine: TRTEngine | None = None
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.fallback = fallback
|
||||
self.verify = verify
|
||||
self.skip_once = False
|
||||
self.disabled = False
|
||||
|
||||
self.logger = logger or getLogger("trt_compile")
|
||||
self.argspec = inspect.getfullargspec(model.forward)
|
||||
# Normally we read input_names from forward() but can be overridden
|
||||
if input_names is None:
|
||||
input_names = self.argspec.args[1:]
|
||||
self.defaults = {}
|
||||
if self.argspec.defaults is not None:
|
||||
for i in range(len(self.argspec.defaults)):
|
||||
d = self.argspec.defaults[-i - 1]
|
||||
if d is not None:
|
||||
# d = make_tensor(d)
|
||||
self.defaults[self.argspec.args[-i - 1]] = d
|
||||
|
||||
self.input_names = input_names
|
||||
self.orig_function = getattr(model, function)
|
||||
setattr(model, function, MethodType(trt_forward, model))
|
||||
|
||||
# Force engine rebuild if older than the timestamp
|
||||
if (
|
||||
timestamp is not None
|
||||
and os.path.exists(self.plan_path)
|
||||
and os.path.getmtime(self.plan_path) < timestamp
|
||||
):
|
||||
os.remove(self.plan_path)
|
||||
|
||||
def _inputs_to_dict(self, input_example):
|
||||
trt_inputs = {}
|
||||
for i, inp in enumerate(input_example):
|
||||
input_name = self.input_names[i]
|
||||
trt_inputs[input_name] = inp
|
||||
return trt_inputs
|
||||
|
||||
def _load_engine(self):
|
||||
"""
|
||||
Loads TRT plan from disk and activates its execution context.
|
||||
"""
|
||||
try:
|
||||
self.engine = TRTEngine(self.plan_path, self.logger)
|
||||
# Make sure we have names correct
|
||||
input_table = {}
|
||||
for name in self.engine.input_names:
|
||||
if name.startswith("__") and name not in self.input_names:
|
||||
orig_name = name[2:]
|
||||
else:
|
||||
orig_name = name
|
||||
input_table[name] = orig_name
|
||||
self.engine.input_table = input_table
|
||||
except Exception as e:
|
||||
self.logger.info(f"Exception while loading the engine:\n{e}")
|
||||
|
||||
def forward(self, model, argv, kwargs):
|
||||
"""
|
||||
Main forward method:
|
||||
Builds TRT engine if not available yet.
|
||||
Tries to run TRT engine
|
||||
If exception thrown and self.callback==True: falls back to original Pytorch
|
||||
|
||||
Args: Passing through whatever args wrapped module's forward() has
|
||||
Returns: Passing through wrapped module's forward() return value(s)
|
||||
|
||||
"""
|
||||
# Let the caches be filled
|
||||
if self.skip_once:
|
||||
self.skip_once = False
|
||||
self.logger.info("Skipping once...")
|
||||
return self.orig_function(*argv, **kwargs)
|
||||
|
||||
args = self.defaults
|
||||
args.update(kwargs)
|
||||
if len(argv) > 0:
|
||||
args.update(self._inputs_to_dict(argv))
|
||||
|
||||
if self.engine is None and not self.disabled:
|
||||
# Restore original forward for export
|
||||
new_forward = model.forward
|
||||
model.forward = self.orig_function
|
||||
try:
|
||||
self._load_engine()
|
||||
if self.engine is None:
|
||||
build_args = args.copy()
|
||||
with torch.no_grad():
|
||||
self._build_and_save(model, build_args)
|
||||
# This will reassign input_names from the engine
|
||||
self._load_engine()
|
||||
assert self.engine is not None
|
||||
except Exception as e:
|
||||
if self.fallback:
|
||||
self.logger.info(f"Failed to build engine: {e}")
|
||||
self.disabled = True
|
||||
else:
|
||||
raise e
|
||||
if not self.disabled:
|
||||
self.move_model_to_cpu(model)
|
||||
# restore TRT hook
|
||||
model.forward = new_forward
|
||||
# Run the engine
|
||||
try:
|
||||
verifying = False
|
||||
if self.engine is not None:
|
||||
# forward_trt is not thread safe as we do not use per-thread execution contexts
|
||||
with lock_sm:
|
||||
device = torch.cuda.current_device()
|
||||
stream = torch.cuda.Stream(device=device)
|
||||
self.engine.set_inputs(
|
||||
unroll_input(self.input_names, args), stream.cuda_stream
|
||||
)
|
||||
self.engine.allocate_buffers(device=device)
|
||||
# Need this to synchronize with Torch stream
|
||||
stream.wait_stream(torch.cuda.current_stream())
|
||||
ret = self.engine.infer(
|
||||
stream.cuda_stream, use_cuda_graph=self.use_cuda_graph
|
||||
)
|
||||
# if output_names is not None, return dictionary
|
||||
if not self.return_dict:
|
||||
ret = list(ret.values())
|
||||
if self.output_lists:
|
||||
ret = parse_groups(ret, self.output_lists)
|
||||
elif len(ret) == 1:
|
||||
ret = ret[0]
|
||||
if self.verify:
|
||||
verifying = True
|
||||
orig_ret = self.orig_function(*argv, **kwargs)
|
||||
# breakpoint()
|
||||
torch.testing.assert_close(ret, orig_ret)
|
||||
self.logger.info("Results verified")
|
||||
return ret
|
||||
except Exception as e:
|
||||
if self.fallback and not verifying:
|
||||
self.logger.debug(f"Exception: {e}\nFalling back to Pytorch ...")
|
||||
else:
|
||||
raise e
|
||||
# fallback path
|
||||
if not self.disabled:
|
||||
model.cuda()
|
||||
ret = self.orig_function(*argv, **kwargs)
|
||||
if not self.disabled:
|
||||
model.cpu()
|
||||
torch.cuda.empty_cache()
|
||||
return ret
|
||||
|
||||
def _onnx_to_trt(self, onnx_path, enable_all_tactics=True):
|
||||
"""
|
||||
Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path
|
||||
"""
|
||||
torch.cuda.empty_cache()
|
||||
profiles = []
|
||||
for profile in self.profiles:
|
||||
p = Profile()
|
||||
for id, val in profile.items():
|
||||
p.add(id, min=val[0], opt=val[1], max=val[2])
|
||||
profiles.append(p)
|
||||
|
||||
build_args = self.build_args.copy()
|
||||
build_args["tf32"] = self.precision != "fp32"
|
||||
if self.precision == "fp16":
|
||||
build_args["fp16"] = True
|
||||
elif self.precision == "bf16":
|
||||
build_args["bf16"] = True
|
||||
|
||||
if not enable_all_tactics:
|
||||
build_args["tactic_sources"] = []
|
||||
else:
|
||||
build_args["tactic_sources"] = [
|
||||
trt.TacticSource.CUBLAS,
|
||||
trt.TacticSource.CUBLAS_LT,
|
||||
trt.TacticSource.EDGE_MASK_CONVOLUTIONS,
|
||||
trt.TacticSource.JIT_CONVOLUTIONS,
|
||||
]
|
||||
|
||||
self.logger.info(
|
||||
f"Building TensorRT engine for {onnx_path}: {self.plan_path}. Build args:\n{build_args}\nProfiles: {profiles}"
|
||||
)
|
||||
network = network_from_onnx_path(
|
||||
onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]
|
||||
)
|
||||
return engine_bytes_from_network(
|
||||
network, config=CreateConfig(profiles=profiles, **build_args)
|
||||
)
|
||||
|
||||
def move_model_to_cpu(self, model):
|
||||
free_mem0, total_mem = torch.cuda.mem_get_info()
|
||||
model.cpu()
|
||||
# Call empty_cache to release GPU memory
|
||||
torch.cuda.empty_cache()
|
||||
free_mem, total_mem = torch.cuda.mem_get_info()
|
||||
self.logger.info(
|
||||
f"Deallocated model memory: {(free_mem - free_mem0) / 1024**2:.2f} MB"
|
||||
)
|
||||
|
||||
def _build_and_save(self, model, input_example):
|
||||
"""
|
||||
If TRT engine is not ready, exports model to ONNX,
|
||||
builds TRT engine and saves serialized TRT engine to the disk.
|
||||
Args:
|
||||
input_example: passed to onnx.export()
|
||||
"""
|
||||
|
||||
if self.engine is not None:
|
||||
return
|
||||
|
||||
export_args = self.export_args
|
||||
engine_bytes = None
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if True:
|
||||
dbs = self.dynamic_batchsize
|
||||
if dbs:
|
||||
if len(self.profiles) > 0:
|
||||
raise ValueError(
|
||||
"ERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!"
|
||||
)
|
||||
if len(dbs) != 3:
|
||||
raise ValueError("dynamic_batchsize has to have len ==3 ")
|
||||
profile = {}
|
||||
for id, val in input_example.items():
|
||||
|
||||
def add_profile(id, val):
|
||||
sh = val.shape
|
||||
if len(sh) > 0:
|
||||
sh = sh[1:]
|
||||
profile[id] = [[dbs[0], *sh], [dbs[1], *sh], [dbs[2], *sh]]
|
||||
|
||||
if isinstance(val, list) or isinstance(val, tuple):
|
||||
for i in range(len(val)):
|
||||
add_profile(f"{id}_{i}", val[i])
|
||||
elif isinstance(val, torch.Tensor):
|
||||
add_profile(id, val)
|
||||
self.profiles = [profile]
|
||||
|
||||
if (
|
||||
"dynamic_axes" not in export_args
|
||||
and "dynamic_shapes" not in export_args
|
||||
):
|
||||
dynamic_axes = get_dynamic_axes(self.profiles)
|
||||
if dynamic_axes:
|
||||
export_args.update({"dynamic_axes": dynamic_axes})
|
||||
|
||||
if self.method == "torch_trt":
|
||||
raise ValueError("Torch-TensorRT option not implemented")
|
||||
else:
|
||||
# Use temporary directory for easy cleanup in case of external weights
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
post_proc = export_args.pop("postprocess", None)
|
||||
if export_args.get("dynamo", False):
|
||||
input_names = None
|
||||
else:
|
||||
input_names = list(
|
||||
unroll_input(self.input_names, input_example).keys()
|
||||
)
|
||||
|
||||
inputs = list(input_example.values())
|
||||
input_shapes = [inp.shape for inp in inputs if torch.is_tensor(inp)]
|
||||
onnx_path = str(Path(tmpdir) / "model.onnx")
|
||||
# onnx_path = "model.onnx"
|
||||
self.logger.info(
|
||||
f"Exporting to {onnx_path}:\n"
|
||||
+ f"output_names={self.output_names}\ninput_names={self.input_names}\nexport args: {export_args}\ninput shapes: {input_shapes}"
|
||||
)
|
||||
|
||||
if False: # self.verify:
|
||||
from torch.onnx.verification import VerificationOptions
|
||||
|
||||
ver_opts = VerificationOptions(rtol=1e-2, atol=1e-2)
|
||||
torch.onnx.verification.find_mismatch(
|
||||
model,
|
||||
tuple(input_example.values()),
|
||||
verbose=False,
|
||||
options=ver_opts,
|
||||
opset_version=export_args["opset_version"],
|
||||
)
|
||||
torch.onnx.export(
|
||||
model,
|
||||
(input_example,),
|
||||
onnx_path,
|
||||
input_names=input_names,
|
||||
output_names=self.output_names,
|
||||
**export_args,
|
||||
)
|
||||
|
||||
onnx_model = fold_constants(
|
||||
onnx_from_path(onnx_path),
|
||||
allow_onnxruntime_shape_inference=False,
|
||||
size_threshold=64 * 1024 * 1024,
|
||||
)
|
||||
if post_proc:
|
||||
onnx_model = post_proc(onnx_model)
|
||||
save_onnx(onnx_model, onnx_path)
|
||||
self.logger.info("Export to ONNX successful.")
|
||||
self.move_model_to_cpu(model)
|
||||
engine_bytes = self._onnx_to_trt(onnx_path)
|
||||
if engine_bytes:
|
||||
open(self.plan_path, "wb").write(engine_bytes)
|
||||
|
||||
|
||||
def trt_forward(self, *argv, **kwargs):
|
||||
"""
|
||||
Patch function to replace original model's forward() with.
|
||||
Redirects to TrtCompiler.forward()
|
||||
"""
|
||||
return self._trt_compiler.forward(self, argv, kwargs)
|
||||
|
||||
|
||||
def trt_registry_forward(self, *argv, **kwargs):
|
||||
"""
|
||||
Patch function to replace original model's forward() with.
|
||||
Redirects to TrtCompilerRegistry.forward()
|
||||
"""
|
||||
return self._trt_compiler_registry.forward(self, argv, kwargs)
|
||||
|
||||
|
||||
def trt_compile(
|
||||
model: torch.nn.Module,
|
||||
base_path: str,
|
||||
args: Dict[str, Any] | None = None,
|
||||
submodule: Union[str, List[str]] | None = None,
|
||||
logger: Any | None = None,
|
||||
) -> torch.nn.Module:
|
||||
"""
|
||||
Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
|
||||
Note: TRT 10.13+ is recommended for best performance.
|
||||
Args:
|
||||
model: module to patch with TrtCompiler object.
|
||||
base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
|
||||
dirname(base_path) must exist, base_path does not have to.
|
||||
If base_path does point to existing file (e.g. associated checkpoint),
|
||||
that file becomes a dependency - its mtime is added to args["timestamp"].
|
||||
args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details.
|
||||
submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder']
|
||||
If None, TrtCompiler patch is applied to the whole model.
|
||||
Otherwise, submodule (or list of) is being patched.
|
||||
logger: Optional logger for diagnostics.
|
||||
Returns:
|
||||
Always returns same model passed in as argument. This is for ease of use in configs.
|
||||
"""
|
||||
|
||||
default_args: Dict[str, Any] = {
|
||||
"method": "onnx",
|
||||
"precision": "bf16",
|
||||
"build_args": {
|
||||
"builder_optimization_level": 5,
|
||||
"precision_constraints": "prefer",
|
||||
},
|
||||
}
|
||||
|
||||
default_args.update(args or {})
|
||||
args = default_args
|
||||
|
||||
if torch.cuda.is_available():
|
||||
# if "path" filename point to existing file (e.g. checkpoint)
|
||||
# it's also treated as dependency
|
||||
if os.path.exists(base_path):
|
||||
timestamp = int(os.path.getmtime(base_path))
|
||||
if "timestamp" in args:
|
||||
timestamp = max(int(args["timestamp"]), timestamp)
|
||||
args["timestamp"] = timestamp
|
||||
|
||||
def wrap(model, path):
|
||||
if not hasattr(model, "_trt_compiler"):
|
||||
model.orig_forward = model.forward
|
||||
wrapper = TrtCompiler(model, path + ".plan", logger=logger, **args)
|
||||
model._trt_compiler = wrapper
|
||||
model.forward = MethodType(trt_forward, model)
|
||||
|
||||
def find_sub(parent, submodule):
|
||||
idx = submodule.find(".")
|
||||
# if there is "." in name, call recursively
|
||||
if idx != -1:
|
||||
parent_name = submodule[:idx]
|
||||
parent = getattr(parent, parent_name)
|
||||
submodule = submodule[idx + 1 :]
|
||||
return find_sub(parent, submodule)
|
||||
return parent, submodule
|
||||
|
||||
if submodule is not None:
|
||||
if isinstance(submodule, str):
|
||||
submodule = [submodule]
|
||||
for s in submodule:
|
||||
parent, sub = find_sub(model, s)
|
||||
wrap(getattr(parent, sub), base_path + "." + s)
|
||||
else:
|
||||
wrap(model, base_path)
|
||||
else:
|
||||
logger = logger or getLogger("trt_compile")
|
||||
logger.warning(
|
||||
"TensorRT and/or polygraphy packages are not available! trt_compile() has no effect."
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class TrtCompilerRegistry:
|
||||
"""
|
||||
Add-on class to be applied to higher-level module in caching situations
|
||||
Supports skip_once functionality by resetting registered sub-modules skip flags
|
||||
so they can skip the first forward() call and let the caches be filled
|
||||
"""
|
||||
|
||||
def __init__(self, model, function="forward", logger=None):
|
||||
self.logger = logger or getLogger("trt_compile")
|
||||
self.orig_function = getattr(model, function)
|
||||
setattr(model, function, MethodType(trt_registry_forward, model))
|
||||
self.registry = []
|
||||
|
||||
def register_skip_once(self, c):
|
||||
self.registry.append(c)
|
||||
|
||||
def reset_skip_once(self):
|
||||
for c in self.registry:
|
||||
c.skip_once = True
|
||||
|
||||
def forward(self, model, argv, kwargs):
|
||||
self.reset_skip_once()
|
||||
return self.orig_function(*argv, **kwargs)
|
||||
|
||||
|
||||
def trt_compile_make_registry(model, function="forward"):
|
||||
"""
|
||||
Instruments model or submodule(s) with TrtCompilerRegistry and replaces its forward() with TRT registry hook.
|
||||
"""
|
||||
if not hasattr(model, "_trt_compiler_registry"):
|
||||
wrapper = TrtCompilerRegistry(model, function)
|
||||
model._trt_compiler_registry = wrapper
|
||||
|
||||
return wrapper
|
||||
161
openfold/utils/tensorrt_utils.py
Normal file
161
openfold/utils/tensorrt_utils.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from .tensorrt_lazy_compiler import trt_compile
|
||||
|
||||
logger = logging.getLogger("trt_compile")
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
|
||||
def instrument_with_trt_compile(model, config):
|
||||
if config.trt.mode is None:
|
||||
return
|
||||
|
||||
if (
|
||||
config.globals.use_cuequivariance_attention
|
||||
or config.globals.use_cuequivariance_multiplicative_update
|
||||
):
|
||||
from cuequivariance_ops_torch.onnx import op_table
|
||||
from cuequivariance_ops_torch.tensorrt import register_plugins
|
||||
|
||||
register_plugins()
|
||||
else:
|
||||
op_table = None
|
||||
|
||||
engine_dir = config.trt.engine_dir
|
||||
os.makedirs(engine_dir, exist_ok=True)
|
||||
# Clean the directory if rebuilding
|
||||
if config.trt.mode == "build":
|
||||
for filename in os.listdir(engine_dir):
|
||||
file_path = os.path.join(engine_dir, filename)
|
||||
if os.path.isfile(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
# skip_once_registry = trt_compile_make_registry(model.structure_module, "sample")
|
||||
|
||||
S_MIN = 16
|
||||
S_MAX = config.trt.max_sequence_len
|
||||
|
||||
# TODO: use config for those numbers
|
||||
seq_dim = torch.export.Dim("seq_len", max=S_MAX)
|
||||
|
||||
evoformer_dynamic_shapes = {
|
||||
"m": {1: seq_dim},
|
||||
"z": {0: seq_dim, 1: seq_dim},
|
||||
"msa_mask": {1: seq_dim},
|
||||
"pair_mask": {0: seq_dim, 1: seq_dim},
|
||||
"chunk_size": None,
|
||||
"use_deepspeed_evo_attention": None,
|
||||
"use_cuequivariance_attention": None,
|
||||
"use_cuequivariance_multiplicative_update": None,
|
||||
"use_lma": None,
|
||||
"use_flash": None,
|
||||
"inplace_safe": None,
|
||||
"_mask_trans": None,
|
||||
}
|
||||
|
||||
def evoformer_profile_one(min_len, max_len):
|
||||
return {
|
||||
"m": [[516, min_len, 256], [516, max_len, 256], [516, max_len, 256]],
|
||||
"z": [
|
||||
[min_len, min_len, 128],
|
||||
[max_len, max_len, 128],
|
||||
[max_len, max_len, 128],
|
||||
],
|
||||
"msa_mask": [[516, min_len], [516, max_len], [516, max_len]],
|
||||
"pair_mask": [[min_len, min_len], [max_len, max_len], [max_len, max_len]],
|
||||
}
|
||||
|
||||
def msa_profile_one(min_len, max_len):
|
||||
return {
|
||||
"m": [[5120, min_len, 64], [5120, max_len, 64], [5120, max_len, 64]],
|
||||
"z": [
|
||||
[min_len, min_len, 128],
|
||||
[max_len, max_len, 128],
|
||||
[max_len, max_len, 128],
|
||||
],
|
||||
"msa_mask": [[5120, min_len], [5120, max_len], [5120, max_len]],
|
||||
"pair_mask": [[min_len, min_len], [max_len, max_len], [max_len, max_len]],
|
||||
}
|
||||
|
||||
def input_profiles(input_profile_one, num_profiles=1):
|
||||
if num_profiles == 4:
|
||||
return [
|
||||
input_profile_one(S_MIN, S_MAX // 4),
|
||||
input_profile_one(S_MAX // 4 + 1, S_MAX // 2),
|
||||
input_profile_one(S_MAX // 2 + 1, (S_MAX // 4) * 3),
|
||||
input_profile_one((S_MAX // 4) * 3, S_MAX),
|
||||
]
|
||||
elif num_profiles == 2:
|
||||
return [
|
||||
input_profile_one(S_MIN, S_MAX // 2),
|
||||
input_profile_one(S_MAX // 2 + 1, S_MAX),
|
||||
]
|
||||
else: # default: num_profiles = 1
|
||||
return [
|
||||
input_profile_one(S_MIN, S_MAX),
|
||||
]
|
||||
|
||||
evoformer_compile_args = {
|
||||
"precision": config.precision,
|
||||
"fallback": True,
|
||||
"input_profiles": input_profiles(
|
||||
evoformer_profile_one, config.trt.num_profiles
|
||||
),
|
||||
"export_args": {
|
||||
"opset_version": 20,
|
||||
"dynamo": True,
|
||||
"report": False,
|
||||
"dynamic_shapes": evoformer_dynamic_shapes,
|
||||
},
|
||||
"build_args": {
|
||||
"builder_optimization_level": config.trt.optimization_level,
|
||||
"precision_constraints": "prefer",
|
||||
},
|
||||
}
|
||||
|
||||
if op_table is not None:
|
||||
evoformer_compile_args["export_args"]["custom_translation_table"] = op_table
|
||||
|
||||
trt_compile(
|
||||
model.evoformer,
|
||||
f"{engine_dir}/EvoformerStack",
|
||||
args=evoformer_compile_args,
|
||||
logger=logger,
|
||||
)
|
||||
logger.info("model.evoformer instrumented")
|
||||
|
||||
"""
|
||||
msa_dynamic_shapes = copy.deepcopy(evoformer_dynamic_shapes)
|
||||
msa_dynamic_shapes.pop("use_flash")
|
||||
|
||||
msa_compile_args = copy.deepcopy(evoformer_compile_args)
|
||||
msa_compile_args["export_args"]["dynamic_shapes"] = msa_dynamic_shapes
|
||||
msa_compile_args["input_profiles"] = input_profiles(msa_profile_one, 2)
|
||||
|
||||
# Requires too much memory - also need to share context memory if using more than one engine
|
||||
if False: # model.extra_msa_config.enabled:
|
||||
trt_compile(
|
||||
model.extra_msa_stack,
|
||||
f"{engine_dir}/ExtraMSAStack",
|
||||
args = msa_compile_args,
|
||||
logger = logger
|
||||
)
|
||||
"""
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2022 AlQuraishi Laboratory
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -182,6 +183,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
]
|
||||
verify_arg_order(
|
||||
@@ -203,6 +205,7 @@ def trace_model_(model, sample_input):
|
||||
("mask", msa_mask),
|
||||
("chunk_size", torch.tensor(evoformer_chunk_size)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("use_flash", torch.tensor(model.globals.use_flash)),
|
||||
]
|
||||
@@ -286,6 +289,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("inplace_safe", torch.tensor(True)),
|
||||
]
|
||||
@@ -309,6 +313,7 @@ def trace_model_(model, sample_input):
|
||||
("chunk_size", torch.tensor(evoformer_attn_chunk_size)),
|
||||
("use_memory_efficient_kernel", torch.tensor(False)),
|
||||
("use_deepspeed_evo_attention", torch.tensor(model.globals.use_deepspeed_evo_attention)),
|
||||
("use_cuequivariance_attention", torch.tensor(model.globals.use_cuequivariance_attention)),
|
||||
("use_lma", torch.tensor(model.globals.use_lma)),
|
||||
("inplace_safe", torch.tensor(True)),
|
||||
]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
@@ -20,13 +22,13 @@ import os
|
||||
import pickle
|
||||
import random
|
||||
import time
|
||||
import torch
|
||||
import json
|
||||
|
||||
logging.basicConfig()
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(level=logging.INFO)
|
||||
|
||||
import torch
|
||||
torch_versions = torch.__version__.split(".")
|
||||
torch_major_version = int(torch_versions[0])
|
||||
torch_minor_version = int(torch_versions[1])
|
||||
@@ -183,13 +185,21 @@ def main(args):
|
||||
args.config_preset,
|
||||
long_sequence_inference=args.long_sequence_inference,
|
||||
use_deepspeed_evoformer_attention=args.use_deepspeed_evoformer_attention,
|
||||
)
|
||||
use_cuequivariance_attention=args.use_cuequivariance_attention,
|
||||
use_cuequivariance_multiplicative_update=args.use_cuequivariance_multiplicative_update,
|
||||
precision=args.precision,
|
||||
trt_mode=args.trt_mode,
|
||||
trt_engine_dir=args.trt_engine_dir,
|
||||
trt_num_profiles=args.trt_num_profiles,
|
||||
trt_optimization_level=args.trt_optimization_level,
|
||||
trt_max_sequence_len=args.trt_max_sequence_len,
|
||||
)
|
||||
|
||||
if args.experiment_config_json:
|
||||
with open(args.experiment_config_json, 'r') as f:
|
||||
custom_config_dict = json.load(f)
|
||||
config.update_from_flattened_dict(custom_config_dict)
|
||||
|
||||
|
||||
if args.trace_model:
|
||||
if not config.data.predict.fixed_size:
|
||||
raise ValueError(
|
||||
@@ -482,6 +492,38 @@ if __name__ == "__main__":
|
||||
"--use_deepspeed_evoformer_attention", action="store_true", default=False,
|
||||
help="Whether to use the DeepSpeed evoformer attention layer. Must have deepspeed installed in the environment.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuequivariance_attention", action="store_true", default=False,
|
||||
help="""Use cuEquivariance kernels for attention computation."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_cuequivariance_multiplicative_update", action="store_true", default=False,
|
||||
help="""Use cuEquivariance kernels for triangular multiplicative update computation."""
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trt_mode", type=str, default=None,
|
||||
help="build = Build engine; run = Run engine; None = Disable TRT"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trt_engine_dir", type=str, default=None,
|
||||
help="Absolute path to directory containing .onnx and .plan files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precision", type=str, default="tf32",
|
||||
help="tf32 | fp32 | fp16 | bf16"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trt_max_sequence_len", type=int, default=640,
|
||||
help="Maximum sequence length supported by TRT, default=640"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trt_num_profiles", type=int, default=1,
|
||||
help="1 = Single profile[50-800]; 2 = [50-200][200-800]; 4 = [50-100]; [100-200]; [200-400]; [400-800]"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--trt_optimization_level", type=int, default=3,
|
||||
help="Allowed values: 0 to 5"
|
||||
)
|
||||
add_data_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
25
setup.py
25
setup.py
@@ -1,5 +1,6 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright 2021 DeepMind Technologies Limited
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -52,15 +53,23 @@ def get_cuda_bare_metal_version(cuda_dir):
|
||||
return raw_output, bare_metal_major, bare_metal_minor
|
||||
|
||||
compute_capabilities = set([
|
||||
(5, 2), # Titan X
|
||||
(6, 1), # GeForce 1000-series
|
||||
(9, 0), # Hopper
|
||||
])
|
||||
|
||||
compute_capabilities.add((7, 0))
|
||||
_, bare_metal_major, _ = get_cuda_bare_metal_version(CUDA_HOME)
|
||||
if int(bare_metal_major) >= 11:
|
||||
compute_capabilities.add((8, 0))
|
||||
compute_capabilities.add((8, 6))
|
||||
compute_capabilities.add((8, 9))
|
||||
|
||||
if int(bare_metal_major) >= 12:
|
||||
compute_capabilities.add((9, 0))
|
||||
|
||||
if int(bare_metal_major) >= 13:
|
||||
compute_capabilities.add((10, 0))
|
||||
compute_capabilities.add((10, 3))
|
||||
compute_capabilities.add((12, 0))
|
||||
else:
|
||||
compute_capabilities.add((7, 0))
|
||||
|
||||
compute_capability, _ = get_nvidia_cc()
|
||||
if compute_capability is not None:
|
||||
@@ -75,8 +84,6 @@ for major, minor in list(compute_capabilities):
|
||||
|
||||
extra_cuda_flags += cc_flag
|
||||
|
||||
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
|
||||
|
||||
if bare_metal_major != -1:
|
||||
modules = [CUDAExtension(
|
||||
name="attn_core_inplace_cuda",
|
||||
@@ -127,6 +134,12 @@ setup(
|
||||
},
|
||||
ext_modules=modules,
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
extras_require={
|
||||
'cuequivariance': [
|
||||
'cuequivariance-torch; sys_platform != "darwin"', # Not available on macOS
|
||||
'triton>=3.3.0; sys_platform != "darwin"', # Required for triangle multiplicative update
|
||||
],
|
||||
},
|
||||
classifiers=[
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
|
||||
@@ -27,6 +27,9 @@ def skip_unless_ds4s_installed():
|
||||
"deepspeed.ops.deepspeed4science") is not None
|
||||
return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")
|
||||
|
||||
def skip_unless_cueq_installed():
|
||||
cueq_is_installed = importlib.util.find_spec("cuequivariance_torch") is not None
|
||||
return unittest.skipUnless(cueq_is_installed, "Requires cuEquivariance")
|
||||
|
||||
def skip_unless_flash_attn_installed():
|
||||
fa_is_installed = importlib.util.find_spec("flash_attn") is not None
|
||||
|
||||
162
tests/test_cuequivariance.py
Normal file
162
tests/test_cuequivariance.py
Normal file
@@ -0,0 +1,162 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Unit tests to compare components of OpenFold run with the cuEquivariance memory-efficient
|
||||
attention kernel vs. a stock PyTorch attention implementation.
|
||||
"""
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
import pickle
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from openfold.data import data_transforms
|
||||
from openfold.model.primitives import (
|
||||
lecun_normal_init_,
|
||||
Attention
|
||||
)
|
||||
from openfold.utils.tensor_utils import tensor_tree_map
|
||||
|
||||
from tests.config import consts
|
||||
import tests.compare_utils as compare_utils
|
||||
from tests.data_utils import random_template_feats, random_attention_inputs
|
||||
|
||||
|
||||
@compare_utils.skip_unless_cueq_installed()
|
||||
class TestCuEquivarianceKernel(unittest.TestCase):
|
||||
|
||||
def test_compare_template_stack(self):
|
||||
"""
|
||||
Compare Template Stack output with and without using DeepSpeed Evoformer attention kernel.
|
||||
Kernel can be used for Triangle Attention in the Template Pair Stack.
|
||||
"""
|
||||
n_templ = consts.n_templ
|
||||
n_res = 20
|
||||
eps = 2e-2
|
||||
|
||||
batch = random_template_feats(n_templ, n_res)
|
||||
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
|
||||
if consts.is_multimer:
|
||||
batch["asym_id"] = batch['asym_id'][0]
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
|
||||
|
||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
|
||||
template_feats = {
|
||||
k: v for k, v in batch.items() if k.startswith("template_")
|
||||
}
|
||||
|
||||
with torch.no_grad():
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_deepspeed_evo_attention = False
|
||||
out_repro = model.embed_templates(
|
||||
template_feats,
|
||||
batch,
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
templ_dim=0,
|
||||
inplace_safe=False
|
||||
)
|
||||
out_repro = out_repro["template_pair_embedding"].cpu()
|
||||
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
model.globals.use_cuequivariance_multiplicative_update = True
|
||||
|
||||
out_repro_ds = model.embed_templates(
|
||||
template_feats,
|
||||
batch,
|
||||
torch.as_tensor(pair_act).cuda(),
|
||||
torch.as_tensor(pair_mask).cuda(),
|
||||
templ_dim=0,
|
||||
inplace_safe=False
|
||||
)
|
||||
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
|
||||
|
||||
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
|
||||
|
||||
def test_compare_model(self):
|
||||
"""
|
||||
Run full model with and without using CuEquivariance Evoformer attention kernel
|
||||
and compare output coordinates.
|
||||
"""
|
||||
eps = 0.2
|
||||
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
|
||||
batch = pickle.load(fp)
|
||||
|
||||
# atom37_to_atom14 doesn't like batches
|
||||
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
|
||||
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
|
||||
|
||||
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
|
||||
|
||||
if consts.is_multimer:
|
||||
n_res = batch['aatype'].shape[1]
|
||||
n_extra_seq = batch['extra_msa'].shape[1]
|
||||
batch["asym_id"] = np.ones((4, n_res))
|
||||
batch["entity_id"] = np.ones((4, n_res))
|
||||
batch["sym_id"] = np.ones((4, n_res))
|
||||
batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
|
||||
|
||||
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
|
||||
|
||||
batch["aatype"] = batch["aatype"].long()
|
||||
batch["template_aatype"] = batch["template_aatype"].long()
|
||||
batch["extra_msa"] = batch["extra_msa"].long()
|
||||
batch["residx_atom37_to_atom14"] = batch[
|
||||
"residx_atom37_to_atom14"
|
||||
].long()
|
||||
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
|
||||
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
|
||||
batch.update(
|
||||
data_transforms.atom37_to_torsion_angles("template_")(batch)
|
||||
)
|
||||
|
||||
# Move the recycling dimension to the end
|
||||
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
|
||||
batch = tensor_tree_map(move_dim, batch)
|
||||
# Restrict this test to use only torch.float32 precision due to instability with torch.bfloat16
|
||||
# https://github.com/aqlaboratory/openfold/issues/532
|
||||
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
|
||||
model = compare_utils.get_global_pretrained_openfold()
|
||||
model.globals.use_deepspeed_evo_attention = False
|
||||
model.globals.use_cuequivariance_attention = False
|
||||
model.globals.use_cuequivariance_multiplicative_update = False
|
||||
out_repro = model(batch)
|
||||
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
|
||||
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
# Enable attention
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
out_repro_attn = model(batch)
|
||||
out_repro_attn = tensor_tree_map(lambda t: t.cpu(), out_repro_attn)
|
||||
out_repro_attn = out_repro_attn["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_attn, eps)
|
||||
|
||||
# Enable multiplication
|
||||
model.globals.use_cuequivariance_attention = True
|
||||
model.globals.use_cuequivariance_multiplicative_update = True
|
||||
out_repro_mul = model(batch)
|
||||
out_repro_mul = tensor_tree_map(lambda t: t.cpu(), out_repro_mul)
|
||||
out_repro_mul = out_repro_mul["sm"]["positions"][-1].squeeze(0)
|
||||
|
||||
compare_utils.assert_mean_abs_diff_small(out_repro_attn, out_repro_mul, eps)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2021 AlQuraishi Laboratory
|
||||
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -113,7 +114,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
def test_tri_mul_in_compare(self):
|
||||
self._tri_mul_compare(incoming=True)
|
||||
|
||||
def _tri_mul_inplace(self, incoming=False):
|
||||
def _tri_mul_inplace(self, incoming=False, dtype = torch.float32):
|
||||
n_res = consts.n_res
|
||||
|
||||
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
|
||||
@@ -126,26 +127,38 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
|
||||
if incoming
|
||||
else model.evoformer.blocks[0].pair_stack.tri_mul_out
|
||||
)
|
||||
|
||||
act = torch.as_tensor(pair_act, dtype=dtype).cuda()
|
||||
mask = torch.as_tensor(pair_mask, dtype=dtype).cuda()
|
||||
module = module.to(dtype=dtype)
|
||||
|
||||
out_stock = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
act,
|
||||
mask=mask,
|
||||
inplace_safe=False,
|
||||
).cpu()
|
||||
)
|
||||
|
||||
# This has to come second because inference mode is in-place
|
||||
out_inplace = module(
|
||||
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
|
||||
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
|
||||
act,
|
||||
mask=mask,
|
||||
inplace_safe=True, _inplace_chunk_size=2,
|
||||
).cpu()
|
||||
)
|
||||
|
||||
self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
|
||||
torch.testing.assert_close(out_stock, out_inplace, rtol=0.1, atol=0.1)
|
||||
|
||||
|
||||
def test_tri_mul_out_inference(self):
|
||||
self._tri_mul_inplace()
|
||||
|
||||
def test_tri_mul_out_inference_bf16(self):
|
||||
self._tri_mul_inplace(dtype=torch.bfloat16)
|
||||
|
||||
def test_tri_mul_in_inference(self):
|
||||
self._tri_mul_inplace(incoming=True)
|
||||
|
||||
def test_tri_mul_in_inference_bf16(self):
|
||||
self._tri_mul_inplace(incoming=True, dtype=torch.bfloat16)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user