mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
More plotting, including CDF
This commit is contained in:
@@ -93,9 +93,13 @@ def plot_distribution_overlap(
|
||||
fname: str = "",
|
||||
ax=None,
|
||||
show_legend: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Plot the distribution overlap between the training and sampled values
|
||||
Additional arguments are given to ax.hist; for example, can specify
|
||||
histtype='step', cumulative=True
|
||||
to get a CDF plot
|
||||
"""
|
||||
# Plot the distribution overlap
|
||||
logging.info(f"Plotting distribution overlap for {ft_name}")
|
||||
@@ -106,18 +110,18 @@ def plot_distribution_overlap(
|
||||
bins=50,
|
||||
density=True,
|
||||
label="Training",
|
||||
color="tab:blue",
|
||||
# color="tab:blue",
|
||||
alpha=0.6,
|
||||
edgecolor="black",
|
||||
**kwargs,
|
||||
)
|
||||
ax.hist(
|
||||
sampled_values,
|
||||
bins=bins,
|
||||
density=True,
|
||||
label="Sampled",
|
||||
color="tab:orange",
|
||||
# color="tab:orange",
|
||||
alpha=0.6,
|
||||
edgecolor="black",
|
||||
**kwargs,
|
||||
)
|
||||
ax.set(title=f"Sampled distribution - {ft_name}")
|
||||
if show_legend:
|
||||
@@ -267,9 +271,12 @@ def main() -> None:
|
||||
model,
|
||||
train_dset,
|
||||
n=10,
|
||||
sweep_lengths=(50, train_dset.dset.pad),
|
||||
# sweep_lengths=(50, train_dset.dset.pad),
|
||||
sweep_lengths=(50, 52), # Dummy values
|
||||
batch_size=args.batchsize,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized length mode: {args.lengths}")
|
||||
final_sampled = [s[-1] for s in sampled]
|
||||
sampled_dfs = [
|
||||
pd.DataFrame(s, columns=train_dset.feature_names["angles"])
|
||||
@@ -316,6 +323,9 @@ def main() -> None:
|
||||
multi_fig, multi_axes = plt.subplots(
|
||||
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
|
||||
)
|
||||
step_multi_fig, step_multi_axes = plt.subplots(
|
||||
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
|
||||
)
|
||||
final_sampled_stacked = np.vstack(final_sampled)
|
||||
for i, ft_name in enumerate(train_dset.feature_names["angles"]):
|
||||
orig_values = train_values_stacked[:, i]
|
||||
@@ -323,6 +333,14 @@ def main() -> None:
|
||||
plot_distribution_overlap(
|
||||
orig_values, samp_values, ft_name, fname=plotdir / f"dist_{ft_name}.pdf"
|
||||
)
|
||||
plot_distribution_overlap(
|
||||
orig_values,
|
||||
samp_values,
|
||||
ft_name,
|
||||
histtype="step",
|
||||
cumulative=True,
|
||||
fname=plotdir / f"cdf_{ft_name}.pdf",
|
||||
)
|
||||
plot_distribution_overlap(
|
||||
orig_values,
|
||||
samp_values,
|
||||
@@ -330,7 +348,17 @@ def main() -> None:
|
||||
ax=multi_axes.flatten()[i],
|
||||
show_legend=i == 0,
|
||||
)
|
||||
plot_distribution_overlap(
|
||||
orig_values,
|
||||
samp_values,
|
||||
ft_name,
|
||||
cumulative=True,
|
||||
histtype="step",
|
||||
ax=step_multi_axes.flatten()[i],
|
||||
show_legend=i == 0,
|
||||
)
|
||||
multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight")
|
||||
step_multi_fig.savefig(plotdir / "cdf_combined.pdf", bbox_inches="tight")
|
||||
|
||||
# Generate ramachandran plot for sampled angles
|
||||
plotting.plot_joint_kde(
|
||||
|
||||
Reference in New Issue
Block a user