diff --git a/plots/variance_schedules/cosine_var_schedule.pdf b/plots/variance_schedules/cosine_var_schedule.pdf index a3d9f22..3133894 100644 Binary files a/plots/variance_schedules/cosine_var_schedule.pdf and b/plots/variance_schedules/cosine_var_schedule.pdf differ diff --git a/plots/variance_schedules/linear_var_schedule.pdf b/plots/variance_schedules/linear_var_schedule.pdf index 2daa588..70562cb 100644 Binary files a/plots/variance_schedules/linear_var_schedule.pdf and b/plots/variance_schedules/linear_var_schedule.pdf differ diff --git a/plots/variance_schedules/quadratic_var_schedule.pdf b/plots/variance_schedules/quadratic_var_schedule.pdf index 8469237..b148fdf 100644 Binary files a/plots/variance_schedules/quadratic_var_schedule.pdf and b/plots/variance_schedules/quadratic_var_schedule.pdf differ diff --git a/protdiff/beta_schedules.py b/protdiff/beta_schedules.py index ac3a9e5..0918e30 100644 --- a/protdiff/beta_schedules.py +++ b/protdiff/beta_schedules.py @@ -86,10 +86,11 @@ def plot_variance_schedule( logging.info( f"Plotting {keyword} variance schedule with {timesteps} timesteps, ranging from {np.min(variance_vals)}-{np.max(variance_vals)}" ) - + alpha_beta_vals = compute_alphas(variance_vals) fig, ax = plt.subplots(dpi=300) - ax.plot(np.arange(timesteps), variance_vals) - fig.savefig(fname) + for k, v in alpha_beta_vals.items(): + ax.plot(np.arange(timesteps), v.numpy(), label=k, alpha=0.7) + fig.savefig(fname + "_alpha_beta.pdf") if __name__ == "__main__":