mirror of
https://github.com/RosettaCommons/foundry.git
synced 2026-06-04 13:24:22 +08:00
fix conflict
This commit is contained in:
@@ -19,12 +19,17 @@ compl_dir = "/projects/ml/RoseTTAComplex"
|
||||
na_dir = "/home/dimaio/TrRosetta/nucleic"
|
||||
fb_dir = "/projects/ml/TrRosetta/fb_af"
|
||||
mol_dir = "/projects/ml/ligand_datasets/mmcif_parse_wlig"
|
||||
|
||||
if not os.path.exists(base_dir):
|
||||
# training on blue
|
||||
base_dir = "/gscratch2/PDB-2021AUG02"
|
||||
compl_dir = "/gscratch2/RoseTTAComplex"
|
||||
na_dir = "/gscratch2/nucleic"
|
||||
fb_dir = "/gscratch2/fb_af1"
|
||||
# training on AWS
|
||||
#base_dir = "/gscratch2/PDB-2021AUG02"
|
||||
#compl_dir = "/gscratch2/RoseTTAComplex"
|
||||
#na_dir = "/gscratch2/nucleic"
|
||||
#fb_dir = "/gscratch2/fb_af1"
|
||||
base_dir = "/data/databases/PDB-2021AUG02"
|
||||
fb_dir = "/data/databases/fb_af"
|
||||
compl_dir = "/data/databases/RoseTTAComplex"
|
||||
mol_dir = "/home/rohith"
|
||||
|
||||
def set_data_loader_params(args):
|
||||
PARAMS = {
|
||||
|
||||
@@ -594,31 +594,31 @@ class Trainer():
|
||||
loader_complex, valid_neg,
|
||||
self.loader_param, negative=True
|
||||
)
|
||||
valid_na_compl_set = DatasetNAComplex(
|
||||
list(valid_na_compl.keys())[:self.n_valid_na_compl],
|
||||
loader_na_complex, valid_na_compl,
|
||||
self.loader_param, negative=False, native_NA_frac=1.0
|
||||
)
|
||||
valid_na_neg_set = DatasetNAComplex(
|
||||
list(valid_na_neg.keys())[:self.n_valid_na_neg],
|
||||
loader_na_complex, valid_na_neg,
|
||||
self.loader_param, negative=True, native_NA_frac=1.0
|
||||
)
|
||||
valid_na_from_scratch_compl_set = DatasetNAComplex(
|
||||
list(valid_na_compl.keys())[:self.n_valid_na_compl],
|
||||
loader_na_complex, valid_na_compl,
|
||||
self.loader_param, negative=False, native_NA_frac=0.0
|
||||
)
|
||||
valid_na_from_scratch_neg_set = DatasetNAComplex(
|
||||
list(valid_na_neg.keys())[:self.n_valid_na_neg],
|
||||
loader_na_complex, valid_na_neg,
|
||||
self.loader_param, negative=True, native_NA_frac=0.0
|
||||
)
|
||||
valid_rna_set = DatasetRNA(
|
||||
list(valid_rna.keys())[:self.n_valid_rna],
|
||||
loader_rna, valid_rna,
|
||||
self.loader_param
|
||||
)
|
||||
# valid_na_compl_set = DatasetNAComplex(
|
||||
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
|
||||
# loader_na_complex, valid_na_compl,
|
||||
# self.loader_param, negative=False, native_NA_frac=1.0
|
||||
# )
|
||||
# valid_na_neg_set = DatasetNAComplex(
|
||||
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
|
||||
# loader_na_complex, valid_na_neg,
|
||||
# self.loader_param, negative=True, native_NA_frac=1.0
|
||||
# )
|
||||
# valid_na_from_scratch_compl_set = DatasetNAComplex(
|
||||
# list(valid_na_compl.keys())[:self.n_valid_na_compl],
|
||||
# loader_na_complex, valid_na_compl,
|
||||
# self.loader_param, negative=False, native_NA_frac=0.0
|
||||
# )
|
||||
# valid_na_from_scratch_neg_set = DatasetNAComplex(
|
||||
# list(valid_na_neg.keys())[:self.n_valid_na_neg],
|
||||
# loader_na_complex, valid_na_neg,
|
||||
# self.loader_param, negative=True, native_NA_frac=0.0
|
||||
# )
|
||||
# valid_rna_set = DatasetRNA(
|
||||
# list(valid_rna.keys())[:self.n_valid_rna],
|
||||
# loader_rna, valid_rna,
|
||||
# self.loader_param
|
||||
# )
|
||||
valid_sm_compl_set = DatasetSMComplex(
|
||||
list(valid_sm_compl.keys())[:self.n_valid_sm_compl],
|
||||
loader_sm_compl, valid_sm_compl,
|
||||
@@ -650,11 +650,11 @@ class Trainer():
|
||||
valid_homo_sampler = data.distributed.DistributedSampler(valid_homo_set, num_replicas=world_size, rank=rank)
|
||||
valid_compl_sampler = data.distributed.DistributedSampler(valid_compl_set, num_replicas=world_size, rank=rank)
|
||||
valid_neg_sampler = data.distributed.DistributedSampler(valid_neg_set, num_replicas=world_size, rank=rank)
|
||||
valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
|
||||
valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
|
||||
valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
|
||||
valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank)
|
||||
valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_compl_sampler = data.distributed.DistributedSampler(valid_na_compl_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_neg_sampler = data.distributed.DistributedSampler(valid_na_neg_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_from_scratch_compl_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_compl_set, num_replicas=world_size, rank=rank)
|
||||
# valid_na_from_scratch_neg_sampler = data.distributed.DistributedSampler(valid_na_from_scratch_neg_set, num_replicas=world_size, rank=rank)
|
||||
# valid_rna_sampler = data.distributed.DistributedSampler(valid_rna_set, num_replicas=world_size, rank=rank)
|
||||
valid_sm_compl_sampler = data.distributed.DistributedSampler(valid_sm_compl_set, num_replicas=world_size, rank=rank)
|
||||
|
||||
train_loader = data.DataLoader(train_set, sampler=train_sampler, batch_size=self.batch_size, **LOAD_PARAM)
|
||||
@@ -662,11 +662,11 @@ class Trainer():
|
||||
valid_homo_loader = data.DataLoader(valid_homo_set, sampler=valid_homo_sampler, **LOAD_PARAM)
|
||||
valid_compl_loader = data.DataLoader(valid_compl_set, sampler=valid_compl_sampler, **LOAD_PARAM)
|
||||
valid_neg_loader = data.DataLoader(valid_neg_set, sampler=valid_neg_sampler, **LOAD_PARAM)
|
||||
valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
|
||||
valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
|
||||
valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
|
||||
valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM)
|
||||
valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM)
|
||||
# valid_na_compl_loader = data.DataLoader(valid_na_compl_set, sampler=valid_na_compl_sampler, **LOAD_PARAM)
|
||||
# valid_na_neg_loader = data.DataLoader(valid_na_neg_set, sampler=valid_na_neg_sampler, **LOAD_PARAM)
|
||||
# valid_na_from_scratch_compl_loader = data.DataLoader(valid_na_from_scratch_compl_set, sampler=valid_na_from_scratch_compl_sampler, **LOAD_PARAM)
|
||||
# valid_na_from_scratch_neg_loader = data.DataLoader(valid_na_from_scratch_neg_set, sampler=valid_na_from_scratch_neg_sampler, **LOAD_PARAM)
|
||||
# valid_rna_loader = data.DataLoader(valid_rna_set, sampler=valid_rna_sampler, **LOAD_PARAM)
|
||||
valid_sm_compl_loader = data.DataLoader(valid_sm_compl_set, sampler=valid_sm_compl_sampler, **LOAD_PARAM)
|
||||
|
||||
# move some global data to cuda device
|
||||
|
||||
Reference in New Issue
Block a user