feat(#256,#258): add diagnostics plotting script

This commit is contained in:
Dima
2026-04-09 14:55:42 +02:00
parent 40c68b5738
commit 0bd4d4771e
4 changed files with 266 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
"""Helpers for plotting AlphaPulldown diagnostic figures."""
from __future__ import annotations
from pathlib import Path
from typing import Iterable
import matplotlib
import numpy as np
matplotlib.use("Agg", force=True)
from matplotlib import pyplot as plt
from af2plots.plotter import plotter
from colabfold.plot import plot_msa_v2
from alphapulldown.utils.lightweight_pickles import extract_feature_dict, load_lightweight_pickle
def _normalise_stem(path: str | Path) -> str:
input_path = Path(path)
name = input_path.name
for suffix in (".pkl.xz", ".pkl.gz", ".pkl", ".json", ".xz", ".gz"):
if name.endswith(suffix):
name = name[: -len(suffix)]
break
return name or input_path.stem
def _ensure_output_dir(output_dir: str | Path) -> Path:
destination = Path(output_dir)
destination.mkdir(parents=True, exist_ok=True)
return destination
def _infer_asym_id_from_result_pickle(result_pickle: str | Path) -> tuple[list[int], int] | None:
payload = load_lightweight_pickle(result_pickle)
if not isinstance(payload, dict):
return None
seqs = payload.get("seqs")
if not isinstance(seqs, list) or not all(isinstance(sequence, str) for sequence in seqs):
return None
asym_id: list[int] = []
for index, sequence in enumerate(seqs, start=1):
asym_id.extend([index] * len(sequence))
return asym_id, len(seqs)
def _ensure_chain_metadata(parsed_models: dict[str, dict]) -> None:
for model_data in parsed_models.values():
if "asym_id" in model_data and "assembly_num_chains" in model_data:
continue
inferred = _infer_asym_id_from_result_pickle(model_data["fn"])
if inferred is None:
continue
asym_id, assembly_num_chains = inferred
model_data["asym_id"] = np.asarray(asym_id, dtype=np.int32)
model_data["assembly_num_chains"] = assembly_num_chains
def save_msa_coverage_plot(
feature_pickle: str | Path,
output_dir: str | Path,
*,
dpi: int = 100,
output_stem: str | None = None,
) -> Path:
"""Save a ColabFold-style MSA coverage plot from a feature pickle."""
payload = load_lightweight_pickle(feature_pickle)
feature_dict = extract_feature_dict(payload)
destination = _ensure_output_dir(output_dir)
plot_module = plot_msa_v2(feature_dict, dpi=dpi)
output_path = destination / f"{output_stem or _normalise_stem(feature_pickle)}_msa_coverage.png"
plot_module.savefig(output_path, bbox_inches="tight")
plot_module.close()
return output_path
def save_prediction_plots(
prediction_dir: str | Path,
output_dir: str | Path,
*,
dpi: int = 100,
) -> list[Path]:
"""Save pLDDT, PAE, and distogram plots from a prediction directory."""
prediction_root = Path(prediction_dir)
destination = _ensure_output_dir(output_dir)
af2_plotter = plotter()
parsed_models = af2_plotter.parse_model_pickles(str(prediction_root))
_ensure_chain_metadata(parsed_models)
output_prefix = destination / prediction_root.name
written_paths: list[Path] = []
pae_figure = af2_plotter.plot_predicted_alignment_error(parsed_models, dpi=dpi)
pae_path = output_prefix.with_name(f"{output_prefix.name}_pae.png")
pae_figure.savefig(pae_path, bbox_inches="tight")
plt.close(pae_figure)
written_paths.append(pae_path)
plddt_figure = af2_plotter.plot_plddts(parsed_models, dpi=dpi)
plddt_path = output_prefix.with_name(f"{output_prefix.name}_plddt.png")
plddt_figure.savefig(plddt_path, bbox_inches="tight")
plt.close(plddt_figure)
written_paths.append(plddt_path)
distogram_result = af2_plotter.plot_distogram(parsed_models, dpi=dpi)
if distogram_result is not None:
distogram_figure, _ = distogram_result
distogram_path = output_prefix.with_name(f"{output_prefix.name}_distogram.png")
distogram_figure.savefig(distogram_path, bbox_inches="tight")
plt.close(distogram_figure)
written_paths.append(distogram_path)
return written_paths
def plot_inputs(
inputs: Iterable[str | Path],
*,
output_dir: str | Path | None = None,
dpi: int = 100,
) -> list[Path]:
"""Dispatch plotting based on the provided input paths."""
written_paths: list[Path] = []
for raw_input in inputs:
input_path = Path(raw_input)
destination = Path(output_dir) if output_dir is not None else input_path.parent
if input_path.is_dir():
if list(input_path.glob("result*.pkl")):
written_paths.extend(save_prediction_plots(input_path, destination, dpi=dpi))
continue
feature_pickle = input_path / "features.pkl"
if feature_pickle.exists():
written_paths.append(
save_msa_coverage_plot(
feature_pickle,
destination,
dpi=dpi,
output_stem=input_path.name,
)
)
continue
raise FileNotFoundError(
f"{input_path} does not contain result*.pkl files or a features.pkl file"
)
written_paths.append(save_msa_coverage_plot(input_path, destination, dpi=dpi))
return written_paths

View File

@@ -0,0 +1,55 @@
#!/usr/bin/env python3
"""Generate AlphaPulldown diagnostic plots similar to ColabFold outputs."""
from __future__ import annotations
import argparse
from pathlib import Path
from alphapulldown.analysis_pipeline.diagnostics import plot_inputs
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Write MSA coverage, pLDDT, PAE, and distogram plots from "
"AlphaPulldown feature pickles or prediction directories."
)
)
parser.add_argument(
"inputs",
nargs="+",
help=(
"Feature pickles, directories containing features.pkl, or "
"prediction directories containing result*.pkl files."
),
)
parser.add_argument(
"--output_dir",
default=None,
help="Directory to write plots into. Defaults to the parent directory of each input.",
)
parser.add_argument(
"--dpi",
type=int,
default=100,
help="Matplotlib DPI to use for saved plots.",
)
return parser
def main(argv: list[str] | None = None) -> list[Path]:
parser = build_parser()
args = parser.parse_args(argv)
written_paths = plot_inputs(
args.inputs,
output_dir=args.output_dir,
dpi=args.dpi,
)
for path in written_paths:
print(path)
return written_paths
if __name__ == "__main__":
main()

View File

@@ -112,6 +112,7 @@ script-files = [
"./alphapulldown/scripts/generate_alphafold_server_json.py",
"./alphapulldown/analysis_pipeline/create_notebook.py",
"./alphapulldown/analysis_pipeline/get_good_inter_pae.py",
"./alphapulldown/analysis_pipeline/plot_diagnostics.py",
"./alphapulldown/scripts/rename_colab_search_a3m.py",
"./alphapulldown/scripts/prepare_seq_names.py",
"./alphapulldown/scripts/generate_crosslink_pickle.py",

View File

@@ -0,0 +1,54 @@
from pathlib import Path
from alphapulldown.analysis_pipeline.diagnostics import (
plot_inputs,
save_msa_coverage_plot,
save_prediction_plots,
)
TEST_DATA = Path(__file__).resolve().parents[1] / "test_data"
def test_save_prediction_plots_writes_core_diagnostics(tmp_path):
written_paths = save_prediction_plots(
TEST_DATA / "predictions" / "TEST_homo_2er",
tmp_path,
)
assert [path.name for path in written_paths] == [
"TEST_homo_2er_pae.png",
"TEST_homo_2er_plddt.png",
"TEST_homo_2er_distogram.png",
]
for path in written_paths:
assert path.exists()
assert path.stat().st_size > 0
def test_save_msa_coverage_plot_supports_monomer_pickles(tmp_path):
output_path = save_msa_coverage_plot(
TEST_DATA / "features" / "A0A024R1R8.pkl",
tmp_path,
)
assert output_path.name == "A0A024R1R8_msa_coverage.png"
assert output_path.exists()
assert output_path.stat().st_size > 0
def test_plot_inputs_accepts_feature_directories_and_prediction_dirs(tmp_path):
written_paths = plot_inputs(
[
TEST_DATA / "predictions" / "af_vs_ap" / "A0A024R1R8",
TEST_DATA / "predictions" / "TEST_homo_2er",
],
output_dir=tmp_path,
)
assert sorted(path.name for path in written_paths) == [
"A0A024R1R8_msa_coverage.png",
"TEST_homo_2er_distogram.png",
"TEST_homo_2er_pae.png",
"TEST_homo_2er_plddt.png",
]