diff --git a/bin/sample.py b/bin/sample.py index d836dec..f417e10 100644 --- a/bin/sample.py +++ b/bin/sample.py @@ -93,9 +93,13 @@ def plot_distribution_overlap( fname: str = "", ax=None, show_legend: bool = True, + **kwargs, ): """ Plot the distribution overlap between the training and sampled values + Additional arguments are given to ax.hist; for example, can specify + histtype='step', cumulative=True + to get a CDF plot """ # Plot the distribution overlap logging.info(f"Plotting distribution overlap for {ft_name}") @@ -106,18 +110,18 @@ def plot_distribution_overlap( bins=50, density=True, label="Training", - color="tab:blue", + # color="tab:blue", alpha=0.6, - edgecolor="black", + **kwargs, ) ax.hist( sampled_values, bins=bins, density=True, label="Sampled", - color="tab:orange", + # color="tab:orange", alpha=0.6, - edgecolor="black", + **kwargs, ) ax.set(title=f"Sampled distribution - {ft_name}") if show_legend: @@ -267,9 +271,12 @@ def main() -> None: model, train_dset, n=10, - sweep_lengths=(50, train_dset.dset.pad), + # sweep_lengths=(50, train_dset.dset.pad), + sweep_lengths=(50, 52), # Dummy values batch_size=args.batchsize, ) + else: + raise ValueError(f"Unrecognized length mode: {args.lengths}") final_sampled = [s[-1] for s in sampled] sampled_dfs = [ pd.DataFrame(s, columns=train_dset.feature_names["angles"]) @@ -316,6 +323,9 @@ def main() -> None: multi_fig, multi_axes = plt.subplots( dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True ) + step_multi_fig, step_multi_axes = plt.subplots( + dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True + ) final_sampled_stacked = np.vstack(final_sampled) for i, ft_name in enumerate(train_dset.feature_names["angles"]): orig_values = train_values_stacked[:, i] @@ -323,6 +333,14 @@ def main() -> None: plot_distribution_overlap( orig_values, samp_values, ft_name, fname=plotdir / f"dist_{ft_name}.pdf" ) + plot_distribution_overlap( + orig_values, + samp_values, + ft_name, + histtype="step", + cumulative=True, + fname=plotdir / f"cdf_{ft_name}.pdf", + ) plot_distribution_overlap( orig_values, samp_values, @@ -330,7 +348,17 @@ def main() -> None: ax=multi_axes.flatten()[i], show_legend=i == 0, ) + plot_distribution_overlap( + orig_values, + samp_values, + ft_name, + cumulative=True, + histtype="step", + ax=step_multi_axes.flatten()[i], + show_legend=i == 0, + ) multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight") + step_multi_fig.savefig(plotdir / "cdf_combined.pdf", bbox_inches="tight") # Generate ramachandran plot for sampled angles plotting.plot_joint_kde(