diff --git a/docs/guide/troubleshooting.rst b/docs/guide/troubleshooting.rst index b85fb218..ae821105 100644 --- a/docs/guide/troubleshooting.rst +++ b/docs/guide/troubleshooting.rst @@ -73,3 +73,23 @@ If the necessary libraries for GPU acceleration are not installed and JAX detect This warning does not mean that the *molecular dynamics* simulation will fall back to using the CPU. The simulation will still use the computing platform specified in the settings. + +PYMBAR_DISABLE_JAX +------------------ + +Due to a suspected memory leak in the JAX acceleration code in ``pymbar`` we disable JAX acceleration by default. +This memory leak may result in the simulation crashing, wasting compute time. +The error message may look like this: + +.. code-block:: bash + + LLVM compilation error: Cannot allocate memory + LLVM ERROR: Unable to allocate section memory! + +We have decided to disable JAX acceleration by default to prevent wasted compute. +However, if you wish to use the JAX acceleration, you may set ``PYMBAR_DISABLE_JAX`` to ``TRUE`` (e.g. put ``export PYMBAR_DISABLE_JAX=FALSE`` in your submission script before running ``openfe quickrun``). +For more information, see these issues on github: + +- https://github.com/choderalab/pymbar/issues/564 +- https://github.com/OpenFreeEnergy/openfe/issues/1534 +- https://github.com/OpenFreeEnergy/openfe/issues/1654 diff --git a/news/jax-warning.rst b/news/jax-warning.rst index f50b3dbd..57aebff8 100644 --- a/news/jax-warning.rst +++ b/news/jax-warning.rst @@ -3,6 +3,8 @@ * Emit a clarifying log message when a user gets a warning from JAX (#1585). Fixes #1499. +* Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (#1694). + **Changed:** * diff --git a/openfe/__init__.py b/openfe/__init__.py index d56e7340..2ec7c009 100644 --- a/openfe/__init__.py +++ b/openfe/__init__.py @@ -1,3 +1,22 @@ +# Before we do anything else, we want to disable JAX +# acceleration by default but if a user has set +# PYMBAR_DISABLE_JAX to some value, we want to keep +# it + +import logging +import os + +logger = logging.getLogger(__name__) + +if "PYMBAR_DISABLE_JAX" not in os.environ: + logger.warn( + "PYMBAR_DISABLE_JAX not set, setting to TRUE, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more details" + ) + +# setdefault will only set PYMBAR_DISABLE_JAX if it is unset +os.environ.setdefault("PYMBAR_DISABLE_JAX", "TRUE") + + # We need to do this first so that we can set up our # log control since some modules have warnings on import from openfe.utils import logging_control