diff --git a/bin/train.py b/bin/train.py index a4fbb31..003392d 100644 --- a/bin/train.py +++ b/bin/train.py @@ -109,6 +109,7 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None: def get_train_valid_test_sets( + dataset_key:str = "cath", angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles", max_seq_len: int = 512, min_seq_len: int = 0, @@ -144,6 +145,7 @@ def get_train_valid_test_sets( logging.info(f"Creating data splits: {splits}") clean_dsets = [ clean_dset_class( + pdbs=dataset_key, split=s, pad=max_seq_len, min_length=min_seq_len, @@ -286,6 +288,7 @@ def train( # Controls output results_dir: str = "./results", # Controls data loading and noising process + dataset_key: str = "cath", # cath, alhpafold, or a directory containing pdb files angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles", max_seq_len: int = 512, min_seq_len: int = 0, # 0 means no filtering based on min sequence length @@ -340,6 +343,7 @@ def train( # Get datasets and wrap them in dataloaders dsets = get_train_valid_test_sets( + dataset_key=dataset_key, angles_definitions=angles_definitions, max_seq_len=max_seq_len, min_seq_len=min_seq_len, diff --git a/foldingdiff/datasets.py b/foldingdiff/datasets.py index a0c6da9..72e78d4 100644 --- a/foldingdiff/datasets.py +++ b/foldingdiff/datasets.py @@ -246,7 +246,7 @@ class CathCanonicalAnglesDataset(Dataset): fnames = glob.glob(os.path.join(CATH_DIR, "dompdb", "*")) assert fnames, f"No files found in {CATH_DIR}/dompdb" elif pdbs == "alphafold": - pdbs = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz")) + fnames = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz")) assert fnames, f"No files found in {ALPHAFOLD_DIR}" else: raise ValueError(f"Unknown pdb set: {pdbs}")