fix gpu issues

This commit is contained in:
HannesStark
2025-10-31 00:31:26 +00:00
parent 3f241ab24d
commit 63fda400e7
5 changed files with 6 additions and 29827 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
entities:
- file:
path:
- ../nanobodies/7eow.yaml
- ../nanobodies/7xl0.yaml
- ../nanobodies/8coh.yaml
- ../nanobodies/8z8v.yaml
- file:
path: 9he1.cif
include:
- chain:
id: A
- chain:
id: B
res_index: 43..
include_proximity:
- chain:
id: B
radius: 20
# binding_types:
# - chain:
# id: A
# binding: 46,63,64,66,69,70,73,95,96,109,117,120,121,122,141,162,163,164,167,184,185,187,204,206,207,228,229,233,235,272,273,274,297,298

View File

@@ -1,23 +0,0 @@
entities:
- protein:
id: A
sequence: 80..140
- file:
path: 9he1.cif
include:
- chain:
id: A
- chain:
id: B
res_index: 43..
include_proximity:
- chain:
id: B
radius: 20
# binding_types:
# - chain:
# id: A
# binding: 46,63,64,66,69,70,73,95,96,109,117,120,121,122,141,162,163,164,167,184,185,187,204,206,207,228,229,233,235,272,273,274,297,298

View File

@@ -3,16 +3,6 @@ from torch import Tensor, nn
from boltzgen.model.layers import initialize as init
_cueq_available = False
try:
from cuequivariance_torch.primitives.triangle import (
triangle_multiplicative_update as _triangle_multiplicative_update,
)
_cueq_available = True
except ModuleNotFoundError:
_cueq_available = False
@torch.compiler.disable # noqa: E402 decorator must follow import of torch
def _kernel_triangular_mult(
@@ -30,7 +20,11 @@ def _kernel_triangular_mult(
g_out_weight: Tensor,
eps: float,
):
if not _cueq_available:
try:
from cuequivariance_torch.primitives.triangle import (
triangle_multiplicative_update as _triangle_multiplicative_update,
)
except ModuleNotFoundError:
raise RuntimeError(
"cuEquivariance kernels requested via use_kernels=True but the package is not available."
)

View File

@@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Tuple
import numpy as np
import torch
from cuequivariance_torch.primitives.triangle import triangle_attention
from boltzgen.model.layers.triangular_attention.utils import (
flatten_final_dims,
@@ -265,6 +264,7 @@ def _attention(
@torch.compiler.disable
def kernel_triangular_attn(q, k, v, tri_bias, mask, scale):
from cuequivariance_torch.primitives.triangle import triangle_attention
return triangle_attention(q, k, v, tri_bias, mask=mask, scale=scale)