Add initial RFD3 Files and passing tests

* Add initial files

* add files

* Move projects.aa_design -> rfd3

* Make format

* Delete test files

* Add configs

* Mc

* Fixed tests

* remove test files
This commit is contained in:
Jasper Butcher
2025-11-11 10:07:43 -08:00
committed by GitHub
parent 36f668b5ae
commit 5e7b739ed3
206 changed files with 26143 additions and 21 deletions

View File

@@ -14,8 +14,8 @@ clean:
## Format src directory using black
format:
ruff format src models tests projects
ruff check --fix src models tests projects
ruff format src models tests
ruff check --fix src models tests
#################################################################################
# Self Documenting Commands #

View File

@@ -40,8 +40,7 @@ git clone https://github.com/RosettaCommons/modelforge.git \
&& uv python install 3.12 \
&& uv venv --python 3.12 \
&& source .venv/bin/activate \
&& uv pip install -e . \
&& uv pip install -e ./models/rf3
&& uv pip install -e ".[rf3]"
```
@@ -71,22 +70,19 @@ ModelForge uses a multi-package architecture:
#### For Users (Single Model)
To use a model, you must first install `modelhub` in editable mode, then install the specific model:
Use `pip install modelforge[<model>]` semantics (or `uv pip install -e ".[<model>]"` when developing locally) to pull in model-specific dependencies and the associated CLI in one step:
```bash
# Install modelhub first (required)
# Install modelforge with RF3 ready to go
uv pip install -e ".[rf3]"
# Install only the core utilities (no models)
uv pip install -e .
# Then install RF3
uv pip install -e ./models/rf3
# Future: Install other models
# uv pip install -e ./models/other_model
# Future models follow the same pattern
# uv pip install -e ".[other_model]"
```
> [!IMPORTANT]
> You must install `modelhub` with `-e` (editable mode) first, before installing any model packages. This ensures both packages are installed in editable mode for proper development workflow.
#### For Core Developers (Multiple Packages)
Install both `modelhub` and models in editable mode for development:

View File

@@ -23,6 +23,8 @@ classifiers = [
]
dependencies = [
# Core functionality shared across all models
"modelforge",
# CLI
"typer>=0.9.0,<1",
# RF3-specific ML dependencies

View File

@@ -28,7 +28,6 @@ def run_inference(cfg: DictConfig) -> None:
"""Execute RF3 inference pipeline."""
# Extract run() parameters from config
# Preserve string inputs, convert other sequence-like inputs to a Python list (None -> [])
inputs_param = cfg.inputs if isinstance(cfg.inputs, str) else list(cfg.inputs or [])

1
models/rfd3/.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
tests/test_data

110
models/rfd3/README.md Normal file
View File

@@ -0,0 +1,110 @@
# De novo Design of Biomolecular Interactions with RFdiffusion3
<p align="center">
<img src="docs/.assets/trajectory.png" alt="All-atom diffusion with RFD3">
</p>
## Installation, Setup, and a Basic Design
### A. Installation using `uv`
```bash
git clone https://github.com/RosettaCommons/modelforge.git \
&& cd modelforge \
&& uv python install 3.12 \
&& uv venv --python 3.12 \
&& source .venv/bin/activate \
&& uv pip install -e . \
&& uv pip install -e ./models/rfd3
```
> [!IMPORTANT]
> You must install `modelhub` (the root package) with `-e` first, then install `rf3`. This ensures both packages are in editable mode for proper development workflow.
### B. Download model weights for RF3
```bash
wget http://files.ipd.uw.edu/pub/rf3/rf3_latest.pt
```
If you're looking for the 9/21 model (e.g., for benchmarking against other models with the same date cutoff):
```bash
wget http://files.ipd.uw.edu/pub/rf3/rf3_921.pt
```
The inference API is otherwise identical.
### C. Run a test prediction
```bash
rf3 fold inputs='tests/data/5vht_from_json.json'
```
You may then specify the specific checkpoint, if desired, with:
```bash
rf3 fold inputs='tests/data/5vht_from_json.json' ckpt_path='/path/to/rf3_921.pt'
```
**Setup**
RFD3 currently requires specific branches for `cifutils` and `datahub` so we recommend cloning the branch with submodules:
```bash
git clone -b aa_design/main git@github.com:baker-laboratory/modelhub.git
git submodule init
git submodule update --init
export PROJECT_PATH="$(pwd)/projects/aa_design"
```
Files for RFD3 exist under this folder (`projects/aa_design`), and wrap around the components for the AF3 repro under `src/modelhub`.
The AF3 repro might not work on this branch since it is currently not kept up to date.
If you run `inference.py` as a script (via `./src/inference.py`), then the shebang file should take care of the submodule paths for you (if you cloned the submodules). Otherwise, add the following to your environment;
```
chmod +x src/modelhub/*.py
```
## Inference:
The following checkpoint is updated continuously (see channel https://chat.ipd.uw.edu/ipd/channels/rfdiffusion3):
```bash
cur_ckpt=/projects/ml/aa_design/models/rfd3_latest.ckpt
```
To run inference, use:
```bash
./src/modelhub/inference.py out_dir=logs/inference_outs/demo/0 ckpt_path=$cur_ckpt inputs=projects/aa_design/tests/test_data/demo.json print_config=True dump_trajectories=True
```
Additional args here are added for verbosity, aligning trajectory structures, printing the config and dumping trajectories are turned off by default.
For full details on how to specify inputs, see the `input.md` documentation.
## PPI-Design
See `input.md`, please feel free to reach out to me (Rafi Brent) if you have any questions or concerns. I'm happy to help out however I can!
## Enzyme-Design
See `input.md`, feel free to send jbutch or jfunk21 a message if in doubt!
## Symmetric-Design
See `symmetry.md`, plese feel free to reach out to aimura or heisen if you have any questions!
## Training (w & w/o WandB):
Add `export PROJECT_PATH=$(pwd)/projects/aa_design` to `scripts/slurm/launch.sh`, where `$(pwd)` is the repositories' absolute path
You will also want to add your cifutils, datahub and modelhub (`$(pwd)`) paths to `launch.sh`.
To launch a training run, use:
```
sbatch -J rfd3-full-sparse launch.sh
```
Optionally ensure your `WANDB_API_KEY` is an environment variable. You can disable wandb by including the following at the top of your experiment config:
```yaml
defaults:
- override /logger: csv # turns off wandb logger
```
## Conditioining Pipeline
Both inference and validation passes arguments to `create_atom_array_from_design_specification`, to create an atom array with all the information needed to run inference.
This is then passed through the same processing pipeline as in training with `is_inference=True` (pipeline in `./projects/aa_design/transforms/pipelines.py`).
<p align="center">
<img src="docs/.assets/pipeline.png" alt="Atom14 Design Pipelines">
<figcaption>Overview of important transforms in the Atom14 conditioning pipeline.
</figcaption>
</p>

View File

View File

@@ -0,0 +1,5 @@
defaults:
- train_logging
- metrics_logging
- dump_validation_structures
- _self_

View File

@@ -0,0 +1,44 @@
defaults:
- train_logging
- dump_validation_structures
- _self_
# Validation metrics:
store_validation_metrics_in_df_callback:
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
save_dir: ${paths.output_dir}/val_metrics
dump_validation_structures_callback:
dump_predictions: True
dump_prediction_metadata_json: True
dump_trajectories: False
dump_denoised_trajectories_only: False
one_model_per_file: True
dump_every_n: 4
align_trajectories: False
verbose: False
log_pipelines_results_callback:
_target_: rfd3.callbacks.LogPipelinesResultsCallback
metrics_save_dir: ${paths.output_dir}/val_metrics
# Other:
log_design_validation_metrics_callback:
_target_: rfd3.callbacks.LogDesignValidationMetricsCallback
log_learning_rate_callback:
log_every_n: 25 # default 10
log_af3_training_losses_callback:
log_full_batch_losses: False
log_every_n: 25 # default 10
log_train_variable_callback:
_target_: rfd3.callbacks.LogTrainVariableCallback
log_every_n: 1
variables:
- hbond_total_count
- hbond_subsample_atoms
- hbond_total_atoms

View File

@@ -0,0 +1,8 @@
dump_validation_structures_callback:
_target_: modelhub.callbacks.dump_validation_structures.DumpValidationStructuresCallback
save_dir: ${paths.output_dir}/val_structures
dump_predictions: False
one_model_per_file: False
dump_trajectories: False
align_trajectories: True
verbose: False

View File

@@ -0,0 +1,14 @@
store_validation_metrics_in_df_callback:
_target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback
save_dir: ${paths.output_dir}/val_metrics
metrics_to_save: "all"
log_af3_validation_metrics_callback:
_target_: modelhub.callbacks.metrics_logging.LogAF3ValidationMetricsCallback
metrics_to_log:
# Only logs if present in the metric output dictionary
# Must be subset of metrics_to_save
- by_type_lddt
- all_atom_lddt
- distogram_loss
- distogram_comparisons

View File

@@ -0,0 +1,15 @@
defaults:
- design_callbacks
- _self_
dump_validation_structures_callback:
dump_every_n: 5
extra_fields:
- is_motif_atom_with_fixed_seq
run_pipelines_callback:
_target_: rfd3.callbacks.RunPipelinesCallback
save_dir: ${paths.output_dir}
run_every_n_epochs: 5
pipelines_config_path: ${paths.root_dir}/rfd3/configs/pipelines/aa_ppi_design_af3_validation.yaml
pipelines_script_path: ${paths.root_dir}/lib/pipelines/pipelines/pipeline.py

View File

@@ -0,0 +1,14 @@
log_af3_training_losses_callback:
_target_: modelhub.callbacks.train_logging.LogAF3TrainingLossesCallback
log_every_n: 10
log_full_batch_losses: true
log_learning_rate_callback:
_target_: modelhub.callbacks.train_logging.LogLearningRateCallback
log_every_n: 10
log_model_parameters_callback:
_target_: modelhub.callbacks.train_logging.LogModelParametersCallback
log_dataset_sampling_ratios_callback:
_target_: modelhub.callbacks.train_logging.LogDatasetSamplingRatiosCallback

View File

@@ -0,0 +1,15 @@
train:
dataloader_params:
# These parameters will be unpacked as kwargs for the DataLoader
batch_size: 1
num_workers: 2
prefetch_factor: 3
n_fallback_retries: 4
val:
dataloader_params:
# These parameters will be unpacked as kwargs for the DataLoader
batch_size: 1
num_workers: 2
prefetch_factor: 3
n_fallback_retries: 0 # Disable fallback retries for validation

View File

@@ -0,0 +1,11 @@
defaults:
- default
train:
dataloader_params:
num_workers: 2
prefetch_factor: 6
val:
dataloader_params:
num_workers: 2
prefetch_factor: 6

View File

@@ -0,0 +1,47 @@
# base training dataset for training AF3 design models (atom14 variants):
# protein subsampling only.
defaults:
- design_base
# Validation sets
# - val/sm_binder_hbonds@val.sm_binder_hbonds
# - val/unindexed@val.unindexed
- val/mcsa_41_short_rigid@val.mcsa_41_short_rigid
- val/bcov_ppi_easy_medium@val.bcov_ppi_easy_medium
- val/unconditional@val.unconditional
# - val/sm_binder_hbonds_short@val.sm_binder_hbonds_short
# - val/unconditional_deep@val.unconditional_deep
# - val/indexed@val.indexed
# Training Sets
- train/pdb/dna_binder@train.dna_binder
- train/pdb/na_complex_distillation@train.na_complex_distillation
- train/pdb/free_dna_dds@train.free_dna_dds
- train/bcov_ppi_distillation@train
- _self_
val:
# Small datasets;
unconditional:
dataset:
eval_every_n: 1
# Medium:
mcsa_41_short_rigid:
dataset:
eval_every_n: 1
train:
# Additional datasets beyond the base training dataset
dna_binder:
probability: 0.0
na_complex_distillation:
probability: 0.0
free_dna_dds:
probability: 0.0
bcov_ppi_distillation:
probability: 0.0
interdomain_distillation:
probability: 0.0

View File

@@ -0,0 +1,3 @@
_target_: rfd3.transforms.training_conditions.SubtypeCondition
frequency: 1.0
subtype: ["is_dna", "is_rna"]

View File

@@ -0,0 +1,29 @@
_target_: rfd3.transforms.training_conditions.IslandCondition
frequency: 1.0
name: island
# Island sampling (`is_motif_token` assignment)
island_sampling_kwargs:
island_len_min: 1
island_len_max: 25
n_islands_min: 2
n_islands_max: 5
# Subgraph / within-token sampling (`is_motif_atom` assignment)
p_diffuse_motif_sidechains: 0.80 # 80% probability of diffusing sidechains
p_diffuse_subgraph_atoms: 0.0 # 0% probability of sampling subgraph atoms (defaults to fully fixed)
subgraph_sampling_kwargs: # see tipatom
residue_p_seed_furthest_from_o: null
residue_n_bond_expectation: null
residue_p_fix_all: null
hetatom_n_bond_expectation: null
hetatom_p_fix_all: null
# Sets `is_motif_atom_with_fixed_seq`
p_fix_motif_sequence: 0.2 # probability that sequence is fixed for all motifs during training
# Sets `is_motif_atom_with_fixed_coord`
p_fix_motif_coordinates: 1.0 # Of the atoms that are sampled, should their coordinates be fixed?
# Sets `is_motif_atom_with_unindexed`
p_unindex_motif_tokens: 0.0 # probability of unindexing all motif atoms

View File

@@ -0,0 +1,2 @@
_target_: rfd3.transforms.training_conditions.PPICondition
frequency: 1.0

View File

@@ -0,0 +1,16 @@
defaults:
- island
- _self_
frequency: 1.0
name: sequence_design
island_sampling_kwargs:
island_len_min: 99999
island_len_max: 999999999
p_diffuse_motif_sidechains: 1.0
p_fix_motif_sequence: 0.0
p_fix_motif_coordinates: 1.0
p_unindex_motif_tokens: 0.0

View File

@@ -0,0 +1,28 @@
defaults:
- island
- _self_
frequency: 1.0
name: tipatom
# Island sampling (`is_motif_token` assignment)
island_sampling_kwargs:
island_len_min: 1
island_len_max: 1
n_islands_min: 2
n_islands_max: 9
# Subgraph / within-token sampling (`is_motif_atom` assignment)
p_diffuse_motif_sidechains: 0.0 # 80% probability of diffusing sidechains
p_diffuse_subgraph_atoms: 1.0
subgraph_sampling_kwargs:
residue_p_seed_furthest_from_o: 0.7
residue_n_bond_expectation: 3
residue_p_fix_all: 0.05
hetatom_n_bond_expectation: 8
hetatom_p_fix_all: 0.5
p_fix_motif_sequence: 0.7
p_fix_motif_coordinates: 1.0
p_unindex_motif_tokens: 0.5

View File

@@ -0,0 +1,21 @@
# Unconditional that fixes non-protein targets
defaults:
- island
- _self_
frequency: 1.0
name: unconditional
island_sampling_kwargs:
island_len_min: 0
island_len_max: 0
n_islands_min: 0
n_islands_max: 0
# Conditional assignments won't matter for protein regions since always diffused:
p_diffuse_motif_sidechains: 0.0
p_diffuse_subgraph_atoms: 0.0
p_fix_motif_sequence: 0.0
p_fix_motif_coordinates: 0.0
p_unindex_motif_tokens: 0.0

View File

@@ -0,0 +1,106 @@
# base training dataset for training AF3 design models (atom14 variants):
# protein subsampling only.
defaults:
# Grab datasets
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
- train/pdb/rfd3_train_pn_unit@train.pdb.sub_datasets.pn_unit
- train/rfd3_monomer_distillation@train
- train/af2db_interdomain_distillation@train
# Customized validation datasets
- val/unconditional@val.unconditional
- val/unconditional_deep@val.unconditional_deep
- val/indexed@val.indexed
# Customized train masks
- conditions/unconditional@global_transform_args.train_conditions.unconditional
- conditions/island@global_transform_args.train_conditions.island
- conditions/tipatom@global_transform_args.train_conditions.tipatom
- conditions/sequence_design@global_transform_args.train_conditions.sequence_design
- conditions/ppi@global_transform_args.train_conditions.ppi
- _self_
# Create a dictionary used for transform arguments
pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
# Base config overrides:
diffusion_batch_size_train: 32
diffusion_batch_size_inference: 8
crop_size: 256
n_recycles_train: 1
n_recycles_validation: 1
max_atoms_in_crop: 2560 # 256x14 doesnt work since we pad up | 256 x 10
# Global transform arguments are necessary for arguments shared between training and inference
global_transform_args:
n_atoms_per_token: 14
central_atom: CB
sigma_perturb: 1.0
sigma_perturb_com: 0.0
association_scheme: atom14 # atom14, atom14-new or null
center_option: all # options are ["all", "motif", "diffuse"]
# Reference conformer policy
generate_conformers: True
generate_conformers_for_non_protein_only: True
provide_reference_conformer_when_unmasked: False
ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details
provide_elements_for_unindexed_components: False
use_element_for_atom_names_of_atomized_tokens: False
# PPI Cropping
keep_full_binder_in_spatial_crop: False
max_binder_length: 170
# PPI Hotspots
max_ppi_hotspots_frac_to_provide: 0.2
ppi_hotspot_max_distance: 4.5
# Secondary structure features
max_ss_frac_to_provide: 0.4
min_ss_island_len: 1
max_ss_island_len: 10
train_conditions:
unconditional:
frequency: 5.0
sequence_design:
frequency: 2.0
island:
frequency: 1.0
tipatom:
frequency: 0.0
ppi:
frequency: 0.0
# Used to create simple boolean flags for downstream conditioning
meta_conditioning_probabilities:
calculate_rasa: 0.0 # not currently used in the default model anyway
keep_protein_motif_rasa: 0.2
mirror_input: 0.0
calculate_hbonds: 0.0
# Conditioning token scheme probabilities
unindex_leak_global_index: 0.33
unindex_insert_random_break: 0.33
unindex_remove_random_break: 0.33
# Probability of adding 1d secondary structure conditioning
add_1d_ss_features: 0.0
add_global_is_non_loopy_feature: 0.0
featurize_plddt: 0.0
# PPI
add_ppi_hotspots: 0
full_binder_crop: 0
# Dataset probabilities
train:
pdb:
probability: 0.75
monomer_distillation:
probability: 0.25
interdomain_distillation:
probability: 0.0

View File

@@ -0,0 +1,98 @@
# # base training dataset for training AF3 design models on the DNA binder task specifically:
defaults:
- base
- train/pdb/dna_binder@train.dna_binder
- train/pdb/rfd3_train_pn_unit@train.rfd3_train_pn_unit
- train/pdb/na_complex_distillation@train.na_complex_distillation
- train/pdb/free_dna_dds@train.free_dna_dds
- val/dna_binder_short@val.dna_binder
- conditions/unconditional@global_transform_args.train_conditions.unconditional
- _self_
pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
global_transform_args:
n_atoms_per_token: 14
central_atom: CB
sigma_perturb: 1.0
sigma_perturb_com: 0.0
association_scheme: atom14 # atom14, atom14-new or null
center_option: "all" # options are ["all", "motif", "diffuse"]
# Reference conformer policy
generate_conformers: False
generate_conformers_for_non_protein_only: False
provide_reference_conformer_when_unmasked: False
ground_truth_conformer_policy: IGNORE # Other options: REPLACE, ADD, FALLBACK. See atomworks.enums for details
provide_elements_for_unindexed_components: False
# PPI Cropping
keep_full_binder_in_spatial_crop: False
max_binder_length: 170
# PPI Hotspots
max_ppi_hotspots_frac_to_provide: 0.2
ppi_hotspot_max_distance: 4.5
# Secondary structure features
max_ss_frac_to_provide: 0.4
min_ss_island_len: 1
max_ss_island_len: 10
train_conditions:
unconditional:
frequency: 1.0
sequence_design:
frequency: 0.0
island:
frequency: 0.0
tipatom:
frequency: 0.0
ppi:
frequency: 0.0
meta_conditioning_probabilities:
calculate_rasa: 0.0
keep_protein_motif_rasa: 0.0
calculate_hbonds: 1.0
hbond_subsample: 0.5
# Conditioning token scheme probabilities
unindex_leak_global_index: 0.0
unindex_insert_random_break: 0.0
unindex_remove_random_break: 0.0
# Probability of adding 1d secondary structure conditioning
add_1d_ss_features: 0.0
# Probability of adding global secondary structure conditioning
add_global_is_non_loopy_feature: 0.0
# PPI
add_ppi_hotspots: 0
full_binder_crop: 0
# Base config overrides:
seed: 42
diffusion_batch_size_train: 32
diffusion_batch_size_inference: 8
crop_size: 256
n_recycles_train: 1
n_recycles_validation: 1
n_msa: null
max_atoms_in_crop: 2560 # 256x14 doesnt work since we pad up | 256 x 10
train:
dna_binder:
probability: 0.55
rfd3_train_pn_unit:
probability: 0.2
na_complex_distillation:
probability: 0.19
free_dna_dds:
probability: 0.06

View File

@@ -0,0 +1,15 @@
defaults:
- design_base
- _self_
global_transform_args:
train_conditions:
unconditional:
frequency: 0.2
sequence_design:
frequency: 0.1
island:
frequency: 0.7
tipatom:
frequency: 0.0

View File

@@ -0,0 +1,146 @@
defaults:
- base
# Interface dataset only
- train/pdb/rfd3_train_interface@train.pdb.sub_datasets.interface
# PPI validation dataset
- val/bcov_ppi_easy_medium@val.bcov_ppi_easy_medium
# Customized train masks
- conditions/unconditional@global_transform_args.train_conditions.unconditional
- conditions/island@global_transform_args.train_conditions.island
- conditions/tipatom@global_transform_args.train_conditions.tipatom
- conditions/sequence_design@global_transform_args.train_conditions.sequence_design
- conditions/ppi@global_transform_args.train_conditions.ppi
- _self_
pipeline_target: rfd3.transforms.pipelines.build_atom14_base_pipeline
# Global transform arguments are necessary for arguments shared between training and inference
global_transform_args:
allowed_types: ['is_protein', 'is_ligand']
n_atoms_per_token: 14
central_atom: CB
sigma_perturb: 4.0
sigma_perturb_com: 3.0
association_scheme: atom14-new # atom14, atom14-new or null
center_option: all # options are ["all", "motif", "diffuse"]
# Give ground-truth reference conformer for residues with unmasked sequences
generate_conformers: True
generate_conformers_for_non_protein_only: False
provide_reference_conformer_when_unmasked: True
ground_truth_conformer_policy: REPLACE
provide_elements_for_unindexed_components: False
use_element_for_atom_names_of_atomized_tokens: False
# PPI Cropping
keep_full_binder_in_spatial_crop: False
max_binder_length: 170
# PPI Hotspots
max_ppi_hotspots_frac_to_provide: 0.2
ppi_hotspot_max_distance: 4.5
# Secondary structure features
max_ss_frac_to_provide: 0.4
min_ss_island_len: 1
max_ss_island_len: 10
train_conditions:
unconditional:
frequency: 0.1
sequence_design:
frequency: 0.0
island:
frequency: 0.1
tipatom:
frequency: 0.0
ppi:
frequency: 0.8
# Used to create simple boolean flags for downstream conditioning
meta_conditioning_probabilities:
calculate_rasa: 0.0 # not currently used in the default model anyway
keep_protein_motif_rasa: 0.2
mirror_input: 0.0
featurize_plddt: 0.0
# Conditioning token scheme probabilities
unindex_leak_global_index: 0.33
unindex_insert_random_break: 0.33
unindex_remove_random_break: 0.33
# Probability of adding 1d secondary structure conditioning
add_1d_ss_features: 0.0
add_global_is_non_loopy_feature: 0.0
# PPI
add_ppi_hotspots: 0
full_binder_crop: 0
# Hbonds
calculate_hbonds: 0.0 # No hbonds by default
hbond_subsample: 0.5 # Same as in DNA/sm cases
# Base config overrides:
seed: 42
diffusion_batch_size_train: 32
diffusion_batch_size_inference: 8
crop_size: 256
n_recycles_train: 1
n_recycles_validation: 1
n_msa: null
max_atoms_in_crop: 5000
train:
pdb:
probability: 1.0
sub_datasets:
interface:
dataset:
dataset:
data: ${paths.data.pdb_data_dir}/interfaces_df_structure_clustered_tm_0.8_singletons_removed_len_lt_25_removed.parquet
filters:
# filters common across all PDB datasets
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
# ppi specific filters
- "pn_unit_1_type.astype('str').str.match('${chain_type_info_to_regex:PROTEINS}')"
- "pn_unit_2_type.astype('str').str.match('${chain_type_info_to_regex:PROTEINS}')"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
# interface specific columns
- pn_unit_1_iid
- pn_unit_2_iid
- pn_unit_1_non_polymer_res_names
- pn_unit_2_non_polymer_res_names
- is_inter_molecule
- all_pn_unit_iids_after_processing
- involves_loi
# necessary if filtering by pn_unit_type
- pn_unit_1_type
- pn_unit_2_type
transform:
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0

View File

@@ -0,0 +1,3 @@
defaults:
- ppi # NOTE: This will also port over the bcov val dataset, which should be removed in the launch script
- val/ppi_inference@val.ppi_inference

View File

@@ -0,0 +1,10 @@
defaults:
- ppi
- train/bcov_ppi_distillation@train
- _self_
train:
pdb:
probability: 0.5
bcov_ppi_distillation:
probability: 0.5

View File

@@ -0,0 +1,10 @@
defaults:
- ppi
- train/af2db_interdomain_distillation@train
- _self_
train:
pdb:
probability: 0.5
interdomain_distillation:
probability: 0.5

View File

@@ -0,0 +1,23 @@
# base training dataset for training AF3 design models (atom14 variants):
# protein subsampling only.
defaults:
- design_base
- val/sm_binder_hbonds@val.sm_binder_hbonds
- val/sm_binder_hbonds_short@val.sm_binder_hbonds_short
- _self_
# Datasets:
train:
pdb:
probability: 0.9
sub_datasets:
interface:
dataset:
transform:
crop_spatial_probability: 1.0
# Should we do monomer for hbond? probably not, maybe small sampling so model learns something?
monomer_distillation:
probability: 0.1

View File

@@ -0,0 +1,40 @@
defaults:
- pdb/base_transform_args@interdomain_distillation
- _self_
interdomain_distillation:
dataset:
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames:
- pn_unit_1_iid
- pn_unit_2_iid
path_colname: path
dataset:
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af2db_interdomain_distillation
data: /projects/ml/datahub/dfs/af2db_interdomain_dset.parquet
id_column: example_id
columns_to_load:
- example_id
- path
- pn_unit_1_iid
- pn_unit_2_iid
- cluster
transform:
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0
weights:
_target_: atomworks.ml.samplers.calculate_weights_by_inverse_cluster_size
cluster_column: cluster

View File

@@ -0,0 +1,38 @@
defaults:
- pdb/base_transform_args@bcov_ppi_distillation
- _self_
bcov_ppi_distillation:
dataset:
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames:
- pn_unit_1_iid
- pn_unit_2_iid
path_colname: path
dataset:
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: bcov_ppi_distillation
data: /projects/ml/datahub/dfs/bcov_ppi_dset_af3_preds.parquet
id_column: example_id
columns_to_load:
- example_id
- path
- pn_unit_1_iid
- pn_unit_2_iid
transform:
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0

View File

@@ -0,0 +1,45 @@
defaults:
- base
dataset:
dataset_parser:
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
dataset:
name: interface
data: ${paths.data.pdb_data_dir}/interfaces_df_train.parquet
filters:
# filters common across all PDB datasets
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
# interface specific columns
- pn_unit_1_iid
- pn_unit_2_iid
- pn_unit_1_non_polymer_res_names
- pn_unit_2_non_polymer_res_names
- is_inter_molecule
- all_pn_unit_iids_after_processing
- involves_loi
transform:
# interface-specific Transform pipeline parameters
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0

View File

@@ -0,0 +1,41 @@
defaults:
- base
dataset:
dataset_parser:
_target_: atomworks.ml.datasets.parsers.PNUnitsDFParser
dataset:
name: pn_unit
data: ${paths.data.pdb_data_dir}/pn_units_df_train.parquet
filters:
# filters common across all PDB datasets
- "deposition_date < '2021-09-30'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# pn_unit specific filters
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
- total_num_atoms_in_unprocessed_assembly
# pn_unit specific columns
- q_pn_unit_iid
- q_pn_unit_non_polymer_res_names
- all_pn_unit_iids_after_processing
- q_pn_unit_is_loi
transform:
# pn_unit-specific Transform pipeline parameters
crop_contiguous_probability: 0.3333333333333333
crop_spatial_probability: 0.6666666666666667

View File

@@ -0,0 +1,14 @@
# Adds weights to the sampler
defaults:
- base_no_weights
- _self_
weights:
_target_: atomworks.ml.samplers.calculate_weights_for_pdb_dataset_df
beta: 0.5
alphas:
a_prot: 3.0 # 3 for AF-3
a_nuc: 0.0 # 3 for AF-3
a_ligand: 1.0 # 1 for AF-3
a_loi: 5.0 # 5 for AF-3

View File

@@ -0,0 +1,19 @@
defaults:
- base_transform_args
- _self_
dataset:
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
cif_parser_args:
cache_dir: null
load_from_cache: false
save_to_cache: false
dataset:
_target_: atomworks.ml.datasets.datasets.PandasDataset
# we will use the example_id as the unique column
id_column: example_id
transform:
# common Transform pipeline components for all PDB datasets
_target_: ${datasets.pipeline_target}
is_inference: False

View File

@@ -0,0 +1,59 @@
# All required training args
defaults:
- _self_
dataset:
transform:
_target_: ${datasets.pipeline_target}
is_inference: False
return_atom_array: False
# Model
sigma_perturb: ${datasets.global_transform_args.sigma_perturb}
sigma_perturb_com: ${datasets.global_transform_args.sigma_perturb_com}
sigma_data: ${model.net.diffusion_module.sigma_data}
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
central_atom: ${datasets.global_transform_args.central_atom}
n_atoms_per_token: ${datasets.global_transform_args.n_atoms_per_token}
association_scheme: ${datasets.global_transform_args.association_scheme}
center_option: ${datasets.global_transform_args.center_option}
# Conformers
generate_conformers: ${datasets.global_transform_args.generate_conformers}
generate_conformers_for_non_protein_only: ${datasets.global_transform_args.generate_conformers_for_non_protein_only}
provide_reference_conformer_when_unmasked: ${datasets.global_transform_args.provide_reference_conformer_when_unmasked}
ground_truth_conformer_policy: ${datasets.global_transform_args.ground_truth_conformer_policy}
provide_elements_for_unindexed_components: ${datasets.global_transform_args.provide_elements_for_unindexed_components}
use_element_for_atom_names_of_atomized_tokens: ${datasets.global_transform_args.use_element_for_atom_names_of_atomized_tokens}
residue_cache_dir: ${paths.data.residue_cache_dir}
# Conditions
train_conditions: ${datasets.global_transform_args.train_conditions}
meta_conditioning_probabilities: ${datasets.global_transform_args.meta_conditioning_probabilities}
# PPI Hypers
keep_full_binder_in_spatial_crop: ${datasets.global_transform_args.keep_full_binder_in_spatial_crop}
max_binder_length: ${datasets.global_transform_args.max_binder_length}
max_ppi_hotspots_frac_to_provide: ${datasets.global_transform_args.max_ppi_hotspots_frac_to_provide}
ppi_hotspot_max_distance: ${datasets.global_transform_args.ppi_hotspot_max_distance}
# 1D SS hypers
max_ss_frac_to_provide: ${datasets.global_transform_args.max_ss_frac_to_provide}
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
# Cropping
crop_size: ${datasets.crop_size}
max_atoms_in_crop: ${datasets.max_atoms_in_crop}
allowed_types: ALL
crop_spatial_probability: ???
crop_contiguous_probability: ???
dna_contact_crop_probability: 0.0
crop_center_cutoff_distance: 15.0
zero_occ_on_exposure_after_cropping: False
b_factor_min: null
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}

View File

@@ -0,0 +1,45 @@
defaults:
- pdb_base
- _self_
dataset:
dataset_parser:
_target_: atomworks.ml.datasets.parsers.InterfacesDFParser
dataset:
name: dna_binder
data: /projects/ml/aa_design/NA_binder_interfaces_df_train.parquet
filters:
# filters common across all PDB datasets
- "resolution < 9.0"
- "cluster.notnull()"
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
columns_to_load:
# columns common across all PDB datasets
- example_id
- pdb_id
- assembly_id
- deposition_date
- resolution
- num_polymer_pn_units
- method
- cluster
- n_prot
- n_nuc
- n_ligand
- n_peptide
# interface specific columns
- pn_unit_1_iid
- pn_unit_2_iid
- pn_unit_1_non_polymer_res_names
- pn_unit_2_non_polymer_res_names
- is_inter_molecule
- all_pn_unit_iids_after_processing
- involves_loi
transform:
crop_contiguous_probability: 0.4
crop_spatial_probability: 0.0
dna_contact_crop_probability: 0.6

View File

@@ -0,0 +1,17 @@
# Free DNA generated from TODO
defaults:
- base_no_weights
- _self_
dataset:
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
dataset:
name: free_dna_dds
data: /projects/ml/prot_dna/hexamer_rebuild/free_dna_dds.csv
transform:
crop_contiguous_probability: 0.25
crop_spatial_probability: 0.75

View File

@@ -0,0 +1,20 @@
defaults:
- base_no_weights
- _self_
dataset:
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
dataset:
name: tf_distillation
data: /projects/ml/prot_dna/transcriptionFactor_distillation_rf3.newDL.csv
columns_to_load:
- example_id
- path
transform:
crop_contiguous_probability: 0.4
crop_spatial_probability: 0.0
dna_contact_crop_probability: 0.6

View File

@@ -0,0 +1,11 @@
# Base config for all PDB datasets
defaults:
- base
- _self_
dataset:
# All PDB datasets load from this cache:
cif_parser_args:
cache_dir: ${paths.data.cif_cache_dir}
load_from_cache: True
save_to_cache: False

View File

@@ -0,0 +1,10 @@
# Inherit
defaults:
- af3_train_interface
- pdb_base
- _self_
dataset:
transform:
crop_contiguous_probability: 0.0
crop_spatial_probability: 1.0

View File

@@ -0,0 +1,10 @@
defaults:
- af3_train_pn_unit
- pdb_base
- _self_
dataset:
transform:
# pn_unit-specific Transform pipeline parameters
crop_contiguous_probability: 0.3333333333333333
crop_spatial_probability: 0.6666666666666667

View File

@@ -0,0 +1,38 @@
defaults:
- pdb/base_transform_args@monomer_distillation
- _self_
monomer_distillation:
dataset:
_target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper
save_failed_examples_to_dir: ${paths.data.failed_examples_dir}
# Explicitly do not load from cache.
# Dataset too big, and structures are small
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
# metadata dataset
dataset:
_target_: atomworks.ml.datasets.datasets.PandasDataset
name: af2fb_distillation
id_column: example_id
data: ${paths.data.monomer_distillation_parquet_dir}/af2_distillation_facebook.parquet
columns_to_load:
- example_id
- path
# metadata parser
dataset_parser:
_target_: atomworks.ml.datasets.parsers.GenericDFParser
pn_unit_iid_colnames: null
transform:
_target_: ${datasets.pipeline_target}
is_inference: False
# protein_msa_dirs: [{"dir": "${paths.data.monomer_distillation_data_dir}/msa", "extension": ".a3m", "directory_depth": 2}]
# rna_msa_dirs: []
crop_contiguous_probability: 0.25
crop_spatial_probability: 0.75
b_factor_min: 70

View File

@@ -0,0 +1,21 @@
defaults:
- design_base
- val/unindexed@val.unindexed
- val/mcsa_41_short_rigid@val.mcsa_41_short_rigid
- _self_
diffusion_batch_size_train: 24
global_transform_args:
train_conditions:
unconditional:
frequency: 3.0
sequence_design:
frequency: 0.5
island:
frequency: 3.0
p_unindex_motif_tokens: 0.5
tipatom:
frequency: 3.0
p_unindex_motif_tokens: 0.5

View File

@@ -0,0 +1,9 @@
defaults:
- design_validation_base
- val_examples/bcov_ppi_easy_medium_with_ori@dataset.data
- _self_
dataset:
eval_every_n: 1
name: bcov-ppi-easy-medium

View File

@@ -0,0 +1 @@
/projects/ml/aa_design/benchmarks

View File

@@ -0,0 +1,51 @@
dataset:
_target_: rfd3.inference.datasets.ContigJsonDataset
# Required parameters for each inheriting dataset
data: ??? # Path to json file
name: ??? # Name for displaying and saving files
eval_every_n: ??? # Evaluate on this dataset every n epochs
subset_to_keys: null # Specific keys in json to keep, ignores all others.
# NB: Used for parsing input files (not for atom_array reloading anymore)
cif_parser_args:
cache_dir: null
load_from_cache: False
save_to_cache: False
add_missing_atoms: False
# Common Transform pipeline components for all PDB datasets
transform:
_target_: ${datasets.pipeline_target}
is_inference: True
return_atom_array: True
diffusion_batch_size: ${datasets.diffusion_batch_size_train}
sigma_data: ${model.net.diffusion_module.sigma_data}
central_atom: ${datasets.global_transform_args.central_atom}
n_atoms_per_token: ${datasets.global_transform_args.n_atoms_per_token}
association_scheme: ${datasets.global_transform_args.association_scheme}
center_option: ${datasets.global_transform_args.center_option}
# Conformers
generate_conformers: ${datasets.global_transform_args.generate_conformers}
generate_conformers_for_non_protein_only: ${datasets.global_transform_args.generate_conformers_for_non_protein_only}
provide_reference_conformer_when_unmasked: ${datasets.global_transform_args.provide_reference_conformer_when_unmasked}
ground_truth_conformer_policy: ${datasets.global_transform_args.ground_truth_conformer_policy}
provide_elements_for_unindexed_components: ${datasets.global_transform_args.provide_elements_for_unindexed_components}
use_element_for_atom_names_of_atomized_tokens: ${datasets.global_transform_args.use_element_for_atom_names_of_atomized_tokens}
residue_cache_dir: ${paths.data.residue_cache_dir}
# PPI
keep_full_binder_in_spatial_crop: ${datasets.global_transform_args.keep_full_binder_in_spatial_crop}
max_binder_length: ${datasets.global_transform_args.max_binder_length}
max_ppi_hotspots_frac_to_provide: ${datasets.global_transform_args.max_ppi_hotspots_frac_to_provide}
ppi_hotspot_max_distance: ${datasets.global_transform_args.ppi_hotspot_max_distance}
# Secondary structure
max_ss_frac_to_provide: ${datasets.global_transform_args.max_ss_frac_to_provide}
min_ss_island_len: ${datasets.global_transform_args.min_ss_island_len}
max_ss_island_len: ${datasets.global_transform_args.max_ss_island_len}
# Other dataset-specific parameters
atom_1d_features: ${model.net.token_initializer.atom_1d_features}
token_1d_features: ${model.net.token_initializer.token_1d_features}

View File

@@ -0,0 +1,9 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/dna_binder.json
name: dna_binder_design
eval_every_n: 1

View File

@@ -0,0 +1,13 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.root_dir}/tests/dna.json
name: dna_binder_design
eval_every_n: 10
subset_to_keys:
- 7rte_sequence_only
- 7rte_with_structure

View File

@@ -0,0 +1,13 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.root_dir}/rfd3/tests/test_data/dna.json
name: dna_binder_design
eval_every_n: 1
subset_to_keys:
- 7rte_sequence_only
- 7rte_with_structure

View File

@@ -0,0 +1,9 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/indexed.json
name: indexed-design
eval_every_n: 8

View File

@@ -0,0 +1,9 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/mcsa_41.json
name: woodys-benchmark
eval_every_n: 16

View File

@@ -0,0 +1,16 @@
defaults:
- mcsa_41
- _self_
dataset:
name: woodys-benchmark-short
eval_every_n: 4
subset_to_keys:
- M0630_1j79
- M0555_1f8r
- M0711_2esd
- M0058_1cju
- M0907_1rbl
- M0157_1qh5

View File

@@ -0,0 +1,10 @@
defaults:
- unindexed
- _self_
dataset:
name: rigid-ligand-enzymes
eval_every_n: 1
data: ${paths.data.design_benchmark_data_dir}/mcsa_41_short_rigid.json

View File

@@ -0,0 +1,7 @@
defaults:
- unconditional
- _self_
dataset:
name: ppi_inference
data: ??? # This is a required override, specifying a path to the dataset json or yaml file.

View File

@@ -0,0 +1,13 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/sm_binder_hbonds.json
eval_every_n: 5
name: sm_binder_hbonds-design
subset_to_keys:
- FAD
- IAI
- OQO
- SAM

View File

@@ -0,0 +1,15 @@
defaults:
- sm_binder_hbonds
- _self_
dataset:
eval_every_n: 1
data: ${paths.data.design_benchmark_data_dir}/sm_binder_hbonds_sampled.json
name: sm_binder_hbonds-design-short
subset_to_keys:
- FAD_1
- FAD_2
- FAD_3
- IAI_1
- IAI_2
- IAI_3

View File

@@ -0,0 +1,9 @@
defaults:
- design_validation_base
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/monomer.json
subset_to_keys: null # Specific keys in json to keep, ignores all others.
name: unconditional-design
eval_every_n: 1

View File

@@ -0,0 +1,9 @@
defaults:
- unconditional
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/unconditional_deep.json
name: unconditional-design-deep
eval_every_n: 8

View File

@@ -0,0 +1,8 @@
defaults:
- unconditional
- _self_
dataset:
data: ${paths.data.design_benchmark_data_dir}/unindexed.json
name: unindexed-design

View File

@@ -0,0 +1,151 @@
global_args:
infer_ori_strategy: hotspots
insulinr:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/insulin_target.pdb
contig: 100-100,/0,B1-150
contig_atoms: '{}'
length: 250-250
redesign_motif_sidechains: false
atom_level_hotspots:
B59:
CG,CZ: 1
B83:
CG,CZ: 1
B91:
CG,CZ: 1
pdl1:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/5o45_pdl1.pdb
contig: 100-100,/0,B1-115
contig_atoms: '{}'
length: 215-215
redesign_motif_sidechains: false
atom_level_hotspots:
B40:
CG,CZ: 1
B99:
CG,SD: 1
B107:
CG,CZ: 1
vegfr:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/vegfr_2x1w_and_af2_B.pdb
contig: 100-100,/0,B1-200
contig_atoms: '{}'
length: 300-300
redesign_motif_sidechains: false
atom_level_hotspots:
B13:
CG1,CG2: 1
B15:
CG,CZ: 1
B43:
CG,CZ: 1
B75:
CG,SD: 1
B89:
CD1,CG2: 1
B91:
CG1,CG2: 1
B187:
CG,CD1: 1
rbd:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/COVID19_target.pdb
contig: 100-100,/0,B1-195
contig_atoms: '{}'
length: 295-295
redesign_motif_sidechains: false
atom_level_hotspots:
B89:
CG,CZ: 1
B121:
CG,CZ: 1
B123:
CG,CD1: 1
B124:
CG,CZ: 1
B141:
CG,CZ: 1
B157:
CG,CZ: 1
B163:
CG,CZ: 1
B165:
CG,CZ: 1
B173:
CG,CZ: 1
cd28:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/cd28_1yjd_B.pdb
contig: 100-100,/0,B1-118
contig_atoms: '{}'
length: 218-218
redesign_motif_sidechains: false
atom_level_hotspots:
B51:
CG,CZ: 1
B61:
CG,CZ: 1
B99:
CG,SD: 1
B104:
CG,CZ: 1
il2ra:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/il2ra_1z92_B.pdb
contig: 100-100,/0,B1-122
contig_atoms: '{}'
length: 222-222
redesign_motif_sidechains: false
atom_level_hotspots:
B3:
CG,CD1: 1
B26:
CG,SD: 1
B43:
CG,CD1: 1
B44:
CG,CZ: 1
B46:
CG,CD1: 1
il10ra:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/il10rb_1lqs_B.pdb
contig: 100-100,/0,B1-207
contig_atoms: '{}'
length: 307-307
redesign_motif_sidechains: false
atom_level_hotspots:
B39:
CG,CD1: 1
B50:
CD1,CG2: 1
B59:
CG,CZ: 1
B63:
CA,CB: 1
B64:
CG1,CG2: 1
B66:
CG,CD1: 1
tie2:
input: /projects/ml/aa_design/benchmarks/bcov_af3_ppi_benchmark/tie2_2gy5_official_B.pdb
contig: 100-100,/0,B1-188
contig_atoms: '{}'
length: 288-288
redesign_motif_sidechains: false
atom_level_hotspots:
B132:
CG1,CG2: 1
B134:
CG,CZ: 1
B135:
CG,CD: 1
B139:
CG,CZ: 1
B140:
CD1,CG2: 1
B154:
CG1,CG2: 1
B156:
CG,CD1: 1
B167:
CG,CZ: 1
B172:
CD1,CG2: 1

View File

@@ -0,0 +1,7 @@
defaults:
- bcov_ppi_easy_medium_with_ori
- _self_
global_args:
use_ss_conditioning: true
spoof_helical_bundle_ss_conditioning: true

View File

@@ -0,0 +1,28 @@
defaults:
- bcov_ppi_easy_medium_with_ori
- _self_
insulinr:
contig: 65-120,/0,B1-150
length: 215-270
pdl1:
contig: 65-120,/0,B1-115
length: 215-215
vegfr:
contig: 65-120,/0,B1-200
length: 300-300
rbd:
contig: 65-120,/0,B1-195
length: 295-295
cd28:
contig: 65-120,/0,B1-118
length: 218-218
il2ra:
contig: 65-120,/0,B1-122
length: 222-222
il10ra:
contig: 65-120,/0,B1-207
length: 307-307
tie2:
contig: 65-120,/0,B1-188
length: 288-288

View File

@@ -0,0 +1,212 @@
defaults:
- bcov_ppi_easy_medium_with_ori
- _self_
# Note: I have removed the more "internal" hotspot atoms on residues that also have
# hbonds to avoid giving conflicting information with the h-bond conditioning.
# For targets with no polar hotspots, I have inserted h-bond conditioning to nearby
# residues.
insulinr:
hbond_donors:
B113:
NH1,NH2: 1
B116:
NZ: 1
B139:
NE2: 1
hbond_acceptors:
B115:
OE2: 1 # OE1 is already h-bonded
B139:
ND1: 1
atom_level_hotspots:
B59:
CG,CZ: 1
B83:
CG,CZ: 1
B91:
CG,CZ: 1
pdl1:
hbond_donors:
B107:
OH: 1
B109:
NH1,NH2: 1
hbond_acceptors:
B40:
OH: 1
B107:
OH: 1
atom_level_hotspots:
B40:
CZ: 1
B99:
CG,SD: 1
B107:
CZ: 1
vegfr:
hbond_donors:
B15:
OH: 1
B43:
OH: 1
hbond_acceptors:
B9: # ASP
OD1,OD2: 1
B15:
OH: 1
B43:
OH: 1
atom_level_hotspots:
B13:
CG1,CG2: 1
B15:
CZ: 1
B43:
CZ: 1
B75:
CG,SD: 1
B89:
CD1,CG2: 1
B91:
CG1,CG2: 1
B187:
CG,CD1: 1
rbd:
hbond_donors:
B89:
OH: 1
B121:
OH: 1
B141:
OH: 1
hbond_acceptors:
B157:
OH: 1
B163:
OH: 1
B173:
OH: 1
atom_level_hotspots:
B89:
CZ: 1
B121:
CZ: 1
B123:
CG,CD1: 1
B124:
CG,CZ: 1
B141:
CZ: 1
B157:
CZ: 1
B163:
CZ: 1
B165:
CG,CZ: 1
B173:
CZ: 1
cd28:
hbond_donors:
B51:
OH: 1
B61:
OH: 1
B104:
OH: 1
hbond_acceptors:
B51:
OH: 1
B61:
OH: 1
B104:
OH: 1
atom_level_hotspots:
B51:
CZ: 1
B61:
CZ: 1
B99:
CG,SD: 1
B104:
CZ: 1
il2ra:
hbond_donors:
B44:
OH: 1
hbond_acceptors:
B44:
OH: 1
B77: # HIS
ND1: 1
B2: # GLU
OE1,OE2: 1
atom_level_hotspots:
B3:
CG,CD1: 1
B26:
CG,SD: 1
B43:
CG,CD1: 1
B44:
CZ: 1
B46:
CG,CD1: 1
il10ra:
hbond_donors:
B48: # ASN
ND2: 1
B59:
OH: 1
hbond_acceptors:
B48: # ASN
OD1: 1
B59:
OH: 1
atom_level_hotspots:
B39:
CG,CD1: 1
B50:
CD1,CG2: 1
B59:
CZ: 1
B63:
CA,CB: 1
B64:
CG1,CG2: 1
B66:
CG,CD1: 1
tie2:
hbond_donors:
B134:
OH: 1
B135:
NZ: 1
B167:
OH: 1
hbond_acceptors:
B134:
OH: 1
B167:
OH: 1
atom_level_hotspots:
B132:
CG1,CG2: 1
B134:
CZ: 1
B135:
CD: 1
B139:
CG,CZ: 1
B140:
CD1,CG2: 1
B154:
CG1,CG2: 1
B156:
CG,CD1: 1
B167:
CZ: 1
B172:
CD1,CG2: 1

View File

@@ -0,0 +1,64 @@
# @package _global_
defaults:
- override /logger: null
# default debugging setup, runs 1 full epoch
# other debugging configs can inherit from this one
# overwrite task name so debugging logs are stored in separate folder
task_name: "debug"
extras:
ignore_warnings: False
enforce_tags: False
# sets level of all command line loggers to 'DEBUG'
# https://hydra.cc/docs/tutorials/basic/running_your_app/logging/
hydra:
job_logging:
root:
level: DEBUG
# use the below to also set hydra loggers to 'DEBUG'
verbose: True
# Print example ID before forward pass
callbacks:
print_example_id_before_forward_pass:
_target_: modelhub.callbacks.train_logging.PrintExampleIDBeforeForwardPassCallback
dataloader:
train:
dataloader_params:
batch_size: 1
num_workers: 0 # debuggers don't like multiprocessing -- work on main thread
pin_memory: False # disable gpu memory pin
prefetch_factor: null # must be null for num_workers=0
n_fallback_retries: 0 # disable fallback retries for debugging
val:
dataloader_params:
batch_size: 1
num_workers: 0
pin_memory: False
prefetch_factor: null # must be null for num_workers=0
datasets:
crop_size: 100 # set small crop size for quick debugging
diffusion_batch_size_train: 1
diffusion_batch_size_inference: 1
n_recycles_train: 1
n_recycles_validation: 1
n_msa: 128
key_to_balance: null # otherwise big examples will be processed first
trainer:
devices_per_node: 1
limit_train_batches: 1
limit_val_batches: 1
validate_every_n_epochs: 1
# Set tags to help identify debugging runs
tags:
- debug

View File

@@ -0,0 +1,21 @@
# @package _global_
# See: https://hydra.cc/docs/patterns/configuring_experiments/
# to execute this experiment run:
# python train.py +debug=train_single_example [any other arguments]
defaults:
- default
- gpu
datasets:
# you can add specific example IDs here to load a subset of the dataset (training)
subset_to_example_ids:
- "{['pdb', 'pn_units']}{3px1}{1}{['A_3']}"
val: null
tags:
- debug
- train
- specific-examples

View File

@@ -0,0 +1,9 @@
# @package _global_
# Inference engine config for development
hydra:
searchpath:
- pkg://configs
defaults:
- inference_engine: dev
- _self_

View File

@@ -0,0 +1,12 @@
# @package _global_
name: train-main
tags:
- atom14
- trunkless
- local-attn
- unet
- ${datasets.global_transform_args.central_atom}
# Optional Pre-trained Design model
ckpt_path: ${paths.data.design_model_weight_dir}/rfd3_latest.ckpt

View File

@@ -0,0 +1,14 @@
# @package _global_
defaults:
- /debug/default
- override /logger: csv
name: a14-debug
tags:
- debug
project: test
ckpt_path: null
trainer:
prevalidate: True

View File

@@ -0,0 +1,71 @@
# @package _global_
defaults:
# - uncond-compact-skip-recycle-permute
- override /datasets: ppi
- override /callbacks: ppi_design_callbacks
# - override /logger: csv # turn off wandb
- _self_
name: ppi
tags:
- atom14
- ppi
- ${datasets.global_transform_args.central_atom}
trainer:
prevalidate: False
checkpoint_every_n_epochs: 5
validate_every_n_epochs: 5
precision: bf16-mixed
read_sequence_from_sequence_head: False
cleanup_virtual_atoms: True
cleanup_guideposts: True
compute_non_clash_metrics_for_diffused_region_only: True
# convert_non_protein_designed_res_to_ala: True
# set_input_seq_for_motif_atoms_with_fixed_seq: True
loss:
verbose_diffusion_loss: null
simple_diffusion_loss:
_target_: rfd3.metrics.losses.SimpleDiffusionLoss
sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
lddt_weight: 0.25
alpha_virtual_atom: 1.0 #0.35
lp_weight: 0.0
unindexed_norm_p: 1.0
alpha_unindexed_diffused: 1.0
unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False
metrics:
ppi_metrics:
_target_: rfd3.metrics.design_metrics.PPIMetrics
distance_cutoff: 4.5
backbone_metrics:
compute_for_diffused_region_only: True
general_metrics:
compute_for_diffused_region_only: True
compute_ss_adherence_if_possible: True
# skip_optimizer_loading: True
# HACK: Seems like we're rolling back the pipelines submodule, so will need to specify local paths
callbacks:
run_pipelines_callback:
pipelines_config_path: ${paths.root_dir}/rfd3/configs/pipelines/aa_ppi_design_af3_validation.yaml
pipelines_script_path: /home/rib7/protocols/pipelines/pipelines/pipeline.py
model:
net:
inference_sampler:
center_option: ${datasets.global_transform_args.center_option}
step_scale: 3 # Slightly reduces diversity and significantly ups alanine content in exchange for much better pass rates
# ckpt_path: /net/scratch/rib7/training/logs/train/unet-indexed/latest_2025-04-17_17-14/ckpt/epoch-0110.ckpt
# ckpt_path: /projects/ml/aa_design/models/uncond-compact-skip-recycle-permute-epoch-0140.ckpt
# ckpt_path: ${paths.data.design_model_weight_dir}/main.ckpt
ckpt_path: ${paths.data.design_model_weight_dir}/rfd3_latest.ckpt

View File

@@ -0,0 +1,47 @@
# @package _global_
defaults:
- ppi
- ../datasets/val/val_examples/bcov_ppi_easy_medium_with_ori@datasets.val.bcov_ppi_easy_medium.dataset.data
- _self_
name: ppi_hdm
tags:
- atom14
- ppi
- atom-level-hotspots
- ${datasets.global_transform_args.central_atom}
- 07-07-branch
- sl2-base
- full-binder-crop
- crop-size-384
- diffused-centered
datasets: # NOTE: in contrast to the old branch, we're using the full atom14 association scheme here
diffusion_batch_size_train: 8
crop_size: 384
global_transform_args:
keep_full_binder_in_spatial_crop: True
max_binder_length: 256
sigma_perturb: 5 # Such that the ori token isn't fixed too closely
center_option: diffuse
meta_conditioning_probabilities:
add_ppi_hotspots: 0.75
full_binder_crop: 1.0 # Note that this leaks the sequence length
train:
pdb:
sub_datasets:
interface:
dataset:
transform:
max_atoms_in_crop: 5000
model:
net:
inference_sampler:
center_option: diffuse
diffusion_module:
diffusion_atom_encoder:
c_atom_1d_features: 395
atom_1d_features:
is_atom_level_hotspot: 1

View File

@@ -0,0 +1,21 @@
# @package _global_
defaults:
- ppi_hdm
- override /datasets: ppi_with_interdomain
- _self_
name: ppi_hdm_dist
tags:
- atom14
- ppi
- atom-level-hotspots
- ${datasets.global_transform_args.central_atom}
- sl2-base
- full-binder-crop
- crop-size-384
- diffused-centered
- motif-t-zero
- seq-head-detached
- glycine-bugfix
- interdomain-distillation

View File

@@ -0,0 +1,35 @@
# @package _global_
defaults:
- ppi_hdm
- ../datasets/val/val_examples/bpem_ori_hb@datasets.val.bcov_ppi_easy_medium.dataset.data
- _self_
name: ppi_hdm_hb
tags:
- atom14
- ppi
- atom-level-hotspots
- ${datasets.global_transform_args.central_atom}
- 05-23-branch
- sl2-base
- full-binder-crop
- crop-size-384
- diffused-centered
- motif-t-zero
- seq-head-detached
- hbond-conditioning
model:
net:
diffusion_module:
diffusion_atom_encoder:
c_atom_1d_features: 397
atom_1d_features:
active_donor: 1
active_acceptor: 1
datasets:
global_transform_args:
meta_conditioning_probabilities:
calculate_hbonds: 0.3 # Fraction of the time we give hbond conditioning
hbond_subsample: 0.5 # Fraction of the time we subsample the hbonds vs giving all of them

View File

@@ -0,0 +1,35 @@
# @package _global_
defaults:
- ppi_hdm
- ../datasets/val/val_examples/bcov_ppi_easy_medium_with_ori_spoof_helical_bundle@datasets.val.bcov_ppi_easy_medium.dataset.data
- _self_
name: ppi_hdm_ss
tags:
- atom14
- ppi
- atom-level-hotspots
- secondary-structure-conditioning
- ${datasets.global_transform_args.central_atom}
- 05-23-branch
- sl2-base
- full-binder-crop
- crop-size-384
- diffused-centered
- motif-t-zero
- seq-head-detached
datasets:
global_transform_args:
meta_conditioning_probabilities:
add_1d_ss_features: 0.3
model:
net:
diffusion_module:
diffusion_atom_encoder:
c_atom_1d_features: 398
atom_1d_features:
is_helix_conditioning: 1
is_sheet_conditioning: 1
is_loop_conditioning: 1

View File

@@ -0,0 +1,31 @@
# @package _global_
defaults:
- ppi_hdm
- _self_
name: ppi_hmm
tags:
- atom14
- ppi
- atom-level-hotspots
- ${datasets.global_transform_args.central_atom}
- 07-07-branch
- sl2-base
- full-binder-crop
- crop-size-384
- motif-centered
datasets:
train:
pdb:
sub_datasets:
interface:
dataset:
transform:
center_option: motif
model:
net:
inference_sampler:
center_option: motif

View File

@@ -0,0 +1,42 @@
# @package _global_
# New base model configuration since last large refactor
# includes
defaults:
- a14-base-train
- override /datasets: dna_base
- override /model: full
- _self_
name: rfd3-NA-full-ft-glycine-fixed
ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
datasets:
diffusion_batch_size_train: 16
crop_size: 256
max_atoms_in_crop: 2560 # 384 * 10
global_transform_args:
association_scheme: atom14-new
sigma_perturb: 4.0 # 3.0
sigma_perturb_com: 3.0
# Down-weight monomer distillation
trainer:
prevalidate: False
loss:
verbose_diffusion_loss: null
simple_diffusion_loss:
_target_: rfd3.metrics.losses.SimpleDiffusionLoss
sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
lddt_weight: 0.25
alpha_virtual_atom: 0.35
lp_weight: 0.0
unindexed_norm_p: 1.0
alpha_unindexed_diffused: 1.0
unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False

View File

@@ -0,0 +1,18 @@
# @package _global_
# Ablation experiment base for RFD3
defaults:
- rfd3-ablation-control
- _self_
name: rfd3-ablation-blt
model:
net:
diffusion_module:
downcast:
method: mean
upcast:
method: broadcast

View File

@@ -0,0 +1,21 @@
# @package _global_
defaults:
- rfd3-ablation-control
- _self_
name: rfd3-ablation-clustering
# Reset the training parquets to the regular ones without structure-clustering
datasets:
train:
pdb:
sub_datasets:
pn_unit:
dataset:
dataset:
data: ${paths.data.pdb_data_dir}/pn_units_df_train.parquet
interface:
dataset:
dataset:
data: ${paths.data.pdb_data_dir}/interfaces_df_train.parquet

View File

@@ -0,0 +1,60 @@
# @package _global_
# Ablation experiment base for RFD3
#
defaults:
- rfd3-full
- _self_
name: rfd3-ablation-control
tags:
- rfd3
- ablation
# Fresh weights
ckpt_path: null
datasets:
# Reset training conditions:
global_transform_args:
# Turn off reference conformer generation
generate_conformers: False
train_conditions:
unconditional:
frequency: 99999
island:
frequency: 0.0
sequence_design:
frequency: 0.0
tipatom:
frequency: 0.0
meta_conditioning_probabilities:
calculate_rasa: 0.0
keep_protein_motif_rasa: 0.0
calculate_hbonds: 0.0
hbond_subsample: 0.0
# fully indexed training
unindex_leak_global_index: 0.0
unindex_insert_random_break: 0.0
unindex_remove_random_break: 0.0
# Reset all validation sets except the unconditionals
val:
indexed: null
unindexed: null
sm_binder_hbonds: null
mcsa_41_short_rigid: null
unconditional:
dataset:
eval_every_n: 1
# Large datasets
unconditional_deep:
dataset:
eval_every_n: 8

View File

@@ -0,0 +1,16 @@
# @package _global_
defaults:
- rfd3-ablation-control
- _self_
name: rfd3-ablation-monomerdistill
# Reset the training parquets to the regular ones without structure-clustering
datasets:
train:
pdb:
probability: 1.0
monomer_distillation:
probability: 0.0

View File

@@ -0,0 +1,18 @@
# @package _global_
defaults:
- rfd3-ablation-control
- _self_
name: rfd3-ablation-recycling
model:
net:
diffusion_module:
n_recycle: 1
inference_sampler:
num_timesteps: 200
trainer:
n_recycles_train: 1

View File

@@ -0,0 +1,12 @@
# @package _global_
defaults:
- rfd3-full-sparse
- _self_
model:
net:
diffusion_module:
c_token: 384
upcast:
n_split: 6

View File

@@ -0,0 +1,24 @@
# @package _global_
# New base model configuration since last large refactor
# includes
defaults:
- rfd3-full
- _self_
name: rfd3-ablation-recycling
model:
net:
diffusion_module:
n_recycle: 2
use_local_token_attention: True
diffusion_transformer:
n_local_tokens: 32
n_keys: 128
inference_sampler:
num_timesteps: 100
trainer:
n_recycles_train: 2

View File

@@ -0,0 +1,42 @@
# @package _global_
# New base model configuration since last large refactor
# includes
defaults:
- rfd3-full-sparse
- _self_
name: rfd3-print
tags: [print-model]
ckpt_path: null
datasets:
global_transform_args:
train_conditions:
unconditional:
frequency: 2.0
island:
frequency: 2.0
sequence_design:
frequency: 0.5
tipatom:
frequency: 5.0
ppi:
frequency: 2.0
train:
pdb:
probability: 0.05
monomer_distillation:
probability: 0.85
# Added in later stages of training:
# interdomain_distillation:
# probability: 0.0
bcov_ppi_distillation:
probability: 0.10
dna_binder: # 6 : 4 : 1
probability: 0.0
na_complex_distillation:
probability: 0.0
free_dna_dds:
probability: 0.0

View File

@@ -0,0 +1,42 @@
# @package _global_
# New base model configuration since last large refactor
# includes
defaults:
- rfd3-full-stage0
- _self_
name: rfd3-print
tags: [print-model]
ckpt_path: null
datasets:
global_transform_args:
train_conditions:
unconditional:
frequency: 2.0
island:
frequency: 2.0
sequence_design:
frequency: 0.5
tipatom:
frequency: 5.0
ppi:
frequency: 5.0
train:
pdb:
probability: 0.20
monomer_distillation:
probability: 0.50
interdomain_distillation:
probability: 0.10
bcov_ppi_distillation:
probability: 0.00
# Added in later stages of training:
dna_binder: # 6 : 4 : 1
probability: 0.14
na_complex_distillation:
probability: 0.05
free_dna_dds:
probability: 0.01

View File

@@ -0,0 +1,187 @@
# @package _global_
defaults:
- override /model: rfd3_base
- override /datasets: all
- _self_
name: rfd3-print
tags: [print-model]
ckpt_path: null
datasets:
diffusion_batch_size_train: 16
crop_size: 384
max_atoms_in_crop: 3840 # 384 * 10
global_transform_args:
association_scheme: dense
sigma_perturb: 2.0 # 4.0
sigma_perturb_com: 1.0 # 3.0
center_option: diffuse # options are ["all", "motif", "diffuse"]
generate_conformers: True
generate_conformers_for_non_protein_only: True
provide_elements_for_unindexed_components: True
use_element_for_atom_names_of_atomized_tokens: True # TODO: correct name, implies unindexed do too
train_conditions:
unconditional:
frequency: 2.0
p_fix_motif_coordinates: 0.5
island:
frequency: 2.0
p_unindex_motif_tokens: 0.5
island_sampling_kwargs:
island_len_max: 12
# Give some love to other kinds of conditioning @magnusb
p_fix_motif_coordinates: 0.9
sequence_design:
frequency: 0.1
# For ChemNet-style sampling < 1.0
p_fix_motif_coordinates: 0.8
p_fix_motif_sequence: 0.1
tipatom:
frequency: 5.0
island_sampling_kwargs:
n_islands_max: 12
subgraph_sampling_kwargs:
residue_p_seed_furthest_from_o: 0.8
residue_n_bond_expectation: 3.0
residue_p_fix_all: 0.05
hetatom_n_bond_expectation: 8
hetatom_p_fix_all: 0.5
p_unindex_motif_tokens: 0.95
ppi:
frequency: 0.0
meta_conditioning_probabilities:
calculate_hbonds: 0.2
calculate_rasa: 0.6
keep_protein_motif_rasa: 0.1 # Small to prevent noisy input to model
hbond_subsample: 0.5
# fully indexed training
unindex_leak_global_index: 0.10
unindex_insert_random_break: 0.10
unindex_remove_random_break: 0.10
# unindex_remove_atom_names: 0.05
# Probability of adding 1d secondary structure conditioning
add_1d_ss_features: 0.1
featurize_plddt: 0.9 # Applied for monomer distillation only
add_global_is_non_loopy_feature: 0.99
# PPI
add_ppi_hotspots: 0.75
full_binder_crop: 0.75
# Down-weight monomer distillation
train:
pdb:
probability: 0.20
sub_datasets:
pn_unit:
dataset:
transform:
# Orig 1/3 - 2/3 split -> bias to spatial
crop_contiguous_probability: 0.25
crop_spatial_probability: 0.75
# Modify: date & clustering parquet
dataset:
# data: ${paths.data.pdb_data_dir}/pn_units_df_structure_clustered_tm_0.8_singletons_removed_len_lt_25_removed.parquet
# data: ${paths.data.pdb_data_dir}/pn_units_df_structure_clustered_tm_0.8.parquet
# data: ${paths.data.pdb_data_dir}/pn_units_df_train_structure_clustered_2021_cutoff_tm_0.8.parquet
# data: ${paths.data.pdb_data_dir}/pn_units_df.parquet
# data: /projects/ml/datahub/dfs/af3_splits/2024_12_16/pn_units_df_train.parquet
filters:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# pn_unit specific filters
- "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
interface:
dataset:
# Modify: date & clustering parquet
dataset:
# data: ${paths.data.pdb_data_dir}/interfaces_df_structure_clustered_tm_0.8_singletons_removed_len_lt_25_removed.parquet
# data: ${paths.data.pdb_data_dir}/interfaces_df_structure_clustered_tm_0.8.parquet
# data: ${paths.data.pdb_data_dir}/interfaces_df_train_structure_clustered_2021_cutoff_tm_0.8.parquet
# data: ${paths.data.pdb_data_dir}/interfaces_df.parquet
# data: /projects/ml/datahub/dfs/af3_splits/2024_12_16/interfaces_df_train.parquet
filters:
# filters common across all PDB datasets
- 'pdb_id not in ["7rte", "7m5w", "7n5u"]'
- 'pdb_id not in ["3di3", "5o45", "1z92", "2gy5", "4zxb"]'
- "deposition_date < '2024-12-16'"
- "resolution < 9.0"
- "num_polymer_pn_units <= 300"
- "cluster.notnull()"
# interface specific filters
- "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${resolve_import:atomworks.constants,AF3_EXCLUDED_LIGANDS_REGEX}', regex=True))"
- "is_inter_molecule"
monomer_distillation:
probability: 0.70
dataset:
transform:
# Orig 1/4 - 3/4 split -> bias to spatial
crop_contiguous_probability: 0.25
crop_spatial_probability: 0.75
# Added in third stage of training:
interdomain_distillation:
probability: 0.10
dna_binder: # 6 : 4 : 1
probability: 0.0
na_complex_distillation:
probability: 0.0
free_dna_dds:
probability: 0.0
trainer:
error_if_grad_nonfinite: False
n_recycles_train: 2
prevalidate: False
validate_every_n_epochs: 4
precision: bf16-mixed
loss:
verbose_diffusion_loss: null
simple_diffusion_loss:
_target_: rfd3.metrics.losses.SimpleDiffusionLoss
sigma_data: ${model.net.diffusion_module.sigma_data}
weight: 4.0
lddt_weight: 0.25
alpha_virtual_atom: 1.0
alpha_polar_residues: 1.0
lp_weight: 0.0
unindexed_norm_p: 1.0
alpha_unindexed_diffused: 1.0
unindexed_t_alpha: 0.75
normalize_virtual_atom_weight: False
alpha_ligand: 10.0
# callbacks:
# activations_tracking_callback:
# _target_: modelhub.callbacks.health_logging.ActivationsGradientsWeightsTracker
# log_freq: 100
# keep_cache: True # --> WARNING: Do not run this in a production run, this will lead to a memory leak! Meant for debugging.
# activations_tracking_callback:
# _target_: modelhub.callbacks.health_logging.ActivationsGradientsWeightsTracker
# log_freq: 100
# keep_cache: True # --> WARNING: Do not run this in a production run, this will lead to a memory leak! Meant for debugging.

View File

@@ -0,0 +1,10 @@
# @package _global_
defaults:
- _self_
datasets:
global_transform_args:
train_conditions:
unconditional:
frequency: 9999999999

View File

@@ -0,0 +1,14 @@
# @package _global_
defaults:
- override /datasets: unindexed
- _self_
name: train-unindexed-main
datasets:
train_conditions:
tipatom:
p_unindex_motif_tokens: 1.0
island:
p_unindex_motif_tokens: 1.0

View File

@@ -0,0 +1,31 @@
# @package _global_
defaults:
- uncond-main
# - override /trainer: aa_design_with_sequence # This has been absorbed into the default aa_design config
# HACK: This dataset config no longer exists. I am simply removing it to overwrite with the ppi dataset
# - override /datasets: uncond_index_mix
- override /model: compact-base
- _self_
name: uncond-compact
trainer:
loss:
verbose_diffusion_loss:
alpha_motif: 1.0
alpha_ca_atom: 1.0
alpha_virtual_atom: 1.0
alpha_fixed_motif: 1.0
align_prediction: False
lddt_weight: 0.50
use_motif_aligned_loss: False
sequence_head_loss:
weight: 0.5
max_t: 2
datasets:
diffusion_batch_size_train: 32
diffusion_batch_size_inference: 10
crop_size: 256

View File

@@ -0,0 +1,53 @@
# @package _global_
defaults:
# - /debug/default
# - override /logger: csv # turn off wandb
# - override /datasets: indexed_conditioned
# - override /trainer: aa_design_with_sequence # This has been absorbed into the default aa_design config
- override /model: unet-recycle
- _self_
name: main
tags:
- atom14
- unet
- uncond
- ${datasets.global_transform_args.central_atom}
trainer:
prevalidate: False
skip_optimizer_loading: False
loss:
verbose_diffusion_loss:
alpha_ca_atom: 1.0
sequence_head_loss:
weight: 0.5
max_t: 2
ckpt_path: null
# validation settings:
callbacks:
dump_validation_structures_callback:
dump_predictions: True
one_model_per_file: True
dump_trajectories: False
dump_every_n: 1
datasets:
diffusion_batch_size_inference: 10
model:
net:
diffusion_module:
use_sequence_head: True
# c_token: 384
# diffusion_transformer:
# n_block: 12
# inference_sampler:
# num_timesteps: 200
# noise_scale: 0.75
# step_scale: 1.5
# num_timesteps: 100

View File

@@ -0,0 +1,18 @@
# https://hydra.cc/docs/configure_hydra/intro/
# enable color logging (requires `colorlog` to be installed)
# defaults:
# - override hydra_logging: colorlog
# - override job_logging: colorlog
# output directory, generated dynamically on each run
run:
dir: ${paths.log_dir}/${task_name}/${name}/${now:%Y-%m-%d}_${now:%H-%M}_JOB_${oc.env:SLURM_JOB_ID,default}
# ... this is where the log file is written (i.e. the programs output)
job_logging:
handlers:
file:
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
filename: ${hydra.runtime.output_dir}/experiment.log

View File

@@ -0,0 +1,7 @@
defaults:
- override job_logging: disabled
- override hydra_logging: disabled
output_subdir: null
run:
dir: .

View File

@@ -0,0 +1,9 @@
# @package _global_
hydra:
searchpath:
- pkg://configs
defaults:
- inference_engine: rfdiffusion3
- _self_

View File

@@ -0,0 +1,16 @@
# @package _global_
defaults:
- /hydra: no_logging
# Parameters for RFD3InferenceEngine.__init__()
ckpt_path: ???
num_nodes: 1
devices_per_node: 1
# Parameters for RFD3InferenceEngine.run()
inputs: ???
out_dir: ???
dump_predictions: true
dump_trajectories: false
one_model_per_file: false

View File

@@ -0,0 +1,20 @@
# @package _global_
defaults:
- rfdiffusion3
- _self_
diffusion_batch_size: 8
n_batches: 1
seed: 42
dump_trajectories: True
print_config: True
skip_existing: False
cleanup_guideposts: False
cleanup_virtual_atoms: False
output_full_json: True
inference_sampler:
gamma_0: 0.0
out_dir: ./logs/benchmark

View File

@@ -0,0 +1,78 @@
# @package _global_
defaults:
- base
- _self_
_target_: rfd3.inference.engine.RFD3InferenceEngine
out_dir: ???
inputs: ??? # null, json, pdb or
ckpt_path: /projects/ml/aa_design/models/rfd3_latest.ckpt
json_keys_subset: null
skip_existing: True
seed: null # if null samples seed integer based on timestamp
#########################################################
# Design spec args: overrides args from input json
specification: {}
#########################################################
# Diffusion args
diffusion_batch_size: 8
n_batches: 1
# Inference sampler args | set to None to use the default in the checkpoint's config
inference_sampler:
kind: "default" # "default" or "symmetry" to choose the sampler
# Classifier-free guidance args:
cfg_features: # set to 0 in the reference CFG step
- active_donor
- active_acceptor
- ref_atomwise_rasa
use_classifier_free_guidance: False
cfg_t_max: null # max t to apply cfg guidance
cfg_scale: 1.5
center_option: "all" # Options are ["all", "motif", "diffuse"]
move_noise_to_reset_com: False # Reset the COM of the diffuse region after the re-noising operation in each diffusion step
s_trans: 1.0 # Translational noise scale for augmentation during inference
fraction_of_steps_to_fix_motif: 0.0 # Fraction of steps to let the model not move the motif. e.g. if we have 10 steps, set this value to 0.2 will make model not move motif for the last 2 steps.
skip_few_diffusion_steps: False # Choose to skip some diffusion steps based on the noise scheme
inference_noise_scaling_factor: 1.0
allow_realignment: False
zero_drift_noise: False
use_frame_guidance: False
# Diffusion args:
num_timesteps: 200
step_scale: 1.5 # 1.5 - 1.0 | Higher values lead to less diverse, more designable, structures
noise_scale: 1.003
p: 7
gamma_0: 0.6 # Previously 1.0 | 0.0 for ODE sampling
gamma_min: 1.0
gamma_min2: 0.0
s_jitter_origin: 0.0 # Sigma of gaussian noise to jitter the motif offset (equivalent to ORI token Jitter)
# Saving args
cleanup_guideposts: True
cleanup_virtual_atoms: True
read_sequence_from_sequence_head: True
output_full_json: True
# Prefix to add to all output samples
# Default: None -> f'{jsonfilebasename}_{jsonkey}_{batch}_{model}'
# Otherwise: string -> f'{string}{jsonkey}_{batch}_{model}'
# e.g. Empty string -> f'{jsonkey}_{batch}_{model}'
# e.g. Chunk string -> f'{chunkprefix_}{jsonkey}_{batch}_{model}' (pipelines usage)
global_prefix: null
dump_prediction_metadata_json: True
dump_trajectories: False
one_model_per_file: True
align_trajectory_structures: False
# Additional args
print_config: False
prevalidate_inputs: True
low_memory_mode: False # False for standard mode, True for memory efficient tokenization mode

View File

@@ -0,0 +1,6 @@
# https://lightning.ai/docs/fabric/latest/api/generated/lightning.fabric.loggers.CSVLogger.html#lightning.fabric.loggers.CSVLogger
csv:
_target_: lightning.fabric.loggers.CSVLogger
root_dir: ${paths.output_dir}
flush_logs_every_n_steps: 1

View File

@@ -0,0 +1,3 @@
defaults:
- wandb
- csv

Some files were not shown because too many files have changed in this diff Show More