mirror of
https://github.com/tsa87/cgflow.git
synced 2026-06-04 12:14:22 +08:00
46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
import os
|
|
import random
|
|
import time
|
|
|
|
import numpy as np
|
|
import torch
|
|
from tqdm import tqdm
|
|
|
|
from synthflow.config import Config, init_empty
|
|
from synthflow.pocket_specific.sampler import PocketSpecificSampler
|
|
|
|
|
|
def set_seed(seed: int):
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
torch.cuda.manual_seed_all(seed)
|
|
np.random.seed(seed)
|
|
random.seed(seed)
|
|
torch.backends.cudnn.deterministic = True
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# NOTE: Create sampler
|
|
config = init_empty(Config())
|
|
config.algo.num_from_policy = 64 # batch size
|
|
config.algo.action_subsampling.sampling_ratio = 0.1
|
|
device = "cuda"
|
|
|
|
for step in [10, 20, 40, 60, 80, 100]:
|
|
times = []
|
|
for seed in tqdm([0, 1, 2], leave=False):
|
|
set_seed(0)
|
|
ckpt_path = f"./logs/ana2-fm-nsteps/rebuttal-fm-nsteps/ALDH1/fm-{step}/seed-{seed}/model_state.pt"
|
|
if not os.path.exists(ckpt_path):
|
|
print("Checkpoint not found:", ckpt_path)
|
|
continue
|
|
sampler = PocketSpecificSampler(config, ckpt_path, device)
|
|
|
|
# NOTE: Run
|
|
st = time.time()
|
|
res = sampler.sample(128)
|
|
end = time.time()
|
|
times.append((end - st) / 128)
|
|
print(step, np.mean(times), np.std(times))
|