mirror of
https://github.com/google-deepmind/alphafold.git
synced 2026-06-04 14:58:05 +08:00
Change softmax to use where and float32.
PiperOrigin-RevId: 519675443 Change-Id: If87e6d16189ddcc03bb8435308d37f5919353107
This commit is contained in:
committed by
Copybara-Service
parent
e1d2d53af8
commit
e51692c9ae
@@ -32,6 +32,9 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
_SOFTMAX_MASK = -1e9
|
||||
|
||||
|
||||
def softmax_cross_entropy(logits, labels):
|
||||
"""Computes softmax cross entropy given logits and one-hot class labels."""
|
||||
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
|
||||
@@ -548,14 +551,14 @@ class Attention(hk.Module):
|
||||
self.global_config = global_config
|
||||
self.output_dim = output_dim
|
||||
|
||||
def __call__(self, q_data, m_data, bias, nonbatched_bias=None):
|
||||
def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
|
||||
"""Builds Attention module.
|
||||
|
||||
Arguments:
|
||||
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
|
||||
m_data: A tensor of memories from which the keys and values are
|
||||
projected, shape [batch_size, N_keys, m_channels].
|
||||
bias: A bias for the attention, shape [batch_size, N_queries, N_keys].
|
||||
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
|
||||
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
|
||||
|
||||
Returns:
|
||||
@@ -586,10 +589,11 @@ class Attention(hk.Module):
|
||||
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
|
||||
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
|
||||
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
|
||||
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias
|
||||
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
|
||||
if nonbatched_bias is not None:
|
||||
logits += jnp.expand_dims(nonbatched_bias, axis=0)
|
||||
weights = jax.nn.softmax(logits)
|
||||
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
|
||||
weights = utils.stable_softmax(logits)
|
||||
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
|
||||
|
||||
if self.global_config.zero_init:
|
||||
@@ -686,9 +690,10 @@ class GlobalAttention(hk.Module):
|
||||
|
||||
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
|
||||
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
|
||||
bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
|
||||
logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
|
||||
weights = jax.nn.softmax(logits)
|
||||
bias = q_mask[:, None, :, 0]
|
||||
logits = jnp.einsum('bhc,bkc->bhk', q, k)
|
||||
logits = jnp.where(bias, logits, _SOFTMAX_MASK)
|
||||
weights = utils.stable_softmax(logits)
|
||||
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
|
||||
|
||||
if self.global_config.zero_init:
|
||||
@@ -761,8 +766,8 @@ class MSARowAttentionWithPairBias(hk.Module):
|
||||
assert len(msa_mask.shape) == 2
|
||||
assert c.orientation == 'per_row'
|
||||
|
||||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
|
||||
assert len(bias.shape) == 4
|
||||
mask = msa_mask[:, None, None, :]
|
||||
assert len(mask.shape) == 4
|
||||
|
||||
msa_act = common_modules.LayerNorm(
|
||||
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
|
||||
@@ -788,7 +793,7 @@ class MSARowAttentionWithPairBias(hk.Module):
|
||||
msa_act = mapping.inference_subbatch(
|
||||
attn_mod,
|
||||
self.global_config.subbatch_size,
|
||||
batched_args=[msa_act, msa_act, bias],
|
||||
batched_args=[msa_act, msa_act, mask],
|
||||
nonbatched_args=[nonbatched_bias],
|
||||
low_memory=not is_training)
|
||||
|
||||
@@ -829,8 +834,8 @@ class MSAColumnAttention(hk.Module):
|
||||
msa_act = jnp.swapaxes(msa_act, -2, -3)
|
||||
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
|
||||
|
||||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
|
||||
assert len(bias.shape) == 4
|
||||
mask = msa_mask[:, None, None, :]
|
||||
assert len(mask.shape) == 4
|
||||
|
||||
msa_act = common_modules.LayerNorm(
|
||||
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
|
||||
@@ -841,7 +846,7 @@ class MSAColumnAttention(hk.Module):
|
||||
msa_act = mapping.inference_subbatch(
|
||||
attn_mod,
|
||||
self.global_config.subbatch_size,
|
||||
batched_args=[msa_act, msa_act, bias],
|
||||
batched_args=[msa_act, msa_act, mask],
|
||||
nonbatched_args=[],
|
||||
low_memory=not is_training)
|
||||
|
||||
@@ -884,9 +889,6 @@ class MSAColumnGlobalAttention(hk.Module):
|
||||
msa_act = jnp.swapaxes(msa_act, -2, -3)
|
||||
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
|
||||
|
||||
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
|
||||
assert len(bias.shape) == 4
|
||||
|
||||
msa_act = common_modules.LayerNorm(
|
||||
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
|
||||
msa_act)
|
||||
@@ -941,8 +943,8 @@ class TriangleAttention(hk.Module):
|
||||
pair_act = jnp.swapaxes(pair_act, -2, -3)
|
||||
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
|
||||
|
||||
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
|
||||
assert len(bias.shape) == 4
|
||||
mask = pair_mask[:, None, None, :]
|
||||
assert len(mask.shape) == 4
|
||||
|
||||
pair_act = common_modules.LayerNorm(
|
||||
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
|
||||
@@ -961,7 +963,7 @@ class TriangleAttention(hk.Module):
|
||||
pair_act = mapping.inference_subbatch(
|
||||
attn_mod,
|
||||
self.global_config.subbatch_size,
|
||||
batched_args=[pair_act, pair_act, bias],
|
||||
batched_args=[pair_act, pair_act, mask],
|
||||
nonbatched_args=[nonbatched_bias],
|
||||
low_memory=not is_training)
|
||||
|
||||
@@ -2171,11 +2173,11 @@ class TemplateEmbedding(hk.Module):
|
||||
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
|
||||
[num_res * num_res, num_templates, num_channels])
|
||||
|
||||
bias = (1e9 * (template_mask[None, None, None, :] - 1.))
|
||||
mask = template_mask[None, None, None, :]
|
||||
|
||||
template_pointwise_attention_module = Attention(
|
||||
self.config.attention, self.global_config, query_num_channels)
|
||||
nonbatched_args = [bias]
|
||||
nonbatched_args = [mask]
|
||||
batched_args = [flat_query, flat_templates]
|
||||
|
||||
embedding = mapping.inference_subbatch(
|
||||
|
||||
@@ -26,6 +26,20 @@ import jax.numpy as jnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
def stable_softmax(logits: jax.Array) -> jax.Array:
|
||||
"""Numerically stable softmax for (potential) bfloat 16."""
|
||||
if logits.dtype == jnp.float32:
|
||||
output = jax.nn.softmax(logits)
|
||||
elif logits.dtype == jnp.bfloat16:
|
||||
# Need to explicitly do softmax in float32 to avoid numerical issues
|
||||
# with large negatives. Large negatives can occur if trying to mask
|
||||
# by adding on large negative logits so that things softmax to zero.
|
||||
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
|
||||
else:
|
||||
raise ValueError(f'Unexpected input dtype {logits.dtype}')
|
||||
return output
|
||||
|
||||
|
||||
def bfloat16_creator(next_creator, shape, dtype, init, context):
|
||||
"""Creates float32 variables when bfloat16 is requested."""
|
||||
if context.original_dtype == jnp.bfloat16:
|
||||
|
||||
@@ -604,7 +604,6 @@
|
||||
" pbar.set_description(f'Running {model_name}')\n",
|
||||
"\n",
|
||||
" cfg = config.model_config(model_name)\n",
|
||||
" cfg.model.global_config.bfloat16 = False\n",
|
||||
"\n",
|
||||
" if model_type_to_use == ModelType.MONOMER:\n",
|
||||
" cfg.data.eval.num_ensemble = 1\n",
|
||||
|
||||
Reference in New Issue
Block a user