diff --git a/RF2_allatom/data_loader.py b/RF2_allatom/data_loader.py index 271e8f0..f85f6bf 100644 --- a/RF2_allatom/data_loader.py +++ b/RF2_allatom/data_loader.py @@ -19,12 +19,17 @@ compl_dir = "/projects/ml/RoseTTAComplex" na_dir = "/home/dimaio/TrRosetta/nucleic" fb_dir = "/projects/ml/TrRosetta/fb_af" mol_dir = "/projects/ml/ligand_datasets/mmcif_parse_wlig" + if not os.path.exists(base_dir): - # training on blue - base_dir = "/gscratch2/PDB-2021AUG02" - compl_dir = "/gscratch2/RoseTTAComplex" - na_dir = "/gscratch2/nucleic" - fb_dir = "/gscratch2/fb_af1" + # training on AWS + #base_dir = "/gscratch2/PDB-2021AUG02" + #compl_dir = "/gscratch2/RoseTTAComplex" + #na_dir = "/gscratch2/nucleic" + #fb_dir = "/gscratch2/fb_af1" + base_dir = "/data/databases/PDB-2021AUG02" + fb_dir = "/data/databases/fb_af" + compl_dir = "/data/databases/RoseTTAComplex" + mol_dir = "/home/rohith" def set_data_loader_params(args): PARAMS = { diff --git a/RF2_allatom/train_multi_EMA.py b/RF2_allatom/train_multi_EMA.py index e30b58d..257bbf4 100644 --- a/RF2_allatom/train_multi_EMA.py +++ b/RF2_allatom/train_multi_EMA.py @@ -594,31 +594,31 @@ class Trainer(): loader_complex, valid_neg, self.loader_param, negative=True ) - valid_na_compl_set = DatasetNAComplex( - list(valid_na_compl.keys())[:self.n_valid_na_compl], - loader_na_complex, valid_na_compl, - self.loader_param, negative=False, native_NA_frac=1.0 - ) - valid_na_neg_set = DatasetNAComplex( - list(valid_na_neg.keys())[:self.n_valid_na_neg], - loader_na_complex, valid_na_neg, - self.loader_param, negative=True, native_NA_frac=1.0 - ) - valid_na_from_scratch_compl_set = DatasetNAComplex( - list(valid_na_compl.keys())[:self.n_valid_na_compl], - loader_na_complex, valid_na_compl, - self.loader_param, negative=False, native_NA_frac=0.0 - ) - valid_na_from_scratch_neg_set = DatasetNAComplex( - list(valid_na_neg.keys())[:self.n_valid_na_neg], - loader_na_complex, valid_na_neg, - self.loader_param, negative=True, native_NA_frac=0.0 - ) - valid_rna_set = DatasetRNA( - list(valid_rna.keys())[:self.n_valid_rna], - loader_rna, valid_rna, - self.loader_param - ) +# valid_na_compl_set = DatasetNAComplex( +# list(valid_na_compl.keys())[:self.n_valid_na_compl], +# loader_na_complex, valid_na_compl, +# self.loader_param, negative=False, native_NA_frac=1.0 +# ) +# valid_na_neg_set = DatasetNAComplex( +# list(valid_na_neg.keys())[:self.n_valid_na_neg], +# loader_na_complex, valid_na_neg, +# self.loader_param, negative=True, native_NA_frac=1.0 +# ) +# valid_na_from_scratch_compl_set = DatasetNAComplex( +# list(valid_na_compl.keys())[:self.n_valid_na_compl], +# loader_na_complex, valid_na_compl, +# self.loader_param, negative=False, native_NA_frac=0.0 +# ) +# valid_na_from_scratch_neg_set = DatasetNAComplex( +# list(valid_na_neg.keys())[:self.n_valid_na_neg], +# loader_na_complex, valid_na_neg, +# self.loader_param, negative=True, native_NA_frac=0.0 +# ) +# valid_rna_set = DatasetRNA( +# list(valid_rna.keys())[:self.n_valid_rna], +# loader_rna, valid_rna, +# self.loader_param +# ) valid_sm_compl_set = DatasetSMComplex( list(valid_sm_compl.keys())[:self.n_valid_sm_compl], loader_sm_compl, valid_sm_compl, @@ -650,11 +650,11 @@ class Trainer(): valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank) valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank) valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank) - valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank) - valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank) - valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank) - valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank) - valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank) +# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank) +# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank) +# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank) +# valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank) +# valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank) valid_sm_compl_sampler = data.distributed.DistributedSampler(valid_sm_compl_set, num_replicas=world_size, rank=rank) train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM) @@ -662,11 +662,11 @@ class Trainer(): valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM) valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM) valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM) - valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM) - valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM) - valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM) - valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM) - valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM) +# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM) +# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM) +# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM) +# valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM) +# valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM) valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set, sampler=valid_sm_compl_sampler, **LOAD_PARAM) # move some global data to cuda device