Reimplement plotting validation distances at timestep t to dynamically use feature names

This commit is contained in:
Kevin Wu
2022-08-28 12:50:44 -07:00
parent 97c7bde6d6
commit bb3d30ce57

View File

@@ -58,24 +58,32 @@ def plot_val_dists_at_t(
assert vals["t"].item() == t, f"Unexpected values of t: {vals['t']} != {t}"
retval.append(select_by_attn(vals))
vals_flat = torch.vstack(retval).numpy()
assert len(vals_flat.shape) == 2
assert vals_flat.ndim == 2
ft_names = dset.feature_names["angles"]
n_fts = len(ft_names)
assert vals_flat.shape[1] == n_fts
fig, axes = plt.subplots(
2, 2, sharex=share_axes, sharey=share_axes, dpi=300, figsize=(9, 7)
nrows=1,
ncols=n_fts,
sharex=share_axes,
sharey=share_axes,
dpi=300,
figsize=(2.6 * n_fts, 2.5),
)
for i, ax in enumerate(axes.flatten()):
val_name = ["dist", "omega", "theta", "phi"][i]
for i, (ax, ft_name) in enumerate(zip(axes, ft_names)):
# Plot the values
vals = vals_flat[:, i]
sns.histplot(vals, ax=ax)
if val_name != "dist":
if "dist" not in ft_name:
if zero_center_angles:
ax.axvline(np.pi, color="tab:orange")
ax.axvline(-np.pi, color="tab:orange")
else:
ax.axvline(0, color="tab:orange")
ax.axvline(2 * np.pi, color="tab:orange")
ax.set(title=f"Timestep {t} - {val_name}")
ax.set(title=f"Timestep {t} - {ft_name}")
if fname is not None:
fig.savefig(fname, bbox_inches="tight")
return fig