mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
fix gpu issues
This commit is contained in:
29765
examples/aarti/9he1.cif
29765
examples/aarti/9he1.cif
File diff suppressed because it is too large
Load Diff
@@ -1,27 +0,0 @@
|
||||
entities:
|
||||
- file:
|
||||
path:
|
||||
- ../nanobodies/7eow.yaml
|
||||
- ../nanobodies/7xl0.yaml
|
||||
- ../nanobodies/8coh.yaml
|
||||
- ../nanobodies/8z8v.yaml
|
||||
- file:
|
||||
path: 9he1.cif
|
||||
|
||||
include:
|
||||
- chain:
|
||||
id: A
|
||||
- chain:
|
||||
id: B
|
||||
res_index: 43..
|
||||
|
||||
include_proximity:
|
||||
- chain:
|
||||
id: B
|
||||
radius: 20
|
||||
|
||||
|
||||
# binding_types:
|
||||
# - chain:
|
||||
# id: A
|
||||
# binding: 46,63,64,66,69,70,73,95,96,109,117,120,121,122,141,162,163,164,167,184,185,187,204,206,207,228,229,233,235,272,273,274,297,298
|
||||
@@ -1,23 +0,0 @@
|
||||
entities:
|
||||
- protein:
|
||||
id: A
|
||||
sequence: 80..140
|
||||
- file:
|
||||
path: 9he1.cif
|
||||
|
||||
include:
|
||||
- chain:
|
||||
id: A
|
||||
- chain:
|
||||
id: B
|
||||
res_index: 43..
|
||||
|
||||
include_proximity:
|
||||
- chain:
|
||||
id: B
|
||||
radius: 20
|
||||
|
||||
# binding_types:
|
||||
# - chain:
|
||||
# id: A
|
||||
# binding: 46,63,64,66,69,70,73,95,96,109,117,120,121,122,141,162,163,164,167,184,185,187,204,206,207,228,229,233,235,272,273,274,297,298
|
||||
@@ -3,16 +3,6 @@ from torch import Tensor, nn
|
||||
|
||||
from boltzgen.model.layers import initialize as init
|
||||
|
||||
_cueq_available = False
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import (
|
||||
triangle_multiplicative_update as _triangle_multiplicative_update,
|
||||
)
|
||||
|
||||
_cueq_available = True
|
||||
except ModuleNotFoundError:
|
||||
_cueq_available = False
|
||||
|
||||
|
||||
@torch.compiler.disable # noqa: E402 – decorator must follow import of torch
|
||||
def _kernel_triangular_mult(
|
||||
@@ -30,7 +20,11 @@ def _kernel_triangular_mult(
|
||||
g_out_weight: Tensor,
|
||||
eps: float,
|
||||
):
|
||||
if not _cueq_available:
|
||||
try:
|
||||
from cuequivariance_torch.primitives.triangle import (
|
||||
triangle_multiplicative_update as _triangle_multiplicative_update,
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
raise RuntimeError(
|
||||
"cuEquivariance kernels requested via use_kernels=True but the package is not available."
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
|
||||
from boltzgen.model.layers.triangular_attention.utils import (
|
||||
flatten_final_dims,
|
||||
@@ -265,6 +264,7 @@ def _attention(
|
||||
|
||||
@torch.compiler.disable
|
||||
def kernel_triangular_attn(q, k, v, tri_bias, mask, scale):
|
||||
from cuequivariance_torch.primitives.triangle import triangle_attention
|
||||
return triangle_attention(q, k, v, tri_bias, mask=mask, scale=scale)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user