Files
foldingdiff/bin/sample.py
2022-07-25 18:19:47 +00:00

105 lines
3.3 KiB
Python

"""
Script to sample from a trained diffusion model
"""
import os, sys
import logging
import json
from pathlib import Path
from typing import List, Optional
import torch
from torch.nn import functional as F
from transformers import BertConfig
SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve()
assert SRC_DIR.is_dir()
sys.path.append(str(SRC_DIR))
import modelling
import beta_schedules
import sampling
import utils
def sample(
num: int,
dset_obj,
model_path: str,
config_json: Optional[str] = None,
seed: int = 6489,
) -> List[torch.Tensor]:
"""
Sample from the given model
"""
assert hasattr(
dset_obj, "sample_length"
), "Passed dataset object must have a sample_length attribute"
# Load in the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = BertConfig(hidden_size=144, position_embedding_type="relative_key_query")
model = modelling.BertForDiffusion.load_from_checkpoint(
checkpoint_path=model_path, config=cfg
)
model.eval()
model.to(device)
# Reproduce the variance schedules bsaed on the config json
if config_json is None:
# Try to find a default config
config_json = os.path.join(os.path.dirname(model_path, "config.json"))
assert os.path.isfile(
config_json
), f"Could not automatically find config at {config_json}"
with open(config_json) as source:
model_config = json.load(source)
betas = beta_schedules.get_variance_schedule(
model_config["variance_schedule"], model_config["timesteps"]
)
# Calculate posterior variance
alphas = 1.0 - betas
# corresponds to bar alpha, product up till t of the first t 1-B terms
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
# Posterior variance, higher variance wih greater t
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# Sample
# batch 128 ~ 9GB GPU memory, batch 512 ~ 38GB GPU memory
torch.manual_seed(seed)
samps = []
for bs in utils.num_to_groups(num, 512):
seq_lens = [dset_obj.sample_length() for _ in range(bs)]
s = sampling.sample(
model,
seq_lens=seq_lens,
seq_max_len=model.config.max_position_embeddings,
betas=betas,
posterior_variance=posterior_variance,
timesteps=model_config["timesteps"],
batch_size=bs,
)
samps.extend(s)
# samps = torch.vstack(samps)
return samps
def main():
import datasets
cath_dset = datasets.CathConsecutiveAnglesDataset(split="train", toy=True)
x = sample(
10,
cath_dset,
"/home/t-kevinwu/projects/protein_diffusion/models_initial/1000_timesteps_linear_variance_schedule_64_batch_size_0.0001_lr_0.5_gradient_clip/lightning_logs/version_0/checkpoints/epoch=9-step=1990.ckpt",
"/home/t-kevinwu/projects/protein_diffusion/models_initial/1000_timesteps_linear_variance_schedule_64_batch_size_0.0001_lr_0.5_gradient_clip/training_args.json",
)
for item in x:
print(item.shape)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()