mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 21:34:32 +08:00
365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""
|
|
Code for sampling from diffusion models
|
|
"""
|
|
import json
|
|
import os
|
|
import multiprocessing as mp
|
|
from pathlib import Path
|
|
import tempfile
|
|
import logging
|
|
from typing import *
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.utils.data import default_collate
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from foldingdiff import datasets as dsets
|
|
from foldingdiff import beta_schedules, modelling, utils, sampling, tmalign
|
|
from foldingdiff import angles_and_coords as ac
|
|
|
|
|
|
@torch.no_grad()
|
|
def p_sample(
|
|
model: nn.Module,
|
|
x: torch.Tensor,
|
|
t: torch.Tensor,
|
|
seq_lens: Sequence[int],
|
|
t_index: torch.Tensor,
|
|
betas: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Sample the given timestep. Note that this _may_ fall off the manifold if we just
|
|
feed the output back into itself repeatedly, so we need to perform modulo on it
|
|
(see p_sample_loop)
|
|
"""
|
|
# Calculate alphas and betas
|
|
alpha_beta_values = beta_schedules.compute_alphas(betas)
|
|
sqrt_recip_alphas = 1.0 / torch.sqrt(alpha_beta_values["alphas"])
|
|
|
|
# Select based on time
|
|
t_unique = torch.unique(t)
|
|
assert len(t_unique) == 1, f"Got multiple values for t: {t_unique}"
|
|
t_index = t_unique.item()
|
|
sqrt_recip_alphas_t = sqrt_recip_alphas[t_index]
|
|
betas_t = betas[t_index]
|
|
sqrt_one_minus_alphas_cumprod_t = alpha_beta_values[
|
|
"sqrt_one_minus_alphas_cumprod"
|
|
][t_index]
|
|
|
|
# Create the attention mask
|
|
attn_mask = torch.zeros(x.shape[:2], device=x.device)
|
|
for i, length in enumerate(seq_lens):
|
|
attn_mask[i, :length] = 1.0
|
|
|
|
# Equation 11 in the paper
|
|
# Use our model (noise predictor) to predict the mean
|
|
model_mean = sqrt_recip_alphas_t * (
|
|
x
|
|
- betas_t
|
|
* model(x, t, attention_mask=attn_mask)
|
|
/ sqrt_one_minus_alphas_cumprod_t
|
|
)
|
|
|
|
if t_index == 0:
|
|
return model_mean
|
|
else:
|
|
posterior_variance_t = alpha_beta_values["posterior_variance"][t_index]
|
|
noise = torch.randn_like(x)
|
|
# Algorithm 2 line 4:
|
|
return model_mean + torch.sqrt(posterior_variance_t) * noise
|
|
|
|
|
|
@torch.no_grad()
|
|
def p_sample_loop(
|
|
model: nn.Module,
|
|
lengths: Sequence[int],
|
|
noise: torch.Tensor,
|
|
timesteps: int,
|
|
betas: torch.Tensor,
|
|
is_angle: Union[bool, List[bool]] = [False, True, True, True],
|
|
disable_pbar: bool = False,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Returns a tensor of shape (timesteps, batch_size, seq_len, n_ft)
|
|
"""
|
|
device = next(model.parameters()).device
|
|
b = noise.shape[0]
|
|
img = noise.to(device)
|
|
# Report metrics on starting noise
|
|
# amin and amax support reducing on multiple dimensions
|
|
logging.info(
|
|
f"Starting from noise {noise.shape} with angularity {is_angle} and range {torch.amin(img, dim=(0, 1))} - {torch.amax(img, dim=(0, 1))} using {device}"
|
|
)
|
|
|
|
imgs = []
|
|
|
|
for i in tqdm(
|
|
reversed(range(0, timesteps)),
|
|
desc="sampling loop time step",
|
|
total=timesteps,
|
|
disable=disable_pbar,
|
|
):
|
|
# Shape is (batch, seq_len, 4)
|
|
img = p_sample(
|
|
model=model,
|
|
x=img,
|
|
t=torch.full((b,), i, device=device, dtype=torch.long), # time vector
|
|
seq_lens=lengths,
|
|
t_index=i,
|
|
betas=betas,
|
|
)
|
|
|
|
# Wrap if angular
|
|
if isinstance(is_angle, bool):
|
|
if is_angle:
|
|
img = utils.modulo_with_wrapped_range(
|
|
img, range_min=-torch.pi, range_max=torch.pi
|
|
)
|
|
else:
|
|
assert len(is_angle) == img.shape[-1]
|
|
for j in range(img.shape[-1]):
|
|
if is_angle[j]:
|
|
img[:, :, j] = utils.modulo_with_wrapped_range(
|
|
img[:, :, j], range_min=-torch.pi, range_max=torch.pi
|
|
)
|
|
imgs.append(img.cpu())
|
|
return torch.stack(imgs)
|
|
|
|
|
|
def sample(
|
|
model: nn.Module,
|
|
train_dset: dsets.NoisedAnglesDataset,
|
|
n: int = 10,
|
|
sweep_lengths: Optional[Tuple[int, int]] = (50, 128),
|
|
batch_size: int = 512,
|
|
feature_key: str = "angles",
|
|
disable_pbar: bool = False,
|
|
trim_to_length: bool = True, # Trim padding regions to reduce memory
|
|
) -> List[np.ndarray]:
|
|
"""
|
|
Sample from the given model. Use the train_dset to generate noise to sample
|
|
sequence lengths. Returns a list of arrays, shape (timesteps, seq_len, fts).
|
|
If sweep_lengths is set, we generate n items per length in the sweep range
|
|
|
|
train_dset object must support:
|
|
- sample_noise - provided by NoisedAnglesDataset
|
|
- timesteps - provided by NoisedAnglesDataset
|
|
- alpha_beta_terms - provided by NoisedAnglesDataset
|
|
- feature_is_angular - provided by *wrapped dataset* under NoisedAnglesDataset
|
|
- pad - provided by *wrapped dataset* under NoisedAnglesDataset
|
|
And optionally, sample_length()
|
|
"""
|
|
# Process each batch
|
|
if sweep_lengths is not None:
|
|
sweep_min, sweep_max = sweep_lengths
|
|
if not sweep_min < sweep_max:
|
|
raise ValueError(
|
|
f"Minimum length {sweep_min} must be less than maximum {sweep_max}"
|
|
)
|
|
logging.info(
|
|
f"Sweeping from {sweep_min}-{sweep_max} with {n} examples at each length"
|
|
)
|
|
lengths = []
|
|
for l in range(sweep_min, sweep_max):
|
|
lengths.extend([l] * n)
|
|
else:
|
|
lengths = [train_dset.sample_length() for _ in range(n)]
|
|
lengths_chunkified = [
|
|
lengths[i : i + batch_size] for i in range(0, len(lengths), batch_size)
|
|
]
|
|
|
|
logging.info(f"Sampling {len(lengths)} items in batches of size {batch_size}")
|
|
retval = []
|
|
for this_lengths in lengths_chunkified:
|
|
batch = len(this_lengths)
|
|
# Sample noise and sample the lengths
|
|
noise = train_dset.sample_noise(
|
|
torch.zeros((batch, train_dset.pad, model.n_inputs), dtype=torch.float32)
|
|
)
|
|
|
|
# Trim things that are beyond the length of what we are generating
|
|
if trim_to_length:
|
|
noise = noise[:, : max(this_lengths), :]
|
|
|
|
# Produces (timesteps, batch_size, seq_len, n_ft)
|
|
sampled = p_sample_loop(
|
|
model=model,
|
|
lengths=this_lengths,
|
|
noise=noise,
|
|
timesteps=train_dset.timesteps,
|
|
betas=train_dset.alpha_beta_terms["betas"],
|
|
is_angle=train_dset.feature_is_angular[feature_key],
|
|
disable_pbar=disable_pbar,
|
|
)
|
|
# Gets to size (timesteps, seq_len, n_ft)
|
|
trimmed_sampled = [
|
|
sampled[:, i, :l, :].numpy() for i, l in enumerate(this_lengths)
|
|
]
|
|
retval.extend(trimmed_sampled)
|
|
# Note that we don't use means variable here directly because we may need a subset
|
|
# of it based on which features are active in the dataset. The function
|
|
# get_masked_means handles this gracefully
|
|
if (
|
|
hasattr(train_dset, "dset")
|
|
and hasattr(train_dset.dset, "get_masked_means")
|
|
and train_dset.dset.get_masked_means() is not None
|
|
):
|
|
logging.info(
|
|
f"Shifting predicted values by original offset: {train_dset.dset.get_masked_means()}"
|
|
)
|
|
retval = [s + train_dset.dset.get_masked_means() for s in retval]
|
|
# Because shifting may have caused us to go across the circle boundary, re-wrap
|
|
angular_idx = np.where(train_dset.feature_is_angular[feature_key])[0]
|
|
for s in retval:
|
|
s[..., angular_idx] = utils.modulo_with_wrapped_range(
|
|
s[..., angular_idx], range_min=-np.pi, range_max=np.pi
|
|
)
|
|
|
|
return retval
|
|
|
|
|
|
def sample_simple(
|
|
model_dir: str, n: int = 10, sweep_lengths: Tuple[int, int] = (50, 128)
|
|
) -> List[pd.DataFrame]:
|
|
"""
|
|
Simple wrapper on sample to automatically load in the model and dummy dataset
|
|
Primarily for gradio integration
|
|
"""
|
|
if utils.is_huggingface_hub_id(model_dir):
|
|
model_dir = snapshot_download(model_dir)
|
|
assert os.path.isdir(model_dir)
|
|
|
|
with open(os.path.join(model_dir, "training_args.json")) as source:
|
|
training_args = json.load(source)
|
|
|
|
model = modelling.BertForDiffusionBase.from_dir(model_dir)
|
|
if torch.cuda.is_available():
|
|
model = model.to("cuda:0")
|
|
|
|
dummy_dset = dsets.AnglesEmptyDataset.from_dir(model_dir)
|
|
dummy_noised_dset = dsets.NoisedAnglesDataset(
|
|
dset=dummy_dset,
|
|
dset_key="coords" if training_args == "cart-cords" else "angles",
|
|
timesteps=training_args["timesteps"],
|
|
exhaustive_t=False,
|
|
beta_schedule=training_args["variance_schedule"],
|
|
nonangular_variance=1.0,
|
|
angular_variance=training_args["variance_scale"],
|
|
)
|
|
|
|
sampled = sample(
|
|
model, dummy_noised_dset, n=n, sweep_lengths=sweep_lengths, disable_pbar=True
|
|
)
|
|
final_sampled = [s[-1] for s in sampled]
|
|
sampled_dfs = [
|
|
pd.DataFrame(s, columns=dummy_noised_dset.feature_names["angles"])
|
|
for s in final_sampled
|
|
]
|
|
return sampled_dfs
|
|
|
|
|
|
def _score_angles(
|
|
reconst_angles: pd.DataFrame, truth_angles: pd.DataFrame, truth_coords_pdb: str
|
|
) -> Tuple[float, float]:
|
|
"""
|
|
Helper function to scores sets of angles
|
|
"""
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
truth_path = Path(tmpdir) / "truth.pdb"
|
|
reconst_path = Path(tmpdir) / "reconst.pdb"
|
|
|
|
truth_pdb = ac.create_new_chain_nerf(str(truth_path), truth_angles)
|
|
reconst_pdb = ac.create_new_chain_nerf(str(reconst_path), reconst_angles)
|
|
|
|
# Calculate WRT the truth angles
|
|
score = tmalign.run_tmalign(reconst_pdb, truth_pdb)
|
|
|
|
score_coord = tmalign.run_tmalign(reconst_pdb, truth_coords_pdb)
|
|
return score, score_coord
|
|
|
|
|
|
@torch.no_grad()
|
|
def get_reconstruction_error(
|
|
model: nn.Module, dset, noise_timesteps: int = 250, bs: int = 512
|
|
) -> Tuple[np.ndarray, np.ndarray]:
|
|
"""
|
|
Get the reconstruction error when adding <noise_timesteps> noise to the idx-th
|
|
item in the dataset.
|
|
"""
|
|
device = next(model.parameters()).device
|
|
model.eval()
|
|
|
|
recont_angle_sets = []
|
|
truth_angle_sets = []
|
|
truth_pdb_files = []
|
|
for idx_batch in tqdm(utils.seq_to_groups(list(range(len(dset))), bs)):
|
|
batch = default_collate(
|
|
[
|
|
{
|
|
k: v.to(device)
|
|
for k, v in dset.__getitem__(idx, use_t_val=noise_timesteps).items()
|
|
}
|
|
for idx in idx_batch
|
|
]
|
|
)
|
|
img = batch["corrupted"].clone()
|
|
assert img.ndim == 3
|
|
|
|
# Record the actual files containing raw coordinates
|
|
for i in idx_batch:
|
|
truth_pdb_files.append(dset.filenames[i])
|
|
|
|
# Run the diffusion model for noise_timesteps steps
|
|
for i in tqdm(list(reversed(list(range(0, noise_timesteps))))):
|
|
img = sampling.p_sample(
|
|
model=model,
|
|
x=img,
|
|
t=torch.full((len(idx_batch),), fill_value=i, dtype=torch.long).to(
|
|
device
|
|
),
|
|
seq_lens=batch["lengths"],
|
|
t_index=i,
|
|
betas=dset.alpha_beta_terms["betas"],
|
|
)
|
|
img = utils.modulo_with_wrapped_range(img)
|
|
|
|
# Finished reconstruction, subset to lengths and add to running list
|
|
for i, l in enumerate(batch["lengths"].squeeze()):
|
|
recont_angle_sets.append(
|
|
pd.DataFrame(img[i, :l].cpu().numpy(), columns=ac.EXHAUSTIVE_ANGLES)
|
|
)
|
|
truth_angle_sets.append(
|
|
pd.DataFrame(
|
|
batch["angles"][i, :l].cpu().numpy(), columns=ac.EXHAUSTIVE_ANGLES
|
|
)
|
|
)
|
|
|
|
# Get the reconstruction error as a TM score
|
|
logging.info(
|
|
f"Calculating TM scores for reconstruction error with {mp.cpu_count()} processes"
|
|
)
|
|
pool = mp.Pool(processes=mp.cpu_count())
|
|
results = pool.starmap(
|
|
_score_angles,
|
|
zip(recont_angle_sets, truth_angle_sets, truth_pdb_files),
|
|
chunksize=10,
|
|
)
|
|
pool.close()
|
|
pool.join()
|
|
scores, coord_scores = zip(*results)
|
|
return np.array(scores), np.array(coord_scores)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
s = sample_simple("wukevin/foldingdiff_cath", n=1, sweep_lengths=(50, 51))
|
|
for i, x in enumerate(s):
|
|
print(x.shape)
|
|
print(x)
|