mirror of
https://github.com/google-deepmind/alphafold3.git
synced 2026-06-02 11:54:36 +08:00
Migrate AF3 to tokamax.gated_linear_unit and tokamax.dot_product_attention.
PiperOrigin-RevId: 838262306 Change-Id: I321a78b2a7d0d5cdeabe797c59b1e9c03e33780d
This commit is contained in:
committed by
Copybara-Service
parent
2e3703e82a
commit
389078218c
@@ -145,15 +145,12 @@ AlphaFold 3 uses the following separate libraries and packages:
|
||||
* [HMMER Suite](https://github.com/EddyRivasLab/hmmer)
|
||||
* [Haiku](https://github.com/deepmind/dm-haiku)
|
||||
* [JAX](https://github.com/jax-ml/jax/)
|
||||
* [jax-triton](https://github.com/jax-ml/jax-triton)
|
||||
* [jaxtyping](https://github.com/patrick-kidger/jaxtyping)
|
||||
* [libcifpp](https://github.com/pdb-redo/libcifpp)
|
||||
* [NumPy](https://github.com/numpy/numpy)
|
||||
* [pybind11](https://github.com/pybind/pybind11) and
|
||||
[pybind11_abseil](https://github.com/pybind/pybind11_abseil)
|
||||
* [RDKit](https://github.com/rdkit/rdkit)
|
||||
* [Tree](https://github.com/deepmind/tree)
|
||||
* [Triton](https://github.com/triton-lang/triton)
|
||||
* [Tokamax](https://github.com/openxla/tokamax)
|
||||
* [tqdm](https://github.com/tqdm/tqdm)
|
||||
|
||||
We thank all their contributors and maintainers!
|
||||
|
||||
2090
dev-requirements.txt
2090
dev-requirements.txt
File diff suppressed because it is too large
Load Diff
@@ -15,18 +15,15 @@ requires-python = ">=3.11"
|
||||
readme = "README.md"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"absl-py",
|
||||
"absl-py>=2.3.1",
|
||||
"dm-haiku==0.0.13",
|
||||
"dm-tree",
|
||||
"jax==0.4.34",
|
||||
"jax[cuda12]==0.4.34",
|
||||
"jax-triton==0.2.0",
|
||||
"jaxtyping==0.2.34",
|
||||
"jax==0.8.0",
|
||||
"jax[cuda12]==0.8.0",
|
||||
"numpy",
|
||||
"rdkit==2024.3.5",
|
||||
"triton==3.1.0",
|
||||
"setuptools==78.1.0",
|
||||
"tokamax==0.0.4",
|
||||
"tqdm",
|
||||
"typeguard==2.13.3",
|
||||
"zstandard",
|
||||
]
|
||||
|
||||
|
||||
2090
requirements.txt
2090
requirements.txt
File diff suppressed because it is too large
Load Diff
@@ -42,7 +42,6 @@ import alphafold3.cpp
|
||||
from alphafold3.data import featurisation
|
||||
from alphafold3.data import pipeline
|
||||
from alphafold3.data.tools import shards
|
||||
from alphafold3.jax.attention import attention
|
||||
from alphafold3.model import features
|
||||
from alphafold3.model import model
|
||||
from alphafold3.model import params
|
||||
@@ -52,6 +51,7 @@ import haiku as hk
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
import numpy as np
|
||||
import tokamax
|
||||
|
||||
|
||||
_HOME_DIR = pathlib.Path(os.environ.get('HOME'))
|
||||
@@ -373,7 +373,7 @@ _FORCE_OUTPUT_DIR = flags.DEFINE_bool(
|
||||
|
||||
def make_model_config(
|
||||
*,
|
||||
flash_attention_implementation: attention.Implementation = 'triton',
|
||||
flash_attention_implementation: tokamax.DotProductAttentionImplementation = 'triton',
|
||||
num_diffusion_samples: int = 5,
|
||||
num_recycles: int = 10,
|
||||
return_embeddings: bool = False,
|
||||
@@ -935,7 +935,8 @@ def main(_):
|
||||
model_runner = ModelRunner(
|
||||
config=make_model_config(
|
||||
flash_attention_implementation=typing.cast(
|
||||
attention.Implementation, _FLASH_ATTENTION_IMPLEMENTATION.value
|
||||
tokamax.DotProductAttentionImplementation,
|
||||
_FLASH_ATTENTION_IMPLEMENTATION.value,
|
||||
),
|
||||
num_diffusion_samples=_NUM_DIFFUSION_SAMPLES.value,
|
||||
num_recycles=_NUM_RECYCLES.value,
|
||||
|
||||
@@ -1,139 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Scaled dot-product attention."""
|
||||
|
||||
import typing
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from alphafold3.jax.attention import attention_base as base
|
||||
from alphafold3.jax.attention import flash_attention as attention_triton
|
||||
from alphafold3.jax.attention import xla_attention
|
||||
from alphafold3.jax.common import triton_utils
|
||||
import jax
|
||||
from jax.typing import DTypeLike # pylint: disable=g-importing-member
|
||||
import jaxtyping
|
||||
from jaxtyping import Array # pylint: disable=g-importing-member
|
||||
from jaxtyping import Bool # pylint: disable=g-importing-member
|
||||
from jaxtyping import Float # pylint: disable=g-importing-member
|
||||
import typeguard
|
||||
|
||||
Implementation: TypeAlias = Literal["cudnn", "xla", "triton"]
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def dot_product_attention(
|
||||
query: Float[Array, "*B T H D"],
|
||||
key: Float[Array, "*B t #H D"],
|
||||
value: Float[Array, "*B t #H D"],
|
||||
*,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None = None,
|
||||
mask: Bool[Array, "*#B #H #T #t"] | None = None,
|
||||
implementation: Implementation | None = None,
|
||||
logits_dtype: DTypeLike | None = None,
|
||||
precision: (
|
||||
jax.lax.Precision | tuple[jax.lax.Precision, jax.lax.Precision] | None
|
||||
) = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Performs scaled dot-product attention.
|
||||
|
||||
Scaled dot-product attention from "Attention is all you need"
|
||||
https://arxiv.org/abs/1706.03762.
|
||||
|
||||
Computes self- or cross-attention. The following is computed:
|
||||
softmax(qk_scale * query @ key^T + bias) @ value.
|
||||
|
||||
Supports both multi-head and multi-query attention
|
||||
(https://arxiv.org/abs/1911.02150).
|
||||
|
||||
Arguments:
|
||||
query: Query array of shape `[batch, seq_len_q, num_heads, head_dim]`.
|
||||
key: Key array of shape `[batch, seq_len_kv, num_heads, head_dim]`.
|
||||
`num_heads` can be 1 for multi-query attention.
|
||||
value: Value array of shape `[batch, seq_len_kv, num_heads, head_dim]`.
|
||||
`num_heads` can be 1 for multi-query attention.
|
||||
bias: Optional bias array, broadcastable to shape `[batch, num_heads,
|
||||
seq_len_q, seq_len_kv]`.
|
||||
mask: Optional boolean mask, broadcastable to `[batch, num_heads, seq_len_q,
|
||||
seq_len_kv]`. Attention weights are masked out if the corresponding mask
|
||||
value is `False`.
|
||||
implementation: if `None` (default), an implementation is automatically
|
||||
chosen. 'xla' will use standard XLA and work on any platform, 'triton'
|
||||
will use a fused Triton GPU kernel, and 'cudnn' a cuDNN FlashAttention
|
||||
kernel. Only a subset of data types, shapes and GPUs are supported by
|
||||
'triton' and 'cudnn', with an exception thrown in this case.
|
||||
logits_dtype: Data type for attention logits (`query @ key^T`). If `None` is
|
||||
passed (the default), the accumulator type from the `query @ key^T` dot
|
||||
product will be used, which is FP32 for BF16/FP16/FP32 inputs. Note that
|
||||
this default increases the memory usage for BF16/FP16 inputs when using
|
||||
`implementation='xla'`, but does not increase memory usage when using
|
||||
`implementation='triton'`.
|
||||
precision: The precision for the dot products. Either `None` (default) which
|
||||
uses the default JAX precision for a backend; a tuple `(
|
||||
query_key_dot_precision, weights_value_dot_precision)` of
|
||||
`jax.lax.Precision` objects; or a single `jax.lax.Precision` object
|
||||
applied to both dot products.
|
||||
|
||||
Returns:
|
||||
An array with the same shape as `query`.
|
||||
"""
|
||||
|
||||
if implementation is not None:
|
||||
named_args = typing.get_args(Implementation)
|
||||
if implementation not in named_args:
|
||||
raise ValueError(
|
||||
f"Unsupported named implementation. Must be one of {named_args}."
|
||||
)
|
||||
|
||||
if implementation == "cudnn":
|
||||
if logits_dtype is not None:
|
||||
raise ValueError(
|
||||
"logits_dtype is not supported for cudnn implementation."
|
||||
)
|
||||
if precision is not None:
|
||||
raise NotImplementedError(
|
||||
"precision is not supported for cudnn implementation."
|
||||
)
|
||||
|
||||
return jax.nn.dot_product_attention(
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
implementation="cudnn",
|
||||
)
|
||||
|
||||
logits_dtype = base.AUTO if logits_dtype is None else logits_dtype
|
||||
precision = jax.lax.Precision.DEFAULT if precision is None else precision
|
||||
|
||||
args = (query, key, value)
|
||||
kwargs = dict(
|
||||
precision=precision,
|
||||
logits_dtype=logits_dtype,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
if implementation == "triton":
|
||||
if not triton_utils.has_triton_support():
|
||||
raise ValueError(
|
||||
"implementation='triton' for FlashAttention is unsupported on this"
|
||||
" GPU generation. Please use implementation='xla' instead."
|
||||
)
|
||||
return attention_triton.TritonFlashAttention()(*args, **kwargs)
|
||||
|
||||
if implementation is None and triton_utils.has_triton_support():
|
||||
try:
|
||||
return attention_triton.TritonFlashAttention()(*args, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
pass # Fallback to XLA.
|
||||
|
||||
return xla_attention.XlaDotProductAttention()(*args, **kwargs)
|
||||
@@ -1,363 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Common types and utilities for attention kernels."""
|
||||
|
||||
import abc
|
||||
import dataclasses
|
||||
import enum
|
||||
import functools
|
||||
import math
|
||||
from typing import Any, Self
|
||||
|
||||
from alphafold3.jax.common import array_view
|
||||
from alphafold3.jax.common import precision as precision_lib
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax.typing import DTypeLike # pylint: disable=g-importing-member
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Bool, Float, Int # pylint: disable=g-multiple-import,g-importing-member
|
||||
import typeguard
|
||||
|
||||
|
||||
class AUTO: # Used as a sentinel value.
|
||||
pass
|
||||
|
||||
|
||||
DotPrecisionLike = jax.lax.Precision | precision_lib.DotPrecision
|
||||
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class Mask:
|
||||
"""An attention mask.
|
||||
|
||||
`k_start` (inclusive) and `k_end` (exclusive) define range of enabled
|
||||
k-sequence values for each row of logits.
|
||||
|
||||
For example, a local attention mask could be defined as follows:
|
||||
```
|
||||
seq_len_q = seq_len_k = 4
|
||||
window_size = 2
|
||||
k_start = jnp.maximum(0, jnp.arange(seq_len_q) + 1 - window_size)
|
||||
mask = Mask(k_start=k_start, is_causal=True)
|
||||
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
|
||||
[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 1, 1]], dtype=bool)
|
||||
```
|
||||
Or equivalently (but less efficiently):
|
||||
```
|
||||
k_end = jnp.arange(seq_len_q) + 1
|
||||
k_start = jnp.maximum(0, k_end - window_size)
|
||||
mask = Mask(k_start=k_start, k_end=k_end)
|
||||
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
|
||||
[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[0, 1, 1, 0],
|
||||
[0, 0, 1, 1]], dtype=bool)
|
||||
```
|
||||
|
||||
A mask for two independent causal sequences could be defined as follows:
|
||||
```
|
||||
k_start = jnp.array([0, 0, 2, 2])
|
||||
mask = Mask(k_start=k_start, is_causal=True)
|
||||
assert mask.as_array(seq_len_q, seq_len_k) == jnp.array(
|
||||
[[1, 0, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[0, 0, 1, 0],
|
||||
[0, 0, 1, 1]], dtype=bool)
|
||||
```
|
||||
"""
|
||||
|
||||
bool_mask: Bool[Array, "*#B #T #t"] | None = None
|
||||
_: dataclasses.KW_ONLY
|
||||
q_start: Int[Array, "*#B #t"] | None = None
|
||||
q_end: Int[Array, "*#B #t"] | None = None
|
||||
k_start: Int[Array, "*#B #T"] | None = None
|
||||
k_end: Int[Array, "*#B #T"] | None = None
|
||||
is_causal: bool = False
|
||||
|
||||
def tree_flatten(self):
|
||||
return (
|
||||
self.bool_mask,
|
||||
self.q_start,
|
||||
self.q_end,
|
||||
self.k_start,
|
||||
self.k_end,
|
||||
), (self.is_causal,)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, children) -> Self:
|
||||
(is_causal,) = aux
|
||||
bool_mask, q_start, q_end, k_start, k_end = children
|
||||
return cls(
|
||||
bool_mask,
|
||||
q_start=q_start,
|
||||
q_end=q_end,
|
||||
k_start=k_start,
|
||||
k_end=k_end,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
def as_array(
|
||||
self,
|
||||
q_len_or_indices: int | Int[Array, "*#B T"],
|
||||
k_len_or_indices: int | Int[Array, "*#B t"],
|
||||
) -> Bool[Array, "*#B #T #t"] | None:
|
||||
"""Returns the mask as a boolean array."""
|
||||
if isinstance(q_len_or_indices, int):
|
||||
q_indices = jnp.arange(q_len_or_indices)
|
||||
else:
|
||||
q_indices = q_len_or_indices
|
||||
|
||||
if isinstance(k_len_or_indices, int):
|
||||
k_indices = jnp.arange(k_len_or_indices)
|
||||
else:
|
||||
k_indices = k_len_or_indices
|
||||
|
||||
q_indices = q_indices[..., None]
|
||||
k_indices = k_indices[..., None, :]
|
||||
|
||||
mask = []
|
||||
if self.bool_mask is not None:
|
||||
mask.append(self.bool_mask)
|
||||
# Check `bool_mask` shape is compatible with `{q,kv}_indices`.
|
||||
_ = jnp.broadcast_shapes(
|
||||
q_indices.shape, k_indices.shape, self.bool_mask.shape
|
||||
)
|
||||
|
||||
if self.q_start is not None:
|
||||
mask.append(q_indices >= self.q_start[..., None, :])
|
||||
|
||||
if self.q_end is not None:
|
||||
mask.append(q_indices < self.q_end[..., None, :])
|
||||
|
||||
if self.k_start is not None:
|
||||
mask.append(k_indices >= self.k_start[..., None])
|
||||
|
||||
if self.k_end is not None:
|
||||
mask.append(k_indices < self.k_end[..., None])
|
||||
|
||||
if self.is_causal:
|
||||
mask.append(q_indices >= k_indices)
|
||||
|
||||
logical_and = functools.partial(functools.reduce, jnp.logical_and)
|
||||
return jax.lax.broadcast_to_rank(logical_and(mask), 3) if mask else None
|
||||
|
||||
def take(self, *attrs: str) -> tuple[Any, ...]:
|
||||
"""Returns a mask with attrs removed and the removed attrs."""
|
||||
default_mask = type(self)()
|
||||
replacements = {attr: getattr(default_mask, attr) for attr in attrs}
|
||||
values = (getattr(self, attr) for attr in attrs)
|
||||
return dataclasses.replace(self, **replacements), *values
|
||||
|
||||
def __and__(self, other: "Bool[Array, '*#B #T #t'] | Mask") -> "Mask": # pylint: disable=g-inconsistent-quotes
|
||||
"""Returns the intersection of two masks."""
|
||||
if not isinstance(other, Mask):
|
||||
other = Mask(other)
|
||||
|
||||
def combine(op):
|
||||
return lambda a, b: b if a is None else a if b is None else op(a, b)
|
||||
|
||||
return Mask(
|
||||
bool_mask=combine(jnp.logical_and)(self.bool_mask, other.bool_mask),
|
||||
q_end=combine(jnp.minimum)(self.q_end, other.q_end),
|
||||
k_start=combine(jnp.maximum)(self.k_start, other.k_start),
|
||||
k_end=combine(jnp.minimum)(self.k_end, other.k_end),
|
||||
is_causal=self.is_causal or other.is_causal,
|
||||
)
|
||||
|
||||
|
||||
CAUSAL_MASK = Mask(is_causal=True)
|
||||
|
||||
|
||||
SoftmaxResidual = (
|
||||
tuple[Float[Array, "*B H T"], Float[Array, "*B H T"]]
|
||||
| Float[Array, "*B H T"]
|
||||
)
|
||||
|
||||
|
||||
@enum.unique
|
||||
class SoftmaxResidualMode(enum.Enum):
|
||||
"""The mode of storing softmax residuals for the backwards pass.
|
||||
|
||||
The stable softmax calculation performs two reductions calculating:
|
||||
- the maximum input value (`x_max`),
|
||||
- the sum of exponentiated values (`denom`).
|
||||
|
||||
We can store these values as residuals to avoid the need to recompute them
|
||||
in the backwards pass.
|
||||
|
||||
It is also possible to combine the two residuals into a single residual,
|
||||
`res = x_max + log(denom)`, as `exp(x - res) === exp(x - x_max - log(denom))
|
||||
=== exp(x - x_max) / denom`. Combining the residuals reduces the memory usage
|
||||
of the residuals, but will reduce the accuracy of the backwards pass if
|
||||
`abs(x_max) >> log(denom)`.
|
||||
"""
|
||||
|
||||
SEPARATE = "separate"
|
||||
COMBINED = "combined"
|
||||
|
||||
def conform(self, aux: SoftmaxResidual) -> SoftmaxResidual | None:
|
||||
match self, aux:
|
||||
case None, _:
|
||||
return None
|
||||
case SoftmaxResidualMode.SEPARATE, (_, _):
|
||||
return aux
|
||||
case SoftmaxResidualMode.SEPARATE, _: # pytype: disable=redundant-match # b/300135240
|
||||
raise ValueError("`aux` has been combined.")
|
||||
case SoftmaxResidualMode.COMBINED, (x_max, denom):
|
||||
return x_max + jnp.log(denom)
|
||||
case SoftmaxResidualMode.COMBINED, _: # pytype: disable=redundant-match # b/300135240
|
||||
return aux
|
||||
|
||||
|
||||
class DotProductAttention(abc.ABC):
|
||||
"""Dot product attention function."""
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def __call__(
|
||||
self,
|
||||
query: Float[Array | array_view.ArrayView, "*B T H D"],
|
||||
key: Float[Array | array_view.ArrayView, "*B t h D"],
|
||||
value: Float[Array | array_view.ArrayView, "*B t h D"],
|
||||
*,
|
||||
precision: (
|
||||
DotPrecisionLike | tuple[DotPrecisionLike, DotPrecisionLike]
|
||||
) = jax.lax.Precision.DEFAULT,
|
||||
logits_dtype: DTypeLike | type[AUTO] = AUTO,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None = None,
|
||||
mask: Bool[Array, "*#B #H #T #t"] | Mask | None = None,
|
||||
q_indices: Int[Array, "*#B #H T"] | None = None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Performs scaled dot-product attention.
|
||||
|
||||
Scaled dot-product attention from "Attention is all you need"
|
||||
https://arxiv.org/abs/1706.03762.
|
||||
|
||||
Computes self- or cross-attention. The following is computed:
|
||||
softmax(qk_scale * query @ key^T + bias) @ value.
|
||||
|
||||
Supports both multi-head and multi-query attention
|
||||
(https://arxiv.org/abs/1911.02150).
|
||||
|
||||
Arguments:
|
||||
query: Query array of shape `[batch, seq_len_q, num_heads_q, head_dim]`.
|
||||
It must be a multiple of num_heads_kv.
|
||||
Here's an example of how q/kv heads are interleaved:
|
||||
For 8 key/value heads and 4 query heads:
|
||||
- key/value heads [0, 1] see query head 0
|
||||
- key/value heads [2, 3] see query head 1
|
||||
- key/value heads [4, 5] see query head 2
|
||||
key: Key array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. It
|
||||
must be divisible by num_heads_q.
|
||||
value: Value array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`.
|
||||
precision: The precision for the dot products. Either a tuple `(
|
||||
query_key_dot_precision, weights_value_dot_precision)` or a single
|
||||
precision applied to both dot products.
|
||||
logits_dtype: Data type for attention logits (`query @ key^T`). If `AUTO`
|
||||
is passed (the default), the accumulator type from the `query @ key^T`
|
||||
dot product will be used.
|
||||
bias: Optional bias array, broadcastable to shape `[batch, num_heads,
|
||||
seq_len_q, seq_len_kv]`.
|
||||
mask: Optional boolean mask, broadcastable to `[batch, num_heads,
|
||||
seq_len_q, seq_len_kv]`. Attention weights are masked out if the
|
||||
corresponding mask value is `False`.
|
||||
q_indices: Optional indices for each token in query sequence.
|
||||
k_indices: Optional indices for each token in key/value sequence.
|
||||
|
||||
Returns:
|
||||
An array with the same shape as `query`.
|
||||
""" # fmt: skip
|
||||
return self.fwd(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
precision=precision,
|
||||
logits_dtype=logits_dtype,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
q_indices=q_indices,
|
||||
k_indices=k_indices,
|
||||
)
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def fwd(
|
||||
self,
|
||||
query: Float[Array | array_view.ArrayView, "*B T H D"],
|
||||
key: Float[Array | array_view.ArrayView, "*B t h D"],
|
||||
value: Float[Array | array_view.ArrayView, "*B t h D"],
|
||||
*,
|
||||
precision: (
|
||||
DotPrecisionLike | tuple[DotPrecisionLike, DotPrecisionLike]
|
||||
) = jax.lax.Precision.DEFAULT,
|
||||
logits_dtype: DTypeLike | type[AUTO] = AUTO,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None = None,
|
||||
mask: Bool[Array, "*#B #H #T #t"] | Mask | None = None,
|
||||
q_indices: Int[Array, "*#B #H T"] | None = None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Performs attention."""
|
||||
if not isinstance(precision, tuple):
|
||||
precision = (precision, precision)
|
||||
|
||||
q_k_dot_precision, weights_v_dot_precision = precision
|
||||
|
||||
if not isinstance(q_k_dot_precision, precision_lib.DotPrecision):
|
||||
q_k_dot_precision = precision_lib.get_equivalent_dot_precision(
|
||||
query.dtype, key.dtype, q_k_dot_precision
|
||||
)
|
||||
|
||||
if not isinstance(weights_v_dot_precision, precision_lib.DotPrecision):
|
||||
weights_v_dot_precision = precision_lib.get_equivalent_dot_precision(
|
||||
value.dtype, value.dtype, weights_v_dot_precision
|
||||
)
|
||||
|
||||
if logits_dtype is AUTO:
|
||||
logits_dtype = q_k_dot_precision.accumulator_dtype
|
||||
|
||||
if not isinstance(mask, Mask):
|
||||
mask = Mask(mask)
|
||||
|
||||
return self._fwd(
|
||||
array_view.as_array_view(query),
|
||||
array_view.as_array_view(key),
|
||||
array_view.as_array_view(value),
|
||||
q_k_dot_precision=q_k_dot_precision,
|
||||
logits_dtype=jnp.dtype(logits_dtype),
|
||||
logits_scale=1 / math.sqrt(query.shape[-1]),
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
weights_v_dot_precision=weights_v_dot_precision,
|
||||
q_indices=q_indices,
|
||||
k_indices=k_indices,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _fwd(
|
||||
self,
|
||||
q: Float[array_view.ArrayView, "*B T H D"],
|
||||
k: Float[array_view.ArrayView, "*B t h D"],
|
||||
v: Float[array_view.ArrayView, "*B t h D"],
|
||||
*,
|
||||
q_k_dot_precision: precision_lib.DotPrecision,
|
||||
logits_dtype: jnp.dtype,
|
||||
logits_scale: float,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None,
|
||||
mask: Mask | None,
|
||||
weights_v_dot_precision: precision_lib.DotPrecision,
|
||||
q_indices: Int[Array, "*#B #H T"] | None = None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Performs attention."""
|
||||
...
|
||||
@@ -1,62 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Attention call argument specifications.
|
||||
|
||||
Attention argument specifications used by users of the library.
|
||||
They are the most important test cases, and also cases for optimize
|
||||
performance of via autotuning.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
|
||||
ShapedArray = jax.ShapeDtypeStruct
|
||||
|
||||
|
||||
def _make_argspec(
|
||||
*,
|
||||
q_shape,
|
||||
dtype,
|
||||
k_shape=None,
|
||||
v_shape=None,
|
||||
bias_shape=None,
|
||||
mask_shape=None,
|
||||
**kwargs,
|
||||
) -> dict[str, Any]:
|
||||
"""Make argspec from shapes and kwargs."""
|
||||
if k_shape is None:
|
||||
k_shape = q_shape
|
||||
if v_shape is None:
|
||||
v_shape = k_shape
|
||||
|
||||
return dict(
|
||||
query=ShapedArray(q_shape, dtype),
|
||||
key=ShapedArray(k_shape, dtype),
|
||||
value=ShapedArray(v_shape, dtype),
|
||||
bias=ShapedArray(bias_shape, dtype) if bias_shape is not None else None,
|
||||
mask=ShapedArray(mask_shape, 'bool_') if mask_shape is not None else None,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# A subset of the full set of argument specifications. Useful for tap-tests and
|
||||
# microbenchmarks.
|
||||
CALL_ARG_SPECS = dict(
|
||||
vanilla_f32=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='float32'),
|
||||
vanilla_bf16=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='bfloat16'),
|
||||
alphafold=_make_argspec(
|
||||
q_shape=(384, 384, 4, 32),
|
||||
bias_shape=(1, 4, 384, 384),
|
||||
mask_shape=(384, 1, 1, 384),
|
||||
dtype='bfloat16',
|
||||
),
|
||||
)
|
||||
@@ -1,703 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Triton FlashAttention implementation."""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
|
||||
from alphafold3.jax.attention import attention_base as base
|
||||
from alphafold3.jax.common import array_view
|
||||
from alphafold3.jax.common import precision as precision_lib
|
||||
from alphafold3.jax.common import triton_utils
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jax_triton as jt
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Bool, Float, Int # pylint: disable=g-multiple-import,g-importing-member
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import typeguard
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _fwd_kernel_inner(
|
||||
start_loop,
|
||||
end_loop,
|
||||
q,
|
||||
span_q,
|
||||
k_block_ptr,
|
||||
v_block_ptr,
|
||||
bias_block_ptr,
|
||||
mask_block_ptr,
|
||||
k_start,
|
||||
k_end,
|
||||
seq_len_k,
|
||||
acc,
|
||||
m_i,
|
||||
l_i,
|
||||
bias_advance: tl.constexpr,
|
||||
mask_advance: tl.constexpr,
|
||||
is_causal: tl.constexpr,
|
||||
use_attention_mask: tl.constexpr,
|
||||
use_k_start: tl.constexpr,
|
||||
use_k_end: tl.constexpr,
|
||||
use_bias: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
use_mask_k: tl.constexpr,
|
||||
k_boundary_check: tl.constexpr,
|
||||
v_boundary_check: tl.constexpr,
|
||||
dot_fn_qk: tl.constexpr,
|
||||
dot_fn_kv: tl.constexpr,
|
||||
):
|
||||
"""Triton MHA forward kernel's inner loop."""
|
||||
|
||||
for start_k in range(start_loop, end_loop, block_k):
|
||||
start_k = tl.multiple_of(start_k, block_k)
|
||||
span_k = start_k + tl.arange(0, block_k)
|
||||
|
||||
k = tl.load(
|
||||
k_block_ptr,
|
||||
boundary_check=k_boundary_check,
|
||||
padding_option="zero" if len(k_boundary_check.value) else "",
|
||||
)
|
||||
v = tl.load(
|
||||
v_block_ptr,
|
||||
boundary_check=v_boundary_check,
|
||||
padding_option="zero" if len(v_boundary_check.value) else "",
|
||||
)
|
||||
|
||||
if use_bias:
|
||||
bias = tl.load(bias_block_ptr)
|
||||
|
||||
qk = dot_fn_qk(q.to(k.dtype), k) # [block_q, block_k]
|
||||
|
||||
if use_bias:
|
||||
# Prevent dot accumulating into the bias tensor. It appears that Triton
|
||||
# doesn't pipeline the bias load as it does the `k` load, so the bias load
|
||||
# blocks the matmul if the add is merged.
|
||||
qk = qk.to(tl.uint32, bitcast=True) & 0xFFFFFFFF
|
||||
qk = qk.to(tl.float32, bitcast=True)
|
||||
qk += bias
|
||||
|
||||
if use_attention_mask | use_k_start | use_k_end:
|
||||
mask_value = float(jnp.finfo(jnp.float32).min)
|
||||
|
||||
if use_attention_mask:
|
||||
mask = tl.load(mask_block_ptr)
|
||||
qk = tl.where(mask, qk, mask_value)
|
||||
|
||||
if use_k_start:
|
||||
# This check is there to work around a triton compiler bug, but it
|
||||
# shouldn't be strictly needed.
|
||||
if tl.sum(k_start) != 0:
|
||||
qk = tl.where(k_start[:, None] <= span_k[None, :], qk, mask_value)
|
||||
if is_causal:
|
||||
qk = tl.where(span_q[:, None] >= span_k[None, :], qk, float("-inf"))
|
||||
elif use_k_end:
|
||||
# When called with k_end and is_causal=True, the causal mask gets folded
|
||||
# into k_end and is_causal is set to False.
|
||||
qk = tl.where(k_end[:, None] > span_k[None, :], qk, mask_value)
|
||||
|
||||
if use_mask_k:
|
||||
qk = tl.where((span_k < seq_len_k)[None, :], qk, float("-inf"))
|
||||
|
||||
m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) # Shape [block_q].
|
||||
p = tl.exp(qk - m_ij[:, None]) # Shape [block_q, block_k].
|
||||
alpha = tl.exp(m_i - m_ij)
|
||||
m_i = m_ij
|
||||
acc *= alpha[:, None]
|
||||
l_i *= alpha
|
||||
l_i += tl.sum(p, axis=1)
|
||||
|
||||
# Add the new block of attention weights.
|
||||
acc += dot_fn_kv(p.to(v.dtype), v)
|
||||
|
||||
k_block_ptr = tl.advance(k_block_ptr, (0, block_k))
|
||||
v_block_ptr = tl.advance(v_block_ptr, (block_k, 0))
|
||||
bias_block_ptr = tl.advance(bias_block_ptr, bias_advance.value)
|
||||
mask_block_ptr = tl.advance(mask_block_ptr, mask_advance.value)
|
||||
|
||||
return (
|
||||
k_block_ptr,
|
||||
v_block_ptr,
|
||||
bias_block_ptr,
|
||||
mask_block_ptr,
|
||||
acc,
|
||||
m_i,
|
||||
l_i,
|
||||
)
|
||||
|
||||
|
||||
# Based on Algorithm 1 of https://arxiv.org/abs/2205.14135.
|
||||
# Inspired by the official Triton tutorial implementation
|
||||
# https://triton-lang.org/main/getting-started/tutorials/06-fused-attention.html
|
||||
@triton.jit
|
||||
def _fwd_kernel(
|
||||
# Input arrays.
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
v_ptr,
|
||||
bias_ptr,
|
||||
mask_ptr,
|
||||
k_start_ptr,
|
||||
k_end_ptr,
|
||||
# Scalar inputs.
|
||||
q_offset,
|
||||
k_offset,
|
||||
v_offset,
|
||||
q_stride_b,
|
||||
q_stride_s,
|
||||
q_stride_h,
|
||||
q_stride_d,
|
||||
k_stride_b,
|
||||
k_stride_s,
|
||||
k_stride_h,
|
||||
k_stride_d,
|
||||
v_stride_b,
|
||||
v_stride_s,
|
||||
v_stride_h,
|
||||
v_stride_d,
|
||||
bias_stride_b,
|
||||
bias_stride_h,
|
||||
bias_stride_sq,
|
||||
bias_stride_sk,
|
||||
mask_stride_b,
|
||||
mask_stride_h,
|
||||
mask_stride_sq,
|
||||
mask_stride_sk,
|
||||
k_start_stride_b,
|
||||
k_start_stride_h,
|
||||
k_start_stride_sq,
|
||||
k_end_stride_b,
|
||||
k_end_stride_h,
|
||||
k_end_stride_sq,
|
||||
o_stride_b,
|
||||
o_stride_s,
|
||||
o_stride_h,
|
||||
o_stride_d,
|
||||
num_heads_q,
|
||||
num_heads_k,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
# Output arrays.
|
||||
o_ptr,
|
||||
# Compile-time constants.
|
||||
is_causal: tl.constexpr,
|
||||
use_attention_mask: tl.constexpr,
|
||||
use_k_start: tl.constexpr,
|
||||
use_k_end: tl.constexpr,
|
||||
use_bias: tl.constexpr,
|
||||
sm_scale: tl.constexpr,
|
||||
block_q: tl.constexpr,
|
||||
block_k: tl.constexpr,
|
||||
head_dim: tl.constexpr,
|
||||
use_mask_q: tl.constexpr,
|
||||
use_mask_k: tl.constexpr,
|
||||
bias_bcast_sq: tl.constexpr,
|
||||
mask_bcast_sq: tl.constexpr,
|
||||
dot_fn_qk: tl.constexpr,
|
||||
dot_fn_kv: tl.constexpr,
|
||||
):
|
||||
"""Triton MHA forward kernel."""
|
||||
# pytype: disable=annotation-type-mismatch,unsupported-operands
|
||||
block_d: tl.constexpr = jt.utils.next_power_of_2(head_dim.value)
|
||||
|
||||
# Each thread block processes one batch element (b) and one head (h).
|
||||
start_q = tl.program_id(1) * block_q
|
||||
off_h = tl.program_id(0) # int in [0, num_heads_o).
|
||||
off_b = tl.program_id(2) # int in [0, batch_size)
|
||||
|
||||
off_h_k = off_h // (num_heads_q // num_heads_k)
|
||||
|
||||
q_ptr += off_h * q_stride_h + off_b * q_stride_b + q_offset
|
||||
k_ptr += off_h_k * k_stride_h + off_b * k_stride_b + k_offset
|
||||
v_ptr += off_h_k * v_stride_h + off_b * v_stride_b + v_offset
|
||||
o_ptr += off_h * o_stride_h + off_b * o_stride_b
|
||||
|
||||
if use_bias:
|
||||
bias_ptr += off_b * bias_stride_b + off_h * bias_stride_h
|
||||
if use_attention_mask:
|
||||
mask_ptr += off_b * mask_stride_b + off_h * mask_stride_h
|
||||
if use_k_start:
|
||||
k_start_ptr += off_b * k_start_stride_b + off_h * k_start_stride_h
|
||||
if use_k_end:
|
||||
k_end_ptr += off_b * k_end_stride_b + off_h * k_end_stride_h
|
||||
|
||||
q_block_ptr = tl.make_block_ptr(
|
||||
q_ptr,
|
||||
shape=(seq_len_q, head_dim),
|
||||
strides=(q_stride_s, q_stride_d),
|
||||
offsets=(start_q, 0),
|
||||
block_shape=(block_q, block_d),
|
||||
order=(1, 0),
|
||||
)
|
||||
k_block_ptr = tl.make_block_ptr(
|
||||
k_ptr,
|
||||
shape=(head_dim, seq_len_k),
|
||||
strides=(k_stride_d, k_stride_s),
|
||||
offsets=(0, 0),
|
||||
block_shape=(block_d, block_k),
|
||||
order=(0, 1),
|
||||
)
|
||||
v_block_ptr = tl.make_block_ptr(
|
||||
v_ptr,
|
||||
shape=(seq_len_k, head_dim),
|
||||
strides=(v_stride_s, v_stride_d),
|
||||
offsets=(0, 0),
|
||||
block_shape=(block_k, block_d),
|
||||
order=(1, 0),
|
||||
)
|
||||
|
||||
q_boundary_check0: tl.constexpr = (0,) if use_mask_q else ()
|
||||
q_boundary_check1: tl.constexpr = (1,) if head_dim != block_d else ()
|
||||
q_boundary_check: tl.constexpr = q_boundary_check0 + q_boundary_check1
|
||||
q_padding_option: tl.constexpr = "zero" if len(q_boundary_check.value) else ""
|
||||
k_boundary_check: tl.constexpr = (0,) if head_dim != block_d else ()
|
||||
v_boundary_check: tl.constexpr = (0,) if use_mask_k else ()
|
||||
|
||||
# If broadcasting in a given dim, use a 1D block (observed to be faster).
|
||||
bias_start_dim: tl.constexpr = 1 if bias_bcast_sq else 0
|
||||
bias_block_ptr = tl.make_block_ptr(
|
||||
bias_ptr,
|
||||
shape=(seq_len_q, seq_len_k)[bias_start_dim:],
|
||||
strides=(bias_stride_sq, bias_stride_sk)[bias_start_dim:],
|
||||
offsets=(start_q, 0)[bias_start_dim:],
|
||||
block_shape=(block_q, block_k)[bias_start_dim:],
|
||||
order=(1, 0)[bias_start_dim:],
|
||||
)
|
||||
bias_advance: tl.constexpr = (0, block_k)[bias_start_dim:]
|
||||
|
||||
mask_start_dim: tl.constexpr = 1 if mask_bcast_sq else 0
|
||||
mask_block_ptr = tl.make_block_ptr(
|
||||
mask_ptr,
|
||||
shape=(seq_len_q, seq_len_k)[mask_start_dim:],
|
||||
strides=(mask_stride_sq, mask_stride_sk)[mask_start_dim:],
|
||||
offsets=(start_q, 0)[mask_start_dim:],
|
||||
block_shape=(block_q, block_k)[mask_start_dim:],
|
||||
order=(1, 0)[mask_start_dim:],
|
||||
)
|
||||
mask_advance: tl.constexpr = (0, block_k)[mask_start_dim:]
|
||||
|
||||
k_start_block_ptr = tl.make_block_ptr(
|
||||
k_start_ptr,
|
||||
shape=(seq_len_q,),
|
||||
strides=(k_start_stride_sq,),
|
||||
offsets=(start_q,),
|
||||
block_shape=(block_q,),
|
||||
order=(0,),
|
||||
)
|
||||
k_end_block_ptr = tl.make_block_ptr(
|
||||
k_end_ptr,
|
||||
shape=(seq_len_q,),
|
||||
strides=(k_end_stride_sq,),
|
||||
offsets=(start_q,),
|
||||
block_shape=(block_q,),
|
||||
order=(0,),
|
||||
)
|
||||
# pytype: enable=annotation-type-mismatch,unsupported-operands
|
||||
|
||||
# Each thread block processes a block of block_q queries.
|
||||
span_q = start_q + tl.arange(0, block_q)
|
||||
|
||||
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
|
||||
m_i = tl.full([block_q], float("-inf"), dtype=tl.float32)
|
||||
l_i = tl.zeros([block_q], dtype=tl.float32)
|
||||
# acc is the buffer where we accumulate the output on sram.
|
||||
acc = tl.zeros([block_q, block_d], dtype=tl.float32)
|
||||
|
||||
# Load q: it will stay in smem throughout. Indices form a matrix because we
|
||||
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
|
||||
q = tl.load(
|
||||
q_block_ptr,
|
||||
boundary_check=q_boundary_check,
|
||||
padding_option=q_padding_option,
|
||||
)
|
||||
q *= sm_scale
|
||||
|
||||
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
|
||||
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
|
||||
# Here we only loop over blocks of kv to process entire seq_len, the loop over
|
||||
# blocks of q is carried out by the grid.
|
||||
k_start = None
|
||||
if use_k_start:
|
||||
k_start = tl.load(k_start_block_ptr)
|
||||
start_loop = tl.maximum(tl.min(k_start), 0)
|
||||
blocks_to_skip = start_loop // block_k
|
||||
start_loop = block_k * blocks_to_skip # Floor to multiple of block_k.
|
||||
for _ in range(blocks_to_skip):
|
||||
# Advance all block pointers to the first valid block.
|
||||
k_block_ptr = tl.advance(k_block_ptr, (0, block_k))
|
||||
v_block_ptr = tl.advance(v_block_ptr, (block_k, 0))
|
||||
bias_block_ptr = tl.advance(bias_block_ptr, bias_advance.value)
|
||||
mask_block_ptr = tl.advance(mask_block_ptr, mask_advance.value)
|
||||
else:
|
||||
start_loop = 0
|
||||
|
||||
k_end = None
|
||||
if is_causal:
|
||||
end_loop = tl.minimum((start_q // block_k) * block_k, seq_len_k)
|
||||
elif use_k_end:
|
||||
k_end = tl.load(k_end_block_ptr)
|
||||
end_loop = tl.minimum(tl.max(k_end), seq_len_k)
|
||||
else:
|
||||
end_loop = seq_len_k
|
||||
|
||||
(
|
||||
k_block_ptr,
|
||||
v_block_ptr,
|
||||
bias_block_ptr,
|
||||
mask_block_ptr,
|
||||
acc,
|
||||
m_i,
|
||||
l_i,
|
||||
) = _fwd_kernel_inner(
|
||||
start_loop,
|
||||
end_loop,
|
||||
q,
|
||||
span_q,
|
||||
k_block_ptr,
|
||||
v_block_ptr,
|
||||
bias_block_ptr,
|
||||
mask_block_ptr,
|
||||
k_start,
|
||||
k_end,
|
||||
seq_len_k,
|
||||
acc,
|
||||
m_i,
|
||||
l_i,
|
||||
bias_advance,
|
||||
mask_advance,
|
||||
False, # is_causal
|
||||
use_attention_mask,
|
||||
use_k_start,
|
||||
use_k_end,
|
||||
use_bias,
|
||||
block_k,
|
||||
use_mask_k,
|
||||
k_boundary_check,
|
||||
v_boundary_check,
|
||||
dot_fn_qk,
|
||||
dot_fn_kv,
|
||||
)
|
||||
|
||||
if is_causal:
|
||||
tl.debug_barrier() # Help compiler schedule loops independently.
|
||||
start_loop, end_loop = end_loop, tl.minimum(end_loop + block_k, seq_len_k)
|
||||
|
||||
_, _, _, _, acc, _, l_i = _fwd_kernel_inner(
|
||||
start_loop,
|
||||
end_loop,
|
||||
q,
|
||||
span_q,
|
||||
k_block_ptr,
|
||||
v_block_ptr,
|
||||
bias_block_ptr,
|
||||
mask_block_ptr,
|
||||
k_start,
|
||||
k_end,
|
||||
seq_len_k,
|
||||
acc,
|
||||
m_i,
|
||||
l_i,
|
||||
bias_advance,
|
||||
mask_advance,
|
||||
True, # is_causal
|
||||
use_attention_mask,
|
||||
use_k_start,
|
||||
use_k_end,
|
||||
use_bias,
|
||||
block_k,
|
||||
use_mask_k,
|
||||
k_boundary_check,
|
||||
v_boundary_check,
|
||||
dot_fn_qk,
|
||||
dot_fn_kv,
|
||||
)
|
||||
|
||||
# It is possible that every value in a row was masked to f32 min or that the
|
||||
# main loop has been completely optimised out, and that `l_i` is `0` for that
|
||||
# row. Add epsilon value to avoid NaNs from `0 / 0`.
|
||||
l_i += float(jnp.finfo(jnp.float32).tiny)
|
||||
|
||||
acc /= l_i[:, None]
|
||||
|
||||
# Write output to dram.
|
||||
o_block_ptr = tl.make_block_ptr(
|
||||
o_ptr,
|
||||
shape=(seq_len_q, head_dim),
|
||||
strides=(o_stride_s, o_stride_d),
|
||||
offsets=(start_q, 0),
|
||||
block_shape=(block_q, block_d),
|
||||
order=(1, 0),
|
||||
)
|
||||
acc = acc.to(o_ptr.dtype.element_ty)
|
||||
tl.store(o_block_ptr, acc, boundary_check=q_boundary_check)
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def _fwd(
|
||||
q: Float[array_view.ArrayView, "*B T H D"],
|
||||
k: Float[array_view.ArrayView, "*B t h D"],
|
||||
v: Float[array_view.ArrayView, "*B t h D"],
|
||||
bias: Float[Array, "*#B #H #T #t"] | None,
|
||||
mask: Bool[Array, "*#B #H #T #t"] | None,
|
||||
k_start: Int[Array, "*#B #H #T"] | None,
|
||||
k_end: Int[Array, "*#B #H #T"] | None,
|
||||
*,
|
||||
logits_scale: float,
|
||||
is_causal: bool,
|
||||
q_k_dot_precision: precision_lib.DotPrecision,
|
||||
weights_v_dot_precision: precision_lib.DotPrecision,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Forward pass of Triton FlashAttention."""
|
||||
|
||||
orig_q_shape = q.shape
|
||||
q = q.collapse(0, -3, allow_copy=True)
|
||||
batch_size, seq_len_q, num_heads_q, head_dim = q.shape
|
||||
*_, seq_len_k, num_heads_kv, _ = k.shape
|
||||
# Maybe broadcast `k`/`v` heads dimension.
|
||||
kv_shape = (batch_size, seq_len_k, num_heads_kv, head_dim)
|
||||
k = k.collapse(0, -3, allow_copy=True).broadcast_to(kv_shape)
|
||||
v = v.collapse(0, -3, allow_copy=True).broadcast_to(kv_shape)
|
||||
|
||||
def get_bias_mask_view(x, dtype):
|
||||
if x is None:
|
||||
x = jnp.array([], dtype=dtype)
|
||||
return array_view.ArrayView(x, shape=(0, 0, 0, 0), strides=(0, 0, 0, 0))
|
||||
|
||||
shape = orig_q_shape[:-3] + (num_heads_q, seq_len_q, seq_len_k)
|
||||
return (
|
||||
array_view.ArrayView(x)
|
||||
.broadcast_to(shape)
|
||||
.collapse(0, -3, allow_copy=True)
|
||||
)
|
||||
|
||||
bias = get_bias_mask_view(bias, dtype=q.dtype)
|
||||
mask = get_bias_mask_view(mask, dtype=jnp.bool_)
|
||||
|
||||
def get_range_view(x, seq_len):
|
||||
if x is None:
|
||||
x = jnp.array([], dtype=jnp.int32)
|
||||
return array_view.ArrayView(x, shape=(0, 0, 0), strides=(0, 0, 0))
|
||||
|
||||
shape = orig_q_shape[:-3] + (num_heads_q, seq_len)
|
||||
return (
|
||||
array_view.ArrayView(x)
|
||||
.broadcast_to(shape)
|
||||
.collapse(0, -2, allow_copy=True)
|
||||
)
|
||||
|
||||
k_start = get_range_view(k_start, seq_len_q)
|
||||
k_end = get_range_view(k_end, seq_len_q)
|
||||
|
||||
block_q = 64
|
||||
block_k = 64
|
||||
|
||||
return jt.triton_call(
|
||||
q.base,
|
||||
k.base,
|
||||
v.base,
|
||||
bias.base,
|
||||
mask.base,
|
||||
k_start.base,
|
||||
k_end.base,
|
||||
q.offset,
|
||||
k.offset,
|
||||
v.offset,
|
||||
*q.strides,
|
||||
*k.strides,
|
||||
*v.strides,
|
||||
*bias.strides,
|
||||
*mask.strides,
|
||||
k_start.strides,
|
||||
k_end.strides,
|
||||
*jt.utils.strides_from_shape(q.shape), # out strides.
|
||||
num_heads_q,
|
||||
num_heads_kv,
|
||||
seq_len_q,
|
||||
seq_len_k,
|
||||
kernel=_fwd_kernel,
|
||||
name="triton_flash_attention",
|
||||
out_shape=jax.ShapeDtypeStruct(q.shape, q.dtype),
|
||||
grid=(num_heads_q, triton.cdiv(seq_len_q, block_q), batch_size),
|
||||
num_stages=2,
|
||||
num_warps=4,
|
||||
is_causal=is_causal,
|
||||
use_attention_mask=(mask.size != 0),
|
||||
use_k_start=(k_start.size != 0),
|
||||
use_k_end=(k_end.size != 0),
|
||||
use_bias=(bias.size != 0),
|
||||
sm_scale=logits_scale,
|
||||
block_q=block_q,
|
||||
block_k=block_k,
|
||||
head_dim=head_dim,
|
||||
use_mask_q=(seq_len_q % block_q != 0),
|
||||
use_mask_k=(seq_len_k % block_q != 0),
|
||||
bias_bcast_sq=(bias.strides[-2] == 0),
|
||||
mask_bcast_sq=(mask.strides[-2] == 0),
|
||||
dot_fn_qk=triton_utils.get_tl_dot_fn(q_k_dot_precision),
|
||||
dot_fn_kv=triton_utils.get_tl_dot_fn(weights_v_dot_precision),
|
||||
).reshape(orig_q_shape)
|
||||
|
||||
|
||||
def _as_batched_array_view(x, axis_size):
|
||||
batched_shape = (axis_size,) + x.shape
|
||||
batched_strides = (x.base.size // axis_size,) + x.strides
|
||||
return dataclasses.replace(x, shape=batched_shape, strides=batched_strides)
|
||||
|
||||
|
||||
def _fwd_vmap_rule(
|
||||
axis_size, in_batched, *args, fn: jax.custom_batching.custom_vmap
|
||||
):
|
||||
"""`vmap` rule for Triton FlashAttention forward op."""
|
||||
q, k, v, bias, mask, k_start, k_end = args
|
||||
(
|
||||
q_batched,
|
||||
k_batched,
|
||||
v_batched,
|
||||
bias_batched,
|
||||
mask_batched,
|
||||
k_start_batched,
|
||||
k_end_batched,
|
||||
) = in_batched
|
||||
|
||||
if q_batched.base:
|
||||
q = _as_batched_array_view(q, axis_size)
|
||||
if k_batched.base:
|
||||
k = _as_batched_array_view(k, axis_size)
|
||||
if v_batched.base:
|
||||
v = _as_batched_array_view(v, axis_size)
|
||||
|
||||
# Triton op requires `q`, `k`, `v` batch dims to be identical.
|
||||
if q_batched.base and k_batched.base and v_batched.base:
|
||||
if bias is not None and not bias_batched:
|
||||
bias = jax.lax.broadcast_to_rank(bias, bias.ndim + 1)
|
||||
if mask is not None and not mask_batched:
|
||||
mask = jax.lax.broadcast_to_rank(mask, mask.ndim + 1)
|
||||
if k_start is not None and not k_start_batched:
|
||||
k_start = jax.lax.broadcast_to_rank(k_start, k_start.ndim + 1)
|
||||
if k_end is not None and not k_end_batched:
|
||||
k_end = jax.lax.broadcast_to_rank(k_end, k_end.ndim + 1)
|
||||
out = fn(q, k, v, bias, mask, k_start, k_end)
|
||||
out_batched = True
|
||||
return out, out_batched
|
||||
|
||||
# Fallback to sequential loop.
|
||||
q, k, v = map(jnp.asarray, (q, k, v))
|
||||
in_batched = [
|
||||
q_batched.base,
|
||||
k_batched.base,
|
||||
v_batched.base,
|
||||
bias_batched,
|
||||
mask_batched,
|
||||
k_start_batched,
|
||||
k_end_batched,
|
||||
]
|
||||
|
||||
def f(q, k, v, *args, **kwargs):
|
||||
q, k, v = map(array_view.ArrayView, (q, k, v))
|
||||
return fn.fun(q, k, v, *args, **kwargs)
|
||||
|
||||
sequential_vmap = jax.custom_batching.sequential_vmap(f)
|
||||
return sequential_vmap.vmap_rule(axis_size, in_batched, q, k, v, *args[3:])
|
||||
|
||||
|
||||
def _decompose_mask(mask, q, k, q_indices, k_indices):
|
||||
"""Decomposes `mask` into a mask array, `is_causal`, `k_start` and `k_end`."""
|
||||
if mask is None:
|
||||
return None, False, None, None
|
||||
|
||||
is_causal = False
|
||||
k_start = None
|
||||
k_end = None
|
||||
if q_indices is None and k_indices is None:
|
||||
mask, is_causal, k_start, k_end = mask.take("is_causal", "k_start", "k_end")
|
||||
if k_start is not None:
|
||||
k_start = jax.lax.broadcast_to_rank(k_start, 2)
|
||||
if k_end is not None:
|
||||
k_end = jax.lax.broadcast_to_rank(k_end, 2)
|
||||
if is_causal: # Fold is_causal into k_end
|
||||
k_end = jnp.minimum(k_end, jnp.arange(1, q.shape[-3] + 1))
|
||||
is_causal = False
|
||||
|
||||
q_len_or_indices = q.shape[-3] if q_indices is None else q_indices
|
||||
k_len_or_indices = k.shape[-3] if k_indices is None else k_indices
|
||||
return (
|
||||
mask.as_array(q_len_or_indices, k_len_or_indices),
|
||||
is_causal,
|
||||
k_start,
|
||||
k_end,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class TritonFlashAttention(base.DotProductAttention):
|
||||
"""Triton FlashAttention implementation."""
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def _fwd(
|
||||
self,
|
||||
q: Float[array_view.ArrayView, "*B T H D"],
|
||||
k: Float[array_view.ArrayView, "*B t h D"],
|
||||
v: Float[array_view.ArrayView, "*B t h D"],
|
||||
bias: Float[Array, "*#B #H #T #t"] | None,
|
||||
*,
|
||||
q_k_dot_precision: precision_lib.DotPrecision,
|
||||
logits_dtype: jnp.dtype,
|
||||
logits_scale: float,
|
||||
mask: base.Mask | None,
|
||||
weights_v_dot_precision: precision_lib.DotPrecision,
|
||||
q_indices: Int[Array, "*#B #H T"] | None = None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
if logits_dtype != jnp.float32:
|
||||
raise ValueError("`logits_dtype` must be float32.")
|
||||
|
||||
kwargs = dict(
|
||||
logits_scale=logits_scale,
|
||||
q_k_dot_precision=q_k_dot_precision,
|
||||
weights_v_dot_precision=weights_v_dot_precision,
|
||||
)
|
||||
|
||||
def attend_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
mask_,
|
||||
q_indices,
|
||||
k_indices,
|
||||
):
|
||||
|
||||
mask, is_causal, k_start, k_end = _decompose_mask(
|
||||
mask_, q, k, q_indices, k_indices
|
||||
)
|
||||
|
||||
fwd_closed_kwargs = dict(
|
||||
is_causal=is_causal,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
fwd_closed = functools.partial(_fwd, **fwd_closed_kwargs)
|
||||
fwd_closed = jax.custom_batching.custom_vmap(fwd_closed)
|
||||
fwd_closed.def_vmap(functools.partial(_fwd_vmap_rule, fn=fwd_closed))
|
||||
|
||||
return fwd_closed(q, k, v, bias, mask, k_start, k_end)
|
||||
|
||||
return attend_fwd(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias,
|
||||
mask,
|
||||
q_indices,
|
||||
k_indices,
|
||||
)
|
||||
@@ -1,140 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""XLA implementation of scaled dot-product attention."""
|
||||
|
||||
import dataclasses
|
||||
|
||||
from alphafold3.jax.attention import attention_base as base
|
||||
from alphafold3.jax.common import array_view
|
||||
from alphafold3.jax.common import precision as precision_lib
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Float, Int # pylint: disable=g-multiple-import,g-importing-member
|
||||
import typeguard
|
||||
|
||||
|
||||
def _get_precision(
|
||||
backend: str, precision: precision_lib.DotPrecision
|
||||
) -> jax.lax.Precision:
|
||||
if backend == "gpu" and precision == precision_lib.DotPrecision.F32_F32:
|
||||
return jax.lax.Precision.HIGHEST
|
||||
return jax.lax.Precision.DEFAULT
|
||||
|
||||
|
||||
def einsum_with_dot_precision(
|
||||
subscript: str,
|
||||
a: jax.Array,
|
||||
b: jax.Array,
|
||||
*,
|
||||
precision: precision_lib.DotPrecision,
|
||||
) -> jax.Array:
|
||||
"""Evaluate `fn` with the given precision."""
|
||||
result = jnp.einsum(
|
||||
subscript,
|
||||
a.astype(precision.operand_dtype),
|
||||
b.astype(precision.operand_dtype),
|
||||
precision=_get_precision(jax.default_backend().lower(), precision),
|
||||
preferred_element_type=precision.accumulator_dtype,
|
||||
)
|
||||
assert result.dtype == precision.accumulator_dtype
|
||||
return result
|
||||
|
||||
|
||||
def _softmax(x: jax.Array) -> jax.Array:
|
||||
"""Computes softmax."""
|
||||
# Always perform reductions in at least f32 precision.
|
||||
dtype = jnp.promote_types(x.dtype, jnp.float32)
|
||||
x_max = jnp.max(x.astype(dtype), axis=-1, keepdims=True)
|
||||
unnormalized = jnp.exp(x - x_max)
|
||||
denom = jnp.sum(unnormalized, axis=-1, keepdims=True)
|
||||
return unnormalized / denom
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def _attend(
|
||||
q: Float[array_view.ArrayView, "*B T H D"],
|
||||
k: Float[array_view.ArrayView, "*B t #H D"],
|
||||
v: Float[array_view.ArrayView, "*B t #H D"],
|
||||
*,
|
||||
q_k_dot_precision: precision_lib.DotPrecision,
|
||||
logits_dtype: jnp.dtype,
|
||||
logits_scale: float,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None,
|
||||
mask: base.Mask | None,
|
||||
weights_v_dot_precision: precision_lib.DotPrecision,
|
||||
q_indices: Int[Array, "*#B #H T"] | None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
"""Computes attention."""
|
||||
logits = einsum_with_dot_precision(
|
||||
"...qhd,...khd->...hqk", q, k, precision=q_k_dot_precision
|
||||
).astype(logits_dtype)
|
||||
|
||||
logits *= logits_scale
|
||||
|
||||
if bias is not None:
|
||||
logits += bias
|
||||
|
||||
if mask is not None:
|
||||
q_len_or_indices = q.shape[-3] if q_indices is None else q_indices
|
||||
k_len_or_indices = k.shape[-3] if k_indices is None else k_indices
|
||||
mask = mask.as_array(q_len_or_indices, k_len_or_indices)
|
||||
|
||||
if mask is not None:
|
||||
mask_value = float(jnp.finfo(logits.dtype).min)
|
||||
|
||||
logits = jnp.where(jnp.asarray(mask), logits, mask_value)
|
||||
|
||||
weights = _softmax(logits)
|
||||
|
||||
weights = weights.astype(v.dtype)
|
||||
out = einsum_with_dot_precision(
|
||||
"...hqk,...khd->...qhd", weights, v, precision=weights_v_dot_precision
|
||||
).astype(q.dtype)
|
||||
return out
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class XlaDotProductAttention(base.DotProductAttention):
|
||||
"""XLA dot product attention function."""
|
||||
|
||||
_: dataclasses.KW_ONLY
|
||||
|
||||
def _fwd(
|
||||
self,
|
||||
q: Float[array_view.ArrayView, "*B T H D"],
|
||||
k: Float[array_view.ArrayView, "*B t #H D"],
|
||||
v: Float[array_view.ArrayView, "*B t #H D"],
|
||||
*,
|
||||
q_k_dot_precision: precision_lib.DotPrecision,
|
||||
logits_dtype: jnp.dtype,
|
||||
logits_scale: float,
|
||||
bias: Float[Array, "*#B #H #T #t"] | None,
|
||||
mask: base.Mask | None,
|
||||
weights_v_dot_precision: precision_lib.DotPrecision,
|
||||
q_indices: Int[Array, "*#B #H T"] | None = None,
|
||||
k_indices: Int[Array, "*#B #H t"] | None = None,
|
||||
) -> Float[Array, "*B T H D"]:
|
||||
|
||||
return _attend(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
bias=bias,
|
||||
mask=mask,
|
||||
q_indices=q_indices,
|
||||
k_indices=k_indices,
|
||||
q_k_dot_precision=q_k_dot_precision,
|
||||
logits_dtype=logits_dtype,
|
||||
logits_scale=logits_scale,
|
||||
weights_v_dot_precision=weights_v_dot_precision,
|
||||
)
|
||||
@@ -1,405 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Array view class and utilities."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
import dataclasses
|
||||
import math
|
||||
import operator
|
||||
from types import EllipsisType # pylint: disable=g-importing-member
|
||||
from typing import Any, Self, TypeAlias, TypeVar
|
||||
|
||||
import jax
|
||||
import jax.experimental
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
from jax.typing import ArrayLike # pylint: disable=g-importing-member
|
||||
from jaxtyping import Int # pylint: disable=g-importing-member
|
||||
import numpy as np
|
||||
|
||||
ArrayT: TypeAlias = Any
|
||||
ScalarInt: TypeAlias = (
|
||||
Int[ArrayT, ""] | Int[np.generic, ""] | Int[jnp.generic, ""]
|
||||
)
|
||||
|
||||
Indexer: TypeAlias = int | ScalarInt | slice | pl.Slice | EllipsisType
|
||||
|
||||
|
||||
@jax.tree_util.register_pytree_node_class
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ArrayView:
|
||||
"""A strided view of a JAX array."""
|
||||
|
||||
base: jax.Array
|
||||
_: dataclasses.KW_ONLY
|
||||
# These are set by `__post_init__` so `None` value is never seen after init.
|
||||
shape: tuple[int, ...] = None # type: ignore
|
||||
strides: tuple[int, ...] = None # type: ignore
|
||||
offset: int | ScalarInt = 0
|
||||
flatten_base: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
if self.shape is None:
|
||||
object.__setattr__(self, "shape", self.base.shape)
|
||||
|
||||
if self.strides is None:
|
||||
object.__setattr__(self, "strides", pl.strides_from_shape(self.shape))
|
||||
|
||||
if len(self.shape) != len(self.strides):
|
||||
raise ValueError("`shape` and `strides` must have the same length.")
|
||||
|
||||
# Within `jax.vjp`, we can get non-`Array` values here (such as `object`).
|
||||
if isinstance(self.base, jax.Array):
|
||||
if isinstance(self.offset, int):
|
||||
if not (0 <= self.offset < max(self.base.size, 1)):
|
||||
raise ValueError("Invalid `offset`.")
|
||||
|
||||
if self.flatten_base:
|
||||
if len(self.base.shape) != 1:
|
||||
object.__setattr__(self, "base", self.base.reshape((-1,)))
|
||||
|
||||
def tree_flatten(self):
|
||||
if isinstance(self.offset, int):
|
||||
return (self.base,), (self.offset, self.shape, self.strides)
|
||||
return (self.base, self.offset), (self.shape, self.strides)
|
||||
|
||||
@classmethod
|
||||
def tree_unflatten(cls, aux, children) -> Self:
|
||||
base, offset, shape, strides = (*children, *aux)
|
||||
return cls(base, shape=shape, strides=strides, offset=offset)
|
||||
|
||||
@property
|
||||
def dtype(self) -> jnp.dtype:
|
||||
return self.base.dtype
|
||||
|
||||
@property
|
||||
def size(self) -> int:
|
||||
return math.prod(self.shape)
|
||||
|
||||
@property
|
||||
def ndim(self) -> int:
|
||||
return len(self.shape)
|
||||
|
||||
@property
|
||||
def T(self) -> Self: # pylint: disable=invalid-name
|
||||
return self.transpose()
|
||||
|
||||
@property
|
||||
def _index_dtype(self) -> jax.typing.DTypeLike:
|
||||
i32_max = jnp.iinfo(jnp.int32).max
|
||||
return jnp.int32 if (self.base.size <= i32_max) else jnp.int64
|
||||
|
||||
@property
|
||||
def offsets(self) -> jax.Array:
|
||||
"""Returns array of offsets into `base` for each element."""
|
||||
with jax.experimental.enable_x64():
|
||||
idxs = jnp.indices(self.shape, sparse=True, dtype=self._index_dtype)
|
||||
return self.offset + sum(s * idx for s, idx in zip(self.strides, idxs))
|
||||
|
||||
def astype(self, dtype: jax.typing.DTypeLike) -> Self:
|
||||
return self._replace(base=self.base.astype(dtype))
|
||||
|
||||
def broadcast_to_rank(self, rank: int) -> Self:
|
||||
"""Returns a new view with the specified rank."""
|
||||
if rank < self.ndim:
|
||||
raise ValueError(f"Cannot broadcast to lower rank: {rank} < {self.ndim}.")
|
||||
|
||||
shape = (1,) * (rank - self.ndim) + self.shape
|
||||
strides = (0,) * (rank - self.ndim) + self.strides
|
||||
return self._replace(shape=shape, strides=strides)
|
||||
|
||||
def broadcast_to(self, shape: tuple[int, ...]) -> Self:
|
||||
"""Returns a new view with the specified shape."""
|
||||
view = self.broadcast_to_rank(len(shape))
|
||||
strides = []
|
||||
for dim_size, stride, target_size in zip(
|
||||
view.shape, view.strides, shape, strict=True
|
||||
):
|
||||
if dim_size == target_size:
|
||||
strides.append(stride)
|
||||
elif dim_size == 1:
|
||||
strides.append(0)
|
||||
else:
|
||||
raise ValueError(f"Cannot broadcast {self.shape} to {shape}.")
|
||||
return self._replace(shape=shape, strides=strides)
|
||||
|
||||
def collapse(
|
||||
self, start: int, stop: int | None = None, *, allow_copy: bool = False
|
||||
) -> Self:
|
||||
"""Returns a new view with the axis range collapsed into one axis."""
|
||||
lo, hi, _ = slice(start, stop).indices(self.ndim)
|
||||
if hi < lo:
|
||||
raise ValueError(
|
||||
"Invalid dimension range passed to collapse: "
|
||||
f"{self.shape} [{start}:{stop}]"
|
||||
)
|
||||
shape = self.shape[:lo] + (-1,) + self.shape[hi:]
|
||||
return self.reshape(shape, allow_copy=allow_copy)
|
||||
|
||||
def reshape(self, shape: Sequence[int], *, allow_copy: bool = False) -> Self:
|
||||
"""Returns a new view with the specified shape."""
|
||||
try:
|
||||
return self._reshape(tuple(shape))
|
||||
except ValueError:
|
||||
if not allow_copy:
|
||||
raise
|
||||
return type(self)(jnp.array(self)).reshape(shape)
|
||||
|
||||
def _reshape(self, shape: tuple[int, ...]) -> Self:
|
||||
"""Returns a new view with the specified shape."""
|
||||
|
||||
if (num_minus_one_dims := shape.count(-1)) > 0:
|
||||
if num_minus_one_dims > 1:
|
||||
raise ValueError("`shape` may only contain a single `-1` dimension.")
|
||||
pos = shape.index(-1)
|
||||
shape = list(shape)
|
||||
shape[pos] = self.size // math.prod(d for d in shape if d != -1)
|
||||
|
||||
if math.prod(shape) != self.size:
|
||||
raise ValueError("Mismatched number of elements.")
|
||||
|
||||
# Logic copied from `numpy` C++ code.
|
||||
# Remove axes with length 1, to simplify logic below.
|
||||
old_shape = [d for d in self.shape if d != 1]
|
||||
old_strides = [s for i, s in enumerate(self.strides) if self.shape[i] != 1]
|
||||
strides = [0] * len(shape)
|
||||
|
||||
# Axes currently being worked upon.
|
||||
old_start, old_stop = 0, 1
|
||||
new_start, new_stop = 0, 1
|
||||
|
||||
while (old_start < len(old_shape)) and (new_start < len(shape)):
|
||||
old_axes_prod = old_shape[old_start]
|
||||
new_axes_prod = shape[new_start]
|
||||
while old_axes_prod != new_axes_prod:
|
||||
if old_axes_prod < new_axes_prod:
|
||||
old_axes_prod *= old_shape[old_stop]
|
||||
old_stop += 1
|
||||
else:
|
||||
new_axes_prod *= shape[new_stop]
|
||||
new_stop += 1
|
||||
|
||||
# Check if original axes can be combined.
|
||||
for i in range(old_start, old_stop - 1):
|
||||
if old_strides[i] != old_shape[i + 1] * old_strides[i + 1]:
|
||||
raise ValueError("Cannot combine axes non-contiguous in memory.")
|
||||
|
||||
# Calculate new strides.
|
||||
strides[new_stop - 1] = old_strides[old_stop - 1]
|
||||
for i in range(new_stop - 1, new_start, -1):
|
||||
strides[i - 1] = strides[i] * shape[i]
|
||||
|
||||
old_start, old_stop = old_stop, old_stop + 1
|
||||
new_start, new_stop = new_stop, new_stop + 1
|
||||
|
||||
return self._replace(shape=shape, strides=strides)
|
||||
|
||||
def split(
|
||||
self, indices_or_sections: int | Sequence[int], axis: int = 0
|
||||
) -> tuple[Self, ...]:
|
||||
"""Splits the view into multiple slice views."""
|
||||
if isinstance(indices_or_sections, int):
|
||||
if self.shape[axis] % indices_or_sections != 0:
|
||||
raise ValueError("Axis size is not divisible by number of sections.")
|
||||
|
||||
chunk = self.shape[axis] // indices_or_sections
|
||||
indices_or_sections = [i * chunk for i in range(1, indices_or_sections)]
|
||||
|
||||
los = (0, *indices_or_sections)
|
||||
his = (*indices_or_sections, None)
|
||||
slice_prefix = (slice(None),) * _canonicalize_axis(axis, self.ndim)
|
||||
return tuple(self[*slice_prefix, slice(lo, hi)] for lo, hi in zip(los, his))
|
||||
|
||||
def swapaxes(self, axis1: int, axis2: int) -> Self:
|
||||
"""Returns a new view with the specified axis swapped."""
|
||||
axes = list(range(self.ndim))
|
||||
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
|
||||
return self.transpose(axes)
|
||||
|
||||
def moveaxis(self, source: int, destination: int) -> Self:
|
||||
"""Returns a new view with the specified axis moved."""
|
||||
source, destination = source % self.ndim, destination % self.ndim
|
||||
axes = list(range(self.ndim))
|
||||
del axes[source]
|
||||
axes.insert(destination, source)
|
||||
return self.transpose(axes)
|
||||
|
||||
def transpose(self, axes: Sequence[int] | None = None) -> Self:
|
||||
"""Returns a new view with the specified axes order."""
|
||||
if axes is None:
|
||||
axes = tuple(reversed(range(self.ndim)))
|
||||
if len(axes) != self.ndim:
|
||||
raise ValueError("`axes` must have the same dimensionality as the array.")
|
||||
shape = tuple(self.shape[a] for a in axes)
|
||||
strides = tuple(self.strides[a] for a in axes)
|
||||
return self._replace(shape=shape, strides=strides)
|
||||
|
||||
def __getitem__(self, idxs: Indexer | tuple[Indexer, ...]) -> Self:
|
||||
if not isinstance(idxs, tuple):
|
||||
idxs = (idxs,)
|
||||
|
||||
if len(idxs) > self.ndim:
|
||||
raise ValueError("Too many slice indices.")
|
||||
|
||||
num_ellipses = idxs.count(Ellipsis)
|
||||
if num_ellipses > 1:
|
||||
raise ValueError("Multiple `...` are not supported.")
|
||||
elif num_ellipses == 0:
|
||||
idxs += (Ellipsis,) # `[a:b]` is equivalent to `[a:b, ...]`.
|
||||
|
||||
# Replace `...` with slices that take the entirety of the missing axes.
|
||||
ellipsis_idx = idxs.index(Ellipsis)
|
||||
ellipsis_slices = (slice(None),) * (self.ndim - len(idxs) + 1)
|
||||
idxs = idxs[:ellipsis_idx] + ellipsis_slices + idxs[ellipsis_idx + 1 :]
|
||||
|
||||
shape = []
|
||||
strides = []
|
||||
with jax.experimental.enable_x64():
|
||||
|
||||
def as_index(x):
|
||||
return x.astype(self._index_dtype) if isinstance(x, jax.Array) else x
|
||||
|
||||
offset = as_index(self.offset)
|
||||
|
||||
for idx, dim, stride in zip(idxs, self.shape, self.strides, strict=True):
|
||||
if isinstance(idx, int):
|
||||
if not (-dim <= idx < dim):
|
||||
raise ValueError("Slice index out of range.")
|
||||
offset += stride * (idx % dim)
|
||||
elif isinstance(idx, ScalarInt):
|
||||
offset += stride * as_index(idx)
|
||||
elif isinstance(idx, slice):
|
||||
start, stop, step = idx.indices(dim)
|
||||
if step >= 0:
|
||||
shape.append(pl.cdiv(stop - start, step))
|
||||
else:
|
||||
shape.append(pl.cdiv(start - stop, -step))
|
||||
strides.append(stride * step)
|
||||
offset += stride * start
|
||||
elif isinstance(idx, pl.Slice):
|
||||
shape.append(idx.size)
|
||||
strides.append(stride * idx.stride)
|
||||
offset += stride * as_index(idx.start)
|
||||
else:
|
||||
raise ValueError(f"Unexpected indexer: {idx}")
|
||||
|
||||
return self._replace(shape=shape, strides=strides, offset=offset)
|
||||
|
||||
def _replace(self, **kwargs) -> Self:
|
||||
if "shape" in kwargs:
|
||||
kwargs["shape"] = tuple(kwargs["shape"])
|
||||
if "strides" in kwargs:
|
||||
kwargs["strides"] = tuple(kwargs["strides"])
|
||||
return dataclasses.replace(self, **kwargs)
|
||||
|
||||
def set(self, value: ArrayLike | "ArrayView") -> Self:
|
||||
"""Returns a new view with the views values set to `value`."""
|
||||
if any(s == 0 for s in self.strides):
|
||||
raise ValueError("Cannot set values on a broadcasted array.")
|
||||
|
||||
# Try to just transpose the value, if possible.
|
||||
major_to_minor = np.argsort(-np.array(self.strides), kind="stable")
|
||||
value = jnp.array(value)
|
||||
value_transposed = value.transpose(major_to_minor)
|
||||
if (
|
||||
self.transpose(major_to_minor).strides
|
||||
== ArrayView(value_transposed).strides
|
||||
):
|
||||
base = jax.lax.dynamic_update_slice(
|
||||
self.base, value_transposed.flatten(), (self.offset,)
|
||||
)
|
||||
else:
|
||||
base = self.base.at[self.offsets].set(value)
|
||||
return self._replace(base=base)
|
||||
|
||||
def __jax_array__(self) -> jax.Array:
|
||||
"""Returns values as a dense array."""
|
||||
# Try to express using transpose, slice, and reshape, to encourage XLA to
|
||||
# fuse into other ops, rather than materialising the values. Otherwise,
|
||||
# fall back to using a gather.
|
||||
if (self.ndim == 0) or any(s < 0 for s in self.strides):
|
||||
return self.base[self.offsets]
|
||||
|
||||
major_to_minor = np.argsort(-np.array(self.strides), kind="stable")
|
||||
|
||||
# Construct a shape that gives us the correct strides.
|
||||
bcast_axes = []
|
||||
shape = []
|
||||
for axis in major_to_minor[::-1]: # minor to major
|
||||
stride = self.strides[axis]
|
||||
if stride == 0:
|
||||
bcast_axes.append(axis)
|
||||
shape.append(1)
|
||||
continue
|
||||
|
||||
if stride % math.prod(shape) != 0:
|
||||
raise ValueError("Cannot express as a reshape, then slice.")
|
||||
shape.append(stride // math.prod(shape))
|
||||
|
||||
if self.base.size % math.prod(shape) != 0:
|
||||
return self.base[self.offsets]
|
||||
|
||||
shape = [self.base.size // math.prod(shape), *reversed(shape)]
|
||||
slice_sizes = [
|
||||
*(1 if a in bcast_axes else self.shape[a] for a in major_to_minor),
|
||||
1,
|
||||
]
|
||||
|
||||
if shape[0] == self.shape[major_to_minor[0]]:
|
||||
needs_offset_slice = False
|
||||
elif not isinstance(self.offset, int):
|
||||
needs_offset_slice = True
|
||||
else:
|
||||
start_indices = np.unravel_index(self.offset, shape)
|
||||
end_indices = [s + size for s, size in zip(start_indices, slice_sizes)]
|
||||
needs_offset_slice = any(e > dim for e, dim in zip(end_indices, shape))
|
||||
|
||||
if needs_offset_slice:
|
||||
shape[0] = self.shape[major_to_minor[0]]
|
||||
size = math.prod(shape)
|
||||
# The pad is necessary to ensure that the dynamic slice is in range.
|
||||
vals = jnp.pad(self.base, (0, size))
|
||||
vals = jax.lax.dynamic_slice(vals, (self.offset,), (size,))
|
||||
start_indices = [0] * len(shape)
|
||||
else:
|
||||
vals = self.base
|
||||
start_indices = jnp.unravel_index(self.offset, shape)
|
||||
|
||||
vals = vals.reshape(shape)
|
||||
vals = jax.lax.dynamic_slice(vals, start_indices, slice_sizes)[..., 0]
|
||||
# Move axes from their physical ordering to their logical ordering.
|
||||
vals = vals.transpose(np.argsort(major_to_minor))
|
||||
return jnp.broadcast_to(vals, self.shape)
|
||||
|
||||
|
||||
def as_array_view(x: jax.Array | ArrayView) -> ArrayView:
|
||||
return x if isinstance(x, ArrayView) else ArrayView(x)
|
||||
|
||||
|
||||
T = TypeVar("T", jax.Array, ArrayView)
|
||||
|
||||
|
||||
def zeros_like(x: T) -> T:
|
||||
if isinstance(x, ArrayView):
|
||||
return x._replace(base=jnp.zeros_like(x.base))
|
||||
return jnp.zeros_like(x)
|
||||
|
||||
|
||||
def _canonicalize_axis(axis, num_dims) -> int:
|
||||
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
|
||||
axis = operator.index(axis)
|
||||
if not -num_dims <= axis < num_dims:
|
||||
raise ValueError(
|
||||
f"axis {axis} is out of bounds for array of dimension {num_dims}"
|
||||
)
|
||||
if axis < 0:
|
||||
axis = axis + num_dims
|
||||
return axis
|
||||
@@ -1,92 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Precision classes and utilities."""
|
||||
|
||||
import enum
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
|
||||
@enum.unique
|
||||
class DotPrecision(enum.Enum):
|
||||
"""Precision for `dot` operation.
|
||||
|
||||
Naming scheme: {OPERAND_DTYPE}_{ACCUMULATOR_DTYPE}[_{NUM_PASSES}x]
|
||||
"""
|
||||
|
||||
BF16_F32 = "bf16_f32"
|
||||
|
||||
# GPU only precisions.
|
||||
F32_F32 = "f32_f32" # Full f32 precision (doesn't use TensorCores).
|
||||
TF32_F32 = "tf32_f32" # Equivalent to `DEFAULT`/`HIGH` on GPU.
|
||||
TF32_F32_3X = "tf32_f32_3x"
|
||||
F16_F16 = "f16_f16"
|
||||
F16_F32 = "f16_f32"
|
||||
|
||||
@property
|
||||
def operand_dtype(self) -> jnp.dtype:
|
||||
match self:
|
||||
case DotPrecision.BF16_F32:
|
||||
return jnp.bfloat16
|
||||
case DotPrecision.F16_F16 | DotPrecision.F16_F32:
|
||||
return jnp.float16
|
||||
case _:
|
||||
return jnp.float32
|
||||
|
||||
@property
|
||||
def accumulator_dtype(self) -> jnp.dtype:
|
||||
return jnp.float16 if (self == DotPrecision.F16_F16) else jnp.float32
|
||||
|
||||
|
||||
_JAX_GPU_PRECISION_MAP = {
|
||||
(jnp.float16, jax.lax.Precision.DEFAULT): DotPrecision.F16_F32,
|
||||
(jnp.bfloat16, jax.lax.Precision.DEFAULT): DotPrecision.BF16_F32,
|
||||
(jnp.float32, jax.lax.Precision.DEFAULT): DotPrecision.TF32_F32,
|
||||
(jnp.float32, jax.lax.Precision.HIGH): DotPrecision.TF32_F32,
|
||||
(jnp.float32, jax.lax.Precision.HIGHEST): DotPrecision.F32_F32,
|
||||
}
|
||||
|
||||
_JAX_CPU_PRECISION_MAP = {
|
||||
(jnp.float16, jax.lax.Precision.DEFAULT): DotPrecision.F16_F32,
|
||||
(jnp.bfloat16, jax.lax.Precision.DEFAULT): DotPrecision.F32_F32,
|
||||
(jnp.float32, jax.lax.Precision.DEFAULT): DotPrecision.F32_F32,
|
||||
(jnp.float32, jax.lax.Precision.HIGH): DotPrecision.F32_F32,
|
||||
(jnp.float32, jax.lax.Precision.HIGHEST): DotPrecision.F32_F32,
|
||||
}
|
||||
|
||||
|
||||
def _create_jax_precision_map():
|
||||
precision_map = {}
|
||||
for (dtype, jax_precision), dot_precision in _JAX_GPU_PRECISION_MAP.items():
|
||||
precision_map[("gpu", jnp.dtype(dtype), jax_precision)] = dot_precision
|
||||
for (dtype, jax_precision), dot_precision in _JAX_CPU_PRECISION_MAP.items():
|
||||
precision_map[("cpu", jnp.dtype(dtype), jax_precision)] = dot_precision
|
||||
return precision_map
|
||||
|
||||
|
||||
_JAX_PRECISION_MAP = _create_jax_precision_map()
|
||||
|
||||
|
||||
def get_equivalent_dot_precision(
|
||||
a_dtype: jnp.dtype, b_dtype: jnp.dtype, jax_precision: jax.lax.Precision
|
||||
) -> DotPrecision:
|
||||
"""Returns `DotPrecision` replicating default XLA behaviour."""
|
||||
if a_dtype != b_dtype:
|
||||
raise ValueError("Cannot infer precision if operand types differ.")
|
||||
|
||||
backend = jax.default_backend().lower()
|
||||
if (jax_precision != jax.lax.Precision.DEFAULT) and (a_dtype != jnp.float32):
|
||||
raise ValueError(
|
||||
"`jax.lax.Precision` values other than `DEFAULT` only have an effect if"
|
||||
" the operand type is `float32`."
|
||||
)
|
||||
return _JAX_PRECISION_MAP[(backend, a_dtype, jax_precision)]
|
||||
@@ -1,125 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Triton utils."""
|
||||
|
||||
from collections.abc import Callable, Mapping
|
||||
|
||||
from alphafold3.jax.common import precision as precision_lib
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
_JNP_TO_TL_DTYPES: Mapping[jnp.dtype, tl.dtype] = {
|
||||
jnp.bool_: tl.int1,
|
||||
jnp.int8: tl.int8,
|
||||
jnp.int16: tl.int16,
|
||||
jnp.int32: tl.int32,
|
||||
jnp.int64: tl.int64,
|
||||
jnp.uint8: tl.uint8,
|
||||
jnp.uint16: tl.uint16,
|
||||
jnp.uint32: tl.uint32,
|
||||
jnp.uint64: tl.uint64,
|
||||
jnp.float16: tl.float16,
|
||||
jnp.bfloat16: tl.bfloat16,
|
||||
jnp.float32: tl.float32,
|
||||
jnp.float64: tl.float64,
|
||||
}
|
||||
|
||||
|
||||
def jnp_to_tl_dtype(jnp_dtype: jnp.dtype) -> tl.dtype:
|
||||
return _JNP_TO_TL_DTYPES[jnp_dtype]
|
||||
|
||||
|
||||
def get_tl_dot_fn(
|
||||
precision: precision_lib.DotPrecision,
|
||||
) -> Callable[..., tl.tensor]:
|
||||
"""Returns a tl `dot` implementation with the specified precision.
|
||||
|
||||
Args:
|
||||
precision: The `dot` precision.
|
||||
"""
|
||||
if not is_precision_supported(precision):
|
||||
raise ValueError(f'Unsupported dot precision: {precision}')
|
||||
|
||||
if precision == precision_lib.DotPrecision.TF32_F32_3X:
|
||||
return _dot_tf32_f32_3x
|
||||
|
||||
in_dtype = jnp_to_tl_dtype(precision.operand_dtype)
|
||||
out_dtype = jnp_to_tl_dtype(precision.accumulator_dtype)
|
||||
allow_tf32 = precision == precision_lib.DotPrecision.TF32_F32
|
||||
|
||||
@tl.core.extern
|
||||
def _dot_fn(
|
||||
a: tl.core.tensor,
|
||||
b: tl.core.tensor,
|
||||
*,
|
||||
trans_a: bool = False,
|
||||
trans_b: bool = False,
|
||||
_builder,
|
||||
):
|
||||
if in_dtype == tl.float32:
|
||||
tl.static_assert(a.dtype == tl.float32, _builder=_builder)
|
||||
tl.static_assert(b.dtype == tl.float32, _builder=_builder)
|
||||
else:
|
||||
tl.static_assert(a.dtype.is_standard_floating(), _builder=_builder)
|
||||
tl.static_assert(b.dtype.is_standard_floating(), _builder=_builder)
|
||||
a = a.to(in_dtype, _builder=_builder)
|
||||
b = b.to(in_dtype, _builder=_builder)
|
||||
a = tl.trans(a, _builder=_builder) if trans_a else a
|
||||
b = tl.trans(b, _builder=_builder) if trans_b else b
|
||||
return tl.dot(
|
||||
a, b, allow_tf32=allow_tf32, out_dtype=out_dtype, _builder=_builder
|
||||
)
|
||||
|
||||
return _dot_fn
|
||||
|
||||
|
||||
def is_precision_supported(precision: precision_lib.DotPrecision) -> bool:
|
||||
return precision in {
|
||||
precision_lib.DotPrecision.F32_F32,
|
||||
precision_lib.DotPrecision.TF32_F32,
|
||||
precision_lib.DotPrecision.F16_F32,
|
||||
precision_lib.DotPrecision.BF16_F32,
|
||||
precision_lib.DotPrecision.TF32_F32_3X,
|
||||
}
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _dot_tf32_f32_3x(a, b, trans_a=False, trans_b=False):
|
||||
"""Perform the 3-pass tf32 dot function."""
|
||||
tl.static_assert(a.dtype == tl.float32)
|
||||
tl.static_assert(b.dtype == tl.float32)
|
||||
a_ = (a.to(tl.uint32, bitcast=True) & 0xFFFFE000).to(tl.float32, bitcast=True)
|
||||
b_ = (b.to(tl.uint32, bitcast=True) & 0xFFFFE000).to(tl.float32, bitcast=True)
|
||||
a_err = a - a_
|
||||
b_err = b - b_
|
||||
if trans_a:
|
||||
a_ = tl.trans(a_)
|
||||
a_err = tl.trans(a_err)
|
||||
if trans_b:
|
||||
b_ = tl.trans(b_)
|
||||
b_err = tl.trans(b_err)
|
||||
# Add smallest terms first for better accuracy.
|
||||
return tl.dot(a_, b_, out_dtype=tl.float32) + (
|
||||
tl.dot(a_, b_err, out_dtype=tl.float32)
|
||||
+ tl.dot(a_err, b_, out_dtype=tl.float32)
|
||||
)
|
||||
|
||||
|
||||
def has_triton_support() -> bool:
|
||||
"""Returns True if Triton is supported by the default JAX device."""
|
||||
if jax.default_backend() != 'gpu':
|
||||
return False
|
||||
|
||||
# Only currently supported for Ampere and above.
|
||||
return float(jax.devices()[0].compute_capability) >= 8.0
|
||||
@@ -1,126 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Pallas block load / store utilities."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from alphafold3.jax.common import array_view
|
||||
import jax
|
||||
import jax.experimental
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping
|
||||
from jaxtyping import Int # pylint: disable=g-importing-member
|
||||
import numpy as np
|
||||
import typeguard
|
||||
|
||||
ArrayT: TypeAlias = Any
|
||||
ScalarInt: TypeAlias = (
|
||||
Int[ArrayT, ""] | Int[np.generic, ""] | Int[jnp.generic, ""]
|
||||
)
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def load_block(
|
||||
ref,
|
||||
idx: Sequence[int | ScalarInt],
|
||||
*,
|
||||
block_shape: Sequence[int | None],
|
||||
other=None,
|
||||
**kwargs,
|
||||
) -> jax.Array:
|
||||
"""Loads a block from the given `ref`, masking where necessary."""
|
||||
idx, mask = _get_block_indexer_and_mask(ref, idx, block_shape=block_shape)
|
||||
if isinstance(ref, array_view.ArrayView):
|
||||
idx = ref[idx].offsets
|
||||
ref = ref.base
|
||||
other = None if mask is None else other
|
||||
with jax.experimental.enable_x64():
|
||||
return pl.load(ref, idx, mask=mask, other=other, **kwargs)
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def store_block(
|
||||
ref,
|
||||
val: jax.Array,
|
||||
idx: Sequence[int | ScalarInt],
|
||||
*,
|
||||
block_shape: Sequence[int | None] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Stores a block from the given `ref`, masking where necessary."""
|
||||
if block_shape is None:
|
||||
block_shape = val.shape
|
||||
idx, mask = _get_block_indexer_and_mask(ref, idx, block_shape=block_shape)
|
||||
if isinstance(ref, array_view.ArrayView):
|
||||
idx = ref[idx].offsets
|
||||
ref = ref.base
|
||||
with jax.experimental.enable_x64():
|
||||
pl.store(ref, idx, val.astype(ref.dtype), mask=mask, **kwargs)
|
||||
|
||||
|
||||
def in_bounds_mask(
|
||||
idx: Sequence[int | slice | pl.Slice | jax.Array],
|
||||
shape: Sequence[int],
|
||||
*,
|
||||
check: Sequence[bool] | None = None,
|
||||
) -> jax.Array | None:
|
||||
"""Returns a boolean mask denoting which indices are within bounds.
|
||||
|
||||
Args:
|
||||
idx: Indices for each dimension.
|
||||
shape: Shape designating the valid bounds.
|
||||
check: Whether or not to check bounds in each dimension. Useful for ignoring
|
||||
indices known to be in bounds. Defaults to all True.
|
||||
"""
|
||||
if check is None:
|
||||
check = [True] * len(shape)
|
||||
|
||||
# Remove `int` indexed dims (mask shape must match slice result shape).
|
||||
shape = [dim for i, dim in enumerate(shape) if not isinstance(idx[i], int)]
|
||||
check = [chk for i, chk in enumerate(check) if not isinstance(idx[i], int)]
|
||||
idx = [idx for idx in idx if not isinstance(idx, int)]
|
||||
|
||||
mask = None
|
||||
for i, (dim_idx, dim, chk) in enumerate(zip(idx, shape, check, strict=True)):
|
||||
if not chk:
|
||||
continue
|
||||
|
||||
if isinstance(dim_idx, slice):
|
||||
dim_idx = pl.Slice.from_slice(dim_idx, dim)
|
||||
if isinstance(dim_idx, pl.Slice):
|
||||
dim_idx = dim_idx.start + dim_idx.stride * jnp.arange(dim_idx.size)
|
||||
if dim_idx.ndim != 1:
|
||||
raise NotImplementedError("Only one-dimensional indices are supported.")
|
||||
|
||||
bcast_axes = [a for a in range(len(shape)) if a != i]
|
||||
dim_mask = jnp.expand_dims(dim_idx < dim, bcast_axes)
|
||||
mask = dim_mask if mask is None else (mask & dim_mask)
|
||||
return mask
|
||||
|
||||
|
||||
def _get_block_indexer_and_mask(
|
||||
ref, idx: Sequence[int | ScalarInt], *, block_shape: Sequence[int | None]
|
||||
) -> tuple[tuple[int | slice | pl.Slice, ...], jax.Array | None]:
|
||||
"""Return indices and mask for loading / storing a block."""
|
||||
shape = ref.shape
|
||||
idxs = []
|
||||
check = []
|
||||
for dim, block_idx, block_dim in zip(shape, idx, block_shape, strict=True):
|
||||
if block_dim is None:
|
||||
idxs.append(block_idx)
|
||||
check.append(False)
|
||||
else:
|
||||
idxs.append(pl.dslice(block_dim * block_idx, block_dim))
|
||||
check.append(dim % block_dim != 0)
|
||||
|
||||
return tuple(idxs), in_bounds_mask(idxs, shape, check=check)
|
||||
@@ -1,124 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Public API for gated linear unit functions."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import typing
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from alphafold3.jax.common import array_view
|
||||
from alphafold3.jax.common import triton_utils
|
||||
from alphafold3.jax.gated_linear_unit import gated_linear_unit_base
|
||||
from alphafold3.jax.gated_linear_unit import matmul_ext
|
||||
import jax
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Float # pylint: disable=g-importing-member,g-multiple-import
|
||||
import typeguard
|
||||
|
||||
Implementation: TypeAlias = Literal['xla', 'triton']
|
||||
|
||||
|
||||
class PallasGatedLinearUnit(gated_linear_unit_base.GatedLinearUnit):
|
||||
"""Pallas gated linear unit."""
|
||||
|
||||
def _fwd(self, x, weight, *, activation, precision):
|
||||
weight_view = array_view.ArrayView(weight)
|
||||
return self.apply_vmap_rule_forward(
|
||||
matmul_ext.gated_linear_unit,
|
||||
activation=activation,
|
||||
precision=precision,
|
||||
)(
|
||||
x,
|
||||
weight_view[:, 1],
|
||||
weight_view[:, 0],
|
||||
)
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def gated_linear_unit(
|
||||
x: Float[Array, '*B M K'],
|
||||
weight: Float[Array, 'K 2 N'],
|
||||
*,
|
||||
activation: Callable[[jax.Array], jax.Array] | None = None,
|
||||
precision: jax.lax.Precision | None = None,
|
||||
implementation: Implementation | None = None,
|
||||
) -> Float[Array, '*B M N']:
|
||||
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
|
||||
|
||||
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
|
||||
|
||||
This is SwiGLU when `activation=jax.nn.swish`, GEGLU when
|
||||
`activation=jax.nn.gelu`, REGLU when `activation=jax.nn.relu`, and GLU when
|
||||
`activation=jax.nn.sigmoid` (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Args:
|
||||
x: the input array.
|
||||
weight: the combined weight array.
|
||||
activation: optional activation function.
|
||||
precision: specifies the matrix multiplication precision. Either `None`
|
||||
(default), which means the default precision for the backend, or a
|
||||
`jax.lax.Precision` enum.
|
||||
implementation: if `None` (default), an implementation is automatically
|
||||
chosen. 'xla' will use standard XLA and work on any platform, and 'triton'
|
||||
will use a fused Triton GPU kernel. Only a subset of data types, shapes
|
||||
and GPUs are supported by 'triton', with an exception thrown in this case.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: if `implementation='triton'` does not support a given
|
||||
input or device.
|
||||
ValueError: if the arguments are invalid.
|
||||
|
||||
Returns:
|
||||
The output array.
|
||||
"""
|
||||
|
||||
match implementation:
|
||||
case 'triton':
|
||||
if not triton_utils.has_triton_support():
|
||||
raise NotImplementedError('Triton not supported on this platform.')
|
||||
case _:
|
||||
...
|
||||
|
||||
if x.dtype.name != weight.dtype.name:
|
||||
raise ValueError(
|
||||
f'Input and weight must have the same dtype. {x.dtype} !='
|
||||
f' {weight.dtype}'
|
||||
)
|
||||
|
||||
if implementation is not None:
|
||||
named_args = typing.get_args(Implementation)
|
||||
if implementation not in named_args:
|
||||
raise ValueError(
|
||||
f'Unsupported named implementation. Must be one of {named_args}.'
|
||||
)
|
||||
|
||||
if implementation is None or implementation == 'triton':
|
||||
try:
|
||||
return PallasGatedLinearUnit()(
|
||||
x=x,
|
||||
weight=weight,
|
||||
activation=activation,
|
||||
precision=precision,
|
||||
)
|
||||
# When `implementation=None`, we must catch any exception, and use XLA
|
||||
# as a fallback. As we rely on a third-party library (Triton), it might
|
||||
# not be possible to enumerate all possible exceptions that could be
|
||||
# thrown, hence catching the broadest possible one.
|
||||
except Exception as e: # pylint: disable=broad-exception-caught
|
||||
if implementation == 'triton':
|
||||
raise e
|
||||
|
||||
return gated_linear_unit_base.gated_linear_unit_xla(
|
||||
x=x,
|
||||
weight=weight,
|
||||
activation=activation,
|
||||
precision=precision,
|
||||
)
|
||||
@@ -1,130 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Common types for gated linear unit kernels."""
|
||||
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Any
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Float # pylint: disable=g-importing-member,g-multiple-import
|
||||
import typeguard
|
||||
|
||||
|
||||
class GatedLinearUnit(abc.ABC):
|
||||
"""Gated linear unit."""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: Float[Array, '*B M K'],
|
||||
weight: Float[Array, 'K 2 N'],
|
||||
*,
|
||||
activation: Callable[[jax.Array], jax.Array] | None = None,
|
||||
precision: jax.lax.Precision | None = None,
|
||||
**kwargs,
|
||||
) -> Float[Array, '*B M N']:
|
||||
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
|
||||
|
||||
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
|
||||
|
||||
Args:
|
||||
x: the input array.
|
||||
weight: the combined weight array.
|
||||
activation: optional activation function.
|
||||
precision: specifies the matrix multiplication precision. Either `None`
|
||||
(default), which means the default precision for the backend, or a
|
||||
`jax.lax.Precision` enum.
|
||||
**kwargs: additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
The output array.
|
||||
"""
|
||||
return self._fwd(
|
||||
x, weight, activation=activation, precision=precision, **kwargs
|
||||
)
|
||||
|
||||
# Default vmap rule.
|
||||
@property
|
||||
def vmap_rule_forward(self) -> Callable[..., Any]:
|
||||
def _vmap_rule(
|
||||
axis_size, in_batched, *args, fn: jax.custom_batching.custom_vmap
|
||||
):
|
||||
sequential_vmap = jax.custom_batching.sequential_vmap(fn.fun)
|
||||
return sequential_vmap.vmap_rule(axis_size, in_batched, *args)
|
||||
|
||||
return _vmap_rule
|
||||
|
||||
def apply_vmap_rule_forward(
|
||||
self, fn: Callable[..., Any], **kwargs
|
||||
) -> jax.custom_batching.custom_vmap:
|
||||
fn_closed = functools.partial(fn, **kwargs)
|
||||
fn_closed = jax.custom_batching.custom_vmap(fn_closed)
|
||||
vmap_rule = functools.partial(self.vmap_rule_forward, fn=fn_closed)
|
||||
fn_closed.def_vmap(vmap_rule)
|
||||
return fn_closed
|
||||
|
||||
@abc.abstractmethod
|
||||
def _fwd(
|
||||
self,
|
||||
x: Float[Array, '*B M K'],
|
||||
weight: Float[Array, 'K 2 N'],
|
||||
*,
|
||||
activation: Callable[[jax.Array], jax.Array] | None,
|
||||
precision: jax.lax.Precision | None,
|
||||
) -> Float[Array, '*B M N']:
|
||||
"""Gated linear unit."""
|
||||
...
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def gated_linear_unit_xla(
|
||||
x: Float[Array, '*B M K'],
|
||||
weight: Float[Array, 'K 2 N'],
|
||||
*,
|
||||
activation: Callable[[jax.Array], jax.Array] | None = None,
|
||||
precision: jax.lax.Precision | None = None,
|
||||
) -> Float[Array, '*B M N']:
|
||||
"""Applies a gated linear unit (https://arxiv.org/abs/1612.08083).
|
||||
|
||||
Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`.
|
||||
|
||||
This is SwiGLU when `activation=jax.nn.swish`, GEGLU when
|
||||
`activation=jax.nn.gelu`, REGLU when `activation=jax.nn.relu`, and GLU when
|
||||
`activation=jax.nn.sigmoid` (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Args:
|
||||
x: the input array.
|
||||
weight: the combined weight array.
|
||||
activation: optional activation function.
|
||||
precision: specifies the matrix multiplication precision. Either `None`
|
||||
(default), which means the default precision for the backend, or a
|
||||
`jax.lax.Precision` enum.
|
||||
|
||||
Returns:
|
||||
The output array.
|
||||
"""
|
||||
|
||||
weight_reshaped = jax.lax.collapse(
|
||||
weight, start_dimension=-2, stop_dimension=None
|
||||
)
|
||||
assert weight_reshaped.ndim == 2
|
||||
|
||||
y = jnp.dot(x, weight_reshaped, precision=precision)
|
||||
|
||||
# Apply activation and compute product of FP8/FP16/BF16 in FP32.
|
||||
y = y.astype(jnp.promote_types(x.dtype, jnp.float32))
|
||||
a, b = jnp.split(y, 2, axis=-1)
|
||||
out = a * b if activation is None else activation(a) * b
|
||||
out = out.astype(x.dtype)
|
||||
return out
|
||||
@@ -1,78 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Auto-tuned configs for matmul."""
|
||||
|
||||
import dataclasses
|
||||
import functools
|
||||
import math
|
||||
|
||||
import jax
|
||||
from jax.experimental import pallas as pl
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True, kw_only=True)
|
||||
class Config:
|
||||
block_m: int
|
||||
block_n: int
|
||||
block_k: int
|
||||
num_warps: int
|
||||
num_stages: int
|
||||
|
||||
|
||||
@functools.cache
|
||||
def _get_best_block_size(
|
||||
m: int, n: int, k: int, core_count: int
|
||||
) -> tuple[int, int, int]:
|
||||
"""Returns the best block size for the given shape."""
|
||||
min_block_dim = 32
|
||||
block_m = min(max(min_block_dim, pl.next_power_of_2(m)), 128)
|
||||
block_n = min(max(min_block_dim, pl.next_power_of_2(n)), 256)
|
||||
block_n = min(block_n, (128 * 128) // block_m)
|
||||
block_k = 32
|
||||
split_k = 1
|
||||
num_blocks = pl.cdiv(m, block_m) * pl.cdiv(n, block_n)
|
||||
while num_blocks < core_count:
|
||||
if block_m > min_block_dim:
|
||||
block_m //= 2
|
||||
num_blocks = pl.cdiv(m, block_m) * pl.cdiv(n, block_n)
|
||||
elif split_k * block_k < pl.next_power_of_2(k):
|
||||
split_k *= 2
|
||||
num_blocks *= 2
|
||||
else:
|
||||
break
|
||||
return block_m, block_n, block_k
|
||||
|
||||
|
||||
def _abstractify(x):
|
||||
return jax.api_util.shaped_abstractify(x) if isinstance(x, jax.Array) else x
|
||||
|
||||
|
||||
def get_config(
|
||||
x: jax.Array, w: jax.Array, core_count: int | None = None
|
||||
) -> Config:
|
||||
"""Returns a config for the given args."""
|
||||
if core_count is None:
|
||||
core_count = jax.devices()[0].core_count
|
||||
x = _abstractify(x)
|
||||
w = _abstractify(w)
|
||||
m, k = math.prod(x.shape[:-1]), x.shape[-1]
|
||||
n = w.shape[1]
|
||||
if n >= m: # Prefer `block_n` > `block_m`.
|
||||
block_m, block_n, block_k = _get_best_block_size(m, n, k, core_count)
|
||||
else:
|
||||
block_n, block_m, block_k = _get_best_block_size(n, m, k, core_count)
|
||||
return Config(
|
||||
block_m=block_m,
|
||||
block_n=block_n // 2, # Halve `block_n` as we read two `w` blocks.
|
||||
block_k=block_k,
|
||||
num_warps=4,
|
||||
num_stages=4,
|
||||
)
|
||||
@@ -1,273 +0,0 @@
|
||||
# Copyright 2024 DeepMind Technologies Limited
|
||||
#
|
||||
# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of
|
||||
# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/
|
||||
#
|
||||
# To request access to the AlphaFold 3 model parameters, follow the process set
|
||||
# out at https://github.com/google-deepmind/alphafold3. You may only use these
|
||||
# if received directly from Google. Use is subject to terms of use available at
|
||||
# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md
|
||||
|
||||
"""Extended matmul ops."""
|
||||
|
||||
from collections.abc import Callable
|
||||
import functools
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from alphafold3.jax.common import array_view
|
||||
from alphafold3.jax.common import triton_utils
|
||||
from alphafold3.jax.gated_linear_unit import block
|
||||
from alphafold3.jax.gated_linear_unit import matmul_config
|
||||
import jax
|
||||
from jax._src.state import discharge
|
||||
from jax.experimental import pallas as pl
|
||||
import jax.numpy as jnp
|
||||
import jaxtyping
|
||||
from jaxtyping import Array, Float, Int # pylint: disable=g-importing-member,g-multiple-import
|
||||
import numpy as np
|
||||
import typeguard
|
||||
|
||||
ArrayView = array_view.ArrayView
|
||||
PyTree: TypeAlias = Any
|
||||
ArrayT: TypeAlias = Any
|
||||
ScalarInt: TypeAlias = (
|
||||
Int[ArrayT, ''] | Int[np.generic, ''] | Int[jnp.generic, '']
|
||||
)
|
||||
|
||||
|
||||
def _get_group_cache_usage(
|
||||
group_size_m, num_blocks_m, num_blocks_n, block_m_bytes, block_n_bytes
|
||||
) -> int:
|
||||
"""Returns the cache usage in bytes for the given group size."""
|
||||
num_live_progs = jax.devices()[0].core_count
|
||||
num_live_blocks_n = min(pl.cdiv(num_live_progs, group_size_m), num_blocks_n)
|
||||
num_live_groups = pl.cdiv(num_live_progs, group_size_m * num_live_blocks_n)
|
||||
num_live_blocks_m = min(num_live_groups * group_size_m, num_blocks_m)
|
||||
return num_live_blocks_m * block_m_bytes + num_live_blocks_n * block_n_bytes
|
||||
|
||||
|
||||
def _get_pids(
|
||||
pid, num_blocks_m, num_blocks_n, group_size_m
|
||||
) -> tuple[ScalarInt, ScalarInt]:
|
||||
"""Returns the program IDs in each grid axis."""
|
||||
# Use `floor_divide` and `remainder` (instead of lax.div and lax.rem)
|
||||
# to handle dtypes: pid (int32) vs. num_blocks_n (int64) when `jax_enable_x64`
|
||||
# is set.
|
||||
if group_size_m == 1:
|
||||
return jnp.floor_divide(pid, num_blocks_n), jnp.remainder(pid, num_blocks_n)
|
||||
|
||||
num_progs_in_group = group_size_m * num_blocks_n
|
||||
group_start_m = jnp.floor_divide(pid, num_progs_in_group) * group_size_m
|
||||
group_size_m = jnp.minimum(num_blocks_m - group_start_m, group_size_m)
|
||||
pid_m = group_start_m + jnp.remainder(pid, group_size_m)
|
||||
pid_n = jnp.floor_divide(jnp.remainder(pid, num_progs_in_group), group_size_m)
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
def _get_best_pids(
|
||||
pid, *, m, n, block_m, block_n, a_dtype_bytes, b_dtype_bytes
|
||||
) -> tuple[ScalarInt, ScalarInt]:
|
||||
"""Returns the grouped program IDs that minimize cache usage."""
|
||||
num_blocks_m = pl.cdiv(m, block_m)
|
||||
num_blocks_n = pl.cdiv(n, block_n)
|
||||
block_m_bytes = block_m * a_dtype_bytes
|
||||
block_n_bytes = block_n * b_dtype_bytes
|
||||
|
||||
num_live_progs = jax.devices()[0].core_count
|
||||
|
||||
def group_size_m_usage(group_size_m):
|
||||
return _get_group_cache_usage(
|
||||
group_size_m, num_blocks_m, num_blocks_n, block_m_bytes, block_n_bytes
|
||||
)
|
||||
|
||||
group_size_m = min(
|
||||
range(1, min(num_live_progs, num_blocks_m) + 1), key=group_size_m_usage
|
||||
)
|
||||
|
||||
def group_size_n_usage(group_size_n):
|
||||
return _get_group_cache_usage(
|
||||
group_size_n, num_blocks_n, num_blocks_m, block_n_bytes, block_m_bytes
|
||||
)
|
||||
|
||||
group_size_n = min(
|
||||
range(1, min(num_live_progs, num_blocks_n) + 1), key=group_size_n_usage
|
||||
)
|
||||
|
||||
if group_size_m_usage(group_size_m) <= group_size_n_usage(group_size_n):
|
||||
pid_m, pid_n = _get_pids(pid, num_blocks_m, num_blocks_n, group_size_m)
|
||||
else:
|
||||
pid_n, pid_m = _get_pids(pid, num_blocks_n, num_blocks_m, group_size_n)
|
||||
return pid_m, pid_n
|
||||
|
||||
|
||||
def _apply_epilogue(
|
||||
epilogue: Callable[..., jax.Array], x: jax.Array, args: PyTree
|
||||
) -> jax.Array:
|
||||
"""Applies the epilogue to the output."""
|
||||
# Convert array view arguments to JAX arrays. This means that we can use the
|
||||
# array view slices, rather than the gather that discharging state gives us.
|
||||
is_leaf = lambda x: isinstance(x, ArrayView)
|
||||
args_flat, args_tree = jax.tree.flatten((x, args), is_leaf=is_leaf)
|
||||
args_flat = tuple(map(jnp.array, args_flat))
|
||||
|
||||
def epilogue_wrapper(refs):
|
||||
x_ref, arg_refs = args_tree.unflatten(refs)
|
||||
x_ref[:] = epilogue(x_ref[:], arg_refs, 0, 0)
|
||||
|
||||
return discharge.run_state_reference(epilogue_wrapper)(args_flat)[0]
|
||||
|
||||
|
||||
def _gated_linear_unit_kernel(
|
||||
x_ref,
|
||||
w_ref,
|
||||
v_ref,
|
||||
_, # Destination, aliased with `out_ref`.
|
||||
epilogue_in_refs,
|
||||
out_ref,
|
||||
*,
|
||||
block_m,
|
||||
block_n,
|
||||
block_k,
|
||||
activation,
|
||||
precision,
|
||||
epilogue,
|
||||
):
|
||||
"""Pallas GLU kernel."""
|
||||
m = x_ref.shape[0]
|
||||
n = w_ref.shape[1]
|
||||
pid_m, pid_n = _get_best_pids(
|
||||
pl.program_id(0),
|
||||
m=m,
|
||||
n=n,
|
||||
block_m=block_m,
|
||||
block_n=block_n,
|
||||
a_dtype_bytes=jnp.dtype(x_ref.dtype).itemsize,
|
||||
b_dtype_bytes=jnp.dtype(w_ref.dtype).itemsize * 2, # Two blocks.
|
||||
)
|
||||
|
||||
def body(i, acc):
|
||||
x = block.load_block(x_ref, (pid_m, i), block_shape=(block_m, block_k))
|
||||
w = block.load_block(w_ref, (i, pid_n), block_shape=(block_k, block_n))
|
||||
v = block.load_block(v_ref, (i, pid_n), block_shape=(block_k, block_n))
|
||||
acc[0] += pl.dot(x, w.astype(x.dtype), precision=precision)
|
||||
acc[1] += pl.dot(x, v.astype(x.dtype), precision=precision)
|
||||
return acc
|
||||
|
||||
num_iters = pl.cdiv(x_ref.shape[-1], block_k)
|
||||
acc0 = jnp.zeros((block_m, block_n), dtype=jnp.float32)
|
||||
acc1 = jnp.zeros((block_m, block_n), dtype=jnp.float32)
|
||||
proj, gates = jax.lax.fori_loop(0, num_iters, body, init_val=[acc0, acc1])
|
||||
|
||||
proj = proj.astype(x_ref.dtype).astype(jnp.float32)
|
||||
gates = gates.astype(x_ref.dtype).astype(jnp.float32)
|
||||
|
||||
out = proj * (gates if activation is None else activation(gates))
|
||||
|
||||
if epilogue is not None:
|
||||
out = epilogue(out, epilogue_in_refs, pid_m, pid_n)
|
||||
|
||||
block.store_block(out_ref, out, (pid_m, pid_n))
|
||||
|
||||
|
||||
def _gated_linear_unit(
|
||||
x: Float[Array | ArrayView, 'M K'],
|
||||
weights_projection: Float[Array | ArrayView, 'K N'],
|
||||
weights_gate: Float[Array | ArrayView, 'K N'],
|
||||
*,
|
||||
dst: Float[ArrayView, 'M N'] | None = None,
|
||||
activation: Callable[[jax.Array], jax.Array] | None,
|
||||
epilogue: Any, # Callable[..., Any] | None - breaks `typed`.
|
||||
epilogue_args: PyTree,
|
||||
precision: jax.lax.Precision | None,
|
||||
) -> jax.Array: # Float[Array, 'M N'] | Float[Array, 'N M']
|
||||
"""Applies a gated linear unit (arxiv.org/abs/1612.08083)."""
|
||||
if epilogue is None and epilogue_args is not None:
|
||||
raise ValueError('`epilogue_args` is specified but `epilogue` is None.')
|
||||
|
||||
name = 'pallas_glu'
|
||||
if activation is not None:
|
||||
name += f'_{getattr(activation, "__name__", repr(activation))}'
|
||||
if epilogue is not None:
|
||||
name += f'_{getattr(epilogue, "__name__", repr(epilogue))}'
|
||||
|
||||
w = weights_projection
|
||||
config = matmul_config.get_config(x, w)
|
||||
|
||||
m = x.shape[0]
|
||||
n = w.shape[1]
|
||||
kernel = functools.partial(
|
||||
_gated_linear_unit_kernel,
|
||||
block_m=config.block_m,
|
||||
block_n=config.block_n,
|
||||
block_k=config.block_k,
|
||||
activation=activation,
|
||||
precision=precision,
|
||||
epilogue=epilogue,
|
||||
)
|
||||
|
||||
if dst is None:
|
||||
input_output_aliases = {}
|
||||
else:
|
||||
input_output_aliases = {3: 0}
|
||||
|
||||
compiler_params = dict(
|
||||
triton=dict(num_warps=config.num_warps, num_stages=config.num_stages)
|
||||
)
|
||||
|
||||
return pl.pallas_call(
|
||||
kernel,
|
||||
name=name,
|
||||
grid=(pl.cdiv(m, config.block_m) * pl.cdiv(n, config.block_n),),
|
||||
out_shape=jax.ShapeDtypeStruct((m, n), x.dtype) if dst is None else dst,
|
||||
input_output_aliases=input_output_aliases,
|
||||
compiler_params=compiler_params,
|
||||
backend='triton',
|
||||
)(x, weights_projection, weights_gate, dst, epilogue_args)
|
||||
|
||||
|
||||
@jaxtyping.jaxtyped(typechecker=typeguard.typechecked)
|
||||
def gated_linear_unit(
|
||||
x: Float[Array | ArrayView, '*B M K'],
|
||||
weights_projection: Float[Array | ArrayView, 'K N'],
|
||||
weights_gate: Float[Array | ArrayView, 'K N'],
|
||||
*,
|
||||
activation: Callable[[jax.Array], jax.Array] | None = None,
|
||||
precision: jax.lax.Precision | None = None,
|
||||
) -> Float[Array | ArrayView, '*B M N']:
|
||||
"""Applies a gated linear unit (arxiv.org/abs/1612.08083).
|
||||
|
||||
Args:
|
||||
x: Input activations.
|
||||
weights_projection: Weights for linear projection.
|
||||
weights_gate: Weights for gates.
|
||||
activation: Optional activation function.
|
||||
precision: Specifies the precision of the matmuls.
|
||||
|
||||
Returns:
|
||||
`(x @ weights_projection) * activation(x @ weights_gate)`
|
||||
"""
|
||||
|
||||
supported_dtypes = {'float16', 'bfloat16', 'float32'}
|
||||
if x.dtype.name not in supported_dtypes:
|
||||
raise NotImplementedError(
|
||||
f'Triton kernel does not support input datatype {x.dtype.name}. Must be'
|
||||
f' one of {supported_dtypes}.'
|
||||
)
|
||||
|
||||
if not triton_utils.has_triton_support():
|
||||
raise NotImplementedError('Triton kernel not supported on current device.')
|
||||
|
||||
*batch, m, _ = x.shape
|
||||
n = weights_projection.shape[1]
|
||||
x = array_view.as_array_view(x).collapse(start=0, stop=-1)
|
||||
|
||||
return _gated_linear_unit(
|
||||
x,
|
||||
weights_projection,
|
||||
weights_gate,
|
||||
dst=None,
|
||||
activation=activation,
|
||||
precision=precision,
|
||||
epilogue=None,
|
||||
epilogue_args=None,
|
||||
).reshape(batch + [m, n])
|
||||
@@ -14,13 +14,14 @@ from collections.abc import Sequence
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from alphafold3.common import base_config
|
||||
from alphafold3.jax.attention import attention
|
||||
|
||||
import tokamax
|
||||
|
||||
_Shape2DType: TypeAlias = tuple[int | None, int | None]
|
||||
|
||||
|
||||
class GlobalConfig(base_config.BaseConfig):
|
||||
"""Global configuration for the AlphaFold3 model."""
|
||||
|
||||
bfloat16: Literal['all', 'none', 'intermediate'] = 'all'
|
||||
final_init: Literal['zeros', 'linear'] = 'zeros'
|
||||
pair_attention_chunk_size: Sequence[_Shape2DType] = ((1536, 128), (None, 32))
|
||||
@@ -29,4 +30,6 @@ class GlobalConfig(base_config.BaseConfig):
|
||||
(None, 1024),
|
||||
)
|
||||
# Note: flash_attention_implementation = 'xla' means no flash attention.
|
||||
flash_attention_implementation: attention.Implementation = 'triton'
|
||||
flash_attention_implementation: tokamax.DotProductAttentionImplementation = (
|
||||
'triton'
|
||||
)
|
||||
|
||||
@@ -11,13 +11,13 @@
|
||||
"""Diffusion transformer model."""
|
||||
|
||||
from alphafold3.common import base_config
|
||||
from alphafold3.jax.gated_linear_unit import gated_linear_unit
|
||||
from alphafold3.model import model_config
|
||||
from alphafold3.model.atom_layout import atom_layout
|
||||
from alphafold3.model.components import haiku_modules as hm
|
||||
import haiku as hk
|
||||
import jax
|
||||
from jax import numpy as jnp
|
||||
import tokamax
|
||||
|
||||
|
||||
def adaptive_layernorm(x, single_cond, name):
|
||||
@@ -97,9 +97,7 @@ def transition_block(
|
||||
name=f'{name}ffw_transition1',
|
||||
)
|
||||
weights = jnp.reshape(weights, (len(weights), 2, num_intermediates))
|
||||
c = gated_linear_unit.gated_linear_unit(
|
||||
x=x, weight=weights, implementation=None, activation=jax.nn.swish
|
||||
)
|
||||
c = tokamax.gated_linear_unit(x=x, weights=weights, activation=jax.nn.swish)
|
||||
else:
|
||||
x = hm.Linear(
|
||||
num_intermediates * 2, initializer='relu', name=f'{name}ffw_transition1'
|
||||
|
||||
@@ -14,8 +14,6 @@ from collections.abc import Sequence
|
||||
from typing import Literal
|
||||
|
||||
from alphafold3.common import base_config
|
||||
from alphafold3.jax.attention import attention
|
||||
from alphafold3.jax.gated_linear_unit import gated_linear_unit
|
||||
from alphafold3.model import model_config
|
||||
from alphafold3.model.components import haiku_modules as hm
|
||||
from alphafold3.model.components import mapping
|
||||
@@ -23,6 +21,7 @@ from alphafold3.model.network import diffusion_transformer
|
||||
import haiku as hk
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
import tokamax
|
||||
|
||||
|
||||
def get_shard_size(
|
||||
@@ -68,8 +67,8 @@ class TransitionBlock(hk.Module):
|
||||
name='transition1',
|
||||
)
|
||||
weights = jnp.reshape(weights, (len(weights), 2, num_intermediate))
|
||||
c = gated_linear_unit.gated_linear_unit(
|
||||
x=act, weight=weights, implementation=None, activation=jax.nn.swish
|
||||
c = tokamax.gated_linear_unit(
|
||||
x=act, weights=weights, activation=jax.nn.swish
|
||||
)
|
||||
else:
|
||||
act = hm.Linear(
|
||||
@@ -172,7 +171,7 @@ class GridSelfAttention(hk.Module):
|
||||
# Dot product attention requires the bias term to have a batch dimension.
|
||||
bias = jnp.expand_dims(bias, 0)
|
||||
|
||||
weighted_avg = attention.dot_product_attention(
|
||||
weighted_avg = tokamax.dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -289,11 +288,8 @@ class TriangleMultiplication(hk.Module):
|
||||
)
|
||||
weights_glu = jnp.stack([weights_gate, weights_projection], axis=1)
|
||||
|
||||
projection = gated_linear_unit.gated_linear_unit(
|
||||
x=act,
|
||||
weight=weights_glu,
|
||||
activation=jax.nn.sigmoid,
|
||||
implementation=None,
|
||||
projection = tokamax.gated_linear_unit(
|
||||
act, weights_glu, activation=jax.nn.sigmoid
|
||||
)
|
||||
projection = jnp.transpose(projection, (2, 0, 1))
|
||||
projection *= mask
|
||||
|
||||
Reference in New Issue
Block a user