diff --git a/alphafold/model/utils.py b/alphafold/model/utils.py index 634f038..3e5ac62 100644 --- a/alphafold/model/utils.py +++ b/alphafold/model/utils.py @@ -163,7 +163,7 @@ def padding_consistent_rng(f): keys = grid_keys(key, shape) signature = ( '()->()' - if jax.dtypes.issubdtype(keys.dtype, jax.dtypes.prng_key) + if isinstance(keys, jax.random.PRNGKeyArray) else '(2)->()' ) return jnp.vectorize(