mirror of
https://github.com/HannesStark/boltzgen.git
synced 2026-06-04 11:54:23 +08:00
580 lines
16 KiB
YAML
Executable File
580 lines
16 KiB
YAML
Executable File
_target_: boltzgen.task.train.train.Training
|
|
|
|
trainer:
|
|
accelerator: gpu
|
|
devices: 8
|
|
precision: bf16-mixed
|
|
gradient_clip_val: 10.0
|
|
accumulate_grad_batches: 1
|
|
max_epochs: -1
|
|
num_sanity_val_steps: 3
|
|
log_every_n_steps: 1
|
|
|
|
wandb:
|
|
group: boltzgen
|
|
project: boltzgen
|
|
entity: yourwandb
|
|
|
|
name: a_big_run_resume3
|
|
slurm: true
|
|
output: workdir
|
|
strict_loading: false
|
|
resume: null
|
|
pretrained: null
|
|
debug: false
|
|
save_every_n_train_steps: 2500
|
|
disable_checkpoint: false
|
|
matmul_precision: null
|
|
save_top_k: -1
|
|
|
|
data:
|
|
datasets:
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
target_dir: ./training_data/targets
|
|
msa_dir: ./training_data/msa
|
|
prob: 0.6
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.date.DateFilter
|
|
date: "2023-06-01"
|
|
ref: released
|
|
- _target_: boltzgen.data.filter.dynamic.resolution.ResolutionFilter
|
|
resolution: 9.0
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
split: ./data/pdb_sequences/boltz2/validation_ids_boltz2_all.txt
|
|
symmetry_correction: false
|
|
val_group: "RCSB"
|
|
|
|
# AFDB Distillation Data
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
manifest_path: ./training_data/afdb/afdb_manifest_foldseek_c75_confidence.json
|
|
target_dir: ./training_data/afdb/targets
|
|
msa_dir: ./training_data/afdb/msa
|
|
prob: 0.3
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
|
|
composition_op: "AND"
|
|
metrics: ["confidence_score"]
|
|
compare_ops: ["greater"]
|
|
thresholds: [70]
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
symmetry_correction: true
|
|
override_method: "AFDB"
|
|
override_bfactor: true
|
|
|
|
# Protein-Ligand Distillation Data
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
target_dir: ./training_data/protein_ligand/targets
|
|
msa_dir: ./training_data/protein_ligand/msa
|
|
moldir: ./training_data/protein_ligand/mols
|
|
prob: 0.03
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
|
|
composition_op: "AND"
|
|
metrics: ["complex_ipde", "complex_pde", "iptm"]
|
|
compare_ops: ["lesser", "lesser", "greater"]
|
|
thresholds: [1.5, 1.5, 0.9]
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
beta_chain: 0.05
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
symmetry_correction: true
|
|
override_method: "BOLTZ-1"
|
|
|
|
# RNA Distillation Data
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
target_dir: ./training_data/rna/targets
|
|
msa_dir: ./training_data/rna/msa
|
|
prob: 0.04
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
|
|
composition_op: "OR"
|
|
metrics: ["complex_pde"]
|
|
compare_ops: ["lesser"]
|
|
thresholds: [2.0]
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
symmetry_correction: true
|
|
override_method: "BOLTZ-1"
|
|
|
|
# Protein-DNA Distillation Data
|
|
- _target_: boltzgen.task.train.data.DatasetConfig
|
|
target_dir: ./training_data/protein_dna/targets
|
|
msa_dir: ./training_data/protein_dna/msa
|
|
prob: 0.03
|
|
filters:
|
|
- _target_: boltzgen.data.filter.dynamic.size.SizeFilter
|
|
min_chains: 1
|
|
max_chains: 300
|
|
- _target_: boltzgen.data.filter.dynamic.confidence.ConfidenceFilter
|
|
composition_op: "AND"
|
|
metrics: ["complex_ipde", "complex_pde", "iptm"]
|
|
compare_ops: ["lesser", "lesser", "greater"]
|
|
thresholds: [1.0, 2.0, 0.7]
|
|
sampler:
|
|
_target_: boltzgen.data.sample.cluster.ClusterSampler
|
|
beta_chain: 0.05
|
|
cropper:
|
|
_target_: boltzgen.data.crop.multimer.MultimerCropper
|
|
neighborhood_sizes: [ 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40 ]
|
|
symmetry_correction: true
|
|
override_method: "BOLTZ-1"
|
|
|
|
|
|
tokenizer:
|
|
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
|
atomize_modified_residues: false
|
|
featurizer:
|
|
_target_: boltzgen.data.feature.featurizer.Featurizer
|
|
moldir: ./training_data/mols
|
|
max_tokens: 512
|
|
max_atoms: 5120
|
|
max_seqs: 4096
|
|
pad_to_max_tokens: true
|
|
pad_to_max_atoms: true
|
|
pad_to_max_seqs: true
|
|
samples_per_epoch: 100000
|
|
batch_size: 1
|
|
num_workers: 2
|
|
random_seed: 42
|
|
pin_memory: false
|
|
overfit: null
|
|
return_train_symmetries: false
|
|
return_val_symmetries: false
|
|
|
|
|
|
atoms_per_window_queries: 32
|
|
min_dist: 2.0
|
|
max_dist: 22.0
|
|
num_bins: 64
|
|
single_sequence_prop_training: 0.1
|
|
msa_sampling_training: true
|
|
|
|
|
|
# Design
|
|
design: true
|
|
backbone_only: false
|
|
atom14: true
|
|
atom37: false
|
|
selector:
|
|
_target_: boltzgen.data.select.protein.ProteinSelector
|
|
design_neighborhood_sizes: [2, 4, 6,8,10,12,14,16,18]
|
|
substructure_neighborhood_sizes: [2,4,6,8,10,12,24]
|
|
structure_condition_prob: 0.4
|
|
distance_noise_std: 1
|
|
run_selection: true
|
|
specify_binding_sites: true
|
|
ss_condition_prob: 0.1
|
|
select_all: false
|
|
|
|
# Design datasets
|
|
monomer_split: data/pdb_sequences/val_monomers_boltzgen_min50_max220.txt
|
|
monomer_target_dir: ./training_data/targets
|
|
monomer_target_structure_condition: true
|
|
monomer_seq_len: 100
|
|
|
|
ligand_split: data/pdb_sequences/val_ccd_pdb_pairs_boltzgen.txt
|
|
ligand_target_dir: ./training_data/targets
|
|
ligand_seq_len: 100
|
|
|
|
|
|
model:
|
|
_target_: boltzgen.model.models.boltz.Boltz
|
|
atom_s: 128
|
|
atom_z: 16
|
|
token_s: 384
|
|
token_z: 128
|
|
num_bins: 64
|
|
atom_feature_dim: 388
|
|
atoms_per_window_queries: 32
|
|
atoms_per_window_keys: 128
|
|
use_miniformer: false
|
|
ema: true
|
|
ema_decay: 0.999
|
|
exclude_ions_from_lddt: true
|
|
num_val_datasets: 1 # New
|
|
ignore_ckpt_shape_mismatch: false # New
|
|
aggregate_distogram: true # New
|
|
bond_type_feature: true
|
|
predict_bfactor: true
|
|
checkpoint_diffusion_conditioning: true
|
|
use_kernels: true
|
|
|
|
|
|
validators:
|
|
- _target_: boltzgen.model.validation.design.DesignValidator
|
|
val_names: ["RCSB"]
|
|
confidence_prediction: ${model.confidence_prediction}
|
|
atom14: ${data.atom14}
|
|
atom37: ${data.atom37}
|
|
|
|
masker_args:
|
|
mask: true
|
|
mask_backbone: false
|
|
mask_disto: true
|
|
|
|
embedder_args:
|
|
atom_encoder_depth: 3
|
|
atom_encoder_heads: 4
|
|
add_mol_type_feat: true
|
|
add_method_conditioning: true
|
|
add_modified_flag: true
|
|
add_cyclic_flag: true
|
|
add_design_mask_flag: true
|
|
add_binding_specification: true
|
|
add_ss_specification: true
|
|
|
|
freeze_template_weights: true
|
|
use_templates: true
|
|
template_args:
|
|
template_dim: 64
|
|
template_blocks: 2
|
|
activation_checkpointing: false
|
|
|
|
|
|
use_token_distances: true
|
|
token_distance_args:
|
|
token_distance_dim: 64
|
|
token_distance_blocks: 2
|
|
use_token_distance_feats: true
|
|
distance_gaussian_dim: 32
|
|
activation_checkpointing: true
|
|
|
|
|
|
msa_args:
|
|
msa_s: 64
|
|
msa_blocks: 4
|
|
msa_dropout: 0.15
|
|
z_dropout: 0.25
|
|
miniformer_blocks: false
|
|
pairwise_head_width: 32
|
|
pairwise_num_heads: 4
|
|
use_paired_feature: true
|
|
activation_checkpointing: true
|
|
|
|
|
|
pairformer_args:
|
|
num_blocks: 64
|
|
num_heads: 16
|
|
dropout: 0.25
|
|
post_layer_norm: false
|
|
activation_checkpointing: true
|
|
|
|
|
|
score_model_args:
|
|
sigma_data: 16
|
|
dim_fourier: 256
|
|
atom_encoder_depth: 3
|
|
atom_encoder_heads: 4
|
|
|
|
# token level args
|
|
token_layers: 1
|
|
token_transformer_depth: 24
|
|
token_transformer_heads: 16
|
|
diffusion_pairformer_args:
|
|
num_blocks: 0
|
|
num_heads: 2
|
|
dropout: 0
|
|
use_s_to_z: false
|
|
|
|
|
|
|
|
atom_decoder_depth: 3
|
|
atom_decoder_heads: 4
|
|
conditioning_transition_layers: 2
|
|
transformer_post_ln: false
|
|
activation_checkpointing: true
|
|
|
|
confidence_prediction: false
|
|
structure_prediction_training: true
|
|
|
|
training_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 20
|
|
diffusion_multiplicity: 32
|
|
diffusion_samples: 1
|
|
confidence_loss_weight: 1e-4
|
|
diffusion_loss_weight: 4.0
|
|
distogram_loss_weight: 3e-2
|
|
bfactor_loss_weight: 1e-3
|
|
adam_beta_1: 0.9
|
|
adam_beta_2: 0.95
|
|
adam_eps: 0.00000001
|
|
lr_scheduler: af3
|
|
base_lr: 0.0
|
|
max_lr: 0.0005
|
|
lr_warmup_no_steps: 1000
|
|
lr_start_decay_after_n_steps: 50000
|
|
lr_decay_every_n_steps: 50000
|
|
lr_decay_factor: 0.95
|
|
weight_decay: 0.003
|
|
weight_decay_exclude: true
|
|
|
|
validation_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 200
|
|
diffusion_samples: 1
|
|
symmetry_correction: false
|
|
|
|
diffusion_process_args:
|
|
sigma_min: 0.0004 # min noise level
|
|
sigma_max: 160.0 # max noise level
|
|
sigma_data: 16.0 # standard deviation of data distribution
|
|
rho: 7 # controls the sampling schedule
|
|
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
|
|
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
|
|
gamma_0: 0.8
|
|
gamma_min: 1.0
|
|
noise_scale: 1.0
|
|
step_scale: 1.0
|
|
mse_rotational_alignment: true
|
|
coordinate_augmentation: true
|
|
alignment_reverse_diff: true
|
|
synchronize_sigmas: false
|
|
|
|
diffusion_loss_args:
|
|
add_smooth_lddt_loss: true
|
|
add_bond_loss: false
|
|
nucleotide_loss_weight: 5.0
|
|
ligand_loss_weight: 10.0
|
|
|
|
refolding_validator:
|
|
_target_: boltzgen.model.validation.refolding.RefoldingValidator
|
|
val_names: ["RCSB"]
|
|
step_scale: 1.5
|
|
noise_scale: 0.75
|
|
atom14: ${data.atom14}
|
|
atom37: ${data.atom37}
|
|
val_monomer: ${data.monomer_split}
|
|
val_ligand: ${data.ligand_split}
|
|
analyze_task:
|
|
_target_: boltzgen.task.analyze.analyze.Analyze
|
|
name: ${name}
|
|
debug: ${debug}
|
|
design_dir: null
|
|
num_processes: 1
|
|
|
|
# Common metrics to compute
|
|
affinity_metrics: false
|
|
allatom_fold_metrics: true
|
|
backbone_fold_metrics: true
|
|
noncovalents_original: false
|
|
noncovalents_refolded: false
|
|
delta_sasa_original: false
|
|
delta_sasa_refolded: false
|
|
largest_hydrophobic: false
|
|
largest_hydrophobic_refolded: false
|
|
run_clustering: false
|
|
|
|
# Liability analysis
|
|
liability_analysis: false
|
|
liability_modality: peptide
|
|
liability_peptide_type: linear
|
|
|
|
# Uncommon metrics
|
|
diversity_original: true
|
|
diversity_refolded: true
|
|
diversity_per_target_original: false
|
|
diversity_per_target_refolded: false
|
|
novelty_original: false
|
|
novelty_refolded: false
|
|
novelty_per_target_original: false
|
|
novelty_per_target_refolded: false
|
|
|
|
wandb: null
|
|
|
|
data:
|
|
_target_: boltzgen.task.predict.data_from_generated.FromGeneratedDataModule
|
|
cfg:
|
|
_target_: boltzgen.task.predict.data_from_generated.DataConfig
|
|
tokenizer:
|
|
_target_: boltzgen.data.tokenize.tokenizer.Tokenizer
|
|
atomize_modified_residues: false
|
|
featurizer:
|
|
_target_: boltzgen.data.feature.featurizer.Featurizer
|
|
|
|
suffix: .cif
|
|
suffix_metadata: .npz
|
|
suffix_native: _native.cif
|
|
samples_per_target: 1
|
|
num_targets: 100000000
|
|
moldir: ./training_data/mols
|
|
|
|
batch_size: 1
|
|
num_workers: 4
|
|
pin_memory: true
|
|
return_native: true
|
|
predict_task: null
|
|
|
|
folding_checkpoint: ./training_data/boltz2_fold.ckpt
|
|
|
|
folding_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 200
|
|
diffusion_samples: 1
|
|
|
|
folding_model_args:
|
|
atom_s: 128
|
|
atom_z: 16
|
|
token_s: 384
|
|
token_z: 128
|
|
num_bins: 64
|
|
atom_feature_dim: 388
|
|
atoms_per_window_queries: 32
|
|
atoms_per_window_keys: 128
|
|
compile_pairformer: false
|
|
compile_templates: false
|
|
compile_msa: false
|
|
use_miniformer: false
|
|
ema: true
|
|
ema_decay: 0.999
|
|
exclude_ions_from_lddt: true
|
|
num_val_datasets: 4
|
|
ignore_ckpt_shape_mismatch: false
|
|
aggregate_distogram: true
|
|
bond_type_feature: true
|
|
conditioning_cutoff_min: 4.0
|
|
conditioning_cutoff_max: 20.0
|
|
use_templates: true
|
|
predict_bfactor: true
|
|
checkpoint_diffusion_conditioning: false
|
|
use_kernels: true
|
|
|
|
validators: null
|
|
|
|
embedder_args:
|
|
atom_encoder_depth: 3
|
|
atom_encoder_heads: 4
|
|
add_mol_type_feat: true
|
|
add_method_conditioning: true
|
|
add_modified_flag: true
|
|
add_cyclic_flag: true
|
|
|
|
msa_args:
|
|
msa_s: 64
|
|
msa_blocks: 4
|
|
msa_dropout: 0.15
|
|
z_dropout: 0.25
|
|
miniformer_blocks: false
|
|
pairwise_head_width: 32
|
|
pairwise_num_heads: 4
|
|
use_paired_feature: true
|
|
activation_checkpointing: false
|
|
|
|
|
|
template_args:
|
|
template_dim: 64
|
|
template_blocks: 2
|
|
activation_checkpointing: false
|
|
|
|
|
|
pairformer_args:
|
|
num_blocks: 64
|
|
num_heads: 16
|
|
dropout: 0.25
|
|
post_layer_norm: false
|
|
activation_checkpointing: false
|
|
|
|
|
|
score_model_args:
|
|
sigma_data: 16
|
|
dim_fourier: 256
|
|
atom_encoder_depth: 3
|
|
atom_encoder_heads: 4
|
|
token_transformer_depth: 24
|
|
token_transformer_heads: 16
|
|
atom_decoder_depth: 3
|
|
atom_decoder_heads: 4
|
|
conditioning_transition_layers: 2
|
|
transformer_post_ln: false
|
|
activation_checkpointing: false
|
|
|
|
confidence_prediction: false
|
|
affinity_prediction: false
|
|
structure_prediction_training: true
|
|
affinity_model_args:
|
|
num_dist_bins: 64
|
|
max_dist: 22
|
|
no_trunk_feats: false
|
|
add_s_to_z_prod: false
|
|
add_s_input_to_s: false
|
|
confidence_args:
|
|
num_plddt_bins: 50
|
|
num_pde_bins: 64
|
|
num_pae_bins: 64
|
|
|
|
training_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 20
|
|
diffusion_multiplicity: 48
|
|
diffusion_samples: 1
|
|
affinity_loss_weight: 3e-3
|
|
confidence_loss_weight: 1e-4
|
|
diffusion_loss_weight: 4.0
|
|
distogram_loss_weight: 3e-2
|
|
bfactor_loss_weight: 1e-3
|
|
adam_beta_1: 0.9
|
|
adam_beta_2: 0.95
|
|
adam_eps: 0.00000001
|
|
lr_scheduler: af3
|
|
base_lr: 0.0
|
|
max_lr: 0.001
|
|
lr_warmup_no_steps: 1000
|
|
lr_start_decay_after_n_steps: 50000
|
|
lr_decay_every_n_steps: 50000
|
|
lr_decay_factor: 0.95
|
|
weight_decay: 0.003
|
|
weight_decay_exclude: true
|
|
|
|
validation_args:
|
|
recycling_steps: 3
|
|
sampling_steps: 200
|
|
diffusion_samples: 5
|
|
symmetry_correction: false
|
|
|
|
diffusion_process_args:
|
|
sigma_min: 0.0004 # min noise level
|
|
sigma_max: 160.0 # max noise level
|
|
sigma_data: 16.0 # standard deviation of data distribution
|
|
rho: 7 # controls the sampling schedule
|
|
P_mean: -1.2 # mean of log-normal distribution from which noise is drawn for training
|
|
P_std: 1.5 # standard deviation of log-normal distribution from which noise is drawn for training
|
|
gamma_0: 0.8
|
|
gamma_min: 1.0
|
|
noise_scale: 1.0
|
|
step_scale: 1.0
|
|
mse_rotational_alignment: true
|
|
coordinate_augmentation: true
|
|
alignment_reverse_diff: true
|
|
synchronize_sigmas: false
|
|
|
|
diffusion_loss_args:
|
|
add_smooth_lddt_loss: true
|
|
add_bond_loss: false
|
|
nucleotide_loss_weight: 5.0
|
|
ligand_loss_weight: 10.0
|