mirror of
https://github.com/tsa87/cgflow.git
synced 2026-06-04 12:14:22 +08:00
29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
import wandb
|
|
|
|
from synthflow.config import Config, init_empty
|
|
from synthflow.pocket_conditional.trainer_unidock import UniDock_MultiPocket_Trainer
|
|
|
|
if __name__ == "__main__":
|
|
"""Example of how this trainer can be run"""
|
|
wandb.init(project="cgflow-update", group="sbdd-unidock")
|
|
|
|
config = init_empty(Config())
|
|
config.env_dir = "./data/envs/stock"
|
|
config.log_dir = "./logs/rebuttal-multipocket/sbdd_unidock"
|
|
config.print_every = 10
|
|
config.checkpoint_every = 500
|
|
config.store_all_checkpoints = True
|
|
|
|
config.algo.train_random_action_prob = 0.1
|
|
config.algo.action_subsampling.sampling_ratio = 0.1 # stock
|
|
|
|
config.cgflow.ckpt_path = "../weights/crossdocked2020_till_end.ckpt"
|
|
config.cgflow.use_predicted_pose = True
|
|
config.cgflow.num_inference_steps = 100
|
|
|
|
config.task.pocket_conditional.pocket_dir = "./data/CrossDocked2020/"
|
|
config.task.pocket_conditional.train_key = "./data/CrossDocked2020/train_keys.csv"
|
|
|
|
trainer = UniDock_MultiPocket_Trainer(config)
|
|
trainer.run()
|