mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Add plots of variance schedules
This commit is contained in:
BIN
plots/variance_schedules/cosine_var_schedule.pdf
Normal file
BIN
plots/variance_schedules/cosine_var_schedule.pdf
Normal file
Binary file not shown.
BIN
plots/variance_schedules/linear_var_schedule.pdf
Normal file
BIN
plots/variance_schedules/linear_var_schedule.pdf
Normal file
Binary file not shown.
BIN
plots/variance_schedules/quadratic_var_schedule.pdf
Normal file
BIN
plots/variance_schedules/quadratic_var_schedule.pdf
Normal file
Binary file not shown.
@@ -1,7 +1,12 @@
|
||||
"""
|
||||
Describe beta schedules
|
||||
"""
|
||||
from typing import Literal
|
||||
import os
|
||||
from typing import Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import torch
|
||||
|
||||
SCHEDULES = Literal["linear", "cosine", "quadratic"]
|
||||
@@ -45,3 +50,30 @@ def get_variance_schedule(keyword: SCHEDULES, timesteps: int, **kwargs) -> torch
|
||||
return quadratic_beta_schedule(timesteps, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized variance schedule: {keyword}")
|
||||
|
||||
|
||||
def plot_variance_schedule(
|
||||
fname: str, keyword: SCHEDULES, timesteps: int = 1000, **kwargs
|
||||
):
|
||||
"""
|
||||
Plot the given variance schedule
|
||||
"""
|
||||
variance_vals = get_variance_schedule(
|
||||
keyword=keyword, timesteps=timesteps, **kwargs
|
||||
).numpy()
|
||||
|
||||
fig, ax = plt.subplots(dpi=300)
|
||||
ax.plot(np.arange(timesteps), variance_vals)
|
||||
fig.savefig(fname)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from plotting import PLOT_DIR
|
||||
|
||||
var_plot_dir = os.path.join(PLOT_DIR, "variance_schedules")
|
||||
if not os.path.isdir(var_plot_dir):
|
||||
os.makedirs(var_plot_dir)
|
||||
for s in get_args(SCHEDULES):
|
||||
plot_variance_schedule(
|
||||
os.path.join(var_plot_dir, f"{s}_var_schedule.pdf"), keyword=s
|
||||
)
|
||||
|
||||
9
protdiff/plotting.py
Normal file
9
protdiff/plotting.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Utility functions for plotting
|
||||
"""
|
||||
|
||||
|
||||
import os
|
||||
|
||||
PLOT_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "plots")
|
||||
assert os.path.isdir(PLOT_DIR)
|
||||
Reference in New Issue
Block a user