diff --git a/bin/train.py b/bin/train.py index e89826a..5d2ebe5 100644 --- a/bin/train.py +++ b/bin/train.py @@ -72,6 +72,7 @@ def train( results_dir: str = "./results", timesteps: int = 1000, variance_schedule: SCHEDULES = "linear", + gradient_clip: float = 0.0, batch_size: int = 128, lr: float = 1e-4, epochs: int = 200, @@ -112,6 +113,7 @@ def train( trainer = pl.Trainer( default_root_dir=results_folder, + gradient_clip_val=gradient_clip, max_epochs=epochs, check_val_every_n_epoch=1, callbacks=[ diff --git a/hyperparam_jsons/bert_angle_diffusion_configs.json b/hyperparam_jsons/bert_angle_diffusion_configs.json index 19776a4..211260e 100644 --- a/hyperparam_jsons/bert_angle_diffusion_configs.json +++ b/hyperparam_jsons/bert_angle_diffusion_configs.json @@ -17,5 +17,9 @@ "lr": [ 1e-3, 1e-4 + ], + "gradient_clip": [ + 0.0, + 0.5 ] } \ No newline at end of file