Support for gradient clipping

This commit is contained in:
Kevin Wu
2022-07-08 19:31:37 +00:00
parent b4f9033e78
commit 510df0935b
2 changed files with 6 additions and 0 deletions

View File

@@ -72,6 +72,7 @@ def train(
results_dir: str = "./results",
timesteps: int = 1000,
variance_schedule: SCHEDULES = "linear",
gradient_clip: float = 0.0,
batch_size: int = 128,
lr: float = 1e-4,
epochs: int = 200,
@@ -112,6 +113,7 @@ def train(
trainer = pl.Trainer(
default_root_dir=results_folder,
gradient_clip_val=gradient_clip,
max_epochs=epochs,
check_val_every_n_epoch=1,
callbacks=[

View File

@@ -17,5 +17,9 @@
"lr": [
1e-3,
1e-4
],
"gradient_clip": [
0.0,
0.5
]
}