Add plots of variance schedules

This commit is contained in:
Kevin Wu
2022-07-13 22:40:02 +00:00
parent 7b6361c759
commit e6200e3975
5 changed files with 42 additions and 1 deletions

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,7 +1,12 @@
"""
Describe beta schedules
"""
from typing import Literal
import os
from typing import Literal, get_args
import numpy as np
from matplotlib import pyplot as plt
import torch
SCHEDULES = Literal["linear", "cosine", "quadratic"]
@@ -45,3 +50,30 @@ def get_variance_schedule(keyword: SCHEDULES, timesteps: int, **kwargs) -> torch
return quadratic_beta_schedule(timesteps, **kwargs)
else:
raise ValueError(f"Unrecognized variance schedule: {keyword}")
def plot_variance_schedule(
fname: str, keyword: SCHEDULES, timesteps: int = 1000, **kwargs
):
"""
Plot the given variance schedule
"""
variance_vals = get_variance_schedule(
keyword=keyword, timesteps=timesteps, **kwargs
).numpy()
fig, ax = plt.subplots(dpi=300)
ax.plot(np.arange(timesteps), variance_vals)
fig.savefig(fname)
if __name__ == "__main__":
from plotting import PLOT_DIR
var_plot_dir = os.path.join(PLOT_DIR, "variance_schedules")
if not os.path.isdir(var_plot_dir):
os.makedirs(var_plot_dir)
for s in get_args(SCHEDULES):
plot_variance_schedule(
os.path.join(var_plot_dir, f"{s}_var_schedule.pdf"), keyword=s
)

9
protdiff/plotting.py Normal file
View File

@@ -0,0 +1,9 @@
"""
Utility functions for plotting
"""
import os
PLOT_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "plots")
assert os.path.isdir(PLOT_DIR)