Files
foldingdiff/bin/sample_plotting_only.py
2022-09-27 21:54:42 -07:00

156 lines
5.2 KiB
Python

"""
Plot distribution plots, CDF, and ramachandran plots for sampled proteins
Meant as a way to re-plot the sampled proteins after sampling has been done.
Assumes directory structure (e.g., contains sampled_angles folder)
Usage: python sample_potting_only.py <dirname_with_results>
"""
from glob import glob
import os, sys
import re
import logging
import json
from pathlib import Path
from typing import *
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import torch
# Import data loading code from main training script
from sample import (
FT_NAME_MAP,
build_datasets,
plot_distribution_overlap,
plot_ramachandran,
)
from foldingdiff import custom_metrics as cm
from foldingdiff.angles_and_coords import get_pdb_length
def int_getter(x: str) -> int:
"""Fetches integer value out of a string"""
matches = re.findall(r"[0-9]+", x)
assert len(matches) == 1
return int(matches.pop())
def main(dir_name: Path):
"""Run the script"""
plotdir = dir_name / "plots"
assert plotdir.is_dir()
# Read in the training args
with open(dir_name / "model_snapshot/training_args.json") as source:
training_args = json.load(source)
_, _, test_dset = build_datasets(training_args)
phi_idx = test_dset.feature_names["angles"].index("phi")
psi_idx = test_dset.feature_names["angles"].index("psi")
# Filter by test set sequence length
test_dset_seq_lens = np.array([get_pdb_length(f) for f in test_dset.filenames])
short_enough_idx = np.where(test_dset_seq_lens <= test_dset.pad)[0]
logging.info(
f"{len(short_enough_idx)}/{len(test_dset)} test set seqeunces < {test_dset.pad} residues"
)
select_by_attn = lambda x: x["angles"][x["attn_mask"] != 0]
test_values = [
select_by_attn(test_dset.dset.__getitem__(i, ignore_zero_center=True))
for i in short_enough_idx
]
test_values_stacked = torch.cat(test_values, dim=0).cpu().numpy()
# Read in the sampled angles
sampled_fnames = sorted(
glob(os.path.join(dir_name, "sampled_angles/*.csv.gz")),
key=lambda x: int_getter(os.path.basename(x)),
)
sampled_dfs = []
for fname in sampled_fnames:
df = pd.read_csv(fname, index_col=0)
sampled_dfs.append(df)
assert sampled_dfs
logging.info(f"Found {len(sampled_dfs)} sets of generated angles")
sampled_stacked = np.vstack([df.values for df in sampled_dfs])
# Ramachandran fro training set
plot_ramachandran(
sampled_stacked[:, phi_idx],
sampled_stacked[:, psi_idx],
fname=plotdir / "ramachandran_generated.pdf",
)
# Ramachandran for test set, subsampled to be same length
rng = np.random.default_rng(seed=6489)
ram_idx = rng.choice(
len(test_values_stacked), size=len(sampled_stacked), replace=True
)
plot_ramachandran(
test_values_stacked[ram_idx, phi_idx],
test_values_stacked[ram_idx, psi_idx],
annot_ss=True,
fname=plotdir / "ramachandran_test_annot.pdf",
)
# Plot distribution overlap
multi_fig, multi_axes = plt.subplots(dpi=300, nrows=2, ncols=3, figsize=(13, 6.5))
step_multi_fig, step_multi_axes = plt.subplots(
dpi=300, nrows=2, ncols=3, figsize=(14, 6.5)
)
for i, ft_name in enumerate(test_dset.feature_names["angles"]):
orig_values = test_values_stacked[:, i]
samp_values = sampled_stacked[:, i]
kl = cm.kl_from_empirical(samp_values, orig_values, nbins=200, pseudocount=True)
logging.info(f"Angle {ft_name} KL(generated || test) = {kl}")
ft_name_readable = FT_NAME_MAP[ft_name]
# Plot combo plots
plot_distribution_overlap(
{"Test": orig_values, "Sampled": samp_values},
title=f"{ft_name_readable} distribution, KL={kl:.4f}",
edgecolor="black",
ax=multi_axes.flatten()[i],
show_legend=i == 0,
alpha=0.45,
bins=60,
)
plot_distribution_overlap(
{"Test": orig_values, "Sampled": samp_values},
title=f"{ft_name_readable} CDF",
cumulative=True,
histtype="step",
ax=step_multi_axes.flatten()[i],
show_legend=i == 0,
alpha=0.6,
bins=60,
)
for i in [3, 4, 5]:
multi_axes.flatten()[i].set_xlabel("Angle (rad)", fontsize=14)
step_multi_axes.flatten()[i].set_xlabel("Angle (rad)", fontsize=14)
for i in [0, 3]:
multi_axes.flatten()[i].set_ylabel("Normalized frequency", fontsize=14)
step_multi_axes.flatten()[i].set_ylabel("Cumulative frequency", fontsize=14)
# for i in range(6):
# https://stackoverflow.com/questions/29188757/matplotlib-specify-format-of-floats-for-tick-labels
# multi_axes.flatten()[i].yaxis.set_major_formatter(FormatStrFormatter("%.2g"))
multi_fig.tight_layout()
multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight")
step_multi_fig.tight_layout()
step_multi_fig.savefig(plotdir / "cdf_combined.pdf", bbox_inches="tight")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
if len(sys.argv) == 1:
main(Path(os.getcwd()))
else:
main(Path(sys.argv[1]))