latest training code

This commit is contained in:
Rohith
2022-07-29 16:10:14 -07:00
parent f7679cb34c
commit 6eeeb900c6
2 changed files with 8 additions and 12 deletions

View File

@@ -1604,11 +1604,7 @@ class DatasetSMComplex(data.Dataset):
def __getitem__(self, index):
ID = self.IDs[index]
sel_idx = np.random.randint(0, len(self.item_dict[ID]))
# remove pdbs with BeF2 ligands, oddly behaved with rdkit
item = self.item_dict[ID][sel_idx][0]
while item[0] in ["1xhf", "1l5y", "4ukd"]:
sel_idx = np.random.randint(0, len(self.item_dict[ID]))
item = self.item_dict[ID][sel_idx][0]
out = self.loader(
self.item_dict[ID][sel_idx][0],
self.item_dict[ID][sel_idx][2],

View File

@@ -22,7 +22,7 @@ from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedu
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
#torch.autograd.set_detect_anomaly(True)
torch.autograd.set_detect_anomaly(True)
torch.manual_seed(5924)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
@@ -733,7 +733,7 @@ class Trainer():
# load model
loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler,
self.model_name, gpu, resume_train=True)
self.model_name, gpu, suffix="best", resume_train=True)
if (self.eval):
# run protein/NA prediction (TEMPLATED)
@@ -742,13 +742,13 @@ class Trainer():
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run protein/NA prediction (NON-TEMPLATED)
_, _, _ = self.valid_ppi_cycle(
ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
#_, _, _ = self.valid_ppi_cycle(
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
# run RNA prediction
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True)
_, _, _ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, 0, verbose=True)
dist.destroy_process_group()
return
@@ -816,7 +816,7 @@ class Trainer():
'valid_loss': valid_loss,
'valid_acc': valid_acc,
'best_loss': best_valid_loss},
self.checkpoint_fn(self.model_name, 'last'))
self.checkpoint_fn(self.model_name, str(epoch)))
dist.destroy_process_group()