mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 21:34:32 +08:00
475 lines
16 KiB
Python
475 lines
16 KiB
Python
"""
|
|
Script to sample from a trained diffusion model
|
|
"""
|
|
import multiprocessing
|
|
import os, sys
|
|
import argparse
|
|
import logging
|
|
import json
|
|
from pathlib import Path
|
|
from typing import *
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import mpl_scatter_density
|
|
from matplotlib import pyplot as plt
|
|
from astropy.visualization import LogStretch
|
|
from astropy.visualization.mpl_normalize import ImageNormalize
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
|
|
# Import data loading code from main training script
|
|
from train import get_train_valid_test_sets
|
|
from annot_secondary_structures import make_ss_cooccurrence_plot
|
|
|
|
from foldingdiff import modelling
|
|
from foldingdiff import sampling
|
|
from foldingdiff import plotting
|
|
from foldingdiff.datasets import AnglesEmptyDataset, NoisedAnglesDataset
|
|
from foldingdiff.angles_and_coords import create_new_chain_nerf
|
|
from foldingdiff import utils
|
|
|
|
# :)
|
|
SEED = int(
|
|
float.fromhex("54616977616e20697320616e20696e646570656e64656e7420636f756e747279")
|
|
% 10000
|
|
)
|
|
|
|
FT_NAME_MAP = {
|
|
"phi": r"$\phi$",
|
|
"psi": r"$\psi$",
|
|
"omega": r"$\omega$",
|
|
"tau": r"$\theta_1$",
|
|
"CA:C:1N": r"$\theta_2$",
|
|
"C:1N:1CA": r"$\theta_3$",
|
|
}
|
|
|
|
|
|
def build_datasets(
|
|
model_dir: Path, load_actual: bool = True
|
|
) -> Tuple[NoisedAnglesDataset, NoisedAnglesDataset, NoisedAnglesDataset]:
|
|
"""
|
|
Build datasets given args again. If load_actual is given, the load the actual datasets
|
|
containing actual values; otherwise, load a empty shell that provides the same API for
|
|
faster generation.
|
|
"""
|
|
with open(model_dir / "training_args.json") as source:
|
|
training_args = json.load(source)
|
|
# Build args based on training args
|
|
if load_actual:
|
|
dset_args = dict(
|
|
timesteps=training_args["timesteps"],
|
|
variance_schedule=training_args["variance_schedule"],
|
|
max_seq_len=training_args["max_seq_len"],
|
|
min_seq_len=training_args["min_seq_len"],
|
|
var_scale=training_args["variance_scale"],
|
|
syn_noiser=training_args["syn_noiser"],
|
|
exhaustive_t=training_args["exhaustive_validation_t"],
|
|
single_angle_debug=training_args["single_angle_debug"],
|
|
single_time_debug=training_args["single_timestep_debug"],
|
|
toy=training_args["subset"],
|
|
angles_definitions=training_args["angles_definitions"],
|
|
train_only=False,
|
|
)
|
|
|
|
train_dset, valid_dset, test_dset = get_train_valid_test_sets(**dset_args)
|
|
logging.info(
|
|
f"Training dset contains features: {train_dset.feature_names} - angular {train_dset.feature_is_angular}"
|
|
)
|
|
return train_dset, valid_dset, test_dset
|
|
else:
|
|
mean_file = model_dir / "training_mean_offset.npy"
|
|
placeholder_dset = AnglesEmptyDataset(
|
|
feature_set_key=training_args["angles_definitions"],
|
|
pad=training_args["max_seq_len"],
|
|
mean_offset=None if not mean_file.exists() else np.load(mean_file),
|
|
)
|
|
noised_dsets = [
|
|
NoisedAnglesDataset(
|
|
dset=placeholder_dset,
|
|
dset_key="coords"
|
|
if training_args["angles_definitions"] == "cart-coords"
|
|
else "angles",
|
|
timesteps=training_args["timesteps"],
|
|
exhaustive_t=False,
|
|
beta_schedule=training_args["variance_schedule"],
|
|
nonangular_variance=1.0,
|
|
angular_variance=training_args["variance_scale"],
|
|
)
|
|
for _ in range(3)
|
|
]
|
|
return noised_dsets
|
|
|
|
|
|
def write_preds_pdb_folder(
|
|
final_sampled: Sequence[pd.DataFrame],
|
|
outdir: str,
|
|
basename_prefix: str = "generated_",
|
|
threads: int = multiprocessing.cpu_count(),
|
|
) -> List[str]:
|
|
"""
|
|
Write the predictions as pdb files in the given folder along with information regarding the
|
|
tm_score for each prediction. Returns the list of files written.
|
|
"""
|
|
os.makedirs(outdir, exist_ok=True)
|
|
logging.info(
|
|
f"Writing sampled angles as PDB files to {outdir} using {threads} threads"
|
|
)
|
|
# Create the pairs of arguments
|
|
arg_tuples = [
|
|
(os.path.join(outdir, f"{basename_prefix}{i}.pdb"), samp)
|
|
for i, samp in enumerate(final_sampled)
|
|
]
|
|
# Write in parallel
|
|
with multiprocessing.Pool(threads) as pool:
|
|
files_written = pool.starmap(create_new_chain_nerf, arg_tuples)
|
|
|
|
return files_written
|
|
|
|
|
|
def plot_ramachandran(
|
|
phi_values,
|
|
psi_values,
|
|
fname: str,
|
|
annot_ss: bool = False,
|
|
title: str = "",
|
|
plot_type: Literal["kde", "density_heatmap"] = "density_heatmap",
|
|
):
|
|
"""Create Ramachandran plot for phi_psi"""
|
|
if plot_type == "kde":
|
|
fig = plotting.plot_joint_kde(
|
|
phi_values,
|
|
psi_values,
|
|
)
|
|
ax = fig.axes[0]
|
|
ax.set_xlim(-3.67, 3.67)
|
|
ax.set_ylim(-3.67, 3.67)
|
|
elif plot_type == "density_heatmap":
|
|
fig = plt.figure(dpi=800)
|
|
ax = fig.add_subplot(1, 1, 1, projection="scatter_density")
|
|
norm = ImageNormalize(vmin=0.0, vmax=650, stretch=LogStretch())
|
|
ax.scatter_density(phi_values, psi_values, norm=norm, cmap=plt.cm.Blues)
|
|
else:
|
|
raise NotImplementedError(f"Cannot plot type: {plot_type}")
|
|
if annot_ss:
|
|
# https://matplotlib.org/stable/tutorials/text/annotations.html
|
|
ram_annot_arrows = dict(
|
|
facecolor="black", shrink=0.05, headwidth=6.0, width=1.5
|
|
)
|
|
ax.annotate(
|
|
r"$\alpha$ helix, LH",
|
|
xy=(1.2, 0.5),
|
|
xycoords="data",
|
|
xytext=(1.7, 1.2),
|
|
textcoords="data",
|
|
arrowprops=ram_annot_arrows,
|
|
horizontalalignment="left",
|
|
verticalalignment="center",
|
|
fontsize=14,
|
|
)
|
|
ax.annotate(
|
|
r"$\alpha$ helix, RH",
|
|
xy=(-1.1, -0.6),
|
|
xycoords="data",
|
|
xytext=(-1.7, -1.9),
|
|
textcoords="data",
|
|
arrowprops=ram_annot_arrows,
|
|
horizontalalignment="right",
|
|
verticalalignment="center",
|
|
fontsize=14,
|
|
)
|
|
ax.annotate(
|
|
r"$\beta$ sheet",
|
|
xy=(-1.67, 2.25),
|
|
xycoords="data",
|
|
xytext=(-0.9, 2.9),
|
|
textcoords="data",
|
|
arrowprops=ram_annot_arrows,
|
|
horizontalalignment="left",
|
|
verticalalignment="center",
|
|
fontsize=14,
|
|
)
|
|
ax.set_xlabel("$\phi$ (radians)", fontsize=14)
|
|
ax.set_ylabel("$\psi$ (radians)", fontsize=14)
|
|
if title:
|
|
ax.set_title(title, fontsize=16)
|
|
fig.savefig(fname, bbox_inches="tight")
|
|
|
|
|
|
def plot_distribution_overlap(
|
|
values_dicts: Dict[str, np.ndarray],
|
|
title: str = "Sampled distribution",
|
|
fname: str = "",
|
|
bins: int = 50,
|
|
ax=None,
|
|
show_legend: bool = True,
|
|
**kwargs,
|
|
):
|
|
"""
|
|
Plot the distribution overlap between the training and sampled values
|
|
Additional arguments are given to ax.hist; for example, can specify
|
|
histtype='step', cumulative=True
|
|
to get a CDF plot
|
|
"""
|
|
# Plot the distribution overlap
|
|
if ax is None:
|
|
fig, ax = plt.subplots(dpi=300)
|
|
|
|
for k, v in values_dicts.items():
|
|
if v is None:
|
|
continue
|
|
_n, bins, _pbatches = ax.hist(
|
|
v,
|
|
bins=bins,
|
|
label=k,
|
|
density=True,
|
|
**kwargs,
|
|
)
|
|
if title:
|
|
ax.set_title(title, fontsize=16)
|
|
if show_legend:
|
|
ax.legend()
|
|
if fname:
|
|
fig.savefig(fname, bbox_inches="tight")
|
|
|
|
|
|
def build_parser() -> argparse.ArgumentParser:
|
|
"""
|
|
Build CLI parser
|
|
"""
|
|
parser = argparse.ArgumentParser(
|
|
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--model",
|
|
type=str,
|
|
default="wukevin/foldingdiff_cath",
|
|
help="Path to model directory, or a repo identifier on huggingface hub. Should contain training_args.json, config.json, and models folder at a minimum.",
|
|
)
|
|
parser.add_argument(
|
|
"--outdir", "-o", type=str, default=os.getcwd(), help="Path to output directory"
|
|
)
|
|
parser.add_argument(
|
|
"--num",
|
|
"-n",
|
|
type=int,
|
|
default=10,
|
|
help="Number of examples to generate *per length*",
|
|
)
|
|
parser.add_argument(
|
|
"-l",
|
|
"--lengths",
|
|
type=int,
|
|
nargs=2,
|
|
default=[50, 128],
|
|
help="Range of lengths to sample from",
|
|
)
|
|
parser.add_argument(
|
|
"-b",
|
|
"--batchsize",
|
|
type=int,
|
|
default=512,
|
|
help="Batch size to use when sampling. 256 consumes ~2GB of GPU memory, 512 ~3.5GB",
|
|
)
|
|
parser.add_argument(
|
|
"--fullhistory",
|
|
action="store_true",
|
|
help="Store full history, not just final structure",
|
|
)
|
|
parser.add_argument(
|
|
"--testcomparison", action="store_true", help="Run comparison against test set"
|
|
)
|
|
parser.add_argument("--nopsea", action="store_true", help="Skip PSEA calculations")
|
|
parser.add_argument("--seed", type=int, default=SEED, help="Random seed")
|
|
parser.add_argument("--device", type=str, default="cuda:0", help="Device to use")
|
|
return parser
|
|
|
|
|
|
def main() -> None:
|
|
"""Run the script"""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
|
|
logging.info(f"Creating {args.outdir}")
|
|
os.makedirs(args.outdir, exist_ok=True)
|
|
outdir = Path(args.outdir)
|
|
# Be extra cautious so we don't overwrite any results
|
|
assert not os.listdir(outdir), f"Expected {outdir} to be empty!"
|
|
|
|
# Download the model if it was given on modelhub
|
|
if utils.is_huggingface_hub_id(args.model):
|
|
logging.info(f"Detected huggingface repo ID {args.model}")
|
|
dl_path = snapshot_download(args.model) # Caching is automatic
|
|
assert os.path.isdir(dl_path)
|
|
logging.info(f"Using downloaded model at {dl_path}")
|
|
args.model = dl_path
|
|
|
|
plotdir = outdir / "plots"
|
|
os.makedirs(plotdir, exist_ok=True)
|
|
|
|
# Load the dataset based on training args
|
|
train_dset, _, test_dset = build_datasets(
|
|
Path(args.model), load_actual=args.testcomparison
|
|
)
|
|
phi_idx = test_dset.feature_names["angles"].index("phi")
|
|
psi_idx = test_dset.feature_names["angles"].index("psi")
|
|
# Fetch values for training distribution
|
|
select_by_attn = lambda x: x["angles"][x["attn_mask"] != 0]
|
|
|
|
if args.testcomparison:
|
|
test_values = [
|
|
select_by_attn(test_dset.dset.__getitem__(i, ignore_zero_center=True))
|
|
for i in range(len(test_dset))
|
|
]
|
|
test_values_stacked = torch.cat(test_values, dim=0).cpu().numpy()
|
|
|
|
# Plot ramachandran plot for the training distribution
|
|
# Default figure size is 6.4x4.8 inches
|
|
plot_ramachandran(
|
|
test_values_stacked[:, phi_idx],
|
|
test_values_stacked[:, psi_idx],
|
|
annot_ss=True,
|
|
fname=plotdir / "ramachandran_test_annot.pdf",
|
|
)
|
|
else:
|
|
test_values_stacked = None
|
|
|
|
# Load the model
|
|
model_snapshot_dir = outdir / "model_snapshot"
|
|
model = modelling.BertForDiffusionBase.from_dir(
|
|
args.model, copy_to=model_snapshot_dir
|
|
).to(torch.device(args.device))
|
|
|
|
# Checks
|
|
sweep_min_len, sweep_max_len = args.lengths
|
|
assert sweep_min_len < sweep_max_len
|
|
assert sweep_max_len <= train_dset.dset.pad
|
|
|
|
# Perform sampling
|
|
torch.manual_seed(args.seed)
|
|
sampled = sampling.sample(
|
|
model,
|
|
train_dset,
|
|
n=args.num,
|
|
sweep_lengths=(sweep_min_len, sweep_max_len),
|
|
batch_size=args.batchsize,
|
|
)
|
|
final_sampled = [s[-1] for s in sampled]
|
|
sampled_dfs = [
|
|
pd.DataFrame(s, columns=train_dset.feature_names["angles"])
|
|
for s in final_sampled
|
|
]
|
|
|
|
# Write the raw sampled items to csv files
|
|
sampled_angles_folder = outdir / "sampled_angles"
|
|
os.makedirs(sampled_angles_folder, exist_ok=True)
|
|
logging.info(f"Writing sampled angles to {sampled_angles_folder}")
|
|
for i, s in enumerate(sampled_dfs):
|
|
s.to_csv(sampled_angles_folder / f"generated_{i}.csv.gz")
|
|
# Write the sampled angles as pdb files
|
|
pdb_files = write_preds_pdb_folder(sampled_dfs, outdir / "sampled_pdb")
|
|
|
|
# If full history is specified, create a separate directory and write those files
|
|
if args.fullhistory:
|
|
# Write the angles
|
|
full_history_angles_dir = sampled_angles_folder / "sample_history"
|
|
os.makedirs(full_history_angles_dir)
|
|
full_history_pdb_dir = outdir / "sampled_pdb/sample_history"
|
|
os.makedirs(full_history_pdb_dir)
|
|
# sampled is a list of np arrays
|
|
for i, sampled_series in enumerate(sampled):
|
|
snapshot_dfs = [
|
|
pd.DataFrame(snapshot, columns=train_dset.feature_names["angles"])
|
|
for snapshot in sampled_series
|
|
]
|
|
# Write the angles
|
|
ith_angle_dir = full_history_angles_dir / f"generated_{i}"
|
|
os.makedirs(ith_angle_dir, exist_ok=True)
|
|
for timestep, snapshot_df in enumerate(snapshot_dfs):
|
|
snapshot_df.to_csv(
|
|
ith_angle_dir / f"generated_{i}_timestep_{timestep}.csv.gz"
|
|
)
|
|
# Write the pdb files
|
|
ith_pdb_dir = full_history_pdb_dir / f"generated_{i}"
|
|
write_preds_pdb_folder(
|
|
snapshot_dfs, ith_pdb_dir, basename_prefix=f"generated_{i}_timestep_"
|
|
)
|
|
|
|
# Generate histograms of sampled angles -- separate plots, and a combined plot
|
|
# For calculating angle distributions
|
|
multi_fig, multi_axes = plt.subplots(
|
|
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
|
|
)
|
|
step_multi_fig, step_multi_axes = plt.subplots(
|
|
dpi=300, nrows=2, ncols=3, figsize=(14, 6), sharex=True
|
|
)
|
|
final_sampled_stacked = np.vstack(final_sampled)
|
|
for i, ft_name in enumerate(test_dset.feature_names["angles"]):
|
|
orig_values = (
|
|
test_values_stacked[:, i] if test_values_stacked is not None else None
|
|
)
|
|
samp_values = final_sampled_stacked[:, i]
|
|
|
|
ft_name_readable = FT_NAME_MAP[ft_name]
|
|
|
|
# Plot single plots
|
|
plot_distribution_overlap(
|
|
{"Test": orig_values, "Sampled": samp_values},
|
|
title=f"Sampled angle distribution - {ft_name_readable}",
|
|
fname=plotdir / f"dist_{ft_name}.pdf",
|
|
)
|
|
plot_distribution_overlap(
|
|
{"Test": orig_values, "Sampled": samp_values},
|
|
title=f"Sampled angle CDF - {ft_name_readable}",
|
|
histtype="step",
|
|
cumulative=True,
|
|
fname=plotdir / f"cdf_{ft_name}.pdf",
|
|
)
|
|
|
|
# Plot combo plots
|
|
plot_distribution_overlap(
|
|
{"Test": orig_values, "Sampled": samp_values},
|
|
title=f"Sampled angle distribution - {ft_name_readable}",
|
|
ax=multi_axes.flatten()[i],
|
|
show_legend=i == 0,
|
|
)
|
|
plot_distribution_overlap(
|
|
{"Test": orig_values, "Sampled": samp_values},
|
|
title=f"Sampled angle CDF - {ft_name_readable}",
|
|
cumulative=True,
|
|
histtype="step",
|
|
ax=step_multi_axes.flatten()[i],
|
|
show_legend=i == 0,
|
|
)
|
|
multi_fig.savefig(plotdir / "dist_combined.pdf", bbox_inches="tight")
|
|
step_multi_fig.savefig(plotdir / "cdf_combined.pdf", bbox_inches="tight")
|
|
|
|
# Generate ramachandran plot for sampled angles
|
|
plot_ramachandran(
|
|
final_sampled_stacked[:, phi_idx],
|
|
final_sampled_stacked[:, psi_idx],
|
|
fname=plotdir / "ramachandran_generated.pdf",
|
|
)
|
|
|
|
# Generate plots of secondary structure co-occurrence
|
|
if not args.nopsea:
|
|
make_ss_cooccurrence_plot(
|
|
pdb_files,
|
|
str(outdir / "plots" / "ss_cooccurrence_sampled.pdf"),
|
|
threads=multiprocessing.cpu_count(),
|
|
)
|
|
if args.testcomparison:
|
|
make_ss_cooccurrence_plot(
|
|
test_dset.filenames,
|
|
str(outdir / "plots" / "ss_cooccurrence_test.pdf"),
|
|
max_seq_len=test_dset.dset.pad,
|
|
threads=multiprocessing.cpu_count(),
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main()
|