Avoid use of deprecated jax.util APIs

jax.util was deprecated in JAX v0.6.0, and will be removed in JAX v0.7.0.

PiperOrigin-RevId: 759626266
Change-Id: If3dcb9a8151a99ecab1f8ec670bd99bbb31bd5de
This commit is contained in:
Augustin Zidek
2025-05-16 08:32:39 -07:00
committed by Copybara-Service
parent 8050b46921
commit b522811d9e

View File

@@ -17,7 +17,7 @@
import functools
import inspect
from typing import Any, Callable, Optional, Sequence, Union
from typing import Any, Callable, Optional, Sequence, TypeVar, Union
import haiku as hk
import jax
@@ -31,6 +31,19 @@ partial = functools.partial
PROXY = object()
T = TypeVar('T')
def _set_docstring(docstr: str) -> Callable[[T], T]:
"""Decorator for setting the docstring of a function."""
def wrapped(fun: T) -> T:
fun.__doc__ = docstr.format(fun=getattr(fun, '__name__', repr(fun)))
return fun
return wrapped
def _maybe_slice(array, i, slice_size, axis):
if axis is PROXY:
return array
@@ -120,7 +133,8 @@ def sharded_apply(
if shard_size is None:
return fun
@jax.util.wraps(fun, docstr=docstr)
@_set_docstring(docstr)
@functools.wraps(fun)
def mapped_fn(*args):
# Expand in axes and Determine Loop range
in_axes_ = _expand_axes(in_axes, args)