More plotting, including CDF

This commit is contained in:
Kevin Wu
2022-09-15 14:30:58 -07:00
parent b494b067d0
commit bead02473a

View File

@@ -93,9 +93,13 @@ def plot_distribution_overlap(
fname: str = "",
ax=None,
show_legend: bool = True,
**kwargs,
):
"""
Plot the distribution overlap between the training and sampled values
Additional arguments are given to ax.hist; for example, can specify
histtype='step', cumulative=True
to get a CDF plot
"""
# Plot the distribution overlap
logging.info(f"Plotting distribution overlap for {ft_name}")
@@ -106,18 +110,18 @@ def plot_distribution_overlap(
bins=50,
density=True,
label="Training",
color="tab:blue",
# color="tab:blue",
alpha=0.6,
edgecolor="black",
**kwargs,
)
ax.hist(
sampled_values,
bins=bins,
density=True,
label="Sampled",
color="tab:orange",
# color="tab:orange",
alpha=0.6,
edgecolor="black",
**kwargs,
)
ax.set(title=f"Sampled distribution - {ft_name}")
if show_legend:
@@ -267,9 +271,12 @@ def main() -> None:
model,
train_dset,
n=10,
sweep_lengths=(50, train_dset.dset.pad),
# sweep_lengths=(50, train_dset.dset.pad),
sweep_lengths=(50, 52), # Dummy values
batch_size=args.batchsize,
)
else:
raise ValueError(f"Unrecognized length mode: {args.lengths}")
final_sampled = [s[-1] for s in sampled]
sampled_dfs = [
pd.DataFrame(s, columns=train_dset.feature_names["angles"])
@@ -316,6 +323,9 @@ def main() -> None:
multi_fig, multi_axes = plt.subplots(
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
)
step_multi_fig, step_multi_axes = plt.subplots(
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
)
final_sampled_stacked = np.vstack(final_sampled)
for i, ft_name in enumerate(train_dset.feature_names["angles"]):
orig_values = train_values_stacked[:, i]
@@ -323,6 +333,14 @@ def main() -> None:
plot_distribution_overlap(
orig_values, samp_values, ft_name, fname=plotdir / f"dist_{ft_name}.pdf"
)
plot_distribution_overlap(
orig_values,
samp_values,
ft_name,
histtype="step",
cumulative=True,
fname=plotdir / f"cdf_{ft_name}.pdf",
)
plot_distribution_overlap(
orig_values,
samp_values,
@@ -330,7 +348,17 @@ def main() -> None:
ax=multi_axes.flatten()[i],
show_legend=i == 0,
)
plot_distribution_overlap(
orig_values,
samp_values,
ft_name,
cumulative=True,
histtype="step",
ax=step_multi_axes.flatten()[i],
show_legend=i == 0,
)
multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight")
step_multi_fig.savefig(plotdir / "cdf_combined.pdf", bbox_inches="tight")
# Generate ramachandran plot for sampled angles
plotting.plot_joint_kde(