fix conflict

This commit is contained in:
Rohith Krishna
2022-07-18 21:14:37 -07:00
2 changed files with 45 additions and 40 deletions

View File

@@ -19,12 +19,17 @@ compl_dir = "/projects/ml/RoseTTAComplex"
na_dir = "/home/dimaio/TrRosetta/nucleic"
fb_dir = "/projects/ml/TrRosetta/fb_af"
mol_dir = "/projects/ml/ligand_datasets/mmcif_parse_wlig"
if not os.path.exists(base_dir):
# training on blue
base_dir = "/gscratch2/PDB-2021AUG02"
compl_dir = "/gscratch2/RoseTTAComplex"
na_dir = "/gscratch2/nucleic"
fb_dir = "/gscratch2/fb_af1"
# training on AWS
#base_dir = "/gscratch2/PDB-2021AUG02"
#compl_dir = "/gscratch2/RoseTTAComplex"
#na_dir = "/gscratch2/nucleic"
#fb_dir = "/gscratch2/fb_af1"
base_dir = "/data/databases/PDB-2021AUG02"
fb_dir = "/data/databases/fb_af"
compl_dir = "/data/databases/RoseTTAComplex"
mol_dir = "/home/rohith"
def set_data_loader_params(args):
PARAMS = {

View File

@@ -594,31 +594,31 @@ class Trainer():
loader_complex, valid_neg,
self.loader_param, negative=True
)
valid_na_compl_set = DatasetNAComplex(
list(valid_na_compl.keys())[:self.n_valid_na_compl],
loader_na_complex, valid_na_compl,
self.loader_param, negative=False, native_NA_frac=1.0
)
valid_na_neg_set = DatasetNAComplex(
list(valid_na_neg.keys())[:self.n_valid_na_neg],
loader_na_complex, valid_na_neg,
self.loader_param, negative=True, native_NA_frac=1.0
)
valid_na_from_scratch_compl_set = DatasetNAComplex(
list(valid_na_compl.keys())[:self.n_valid_na_compl],
loader_na_complex, valid_na_compl,
self.loader_param, negative=False, native_NA_frac=0.0
)
valid_na_from_scratch_neg_set = DatasetNAComplex(
list(valid_na_neg.keys())[:self.n_valid_na_neg],
loader_na_complex, valid_na_neg,
self.loader_param, negative=True, native_NA_frac=0.0
)
valid_rna_set = DatasetRNA(
list(valid_rna.keys())[:self.n_valid_rna],
loader_rna, valid_rna,
self.loader_param
)
# valid_na_compl_set = DatasetNAComplex(
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
# loader_na_complex, valid_na_compl,
# self.loader_param, negative=False, native_NA_frac=1.0
# )
# valid_na_neg_set = DatasetNAComplex(
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
# loader_na_complex, valid_na_neg,
# self.loader_param, negative=True, native_NA_frac=1.0
# )
# valid_na_from_scratch_compl_set = DatasetNAComplex(
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
# loader_na_complex, valid_na_compl,
# self.loader_param, negative=False, native_NA_frac=0.0
# )
# valid_na_from_scratch_neg_set = DatasetNAComplex(
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
# loader_na_complex, valid_na_neg,
# self.loader_param, negative=True, native_NA_frac=0.0
# )
# valid_rna_set = DatasetRNA(
# list(valid_rna.keys())[:self.n_valid_rna],
# loader_rna, valid_rna,
# self.loader_param
# )
valid_sm_compl_set = DatasetSMComplex(
list(valid_sm_compl.keys())[:self.n_valid_sm_compl],
loader_sm_compl, valid_sm_compl,
@@ -650,11 +650,11 @@ class Trainer():
valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank)
valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank)
# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
# valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank)
# valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank)
valid_sm_compl_sampler = data.distributed.DistributedSampler(valid_sm_compl_set, num_replicas=world_size, rank=rank)
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)
@@ -662,11 +662,11 @@ class Trainer():
valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM)
valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM)
# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
# valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM)
# valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM)
valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set, sampler=valid_sm_compl_sampler, **LOAD_PARAM)
# move some global data to cuda device