fix: add distogram fallback for older af2plots

This commit is contained in:
Dima
2026-04-09 15:51:05 +02:00
parent d668a3342b
commit 39409d9387
2 changed files with 110 additions and 1 deletions

View File

@@ -131,6 +131,87 @@ def _parse_prediction_pickles(prediction_dir: str | Path) -> dict[str, dict]:
return parsed_models
def _softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray:
shifted = logits - np.max(logits, axis=axis, keepdims=True)
exponentiated = np.exp(shifted)
return exponentiated / np.sum(exponentiated, axis=axis, keepdims=True)
def _plot_distogram_fallback(
parsed_models: dict[str, dict],
*,
dpi: int = 100,
distance: float = 8.0,
) -> tuple[plt.Figure, list[str]] | None:
top_model = next(
(model for model in parsed_models.values() if model.get("rank") == 1),
max(parsed_models.values(), key=lambda model: float(model["ptm"]), default=None),
)
if top_model is None:
return None
predicted_distogram = top_model.get("distogram")
if not isinstance(predicted_distogram, dict):
return None
logits = predicted_distogram.get("logits")
bin_edges = predicted_distogram.get("bin_edges")
if logits is None or bin_edges is None:
return None
logits_array = np.asarray(logits)
bin_edges_array = np.asarray(bin_edges)
if logits_array.ndim != 3 or bin_edges_array.ndim != 1:
return None
upper_bounds = np.concatenate([bin_edges_array, [np.inf]])
threshold_mask = upper_bounds < float(np.clip(distance, 3.0, 20.0))
if not np.any(threshold_mask):
threshold_mask[0] = True
contact_probabilities = _softmax(logits_array, axis=-1)[..., threshold_mask].sum(axis=-1)
figure, axis = plt.subplots(figsize=(8, 8), dpi=dpi)
image = axis.imshow(
contact_probabilities,
cmap="coolwarm",
vmin=0.0,
vmax=1.0,
extent=(0, contact_probabilities.shape[0], contact_probabilities.shape[0], 0),
)
colorbar = figure.colorbar(image, ax=axis, fraction=0.046, pad=0.04)
colorbar.ax.set_ylabel(f"Probability(distance<{int(distance)}A)")
axis.set_title("Predicted contacts")
axis.set_xlabel("Residue number")
axis.set_ylabel("Residue number")
asym_id = top_model.get("asym_id")
assembly_num_chains = top_model.get("assembly_num_chains")
if asym_id is not None and assembly_num_chains is not None:
asym_id_array = np.asarray(asym_id)
for chain_index in range(int(assembly_num_chains) - 1):
chain_positions = np.where(asym_id_array == (chain_index + 1))[0]
if chain_positions.size == 0:
continue
chain_cut = int(chain_positions.max())
axis.axvline(x=chain_cut, ls="--", c="k", lw=1)
axis.axhline(y=chain_cut, ls="--", c="k", lw=1)
return figure, []
def _plot_distogram_compat(
af2_plotter: object,
parsed_models: dict[str, dict],
*,
dpi: int = 100,
) -> tuple[plt.Figure, list[str]] | None:
plot_distogram = getattr(af2_plotter, "plot_distogram", None)
if callable(plot_distogram):
return plot_distogram(parsed_models, dpi=dpi)
return _plot_distogram_fallback(parsed_models, dpi=dpi)
def save_msa_coverage_plot(
feature_pickle: str | Path,
output_dir: str | Path,
@@ -179,7 +260,7 @@ def save_prediction_plots(
plt.close(plddt_figure)
written_paths.append(plddt_path)
distogram_result = af2_plotter.plot_distogram(parsed_models, dpi=dpi)
distogram_result = _plot_distogram_compat(af2_plotter, parsed_models, dpi=dpi)
if distogram_result is not None:
distogram_figure, _ = distogram_result
distogram_path = output_prefix.with_name(f"{output_prefix.name}_distogram.png")

View File

@@ -2,6 +2,7 @@ import gzip
from pathlib import Path
import shutil
import alphapulldown.analysis_pipeline.diagnostics as diagnostics
from alphapulldown.analysis_pipeline.diagnostics import (
plot_inputs,
save_msa_coverage_plot,
@@ -74,3 +75,30 @@ def test_plot_inputs_accepts_gzip_compressed_prediction_dirs(tmp_path):
"compressed_prediction_pae.png",
"compressed_prediction_plddt.png",
]
def test_save_prediction_plots_falls_back_when_af2plots_has_no_distogram(monkeypatch, tmp_path):
real_plotter = diagnostics.plotter
class LegacyPlotter:
def __init__(self):
self._delegate = real_plotter()
def plot_predicted_alignment_error(self, *args, **kwargs):
return self._delegate.plot_predicted_alignment_error(*args, **kwargs)
def plot_plddts(self, *args, **kwargs):
return self._delegate.plot_plddts(*args, **kwargs)
monkeypatch.setattr(diagnostics, "plotter", LegacyPlotter)
written_paths = save_prediction_plots(
TEST_DATA / "predictions" / "TEST_homo_2er",
tmp_path,
)
assert [path.name for path in written_paths] == [
"TEST_homo_2er_pae.png",
"TEST_homo_2er_plddt.png",
"TEST_homo_2er_distogram.png",
]