From 97eb26b2fe04386205bd8049033cc14f022d1fa3 Mon Sep 17 00:00:00 2001 From: Ishaan Mathur <42471598+imathur1@users.noreply.github.com> Date: Thu, 18 Sep 2025 13:54:10 -0400 Subject: [PATCH] pyright ignore flash_attn import because it is guarded (#274) --- .github/workflows/ci.yml | 8 ++++---- esm/layers/attention.py | 2 +- tests/oss_pytests/test_oss_client.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 53e317e..9f46799 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -25,9 +25,9 @@ jobs: uses: actions/checkout@v4 - name: Setup Environment - uses: prefix-dev/setup-pixi@v0.8.1 + uses: prefix-dev/setup-pixi@v0.9.0 with: - pixi-version: v0.47.0 + pixi-version: v0.54.0 cache: false cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} @@ -43,9 +43,9 @@ jobs: uses: actions/checkout@v4 - name: Setup Environment - uses: prefix-dev/setup-pixi@v0.8.1 + uses: prefix-dev/setup-pixi@v0.9.0 with: - pixi-version: v0.47.0 + pixi-version: v0.54.0 cache: false cache-write: ${{ github.event_name == 'push' && github.ref_name == 'main' }} diff --git a/esm/layers/attention.py b/esm/layers/attention.py index 564ef90..964d428 100644 --- a/esm/layers/attention.py +++ b/esm/layers/attention.py @@ -8,7 +8,7 @@ from torch import nn from esm.layers.rotary import RotaryEmbedding, TritonRotaryEmbedding try: - from flash_attn import flash_attn_varlen_qkvpacked_func + from flash_attn import flash_attn_varlen_qkvpacked_func # type: ignore except (ImportError, RuntimeError): flash_attn_varlen_qkvpacked_func = None diff --git a/tests/oss_pytests/test_oss_client.py b/tests/oss_pytests/test_oss_client.py index 6fa4a30..9dd5b18 100644 --- a/tests/oss_pytests/test_oss_client.py +++ b/tests/oss_pytests/test_oss_client.py @@ -53,7 +53,7 @@ def test_oss_esm3_client(): def test_oss_esmc_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVUUPDPAAA" + sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esmc-300m-2024-12" esmc_client = client(model=model, url=URL, token=API_TOKEN) @@ -75,7 +75,7 @@ def test_oss_esmc_client(): def test_oss_sequence_structure_forge_inference_client(): assert URL is not None - sequence = "MALWMRLLPLLALLALAVUUPDPAAA" + sequence = "MALWMRLLPLLALLALAVPDPAAA" model = "esm3-small-2024-03" client = SequenceStructureForgeInferenceClient( model=model, url=URL, token=API_TOKEN