Visual tweaks

This commit is contained in:
Kevin Wu
2023-02-09 18:15:39 -08:00
parent ec2e5e9902
commit 02dcfa85fe

View File

@@ -69,9 +69,21 @@ def main():
)
logging.info(f"Computing TMscore on {len(fnames)} structures")
pdist_df = get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
mds = MDS(n_components=2, dissimilarity="precomputed", n_jobs=-1, random_state=SEED)
embedding = pd.DataFrame(mds.fit_transform(pdist_df.values), index=pdist_df.index)
# in the dissimilarity matrix, larger values should indicate more distant points
# therefore we want to do 1 - TMscore (since larger values are closer in TMscore space)
pdist_df = 1.0 - get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
mds = MDS(
n_components=2,
dissimilarity="precomputed",
metric=False, # TMscores do not respect triangle inequality
n_jobs=-1,
random_state=SEED,
)
embedding = pd.DataFrame(
mds.fit_transform(pdist_df.values),
index=pdist_df.index,
columns=["MDS1", "MDS2"],
)
format_strings = {
"Number helices": "{x:.1f}",
@@ -79,17 +91,23 @@ def main():
# For a variety of coloring keys, compute/read the scores and color scatter
# plot by the scores.
for k, v in {
"null": None,
"Max training TM": args.trainingtm,
"scTM": args.sctm,
"length": lambda x: len_pdb_structure(x),
"length": len_pdb_structure,
"Number helices": lambda x: count_structures_in_pdb(x, "psea")[0],
"Number sheets": lambda x: count_structures_in_pdb(x, "psea")[1],
"null": None,
}.items():
if v is None or v:
logging.info(f"Coloring by {k} scores")
figsize = (6.4, 4.8)
annot_points = False
if v is None:
scores = None
figsize = (12.8, 9.6)
annot_points = True
# If we are doing the null, the plot very big and label each
# point with the text id
elif callable(v):
fname_to_key = lambda f: os.path.basename(f).split(".")[0]
scores = {
@@ -104,7 +122,7 @@ def main():
else:
raise ValueError(f"Invalid value for {k}: {v}")
fig, ax = plt.subplots(dpi=300)
fig, ax = plt.subplots(figsize=figsize, dpi=300)
points = ax.scatter(
embedding.iloc[:, 0],
embedding.iloc[:, 1],
@@ -113,6 +131,13 @@ def main():
cmap="RdYlBu",
alpha=0.9,
)
if annot_points:
for i in range(len(embedding)):
ax.annotate(
embedding.index[i],
(embedding.iloc[i, 0], embedding.iloc[i, 1]),
fontsize=6,
)
ax.set(
xlabel="MDS 1",
ylabel="MDS 2",
@@ -130,7 +155,7 @@ def main():
fraction=0.08,
pad=0.04,
location="right",
format=format_strings.get(k, None),
# format=format_strings.get(k, None),
)
cbar.ax.set_ylabel(k, fontsize=12)