Files
foundry/configs/model/components/af3_net.yaml
Nathaniel Corley 5a492032d5 refactor: new modelhub (#109)
* Initial commit of chiral changes

Initial checkin of chiral feature code

Add chiral metric

* Update the way chiral features are incorporated into the model

Move initialization to new func

use default pytorch reset parameters

fix initialization for chirals

config

rename argument of confidence head

fix initialization for chirals

* refactor: src nest, rename rf2aa to modelhub

* refactor: initial commit without projects

* Initial commit of chiral changes

* Initial checkin of chiral feature code

* Add chiral metric

* Remove option for double residual connection.  Add kq_norm oiptions to base (20250125) config.

* Restoring flag

* config

* rename argument of confidence head

* Update the way chiral features are incorporated into the model

* config

* rename argument of confidence head

* Update the way chiral features are incorporated into the model

* Initial commit of chiral changes

Initial checkin of chiral feature code

Add chiral metric

* Update the way chiral features are incorporated into the model

Move initialization to new func

use default pytorch reset parameters

fix initialization for chirals

config

rename argument of confidence head

fix initialization for chirals

* refactor: new modelhub

---------

Co-authored-by: fdimaio <dimaio@uw.edu>
Co-authored-by: HaotianZhangAI4Science <haotianzhang@zju.edu.cn>
2025-04-08 13:33:17 -07:00

177 lines
4.4 KiB
YAML

# Model architecture
_target_: modelhub.model.AF3.AF3
# +---------- Channel dimensions ----------+
c_s: 384
c_z: 128
c_atom: 128
c_atompair: 16
c_s_inputs: 449 # TODO: What is this?
# +---------- Feature embedding ----------+
feature_initializer:
# InputFeatureEmbedder
input_feature_embedder:
features:
- restype
- profile
- deletion_mean
atom_attention_encoder:
c_token: 384
c_atom_1d_features: 389
c_tokenpair: ${model.net.c_z}
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: 40_000 # does not matter
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
# RelativePositionEncoding
relative_position_encoding:
r_max: 32
s_max: 2
# +---------- Recycler ----------+
recycler:
# Pairformer
n_pairformer_blocks: 48
pairformer_block:
p_drop: 0.25
triangle_multiplication:
d_hidden: 128
triangle_attention:
n_head: 4
d_hidden: 32
attention_pair_bias:
n_head: 16
# TemplateEmbedder
template_embedder:
n_block: 2
raw_template_dim: 108
c: 64
p_drop: 0.25
# MSA module
msa_module:
n_block: 4
c_m: 64
p_drop_msa: 0.15
p_drop_pair: 0.25
msa_subsample_embedder:
num_sequences: 1024
dim_raw_msa: 34
c_s_inputs: ${model.net.c_s_inputs}
c_msa_embed: ${model.net.recycler.msa_module.c_m}
outer_product:
c_msa_embed: ${model.net.recycler.msa_module.c_m}
c_outer_product: 32
c_out: ${model.net.c_z}
msa_pair_weighted_averaging:
n_heads: 8
c_weighted_average: 32
c_msa_embed: ${model.net.recycler.msa_module.c_m}
c_z: ${model.net.c_z}
separate_gate_for_every_channel: true
msa_transition:
n: 4
c: ${model.net.recycler.msa_module.c_m}
triangle_multiplication_outgoing:
d_pair: ${model.net.c_z}
d_hidden: 128
bias: True
triangle_multiplication_incoming:
d_pair: ${model.net.c_z}
d_hidden: 128
bias: True
triangle_attention_starting:
d_pair: ${model.net.c_z}
n_head: 4
d_hidden: 32
p_drop: 0.0 # This does not do anything: TODO: Remove
triangle_attention_ending:
d_pair: ${model.net.c_z}
n_head: 4
d_hidden: 32
p_drop: 0.0 # This does not do anything; TODO: Remove
pair_transition:
n: 4
c: ${model.net.c_z}
# +---------- Diffusion module ----------+
diffusion_module:
sigma_data: 16
c_token: 768
f_pred: edm
diffusion_conditioning:
c_s_inputs: ${model.net.c_s_inputs}
c_t_embed: 256
relative_position_encoding:
r_max: 32
s_max: 2
atom_attention_encoder:
c_tokenpair: ${model.net.c_z}
c_atom_1d_features: 389
atom_1d_features:
- ref_pos
- ref_charge
- ref_mask
- ref_element
- ref_atom_name_chars
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
broadcast_trunk_feats_on_1dim_old: false
use_chiral_features: true
diffusion_transformer:
n_block: 24
diffusion_transformer_block:
n_head: 16
no_residual_connection_between_attention_and_transition: true
kq_norm: true
atom_attention_decoder:
atom_transformer:
n_queries: 32
n_keys: 128
l_max: ${model.net.feature_initializer.input_feature_embedder.atom_attention_encoder.atom_transformer.l_max}
diffusion_transformer:
n_block: 3
diffusion_transformer_block:
n_head: 4
no_residual_connection_between_attention_and_transition: true
kq_norm: false
distogram_head:
bins: 65
# +---------- Inference sampler ----------+
inference_sampler:
solver: "af3"
num_timesteps: 200
min_t: 0
max_t: 1
sigma_data: ${model.net.diffusion_module.sigma_data}
s_min: 4e-4
s_max: 160
p: 7
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.003
step_scale: 1.5