Disable JAX acceleration by default (#1694)

* Disable JAX acceleration by default

* ruff fmt

* add logging info

* fix url

* fix list

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add example of error message

* added note about disabling jax acel by default

---------

Co-authored-by: Irfan Alibay <IAlibay@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
Mike Henry
2025-12-02 14:12:24 -07:00
committed by GitHub
parent 78d0f901f8
commit f491184b58
3 changed files with 41 additions and 0 deletions

View File

@@ -73,3 +73,23 @@ If the necessary libraries for GPU acceleration are not installed and JAX detect
This warning does not mean that the *molecular dynamics* simulation will fall back to using the CPU.
The simulation will still use the computing platform specified in the settings.
PYMBAR_DISABLE_JAX
------------------
Due to a suspected memory leak in the JAX acceleration code in ``pymbar`` we disable JAX acceleration by default.
This memory leak may result in the simulation crashing, wasting compute time.
The error message may look like this:
.. code-block:: bash
LLVM compilation error: Cannot allocate memory
LLVM ERROR: Unable to allocate section memory!
We have decided to disable JAX acceleration by default to prevent wasted compute.
However, if you wish to use the JAX acceleration, you may set ``PYMBAR_DISABLE_JAX`` to ``TRUE`` (e.g. put ``export PYMBAR_DISABLE_JAX=FALSE`` in your submission script before running ``openfe quickrun``).
For more information, see these issues on github:
- https://github.com/choderalab/pymbar/issues/564
- https://github.com/OpenFreeEnergy/openfe/issues/1534
- https://github.com/OpenFreeEnergy/openfe/issues/1654

View File

@@ -3,6 +3,8 @@
* Emit a clarifying log message when a user gets a warning from JAX (#1585).
Fixes #1499.
* Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (#1694).
**Changed:**
* <news item>

View File

@@ -1,3 +1,22 @@
# Before we do anything else, we want to disable JAX
# acceleration by default but if a user has set
# PYMBAR_DISABLE_JAX to some value, we want to keep
# it
import logging
import os
logger = logging.getLogger(__name__)
if "PYMBAR_DISABLE_JAX" not in os.environ:
logger.warn(
"PYMBAR_DISABLE_JAX not set, setting to TRUE, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more details"
)
# setdefault will only set PYMBAR_DISABLE_JAX if it is unset
os.environ.setdefault("PYMBAR_DISABLE_JAX", "TRUE")
# We need to do this first so that we can set up our
# log control since some modules have warnings on import
from openfe.utils import logging_control