diff --git a/bin/sample.py b/bin/sample.py index 76b6173..80af3ff 100644 --- a/bin/sample.py +++ b/bin/sample.py @@ -136,7 +136,7 @@ def compute_training_tm_scores( samp_name = os.path.splitext(os.path.basename(fname))[0] tm_score = tmalign.max_tm_across_refs( fname, - train_dset.dset.filenames, + train_dset.filenames, n_threads=nthreads, ) all_tm_scores[samp_name] = tm_score diff --git a/bin/tmscore_training.py b/bin/tmscore_training.py new file mode 100644 index 0000000..9f8ba53 --- /dev/null +++ b/bin/tmscore_training.py @@ -0,0 +1,53 @@ +""" +Compute the maximum TM score against training set +""" +import logging +import os, sys +from glob import glob +from pathlib import Path +import argparse + +from sample import compute_training_tm_scores +from datasets import CathCanonicalAnglesDataset + +SRC_DIR = (Path(os.path.dirname(os.path.abspath(__file__))) / "../protdiff").resolve() +assert SRC_DIR.is_dir() +sys.path.append(str(SRC_DIR)) + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "-d", + "--dirname", + type=str, + default=os.path.join(os.getcwd(), "sampled_pdb"), + help="Directory of generated PDB structures", + ) + return parser + + +def main(): + """Run the script""" + parser = build_parser() + args = parser.parse_args() + + assert os.path.isdir(args.dirname) + generated_pdbs = glob(os.path.join(args.dirname, "*.pdb")) + assert generated_pdbs + logging.info(f"Found {len(generated_pdbs)} generated structures") + + # we only need the filenames from the training dataset so it doesn't really matter + # what specific parameters we use to initialize it. The only important parameters are + # min_length, which is default to 40 and likely unchanged + train_dset = CathCanonicalAnglesDataset(split="train") + + # Calculate scores + compute_training_tm_scores(generated_pdbs, train_dset, Path(args.dirname)) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/protdiff/datasets.py b/protdiff/datasets.py index daff167..4235fb6 100644 --- a/protdiff/datasets.py +++ b/protdiff/datasets.py @@ -694,6 +694,11 @@ class NoisedAnglesDataset(Dataset): def pad(self): """Pas through the pad property of wrapped dset""" return self.dset.pad + + @property + def filenames(self): + """Pass through the filenames property of the wrapped dset""" + return self.dset.filenames def sample_length(self, *args, **kwargs): return self.dset.sample_length(*args, **kwargs)