Improvements to TMscore script interface

This commit is contained in:
Kevin Wu
2023-10-22 21:30:20 -07:00
parent fb5bc22929
commit e3c18d1a11

View File

@@ -14,6 +14,7 @@ import logging
from typing import *
import numpy as np
import pandas as pd
logging.basicConfig(level=logging.INFO)
@@ -101,6 +102,11 @@ def match_files(
pattern = re.compile("^" + k + r"[\-\_]+.*")
for k2 in [k2 for k2 in ref_files_map if pattern.match(k2)]:
retval[query_files_map[k]].append(ref_files_map[k2])
elif strategy == "suffix":
for k in ref_files_map:
pattern = re.compile("^" + k + r"[\-\_]+.*")
for k2 in [k2 for k2 in query_files_map if pattern.match(k2)]:
retval[query_files_map[k2]].append(ref_files_map[k])
else:
raise ValueError(f"Unknown strategy {strategy}")
return retval
@@ -111,12 +117,18 @@ def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument("-q", "--query", type=str, nargs="+", help="Query files")
parser.add_argument("-r", "--ref", type=str, nargs="+", help="Reference files")
parser.add_argument("-o", "--output", type=str, help="Output file")
parser.add_argument(
"-o",
"--output",
type=str,
required=True,
help="Output csv file to write TMscores to",
)
parser.add_argument(
"-s",
"--strat",
type=str,
choices=["exact", "prefix"],
choices=["exact", "prefix", "suffix"],
default="exact",
help="Strategy for matching query and reference files",
)
@@ -128,14 +140,26 @@ def main():
args = parse_args()
query2refs = match_files(args.query, args.ref, args.strat)
logging.info(
f"Matched {len(query2refs)} pdb files to {len(list(itertools.chain.from_iterable(query2refs.values())))} reference files"
)
with mp.Pool(processes=mp.cpu_count()) as pool:
out = list(pool.starmap(max_tm_across_refs, query2refs.items()))
tmscores, _best_matching = zip(*out)
tmscores, best_matching = zip(*out)
logging.info(f"Mean TM-score: {np.nanmean(tmscores):.3f}")
logging.info(f"Num >= 0.5: {np.sum(np.array(tmscores) >= 0.5)} / {len(tmscores)}")
out_table = pd.DataFrame(
data={
"query": list(query2refs.keys()),
"tmscore": tmscores,
"matching": best_matching,
}
)
out_table.to_csv(args.output, index=False)
if __name__ == "__main__":
main()