mirror of
https://github.com/microsoft/foldingdiff.git
synced 2026-06-04 13:30:33 +08:00
Visual tweaks
This commit is contained in:
@@ -69,9 +69,21 @@ def main():
|
||||
)
|
||||
logging.info(f"Computing TMscore on {len(fnames)} structures")
|
||||
|
||||
pdist_df = get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
|
||||
mds = MDS(n_components=2, dissimilarity="precomputed", n_jobs=-1, random_state=SEED)
|
||||
embedding = pd.DataFrame(mds.fit_transform(pdist_df.values), index=pdist_df.index)
|
||||
# in the dissimilarity matrix, larger values should indicate more distant points
|
||||
# therefore we want to do 1 - TMscore (since larger values are closer in TMscore space)
|
||||
pdist_df = 1.0 - get_pairwise_tmscores(fnames, sctm_scores_json=args.sctm)
|
||||
mds = MDS(
|
||||
n_components=2,
|
||||
dissimilarity="precomputed",
|
||||
metric=False, # TMscores do not respect triangle inequality
|
||||
n_jobs=-1,
|
||||
random_state=SEED,
|
||||
)
|
||||
embedding = pd.DataFrame(
|
||||
mds.fit_transform(pdist_df.values),
|
||||
index=pdist_df.index,
|
||||
columns=["MDS1", "MDS2"],
|
||||
)
|
||||
|
||||
format_strings = {
|
||||
"Number helices": "{x:.1f}",
|
||||
@@ -79,17 +91,23 @@ def main():
|
||||
# For a variety of coloring keys, compute/read the scores and color scatter
|
||||
# plot by the scores.
|
||||
for k, v in {
|
||||
"null": None,
|
||||
"Max training TM": args.trainingtm,
|
||||
"scTM": args.sctm,
|
||||
"length": lambda x: len_pdb_structure(x),
|
||||
"length": len_pdb_structure,
|
||||
"Number helices": lambda x: count_structures_in_pdb(x, "psea")[0],
|
||||
"Number sheets": lambda x: count_structures_in_pdb(x, "psea")[1],
|
||||
"null": None,
|
||||
}.items():
|
||||
if v is None or v:
|
||||
logging.info(f"Coloring by {k} scores")
|
||||
figsize = (6.4, 4.8)
|
||||
annot_points = False
|
||||
if v is None:
|
||||
scores = None
|
||||
figsize = (12.8, 9.6)
|
||||
annot_points = True
|
||||
# If we are doing the null, the plot very big and label each
|
||||
# point with the text id
|
||||
elif callable(v):
|
||||
fname_to_key = lambda f: os.path.basename(f).split(".")[0]
|
||||
scores = {
|
||||
@@ -104,7 +122,7 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Invalid value for {k}: {v}")
|
||||
|
||||
fig, ax = plt.subplots(dpi=300)
|
||||
fig, ax = plt.subplots(figsize=figsize, dpi=300)
|
||||
points = ax.scatter(
|
||||
embedding.iloc[:, 0],
|
||||
embedding.iloc[:, 1],
|
||||
@@ -113,6 +131,13 @@ def main():
|
||||
cmap="RdYlBu",
|
||||
alpha=0.9,
|
||||
)
|
||||
if annot_points:
|
||||
for i in range(len(embedding)):
|
||||
ax.annotate(
|
||||
embedding.index[i],
|
||||
(embedding.iloc[i, 0], embedding.iloc[i, 1]),
|
||||
fontsize=6,
|
||||
)
|
||||
ax.set(
|
||||
xlabel="MDS 1",
|
||||
ylabel="MDS 2",
|
||||
@@ -130,7 +155,7 @@ def main():
|
||||
fraction=0.08,
|
||||
pad=0.04,
|
||||
location="right",
|
||||
format=format_strings.get(k, None),
|
||||
# format=format_strings.get(k, None),
|
||||
)
|
||||
cbar.ax.set_ylabel(k, fontsize=12)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user