mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-05 22:37:33 +08:00
105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
"""
|
|
Script to sample from a trained diffusion model
|
|
"""
|
|
import os, sys
|
|
import logging
|
|
import json
|
|
from pathlib import Path
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch.nn import functional as F
|
|
|
|
from transformers import BertConfig
|
|
|
|
SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve()
|
|
assert SRC_DIR.is_dir()
|
|
sys.path.append(str(SRC_DIR))
|
|
import modelling
|
|
import beta_schedules
|
|
import sampling
|
|
import utils
|
|
|
|
|
|
def sample(
|
|
num: int,
|
|
dset_obj,
|
|
model_path: str,
|
|
config_json: Optional[str] = None,
|
|
seed: int = 6489,
|
|
) -> List[torch.Tensor]:
|
|
"""
|
|
Sample from the given model
|
|
"""
|
|
assert hasattr(
|
|
dset_obj, "sample_length"
|
|
), "Passed dataset object must have a sample_length attribute"
|
|
# Load in the model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
cfg = BertConfig(hidden_size=144, position_embedding_type="relative_key_query")
|
|
model = modelling.BertForDiffusion.load_from_checkpoint(
|
|
checkpoint_path=model_path, config=cfg
|
|
)
|
|
model.eval()
|
|
model.to(device)
|
|
|
|
# Reproduce the variance schedules bsaed on the config json
|
|
if config_json is None:
|
|
# Try to find a default config
|
|
config_json = os.path.join(os.path.dirname(model_path, "config.json"))
|
|
assert os.path.isfile(
|
|
config_json
|
|
), f"Could not automatically find config at {config_json}"
|
|
with open(config_json) as source:
|
|
model_config = json.load(source)
|
|
betas = beta_schedules.get_variance_schedule(
|
|
model_config["variance_schedule"], model_config["timesteps"]
|
|
)
|
|
|
|
# Calculate posterior variance
|
|
alphas = 1.0 - betas
|
|
# corresponds to bar alpha, product up till t of the first t 1-B terms
|
|
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
|
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
|
|
|
# Posterior variance, higher variance wih greater t
|
|
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
|
|
|
# Sample
|
|
# batch 128 ~ 9GB GPU memory, batch 512 ~ 38GB GPU memory
|
|
torch.manual_seed(seed)
|
|
samps = []
|
|
for bs in utils.num_to_groups(num, 512):
|
|
seq_lens = [dset_obj.sample_length() for _ in range(bs)]
|
|
s = sampling.sample(
|
|
model,
|
|
seq_lens=seq_lens,
|
|
seq_max_len=model.config.max_position_embeddings,
|
|
betas=betas,
|
|
posterior_variance=posterior_variance,
|
|
timesteps=model_config["timesteps"],
|
|
batch_size=bs,
|
|
)
|
|
samps.extend(s)
|
|
# samps = torch.vstack(samps)
|
|
return samps
|
|
|
|
|
|
def main():
|
|
import datasets
|
|
|
|
cath_dset = datasets.CathConsecutiveAnglesDataset(split="train", toy=True)
|
|
x = sample(
|
|
10,
|
|
cath_dset,
|
|
"/home/t-kevinwu/projects/protein_diffusion/models_initial/1000_timesteps_linear_variance_schedule_64_batch_size_0.0001_lr_0.5_gradient_clip/lightning_logs/version_0/checkpoints/epoch=9-step=1990.ckpt",
|
|
"/home/t-kevinwu/projects/protein_diffusion/models_initial/1000_timesteps_linear_variance_schedule_64_batch_size_0.0001_lr_0.5_gradient_clip/training_args.json",
|
|
)
|
|
for item in x:
|
|
print(item.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main()
|