Remove unused utility functions

PiperOrigin-RevId: 875760003
Change-Id: I23d7759ddf0835394dbd6b6ab3c304fa52c6a9b1
This commit is contained in:
Augustin Zidek
2026-02-26 09:31:46 -08:00
committed by Copybara-Service
parent 1e9ece2cfb
commit 6daafb546f

View File

@@ -13,72 +13,9 @@
from collections.abc import Iterable
import numbers
import jax
from jax import lax
import jax.numpy as jnp
def safe_select(condition, true_fn, false_fn):
"""Safe version of selection (i.e. `where`).
This applies the double-where trick.
Like jnp.where, this function will still execute both branches and is
expected to be more lightweight than lax.cond. Other than NaN-semantics,
safe_select(condition, true_fn, false_fn) is equivalent to
jax.tree.map(lambda x, y: jnp.where(condition, x, y),
true_fn(),
false_fn()),
Compared to the naive implementation above, safe_select provides the
following guarantee: in either the forward or backward pass, a NaN produced
*during the execution of true_fn()* will not propagate to the rest of the
computation and similarly for false_fn. It is very important to note that
while true_fn and false_fn will typically close over other tensors (i.e. they
use values computed prior to the safe_select function), there is no NaN-safety
for the backward pass of closed over values. It is important than any NaN's
are produced within the branch functions and not before them. For example,
safe_select(x < eps, lambda: 0., lambda: jnp.sqrt(x))
will not produce NaN on the backward pass even if x == 0. since sqrt happens
within the false_fn, but the very similar
y = jnp.sqrt(x)
safe_select(x < eps, lambda: 0., lambda: y)
will produce a NaN on the backward pass if x == 0 because the sqrt happens
prior to the false_fn.
Args:
condition: Boolean array to use in where
true_fn: Zero-argument function to construct the values used in the True
condition. Tensors that this function closes over will be extracted
automatically to implement the double-where trick to suppress spurious NaN
propagation.
false_fn: False branch equivalent of true_fn
Returns:
Resulting PyTree equivalent to tree_map line above.
"""
true_fn, true_args = jax.closure_convert(true_fn)
false_fn, false_args = jax.closure_convert(false_fn)
true_args = jax.tree.map(
lambda x: jnp.where(condition, x, lax.stop_gradient(x)), true_args
)
false_args = jax.tree.map(
lambda x: jnp.where(condition, lax.stop_gradient(x), x), false_args
)
return jax.tree.map(
lambda x, y: jnp.where(condition, x, y),
true_fn(*true_args),
false_fn(*false_args),
)
def unstack(value: jnp.ndarray, axis: int = -1) -> list[jnp.ndarray]:
return [
jnp.squeeze(v, axis=axis)
@@ -93,18 +30,6 @@ def angdiff(alpha: jnp.ndarray, beta: jnp.ndarray) -> jnp.ndarray:
return d
def safe_arctan2(
x1: jnp.ndarray, x2: jnp.ndarray, eps: float = 1e-8
) -> jnp.ndarray:
"""Safe version of arctan2 that avoids NaN gradients when x1=x2=0."""
return safe_select(
jnp.abs(x1) + jnp.abs(x2) < eps,
lambda: jnp.zeros_like(jnp.arctan2(x1, x2)),
lambda: jnp.arctan2(x1, x2),
)
def weighted_mean(
*,
weights: jnp.ndarray,