diff --git a/plots/variance_schedules/cosine_var_schedule.pdf b/plots/variance_schedules/cosine_var_schedule.pdf new file mode 100644 index 0000000..a3d9f22 Binary files /dev/null and b/plots/variance_schedules/cosine_var_schedule.pdf differ diff --git a/plots/variance_schedules/linear_var_schedule.pdf b/plots/variance_schedules/linear_var_schedule.pdf new file mode 100644 index 0000000..2daa588 Binary files /dev/null and b/plots/variance_schedules/linear_var_schedule.pdf differ diff --git a/plots/variance_schedules/quadratic_var_schedule.pdf b/plots/variance_schedules/quadratic_var_schedule.pdf new file mode 100644 index 0000000..8469237 Binary files /dev/null and b/plots/variance_schedules/quadratic_var_schedule.pdf differ diff --git a/protdiff/beta_schedules.py b/protdiff/beta_schedules.py index 9cfe1a0..8a5eff0 100644 --- a/protdiff/beta_schedules.py +++ b/protdiff/beta_schedules.py @@ -1,7 +1,12 @@ """ Describe beta schedules """ -from typing import Literal +import os +from typing import Literal, get_args + +import numpy as np +from matplotlib import pyplot as plt + import torch SCHEDULES = Literal["linear", "cosine", "quadratic"] @@ -45,3 +50,30 @@ def get_variance_schedule(keyword: SCHEDULES, timesteps: int, **kwargs) -> torch return quadratic_beta_schedule(timesteps, **kwargs) else: raise ValueError(f"Unrecognized variance schedule: {keyword}") + + +def plot_variance_schedule( + fname: str, keyword: SCHEDULES, timesteps: int = 1000, **kwargs +): + """ + Plot the given variance schedule + """ + variance_vals = get_variance_schedule( + keyword=keyword, timesteps=timesteps, **kwargs + ).numpy() + + fig, ax = plt.subplots(dpi=300) + ax.plot(np.arange(timesteps), variance_vals) + fig.savefig(fname) + + +if __name__ == "__main__": + from plotting import PLOT_DIR + + var_plot_dir = os.path.join(PLOT_DIR, "variance_schedules") + if not os.path.isdir(var_plot_dir): + os.makedirs(var_plot_dir) + for s in get_args(SCHEDULES): + plot_variance_schedule( + os.path.join(var_plot_dir, f"{s}_var_schedule.pdf"), keyword=s + ) diff --git a/protdiff/plotting.py b/protdiff/plotting.py new file mode 100644 index 0000000..5475486 --- /dev/null +++ b/protdiff/plotting.py @@ -0,0 +1,9 @@ +""" +Utility functions for plotting +""" + + +import os + +PLOT_DIR = os.path.join(os.path.dirname(os.path.dirname(__file__)), "plots") +assert os.path.isdir(PLOT_DIR)