From 3e6fc368cd3202be44f7715486e4aa4f7affdd83 Mon Sep 17 00:00:00 2001 From: Dima <33123184+DimaMolod@users.noreply.github.com> Date: Fri, 27 Mar 2026 15:29:48 +0100 Subject: [PATCH] New. --- test/unit/test_distogram_parser.py | 38 ++++++++++ test/unit/test_file_handling.py | 118 +++++++++++++++++++++++++++++ test/unit/test_plotting.py | 17 +++++ test/unit/test_post_modelling.py | 94 +++++++++++++++++++++++ 4 files changed, 267 insertions(+) create mode 100644 test/unit/test_distogram_parser.py create mode 100644 test/unit/test_file_handling.py create mode 100644 test/unit/test_plotting.py create mode 100644 test/unit/test_post_modelling.py diff --git a/test/unit/test_distogram_parser.py b/test/unit/test_distogram_parser.py new file mode 100644 index 00000000..4ca0bc07 --- /dev/null +++ b/test/unit/test_distogram_parser.py @@ -0,0 +1,38 @@ +import pickle + +import numpy as np + +import alphapulldown.utils.distogram_parser as distogram_parser_module + + +def test_get_contacts_returns_empty_list_when_no_pickles_exist(monkeypatch, tmp_path): + monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False) + + parser = distogram_parser_module.distogram_parser() + + assert parser.get_contacts("ignored") == [] + + +def test_get_contacts_extracts_top_ranked_inter_chain_contact(monkeypatch, tmp_path): + logits = np.full((4, 4, 3), -10.0, dtype=np.float32) + logits[0, 2, 0] = 10.0 + logits[2, 0, 0] = 10.0 + payload = { + "ranking_confidence": 0.9, + "seqs": ["AA", "BB"], + "distogram": { + "bin_edges": np.array([4.0, 8.0, 12.0], dtype=np.float32), + "logits": logits, + }, + } + with open(tmp_path / "result_model.pkl", "wb") as handle: + pickle.dump(payload, handle) + monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False) + + parser = distogram_parser_module.distogram_parser() + contacts = parser.get_contacts("ignored", distance=9, pbtycutoff=0.5, cross_only=True) + + assert len(contacts) == 1 + assert contacts[0][0] == (1, "A") + assert contacts[0][1] == (1, "B") + assert contacts[0][2] > 0.99 diff --git a/test/unit/test_file_handling.py b/test/unit/test_file_handling.py new file mode 100644 index 00000000..8489a875 --- /dev/null +++ b/test/unit/test_file_handling.py @@ -0,0 +1,118 @@ +from pathlib import Path + +import pytest + +from alphapulldown.utils import file_handling + + +def test_temp_fasta_file_writes_contents_and_cleans_up(): + with file_handling.temp_fasta_file(">seq\nACDE\n") as fasta_path: + path = Path(fasta_path) + assert path.is_file() + assert path.read_text(encoding="utf-8") == ">seq\nACDE\n" + + assert not path.exists() + + +def test_convert_fasta_description_to_protein_name_sanitizes_symbols(): + protein_name = file_handling.convert_fasta_description_to_protein_name( + ">sp|P12345|Protein A:chain;1?" + ) + + assert protein_name == "sp_P12345_Protein_A_chain_1_" + + +def test_parse_fasta_parses_multiple_sequences_and_skips_blank_lines(): + sequences, descriptions = file_handling.parse_fasta( + ">protein one\nACD\n\nEF\n>protein|two\nGHI\n" + ) + + assert sequences == ["ACDEF", "GHI"] + assert descriptions == ["protein_one", "protein_two"] + + +def test_iter_seqs_yields_sequences_from_multiple_files(tmp_path): + fasta_a = tmp_path / "a.fasta" + fasta_b = tmp_path / "b.fasta" + fasta_a.write_text(">first\nAAA\n", encoding="utf-8") + fasta_b.write_text(">second\nBBB\n", encoding="utf-8") + + records = list(file_handling.iter_seqs([str(fasta_a), str(fasta_b)])) + + assert records == [("AAA", "first"), ("BBB", "second")] + + +def test_make_dir_monomer_dictionary_maps_files_to_their_source_dirs(tmp_path): + dir_a = tmp_path / "a" + dir_b = tmp_path / "b" + dir_a.mkdir() + dir_b.mkdir() + (dir_a / "proteinA.pkl").write_text("", encoding="utf-8") + (dir_b / "proteinB.pkl.xz").write_text("", encoding="utf-8") + + result = file_handling.make_dir_monomer_dictionary([str(dir_a), str(dir_b)]) + + assert result == { + "proteinA.pkl": str(dir_a), + "proteinB.pkl.xz": str(dir_b), + } + + +def test_parse_csv_file_raises_for_missing_fasta(): + with pytest.raises(FileNotFoundError): + file_handling.parse_csv_file("missing.csv", ["missing.fasta"], "/templates") + + +def test_parse_csv_file_returns_unique_records_without_clustering(tmp_path): + fasta = tmp_path / "proteins.fasta" + fasta.write_text(">proteinA\nACDE\n", encoding="utf-8") + csv_path = tmp_path / "description.csv" + csv_path.write_text( + "proteinA,template1.cif,A\nproteinA,template2.cif,B\ninvalid,row\nmissing,template3.cif,C\n", + encoding="utf-8", + ) + + result = file_handling.parse_csv_file( + str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=False + ) + + assert result == [ + { + "protein": "proteinA.template1.cif.A", + "sequence": "ACDE", + "templates": [str(tmp_path / "templates" / "template1.cif")], + "chains": ["A"], + }, + { + "protein": "proteinA.template2.cif.B", + "sequence": "ACDE", + "templates": [str(tmp_path / "templates" / "template2.cif")], + "chains": ["B"], + }, + ] + + +def test_parse_csv_file_clusters_multiple_templates_per_protein(tmp_path): + fasta = tmp_path / "proteins.fasta" + fasta.write_text(">proteinA\nACDE\n", encoding="utf-8") + csv_path = tmp_path / "description.csv" + csv_path.write_text( + "proteinA,template1.cif,A\nproteinA,template2.cif,B\n", + encoding="utf-8", + ) + + result = file_handling.parse_csv_file( + str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=True + ) + + assert result == [ + { + "protein": "proteinA", + "sequence": "ACDE", + "templates": [ + str(tmp_path / "templates" / "template1.cif"), + str(tmp_path / "templates" / "template2.cif"), + ], + "chains": ["A", "B"], + } + ] diff --git a/test/unit/test_plotting.py b/test/unit/test_plotting.py new file mode 100644 index 00000000..a461fc43 --- /dev/null +++ b/test/unit/test_plotting.py @@ -0,0 +1,17 @@ +import numpy as np + +from alphapulldown.utils.plotting import plot_pae_from_matrix + + +def test_plot_pae_from_matrix_writes_png_file(tmp_path): + output = tmp_path / "pae.png" + + plot_pae_from_matrix( + seqs=["AA", "BBB"], + pae_matrix=np.arange(25, dtype=float).reshape(5, 5), + figure_name=str(output), + ranking=2, + ) + + assert output.is_file() + assert output.stat().st_size > 0 diff --git a/test/unit/test_post_modelling.py b/test/unit/test_post_modelling.py new file mode 100644 index 00000000..9d6a3b17 --- /dev/null +++ b/test/unit/test_post_modelling.py @@ -0,0 +1,94 @@ +import gzip +import json +import pickle +from pathlib import Path + +from alphapulldown.utils import post_modelling + + +def test_compress_file_gzips_and_removes_original(tmp_path): + file_path = tmp_path / "result.pkl" + file_path.write_bytes(b"payload") + + gz_path = post_modelling.compress_file(str(file_path)) + + assert gz_path == str(file_path) + ".gz" + assert not file_path.exists() + with gzip.open(gz_path, "rb") as handle: + assert handle.read() == b"payload" + + +def test_compress_result_pickles_only_compresses_pickle_files(tmp_path): + pickle_path = tmp_path / "result.pkl" + text_path = tmp_path / "notes.txt" + pickle_path.write_bytes(b"payload") + text_path.write_text("keep", encoding="utf-8") + + post_modelling.compress_result_pickles(str(tmp_path)) + + assert not pickle_path.exists() + assert (tmp_path / "result.pkl.gz").is_file() + assert text_path.is_file() + + +def test_remove_keys_from_pickle_updates_pickle_in_place(tmp_path): + file_path = tmp_path / "result.pkl" + with open(file_path, "wb") as handle: + pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle) + + post_modelling.remove_keys_from_pickle(str(file_path), ["distogram", "missing"]) + + with open(file_path, "rb") as handle: + payload = pickle.load(handle) + + assert payload == {"keep": 1, "masked_msa": 3} + + +def test_remove_irrelevant_pickles_keeps_only_best_pickle(tmp_path): + best = tmp_path / "result_model_1.pkl" + other = tmp_path / "result_model_2.pkl" + note = tmp_path / "note.txt" + best.write_bytes(b"best") + other.write_bytes(b"other") + note.write_text("keep", encoding="utf-8") + + post_modelling.remove_irrelevant_pickles(str(tmp_path), best.name) + + assert best.is_file() + assert not other.exists() + assert note.is_file() + + +def test_post_prediction_process_remove_keys_then_compress_and_prune(tmp_path): + ranking_debug = tmp_path / "ranking_debug.json" + ranking_debug.write_text(json.dumps({"order": ["model_1"]}), encoding="utf-8") + result_best = tmp_path / "result_model_1.pkl" + result_other = tmp_path / "result_model_2.pkl" + with open(result_best, "wb") as handle: + pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle) + with open(result_other, "wb") as handle: + pickle.dump({"keep": 4, "aligned_confidence_probs": 5}, handle) + + post_modelling.post_prediction_process( + str(tmp_path), + compress_pickles=True, + remove_pickles=True, + remove_keys=True, + ) + + assert not result_best.exists() + assert not result_other.exists() + gz_best = tmp_path / "result_model_1.pkl.gz" + assert gz_best.is_file() + with gzip.open(gz_best, "rb") as handle: + payload = pickle.load(handle) + assert payload == {"keep": 1} + + +def test_post_prediction_process_handles_missing_ranking_debug(tmp_path): + lonely_pickle = tmp_path / "result_model_1.pkl" + lonely_pickle.write_bytes(b"payload") + + post_modelling.post_prediction_process(str(tmp_path), compress_pickles=True) + + assert lonely_pickle.is_file()