Support for specifying dataset in training script

This commit is contained in:
Kevin Wu
2022-12-08 11:04:36 -08:00
parent b3c87e0206
commit 09f9f15e37
2 changed files with 5 additions and 1 deletions

View File

@@ -109,6 +109,7 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
def get_train_valid_test_sets(
dataset_key:str = "cath",
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
max_seq_len: int = 512,
min_seq_len: int = 0,
@@ -144,6 +145,7 @@ def get_train_valid_test_sets(
logging.info(f"Creating data splits: {splits}")
clean_dsets = [
clean_dset_class(
pdbs=dataset_key,
split=s,
pad=max_seq_len,
min_length=min_seq_len,
@@ -286,6 +288,7 @@ def train(
# Controls output
results_dir: str = "./results",
# Controls data loading and noising process
dataset_key: str = "cath", # cath, alhpafold, or a directory containing pdb files
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
max_seq_len: int = 512,
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
@@ -340,6 +343,7 @@ def train(
# Get datasets and wrap them in dataloaders
dsets = get_train_valid_test_sets(
dataset_key=dataset_key,
angles_definitions=angles_definitions,
max_seq_len=max_seq_len,
min_seq_len=min_seq_len,

View File

@@ -246,7 +246,7 @@ class CathCanonicalAnglesDataset(Dataset):
fnames = glob.glob(os.path.join(CATH_DIR, "dompdb", "*"))
assert fnames, f"No files found in {CATH_DIR}/dompdb"
elif pdbs == "alphafold":
pdbs = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz"))
fnames = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz"))
assert fnames, f"No files found in {ALPHAFOLD_DIR}"
else:
raise ValueError(f"Unknown pdb set: {pdbs}")