diff --git a/bin/sample.py b/bin/sample.py index 7a773df..7cc8a10 100644 --- a/bin/sample.py +++ b/bin/sample.py @@ -11,7 +11,10 @@ from typing import * import numpy as np import pandas as pd +import mpl_scatter_density from matplotlib import pyplot as plt +from astropy.visualization import LogStretch +from astropy.visualization.mpl_normalize import ImageNormalize import torch @@ -93,14 +96,29 @@ def write_preds_pdb_folder( def plot_ramachandran( - phi_values, psi_values, fname: str, annot_ss: bool = False, title: str = "" + phi_values, + psi_values, + fname: str, + annot_ss: bool = False, + title: str = "", + plot_type: Literal["kde", "density_heatmap"] = "density_heatmap", ): """Create Ramachandran plot for phi_psi""" - fig = plotting.plot_joint_kde( - phi_values, - psi_values, - ) - ax = fig.axes[0] + if plot_type == "kde": + fig = plotting.plot_joint_kde( + phi_values, + psi_values, + ) + ax = fig.axes[0] + ax.set_xlim(-3.67, 3.67) + ax.set_ylim(-3.67, 3.67) + elif plot_type == "density_heatmap": + fig = plt.figure(dpi=800) + ax = fig.add_subplot(1, 1, 1, projection="scatter_density") + norm = ImageNormalize(vmin=0.0, vmax=650, stretch=LogStretch()) + ax.scatter_density(phi_values, psi_values, norm=norm, cmap=plt.cm.Blues) + else: + raise NotImplementedError(f"Cannot plot type: {plot_type}") if annot_ss: # https://matplotlib.org/stable/tutorials/text/annotations.html ram_annot_arrows = dict( @@ -110,7 +128,7 @@ def plot_ramachandran( r"$\alpha$ helix, LH", xy=(1.2, 0.5), xycoords="data", - xytext=(2.0, 1.2), + xytext=(1.7, 1.2), textcoords="data", arrowprops=ram_annot_arrows, horizontalalignment="left", @@ -121,7 +139,7 @@ def plot_ramachandran( r"$\alpha$ helix, RH", xy=(-1.1, -0.6), xycoords="data", - xytext=(-1.9, -1.9), + xytext=(-1.7, -1.9), textcoords="data", arrowprops=ram_annot_arrows, horizontalalignment="right", @@ -132,18 +150,15 @@ def plot_ramachandran( r"$\beta$ sheet", xy=(-1.67, 2.25), xycoords="data", - xytext=(-0.9, 3.33), + xytext=(-0.9, 2.9), textcoords="data", arrowprops=ram_annot_arrows, horizontalalignment="left", verticalalignment="center", fontsize=14, ) - ax.set_xlabel("$\phi$", fontsize=14) - ax.set_ylabel("$\psi$", fontsize=14) - ax.set( - xlim=(-3.67, 3.67), ylim=(-3.67, 3.67) - ) + ax.set_xlabel("$\phi$ (radians)", fontsize=14) + ax.set_ylabel("$\psi$ (radians)", fontsize=14) if title: ax.set_title(title, fontsize=16) fig.savefig(fname, bbox_inches="tight") diff --git a/bin/sample_plotting_only.py b/bin/sample_plotting_only.py index 5e6d996..f379223 100644 --- a/bin/sample_plotting_only.py +++ b/bin/sample_plotting_only.py @@ -31,6 +31,7 @@ from sample import ( from foldingdiff import custom_metrics as cm from foldingdiff.angles_and_coords import get_pdb_length + def int_getter(x: str) -> int: """Fetches integer value out of a string""" matches = re.findall(r"[0-9]+", x) @@ -54,21 +55,16 @@ def main(dir_name: Path): # Filter by test set sequence length test_dset_seq_lens = np.array([get_pdb_length(f) for f in test_dset.filenames]) short_enough_idx = np.where(test_dset_seq_lens <= test_dset.pad)[0] - logging.info(f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues") + logging.info( + f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues" + ) - # Ramachandran for test set select_by_attn = lambda x: x["angles"][x["attn_mask"] != 0] test_values = [ select_by_attn(test_dset.dset.__getitem__(i, ignore_zero_center=True)) for i in short_enough_idx ] test_values_stacked = torch.cat(test_values, dim=0).cpu().numpy() - plot_ramachandran( - test_values_stacked[:, phi_idx], - test_values_stacked[:, psi_idx], - annot_ss=True, - fname=plotdir / "ramachandran_test_annot.pdf", - ) # Read in the sampled angles sampled_fnames = sorted( @@ -83,11 +79,23 @@ def main(dir_name: Path): logging.info(f"Found {len(sampled_dfs)} sets of generated angles") sampled_stacked = np.vstack([df.values for df in sampled_dfs]) + # Ramachandran fro training set plot_ramachandran( sampled_stacked[:, phi_idx], sampled_stacked[:, psi_idx], fname=plotdir / "ramachandran_generated.pdf", ) + # Ramachandran for test set, subsampled to be same length + rng = np.random.default_rng(seed=6489) + ram_idx = rng.choice( + len(test_values_stacked), size=len(sampled_stacked), replace=True + ) + plot_ramachandran( + test_values_stacked[ram_idx, phi_idx], + test_values_stacked[ram_idx, psi_idx], + annot_ss=True, + fname=plotdir / "ramachandran_test_annot.pdf", + ) # Plot distribution overlap multi_fig, multi_axes = plt.subplots(dpi=300, nrows=2, ncols=3, figsize=(13, 6.5))