mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Code to generate a combo plot for generated angles
This commit is contained in:
@@ -11,7 +11,6 @@ from typing import *
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib import pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
import torch
|
||||
|
||||
@@ -92,6 +91,7 @@ def plot_distribution_overlap(
|
||||
ft_name: str,
|
||||
fname: str = "",
|
||||
ax=None,
|
||||
show_legend: bool = True,
|
||||
):
|
||||
"""
|
||||
Plot the distribution overlap between the training and sampled values
|
||||
@@ -102,7 +102,7 @@ def plot_distribution_overlap(
|
||||
fig, ax = plt.subplots(dpi=300)
|
||||
_n, bins, _pbatches = ax.hist(
|
||||
train_values,
|
||||
bins=40,
|
||||
bins=50,
|
||||
density=True,
|
||||
label="Training",
|
||||
color="tab:blue",
|
||||
@@ -119,7 +119,8 @@ def plot_distribution_overlap(
|
||||
edgecolor="black",
|
||||
)
|
||||
ax.set(title=f"Sampled distribution - {ft_name}")
|
||||
ax.legend()
|
||||
if show_legend:
|
||||
ax.legend()
|
||||
if fname:
|
||||
fig.savefig(fname, bbox_inches="tight")
|
||||
|
||||
@@ -247,8 +248,11 @@ def main() -> None:
|
||||
snapshot_dfs, ith_pdb_dir, basename_prefix=f"generated_{i}_timestep_"
|
||||
)
|
||||
|
||||
# Generate histograms of sampled angles
|
||||
# Generate histograms of sampled angles -- separate plots, and a combined plot
|
||||
# For calculating angle distributions
|
||||
multi_fig, multi_axes = plt.subplots(
|
||||
dpi=300, nrows=2, ncols=3, figsize=(16, 7), sharex=True
|
||||
)
|
||||
final_sampled_stacked = np.vstack(final_sampled)
|
||||
for i, ft_name in enumerate(train_dset.feature_names["angles"]):
|
||||
orig_values = train_values_stacked[:, i]
|
||||
@@ -256,6 +260,14 @@ def main() -> None:
|
||||
plot_distribution_overlap(
|
||||
orig_values, samp_values, ft_name, fname=plotdir / f"dist_{ft_name}.pdf"
|
||||
)
|
||||
plot_distribution_overlap(
|
||||
orig_values,
|
||||
samp_values,
|
||||
ft_name,
|
||||
ax=multi_axes.flatten()[i],
|
||||
show_legend=i == 0,
|
||||
)
|
||||
multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight")
|
||||
|
||||
# Generate ramachandran plot for sampled angles
|
||||
plotting.plot_joint_kde(
|
||||
|
||||
Reference in New Issue
Block a user