From bb3d30ce57fbd697c72be2523c1b8df721f3d9de Mon Sep 17 00:00:00 2001 From: Kevin Wu Date: Sun, 28 Aug 2022 12:50:44 -0700 Subject: [PATCH] Reimplement plotting validation distances at timestep t to dynamically use feature names --- protdiff/plotting.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/protdiff/plotting.py b/protdiff/plotting.py index 6b7126d..d54e84d 100644 --- a/protdiff/plotting.py +++ b/protdiff/plotting.py @@ -58,24 +58,32 @@ def plot_val_dists_at_t( assert vals["t"].item() == t, f"Unexpected values of t: {vals['t']} != {t}" retval.append(select_by_attn(vals)) vals_flat = torch.vstack(retval).numpy() - assert len(vals_flat.shape) == 2 + assert vals_flat.ndim == 2 + + ft_names = dset.feature_names["angles"] + n_fts = len(ft_names) + assert vals_flat.shape[1] == n_fts fig, axes = plt.subplots( - 2, 2, sharex=share_axes, sharey=share_axes, dpi=300, figsize=(9, 7) + nrows=1, + ncols=n_fts, + sharex=share_axes, + sharey=share_axes, + dpi=300, + figsize=(2.6 * n_fts, 2.5), ) - for i, ax in enumerate(axes.flatten()): - val_name = ["dist", "omega", "theta", "phi"][i] + for i, (ax, ft_name) in enumerate(zip(axes, ft_names)): # Plot the values vals = vals_flat[:, i] sns.histplot(vals, ax=ax) - if val_name != "dist": + if "dist" not in ft_name: if zero_center_angles: ax.axvline(np.pi, color="tab:orange") ax.axvline(-np.pi, color="tab:orange") else: ax.axvline(0, color="tab:orange") ax.axvline(2 * np.pi, color="tab:orange") - ax.set(title=f"Timestep {t} - {val_name}") + ax.set(title=f"Timestep {t} - {ft_name}") if fname is not None: fig.savefig(fname, bbox_inches="tight") return fig