mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Avoid use of deprecated jax.util APIs
jax.util was deprecated in JAX v0.6.0, and will be removed in JAX v0.7.0. PiperOrigin-RevId: 759626266 Change-Id: If3dcb9a8151a99ecab1f8ec670bd99bbb31bd5de
This commit is contained in:
committed by
Copybara-Service
parent
8050b46921
commit
b522811d9e
@@ -17,7 +17,7 @@
|
||||
import functools
|
||||
import inspect
|
||||
|
||||
from typing import Any, Callable, Optional, Sequence, Union
|
||||
from typing import Any, Callable, Optional, Sequence, TypeVar, Union
|
||||
|
||||
import haiku as hk
|
||||
import jax
|
||||
@@ -31,6 +31,19 @@ partial = functools.partial
|
||||
PROXY = object()
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
def _set_docstring(docstr: str) -> Callable[[T], T]:
|
||||
"""Decorator for setting the docstring of a function."""
|
||||
|
||||
def wrapped(fun: T) -> T:
|
||||
fun.__doc__ = docstr.format(fun=getattr(fun, '__name__', repr(fun)))
|
||||
return fun
|
||||
|
||||
return wrapped
|
||||
|
||||
|
||||
def _maybe_slice(array, i, slice_size, axis):
|
||||
if axis is PROXY:
|
||||
return array
|
||||
@@ -120,7 +133,8 @@ def sharded_apply(
|
||||
if shard_size is None:
|
||||
return fun
|
||||
|
||||
@jax.util.wraps(fun, docstr=docstr)
|
||||
@_set_docstring(docstr)
|
||||
@functools.wraps(fun)
|
||||
def mapped_fn(*args):
|
||||
# Expand in axes and Determine Loop range
|
||||
in_axes_ = _expand_axes(in_axes, args)
|
||||
|
||||
Reference in New Issue
Block a user