mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Improve typing and error checks in mask_mean
PiperOrigin-RevId: 923361640 Change-Id: I708b5b8888e826f24e1ed0f1aa73a48e8174a161
This commit is contained in:
committed by
Copybara-Service
parent
eb5154f441
commit
1dc1e0ad0b
@@ -11,6 +11,7 @@
|
||||
"""Utility functions for training AlphaFold and similar models."""
|
||||
|
||||
from collections import abc
|
||||
from collections.abc import Sequence
|
||||
import contextlib
|
||||
import numbers
|
||||
|
||||
@@ -19,7 +20,6 @@ import haiku as hk
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
VALID_DTYPES = [np.float32, np.float64, np.int8, np.int32, np.int64, bool]
|
||||
|
||||
|
||||
@@ -48,23 +48,26 @@ def bfloat16_context():
|
||||
yield
|
||||
|
||||
|
||||
def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10):
|
||||
def mask_mean(
|
||||
mask: jnp.ndarray,
|
||||
value: jnp.ndarray,
|
||||
axis: int | Sequence[int] | None = None,
|
||||
keepdims: bool = False,
|
||||
eps: float = 1e-10,
|
||||
) -> jnp.ndarray:
|
||||
"""Masked mean."""
|
||||
|
||||
mask_shape = mask.shape
|
||||
value_shape = value.shape
|
||||
|
||||
assert len(mask_shape) == len(
|
||||
value_shape
|
||||
), 'Shapes are not compatible, shapes: {}, {}'.format(mask_shape, value_shape)
|
||||
if len(mask_shape) != len(value_shape):
|
||||
raise ValueError(f'Incompatible shapes: {mask_shape=}, {value_shape=}')
|
||||
|
||||
if isinstance(axis, numbers.Integral):
|
||||
axis = [axis]
|
||||
elif axis is None:
|
||||
axis = list(range(len(mask_shape)))
|
||||
assert isinstance(
|
||||
axis, abc.Iterable
|
||||
), 'axis needs to be either an iterable, integer or "None"'
|
||||
assert isinstance(axis, abc.Sequence)
|
||||
|
||||
broadcast_factor = 1.0
|
||||
for axis_ in axis:
|
||||
@@ -73,8 +76,8 @@ def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10):
|
||||
if mask_size == 1:
|
||||
broadcast_factor *= value_size
|
||||
else:
|
||||
error = f'Shapes are not compatible, shapes: {mask_shape}, {value_shape}'
|
||||
assert mask_size == value_size, error
|
||||
if mask_size != value_size:
|
||||
raise ValueError(f'Incompatible shapes: {mask_shape=}, {value_shape=}')
|
||||
|
||||
return jnp.sum(mask * value, keepdims=keepdims, axis=axis) / (
|
||||
jnp.maximum(
|
||||
|
||||
Reference in New Issue
Block a user