mirror of
https://github.com/OpenFreeEnergy/openfe.git
synced 2026-06-04 14:14:22 +08:00
Disable JAX acceleration by default (#1694)
* Disable JAX acceleration by default * ruff fmt * add logging info * fix url * fix list * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add example of error message * added note about disabling jax acel by default --------- Co-authored-by: Irfan Alibay <IAlibay@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -73,3 +73,23 @@ If the necessary libraries for GPU acceleration are not installed and JAX detect
|
||||
|
||||
This warning does not mean that the *molecular dynamics* simulation will fall back to using the CPU.
|
||||
The simulation will still use the computing platform specified in the settings.
|
||||
|
||||
PYMBAR_DISABLE_JAX
|
||||
------------------
|
||||
|
||||
Due to a suspected memory leak in the JAX acceleration code in ``pymbar`` we disable JAX acceleration by default.
|
||||
This memory leak may result in the simulation crashing, wasting compute time.
|
||||
The error message may look like this:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
LLVM compilation error: Cannot allocate memory
|
||||
LLVM ERROR: Unable to allocate section memory!
|
||||
|
||||
We have decided to disable JAX acceleration by default to prevent wasted compute.
|
||||
However, if you wish to use the JAX acceleration, you may set ``PYMBAR_DISABLE_JAX`` to ``TRUE`` (e.g. put ``export PYMBAR_DISABLE_JAX=FALSE`` in your submission script before running ``openfe quickrun``).
|
||||
For more information, see these issues on github:
|
||||
|
||||
- https://github.com/choderalab/pymbar/issues/564
|
||||
- https://github.com/OpenFreeEnergy/openfe/issues/1534
|
||||
- https://github.com/OpenFreeEnergy/openfe/issues/1654
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
* Emit a clarifying log message when a user gets a warning from JAX (#1585).
|
||||
Fixes #1499.
|
||||
|
||||
* Disable JAX acceleration by default, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more information (#1694).
|
||||
|
||||
**Changed:**
|
||||
|
||||
* <news item>
|
||||
|
||||
@@ -1,3 +1,22 @@
|
||||
# Before we do anything else, we want to disable JAX
|
||||
# acceleration by default but if a user has set
|
||||
# PYMBAR_DISABLE_JAX to some value, we want to keep
|
||||
# it
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if "PYMBAR_DISABLE_JAX" not in os.environ:
|
||||
logger.warn(
|
||||
"PYMBAR_DISABLE_JAX not set, setting to TRUE, see https://docs.openfree.energy/en/latest/guide/troubleshooting.html#pymbar-disable-jax for more details"
|
||||
)
|
||||
|
||||
# setdefault will only set PYMBAR_DISABLE_JAX if it is unset
|
||||
os.environ.setdefault("PYMBAR_DISABLE_JAX", "TRUE")
|
||||
|
||||
|
||||
# We need to do this first so that we can set up our
|
||||
# log control since some modules have warnings on import
|
||||
from openfe.utils import logging_control
|
||||
|
||||
Reference in New Issue
Block a user