diff --git a/alphapulldown/analysis_pipeline/diagnostics.py b/alphapulldown/analysis_pipeline/diagnostics.py index c8bf321e..c2dad42d 100644 --- a/alphapulldown/analysis_pipeline/diagnostics.py +++ b/alphapulldown/analysis_pipeline/diagnostics.py @@ -131,6 +131,87 @@ def _parse_prediction_pickles(prediction_dir: str | Path) -> dict[str, dict]: return parsed_models +def _softmax(logits: np.ndarray, axis: int = -1) -> np.ndarray: + shifted = logits - np.max(logits, axis=axis, keepdims=True) + exponentiated = np.exp(shifted) + return exponentiated / np.sum(exponentiated, axis=axis, keepdims=True) + + +def _plot_distogram_fallback( + parsed_models: dict[str, dict], + *, + dpi: int = 100, + distance: float = 8.0, +) -> tuple[plt.Figure, list[str]] | None: + top_model = next( + (model for model in parsed_models.values() if model.get("rank") == 1), + max(parsed_models.values(), key=lambda model: float(model["ptm"]), default=None), + ) + if top_model is None: + return None + + predicted_distogram = top_model.get("distogram") + if not isinstance(predicted_distogram, dict): + return None + + logits = predicted_distogram.get("logits") + bin_edges = predicted_distogram.get("bin_edges") + if logits is None or bin_edges is None: + return None + + logits_array = np.asarray(logits) + bin_edges_array = np.asarray(bin_edges) + if logits_array.ndim != 3 or bin_edges_array.ndim != 1: + return None + + upper_bounds = np.concatenate([bin_edges_array, [np.inf]]) + threshold_mask = upper_bounds < float(np.clip(distance, 3.0, 20.0)) + if not np.any(threshold_mask): + threshold_mask[0] = True + + contact_probabilities = _softmax(logits_array, axis=-1)[..., threshold_mask].sum(axis=-1) + + figure, axis = plt.subplots(figsize=(8, 8), dpi=dpi) + image = axis.imshow( + contact_probabilities, + cmap="coolwarm", + vmin=0.0, + vmax=1.0, + extent=(0, contact_probabilities.shape[0], contact_probabilities.shape[0], 0), + ) + colorbar = figure.colorbar(image, ax=axis, fraction=0.046, pad=0.04) + colorbar.ax.set_ylabel(f"Probability(distance<{int(distance)}A)") + axis.set_title("Predicted contacts") + axis.set_xlabel("Residue number") + axis.set_ylabel("Residue number") + + asym_id = top_model.get("asym_id") + assembly_num_chains = top_model.get("assembly_num_chains") + if asym_id is not None and assembly_num_chains is not None: + asym_id_array = np.asarray(asym_id) + for chain_index in range(int(assembly_num_chains) - 1): + chain_positions = np.where(asym_id_array == (chain_index + 1))[0] + if chain_positions.size == 0: + continue + chain_cut = int(chain_positions.max()) + axis.axvline(x=chain_cut, ls="--", c="k", lw=1) + axis.axhline(y=chain_cut, ls="--", c="k", lw=1) + + return figure, [] + + +def _plot_distogram_compat( + af2_plotter: object, + parsed_models: dict[str, dict], + *, + dpi: int = 100, +) -> tuple[plt.Figure, list[str]] | None: + plot_distogram = getattr(af2_plotter, "plot_distogram", None) + if callable(plot_distogram): + return plot_distogram(parsed_models, dpi=dpi) + return _plot_distogram_fallback(parsed_models, dpi=dpi) + + def save_msa_coverage_plot( feature_pickle: str | Path, output_dir: str | Path, @@ -179,7 +260,7 @@ def save_prediction_plots( plt.close(plddt_figure) written_paths.append(plddt_path) - distogram_result = af2_plotter.plot_distogram(parsed_models, dpi=dpi) + distogram_result = _plot_distogram_compat(af2_plotter, parsed_models, dpi=dpi) if distogram_result is not None: distogram_figure, _ = distogram_result distogram_path = output_prefix.with_name(f"{output_prefix.name}_distogram.png") diff --git a/test/unit/test_diagnostics.py b/test/unit/test_diagnostics.py index b0359793..91537cae 100644 --- a/test/unit/test_diagnostics.py +++ b/test/unit/test_diagnostics.py @@ -2,6 +2,7 @@ import gzip from pathlib import Path import shutil +import alphapulldown.analysis_pipeline.diagnostics as diagnostics from alphapulldown.analysis_pipeline.diagnostics import ( plot_inputs, save_msa_coverage_plot, @@ -74,3 +75,30 @@ def test_plot_inputs_accepts_gzip_compressed_prediction_dirs(tmp_path): "compressed_prediction_pae.png", "compressed_prediction_plddt.png", ] + + +def test_save_prediction_plots_falls_back_when_af2plots_has_no_distogram(monkeypatch, tmp_path): + real_plotter = diagnostics.plotter + + class LegacyPlotter: + def __init__(self): + self._delegate = real_plotter() + + def plot_predicted_alignment_error(self, *args, **kwargs): + return self._delegate.plot_predicted_alignment_error(*args, **kwargs) + + def plot_plddts(self, *args, **kwargs): + return self._delegate.plot_plddts(*args, **kwargs) + + monkeypatch.setattr(diagnostics, "plotter", LegacyPlotter) + + written_paths = save_prediction_plots( + TEST_DATA / "predictions" / "TEST_homo_2er", + tmp_path, + ) + + assert [path.name for path in written_paths] == [ + "TEST_homo_2er_pae.png", + "TEST_homo_2er_plddt.png", + "TEST_homo_2er_distogram.png", + ]