Revert a change that is not compatible with JAX 0.3.25.

PiperOrigin-RevId: 578486539
Change-Id: Id0e535ebf75916b5179791d093ca714e343ccb69
This commit is contained in:
Augustin Zidek
2023-11-01 05:31:26 -07:00
committed by Copybara-Service
parent f78c589304
commit f715f016d8

View File

@@ -163,7 +163,7 @@ def padding_consistent_rng(f):
keys = grid_keys(key, shape)
signature = (
'()->()'
if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key)
if isinstance(keys, jax.random.PRNGKeyArray)
else '(2)->()'
)
return jnp.vectorize(