mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Support for specifying dataset in training script
This commit is contained in:
@@ -109,6 +109,7 @@ def plot_kl_divergence(train_dset, plots_folder: Path) -> None:
|
||||
|
||||
|
||||
def get_train_valid_test_sets(
|
||||
dataset_key:str = "cath",
|
||||
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0,
|
||||
@@ -144,6 +145,7 @@ def get_train_valid_test_sets(
|
||||
logging.info(f"Creating data splits: {splits}")
|
||||
clean_dsets = [
|
||||
clean_dset_class(
|
||||
pdbs=dataset_key,
|
||||
split=s,
|
||||
pad=max_seq_len,
|
||||
min_length=min_seq_len,
|
||||
@@ -286,6 +288,7 @@ def train(
|
||||
# Controls output
|
||||
results_dir: str = "./results",
|
||||
# Controls data loading and noising process
|
||||
dataset_key: str = "cath", # cath, alhpafold, or a directory containing pdb files
|
||||
angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
|
||||
max_seq_len: int = 512,
|
||||
min_seq_len: int = 0, # 0 means no filtering based on min sequence length
|
||||
@@ -340,6 +343,7 @@ def train(
|
||||
|
||||
# Get datasets and wrap them in dataloaders
|
||||
dsets = get_train_valid_test_sets(
|
||||
dataset_key=dataset_key,
|
||||
angles_definitions=angles_definitions,
|
||||
max_seq_len=max_seq_len,
|
||||
min_seq_len=min_seq_len,
|
||||
|
||||
@@ -246,7 +246,7 @@ class CathCanonicalAnglesDataset(Dataset):
|
||||
fnames = glob.glob(os.path.join(CATH_DIR, "dompdb", "*"))
|
||||
assert fnames, f"No files found in {CATH_DIR}/dompdb"
|
||||
elif pdbs == "alphafold":
|
||||
pdbs = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz"))
|
||||
fnames = glob.glob(os.path.join(ALPHAFOLD_DIR, "*.pdb.gz"))
|
||||
assert fnames, f"No files found in {ALPHAFOLD_DIR}"
|
||||
else:
|
||||
raise ValueError(f"Unknown pdb set: {pdbs}")
|
||||
|
||||
Reference in New Issue
Block a user