diff --git a/.gitmodules b/.gitmodules index ec361d14..389f61cd 100644 --- a/.gitmodules +++ b/.gitmodules @@ -17,4 +17,4 @@ [submodule "alphafold3"] path = alphafold3 url = git@github.com:KosinskiLab/alphafold3.git - branch = ap-gapped-discontinuous-chains + branch = ap-gapped-discontinuous-chains-pre-tokamax diff --git a/alphafold3 b/alphafold3 index 7d5a657d..8420ad39 160000 --- a/alphafold3 +++ b/alphafold3 @@ -1 +1 @@ -Subproject commit 7d5a657db333eb56edb5e2c0b2b7cf2441c8f5c4 +Subproject commit 8420ad39e9988ca02178294cd3f7b7420b88cabf diff --git a/alphapulldown/folding_backend/alphafold3_backend.py b/alphapulldown/folding_backend/alphafold3_backend.py index 39f9043e..e299b1db 100644 --- a/alphapulldown/folding_backend/alphafold3_backend.py +++ b/alphapulldown/folding_backend/alphafold3_backend.py @@ -23,12 +23,12 @@ import alphafold3.cpp import haiku as hk import jax import numpy as np -import tokamax from alphafold3.common import base_config from alphafold3.common import folding_input from alphafold3.constants import chemical_components from alphafold3.data import featurisation from alphafold3.data import parsers as af3_parsers +from alphafold3.jax.attention import attention from alphafold3.model import features, params, post_processing from alphafold3.model import model from alphafold3.model.components import utils @@ -413,7 +413,7 @@ class AlphaFold3Backend(FoldingBackend): def make_model_config( *, model_class: type[ModelT] = MyNewModel, - flash_attention_implementation: tokamax.DotProductAttentionImplementation, + flash_attention_implementation: attention.Implementation, num_diffusion_samples: int = 5, num_recycles: int = 10, return_embeddings: bool = False, @@ -458,7 +458,7 @@ class AlphaFold3Backend(FoldingBackend): model_class=MyNewModel, config=make_model_config( flash_attention_implementation=typing.cast( - tokamax.DotProductAttentionImplementation, flash_attention_implementation + attention.Implementation, flash_attention_implementation ), num_diffusion_samples=num_diffusion_samples, num_recycles=num_recycles,