From a9ced98137d6eaa9e2777530dc1e089a436aefe3 Mon Sep 17 00:00:00 2001 From: jbutch Date: Tue, 2 Dec 2025 02:32:40 -0800 Subject: [PATCH] Fixes to training --- .../configs/callbacks/design_callbacks.yaml | 18 ------------------ .../configs/callbacks/metrics_logging.yaml | 16 ++++++++++++++++ 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/models/rfd3/configs/callbacks/design_callbacks.yaml b/models/rfd3/configs/callbacks/design_callbacks.yaml index ea1a33b..309b492 100644 --- a/models/rfd3/configs/callbacks/design_callbacks.yaml +++ b/models/rfd3/configs/callbacks/design_callbacks.yaml @@ -2,24 +2,6 @@ defaults: - train_logging - _self_ -# Validation metrics: -dump_validation_structures_callback: - _target_: rfd3.trainer.dump_validation_structures.DumpValidationStructuresCallback - save_dir: ${paths.output_dir}/val_structures - dump_predictions: True - dump_prediction_metadata_json: True - dump_trajectories: False - dump_denoised_trajectories_only: False - - one_model_per_file: True - dump_every_n: 4 - align_trajectories: False - verbose: False - -# Other: -log_design_validation_metrics_callback: - _target_: rfd3.callbacks.LogDesignValidationMetricsCallback - log_learning_rate_callback: log_every_n: 25 # default 10 diff --git a/models/rfd3/configs/callbacks/metrics_logging.yaml b/models/rfd3/configs/callbacks/metrics_logging.yaml index c6ed3c2..56d35cd 100644 --- a/models/rfd3/configs/callbacks/metrics_logging.yaml +++ b/models/rfd3/configs/callbacks/metrics_logging.yaml @@ -2,3 +2,19 @@ store_validation_metrics_in_df_callback: _target_: modelhub.callbacks.metrics_logging.StoreValidationMetricsInDFCallback save_dir: ${paths.output_dir}/val_metrics metrics_to_save: "all" + +dump_validation_structures_callback: + _target_: rfd3.trainer.dump_validation_structures.DumpValidationStructuresCallback + save_dir: ${paths.output_dir}/val_structures + dump_predictions: True + dump_prediction_metadata_json: True + dump_trajectories: False + dump_denoised_trajectories_only: False + + one_model_per_file: True + dump_every_n: 4 + align_trajectories: False + verbose: False + +log_design_validation_metrics_callback: + _target_: rfd3.callbacks.LogDesignValidationMetricsCallback