diff --git a/foldingdiff/tmalign.py b/foldingdiff/tmalign.py index 8377621..880d467 100644 --- a/foldingdiff/tmalign.py +++ b/foldingdiff/tmalign.py @@ -14,6 +14,7 @@ import logging from typing import * import numpy as np +import pandas as pd logging.basicConfig(level=logging.INFO) @@ -101,6 +102,11 @@ def match_files( pattern = re.compile("^" + k + r"[\-\_]+.*") for k2 in [k2 for k2 in ref_files_map if pattern.match(k2)]: retval[query_files_map[k]].append(ref_files_map[k2]) + elif strategy == "suffix": + for k in ref_files_map: + pattern = re.compile("^" + k + r"[\-\_]+.*") + for k2 in [k2 for k2 in query_files_map if pattern.match(k2)]: + retval[query_files_map[k2]].append(ref_files_map[k]) else: raise ValueError(f"Unknown strategy {strategy}") return retval @@ -111,12 +117,18 @@ def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("-q", "--query", type=str, nargs="+", help="Query files") parser.add_argument("-r", "--ref", type=str, nargs="+", help="Reference files") - parser.add_argument("-o", "--output", type=str, help="Output file") + parser.add_argument( + "-o", + "--output", + type=str, + required=True, + help="Output csv file to write TMscores to", + ) parser.add_argument( "-s", "--strat", type=str, - choices=["exact", "prefix"], + choices=["exact", "prefix", "suffix"], default="exact", help="Strategy for matching query and reference files", ) @@ -128,14 +140,26 @@ def main(): args = parse_args() query2refs = match_files(args.query, args.ref, args.strat) + logging.info( + f"Matched {len(query2refs)} pdb files to {len(list(itertools.chain.from_iterable(query2refs.values())))} reference files" + ) with mp.Pool(processes=mp.cpu_count()) as pool: out = list(pool.starmap(max_tm_across_refs, query2refs.items())) - tmscores, _best_matching = zip(*out) + tmscores, best_matching = zip(*out) logging.info(f"Mean TM-score: {np.nanmean(tmscores):.3f}") logging.info(f"Num >= 0.5: {np.sum(np.array(tmscores) >= 0.5)} / {len(tmscores)}") + out_table = pd.DataFrame( + data={ + "query": list(query2refs.keys()), + "tmscore": tmscores, + "matching": best_matching, + } + ) + out_table.to_csv(args.output, index=False) + if __name__ == "__main__": main()