Files
boltzgen/config/train/boltzgen.yaml
2025-10-26 20:27:38 +00:00

580 lines
16 KiB
YAML
Executable File

_target_: boltzgen.task.train.train.Training
trainer:
accelerator: gpu
devices: 8
precision: bf16-mixed
gradient_clip_val: 10.0
accumulate_grad_batches: 1
max_epochs: -1
num_sanity_val_steps: 3
log_every_n_steps: 1
wandb:
group: boltzgen
project: boltzgen
entity: yourwandb
name: a_big_run_resume3
slurm: true
output: workdir
strict_loading: false
resume: null
pretrained: null
debug: false
save_every_n_train_steps: 2500
disable_checkpoint: false
matmul_precision: null
save_top_k: -1
data:
datasets:
- _target_: boltzgen.task.train.data.DatasetConfig
target_dir: ./training_data/targets
msa_dir: ./training_data/msa
prob: 0.6
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.date.DateFilter
date: "2023-06-01"
ref: released
- _target_: boltzgen.data.filter.dynamic.resolution.ResolutionFilter
resolution: 9.0
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
split: ./data/pdb_sequences/boltz2/validation_ids_boltz2_all.txt
symmetry_correction: false
val_group: "RCSB"
# AFDB Distillation Data
- _target_: boltzgen.task.train.data.DatasetConfig
manifest_path: ./training_data/afdb/afdb_manifest_foldseek_c75_confidence.json
target_dir: ./training_data/afdb/targets
msa_dir: ./training_data/afdb/msa
prob: 0.3
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
composition_op: "AND"
metrics: ["confidence_score"]
compare_ops: ["greater"]
thresholds: [70]
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
symmetry_correction: true
override_method: "AFDB"
override_bfactor: true
# Protein-Ligand Distillation Data
- _target_: boltzgen.task.train.data.DatasetConfig
target_dir: ./training_data/protein_ligand/targets
msa_dir: ./training_data/protein_ligand/msa
moldir: ./training_data/protein_ligand/mols
prob: 0.03
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
composition_op: "AND"
metrics: ["complex_ipde", "complex_pde", "iptm"]
compare_ops: ["lesser", "lesser", "greater"]
thresholds: [1.5, 1.5, 0.9]
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
beta_chain: 0.05
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
symmetry_correction: true
override_method: "BOLTZ-1"
# RNA Distillation Data
- _target_: boltzgen.task.train.data.DatasetConfig
target_dir: ./training_data/rna/targets
msa_dir: ./training_data/rna/msa
prob: 0.04
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
composition_op: "OR"
metrics: ["complex_pde"]
compare_ops: ["lesser"]
thresholds: [2.0]
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
symmetry_correction: true
override_method: "BOLTZ-1"
# Protein-DNA Distillation Data
- _target_: boltzgen.task.train.data.DatasetConfig
target_dir: ./training_data/protein_dna/targets
msa_dir: ./training_data/protein_dna/msa
prob: 0.03
filters:
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
min_chains: 1
max_chains: 300
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
composition_op: "AND"
metrics: ["complex_ipde", "complex_pde", "iptm"]
compare_ops: ["lesser", "lesser", "greater"]
thresholds: [1.0, 2.0, 0.7]
sampler:
_target_: boltzgen.data.sample.cluster.ClusterSampler
beta_chain: 0.05
cropper:
_target_: boltzgen.data.crop.multimer.MultimerCropper
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
symmetry_correction: true
override_method: "BOLTZ-1"
tokenizer:
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
atomize_modified_residues: false
featurizer:
_target_: boltzgen.data.feature.featurizer.Featurizer
moldir: ./training_data/mols
max_tokens: 512
max_atoms: 5120
max_seqs: 4096
pad_to_max_tokens: true
pad_to_max_atoms: true
pad_to_max_seqs: true
samples_per_epoch: 100000
batch_size: 1
num_workers: 2
random_seed: 42
pin_memory: false
overfit: null
return_train_symmetries: false
return_val_symmetries: false
atoms_per_window_queries: 32
min_dist: 2.0
max_dist: 22.0
num_bins: 64
single_sequence_prop_training: 0.1
msa_sampling_training: true
# Design
design: true
backbone_only: false
atom14: true
atom37: false
selector:
_target_: boltzgen.data.select.protein.ProteinSelector
design_neighborhood_sizes: [2, 4, 6,8,10,12,14,16,18]
substructure_neighborhood_sizes: [2,4,6,8,10,12,24]
structure_condition_prob: 0.4
distance_noise_std: 1
run_selection: true
specify_binding_sites: true
ss_condition_prob: 0.1
select_all: false
# Design datasets
monomer_split: data/pdb_sequences/val_monomers_boltzgen_min50_max220.txt
monomer_target_dir: ./training_data/targets
monomer_target_structure_condition: true
monomer_seq_len: 100
ligand_split: data/pdb_sequences/val_ccd_pdb_pairs_boltzgen.txt
ligand_target_dir: ./training_data/targets
ligand_seq_len: 100
model:
_target_: boltzgen.model.models.boltz.Boltz
atom_s: 128
atom_z: 16
token_s: 384
token_z: 128
num_bins: 64
atom_feature_dim: 388
atoms_per_window_queries: 32
atoms_per_window_keys: 128
use_miniformer: false
ema: true
ema_decay: 0.999
exclude_ions_from_lddt: true
num_val_datasets: 1 # New
ignore_ckpt_shape_mismatch: false # New
aggregate_distogram: true # New
bond_type_feature: true
predict_bfactor: true
checkpoint_diffusion_conditioning: true
use_kernels: true
validators:
- _target_: boltzgen.model.validation.design.DesignValidator
val_names: ["RCSB"]
confidence_prediction: ${model.confidence_prediction}
atom14: ${data.atom14}
atom37: ${data.atom37}
masker_args:
mask: true
mask_backbone: false
mask_disto: true
embedder_args:
atom_encoder_depth: 3
atom_encoder_heads: 4
add_mol_type_feat: true
add_method_conditioning: true
add_modified_flag: true
add_cyclic_flag: true
add_design_mask_flag: true
add_binding_specification: true
add_ss_specification: true
freeze_template_weights: true
use_templates: true
template_args:
template_dim: 64
template_blocks: 2
activation_checkpointing: false
use_token_distances: true
token_distance_args:
token_distance_dim: 64
token_distance_blocks: 2
use_token_distance_feats: true
distance_gaussian_dim: 32
activation_checkpointing: true
msa_args:
msa_s: 64
msa_blocks: 4
msa_dropout: 0.15
z_dropout: 0.25
miniformer_blocks: false
pairwise_head_width: 32
pairwise_num_heads: 4
use_paired_feature: true
activation_checkpointing: true
pairformer_args:
num_blocks: 64
num_heads: 16
dropout: 0.25
post_layer_norm: false
activation_checkpointing: true
score_model_args:
sigma_data: 16
dim_fourier: 256
atom_encoder_depth: 3
atom_encoder_heads: 4
# token level args
token_layers: 1
token_transformer_depth: 24
token_transformer_heads: 16
diffusion_pairformer_args:
num_blocks: 0
num_heads: 2
dropout: 0
use_s_to_z: false
atom_decoder_depth: 3
atom_decoder_heads: 4
conditioning_transition_layers: 2
transformer_post_ln: false
activation_checkpointing: true
confidence_prediction: false
structure_prediction_training: true
training_args:
recycling_steps: 3
sampling_steps: 20
diffusion_multiplicity: 32
diffusion_samples: 1
confidence_loss_weight: 1e-4
diffusion_loss_weight: 4.0
distogram_loss_weight: 3e-2
bfactor_loss_weight: 1e-3
adam_beta_1: 0.9
adam_beta_2: 0.95
adam_eps: 0.00000001
lr_scheduler: af3
base_lr: 0.0
max_lr: 0.0005
lr_warmup_no_steps: 1000
lr_start_decay_after_n_steps: 50000
lr_decay_every_n_steps: 50000
lr_decay_factor: 0.95
weight_decay: 0.003
weight_decay_exclude: true
validation_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 1
symmetry_correction: false
diffusion_process_args:
sigma_min: 0.0004 # min noise level
sigma_max: 160.0 # max noise level
sigma_data: 16.0 # standard deviation of data distribution
rho: 7 # controls the sampling schedule
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.0
step_scale: 1.0
mse_rotational_alignment: true
coordinate_augmentation: true
alignment_reverse_diff: true
synchronize_sigmas: false
diffusion_loss_args:
add_smooth_lddt_loss: true
add_bond_loss: false
nucleotide_loss_weight: 5.0
ligand_loss_weight: 10.0
refolding_validator:
_target_: boltzgen.model.validation.refolding.RefoldingValidator
val_names: ["RCSB"]
step_scale: 1.5
noise_scale: 0.75
atom14: ${data.atom14}
atom37: ${data.atom37}
val_monomer: ${data.monomer_split}
val_ligand: ${data.ligand_split}
analyze_task:
_target_: boltzgen.task.analyze.analyze.Analyze
name: ${name}
debug: ${debug}
design_dir: null
num_processes: 1
# Common metrics to compute
affinity_metrics: false
allatom_fold_metrics: true
backbone_fold_metrics: true
noncovalents_original: false
noncovalents_refolded: false
delta_sasa_original: false
delta_sasa_refolded: false
largest_hydrophobic: false
largest_hydrophobic_refolded: false
run_clustering: false
# Liability analysis
liability_analysis: false
liability_modality: peptide
liability_peptide_type: linear
# Uncommon metrics
diversity_original: true
diversity_refolded: true
diversity_per_target_original: false
diversity_per_target_refolded: false
novelty_original: false
novelty_refolded: false
novelty_per_target_original: false
novelty_per_target_refolded: false
wandb: null
data:
_target_: boltzgen.task.predict.data_from_generated.FromGeneratedDataModule
cfg:
_target_: boltzgen.task.predict.data_from_generated.DataConfig
tokenizer:
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
atomize_modified_residues: false
featurizer:
_target_: boltzgen.data.feature.featurizer.Featurizer
suffix: .cif
suffix_metadata: .npz
suffix_native: _native.cif
samples_per_target: 1
num_targets: 100000000
moldir: ./training_data/mols
batch_size: 1
num_workers: 4
pin_memory: true
return_native: true
predict_task: null
folding_checkpoint: ./training_data/boltz2_fold.ckpt
folding_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 1
folding_model_args:
atom_s: 128
atom_z: 16
token_s: 384
token_z: 128
num_bins: 64
atom_feature_dim: 388
atoms_per_window_queries: 32
atoms_per_window_keys: 128
compile_pairformer: false
compile_templates: false
compile_msa: false
use_miniformer: false
ema: true
ema_decay: 0.999
exclude_ions_from_lddt: true
num_val_datasets: 4
ignore_ckpt_shape_mismatch: false
aggregate_distogram: true
bond_type_feature: true
conditioning_cutoff_min: 4.0
conditioning_cutoff_max: 20.0
use_templates: true
predict_bfactor: true
checkpoint_diffusion_conditioning: false
use_kernels: true
validators: null
embedder_args:
atom_encoder_depth: 3
atom_encoder_heads: 4
add_mol_type_feat: true
add_method_conditioning: true
add_modified_flag: true
add_cyclic_flag: true
msa_args:
msa_s: 64
msa_blocks: 4
msa_dropout: 0.15
z_dropout: 0.25
miniformer_blocks: false
pairwise_head_width: 32
pairwise_num_heads: 4
use_paired_feature: true
activation_checkpointing: false
template_args:
template_dim: 64
template_blocks: 2
activation_checkpointing: false
pairformer_args:
num_blocks: 64
num_heads: 16
dropout: 0.25
post_layer_norm: false
activation_checkpointing: false
score_model_args:
sigma_data: 16
dim_fourier: 256
atom_encoder_depth: 3
atom_encoder_heads: 4
token_transformer_depth: 24
token_transformer_heads: 16
atom_decoder_depth: 3
atom_decoder_heads: 4
conditioning_transition_layers: 2
transformer_post_ln: false
activation_checkpointing: false
confidence_prediction: false
affinity_prediction: false
structure_prediction_training: true
affinity_model_args:
num_dist_bins: 64
max_dist: 22
no_trunk_feats: false
add_s_to_z_prod: false
add_s_input_to_s: false
confidence_args:
num_plddt_bins: 50
num_pde_bins: 64
num_pae_bins: 64
training_args:
recycling_steps: 3
sampling_steps: 20
diffusion_multiplicity: 48
diffusion_samples: 1
affinity_loss_weight: 3e-3
confidence_loss_weight: 1e-4
diffusion_loss_weight: 4.0
distogram_loss_weight: 3e-2
bfactor_loss_weight: 1e-3
adam_beta_1: 0.9
adam_beta_2: 0.95
adam_eps: 0.00000001
lr_scheduler: af3
base_lr: 0.0
max_lr: 0.001
lr_warmup_no_steps: 1000
lr_start_decay_after_n_steps: 50000
lr_decay_every_n_steps: 50000
lr_decay_factor: 0.95
weight_decay: 0.003
weight_decay_exclude: true
validation_args:
recycling_steps: 3
sampling_steps: 200
diffusion_samples: 5
symmetry_correction: false
diffusion_process_args:
sigma_min: 0.0004 # min noise level
sigma_max: 160.0 # max noise level
sigma_data: 16.0 # standard deviation of data distribution
rho: 7 # controls the sampling schedule
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
gamma_0: 0.8
gamma_min: 1.0
noise_scale: 1.0
step_scale: 1.0
mse_rotational_alignment: true
coordinate_augmentation: true
alignment_reverse_diff: true
synchronize_sigmas: false
diffusion_loss_args:
add_smooth_lddt_loss: true
add_bond_loss: false
nucleotide_loss_weight: 5.0
ligand_loss_weight: 10.0