mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
latest training code
This commit is contained in:
@@ -1604,11 +1604,7 @@ class DatasetSMComplex(data.Dataset):
|
||||
def __getitem__(self, index):
|
||||
ID = self.IDs[index]
|
||||
sel_idx = np.random.randint(0, len(self.item_dict[ID]))
|
||||
# remove pdbs with BeF2 ligands, oddly behaved with rdkit
|
||||
item = self.item_dict[ID][sel_idx][0]
|
||||
while item[0] in ["1xhf", "1l5y", "4ukd"]:
|
||||
sel_idx = np.random.randint(0, len(self.item_dict[ID]))
|
||||
item = self.item_dict[ID][sel_idx][0]
|
||||
|
||||
out = self.loader(
|
||||
self.item_dict[ID][sel_idx][0],
|
||||
self.item_dict[ID][sel_idx][2],
|
||||
|
||||
@@ -22,7 +22,7 @@ from scheduler import get_linear_schedule_with_warmup, get_stepwise_decay_schedu
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
#torch.autograd.set_detect_anomaly(True)
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
torch.manual_seed(5924)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
@@ -733,7 +733,7 @@ class Trainer():
|
||||
|
||||
# load model
|
||||
loaded_epoch, best_valid_loss = self.load_model(ddp_model, optimizer, scheduler, scaler,
|
||||
self.model_name, gpu, resume_train=True)
|
||||
self.model_name, gpu, suffix="best", resume_train=True)
|
||||
|
||||
if (self.eval):
|
||||
# run protein/NA prediction (TEMPLATED)
|
||||
@@ -742,13 +742,13 @@ class Trainer():
|
||||
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
|
||||
# run protein/NA prediction (NON-TEMPLATED)
|
||||
_, _, _ = self.valid_ppi_cycle(
|
||||
ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
|
||||
rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
#_, _, _ = self.valid_ppi_cycle(
|
||||
# ddp_model, valid_na_from_scratch_compl_loader, valid_na_from_scratch_neg_loader,
|
||||
# rank, gpu, world_size, 0, header="NA", report_interface=False, verbose=True)
|
||||
|
||||
# run RNA prediction
|
||||
#_,_,_ = self.valid_pdb_cycle(ddp_model, valid_rna_loader, rank, gpu, world_size, 0, verbose=True)
|
||||
|
||||
_, _, _ = self.valid_pdb_cycle(ddp_model, valid_sm_compl_loader, rank, gpu, world_size, 0, verbose=True)
|
||||
dist.destroy_process_group()
|
||||
return
|
||||
|
||||
@@ -816,7 +816,7 @@ class Trainer():
|
||||
'valid_loss': valid_loss,
|
||||
'valid_acc': valid_acc,
|
||||
'best_loss': best_valid_loss},
|
||||
self.checkpoint_fn(self.model_name, 'last'))
|
||||
self.checkpoint_fn(self.model_name, str(epoch)))
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user