Point AF3 gap branch to pre-Tokamax submodule

This commit is contained in:
Dima
2026-03-25 15:28:08 +01:00
parent 78398ee705
commit ae8bdbeaa8
3 changed files with 5 additions and 5 deletions

2
.gitmodules vendored
View File

@@ -17,4 +17,4 @@
[submodule "alphafold3"]
path = alphafold3
url = git@github.com:KosinskiLab/alphafold3.git
branch = ap-gapped-discontinuous-chains
branch = ap-gapped-discontinuous-chains-pre-tokamax

View File

@@ -23,12 +23,12 @@ import alphafold3.cpp
import haiku as hk
import jax
import numpy as np
import tokamax
from alphafold3.common import base_config
from alphafold3.common import folding_input
from alphafold3.constants import chemical_components
from alphafold3.data import featurisation
from alphafold3.data import parsers as af3_parsers
from alphafold3.jax.attention import attention
from alphafold3.model import features, params, post_processing
from alphafold3.model import model
from alphafold3.model.components import utils
@@ -413,7 +413,7 @@ class AlphaFold3Backend(FoldingBackend):
def make_model_config(
*,
model_class: type[ModelT] = MyNewModel,
flash_attention_implementation: tokamax.DotProductAttentionImplementation,
flash_attention_implementation: attention.Implementation,
num_diffusion_samples: int = 5,
num_recycles: int = 10,
return_embeddings: bool = False,
@@ -458,7 +458,7 @@ class AlphaFold3Backend(FoldingBackend):
model_class=MyNewModel,
config=make_model_config(
flash_attention_implementation=typing.cast(
tokamax.DotProductAttentionImplementation, flash_attention_implementation
attention.Implementation, flash_attention_implementation
),
num_diffusion_samples=num_diffusion_samples,
num_recycles=num_recycles,