mirror of
https://github.com/KosinskiLab/AlphaPulldown.git
synced 2026-06-04 14:14:24 +08:00
Skip GPU functional suites on CI without override
This commit is contained in:
@@ -13,6 +13,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@@ -38,6 +39,37 @@ os.environ["JAX_COMPILATION_CACHE_DIR"] = "/scratch/dima/jax_cache"
|
||||
# from alphafold.model import config
|
||||
# config.CONFIG_MULTIMER.model.embeddings_and_evoformer.evoformer_num_block = 1
|
||||
|
||||
|
||||
def _has_nvidia_gpu() -> bool:
|
||||
nvidia_smi = shutil.which("nvidia-smi")
|
||||
if not nvidia_smi:
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[nvidia_smi, "-L"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except OSError:
|
||||
return False
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
|
||||
|
||||
def _gpu_functional_test_skip_reason() -> str | None:
|
||||
if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
||||
return None
|
||||
if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv(
|
||||
"GITHUB_ACTIONS", ""
|
||||
).lower() == "true":
|
||||
return (
|
||||
"GPU functional tests are disabled on CI/CD. "
|
||||
"Set RUN_GPU_FUNCTIONAL_TESTS=1 to override."
|
||||
)
|
||||
if not _has_nvidia_gpu():
|
||||
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
|
||||
return None
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -47,6 +79,9 @@ class _TestBase(parameterized.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
skip_reason = _gpu_functional_test_skip_reason()
|
||||
if skip_reason:
|
||||
raise unittest.SkipTest(skip_reason)
|
||||
# do the skip here so import-time doesn't abort discovery
|
||||
#if not DATA_DIR.is_dir():
|
||||
# cls.skipTest(f"set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests (tried {DATA_DIR!r})")
|
||||
@@ -405,4 +440,4 @@ class TestDropoutDiversity(_TestBase):
|
||||
# The test passes if calculations succeed - the diversity check is informational
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
absltest.main()
|
||||
|
||||
@@ -18,6 +18,7 @@ import pickle
|
||||
import json
|
||||
import numpy as np
|
||||
import re
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple, Any
|
||||
|
||||
from absl.testing import absltest, parameterized
|
||||
@@ -45,6 +46,37 @@ if not os.path.exists(DATA_DIR):
|
||||
absltest.skip("set $ALPHAFOLD_DATA_DIR to run Alphafold functional tests")
|
||||
|
||||
|
||||
def _has_nvidia_gpu() -> bool:
|
||||
nvidia_smi = shutil.which("nvidia-smi")
|
||||
if not nvidia_smi:
|
||||
return False
|
||||
try:
|
||||
result = subprocess.run(
|
||||
[nvidia_smi, "-L"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except OSError:
|
||||
return False
|
||||
return result.returncode == 0 and bool(result.stdout.strip())
|
||||
|
||||
|
||||
def _gpu_functional_test_skip_reason() -> str | None:
|
||||
if os.getenv("RUN_GPU_FUNCTIONAL_TESTS", "").lower() in ("1", "true", "yes"):
|
||||
return None
|
||||
if os.getenv("CI", "").lower() in ("1", "true", "yes") or os.getenv(
|
||||
"GITHUB_ACTIONS", ""
|
||||
).lower() == "true":
|
||||
return (
|
||||
"GPU functional tests are disabled on CI/CD. "
|
||||
"Set RUN_GPU_FUNCTIONAL_TESTS=1 to override."
|
||||
)
|
||||
if not _has_nvidia_gpu():
|
||||
return "GPU functional tests require an NVIDIA GPU and nvidia-smi."
|
||||
return None
|
||||
|
||||
|
||||
# --------------------------------------------------------------------------- #
|
||||
# common helper mix-in / assertions #
|
||||
# --------------------------------------------------------------------------- #
|
||||
@@ -54,6 +86,9 @@ class _TestBase(parameterized.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
skip_reason = _gpu_functional_test_skip_reason()
|
||||
if skip_reason:
|
||||
raise unittest.SkipTest(skip_reason)
|
||||
# Create a base directory for all test outputs
|
||||
if cls.use_temp_dir:
|
||||
cls.base_output_dir = Path(tempfile.mkdtemp(prefix="af3_test_"))
|
||||
|
||||
Reference in New Issue
Block a user