mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
149 lines
4.4 KiB
Python
149 lines
4.4 KiB
Python
"""
|
|
Run heirarchical clustering on the pairwise distance matrix between all pairs of files
|
|
"""
|
|
|
|
import os, sys
|
|
import re
|
|
import json
|
|
import logging
|
|
from glob import glob
|
|
from pathlib import Path
|
|
import itertools
|
|
import argparse
|
|
from typing import *
|
|
import multiprocessing as mp
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import scipy.spatial as sp, scipy.cluster.hierarchy as hc
|
|
import seaborn as sns
|
|
|
|
from train import get_train_valid_test_sets
|
|
|
|
from foldingdiff import tmalign
|
|
|
|
# :)
|
|
SEED = int(
|
|
float.fromhex("2254616977616e2069732061206672656520636f756e74727922") % 10000
|
|
)
|
|
|
|
|
|
def int_getter(x: str) -> int:
|
|
"""Fetches integer value out of a string"""
|
|
matches = re.findall(r"[0-9]+", x)
|
|
assert len(matches) == 1
|
|
return int(matches.pop())
|
|
|
|
|
|
def get_pairwise_tmscores(
|
|
fnames: Collection[str], sctm_scores_json: Optional[str] = None
|
|
) -> pd.DataFrame:
|
|
"""
|
|
Get the pairwise TM scores across all fnames. If sctm_scores_json is given
|
|
then we filter the fnames by passing scTM scores.
|
|
"""
|
|
logging.info(f"Computing pairwise distances between {len(fnames)} pdb files")
|
|
|
|
bname_getter = lambda x: os.path.splitext(os.path.basename(x))[0]
|
|
if sctm_scores_json:
|
|
with open(sctm_scores_json) as source:
|
|
sctm_scores = json.load(source)
|
|
fnames = [f for f in fnames if sctm_scores[bname_getter(f)] >= 0.5]
|
|
logging.info(f"{len(fnames)} structures have scTM scores >= 0.5")
|
|
|
|
# for debugging
|
|
# fnames = fnames[:50]
|
|
|
|
pairs = list(itertools.combinations(fnames, 2))
|
|
pool = mp.Pool(mp.cpu_count())
|
|
values = list(pool.starmap(tmalign.run_tmalign, pairs, chunksize=25))
|
|
pool.close()
|
|
pool.join()
|
|
|
|
bnames = [bname_getter(f) for f in fnames]
|
|
retval = pd.DataFrame(1.0, index=bnames, columns=bnames)
|
|
for (k, v), val in zip(pairs, values):
|
|
retval.loc[bname_getter(k), bname_getter(v)] = val
|
|
retval.loc[bname_getter(v), bname_getter(k)] = val
|
|
assert np.allclose(retval, retval.T)
|
|
return retval
|
|
|
|
|
|
def build_parser():
|
|
"""Build a CLI parser"""
|
|
parser = argparse.ArgumentParser(
|
|
usage=__doc__, formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
|
)
|
|
g = parser.add_mutually_exclusive_group()
|
|
g.add_argument("--dirname", type=str, help="Directory of PDB files to analyze")
|
|
g.add_argument("--testsubset", type=int, help="Subset of test set sequences to run")
|
|
parser.add_argument(
|
|
"--sctm", type=str, required=False, default="", help="scTM scores to filter by"
|
|
)
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output",
|
|
type=str,
|
|
default="tmscore_hclust.pdf",
|
|
help="PDF file to write output clustering plot",
|
|
)
|
|
return parser
|
|
|
|
|
|
def main():
|
|
"""Run the script"""
|
|
parser = build_parser()
|
|
args = parser.parse_args()
|
|
|
|
# Get the files
|
|
if args.dirname:
|
|
fnames = sorted(
|
|
glob(os.path.join(args.dirname, "*.pdb")),
|
|
key=lambda x: int_getter(os.path.basename(x)),
|
|
)
|
|
assert fnames, f"{args.dirname} does not contain any pdb files"
|
|
elif args.testsubset:
|
|
# We only care about fnames here
|
|
*_, test_subset = get_train_valid_test_sets(
|
|
max_seq_len=128,
|
|
min_seq_len=50,
|
|
seq_trim_strategy="discard",
|
|
)
|
|
rng = np.random.default_rng(SEED)
|
|
idx = rng.choice(
|
|
len(test_subset.filenames), size=args.testsubset, replace=False
|
|
)
|
|
fnames = [test_subset.filenames[i] for i in idx]
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
# TMscore of 1 = perfect match --> 0 distance, so need 1.0 - tmscore
|
|
pdist_df = 1.0 - get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
|
|
|
|
# https://stackoverflow.com/questions/38705359/how-to-give-sns-clustermap-a-precomputed-distance-matrix
|
|
# https://stackoverflow.com/questions/57308725/pass-distance-matrix-to-seaborn-clustermap
|
|
m = "average" # Trippe uses average here
|
|
linkage = hc.linkage(
|
|
sp.distance.squareform(pdist_df), method=m, optimal_ordering=False
|
|
)
|
|
|
|
c = sns.clustermap(
|
|
pdist_df,
|
|
row_linkage=linkage,
|
|
col_linkage=linkage,
|
|
method=None,
|
|
row_cluster=True,
|
|
col_cluster=True,
|
|
vmin=0.0,
|
|
vmax=1.0,
|
|
xticklabels=False,
|
|
yticklabels=False,
|
|
cbar_kws={"label": r"$d(x, y) = 1 - \mathrm{TMscore}(x, y)$"},
|
|
)
|
|
c.savefig(args.output)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
logging.basicConfig(level=logging.INFO)
|
|
main()
|