xla flash attention env

This commit is contained in:
Dima
2026-03-19 13:14:05 +01:00
parent 24faf843bb
commit 630fd6028c

View File

@@ -888,15 +888,19 @@ class _TestBase(parameterized.TestCase):
print(f" ✓ Chain {chain_id}: Valid sequence with correct chain ID")
def _make_af3_test_env(self) -> Dict[str, str]:
flash_impl = self._af3_flash_attention_impl()
env = os.environ.copy()
env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0"
env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
env["XLA_CLIENT_MEM_FRACTION"] = "0.95"
env["JAX_FLASH_ATTENTION_IMPL"] = "xla"
env["JAX_FLASH_ATTENTION_IMPL"] = flash_impl
if "XLA_PYTHON_CLIENT_MEM_FRACTION" in env:
del env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
return env
def _af3_flash_attention_impl(self) -> str:
return os.getenv("AF3_TEST_FLASH_ATTENTION_IMPL", "xla")
def _require_af3_functional_environment(self) -> None:
if not os.path.exists(DATA_DIR):
self.skipTest(
@@ -971,6 +975,7 @@ class _TestBase(parameterized.TestCase):
# convenience builder
def _args(self, *, plist, script):
flash_impl = self._af3_flash_attention_impl()
# Determine mode from protein list name
if "homooligomer" in plist:
mode = "homo-oligomer"
@@ -998,7 +1003,7 @@ class _TestBase(parameterized.TestCase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
]
# Add special arguments for multi_seeds_samples test
@@ -1026,7 +1031,7 @@ class _TestBase(parameterized.TestCase):
+ f"={self.test_protein_lists_dir / plist}",
# Ensure AF3 backend and keep runtime small
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
]
return args
@@ -1201,6 +1206,7 @@ class TestAlphaFold3RunModes(_TestBase):
"""A single explicit output dir must remain flat even with --use_ap_style."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_input = self.test_features_dir / "protein_with_ptms.json"
res = subprocess.run(
@@ -1212,7 +1218,7 @@ class TestAlphaFold3RunModes(_TestBase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--use_ap_style",
],
@@ -1230,6 +1236,7 @@ class TestAlphaFold3RunModes(_TestBase):
"""Multiple AF3 JSON jobs sharing one root must be split into per-job subdirectories."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
json_inputs = [
self.test_features_dir / "protein_with_ptms.json",
self.test_features_dir / "test_alphafold3_prediction.json",
@@ -1244,7 +1251,7 @@ class TestAlphaFold3RunModes(_TestBase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--use_ap_style",
],
@@ -1272,6 +1279,7 @@ class TestAlphaFold3RunModes(_TestBase):
"""Shared AF3 wrapper output roots must isolate multiple jobs by subdirectory."""
self._require_af3_functional_environment()
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
protein_list = self.test_protein_lists_dir / "test_multiple_monomers.txt"
res = subprocess.run(
@@ -1286,7 +1294,7 @@ class TestAlphaFold3RunModes(_TestBase):
"--mode=custom",
f"--protein_lists={protein_list}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
],
capture_output=True,
@@ -1356,14 +1364,7 @@ class TestAlphaFold3RunModes(_TestBase):
)
def test_(self, protein_list, script):
# Create environment with GPU settings
env = os.environ.copy()
env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0"
env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
env["XLA_CLIENT_MEM_FRACTION"] = "0.95"
env["JAX_FLASH_ATTENTION_IMPL"] = "xla"
# Remove deprecated variable if present
if "XLA_PYTHON_CLIENT_MEM_FRACTION" in env:
del env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
env = self._make_af3_test_env()
# Debug output
print("\nEnvironment variables:")
@@ -1395,13 +1396,8 @@ class TestAlphaFold3RunModes(_TestBase):
def test_af3_writes_embeddings_and_distogram(self):
"""Run AF3 with embeddings and distogram enabled and check files exist."""
env = os.environ.copy()
env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0"
env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
env["XLA_CLIENT_MEM_FRACTION"] = "0.95"
env["JAX_FLASH_ATTENTION_IMPL"] = "xla"
if "XLA_PYTHON_CLIENT_MEM_FRACTION" in env:
del env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
args = [
sys.executable,
@@ -1411,7 +1407,7 @@ class TestAlphaFold3RunModes(_TestBase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--save_embeddings",
"--save_distogram",
"--num_diffusion_samples=1",
@@ -1456,13 +1452,8 @@ class TestAlphaFold3RunModes(_TestBase):
def test_af3_num_recycles_affects_runtime(self):
"""num_recycles=1 should be faster than default (keeping other knobs same)."""
env = os.environ.copy()
env["XLA_FLAGS"] = "--xla_disable_hlo_passes=custom-kernel-fusion-rewriter --xla_gpu_force_compilation_parallelism=0"
env["XLA_PYTHON_CLIENT_PREALLOCATE"] = "true"
env["XLA_CLIENT_MEM_FRACTION"] = "0.95"
env["JAX_FLASH_ATTENTION_IMPL"] = "xla"
if "XLA_PYTHON_CLIENT_MEM_FRACTION" in env:
del env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
common = [
sys.executable,
@@ -1472,7 +1463,7 @@ class TestAlphaFold3RunModes(_TestBase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
"--num_seeds=2", # ensures second seed reuses compiled XLA and timing reflects compute
]
@@ -1507,8 +1498,8 @@ class TestAlphaFold3RunModes(_TestBase):
def test_af3_rejects_alphafold2_flag(self):
"""Passing AF2-only flags to AF3 backend should fail via validator."""
env = os.environ.copy()
env["JAX_FLASH_ATTENTION_IMPL"] = "xla"
env = self._make_af3_test_env()
flash_impl = self._af3_flash_attention_impl()
args = [
sys.executable,
@@ -1518,7 +1509,7 @@ class TestAlphaFold3RunModes(_TestBase):
f"--data_directory={DATA_DIR}",
f"--features_directory={self.test_features_dir}",
"--fold_backend=alphafold3",
"--flash_attention_implementation=xla",
f"--flash_attention_implementation={flash_impl}",
"--num_diffusion_samples=1",
# Intentionally invalid for AF3:
"--num_predictions_per_model=1",