mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Fix jax.tree_multimap deprecation warning.
PiperOrigin-RevId: 451994826 Change-Id: I4573baf61d33010c75de717d3b49f47bc9c6a8ac
This commit is contained in:
committed by
Copybara-Service
parent
197bd19ee3
commit
d9e5e1d9c6
@@ -426,7 +426,7 @@ def torsion_angles_to_frames(
|
||||
chi3_frame_to_backb = chi2_frame_to_backb @ all_frames[:, 6]
|
||||
chi4_frame_to_backb = chi3_frame_to_backb @ all_frames[:, 7]
|
||||
|
||||
all_frames_to_backb = jax.tree_multimap(
|
||||
all_frames_to_backb = jax.tree_map(
|
||||
lambda *x: jnp.concatenate(x, axis=-1), all_frames[:, 0:5],
|
||||
chi2_frame_to_backb[:, None], chi3_frame_to_backb[:, None],
|
||||
chi4_frame_to_backb[:, None])
|
||||
|
||||
@@ -546,7 +546,7 @@ def generate_monomer_rigids(representations: Mapping[str, jnp.ndarray],
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
output = jax.tree_multimap(lambda *x: jnp.stack(x), *outputs)
|
||||
output = jax.tree_map(lambda *x: jnp.stack(x), *outputs)
|
||||
# Pass along for LDDT-Head.
|
||||
output['act'] = activations['act']
|
||||
|
||||
@@ -823,7 +823,7 @@ def compute_frames(
|
||||
alt_gt_frames = frames_batch['rigidgroups_alt_gt_frames']
|
||||
use_alt = use_alt[:, None]
|
||||
|
||||
renamed_gt_frames = jax.tree_multimap(
|
||||
renamed_gt_frames = jax.tree_map(
|
||||
lambda x, y: (1. - use_alt) * x + use_alt * y, gt_frames, alt_gt_frames)
|
||||
|
||||
return renamed_gt_frames, frames_batch['rigidgroups_gt_exists']
|
||||
@@ -1160,4 +1160,3 @@ class MultiRigidSidechain(hk.Module):
|
||||
'frames': all_frames_to_global, # geometry.Rigid3Array (N, 8)
|
||||
})
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -53,10 +53,10 @@ class Vec3Array:
|
||||
assert all([x == z for x, z in zip(self.x.shape, self.z.shape)])
|
||||
|
||||
def __add__(self, other: Vec3Array) -> Vec3Array:
|
||||
return jax.tree_multimap(lambda x, y: x + y, self, other)
|
||||
return jax.tree_map(lambda x, y: x + y, self, other)
|
||||
|
||||
def __sub__(self, other: Vec3Array) -> Vec3Array:
|
||||
return jax.tree_multimap(lambda x, y: x - y, self, other)
|
||||
return jax.tree_map(lambda x, y: x - y, self, other)
|
||||
|
||||
def __mul__(self, other: Float) -> Vec3Array:
|
||||
return jax.tree_map(lambda x: x * other, self)
|
||||
|
||||
@@ -198,8 +198,8 @@ class LayerStackTest(parameterized.TestCase):
|
||||
assert_fn = functools.partial(
|
||||
np.testing.assert_allclose, atol=1e-4, rtol=1e-4)
|
||||
|
||||
jax.tree_multimap(assert_fn, unrolled_grad,
|
||||
_slice_layers_params(layer_stack_grad))
|
||||
jax.tree_map(assert_fn, unrolled_grad,
|
||||
_slice_layers_params(layer_stack_grad))
|
||||
|
||||
def test_random(self):
|
||||
"""Random numbers should be handled correctly."""
|
||||
|
||||
@@ -125,7 +125,7 @@ def sharded_apply(
|
||||
# Expand in axes and Determine Loop range
|
||||
in_axes_ = _expand_axes(in_axes, args)
|
||||
|
||||
in_sizes = jax.tree_multimap(_maybe_get_size, args, in_axes_)
|
||||
in_sizes = jax.tree_map(_maybe_get_size, args, in_axes_)
|
||||
flat_sizes = jax.tree_flatten(in_sizes)[0]
|
||||
in_size = max(flat_sizes)
|
||||
assert all(i in {in_size, -1} for i in flat_sizes)
|
||||
@@ -137,7 +137,7 @@ def sharded_apply(
|
||||
last_shard_size = shard_size if last_shard_size == 0 else last_shard_size
|
||||
|
||||
def apply_fun_to_slice(slice_start, slice_size):
|
||||
input_slice = jax.tree_multimap(
|
||||
input_slice = jax.tree_map(
|
||||
lambda array, axis: _maybe_slice(array, slice_start, slice_size, axis
|
||||
), args, in_axes_)
|
||||
return fun(*input_slice)
|
||||
@@ -158,11 +158,11 @@ def sharded_apply(
|
||||
shard_shape[axis] * num_extra_shards +
|
||||
remainder_shape[axis],) + shard_shape[axis + 1:]
|
||||
|
||||
out_shapes = jax.tree_multimap(make_output_shape, out_axes_, shard_shapes,
|
||||
out_shapes)
|
||||
out_shapes = jax.tree_map(make_output_shape, out_axes_, shard_shapes,
|
||||
out_shapes)
|
||||
|
||||
# Calls dynamic Update slice with different argument order
|
||||
# This is here since tree_multimap only works with positional arguments
|
||||
# This is here since tree_map only works with positional arguments
|
||||
def dynamic_update_slice_in_dim(full_array, update, axis, i):
|
||||
return jax.lax.dynamic_update_slice_in_dim(full_array, update, i, axis)
|
||||
|
||||
@@ -170,7 +170,7 @@ def sharded_apply(
|
||||
slice_out = apply_fun_to_slice(slice_start, slice_size)
|
||||
update_slice = partial(
|
||||
dynamic_update_slice_in_dim, i=slice_start)
|
||||
return jax.tree_multimap(update_slice, outputs, slice_out, out_axes_)
|
||||
return jax.tree_map(update_slice, outputs, slice_out, out_axes_)
|
||||
|
||||
def scan_iteration(outputs, i):
|
||||
new_outputs = compute_shard(outputs, i, shard_size)
|
||||
@@ -181,7 +181,7 @@ def sharded_apply(
|
||||
def allocate_buffer(dtype, shape):
|
||||
return jnp.zeros(shape, dtype=dtype)
|
||||
|
||||
outputs = jax.tree_multimap(allocate_buffer, out_dtypes, out_shapes)
|
||||
outputs = jax.tree_map(allocate_buffer, out_dtypes, out_shapes)
|
||||
|
||||
if slice_starts.shape[0] > 0:
|
||||
outputs, _ = hk.scan(scan_iteration, outputs, slice_starts)
|
||||
|
||||
Reference in New Issue
Block a user