Migrate AF3 to tokamax.gated_linear_unit and tokamax.dot_product_attention.

PiperOrigin-RevId: 838262306
Change-Id: I321a78b2a7d0d5cdeabe797c59b1e9c03e33780d
This commit is contained in:
Ryan Pachauri
2025-11-29 18:54:14 -08:00
committed by Copybara-Service
parent 2e3703e82a
commit 389078218c
21 changed files with 3828 additions and 3168 deletions

View File

@@ -145,15 +145,12 @@ AlphaFold 3 uses the following separate libraries and packages:
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)
* [Haiku](https://github.com/deepmind/dm-haiku)
* [JAX](https://github.com/jax-ml/jax/)
* [jax-triton](https://github.com/jax-ml/jax-triton)
* [jaxtyping](https://github.com/patrick-kidger/jaxtyping)
* [libcifpp](https://github.com/pdb-redo/libcifpp)
* [NumPy](https://github.com/numpy/numpy)
* [pybind11](https://github.com/pybind/pybind11) and
[pybind11_abseil](https://github.com/pybind/pybind11_abseil)
* [RDKit](https://github.com/rdkit/rdkit)
* [Tree](https://github.com/deepmind/tree)
* [Triton](https://github.com/triton-lang/triton)
* [Tokamax](https://github.com/openxla/tokamax)
* [tqdm](https://github.com/tqdm/tqdm)
We thank all their contributors and maintainers!

File diff suppressed because it is too large Load Diff

View File

@@ -15,18 +15,15 @@ requires-python = ">=3.11"
readme = "README.md"
license = {file = "LICENSE"}
dependencies = [
"absl-py",
"absl-py>=2.3.1",
"dm-haiku==0.0.13",
"dm-tree",
"jax==0.4.34",
"jax[cuda12]==0.4.34",
"jax-triton==0.2.0",
"jaxtyping==0.2.34",
"jax==0.8.0",
"jax[cuda12]==0.8.0",
"numpy",
"rdkit==2024.3.5",
"triton==3.1.0",
"setuptools==78.1.0",
"tokamax==0.0.4",
"tqdm",
"typeguard==2.13.3",
"zstandard",
]

File diff suppressed because it is too large Load Diff

View File

@@ -42,7 +42,6 @@ import alphafold3.cpp
from alphafold3.data import featurisation
from alphafold3.data import pipeline
from alphafold3.data.tools import shards
from alphafold3.jax.attention import attention
from alphafold3.model import features
from alphafold3.model import model
from alphafold3.model import params
@@ -52,6 +51,7 @@ import haiku as hk
import jax
from jax import numpy as jnp
import numpy as np
import tokamax
_HOME_DIR = pathlib.Path(os.environ.get('HOME'))
@@ -373,7 +373,7 @@ _FORCE_OUTPUT_DIR = flags.DEFINE_bool(
def make_model_config(
*,
flash_attention_implementation: attention.Implementation = 'triton',
flash_attention_implementation: tokamax.DotProductAttentionImplementation = 'triton',
num_diffusion_samples: int = 5,
num_recycles: int = 10,
return_embeddings: bool = False,
@@ -935,7 +935,8 @@ def main(_):
model_runner = ModelRunner(
config=make_model_config(
flash_attention_implementation=typing.cast(
attention.Implementation, _FLASH_ATTENTION_IMPLEMENTATION.value
tokamax.DotProductAttentionImplementation,
_FLASH_ATTENTION_IMPLEMENTATION.value,
),
num_diffusion_samples=_NUM_DIFFUSION_SAMPLES.value,
num_recycles=_NUM_RECYCLES.value,

View File

@@ -1,139 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Scaled dot-product attention."""
import typing
from typing import Literal, TypeAlias
from alphafold3.jax.attention import attention_base as base
from alphafold3.jax.attention import flash_attention as attention_triton
from alphafold3.jax.attention import xla_attention
from alphafold3.jax.common import triton_utils
import jax
from jax.typing import DTypeLike # pylint: disable=g-importing-member
import jaxtyping
from jaxtyping import Array # pylint: disable=g-importing-member
from jaxtyping import Bool # pylint: disable=g-importing-member
from jaxtyping import Float # pylint: disable=g-importing-member
import typeguard
Implementation: TypeAlias = Literal["cudnn", "xla", "triton"]
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def dot_product_attention(
query: Float[Array, "*B T H D"],
key: Float[Array, "*B t #H D"],
value: Float[Array, "*B t #H D"],
*,
bias: Float[Array, "*#B #H #T #t"] | None = None,
mask: Bool[Array, "*#B #H #T #t"] | None = None,
implementation: Implementation | None = None,
logits_dtype: DTypeLike | None = None,
precision: (
jax.lax.Precision | tuple[jax.lax.Precision, jax.lax.Precision] | None
) = None,
) -> Float[Array, "*B T H D"]:
"""Performs scaled dot-product attention.
Scaled dot-product attention from "Attention is all you need"
https://arxiv.org/abs/1706.03762.
Computes self- or cross-attention. The following is computed:
softmax(qk_scale * query @ key^T + bias) @ value.
Supports both multi-head and multi-query attention
(https://arxiv.org/abs/1911.02150).
Arguments:
query: Query array of shape `[batch, seq_len_q, num_heads, head_dim]`.
key: Key array of shape `[batch, seq_len_kv, num_heads, head_dim]`.
`num_heads` can be 1 for multi-query attention.
value: Value array of shape `[batch, seq_len_kv, num_heads, head_dim]`.
`num_heads` can be 1 for multi-query attention.
bias: Optional bias array, broadcastable to shape `[batch, num_heads,
seq_len_q, seq_len_kv]`.
mask: Optional boolean mask, broadcastable to `[batch, num_heads, seq_len_q,
seq_len_kv]`. Attention weights are masked out if the corresponding mask
value is `False`.
implementation: if `None` (default), an implementation is automatically
chosen. 'xla' will use standard XLA and work on any platform, 'triton'
will use a fused Triton GPU kernel, and 'cudnn' a cuDNN FlashAttention
kernel. Only a subset of data types, shapes and GPUs are supported by
'triton' and 'cudnn', with an exception thrown in this case.
logits_dtype: Data type for attention logits (`query @ key^T`). If `None` is
passed (the default), the accumulator type from the `query @ key^T` dot
product will be used, which is FP32 for BF16/FP16/FP32 inputs. Note that
this default increases the memory usage for BF16/FP16 inputs when using
`implementation='xla'`, but does not increase memory usage when using
`implementation='triton'`.
precision: The precision for the dot products. Either `None` (default) which
uses the default JAX precision for a backend; a tuple `(
query_key_dot_precision, weights_value_dot_precision)` of
`jax.lax.Precision` objects; or a single `jax.lax.Precision` object
applied to both dot products.
Returns:
An array with the same shape as `query`.
"""
if implementation is not None:
named_args = typing.get_args(Implementation)
if implementation not in named_args:
raise ValueError(
f"Unsupported named implementation. Must be one of {named_args}."
)
if implementation == "cudnn":
if logits_dtype is not None:
raise ValueError(
"logits_dtype is not supported for cudnn implementation."
)
if precision is not None:
raise NotImplementedError(
"precision is not supported for cudnn implementation."
)
return jax.nn.dot_product_attention(
query=query,
key=key,
value=value,
bias=bias,
mask=mask,
implementation="cudnn",
)
logits_dtype = base.AUTO if logits_dtype is None else logits_dtype
precision = jax.lax.Precision.DEFAULT if precision is None else precision
args = (query, key, value)
kwargs = dict(
precision=precision,
logits_dtype=logits_dtype,
bias=bias,
mask=mask,
)
if implementation == "triton":
if not triton_utils.has_triton_support():
raise ValueError(
"implementation='triton' for FlashAttention is unsupported on this"
" GPU generation. Please use implementation='xla' instead."
)
return attention_triton.TritonFlashAttention()(*args, **kwargs)
if implementation is None and triton_utils.has_triton_support():
try:
return attention_triton.TritonFlashAttention()(*args, **kwargs)
except Exception: # pylint: disable=broad-exception-caught
pass # Fallback to XLA.
return xla_attention.XlaDotProductAttention()(*args, **kwargs)

View File

@@ -1,363 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Common types and utilities for attention kernels."""
import abc
import dataclasses
import enum
import functools
import math
from typing import Any, Self
from alphafold3.jax.common import array_view
from alphafold3.jax.common import precision as precision_lib
import jax
import jax.numpy as jnp
from jax.typing import DTypeLike # pylint: disable=g-importing-member
import jaxtyping
from jaxtyping import Array, Bool, Float, Int # pylint: disable=g-multiple-import,g-importing-member
import typeguard
class AUTO: # Used as a sentinel value.
pass
DotPrecisionLike = jax.lax.Precision | precision_lib.DotPrecision
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class Mask:
"""An attention mask.
`k_start` (inclusive) and `k_end` (exclusive) define range of enabled
k-sequence values for each row of logits.
For example, a local attention mask could be defined as follows:
```
seq_len_q = seq_len_k = 4
window_size = 2
k_start = jnp.maximum(0, jnp.arange(seq_len_q) + 1 - window_size)
mask = Mask(k_start=k_start, is_causal=True)
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
[[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1]], dtype=bool)
```
Or equivalently (but less efficiently):
```
k_end = jnp.arange(seq_len_q) + 1
k_start = jnp.maximum(0, k_end - window_size)
mask = Mask(k_start=k_start, k_end=k_end)
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
[[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 1]], dtype=bool)
```
A mask for two independent causal sequences could be defined as follows:
```
k_start = jnp.array([0, 0, 2, 2])
mask = Mask(k_start=k_start, is_causal=True)
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
[[1, 0, 0, 0],
[1, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 1, 1]], dtype=bool)
```
"""
bool_mask: Bool[Array, "*#B #T #t"] | None = None
_: dataclasses.KW_ONLY
q_start: Int[Array, "*#B #t"] | None = None
q_end: Int[Array, "*#B #t"] | None = None
k_start: Int[Array, "*#B #T"] | None = None
k_end: Int[Array, "*#B #T"] | None = None
is_causal: bool = False
def tree_flatten(self):
return (
self.bool_mask,
self.q_start,
self.q_end,
self.k_start,
self.k_end,
), (self.is_causal,)
@classmethod
def tree_unflatten(cls, aux, children) -> Self:
(is_causal,) = aux
bool_mask, q_start, q_end, k_start, k_end = children
return cls(
bool_mask,
q_start=q_start,
q_end=q_end,
k_start=k_start,
k_end=k_end,
is_causal=is_causal,
)
def as_array(
self,
q_len_or_indices: int | Int[Array, "*#B T"],
k_len_or_indices: int | Int[Array, "*#B t"],
) -> Bool[Array, "*#B #T #t"] | None:
"""Returns the mask as a boolean array."""
if isinstance(q_len_or_indices, int):
q_indices = jnp.arange(q_len_or_indices)
else:
q_indices = q_len_or_indices
if isinstance(k_len_or_indices, int):
k_indices = jnp.arange(k_len_or_indices)
else:
k_indices = k_len_or_indices
q_indices = q_indices[..., None]
k_indices = k_indices[..., None, :]
mask = []
if self.bool_mask is not None:
mask.append(self.bool_mask)
# Check `bool_mask` shape is compatible with `{q,kv}_indices`.
_ = jnp.broadcast_shapes(
q_indices.shape, k_indices.shape, self.bool_mask.shape
)
if self.q_start is not None:
mask.append(q_indices >= self.q_start[..., None, :])
if self.q_end is not None:
mask.append(q_indices < self.q_end[..., None, :])
if self.k_start is not None:
mask.append(k_indices >= self.k_start[..., None])
if self.k_end is not None:
mask.append(k_indices < self.k_end[..., None])
if self.is_causal:
mask.append(q_indices >= k_indices)
logical_and = functools.partial(functools.reduce, jnp.logical_and)
return jax.lax.broadcast_to_rank(logical_and(mask), 3) if mask else None
def take(self, *attrs: str) -> tuple[Any, ...]:
"""Returns a mask with attrs removed and the removed attrs."""
default_mask = type(self)()
replacements = {attr: getattr(default_mask, attr) for attr in attrs}
values = (getattr(self, attr) for attr in attrs)
return dataclasses.replace(self, **replacements), *values
def __and__(self, other: "Bool[Array, '*#B #T #t'] | Mask") -> "Mask": # pylint: disable=g-inconsistent-quotes
"""Returns the intersection of two masks."""
if not isinstance(other, Mask):
other = Mask(other)
def combine(op):
return lambda a, b: b if a is None else a if b is None else op(a, b)
return Mask(
bool_mask=combine(jnp.logical_and)(self.bool_mask, other.bool_mask),
q_end=combine(jnp.minimum)(self.q_end, other.q_end),
k_start=combine(jnp.maximum)(self.k_start, other.k_start),
k_end=combine(jnp.minimum)(self.k_end, other.k_end),
is_causal=self.is_causal or other.is_causal,
)
CAUSAL_MASK = Mask(is_causal=True)
SoftmaxResidual = (
tuple[Float[Array, "*B H T"], Float[Array, "*B H T"]]
| Float[Array, "*B H T"]
)
@enum.unique
class SoftmaxResidualMode(enum.Enum):
"""The mode of storing softmax residuals for the backwards pass.
The stable softmax calculation performs two reductions calculating:
- the maximum input value (`x_max`),
- the sum of exponentiated values (`denom`).
We can store these values as residuals to avoid the need to recompute them
in the backwards pass.
It is also possible to combine the two residuals into a single residual,
`res = x_max + log(denom)`, as `exp(x - res) === exp(x - x_max - log(denom))
=== exp(x - x_max) / denom`. Combining the residuals reduces the memory usage
of the residuals, but will reduce the accuracy of the backwards pass if
`abs(x_max) >> log(denom)`.
"""
SEPARATE = "separate"
COMBINED = "combined"
def conform(self, aux: SoftmaxResidual) -> SoftmaxResidual | None:
match self, aux:
case None, _:
return None
case SoftmaxResidualMode.SEPARATE, (_, _):
return aux
case SoftmaxResidualMode.SEPARATE, _: # pytype: disable=redundant-match # b/300135240
raise ValueError("`aux` has been combined.")
case SoftmaxResidualMode.COMBINED, (x_max, denom):
return x_max + jnp.log(denom)
case SoftmaxResidualMode.COMBINED, _: # pytype: disable=redundant-match # b/300135240
return aux
class DotProductAttention(abc.ABC):
"""Dot product attention function."""
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def __call__(
self,
query: Float[Array | array_view.ArrayView, "*B T H D"],
key: Float[Array | array_view.ArrayView, "*B t h D"],
value: Float[Array | array_view.ArrayView, "*B t h D"],
*,
precision: (
DotPrecisionLike | tuple[DotPrecisionLike, DotPrecisionLike]
) = jax.lax.Precision.DEFAULT,
logits_dtype: DTypeLike | type[AUTO] = AUTO,
bias: Float[Array, "*#B #H #T #t"] | None = None,
mask: Bool[Array, "*#B #H #T #t"] | Mask | None = None,
q_indices: Int[Array, "*#B #H T"] | None = None,
k_indices: Int[Array, "*#B #H t"] | None = None,
) -> Float[Array, "*B T H D"]:
"""Performs scaled dot-product attention.
Scaled dot-product attention from "Attention is all you need"
https://arxiv.org/abs/1706.03762.
Computes self- or cross-attention. The following is computed:
softmax(qk_scale * query @ key^T + bias) @ value.
Supports both multi-head and multi-query attention
(https://arxiv.org/abs/1911.02150).
Arguments:
query: Query array of shape `[batch, seq_len_q, num_heads_q, head_dim]`.
It must be a multiple of num_heads_kv.
Here's an example of how q/kv heads are interleaved:
For 8 key/value heads and 4 query heads:
- key/value heads [0, 1] see query head 0
- key/value heads [2, 3] see query head 1
- key/value heads [4, 5] see query head 2
key: Key array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. It
must be divisible by num_heads_q.
value: Value array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`.
precision: The precision for the dot products. Either a tuple `(
query_key_dot_precision, weights_value_dot_precision)` or a single
precision applied to both dot products.
logits_dtype: Data type for attention logits (`query @ key^T`). If `AUTO`
is passed (the default), the accumulator type from the `query @ key^T`
dot product will be used.
bias: Optional bias array, broadcastable to shape `[batch, num_heads,
seq_len_q, seq_len_kv]`.
mask: Optional boolean mask, broadcastable to `[batch, num_heads,
seq_len_q, seq_len_kv]`. Attention weights are masked out if the
corresponding mask value is `False`.
q_indices: Optional indices for each token in query sequence.
k_indices: Optional indices for each token in key/value sequence.
Returns:
An array with the same shape as `query`.
""" # fmt: skip
return self.fwd(
query,
key,
value,
precision=precision,
logits_dtype=logits_dtype,
bias=bias,
mask=mask,
q_indices=q_indices,
k_indices=k_indices,
)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def fwd(
self,
query: Float[Array | array_view.ArrayView, "*B T H D"],
key: Float[Array | array_view.ArrayView, "*B t h D"],
value: Float[Array | array_view.ArrayView, "*B t h D"],
*,
precision: (
DotPrecisionLike | tuple[DotPrecisionLike, DotPrecisionLike]
) = jax.lax.Precision.DEFAULT,
logits_dtype: DTypeLike | type[AUTO] = AUTO,
bias: Float[Array, "*#B #H #T #t"] | None = None,
mask: Bool[Array, "*#B #H #T #t"] | Mask | None = None,
q_indices: Int[Array, "*#B #H T"] | None = None,
k_indices: Int[Array, "*#B #H t"] | None = None,
) -> Float[Array, "*B T H D"]:
"""Performs attention."""
if not isinstance(precision, tuple):
precision = (precision, precision)
q_k_dot_precision, weights_v_dot_precision = precision
if not isinstance(q_k_dot_precision, precision_lib.DotPrecision):
q_k_dot_precision = precision_lib.get_equivalent_dot_precision(
query.dtype, key.dtype, q_k_dot_precision
)
if not isinstance(weights_v_dot_precision, precision_lib.DotPrecision):
weights_v_dot_precision = precision_lib.get_equivalent_dot_precision(
value.dtype, value.dtype, weights_v_dot_precision
)
if logits_dtype is AUTO:
logits_dtype = q_k_dot_precision.accumulator_dtype
if not isinstance(mask, Mask):
mask = Mask(mask)
return self._fwd(
array_view.as_array_view(query),
array_view.as_array_view(key),
array_view.as_array_view(value),
q_k_dot_precision=q_k_dot_precision,
logits_dtype=jnp.dtype(logits_dtype),
logits_scale=1 / math.sqrt(query.shape[-1]),
bias=bias,
mask=mask,
weights_v_dot_precision=weights_v_dot_precision,
q_indices=q_indices,
k_indices=k_indices,
)
@abc.abstractmethod
def _fwd(
self,
q: Float[array_view.ArrayView, "*B T H D"],
k: Float[array_view.ArrayView, "*B t h D"],
v: Float[array_view.ArrayView, "*B t h D"],
*,
q_k_dot_precision: precision_lib.DotPrecision,
logits_dtype: jnp.dtype,
logits_scale: float,
bias: Float[Array, "*#B #H #T #t"] | None,
mask: Mask | None,
weights_v_dot_precision: precision_lib.DotPrecision,
q_indices: Int[Array, "*#B #H T"] | None = None,
k_indices: Int[Array, "*#B #H t"] | None = None,
) -> Float[Array, "*B T H D"]:
"""Performs attention."""
...

View File

@@ -1,62 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Attention call argument specifications.
Attention argument specifications used by users of the library.
They are the most important test cases, and also cases for optimize
performance of via autotuning.
"""
from typing import Any
import jax
ShapedArray = jax.ShapeDtypeStruct
def _make_argspec(
*,
q_shape,
dtype,
k_shape=None,
v_shape=None,
bias_shape=None,
mask_shape=None,
**kwargs,
) -> dict[str, Any]:
"""Make argspec from shapes and kwargs."""
if k_shape is None:
k_shape = q_shape
if v_shape is None:
v_shape = k_shape
return dict(
query=ShapedArray(q_shape, dtype),
key=ShapedArray(k_shape, dtype),
value=ShapedArray(v_shape, dtype),
bias=ShapedArray(bias_shape, dtype) if bias_shape is not None else None,
mask=ShapedArray(mask_shape, 'bool_') if mask_shape is not None else None,
**kwargs,
)
# A subset of the full set of argument specifications. Useful for tap-tests and
# microbenchmarks.
CALL_ARG_SPECS = dict(
vanilla_f32=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='float32'),
vanilla_bf16=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='bfloat16'),
alphafold=_make_argspec(
q_shape=(384, 384, 4, 32),
bias_shape=(1, 4, 384, 384),
mask_shape=(384, 1, 1, 384),
dtype='bfloat16',
),
)

View File

@@ -1,703 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Triton FlashAttention implementation."""
import dataclasses
import functools
from alphafold3.jax.attention import attention_base as base
from alphafold3.jax.common import array_view
from alphafold3.jax.common import precision as precision_lib
from alphafold3.jax.common import triton_utils
import jax
import jax.numpy as jnp
import jax_triton as jt
import jaxtyping
from jaxtyping import Array, Bool, Float, Int # pylint: disable=g-multiple-import,g-importing-member
import triton
import triton.language as tl
import typeguard
@triton.jit
def _fwd_kernel_inner(
start_loop,
end_loop,
q,
span_q,
k_block_ptr,
v_block_ptr,
bias_block_ptr,
mask_block_ptr,
k_start,
k_end,
seq_len_k,
acc,
m_i,
l_i,
bias_advance: tl.constexpr,
mask_advance: tl.constexpr,
is_causal: tl.constexpr,
use_attention_mask: tl.constexpr,
use_k_start: tl.constexpr,
use_k_end: tl.constexpr,
use_bias: tl.constexpr,
block_k: tl.constexpr,
use_mask_k: tl.constexpr,
k_boundary_check: tl.constexpr,
v_boundary_check: tl.constexpr,
dot_fn_qk: tl.constexpr,
dot_fn_kv: tl.constexpr,
):
"""Triton MHA forward kernel's inner loop."""
for start_k in range(start_loop, end_loop, block_k):
start_k = tl.multiple_of(start_k, block_k)
span_k = start_k + tl.arange(0, block_k)
k = tl.load(
k_block_ptr,
boundary_check=k_boundary_check,
padding_option="zero" if len(k_boundary_check.value) else "",
)
v = tl.load(
v_block_ptr,
boundary_check=v_boundary_check,
padding_option="zero" if len(v_boundary_check.value) else "",
)
if use_bias:
bias = tl.load(bias_block_ptr)
qk = dot_fn_qk(q.to(k.dtype), k) # [block_q, block_k]
if use_bias:
# Prevent dot accumulating into the bias tensor. It appears that Triton
# doesn't pipeline the bias load as it does the `k` load, so the bias load
# blocks the matmul if the add is merged.
qk = qk.to(tl.uint32, bitcast=True) & 0xFFFFFFFF
qk = qk.to(tl.float32, bitcast=True)
qk += bias
if use_attention_mask | use_k_start | use_k_end:
mask_value = float(jnp.finfo(jnp.float32).min)
if use_attention_mask:
mask = tl.load(mask_block_ptr)
qk = tl.where(mask, qk, mask_value)
if use_k_start:
# This check is there to work around a triton compiler bug, but it
# shouldn't be strictly needed.
if tl.sum(k_start) != 0:
qk = tl.where(k_start[:, None] <= span_k[None, :], qk, mask_value)
if is_causal:
qk = tl.where(span_q[:, None] >= span_k[None, :], qk, float("-inf"))
elif use_k_end:
# When called with k_end and is_causal=True, the causal mask gets folded
# into k_end and is_causal is set to False.
qk = tl.where(k_end[:, None] > span_k[None, :], qk, mask_value)
if use_mask_k:
qk = tl.where((span_k < seq_len_k)[None, :], qk, float("-inf"))
m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) # Shape [block_q].
p = tl.exp(qk - m_ij[:, None]) # Shape [block_q, block_k].
alpha = tl.exp(m_i - m_ij)
m_i = m_ij
acc *= alpha[:, None]
l_i *= alpha
l_i += tl.sum(p, axis=1)
# Add the new block of attention weights.
acc += dot_fn_kv(p.to(v.dtype), v)
k_block_ptr = tl.advance(k_block_ptr, (0, block_k))
v_block_ptr = tl.advance(v_block_ptr, (block_k, 0))
bias_block_ptr = tl.advance(bias_block_ptr, bias_advance.value)
mask_block_ptr = tl.advance(mask_block_ptr, mask_advance.value)
return (
k_block_ptr,
v_block_ptr,
bias_block_ptr,
mask_block_ptr,
acc,
m_i,
l_i,
)
# Based on Algorithm 1 of https://arxiv.org/abs/2205.14135.
# Inspired by the official Triton tutorial implementation
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
@triton.jit
def _fwd_kernel(
# Input arrays.
q_ptr,
k_ptr,
v_ptr,
bias_ptr,
mask_ptr,
k_start_ptr,
k_end_ptr,
# Scalar inputs.
q_offset,
k_offset,
v_offset,
q_stride_b,
q_stride_s,
q_stride_h,
q_stride_d,
k_stride_b,
k_stride_s,
k_stride_h,
k_stride_d,
v_stride_b,
v_stride_s,
v_stride_h,
v_stride_d,
bias_stride_b,
bias_stride_h,
bias_stride_sq,
bias_stride_sk,
mask_stride_b,
mask_stride_h,
mask_stride_sq,
mask_stride_sk,
k_start_stride_b,
k_start_stride_h,
k_start_stride_sq,
k_end_stride_b,
k_end_stride_h,
k_end_stride_sq,
o_stride_b,
o_stride_s,
o_stride_h,
o_stride_d,
num_heads_q,
num_heads_k,
seq_len_q,
seq_len_k,
# Output arrays.
o_ptr,
# Compile-time constants.
is_causal: tl.constexpr,
use_attention_mask: tl.constexpr,
use_k_start: tl.constexpr,
use_k_end: tl.constexpr,
use_bias: tl.constexpr,
sm_scale: tl.constexpr,
block_q: tl.constexpr,
block_k: tl.constexpr,
head_dim: tl.constexpr,
use_mask_q: tl.constexpr,
use_mask_k: tl.constexpr,
bias_bcast_sq: tl.constexpr,
mask_bcast_sq: tl.constexpr,
dot_fn_qk: tl.constexpr,
dot_fn_kv: tl.constexpr,
):
"""Triton MHA forward kernel."""
# pytype: disable=annotation-type-mismatch,unsupported-operands
block_d: tl.constexpr = jt.utils.next_power_of_2(head_dim.value)
# Each thread block processes one batch element (b) and one head (h).
start_q = tl.program_id(1) * block_q
off_h = tl.program_id(0) # int in [0, num_heads_o).
off_b = tl.program_id(2) # int in [0, batch_size)
off_h_k = off_h // (num_heads_q // num_heads_k)
q_ptr += off_h * q_stride_h + off_b * q_stride_b + q_offset
k_ptr += off_h_k * k_stride_h + off_b * k_stride_b + k_offset
v_ptr += off_h_k * v_stride_h + off_b * v_stride_b + v_offset
o_ptr += off_h * o_stride_h + off_b * o_stride_b
if use_bias:
bias_ptr += off_b * bias_stride_b + off_h * bias_stride_h
if use_attention_mask:
mask_ptr += off_b * mask_stride_b + off_h * mask_stride_h
if use_k_start:
k_start_ptr += off_b * k_start_stride_b + off_h * k_start_stride_h
if use_k_end:
k_end_ptr += off_b * k_end_stride_b + off_h * k_end_stride_h
q_block_ptr = tl.make_block_ptr(
q_ptr,
shape=(seq_len_q, head_dim),
strides=(q_stride_s, q_stride_d),
offsets=(start_q, 0),
block_shape=(block_q, block_d),
order=(1, 0),
)
k_block_ptr = tl.make_block_ptr(
k_ptr,
shape=(head_dim, seq_len_k),
strides=(k_stride_d, k_stride_s),
offsets=(0, 0),
block_shape=(block_d, block_k),
order=(0, 1),
)
v_block_ptr = tl.make_block_ptr(
v_ptr,
shape=(seq_len_k, head_dim),
strides=(v_stride_s, v_stride_d),
offsets=(0, 0),
block_shape=(block_k, block_d),
order=(1, 0),
)
q_boundary_check0: tl.constexpr = (0,) if use_mask_q else ()
q_boundary_check1: tl.constexpr = (1,) if head_dim != block_d else ()
q_boundary_check: tl.constexpr = q_boundary_check0 + q_boundary_check1
q_padding_option: tl.constexpr = "zero" if len(q_boundary_check.value) else ""
k_boundary_check: tl.constexpr = (0,) if head_dim != block_d else ()
v_boundary_check: tl.constexpr = (0,) if use_mask_k else ()
# If broadcasting in a given dim, use a 1D block (observed to be faster).
bias_start_dim: tl.constexpr = 1 if bias_bcast_sq else 0
bias_block_ptr = tl.make_block_ptr(
bias_ptr,
shape=(seq_len_q, seq_len_k)[bias_start_dim:],
strides=(bias_stride_sq, bias_stride_sk)[bias_start_dim:],
offsets=(start_q, 0)[bias_start_dim:],
block_shape=(block_q, block_k)[bias_start_dim:],
order=(1, 0)[bias_start_dim:],
)
bias_advance: tl.constexpr = (0, block_k)[bias_start_dim:]
mask_start_dim: tl.constexpr = 1 if mask_bcast_sq else 0
mask_block_ptr = tl.make_block_ptr(
mask_ptr,
shape=(seq_len_q, seq_len_k)[mask_start_dim:],
strides=(mask_stride_sq, mask_stride_sk)[mask_start_dim:],
offsets=(start_q, 0)[mask_start_dim:],
block_shape=(block_q, block_k)[mask_start_dim:],
order=(1, 0)[mask_start_dim:],
)
mask_advance: tl.constexpr = (0, block_k)[mask_start_dim:]
k_start_block_ptr = tl.make_block_ptr(
k_start_ptr,
shape=(seq_len_q,),
strides=(k_start_stride_sq,),
offsets=(start_q,),
block_shape=(block_q,),
order=(0,),
)
k_end_block_ptr = tl.make_block_ptr(
k_end_ptr,
shape=(seq_len_q,),
strides=(k_end_stride_sq,),
offsets=(start_q,),
block_shape=(block_q,),
order=(0,),
)
# pytype: enable=annotation-type-mismatch,unsupported-operands
# Each thread block processes a block of block_q queries.
span_q = start_q + tl.arange(0, block_q)
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
m_i = tl.full([block_q], float("-inf"), dtype=tl.float32)
l_i = tl.zeros([block_q], dtype=tl.float32)
# acc is the buffer where we accumulate the output on sram.
acc = tl.zeros([block_q, block_d], dtype=tl.float32)
# Load q: it will stay in smem throughout. Indices form a matrix because we
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
q = tl.load(
q_block_ptr,
boundary_check=q_boundary_check,
padding_option=q_padding_option,
)
q *= sm_scale
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
# Here we only loop over blocks of kv to process entire seq_len, the loop over
# blocks of q is carried out by the grid.
k_start = None
if use_k_start:
k_start = tl.load(k_start_block_ptr)
start_loop = tl.maximum(tl.min(k_start), 0)
blocks_to_skip = start_loop // block_k
start_loop = block_k * blocks_to_skip # Floor to multiple of block_k.
for _ in range(blocks_to_skip):
# Advance all block pointers to the first valid block.
k_block_ptr = tl.advance(k_block_ptr, (0, block_k))
v_block_ptr = tl.advance(v_block_ptr, (block_k, 0))
bias_block_ptr = tl.advance(bias_block_ptr, bias_advance.value)
mask_block_ptr = tl.advance(mask_block_ptr, mask_advance.value)
else:
start_loop = 0
k_end = None
if is_causal:
end_loop = tl.minimum((start_q // block_k) * block_k, seq_len_k)
elif use_k_end:
k_end = tl.load(k_end_block_ptr)
end_loop = tl.minimum(tl.max(k_end), seq_len_k)
else:
end_loop = seq_len_k
(
k_block_ptr,
v_block_ptr,
bias_block_ptr,
mask_block_ptr,
acc,
m_i,
l_i,
) = _fwd_kernel_inner(
start_loop,
end_loop,
q,
span_q,
k_block_ptr,
v_block_ptr,
bias_block_ptr,
mask_block_ptr,
k_start,
k_end,
seq_len_k,
acc,
m_i,
l_i,
bias_advance,
mask_advance,
False, # is_causal
use_attention_mask,
use_k_start,
use_k_end,
use_bias,
block_k,
use_mask_k,
k_boundary_check,
v_boundary_check,
dot_fn_qk,
dot_fn_kv,
)
if is_causal:
tl.debug_barrier() # Help compiler schedule loops independently.
start_loop, end_loop = end_loop, tl.minimum(end_loop + block_k, seq_len_k)
_, _, _, _, acc, _, l_i = _fwd_kernel_inner(
start_loop,
end_loop,
q,
span_q,
k_block_ptr,
v_block_ptr,
bias_block_ptr,
mask_block_ptr,
k_start,
k_end,
seq_len_k,
acc,
m_i,
l_i,
bias_advance,
mask_advance,
True, # is_causal
use_attention_mask,
use_k_start,
use_k_end,
use_bias,
block_k,
use_mask_k,
k_boundary_check,
v_boundary_check,
dot_fn_qk,
dot_fn_kv,
)
# It is possible that every value in a row was masked to f32 min or that the
# main loop has been completely optimised out, and that `l_i` is `0` for that
# row. Add epsilon value to avoid NaNs from `0 / 0`.
l_i += float(jnp.finfo(jnp.float32).tiny)
acc /= l_i[:, None]
# Write output to dram.
o_block_ptr = tl.make_block_ptr(
o_ptr,
shape=(seq_len_q, head_dim),
strides=(o_stride_s, o_stride_d),
offsets=(start_q, 0),
block_shape=(block_q, block_d),
order=(1, 0),
)
acc = acc.to(o_ptr.dtype.element_ty)
tl.store(o_block_ptr, acc, boundary_check=q_boundary_check)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def _fwd(
q: Float[array_view.ArrayView, "*B T H D"],
k: Float[array_view.ArrayView, "*B t h D"],
v: Float[array_view.ArrayView, "*B t h D"],
bias: Float[Array, "*#B #H #T #t"] | None,
mask: Bool[Array, "*#B #H #T #t"] | None,
k_start: Int[Array, "*#B #H #T"] | None,
k_end: Int[Array, "*#B #H #T"] | None,
*,
logits_scale: float,
is_causal: bool,
q_k_dot_precision: precision_lib.DotPrecision,
weights_v_dot_precision: precision_lib.DotPrecision,
) -> Float[Array, "*B T H D"]:
"""Forward pass of Triton FlashAttention."""
orig_q_shape = q.shape
q = q.collapse(0, -3, allow_copy=True)
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
*_, seq_len_k, num_heads_kv, _ = k.shape
# Maybe broadcast `k`/`v` heads dimension.
kv_shape = (batch_size, seq_len_k, num_heads_kv, head_dim)
k = k.collapse(0, -3, allow_copy=True).broadcast_to(kv_shape)
v = v.collapse(0, -3, allow_copy=True).broadcast_to(kv_shape)
def get_bias_mask_view(x, dtype):
if x is None:
x = jnp.array([], dtype=dtype)
return array_view.ArrayView(x, shape=(0, 0, 0, 0), strides=(0, 0, 0, 0))
shape = orig_q_shape[:-3] + (num_heads_q, seq_len_q, seq_len_k)
return (
array_view.ArrayView(x)
.broadcast_to(shape)
.collapse(0, -3, allow_copy=True)
)
bias = get_bias_mask_view(bias, dtype=q.dtype)
mask = get_bias_mask_view(mask, dtype=jnp.bool_)
def get_range_view(x, seq_len):
if x is None:
x = jnp.array([], dtype=jnp.int32)
return array_view.ArrayView(x, shape=(0, 0, 0), strides=(0, 0, 0))
shape = orig_q_shape[:-3] + (num_heads_q, seq_len)
return (
array_view.ArrayView(x)
.broadcast_to(shape)
.collapse(0, -2, allow_copy=True)
)
k_start = get_range_view(k_start, seq_len_q)
k_end = get_range_view(k_end, seq_len_q)
block_q = 64
block_k = 64
return jt.triton_call(
q.base,
k.base,
v.base,
bias.base,
mask.base,
k_start.base,
k_end.base,
q.offset,
k.offset,
v.offset,
*q.strides,
*k.strides,
*v.strides,
*bias.strides,
*mask.strides,
k_start.strides,
k_end.strides,
*jt.utils.strides_from_shape(q.shape), # out strides.
num_heads_q,
num_heads_kv,
seq_len_q,
seq_len_k,
kernel=_fwd_kernel,
name="triton_flash_attention",
out_shape=jax.ShapeDtypeStruct(q.shape, q.dtype),
grid=(num_heads_q, triton.cdiv(seq_len_q, block_q), batch_size),
num_stages=2,
num_warps=4,
is_causal=is_causal,
use_attention_mask=(mask.size != 0),
use_k_start=(k_start.size != 0),
use_k_end=(k_end.size != 0),
use_bias=(bias.size != 0),
sm_scale=logits_scale,
block_q=block_q,
block_k=block_k,
head_dim=head_dim,
use_mask_q=(seq_len_q % block_q != 0),
use_mask_k=(seq_len_k % block_q != 0),
bias_bcast_sq=(bias.strides[-2] == 0),
mask_bcast_sq=(mask.strides[-2] == 0),
dot_fn_qk=triton_utils.get_tl_dot_fn(q_k_dot_precision),
dot_fn_kv=triton_utils.get_tl_dot_fn(weights_v_dot_precision),
).reshape(orig_q_shape)
def _as_batched_array_view(x, axis_size):
batched_shape = (axis_size,) + x.shape
batched_strides = (x.base.size // axis_size,) + x.strides
return dataclasses.replace(x, shape=batched_shape, strides=batched_strides)
def _fwd_vmap_rule(
axis_size, in_batched, *args, fn: jax.custom_batching.custom_vmap
):
"""`vmap` rule for Triton FlashAttention forward op."""
q, k, v, bias, mask, k_start, k_end = args
(
q_batched,
k_batched,
v_batched,
bias_batched,
mask_batched,
k_start_batched,
k_end_batched,
) = in_batched
if q_batched.base:
q = _as_batched_array_view(q, axis_size)
if k_batched.base:
k = _as_batched_array_view(k, axis_size)
if v_batched.base:
v = _as_batched_array_view(v, axis_size)
# Triton op requires `q`, `k`, `v` batch dims to be identical.
if q_batched.base and k_batched.base and v_batched.base:
if bias is not None and not bias_batched:
bias = jax.lax.broadcast_to_rank(bias, bias.ndim + 1)
if mask is not None and not mask_batched:
mask = jax.lax.broadcast_to_rank(mask, mask.ndim + 1)
if k_start is not None and not k_start_batched:
k_start = jax.lax.broadcast_to_rank(k_start, k_start.ndim + 1)
if k_end is not None and not k_end_batched:
k_end = jax.lax.broadcast_to_rank(k_end, k_end.ndim + 1)
out = fn(q, k, v, bias, mask, k_start, k_end)
out_batched = True
return out, out_batched
# Fallback to sequential loop.
q, k, v = map(jnp.asarray, (q, k, v))
in_batched = [
q_batched.base,
k_batched.base,
v_batched.base,
bias_batched,
mask_batched,
k_start_batched,
k_end_batched,
]
def f(q, k, v, *args, **kwargs):
q, k, v = map(array_view.ArrayView, (q, k, v))
return fn.fun(q, k, v, *args, **kwargs)
sequential_vmap = jax.custom_batching.sequential_vmap(f)
return sequential_vmap.vmap_rule(axis_size, in_batched, q, k, v, *args[3:])
def _decompose_mask(mask, q, k, q_indices, k_indices):
"""Decomposes `mask` into a mask array, `is_causal`, `k_start` and `k_end`."""
if mask is None:
return None, False, None, None
is_causal = False
k_start = None
k_end = None
if q_indices is None and k_indices is None:
mask, is_causal, k_start, k_end = mask.take("is_causal", "k_start", "k_end")
if k_start is not None:
k_start = jax.lax.broadcast_to_rank(k_start, 2)
if k_end is not None:
k_end = jax.lax.broadcast_to_rank(k_end, 2)
if is_causal: # Fold is_causal into k_end
k_end = jnp.minimum(k_end, jnp.arange(1, q.shape[-3] + 1))
is_causal = False
q_len_or_indices = q.shape[-3] if q_indices is None else q_indices
k_len_or_indices = k.shape[-3] if k_indices is None else k_indices
return (
mask.as_array(q_len_or_indices, k_len_or_indices),
is_causal,
k_start,
k_end,
)
@dataclasses.dataclass(frozen=True)
class TritonFlashAttention(base.DotProductAttention):
"""Triton FlashAttention implementation."""
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def _fwd(
self,
q: Float[array_view.ArrayView, "*B T H D"],
k: Float[array_view.ArrayView, "*B t h D"],
v: Float[array_view.ArrayView, "*B t h D"],
bias: Float[Array, "*#B #H #T #t"] | None,
*,
q_k_dot_precision: precision_lib.DotPrecision,
logits_dtype: jnp.dtype,
logits_scale: float,
mask: base.Mask | None,
weights_v_dot_precision: precision_lib.DotPrecision,
q_indices: Int[Array, "*#B #H T"] | None = None,
k_indices: Int[Array, "*#B #H t"] | None = None,
) -> Float[Array, "*B T H D"]:
if logits_dtype != jnp.float32:
raise ValueError("`logits_dtype` must be float32.")
kwargs = dict(
logits_scale=logits_scale,
q_k_dot_precision=q_k_dot_precision,
weights_v_dot_precision=weights_v_dot_precision,
)
def attend_fwd(
q,
k,
v,
bias,
mask_,
q_indices,
k_indices,
):
mask, is_causal, k_start, k_end = _decompose_mask(
mask_, q, k, q_indices, k_indices
)
fwd_closed_kwargs = dict(
is_causal=is_causal,
**kwargs,
)
fwd_closed = functools.partial(_fwd, **fwd_closed_kwargs)
fwd_closed = jax.custom_batching.custom_vmap(fwd_closed)
fwd_closed.def_vmap(functools.partial(_fwd_vmap_rule, fn=fwd_closed))
return fwd_closed(q, k, v, bias, mask, k_start, k_end)
return attend_fwd(
q,
k,
v,
bias,
mask,
q_indices,
k_indices,
)

View File

@@ -1,140 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""XLA implementation of scaled dot-product attention."""
import dataclasses
from alphafold3.jax.attention import attention_base as base
from alphafold3.jax.common import array_view
from alphafold3.jax.common import precision as precision_lib
import jax
import jax.numpy as jnp
import jaxtyping
from jaxtyping import Array, Float, Int # pylint: disable=g-multiple-import,g-importing-member
import typeguard
def _get_precision(
backend: str, precision: precision_lib.DotPrecision
) -> jax.lax.Precision:
if backend == "gpu" and precision == precision_lib.DotPrecision.F32_F32:
return jax.lax.Precision.HIGHEST
return jax.lax.Precision.DEFAULT
def einsum_with_dot_precision(
subscript: str,
a: jax.Array,
b: jax.Array,
*,
precision: precision_lib.DotPrecision,
) -> jax.Array:
"""Evaluate `fn` with the given precision."""
result = jnp.einsum(
subscript,
a.astype(precision.operand_dtype),
b.astype(precision.operand_dtype),
precision=_get_precision(jax.default_backend().lower(), precision),
preferred_element_type=precision.accumulator_dtype,
)
assert result.dtype == precision.accumulator_dtype
return result
def _softmax(x: jax.Array) -> jax.Array:
"""Computes softmax."""
# Always perform reductions in at least f32 precision.
dtype = jnp.promote_types(x.dtype, jnp.float32)
x_max = jnp.max(x.astype(dtype), axis=-1, keepdims=True)
unnormalized = jnp.exp(x - x_max)
denom = jnp.sum(unnormalized, axis=-1, keepdims=True)
return unnormalized / denom
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def _attend(
q: Float[array_view.ArrayView, "*B T H D"],
k: Float[array_view.ArrayView, "*B t #H D"],
v: Float[array_view.ArrayView, "*B t #H D"],
*,
q_k_dot_precision: precision_lib.DotPrecision,
logits_dtype: jnp.dtype,
logits_scale: float,
bias: Float[Array, "*#B #H #T #t"] | None,
mask: base.Mask | None,
weights_v_dot_precision: precision_lib.DotPrecision,
q_indices: Int[Array, "*#B #H T"] | None,
k_indices: Int[Array, "*#B #H t"] | None,
) -> Float[Array, "*B T H D"]:
"""Computes attention."""
logits = einsum_with_dot_precision(
"...qhd,...khd->...hqk", q, k, precision=q_k_dot_precision
).astype(logits_dtype)
logits *= logits_scale
if bias is not None:
logits += bias
if mask is not None:
q_len_or_indices = q.shape[-3] if q_indices is None else q_indices
k_len_or_indices = k.shape[-3] if k_indices is None else k_indices
mask = mask.as_array(q_len_or_indices, k_len_or_indices)
if mask is not None:
mask_value = float(jnp.finfo(logits.dtype).min)
logits = jnp.where(jnp.asarray(mask), logits, mask_value)
weights = _softmax(logits)
weights = weights.astype(v.dtype)
out = einsum_with_dot_precision(
"...hqk,...khd->...qhd", weights, v, precision=weights_v_dot_precision
).astype(q.dtype)
return out
@dataclasses.dataclass(frozen=True)
class XlaDotProductAttention(base.DotProductAttention):
"""XLA dot product attention function."""
_: dataclasses.KW_ONLY
def _fwd(
self,
q: Float[array_view.ArrayView, "*B T H D"],
k: Float[array_view.ArrayView, "*B t #H D"],
v: Float[array_view.ArrayView, "*B t #H D"],
*,
q_k_dot_precision: precision_lib.DotPrecision,
logits_dtype: jnp.dtype,
logits_scale: float,
bias: Float[Array, "*#B #H #T #t"] | None,
mask: base.Mask | None,
weights_v_dot_precision: precision_lib.DotPrecision,
q_indices: Int[Array, "*#B #H T"] | None = None,
k_indices: Int[Array, "*#B #H t"] | None = None,
) -> Float[Array, "*B T H D"]:
return _attend(
q,
k,
v,
bias=bias,
mask=mask,
q_indices=q_indices,
k_indices=k_indices,
q_k_dot_precision=q_k_dot_precision,
logits_dtype=logits_dtype,
logits_scale=logits_scale,
weights_v_dot_precision=weights_v_dot_precision,
)

View File

@@ -1,405 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Array view class and utilities."""
from collections.abc import Sequence
import dataclasses
import math
import operator
from types import EllipsisType # pylint: disable=g-importing-member
from typing import Any, Self, TypeAlias, TypeVar
import jax
import jax.experimental
from jax.experimental import pallas as pl
import jax.numpy as jnp
from jax.typing import ArrayLike # pylint: disable=g-importing-member
from jaxtyping import Int # pylint: disable=g-importing-member
import numpy as np
ArrayT: TypeAlias = Any
ScalarInt: TypeAlias = (
Int[ArrayT, ""] | Int[np.generic, ""] | Int[jnp.generic, ""]
)
Indexer: TypeAlias = int | ScalarInt | slice | pl.Slice | EllipsisType
@jax.tree_util.register_pytree_node_class
@dataclasses.dataclass(frozen=True)
class ArrayView:
"""A strided view of a JAX array."""
base: jax.Array
_: dataclasses.KW_ONLY
# These are set by `__post_init__` so `None` value is never seen after init.
shape: tuple[int, ...] = None # type: ignore
strides: tuple[int, ...] = None # type: ignore
offset: int | ScalarInt = 0
flatten_base: bool = True
def __post_init__(self):
if self.shape is None:
object.__setattr__(self, "shape", self.base.shape)
if self.strides is None:
object.__setattr__(self, "strides", pl.strides_from_shape(self.shape))
if len(self.shape) != len(self.strides):
raise ValueError("`shape` and `strides` must have the same length.")
# Within `jax.vjp`, we can get non-`Array` values here (such as `object`).
if isinstance(self.base, jax.Array):
if isinstance(self.offset, int):
if not (0 <= self.offset < max(self.base.size, 1)):
raise ValueError("Invalid `offset`.")
if self.flatten_base:
if len(self.base.shape) != 1:
object.__setattr__(self, "base", self.base.reshape((-1,)))
def tree_flatten(self):
if isinstance(self.offset, int):
return (self.base,), (self.offset, self.shape, self.strides)
return (self.base, self.offset), (self.shape, self.strides)
@classmethod
def tree_unflatten(cls, aux, children) -> Self:
base, offset, shape, strides = (*children, *aux)
return cls(base, shape=shape, strides=strides, offset=offset)
@property
def dtype(self) -> jnp.dtype:
return self.base.dtype
@property
def size(self) -> int:
return math.prod(self.shape)
@property
def ndim(self) -> int:
return len(self.shape)
@property
def T(self) -> Self: # pylint: disable=invalid-name
return self.transpose()
@property
def _index_dtype(self) -> jax.typing.DTypeLike:
i32_max = jnp.iinfo(jnp.int32).max
return jnp.int32 if (self.base.size <= i32_max) else jnp.int64
@property
def offsets(self) -> jax.Array:
"""Returns array of offsets into `base` for each element."""
with jax.experimental.enable_x64():
idxs = jnp.indices(self.shape, sparse=True, dtype=self._index_dtype)
return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))
def astype(self, dtype: jax.typing.DTypeLike) -> Self:
return self._replace(base=self.base.astype(dtype))
def broadcast_to_rank(self, rank: int) -> Self:
"""Returns a new view with the specified rank."""
if rank < self.ndim:
raise ValueError(f"Cannot broadcast to lower rank: {rank} < {self.ndim}.")
shape = (1,) * (rank - self.ndim) + self.shape
strides = (0,) * (rank - self.ndim) + self.strides
return self._replace(shape=shape, strides=strides)
def broadcast_to(self, shape: tuple[int, ...]) -> Self:
"""Returns a new view with the specified shape."""
view = self.broadcast_to_rank(len(shape))
strides = []
for dim_size, stride, target_size in zip(
view.shape, view.strides, shape, strict=True
):
if dim_size == target_size:
strides.append(stride)
elif dim_size == 1:
strides.append(0)
else:
raise ValueError(f"Cannot broadcast {self.shape} to {shape}.")
return self._replace(shape=shape, strides=strides)
def collapse(
self, start: int, stop: int | None = None, *, allow_copy: bool = False
) -> Self:
"""Returns a new view with the axis range collapsed into one axis."""
lo, hi, _ = slice(start, stop).indices(self.ndim)
if hi < lo:
raise ValueError(
"Invalid dimension range passed to collapse: "
f"{self.shape} [{start}:{stop}]"
)
shape = self.shape[:lo] + (-1,) + self.shape[hi:]
return self.reshape(shape, allow_copy=allow_copy)
def reshape(self, shape: Sequence[int], *, allow_copy: bool = False) -> Self:
"""Returns a new view with the specified shape."""
try:
return self._reshape(tuple(shape))
except ValueError:
if not allow_copy:
raise
return type(self)(jnp.array(self)).reshape(shape)
def _reshape(self, shape: tuple[int, ...]) -> Self:
"""Returns a new view with the specified shape."""
if (num_minus_one_dims := shape.count(-1)) > 0:
if num_minus_one_dims > 1:
raise ValueError("`shape` may only contain a single `-1` dimension.")
pos = shape.index(-1)
shape = list(shape)
shape[pos] = self.size // math.prod(d for d in shape if d != -1)
if math.prod(shape) != self.size:
raise ValueError("Mismatched number of elements.")
# Logic copied from `numpy` C++ code.
# Remove axes with length 1, to simplify logic below.
old_shape = [d for d in self.shape if d != 1]
old_strides = [s for i, s in enumerate(self.strides) if self.shape[i] != 1]
strides = [0] * len(shape)
# Axes currently being worked upon.
old_start, old_stop = 0, 1
new_start, new_stop = 0, 1
while (old_start < len(old_shape)) and (new_start < len(shape)):
old_axes_prod = old_shape[old_start]
new_axes_prod = shape[new_start]
while old_axes_prod != new_axes_prod:
if old_axes_prod < new_axes_prod:
old_axes_prod *= old_shape[old_stop]
old_stop += 1
else:
new_axes_prod *= shape[new_stop]
new_stop += 1
# Check if original axes can be combined.
for i in range(old_start, old_stop - 1):
if old_strides[i] != old_shape[i + 1] * old_strides[i + 1]:
raise ValueError("Cannot combine axes non-contiguous in memory.")
# Calculate new strides.
strides[new_stop - 1] = old_strides[old_stop - 1]
for i in range(new_stop - 1, new_start, -1):
strides[i - 1] = strides[i] * shape[i]
old_start, old_stop = old_stop, old_stop + 1
new_start, new_stop = new_stop, new_stop + 1
return self._replace(shape=shape, strides=strides)
def split(
self, indices_or_sections: int | Sequence[int], axis: int = 0
) -> tuple[Self, ...]:
"""Splits the view into multiple slice views."""
if isinstance(indices_or_sections, int):
if self.shape[axis] % indices_or_sections != 0:
raise ValueError("Axis size is not divisible by number of sections.")
chunk = self.shape[axis] // indices_or_sections
indices_or_sections = [i * chunk for i in range(1, indices_or_sections)]
los = (0, *indices_or_sections)
his = (*indices_or_sections, None)
slice_prefix = (slice(None),) * _canonicalize_axis(axis, self.ndim)
return tuple(self[*slice_prefix, slice(lo, hi)] for lo, hi in zip(los, his))
def swapaxes(self, axis1: int, axis2: int) -> Self:
"""Returns a new view with the specified axis swapped."""
axes = list(range(self.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
return self.transpose(axes)
def moveaxis(self, source: int, destination: int) -> Self:
"""Returns a new view with the specified axis moved."""
source, destination = source % self.ndim, destination % self.ndim
axes = list(range(self.ndim))
del axes[source]
axes.insert(destination, source)
return self.transpose(axes)
def transpose(self, axes: Sequence[int] | None = None) -> Self:
"""Returns a new view with the specified axes order."""
if axes is None:
axes = tuple(reversed(range(self.ndim)))
if len(axes) != self.ndim:
raise ValueError("`axes` must have the same dimensionality as the array.")
shape = tuple(self.shape[a] for a in axes)
strides = tuple(self.strides[a] for a in axes)
return self._replace(shape=shape, strides=strides)
def __getitem__(self, idxs: Indexer | tuple[Indexer, ...]) -> Self:
if not isinstance(idxs, tuple):
idxs = (idxs,)
if len(idxs) > self.ndim:
raise ValueError("Too many slice indices.")
num_ellipses = idxs.count(Ellipsis)
if num_ellipses > 1:
raise ValueError("Multiple `...` are not supported.")
elif num_ellipses == 0:
idxs += (Ellipsis,) # `[a:b]` is equivalent to `[a:b, ...]`.
# Replace `...` with slices that take the entirety of the missing axes.
ellipsis_idx = idxs.index(Ellipsis)
ellipsis_slices = (slice(None),) * (self.ndim - len(idxs) + 1)
idxs = idxs[:ellipsis_idx] + ellipsis_slices + idxs[ellipsis_idx + 1 :]
shape = []
strides = []
with jax.experimental.enable_x64():
def as_index(x):
return x.astype(self._index_dtype) if isinstance(x, jax.Array) else x
offset = as_index(self.offset)
for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
if isinstance(idx, int):
if not (-dim <= idx < dim):
raise ValueError("Slice index out of range.")
offset += stride * (idx % dim)
elif isinstance(idx, ScalarInt):
offset += stride * as_index(idx)
elif isinstance(idx, slice):
start, stop, step = idx.indices(dim)
if step >= 0:
shape.append(pl.cdiv(stop - start, step))
else:
shape.append(pl.cdiv(start - stop, -step))
strides.append(stride * step)
offset += stride * start
elif isinstance(idx, pl.Slice):
shape.append(idx.size)
strides.append(stride * idx.stride)
offset += stride * as_index(idx.start)
else:
raise ValueError(f"Unexpected indexer: {idx}")
return self._replace(shape=shape, strides=strides, offset=offset)
def _replace(self, **kwargs) -> Self:
if "shape" in kwargs:
kwargs["shape"] = tuple(kwargs["shape"])
if "strides" in kwargs:
kwargs["strides"] = tuple(kwargs["strides"])
return dataclasses.replace(self, **kwargs)
def set(self, value: ArrayLike | "ArrayView") -> Self:
"""Returns a new view with the views values set to `value`."""
if any(s == 0 for s in self.strides):
raise ValueError("Cannot set values on a broadcasted array.")
# Try to just transpose the value, if possible.
major_to_minor = np.argsort(-np.array(self.strides), kind="stable")
value = jnp.array(value)
value_transposed = value.transpose(major_to_minor)
if (
self.transpose(major_to_minor).strides
== ArrayView(value_transposed).strides
):
base = jax.lax.dynamic_update_slice(
self.base, value_transposed.flatten(), (self.offset,)
)
else:
base = self.base.at[self.offsets].set(value)
return self._replace(base=base)
def __jax_array__(self) -> jax.Array:
"""Returns values as a dense array."""
# Try to express using transpose, slice, and reshape, to encourage XLA to
# fuse into other ops, rather than materialising the values. Otherwise,
# fall back to using a gather.
if (self.ndim == 0) or any(s < 0 for s in self.strides):
return self.base[self.offsets]
major_to_minor = np.argsort(-np.array(self.strides), kind="stable")
# Construct a shape that gives us the correct strides.
bcast_axes = []
shape = []
for axis in major_to_minor[::-1]: # minor to major
stride = self.strides[axis]
if stride == 0:
bcast_axes.append(axis)
shape.append(1)
continue
if stride % math.prod(shape) != 0:
raise ValueError("Cannot express as a reshape, then slice.")
shape.append(stride // math.prod(shape))
if self.base.size % math.prod(shape) != 0:
return self.base[self.offsets]
shape = [self.base.size // math.prod(shape), *reversed(shape)]
slice_sizes = [
*(1 if a in bcast_axes else self.shape[a] for a in major_to_minor),
1,
]
if shape[0] == self.shape[major_to_minor[0]]:
needs_offset_slice = False
elif not isinstance(self.offset, int):
needs_offset_slice = True
else:
start_indices = np.unravel_index(self.offset, shape)
end_indices = [s + size for s, size in zip(start_indices, slice_sizes)]
needs_offset_slice = any(e > dim for e, dim in zip(end_indices, shape))
if needs_offset_slice:
shape[0] = self.shape[major_to_minor[0]]
size = math.prod(shape)
# The pad is necessary to ensure that the dynamic slice is in range.
vals = jnp.pad(self.base, (0, size))
vals = jax.lax.dynamic_slice(vals, (self.offset,), (size,))
start_indices = [0] * len(shape)
else:
vals = self.base
start_indices = jnp.unravel_index(self.offset, shape)
vals = vals.reshape(shape)
vals = jax.lax.dynamic_slice(vals, start_indices, slice_sizes)[..., 0]
# Move axes from their physical ordering to their logical ordering.
vals = vals.transpose(np.argsort(major_to_minor))
return jnp.broadcast_to(vals, self.shape)
def as_array_view(x: jax.Array | ArrayView) -> ArrayView:
return x if isinstance(x, ArrayView) else ArrayView(x)
T = TypeVar("T", jax.Array, ArrayView)
def zeros_like(x: T) -> T:
if isinstance(x, ArrayView):
return x._replace(base=jnp.zeros_like(x.base))
return jnp.zeros_like(x)
def _canonicalize_axis(axis, num_dims) -> int:
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
f"axis {axis} is out of bounds for array of dimension {num_dims}"
)
if axis < 0:
axis = axis + num_dims
return axis

View File

@@ -1,92 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Precision classes and utilities."""
import enum
import jax
import jax.numpy as jnp
@enum.unique
class DotPrecision(enum.Enum):
"""Precision for `dot` operation.
Naming scheme: {OPERAND_DTYPE}_{ACCUMULATOR_DTYPE}[_{NUM_PASSES}x]
"""
BF16_F32 = "bf16_f32"
# GPU only precisions.
F32_F32 = "f32_f32" # Full f32 precision (doesn't use TensorCores).
TF32_F32 = "tf32_f32" # Equivalent to `DEFAULT`/`HIGH` on GPU.
TF32_F32_3X = "tf32_f32_3x"
F16_F16 = "f16_f16"
F16_F32 = "f16_f32"
@property
def operand_dtype(self) -> jnp.dtype:
match self:
case DotPrecision.BF16_F32:
return jnp.bfloat16
case DotPrecision.F16_F16 | DotPrecision.F16_F32:
return jnp.float16
case _:
return jnp.float32
@property
def accumulator_dtype(self) -> jnp.dtype:
return jnp.float16 if (self == DotPrecision.F16_F16) else jnp.float32
_JAX_GPU_PRECISION_MAP = {
(jnp.float16, jax.lax.Precision.DEFAULT): DotPrecision.F16_F32,
(jnp.bfloat16, jax.lax.Precision.DEFAULT): DotPrecision.BF16_F32,
(jnp.float32, jax.lax.Precision.DEFAULT): DotPrecision.TF32_F32,
(jnp.float32, jax.lax.Precision.HIGH): DotPrecision.TF32_F32,
(jnp.float32, jax.lax.Precision.HIGHEST): DotPrecision.F32_F32,
}
_JAX_CPU_PRECISION_MAP = {
(jnp.float16, jax.lax.Precision.DEFAULT): DotPrecision.F16_F32,
(jnp.bfloat16, jax.lax.Precision.DEFAULT): DotPrecision.F32_F32,
(jnp.float32, jax.lax.Precision.DEFAULT): DotPrecision.F32_F32,
(jnp.float32, jax.lax.Precision.HIGH): DotPrecision.F32_F32,
(jnp.float32, jax.lax.Precision.HIGHEST): DotPrecision.F32_F32,
}
def _create_jax_precision_map():
precision_map = {}
for (dtype, jax_precision), dot_precision in _JAX_GPU_PRECISION_MAP.items():
precision_map[("gpu", jnp.dtype(dtype), jax_precision)] = dot_precision
for (dtype, jax_precision), dot_precision in _JAX_CPU_PRECISION_MAP.items():
precision_map[("cpu", jnp.dtype(dtype), jax_precision)] = dot_precision
return precision_map
_JAX_PRECISION_MAP = _create_jax_precision_map()
def get_equivalent_dot_precision(
a_dtype: jnp.dtype, b_dtype: jnp.dtype, jax_precision: jax.lax.Precision
) -> DotPrecision:
"""Returns `DotPrecision` replicating default XLA behaviour."""
if a_dtype != b_dtype:
raise ValueError("Cannot infer precision if operand types differ.")
backend = jax.default_backend().lower()
if (jax_precision != jax.lax.Precision.DEFAULT) and (a_dtype != jnp.float32):
raise ValueError(
"`jax.lax.Precision` values other than `DEFAULT` only have an effect if"
" the operand type is `float32`."
)
return _JAX_PRECISION_MAP[(backend, a_dtype, jax_precision)]

View File

@@ -1,125 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Triton utils."""
from collections.abc import Callable, Mapping
from alphafold3.jax.common import precision as precision_lib
import jax
import jax.numpy as jnp
import triton
import triton.language as tl
_JNP_TO_TL_DTYPES: Mapping[jnp.dtype, tl.dtype] = {
jnp.bool_: tl.int1,
jnp.int8: tl.int8,
jnp.int16: tl.int16,
jnp.int32: tl.int32,
jnp.int64: tl.int64,
jnp.uint8: tl.uint8,
jnp.uint16: tl.uint16,
jnp.uint32: tl.uint32,
jnp.uint64: tl.uint64,
jnp.float16: tl.float16,
jnp.bfloat16: tl.bfloat16,
jnp.float32: tl.float32,
jnp.float64: tl.float64,
}
def jnp_to_tl_dtype(jnp_dtype: jnp.dtype) -> tl.dtype:
return _JNP_TO_TL_DTYPES[jnp_dtype]
def get_tl_dot_fn(
precision: precision_lib.DotPrecision,
) -> Callable[..., tl.tensor]:
"""Returns a tl `dot` implementation with the specified precision.
Args:
precision: The `dot` precision.
"""
if not is_precision_supported(precision):
raise ValueError(f'Unsupported dot precision: {precision}')
if precision == precision_lib.DotPrecision.TF32_F32_3X:
return _dot_tf32_f32_3x
in_dtype = jnp_to_tl_dtype(precision.operand_dtype)
out_dtype = jnp_to_tl_dtype(precision.accumulator_dtype)
allow_tf32 = precision == precision_lib.DotPrecision.TF32_F32
@tl.core.extern
def _dot_fn(
a: tl.core.tensor,
b: tl.core.tensor,
*,
trans_a: bool = False,
trans_b: bool = False,
_builder,
):
if in_dtype == tl.float32:
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
else:
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
a = a.to(in_dtype, _builder=_builder)
b = b.to(in_dtype, _builder=_builder)
a = tl.trans(a, _builder=_builder) if trans_a else a
b = tl.trans(b, _builder=_builder) if trans_b else b
return tl.dot(
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
)
return _dot_fn
def is_precision_supported(precision: precision_lib.DotPrecision) -> bool:
return precision in {
precision_lib.DotPrecision.F32_F32,
precision_lib.DotPrecision.TF32_F32,
precision_lib.DotPrecision.F16_F32,
precision_lib.DotPrecision.BF16_F32,
precision_lib.DotPrecision.TF32_F32_3X,
}
@triton.jit
def _dot_tf32_f32_3x(a, b, trans_a=False, trans_b=False):
"""Perform the 3-pass tf32 dot function."""
tl.static_assert(a.dtype == tl.float32)
tl.static_assert(b.dtype == tl.float32)
a_ = (a.to(tl.uint32, bitcast=True) & 0xFFFFE000).to(tl.float32, bitcast=True)
b_ = (b.to(tl.uint32, bitcast=True) & 0xFFFFE000).to(tl.float32, bitcast=True)
a_err = a - a_
b_err = b - b_
if trans_a:
a_ = tl.trans(a_)
a_err = tl.trans(a_err)
if trans_b:
b_ = tl.trans(b_)
b_err = tl.trans(b_err)
# Add smallest terms first for better accuracy.
return tl.dot(a_, b_, out_dtype=tl.float32) + (
tl.dot(a_, b_err, out_dtype=tl.float32)
+ tl.dot(a_err, b_, out_dtype=tl.float32)
)
def has_triton_support() -> bool:
"""Returns True if Triton is supported by the default JAX device."""
if jax.default_backend() != 'gpu':
return False
# Only currently supported for Ampere and above.
return float(jax.devices()[0].compute_capability) >= 8.0

View File

@@ -1,126 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Pallas block load / store utilities."""
from collections.abc import Sequence
from typing import Any, TypeAlias
from alphafold3.jax.common import array_view
import jax
import jax.experimental
from jax.experimental import pallas as pl
import jax.numpy as jnp
import jaxtyping
from jaxtyping import Int # pylint: disable=g-importing-member
import numpy as np
import typeguard
ArrayT: TypeAlias = Any
ScalarInt: TypeAlias = (
Int[ArrayT, ""] | Int[np.generic, ""] | Int[jnp.generic, ""]
)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def load_block(
ref,
idx: Sequence[int | ScalarInt],
*,
block_shape: Sequence[int | None],
other=None,
**kwargs,
) -> jax.Array:
"""Loads a block from the given `ref`, masking where necessary."""
idx, mask = _get_block_indexer_and_mask(ref, idx, block_shape=block_shape)
if isinstance(ref, array_view.ArrayView):
idx = ref[idx].offsets
ref = ref.base
other = None if mask is None else other
with jax.experimental.enable_x64():
return pl.load(ref, idx, mask=mask, other=other, **kwargs)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def store_block(
ref,
val: jax.Array,
idx: Sequence[int | ScalarInt],
*,
block_shape: Sequence[int | None] | None = None,
**kwargs,
):
"""Stores a block from the given `ref`, masking where necessary."""
if block_shape is None:
block_shape = val.shape
idx, mask = _get_block_indexer_and_mask(ref, idx, block_shape=block_shape)
if isinstance(ref, array_view.ArrayView):
idx = ref[idx].offsets
ref = ref.base
with jax.experimental.enable_x64():
pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)
def in_bounds_mask(
idx: Sequence[int | slice | pl.Slice | jax.Array],
shape: Sequence[int],
*,
check: Sequence[bool] | None = None,
) -> jax.Array | None:
"""Returns a boolean mask denoting which indices are within bounds.
Args:
idx: Indices for each dimension.
shape: Shape designating the valid bounds.
check: Whether or not to check bounds in each dimension. Useful for ignoring
indices known to be in bounds. Defaults to all True.
"""
if check is None:
check = [True] * len(shape)
# Remove `int` indexed dims (mask shape must match slice result shape).
shape = [dim for i, dim in enumerate(shape) if not isinstance(idx[i], int)]
check = [chk for i, chk in enumerate(check) if not isinstance(idx[i], int)]
idx = [idx for idx in idx if not isinstance(idx, int)]
mask = None
for i, (dim_idx, dim, chk) in enumerate(zip(idx, shape, check, strict=True)):
if not chk:
continue
if isinstance(dim_idx, slice):
dim_idx = pl.Slice.from_slice(dim_idx, dim)
if isinstance(dim_idx, pl.Slice):
dim_idx = dim_idx.start + dim_idx.stride * jnp.arange(dim_idx.size)
if dim_idx.ndim != 1:
raise NotImplementedError("Only one-dimensional indices are supported.")
bcast_axes = [a for a in range(len(shape)) if a != i]
dim_mask = jnp.expand_dims(dim_idx < dim, bcast_axes)
mask = dim_mask if mask is None else (mask & dim_mask)
return mask
def _get_block_indexer_and_mask(
ref, idx: Sequence[int | ScalarInt], *, block_shape: Sequence[int | None]
) -> tuple[tuple[int | slice | pl.Slice, ...], jax.Array | None]:
"""Return indices and mask for loading / storing a block."""
shape = ref.shape
idxs = []
check = []
for dim, block_idx, block_dim in zip(shape, idx, block_shape, strict=True):
if block_dim is None:
idxs.append(block_idx)
check.append(False)
else:
idxs.append(pl.dslice(block_dim * block_idx, block_dim))
check.append(dim % block_dim != 0)
return tuple(idxs), in_bounds_mask(idxs, shape, check=check)

View File

@@ -1,124 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Public API for gated linear unit functions."""
from collections.abc import Callable
import typing
from typing import Literal, TypeAlias
from alphafold3.jax.common import array_view
from alphafold3.jax.common import triton_utils
from alphafold3.jax.gated_linear_unit import gated_linear_unit_base
from alphafold3.jax.gated_linear_unit import matmul_ext
import jax
import jaxtyping
from jaxtyping import Array, Float # pylint: disable=g-importing-member,g-multiple-import
import typeguard
Implementation: TypeAlias = Literal['xla', 'triton']
class PallasGatedLinearUnit(gated_linear_unit_base.GatedLinearUnit):
"""Pallas gated linear unit."""
def _fwd(self, x, weight, *, activation, precision):
weight_view = array_view.ArrayView(weight)
return self.apply_vmap_rule_forward(
matmul_ext.gated_linear_unit,
activation=activation,
precision=precision,
)(
x,
weight_view[:, 1],
weight_view[:, 0],
)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def gated_linear_unit(
x: Float[Array, '*B M K'],
weight: Float[Array, 'K 2 N'],
*,
activation: Callable[[jax.Array], jax.Array] | None = None,
precision: jax.lax.Precision | None = None,
implementation: Implementation | None = None,
) -> Float[Array, '*B M N']:
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
This is SwiGLU when `activation=jax.nn.swish`, GEGLU when
`activation=jax.nn.gelu`, REGLU when `activation=jax.nn.relu`, and GLU when
`activation=jax.nn.sigmoid` (https://arxiv.org/abs/2002.05202).
Args:
x: the input array.
weight: the combined weight array.
activation: optional activation function.
precision: specifies the matrix multiplication precision. Either `None`
(default), which means the default precision for the backend, or a
`jax.lax.Precision` enum.
implementation: if `None` (default), an implementation is automatically
chosen. 'xla' will use standard XLA and work on any platform, and 'triton'
will use a fused Triton GPU kernel. Only a subset of data types, shapes
and GPUs are supported by 'triton', with an exception thrown in this case.
Raises:
NotImplementedError: if `implementation='triton'` does not support a given
input or device.
ValueError: if the arguments are invalid.
Returns:
The output array.
"""
match implementation:
case 'triton':
if not triton_utils.has_triton_support():
raise NotImplementedError('Triton not supported on this platform.')
case _:
...
if x.dtype.name != weight.dtype.name:
raise ValueError(
f'Input and weight must have the same dtype. {x.dtype} !='
f' {weight.dtype}'
)
if implementation is not None:
named_args = typing.get_args(Implementation)
if implementation not in named_args:
raise ValueError(
f'Unsupported named implementation. Must be one of {named_args}.'
)
if implementation is None or implementation == 'triton':
try:
return PallasGatedLinearUnit()(
x=x,
weight=weight,
activation=activation,
precision=precision,
)
# When `implementation=None`, we must catch any exception, and use XLA
# as a fallback. As we rely on a third-party library (Triton), it might
# not be possible to enumerate all possible exceptions that could be
# thrown, hence catching the broadest possible one.
except Exception as e: # pylint: disable=broad-exception-caught
if implementation == 'triton':
raise e
return gated_linear_unit_base.gated_linear_unit_xla(
x=x,
weight=weight,
activation=activation,
precision=precision,
)

View File

@@ -1,130 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Common types for gated linear unit kernels."""
import abc
from collections.abc import Callable
import functools
from typing import Any
import jax
import jax.numpy as jnp
import jaxtyping
from jaxtyping import Array, Float # pylint: disable=g-importing-member,g-multiple-import
import typeguard
class GatedLinearUnit(abc.ABC):
"""Gated linear unit."""
def __call__(
self,
x: Float[Array, '*B M K'],
weight: Float[Array, 'K 2 N'],
*,
activation: Callable[[jax.Array], jax.Array] | None = None,
precision: jax.lax.Precision | None = None,
**kwargs,
) -> Float[Array, '*B M N']:
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
Args:
x: the input array.
weight: the combined weight array.
activation: optional activation function.
precision: specifies the matrix multiplication precision. Either `None`
(default), which means the default precision for the backend, or a
`jax.lax.Precision` enum.
**kwargs: additional keyword arguments.
Returns:
The output array.
"""
return self._fwd(
x, weight, activation=activation, precision=precision, **kwargs
)
# Default vmap rule.
@property
def vmap_rule_forward(self) -> Callable[..., Any]:
def _vmap_rule(
axis_size, in_batched, *args, fn: jax.custom_batching.custom_vmap
):
sequential_vmap = jax.custom_batching.sequential_vmap(fn.fun)
return sequential_vmap.vmap_rule(axis_size, in_batched, *args)
return _vmap_rule
def apply_vmap_rule_forward(
self, fn: Callable[..., Any], **kwargs
) -> jax.custom_batching.custom_vmap:
fn_closed = functools.partial(fn, **kwargs)
fn_closed = jax.custom_batching.custom_vmap(fn_closed)
vmap_rule = functools.partial(self.vmap_rule_forward, fn=fn_closed)
fn_closed.def_vmap(vmap_rule)
return fn_closed
@abc.abstractmethod
def _fwd(
self,
x: Float[Array, '*B M K'],
weight: Float[Array, 'K 2 N'],
*,
activation: Callable[[jax.Array], jax.Array] | None,
precision: jax.lax.Precision | None,
) -> Float[Array, '*B M N']:
"""Gated linear unit."""
...
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def gated_linear_unit_xla(
x: Float[Array, '*B M K'],
weight: Float[Array, 'K 2 N'],
*,
activation: Callable[[jax.Array], jax.Array] | None = None,
precision: jax.lax.Precision | None = None,
) -> Float[Array, '*B M N']:
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
This is SwiGLU when `activation=jax.nn.swish`, GEGLU when
`activation=jax.nn.gelu`, REGLU when `activation=jax.nn.relu`, and GLU when
`activation=jax.nn.sigmoid` (https://arxiv.org/abs/2002.05202).
Args:
x: the input array.
weight: the combined weight array.
activation: optional activation function.
precision: specifies the matrix multiplication precision. Either `None`
(default), which means the default precision for the backend, or a
`jax.lax.Precision` enum.
Returns:
The output array.
"""
weight_reshaped = jax.lax.collapse(
weight, start_dimension=-2, stop_dimension=None
)
assert weight_reshaped.ndim == 2
y = jnp.dot(x, weight_reshaped, precision=precision)
# Apply activation and compute product of FP8/FP16/BF16 in FP32.
y = y.astype(jnp.promote_types(x.dtype, jnp.float32))
a, b = jnp.split(y, 2, axis=-1)
out = a * b if activation is None else activation(a) * b
out = out.astype(x.dtype)
return out

View File

@@ -1,78 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Auto-tuned configs for matmul."""
import dataclasses
import functools
import math
import jax
from jax.experimental import pallas as pl
@dataclasses.dataclass(frozen=True, kw_only=True)
class Config:
block_m: int
block_n: int
block_k: int
num_warps: int
num_stages: int
@functools.cache
def _get_best_block_size(
m: int, n: int, k: int, core_count: int
) -> tuple[int, int, int]:
"""Returns the best block size for the given shape."""
min_block_dim = 32
block_m = min(max(min_block_dim, pl.next_power_of_2(m)), 128)
block_n = min(max(min_block_dim, pl.next_power_of_2(n)), 256)
block_n = min(block_n, (128 * 128) // block_m)
block_k = 32
split_k = 1
num_blocks = pl.cdiv(m, block_m) * pl.cdiv(n, block_n)
while num_blocks < core_count:
if block_m > min_block_dim:
block_m //= 2
num_blocks = pl.cdiv(m, block_m) * pl.cdiv(n, block_n)
elif split_k * block_k < pl.next_power_of_2(k):
split_k *= 2
num_blocks *= 2
else:
break
return block_m, block_n, block_k
def _abstractify(x):
return jax.api_util.shaped_abstractify(x) if isinstance(x, jax.Array) else x
def get_config(
x: jax.Array, w: jax.Array, core_count: int | None = None
) -> Config:
"""Returns a config for the given args."""
if core_count is None:
core_count = jax.devices()[0].core_count
x = _abstractify(x)
w = _abstractify(w)
m, k = math.prod(x.shape[:-1]), x.shape[-1]
n = w.shape[1]
if n >= m: # Prefer `block_n` > `block_m`.
block_m, block_n, block_k = _get_best_block_size(m, n, k, core_count)
else:
block_n, block_m, block_k = _get_best_block_size(n, m, k, core_count)
return Config(
block_m=block_m,
block_n=block_n // 2, # Halve `block_n` as we read two `w` blocks.
block_k=block_k,
num_warps=4,
num_stages=4,
)

View File

@@ -1,273 +0,0 @@
# Copyright 2024 DeepMind Technologies Limited
#
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
#
# To request access to the AlphaFold 3 model parameters, follow the process set
# out at https://github.com/google-deepmind/alphafold3. You may only use these
# if received directly from Google. Use is subject to terms of use available at
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
"""Extended matmul ops."""
from collections.abc import Callable
import functools
from typing import Any, TypeAlias
from alphafold3.jax.common import array_view
from alphafold3.jax.common import triton_utils
from alphafold3.jax.gated_linear_unit import block
from alphafold3.jax.gated_linear_unit import matmul_config
import jax
from jax._src.state import discharge
from jax.experimental import pallas as pl
import jax.numpy as jnp
import jaxtyping
from jaxtyping import Array, Float, Int # pylint: disable=g-importing-member,g-multiple-import
import numpy as np
import typeguard
ArrayView = array_view.ArrayView
PyTree: TypeAlias = Any
ArrayT: TypeAlias = Any
ScalarInt: TypeAlias = (
Int[ArrayT, ''] | Int[np.generic, ''] | Int[jnp.generic, '']
)
def _get_group_cache_usage(
group_size_m, num_blocks_m, num_blocks_n, block_m_bytes, block_n_bytes
) -> int:
"""Returns the cache usage in bytes for the given group size."""
num_live_progs = jax.devices()[0].core_count
num_live_blocks_n = min(pl.cdiv(num_live_progs, group_size_m), num_blocks_n)
num_live_groups = pl.cdiv(num_live_progs, group_size_m * num_live_blocks_n)
num_live_blocks_m = min(num_live_groups * group_size_m, num_blocks_m)
return num_live_blocks_m * block_m_bytes + num_live_blocks_n * block_n_bytes
def _get_pids(
pid, num_blocks_m, num_blocks_n, group_size_m
) -> tuple[ScalarInt, ScalarInt]:
"""Returns the program IDs in each grid axis."""
# Use `floor_divide` and `remainder` (instead of lax.div and lax.rem)
# to handle dtypes: pid (int32) vs. num_blocks_n (int64) when `jax_enable_x64`
# is set.
if group_size_m == 1:
return jnp.floor_divide(pid, num_blocks_n), jnp.remainder(pid, num_blocks_n)
num_progs_in_group = group_size_m * num_blocks_n
group_start_m = jnp.floor_divide(pid, num_progs_in_group) * group_size_m
group_size_m = jnp.minimum(num_blocks_m - group_start_m, group_size_m)
pid_m = group_start_m + jnp.remainder(pid, group_size_m)
pid_n = jnp.floor_divide(jnp.remainder(pid, num_progs_in_group), group_size_m)
return pid_m, pid_n
def _get_best_pids(
pid, *, m, n, block_m, block_n, a_dtype_bytes, b_dtype_bytes
) -> tuple[ScalarInt, ScalarInt]:
"""Returns the grouped program IDs that minimize cache usage."""
num_blocks_m = pl.cdiv(m, block_m)
num_blocks_n = pl.cdiv(n, block_n)
block_m_bytes = block_m * a_dtype_bytes
block_n_bytes = block_n * b_dtype_bytes
num_live_progs = jax.devices()[0].core_count
def group_size_m_usage(group_size_m):
return _get_group_cache_usage(
group_size_m, num_blocks_m, num_blocks_n, block_m_bytes, block_n_bytes
)
group_size_m = min(
range(1, min(num_live_progs, num_blocks_m) + 1), key=group_size_m_usage
)
def group_size_n_usage(group_size_n):
return _get_group_cache_usage(
group_size_n, num_blocks_n, num_blocks_m, block_n_bytes, block_m_bytes
)
group_size_n = min(
range(1, min(num_live_progs, num_blocks_n) + 1), key=group_size_n_usage
)
if group_size_m_usage(group_size_m) <= group_size_n_usage(group_size_n):
pid_m, pid_n = _get_pids(pid, num_blocks_m, num_blocks_n, group_size_m)
else:
pid_n, pid_m = _get_pids(pid, num_blocks_n, num_blocks_m, group_size_n)
return pid_m, pid_n
def _apply_epilogue(
epilogue: Callable[..., jax.Array], x: jax.Array, args: PyTree
) -> jax.Array:
"""Applies the epilogue to the output."""
# Convert array view arguments to JAX arrays. This means that we can use the
# array view slices, rather than the gather that discharging state gives us.
is_leaf = lambda x: isinstance(x, ArrayView)
args_flat, args_tree = jax.tree.flatten((x, args), is_leaf=is_leaf)
args_flat = tuple(map(jnp.array, args_flat))
def epilogue_wrapper(refs):
x_ref, arg_refs = args_tree.unflatten(refs)
x_ref[:] = epilogue(x_ref[:], arg_refs, 0, 0)
return discharge.run_state_reference(epilogue_wrapper)(args_flat)[0]
def _gated_linear_unit_kernel(
x_ref,
w_ref,
v_ref,
_, # Destination, aliased with `out_ref`.
epilogue_in_refs,
out_ref,
*,
block_m,
block_n,
block_k,
activation,
precision,
epilogue,
):
"""Pallas GLU kernel."""
m = x_ref.shape[0]
n = w_ref.shape[1]
pid_m, pid_n = _get_best_pids(
pl.program_id(0),
m=m,
n=n,
block_m=block_m,
block_n=block_n,
a_dtype_bytes=jnp.dtype(x_ref.dtype).itemsize,
b_dtype_bytes=jnp.dtype(w_ref.dtype).itemsize * 2, # Two blocks.
)
def body(i, acc):
x = block.load_block(x_ref, (pid_m, i), block_shape=(block_m, block_k))
w = block.load_block(w_ref, (i, pid_n), block_shape=(block_k, block_n))
v = block.load_block(v_ref, (i, pid_n), block_shape=(block_k, block_n))
acc[0] += pl.dot(x, w.astype(x.dtype), precision=precision)
acc[1] += pl.dot(x, v.astype(x.dtype), precision=precision)
return acc
num_iters = pl.cdiv(x_ref.shape[-1], block_k)
acc0 = jnp.zeros((block_m, block_n), dtype=jnp.float32)
acc1 = jnp.zeros((block_m, block_n), dtype=jnp.float32)
proj, gates = jax.lax.fori_loop(0, num_iters, body, init_val=[acc0, acc1])
proj = proj.astype(x_ref.dtype).astype(jnp.float32)
gates = gates.astype(x_ref.dtype).astype(jnp.float32)
out = proj * (gates if activation is None else activation(gates))
if epilogue is not None:
out = epilogue(out, epilogue_in_refs, pid_m, pid_n)
block.store_block(out_ref, out, (pid_m, pid_n))
def _gated_linear_unit(
x: Float[Array | ArrayView, 'M K'],
weights_projection: Float[Array | ArrayView, 'K N'],
weights_gate: Float[Array | ArrayView, 'K N'],
*,
dst: Float[ArrayView, 'M N'] | None = None,
activation: Callable[[jax.Array], jax.Array] | None,
epilogue: Any, # Callable[..., Any] | None - breaks `typed`.
epilogue_args: PyTree,
precision: jax.lax.Precision | None,
) -> jax.Array: # Float[Array, 'M N'] | Float[Array, 'N M']
"""Applies a gated linear unit (arxiv.org/abs/1612.08083)."""
if epilogue is None and epilogue_args is not None:
raise ValueError('`epilogue_args` is specified but `epilogue` is None.')
name = 'pallas_glu'
if activation is not None:
name += f'_{getattr(activation, "__name__", repr(activation))}'
if epilogue is not None:
name += f'_{getattr(epilogue, "__name__", repr(epilogue))}'
w = weights_projection
config = matmul_config.get_config(x, w)
m = x.shape[0]
n = w.shape[1]
kernel = functools.partial(
_gated_linear_unit_kernel,
block_m=config.block_m,
block_n=config.block_n,
block_k=config.block_k,
activation=activation,
precision=precision,
epilogue=epilogue,
)
if dst is None:
input_output_aliases = {}
else:
input_output_aliases = {3: 0}
compiler_params = dict(
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
)
return pl.pallas_call(
kernel,
name=name,
grid=(pl.cdiv(m, config.block_m) * pl.cdiv(n, config.block_n),),
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
input_output_aliases=input_output_aliases,
compiler_params=compiler_params,
backend='triton',
)(x, weights_projection, weights_gate, dst, epilogue_args)
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
def gated_linear_unit(
x: Float[Array | ArrayView, '*B M K'],
weights_projection: Float[Array | ArrayView, 'K N'],
weights_gate: Float[Array | ArrayView, 'K N'],
*,
activation: Callable[[jax.Array], jax.Array] | None = None,
precision: jax.lax.Precision | None = None,
) -> Float[Array | ArrayView, '*B M N']:
"""Applies a gated linear unit (arxiv.org/abs/1612.08083).
Args:
x: Input activations.
weights_projection: Weights for linear projection.
weights_gate: Weights for gates.
activation: Optional activation function.
precision: Specifies the precision of the matmuls.
Returns:
`(x @ weights_projection) * activation(x @ weights_gate)`
"""
supported_dtypes = {'float16', 'bfloat16', 'float32'}
if x.dtype.name not in supported_dtypes:
raise NotImplementedError(
f'Triton kernel does not support input datatype {x.dtype.name}. Must be'
f' one of {supported_dtypes}.'
)
if not triton_utils.has_triton_support():
raise NotImplementedError('Triton kernel not supported on current device.')
*batch, m, _ = x.shape
n = weights_projection.shape[1]
x = array_view.as_array_view(x).collapse(start=0, stop=-1)
return _gated_linear_unit(
x,
weights_projection,
weights_gate,
dst=None,
activation=activation,
precision=precision,
epilogue=None,
epilogue_args=None,
).reshape(batch + [m, n])

View File

@@ -14,13 +14,14 @@ from collections.abc import Sequence
from typing import Literal, TypeAlias
from alphafold3.common import base_config
from alphafold3.jax.attention import attention
import tokamax
_Shape2DType: TypeAlias = tuple[int | None, int | None]
class GlobalConfig(base_config.BaseConfig):
"""Global configuration for the AlphaFold3 model."""
bfloat16: Literal['all', 'none', 'intermediate'] = 'all'
final_init: Literal['zeros', 'linear'] = 'zeros'
pair_attention_chunk_size: Sequence[_Shape2DType] = ((1536, 128), (None, 32))
@@ -29,4 +30,6 @@ class GlobalConfig(base_config.BaseConfig):
(None, 1024),
)
# Note: flash_attention_implementation = 'xla' means no flash attention.
flash_attention_implementation: attention.Implementation = 'triton'
flash_attention_implementation: tokamax.DotProductAttentionImplementation = (
'triton'
)

View File

@@ -11,13 +11,13 @@
"""Diffusion transformer model."""
from alphafold3.common import base_config
from alphafold3.jax.gated_linear_unit import gated_linear_unit
from alphafold3.model import model_config
from alphafold3.model.atom_layout import atom_layout
from alphafold3.model.components import haiku_modules as hm
import haiku as hk
import jax
from jax import numpy as jnp
import tokamax
def adaptive_layernorm(x, single_cond, name):
@@ -97,9 +97,7 @@ def transition_block(
name=f'{name}ffw_transition1',
)
weights = jnp.reshape(weights, (len(weights), 2, num_intermediates))
c = gated_linear_unit.gated_linear_unit(
x=x, weight=weights, implementation=None, activation=jax.nn.swish
)
c = tokamax.gated_linear_unit(x=x, weights=weights, activation=jax.nn.swish)
else:
x = hm.Linear(
num_intermediates * 2, initializer='relu', name=f'{name}ffw_transition1'

View File

@@ -14,8 +14,6 @@ from collections.abc import Sequence
from typing import Literal
from alphafold3.common import base_config
from alphafold3.jax.attention import attention
from alphafold3.jax.gated_linear_unit import gated_linear_unit
from alphafold3.model import model_config
from alphafold3.model.components import haiku_modules as hm
from alphafold3.model.components import mapping
@@ -23,6 +21,7 @@ from alphafold3.model.network import diffusion_transformer
import haiku as hk
import jax
import jax.numpy as jnp
import tokamax
def get_shard_size(
@@ -68,8 +67,8 @@ class TransitionBlock(hk.Module):
name='transition1',
)
weights = jnp.reshape(weights, (len(weights), 2, num_intermediate))
c = gated_linear_unit.gated_linear_unit(
x=act, weight=weights, implementation=None, activation=jax.nn.swish
c = tokamax.gated_linear_unit(
x=act, weights=weights, activation=jax.nn.swish
)
else:
act = hm.Linear(
@@ -172,7 +171,7 @@ class GridSelfAttention(hk.Module):
# Dot product attention requires the bias term to have a batch dimension.
bias = jnp.expand_dims(bias, 0)
weighted_avg = attention.dot_product_attention(
weighted_avg = tokamax.dot_product_attention(
q,
k,
v,
@@ -289,11 +288,8 @@ class TriangleMultiplication(hk.Module):
)
weights_glu = jnp.stack([weights_gate, weights_projection], axis=1)
projection = gated_linear_unit.gated_linear_unit(
x=act,
weight=weights_glu,
activation=jax.nn.sigmoid,
implementation=None,
projection = tokamax.gated_linear_unit(
act, weights_glu, activation=jax.nn.sigmoid
)
projection = jnp.transpose(projection, (2, 0, 1))
projection *= mask