mirror of
https://github.com/Saoge123/PocketFlow.git
synced 2026-06-04 12:44:22 +08:00
78 lines
3.0 KiB
Python
78 lines
3.0 KiB
Python
import torch
|
|
from pocket_flow.gdbp_model import PocketFlow
|
|
from pocket_flow.utils import Experiment, LoadDataset
|
|
from pocket_flow.utils.transform import *
|
|
from pocket_flow.utils.data import ComplexData, torchify_dict
|
|
import os
|
|
from easydict import EasyDict
|
|
|
|
torch.multiprocessing.set_sharing_strategy('file_system')
|
|
|
|
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
|
|
|
|
|
|
encoder_cfg = EasyDict(
|
|
{'edge_channels':8, 'num_interactions':6, 'num_heads':4,
|
|
'knn':16, 'cutoff':10.0}
|
|
)
|
|
focal_net_cfg = EasyDict(
|
|
{'hidden_dim_sca':64, 'hidden_dim_vec':8}
|
|
)
|
|
atom_flow_cfg = EasyDict(
|
|
{'hidden_dim_sca':64, 'hidden_dim_vec':8, 'num_flow_layers':6}
|
|
)
|
|
pos_predictor_cfg = EasyDict(
|
|
{'num_filters':[64,64], 'n_component':3}
|
|
)
|
|
pos_filter_cfg = EasyDict(
|
|
{'edge_channels':8, 'num_filters':[64,16]}
|
|
)
|
|
edge_flow_cfg = EasyDict(
|
|
{'edge_channels':16, 'num_filters':[64,8], 'num_bond_types':3,
|
|
'num_heads':4, 'cutoff':10.0, 'num_flow_layers':6}
|
|
)
|
|
config = EasyDict(
|
|
{'deq_coeff':0.9, 'hidden_channels':64, 'hidden_channels_vec':16, 'use_conv1d':False, 'num_bond_types':4,
|
|
'bottleneck':(8,2), 'protein_atom_feature_dim':27, 'ligand_atom_feature_dim':15, 'num_atom_type':9,
|
|
'msg_annealing':True, 'encoder':encoder_cfg, 'atom_flow':atom_flow_cfg, 'pos_predictor':pos_predictor_cfg,
|
|
'edge_flow':edge_flow_cfg, 'focal_net':focal_net_cfg, 'pos_filter':pos_filter_cfg}
|
|
)
|
|
|
|
protein_featurizer = FeaturizeProteinAtom()
|
|
ligand_featurizer = FeaturizeLigandAtom(atomic_numbers=[6,7,8,9,15,16,17,35,53])
|
|
traj_fn = LigandTrajectory(perm_type='mix', num_atom_type=9)
|
|
focal_masker = FocalMaker(r=6.0, num_work=16, atomic_numbers=[6,7,8,9,15,16,17,35,53])
|
|
atom_composer = AtomComposer(
|
|
knn=16, num_workers=16, graph_type='knn', radius=10.0, use_protein_bond=False
|
|
)
|
|
combine = Combine(traj_fn, focal_masker, atom_composer, lig_only=True)
|
|
transform = TrajCompose([
|
|
RefineData(),
|
|
LigandCountNeighbors(),
|
|
protein_featurizer,
|
|
ligand_featurizer,
|
|
combine,
|
|
collate_fn
|
|
])
|
|
|
|
dataset = LoadDataset('pretrain_data/ZINC_PretrainingDataset.lmdb', transform=transform)
|
|
print('Num data:', len(dataset))
|
|
train_set, valid_set = LoadDataset.split(dataset, val_num=10000, shuffle=True, random_seed=0)
|
|
|
|
model = PocketFlow(config).to('cuda:0')
|
|
print(model.get_parameter_number())
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=2.e-4, weight_decay=0, betas=(0.99, 0.999))
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.6, patience=10, min_lr=1.e-5)
|
|
|
|
exp = Experiment(
|
|
model, train_set, optimizer, valid_set=valid_set, scheduler=scheduler,
|
|
device='cuda:0', data_parallel=False, use_amp=False
|
|
)
|
|
exp.fit_step(
|
|
1000000, valid_per_step=5000, train_batch_size=256, valid_batch_size=512, print_log=False,
|
|
with_tb=True, logdir='./pretraining_log', schedule_key='loss', num_workers=16,
|
|
pin_memory=False, follow_batch=[], exclude_keys=[], collate_fn=None,
|
|
max_edge_num_in_batch=2000000
|
|
)
|