mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Reimplement plotting validation distances at timestep t to dynamically use feature names
This commit is contained in:
@@ -58,24 +58,32 @@ def plot_val_dists_at_t(
|
||||
assert vals["t"].item() == t, f"Unexpected values of t: {vals['t']} != {t}"
|
||||
retval.append(select_by_attn(vals))
|
||||
vals_flat = torch.vstack(retval).numpy()
|
||||
assert len(vals_flat.shape) == 2
|
||||
assert vals_flat.ndim == 2
|
||||
|
||||
ft_names = dset.feature_names["angles"]
|
||||
n_fts = len(ft_names)
|
||||
assert vals_flat.shape[1] == n_fts
|
||||
|
||||
fig, axes = plt.subplots(
|
||||
2, 2, sharex=share_axes, sharey=share_axes, dpi=300, figsize=(9, 7)
|
||||
nrows=1,
|
||||
ncols=n_fts,
|
||||
sharex=share_axes,
|
||||
sharey=share_axes,
|
||||
dpi=300,
|
||||
figsize=(2.6 * n_fts, 2.5),
|
||||
)
|
||||
for i, ax in enumerate(axes.flatten()):
|
||||
val_name = ["dist", "omega", "theta", "phi"][i]
|
||||
for i, (ax, ft_name) in enumerate(zip(axes, ft_names)):
|
||||
# Plot the values
|
||||
vals = vals_flat[:, i]
|
||||
sns.histplot(vals, ax=ax)
|
||||
if val_name != "dist":
|
||||
if "dist" not in ft_name:
|
||||
if zero_center_angles:
|
||||
ax.axvline(np.pi, color="tab:orange")
|
||||
ax.axvline(-np.pi, color="tab:orange")
|
||||
else:
|
||||
ax.axvline(0, color="tab:orange")
|
||||
ax.axvline(2 * np.pi, color="tab:orange")
|
||||
ax.set(title=f"Timestep {t} - {val_name}")
|
||||
ax.set(title=f"Timestep {t} - {ft_name}")
|
||||
if fname is not None:
|
||||
fig.savefig(fname, bbox_inches="tight")
|
||||
return fig
|
||||
|
||||
Reference in New Issue
Block a user