Files
cgflow/experiments/scripts/analysis2_sampling.py
2025-07-15 20:48:43 -04:00

46 lines
1.3 KiB
Python

import os
import random
import time
import numpy as np
import torch
from tqdm import tqdm
from synthflow.config import Config, init_empty
from synthflow.pocket_specific.sampler import PocketSpecificSampler
def set_seed(seed: int):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if __name__ == "__main__":
# NOTE: Create sampler
config = init_empty(Config())
config.algo.num_from_policy = 64 # batch size
config.algo.action_subsampling.sampling_ratio = 0.1
device = "cuda"
for step in [10, 20, 40, 60, 80, 100]:
times = []
for seed in tqdm([0, 1, 2], leave=False):
set_seed(0)
ckpt_path = f"./logs/ana2-fm-nsteps/rebuttal-fm-nsteps/ALDH1/fm-{step}/seed-{seed}/model_state.pt"
if not os.path.exists(ckpt_path):
print("Checkpoint not found:", ckpt_path)
continue
sampler = PocketSpecificSampler(config, ckpt_path, device)
# NOTE: Run
st = time.time()
res = sampler.sample(128)
end = time.time()
times.append((end - st) / 128)
print(step, np.mean(times), np.std(times))