diff --git a/bin/mds_structures.py b/bin/mds_structures.py index 9679d75..a2acd2b 100644 --- a/bin/mds_structures.py +++ b/bin/mds_structures.py @@ -69,9 +69,21 @@ def main(): ) logging.info(f"Computing TMscore on {len(fnames)} structures") - pdist_df = get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm) - mds = MDS(n_components=2, dissimilarity="precomputed", n_jobs=-1, random_state=SEED) - embedding = pd.DataFrame(mds.fit_transform(pdist_df.values), index=pdist_df.index) + # in the dissimilarity matrix, larger values should indicate more distant points + # therefore we want to do 1 - TMscore (since larger values are closer in TMscore space) + pdist_df = 1.0 - get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm) + mds = MDS( + n_components=2, + dissimilarity="precomputed", + metric=False, # TMscores do not respect triangle inequality + n_jobs=-1, + random_state=SEED, + ) + embedding = pd.DataFrame( + mds.fit_transform(pdist_df.values), + index=pdist_df.index, + columns=["MDS1", "MDS2"], + ) format_strings = { "Number helices": "{x:.1f}", @@ -79,17 +91,23 @@ def main(): # For a variety of coloring keys, compute/read the scores and color scatter # plot by the scores. for k, v in { - "null": None, "Max training TM": args.trainingtm, "scTM": args.sctm, - "length": lambda x: len_pdb_structure(x), + "length": len_pdb_structure, "Number helices": lambda x: count_structures_in_pdb(x, "psea")[0], "Number sheets": lambda x: count_structures_in_pdb(x, "psea")[1], + "null": None, }.items(): if v is None or v: logging.info(f"Coloring by {k} scores") + figsize = (6.4, 4.8) + annot_points = False if v is None: scores = None + figsize = (12.8, 9.6) + annot_points = True + # If we are doing the null, the plot very big and label each + # point with the text id elif callable(v): fname_to_key = lambda f: os.path.basename(f).split(".")[0] scores = { @@ -104,7 +122,7 @@ def main(): else: raise ValueError(f"Invalid value for {k}: {v}") - fig, ax = plt.subplots(dpi=300) + fig, ax = plt.subplots(figsize=figsize, dpi=300) points = ax.scatter( embedding.iloc[:, 0], embedding.iloc[:, 1], @@ -113,6 +131,13 @@ def main(): cmap="RdYlBu", alpha=0.9, ) + if annot_points: + for i in range(len(embedding)): + ax.annotate( + embedding.index[i], + (embedding.iloc[i, 0], embedding.iloc[i, 1]), + fontsize=6, + ) ax.set( xlabel="MDS 1", ylabel="MDS 2", @@ -130,7 +155,7 @@ def main(): fraction=0.08, pad=0.04, location="right", - format=format_strings.get(k, None), + # format=format_strings.get(k, None), ) cbar.ax.set_ylabel(k, fontsize=12)