mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
376 lines
9.9 KiB
YAML
Executable File
376 lines
9.9 KiB
YAML
Executable File
_target_: boltzgen.task.train.train.Training
|
|
|
|
trainer:
|
|
accelerator: cuda
|
|
devices: 4
|
|
precision: 32
|
|
gradient_clip_val: 10.0
|
|
accumulate_grad_batches: 1
|
|
max_epochs: 5
|
|
num_sanity_val_steps: 1
|
|
log_every_n_steps: 1
|
|
|
|
wandb:
|
|
group: boltzgen
|
|
project: boltzgen
|
|
entity: yourwandb
|
|
|
|
name: if_lr_scheduler
|
|
output: workdir
|
|
strict_loading: false
|
|
resume: null
|
|
debug: false
|
|
save_every_n_train_steps: 2500
|
|
disable_checkpoint: false
|
|
matmul_precision: null
|
|
save_top_k: -1
|
|
|
|
data:
|
|
datasets:
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
target_dir: ./training_data/targets
|
|
msa_dir: ./training_data/msa
|
|
prob: 1
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.date.DateFilter
|
|
date: "2023-06-01"
|
|
ref: released
|
|
- _target_: boltzgen.data.filter.dynamic.resolution.ResolutionFilter
|
|
resolution: 9.0
|
|
- _target_: boltzgen.data.filter.dynamic.min_protein_residues.MinProteinResiduesFilter
|
|
min_residues: 5
|
|
- _target_: boltzgen.data.filter.dynamic.pdb_id_txtfile.FilterIDFromTXT
|
|
paths:
|
|
- data/exclude_ids/fibril.txt
|
|
- data/exclude_ids/transmembrane.txt
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
split: ./data/pdb_sequences/boltz2/validation_ids_boltz2_all.txt
|
|
symmetry_correction: false
|
|
val_group: "RCSB"
|
|
|
|
|
|
tokenizer:
|
|
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
|
atomize_modified_residues: false
|
|
featurizer:
|
|
_target_: boltzgen.data.feature.featurizer.Featurizer
|
|
moldir: ./training_data/mols
|
|
max_tokens: 1024
|
|
max_atoms: 8192
|
|
max_seqs: 1
|
|
pad_to_max_tokens: true
|
|
pad_to_max_atoms: true
|
|
pad_to_max_seqs: true
|
|
samples_per_epoch: 600000
|
|
batch_size: 2
|
|
num_workers: 32
|
|
random_seed: 42
|
|
pin_memory: true
|
|
overfit: null
|
|
return_train_symmetries: false
|
|
return_val_symmetries: false
|
|
compute_frames: false
|
|
|
|
|
|
atoms_per_window_queries: 32
|
|
min_dist: 2.0
|
|
max_dist: 22.0
|
|
num_bins: 64
|
|
single_sequence_prop_training: 0.05
|
|
msa_sampling_training: true
|
|
|
|
# Design
|
|
design: true
|
|
backbone_only: true
|
|
atom14: false
|
|
atom37: false
|
|
inverse_fold: ${model.inverse_fold}
|
|
use_msa: false
|
|
selector:
|
|
_target_: boltzgen.data.select.protein.ProteinSelector
|
|
design_neighborhood_sizes: [2, 4, 6,8,10,12,14,16,18]
|
|
substructure_neighborhood_sizes: [2,4,6,8,10,12,24]
|
|
structure_condition_prob: 0.5
|
|
distance_noise_std: 1
|
|
run_selection: true
|
|
specify_binding_sites: false
|
|
ss_condition_prob: 0
|
|
select_all: true
|
|
complete_structure_mask: true
|
|
|
|
# Design datasets
|
|
monomer_split: data/pdb_sequences/val_monomers_boltzgen_min50_max220.txt
|
|
monomer_target_dir: ./training_data/targets
|
|
monomer_target_structure_condition: true
|
|
monomer_seq_len: 100
|
|
|
|
ligand_split: null
|
|
ligand_target_dir: ./training_data/targets
|
|
ligand_seq_len: 100
|
|
|
|
|
|
model:
|
|
_target_: boltzgen.model.models.boltz.Boltz
|
|
atom_s: 128
|
|
atom_z: 16
|
|
token_s: 384
|
|
token_z: ${model.inverse_fold_args.pair_dim}
|
|
num_bins: 64
|
|
atom_feature_dim: 388
|
|
atoms_per_window_queries: 32
|
|
atoms_per_window_keys: 128
|
|
use_miniformer: true
|
|
ema: true
|
|
ema_decay: 0.999
|
|
exclude_ions_from_lddt: true
|
|
num_val_datasets: 1 # New
|
|
ignore_ckpt_shape_mismatch: false # New
|
|
aggregate_distogram: true # New
|
|
bond_type_feature: true
|
|
predict_bfactor: true
|
|
predict_res_type: true
|
|
checkpoint_diffusion_conditioning: false
|
|
inverse_fold: true
|
|
inverse_fold_args:
|
|
atom_s: ${model.atom_s}
|
|
atom_z: ${model.atom_z}
|
|
token_s: ${model.token_s}
|
|
token_z: ${model.token_z}
|
|
node_dim: 128
|
|
pair_dim: 128
|
|
hidden_dim: 128
|
|
dropout: 0.1
|
|
softmax_dropout: 0.2
|
|
num_encoder_layers: 6
|
|
num_decoder_layers: 3
|
|
autoregressive: true
|
|
transformation_scale_factor: 1.0
|
|
inverse_fold_noise: 0.2
|
|
topk: 30
|
|
num_heads: 4
|
|
enable_input_embedder: True
|
|
sampling_temperature: -1.0
|
|
|
|
validators:
|
|
- _target_: boltzgen.model.validation.design.DesignValidator
|
|
val_names: ["RCSB"]
|
|
confidence_prediction: ${model.confidence_prediction}
|
|
atom14: ${data.atom14}
|
|
atom37: ${data.atom37}
|
|
backbone_only: ${data.backbone_only}
|
|
inverse_fold: ${model.inverse_fold}
|
|
|
|
masker_args:
|
|
mask: true
|
|
mask_backbone: false
|
|
mask_disto: false
|
|
|
|
embedder_args:
|
|
atom_encoder_depth: 1
|
|
atom_encoder_heads: 4
|
|
add_mol_type_feat: true
|
|
add_method_conditioning: true
|
|
add_modified_flag: true
|
|
add_cyclic_flag: true
|
|
add_design_mask_flag: false
|
|
add_binding_specification: false
|
|
add_ss_specification: false
|
|
|
|
use_token_distances: false
|
|
token_distance_args:
|
|
token_distance_dim: ${model.inverse_fold_args.pair_dim}
|
|
token_distance_blocks: 0
|
|
use_token_distance_feats: true
|
|
distance_gaussian_dim: 32
|
|
disable_token_distance_transition: true
|
|
use_relative_position_encoding: true
|
|
|
|
# MSA module is not used in inverse folding
|
|
msa_args:
|
|
msa_s: 2
|
|
msa_blocks: 0
|
|
msa_dropout: 0
|
|
z_dropout: 0
|
|
miniformer_blocks: true
|
|
pairwise_head_width: 2
|
|
pairwise_num_heads: 1
|
|
use_paired_feature: true
|
|
activation_checkpointing: false
|
|
|
|
pairformer_args:
|
|
num_blocks: 2
|
|
num_heads: 16
|
|
dropout: 0.25
|
|
post_layer_norm: false
|
|
activation_checkpointing: false
|
|
|
|
score_model_args:
|
|
sigma_data: 16
|
|
dim_fourier: 256
|
|
atom_encoder_depth: 3
|
|
atom_encoder_heads: 4
|
|
|
|
# token level args
|
|
token_layers: 1
|
|
token_transformer_depth: 3
|
|
token_transformer_heads: 16
|
|
diffusion_pairformer_args:
|
|
num_blocks: 0
|
|
num_heads: 2
|
|
dropout: 0
|
|
use_s_to_z: false
|
|
|
|
atom_decoder_depth: 3
|
|
atom_decoder_heads: 4
|
|
conditioning_transition_layers: 2
|
|
transformer_post_ln: false
|
|
activation_checkpointing: false
|
|
|
|
confidence_prediction: false
|
|
affinity_prediction: false
|
|
structure_prediction_training: true
|
|
affinity_model_args:
|
|
num_dist_bins: 64
|
|
max_dist: 22
|
|
no_trunk_feats: false
|
|
add_s_to_z_prod: false
|
|
add_s_input_to_s: false
|
|
|
|
confidence_args:
|
|
num_plddt_bins: 50
|
|
num_pde_bins: 64
|
|
num_pae_bins: 64
|
|
|
|
training_args:
|
|
recycling_steps: 0
|
|
sampling_steps: 20
|
|
diffusion_multiplicity: 2
|
|
diffusion_samples: 1
|
|
affinity_loss_weight: 3e-3
|
|
confidence_loss_weight: 1e-4
|
|
diffusion_loss_weight: 4.0
|
|
distogram_loss_weight: 3e-2
|
|
bfactor_loss_weight: 1e-3
|
|
res_type_loss_weight: 1
|
|
adam_beta_1: 0.9
|
|
adam_beta_2: 0.95
|
|
adam_eps: 0.00000001
|
|
lr_scheduler: onecycle
|
|
base_lr: 0.0
|
|
max_lr: 0.001
|
|
weight_decay: 0.003
|
|
weight_decay_exclude: true
|
|
|
|
validation_args:
|
|
recycling_steps: 0
|
|
sampling_steps: 200
|
|
diffusion_samples: 1
|
|
symmetry_correction: false
|
|
|
|
diffusion_process_args:
|
|
sigma_min: 0.0004 # min noise level
|
|
sigma_max: 160.0 # max noise level
|
|
sigma_data: 16.0 # standard deviation of data distribution
|
|
rho: 7 # controls the sampling schedule
|
|
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
|
|
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
|
|
gamma_0: 0.8
|
|
gamma_min: 1.0
|
|
noise_scale: 1.0
|
|
step_scale: 1.0
|
|
mse_rotational_alignment: true
|
|
coordinate_augmentation: true
|
|
alignment_reverse_diff: true
|
|
synchronize_sigmas: false
|
|
|
|
diffusion_loss_args:
|
|
add_smooth_lddt_loss: true
|
|
add_bond_loss: false
|
|
nucleotide_loss_weight: 5.0
|
|
ligand_loss_weight: 10.0
|
|
|
|
refolding_validator:
|
|
_target_: boltzgen.model.validation.refolding.RefoldingValidator
|
|
val_names: ["RCSB"]
|
|
step_scale: 1.5
|
|
noise_scale: 0.75
|
|
atom14: ${data.atom14}
|
|
atom37: ${data.atom37}
|
|
val_monomer: ${data.monomer_split}
|
|
val_ligand: ${data.ligand_split}
|
|
inverse_fold: ${model.inverse_fold}
|
|
analyze_task:
|
|
_target_: boltzgen.task.analyze.analyze.Analyze
|
|
name: ${name}
|
|
debug: ${debug}
|
|
design_dir: null
|
|
num_processes: 1
|
|
|
|
# Common metrics to compute
|
|
affinity_metrics: false
|
|
allatom_fold_metrics: true
|
|
backbone_fold_metrics: true
|
|
noncovalents_original: false
|
|
noncovalents_refolded: false
|
|
delta_sasa_original: false
|
|
delta_sasa_refolded: false
|
|
largest_hydrophobic: false
|
|
largest_hydrophobic_refolded: false
|
|
run_clustering: false
|
|
|
|
# Liability analysis
|
|
liability_analysis: false
|
|
liability_modality: peptide
|
|
liability_peptide_type: linear
|
|
|
|
# Uncommon metrics
|
|
diversity_original: true
|
|
diversity_refolded: true
|
|
diversity_per_target_original: false
|
|
diversity_per_target_refolded: false
|
|
novelty_original: false
|
|
novelty_refolded: false
|
|
novelty_per_target_original: false
|
|
novelty_per_target_refolded: false
|
|
|
|
wandb: null
|
|
|
|
data:
|
|
_target_: boltzgen.task.predict.data_from_generated.FromGeneratedDataModule
|
|
cfg:
|
|
_target_: boltzgen.task.predict.data_from_generated.DataConfig
|
|
tokenizer:
|
|
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
|
atomize_modified_residues: false
|
|
featurizer:
|
|
_target_: boltzgen.data.feature.featurizer.Featurizer
|
|
|
|
suffix: .cif
|
|
suffix_metadata: .npz
|
|
suffix_native: _native.cif
|
|
samples_per_target: 1
|
|
num_targets: 100000000
|
|
moldir: ./training_data/mols
|
|
|
|
batch_size: 1
|
|
num_workers: 1
|
|
pin_memory: false
|
|
target_templates: true
|
|
return_native: true
|
|
|
|
folding_checkpoint: ./training_data/boltz2_fold.ckpt
|
|
|
|
folding_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 200
|
|
diffusion_samples: 1
|
|
|
|
folding_model_args:
|
|
validators: null |