mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Support for gradient clipping
This commit is contained in:
@@ -72,6 +72,7 @@ def train(
|
||||
results_dir: str = "./results",
|
||||
timesteps: int = 1000,
|
||||
variance_schedule: SCHEDULES = "linear",
|
||||
gradient_clip: float = 0.0,
|
||||
batch_size: int = 128,
|
||||
lr: float = 1e-4,
|
||||
epochs: int = 200,
|
||||
@@ -112,6 +113,7 @@ def train(
|
||||
|
||||
trainer = pl.Trainer(
|
||||
default_root_dir=results_folder,
|
||||
gradient_clip_val=gradient_clip,
|
||||
max_epochs=epochs,
|
||||
check_val_every_n_epoch=1,
|
||||
callbacks=[
|
||||
|
||||
@@ -17,5 +17,9 @@
|
||||
"lr": [
|
||||
1e-3,
|
||||
1e-4
|
||||
],
|
||||
"gradient_clip": [
|
||||
0.0,
|
||||
0.5
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user