Updates to Ramachandran plot

This commit is contained in:
Kevin Wu
2022-09-27 21:54:42 -07:00
parent 1074327185
commit 17898dee6c
2 changed files with 45 additions and 22 deletions

View File

@@ -11,7 +11,10 @@ from typing import *
import numpy as np
import pandas as pd
import mpl_scatter_density
from matplotlib import pyplot as plt
from astropy.visualization import LogStretch
from astropy.visualization.mpl_normalize import ImageNormalize
import torch
@@ -93,14 +96,29 @@ def write_preds_pdb_folder(
def plot_ramachandran(
phi_values, psi_values, fname: str, annot_ss: bool = False, title: str = ""
phi_values,
psi_values,
fname: str,
annot_ss: bool = False,
title: str = "",
plot_type: Literal["kde", "density_heatmap"] = "density_heatmap",
):
"""Create Ramachandran plot for phi_psi"""
fig = plotting.plot_joint_kde(
phi_values,
psi_values,
)
ax = fig.axes[0]
if plot_type == "kde":
fig = plotting.plot_joint_kde(
phi_values,
psi_values,
)
ax = fig.axes[0]
ax.set_xlim(-3.67, 3.67)
ax.set_ylim(-3.67, 3.67)
elif plot_type == "density_heatmap":
fig = plt.figure(dpi=800)
ax = fig.add_subplot(1, 1, 1, projection="scatter_density")
norm = ImageNormalize(vmin=0.0, vmax=650, stretch=LogStretch())
ax.scatter_density(phi_values, psi_values, norm=norm, cmap=plt.cm.Blues)
else:
raise NotImplementedError(f"Cannot plot type: {plot_type}")
if annot_ss:
# https://matplotlib.org/stable/tutorials/text/annotations.html
ram_annot_arrows = dict(
@@ -110,7 +128,7 @@ def plot_ramachandran(
r"$\alpha$ helix, LH",
xy=(1.2, 0.5),
xycoords="data",
xytext=(2.0, 1.2),
xytext=(1.7, 1.2),
textcoords="data",
arrowprops=ram_annot_arrows,
horizontalalignment="left",
@@ -121,7 +139,7 @@ def plot_ramachandran(
r"$\alpha$ helix, RH",
xy=(-1.1, -0.6),
xycoords="data",
xytext=(-1.9, -1.9),
xytext=(-1.7, -1.9),
textcoords="data",
arrowprops=ram_annot_arrows,
horizontalalignment="right",
@@ -132,18 +150,15 @@ def plot_ramachandran(
r"$\beta$ sheet",
xy=(-1.67, 2.25),
xycoords="data",
xytext=(-0.9, 3.33),
xytext=(-0.9, 2.9),
textcoords="data",
arrowprops=ram_annot_arrows,
horizontalalignment="left",
verticalalignment="center",
fontsize=14,
)
ax.set_xlabel("$\phi$", fontsize=14)
ax.set_ylabel("$\psi$", fontsize=14)
ax.set(
xlim=(-3.67, 3.67), ylim=(-3.67, 3.67)
)
ax.set_xlabel("$\phi$ (radians)", fontsize=14)
ax.set_ylabel("$\psi$ (radians)", fontsize=14)
if title:
ax.set_title(title, fontsize=16)
fig.savefig(fname, bbox_inches="tight")

View File

@@ -31,6 +31,7 @@ from sample import (
from foldingdiff import custom_metrics as cm
from foldingdiff.angles_and_coords import get_pdb_length
def int_getter(x: str) -> int:
"""Fetches integer value out of a string"""
matches = re.findall(r"[0-9]+", x)
@@ -54,21 +55,16 @@ def main(dir_name: Path):
# Filter by test set sequence length
test_dset_seq_lens = np.array([get_pdb_length(f) for f in test_dset.filenames])
short_enough_idx = np.where(test_dset_seq_lens <= test_dset.pad)[0]
logging.info(f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues")
logging.info(
f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues"
)
# Ramachandran for test set
select_by_attn = lambda x: x["angles"][x["attn_mask"] != 0]
test_values = [
select_by_attn(test_dset.dset.__getitem__(i, ignore_zero_center=True))
for i in short_enough_idx
]
test_values_stacked = torch.cat(test_values, dim=0).cpu().numpy()
plot_ramachandran(
test_values_stacked[:, phi_idx],
test_values_stacked[:, psi_idx],
annot_ss=True,
fname=plotdir / "ramachandran_test_annot.pdf",
)
# Read in the sampled angles
sampled_fnames = sorted(
@@ -83,11 +79,23 @@ def main(dir_name: Path):
logging.info(f"Found {len(sampled_dfs)} sets of generated angles")
sampled_stacked = np.vstack([df.values for df in sampled_dfs])
# Ramachandran fro training set
plot_ramachandran(
sampled_stacked[:, phi_idx],
sampled_stacked[:, psi_idx],
fname=plotdir / "ramachandran_generated.pdf",
)
# Ramachandran for test set, subsampled to be same length
rng = np.random.default_rng(seed=6489)
ram_idx = rng.choice(
len(test_values_stacked), size=len(sampled_stacked), replace=True
)
plot_ramachandran(
test_values_stacked[ram_idx, phi_idx],
test_values_stacked[ram_idx, psi_idx],
annot_ss=True,
fname=plotdir / "ramachandran_test_annot.pdf",
)
# Plot distribution overlap
multi_fig, multi_axes = plt.subplots(dpi=300, nrows=2, ncols=3, figsize=(13, 6.5))