Improve typing and error checks in mask_mean

PiperOrigin-RevId: 923361640
Change-Id: I708b5b8888e826f24e1ed0f1aa73a48e8174a161
This commit is contained in:
Augustin Zidek
2026-05-29 05:50:28 -07:00
committed by Copybara-Service
parent eb5154f441
commit 1dc1e0ad0b

View File

@@ -11,6 +11,7 @@
"""Utility functions for training AlphaFold and similar models."""
from collections import abc
from collections.abc import Sequence
import contextlib
import numbers
@@ -19,7 +20,6 @@ import haiku as hk
import jax.numpy as jnp
import numpy as np
VALID_DTYPES = [np.float32, np.float64, np.int8, np.int32, np.int64, bool]
@@ -48,23 +48,26 @@ def bfloat16_context():
yield
def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10):
def mask_mean(
mask: jnp.ndarray,
value: jnp.ndarray,
axis: int | Sequence[int] | None = None,
keepdims: bool = False,
eps: float = 1e-10,
) -> jnp.ndarray:
"""Masked mean."""
mask_shape = mask.shape
value_shape = value.shape
assert len(mask_shape) == len(
value_shape
), 'Shapes are not compatible, shapes: {}, {}'.format(mask_shape, value_shape)
if len(mask_shape) != len(value_shape):
raise ValueError(f'Incompatible shapes: {mask_shape=}, {value_shape=}')
if isinstance(axis, numbers.Integral):
axis = [axis]
elif axis is None:
axis = list(range(len(mask_shape)))
assert isinstance(
axis, abc.Iterable
), 'axis needs to be either an iterable, integer or "None"'
assert isinstance(axis, abc.Sequence)
broadcast_factor = 1.0
for axis_ in axis:
@@ -73,8 +76,8 @@ def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10):
if mask_size == 1:
broadcast_factor *= value_size
else:
error = f'Shapes are not compatible, shapes: {mask_shape}, {value_shape}'
assert mask_size == value_size, error
if mask_size != value_size:
raise ValueError(f'Incompatible shapes: {mask_shape=}, {value_shape=}')
return jnp.sum(mask * value, keepdims=keepdims, axis=axis) / (
jnp.maximum(