pyright ignore flash_attn import because it is guarded (#274)

This commit is contained in:
Ishaan Mathur
2025-09-18 13:54:10 -04:00
committed by GitHub
parent cbaccc1693
commit 97eb26b2fe
3 changed files with 7 additions and 7 deletions

View File

@@ -25,9 +25,9 @@ jobs:
uses: actions/checkout@v4
- name: Setup Environment
uses: prefix-dev/setup-pixi@v0.8.1
uses: prefix-dev/setup-pixi@v0.9.0
with:
pixi-version: v0.47.0
pixi-version: v0.54.0
cache: false
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}
@@ -43,9 +43,9 @@ jobs:
uses: actions/checkout@v4
- name: Setup Environment
uses: prefix-dev/setup-pixi@v0.8.1
uses: prefix-dev/setup-pixi@v0.9.0
with:
pixi-version: v0.47.0
pixi-version: v0.54.0
cache: false
cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }}

View File

@@ -8,7 +8,7 @@ from torch import nn
from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding
try:
from flash_attn import flash_attn_varlen_qkvpacked_func
from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore
except (ImportError, RuntimeError):
flash_attn_varlen_qkvpacked_func = None

View File

@@ -53,7 +53,7 @@ def test_oss_esm3_client():
def test_oss_esmc_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
sequence = "MALWMRLLPLLALLALAVPDPAAA"
model = "esmc-300m-2024-12"
esmc_client = client(model=model, url=URL, token=API_TOKEN)
@@ -75,7 +75,7 @@ def test_oss_esmc_client():
def test_oss_sequence_structure_forge_inference_client():
assert URL is not None
sequence = "MALWMRLLPLLALLALAVUUPDPAAA"
sequence = "MALWMRLLPLLALLALAVPDPAAA"
model = "esm3-small-2024-03"
client = SequenceStructureForgeInferenceClient(
model=model, url=URL, token=API_TOKEN