mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 21:34:32 +08:00
Updates to Ramachandran plot
This commit is contained in:
@@ -11,7 +11,10 @@ from typing import *
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mpl_scatter_density
|
||||
from matplotlib import pyplot as plt
|
||||
from astropy.visualization import LogStretch
|
||||
from astropy.visualization.mpl_normalize import ImageNormalize
|
||||
|
||||
import torch
|
||||
|
||||
@@ -93,14 +96,29 @@ def write_preds_pdb_folder(
|
||||
|
||||
|
||||
def plot_ramachandran(
|
||||
phi_values, psi_values, fname: str, annot_ss: bool = False, title: str = ""
|
||||
phi_values,
|
||||
psi_values,
|
||||
fname: str,
|
||||
annot_ss: bool = False,
|
||||
title: str = "",
|
||||
plot_type: Literal["kde", "density_heatmap"] = "density_heatmap",
|
||||
):
|
||||
"""Create Ramachandran plot for phi_psi"""
|
||||
fig = plotting.plot_joint_kde(
|
||||
phi_values,
|
||||
psi_values,
|
||||
)
|
||||
ax = fig.axes[0]
|
||||
if plot_type == "kde":
|
||||
fig = plotting.plot_joint_kde(
|
||||
phi_values,
|
||||
psi_values,
|
||||
)
|
||||
ax = fig.axes[0]
|
||||
ax.set_xlim(-3.67, 3.67)
|
||||
ax.set_ylim(-3.67, 3.67)
|
||||
elif plot_type == "density_heatmap":
|
||||
fig = plt.figure(dpi=800)
|
||||
ax = fig.add_subplot(1, 1, 1, projection="scatter_density")
|
||||
norm = ImageNormalize(vmin=0.0, vmax=650, stretch=LogStretch())
|
||||
ax.scatter_density(phi_values, psi_values, norm=norm, cmap=plt.cm.Blues)
|
||||
else:
|
||||
raise NotImplementedError(f"Cannot plot type: {plot_type}")
|
||||
if annot_ss:
|
||||
# https://matplotlib.org/stable/tutorials/text/annotations.html
|
||||
ram_annot_arrows = dict(
|
||||
@@ -110,7 +128,7 @@ def plot_ramachandran(
|
||||
r"$\alpha$ helix, LH",
|
||||
xy=(1.2, 0.5),
|
||||
xycoords="data",
|
||||
xytext=(2.0, 1.2),
|
||||
xytext=(1.7, 1.2),
|
||||
textcoords="data",
|
||||
arrowprops=ram_annot_arrows,
|
||||
horizontalalignment="left",
|
||||
@@ -121,7 +139,7 @@ def plot_ramachandran(
|
||||
r"$\alpha$ helix, RH",
|
||||
xy=(-1.1, -0.6),
|
||||
xycoords="data",
|
||||
xytext=(-1.9, -1.9),
|
||||
xytext=(-1.7, -1.9),
|
||||
textcoords="data",
|
||||
arrowprops=ram_annot_arrows,
|
||||
horizontalalignment="right",
|
||||
@@ -132,18 +150,15 @@ def plot_ramachandran(
|
||||
r"$\beta$ sheet",
|
||||
xy=(-1.67, 2.25),
|
||||
xycoords="data",
|
||||
xytext=(-0.9, 3.33),
|
||||
xytext=(-0.9, 2.9),
|
||||
textcoords="data",
|
||||
arrowprops=ram_annot_arrows,
|
||||
horizontalalignment="left",
|
||||
verticalalignment="center",
|
||||
fontsize=14,
|
||||
)
|
||||
ax.set_xlabel("$\phi$", fontsize=14)
|
||||
ax.set_ylabel("$\psi$", fontsize=14)
|
||||
ax.set(
|
||||
xlim=(-3.67, 3.67), ylim=(-3.67, 3.67)
|
||||
)
|
||||
ax.set_xlabel("$\phi$ (radians)", fontsize=14)
|
||||
ax.set_ylabel("$\psi$ (radians)", fontsize=14)
|
||||
if title:
|
||||
ax.set_title(title, fontsize=16)
|
||||
fig.savefig(fname, bbox_inches="tight")
|
||||
|
||||
@@ -31,6 +31,7 @@ from sample import (
|
||||
from foldingdiff import custom_metrics as cm
|
||||
from foldingdiff.angles_and_coords import get_pdb_length
|
||||
|
||||
|
||||
def int_getter(x: str) -> int:
|
||||
"""Fetches integer value out of a string"""
|
||||
matches = re.findall(r"[0-9]+", x)
|
||||
@@ -54,21 +55,16 @@ def main(dir_name: Path):
|
||||
# Filter by test set sequence length
|
||||
test_dset_seq_lens = np.array([get_pdb_length(f) for f in test_dset.filenames])
|
||||
short_enough_idx = np.where(test_dset_seq_lens <= test_dset.pad)[0]
|
||||
logging.info(f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues")
|
||||
logging.info(
|
||||
f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues"
|
||||
)
|
||||
|
||||
# Ramachandran for test set
|
||||
select_by_attn = lambda x: x["angles"][x["attn_mask"] != 0]
|
||||
test_values = [
|
||||
select_by_attn(test_dset.dset.__getitem__(i, ignore_zero_center=True))
|
||||
for i in short_enough_idx
|
||||
]
|
||||
test_values_stacked = torch.cat(test_values, dim=0).cpu().numpy()
|
||||
plot_ramachandran(
|
||||
test_values_stacked[:, phi_idx],
|
||||
test_values_stacked[:, psi_idx],
|
||||
annot_ss=True,
|
||||
fname=plotdir / "ramachandran_test_annot.pdf",
|
||||
)
|
||||
|
||||
# Read in the sampled angles
|
||||
sampled_fnames = sorted(
|
||||
@@ -83,11 +79,23 @@ def main(dir_name: Path):
|
||||
logging.info(f"Found {len(sampled_dfs)} sets of generated angles")
|
||||
|
||||
sampled_stacked = np.vstack([df.values for df in sampled_dfs])
|
||||
# Ramachandran fro training set
|
||||
plot_ramachandran(
|
||||
sampled_stacked[:, phi_idx],
|
||||
sampled_stacked[:, psi_idx],
|
||||
fname=plotdir / "ramachandran_generated.pdf",
|
||||
)
|
||||
# Ramachandran for test set, subsampled to be same length
|
||||
rng = np.random.default_rng(seed=6489)
|
||||
ram_idx = rng.choice(
|
||||
len(test_values_stacked), size=len(sampled_stacked), replace=True
|
||||
)
|
||||
plot_ramachandran(
|
||||
test_values_stacked[ram_idx, phi_idx],
|
||||
test_values_stacked[ram_idx, psi_idx],
|
||||
annot_ss=True,
|
||||
fname=plotdir / "ramachandran_test_annot.pdf",
|
||||
)
|
||||
|
||||
# Plot distribution overlap
|
||||
multi_fig, multi_axes = plt.subplots(dpi=300, nrows=2, ncols=3, figsize=(13, 6.5))
|
||||
|
||||
Reference in New Issue
Block a user