This commit is contained in:
Dima
2026-03-27 15:29:48 +01:00
parent cbfaf51b0e
commit 3e6fc368cd
4 changed files with 267 additions and 0 deletions

View File

@@ -0,0 +1,38 @@
import pickle
import numpy as np
import alphapulldown.utils.distogram_parser as distogram_parser_module
def test_get_contacts_returns_empty_list_when_no_pickles_exist(monkeypatch, tmp_path):
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
parser = distogram_parser_module.distogram_parser()
assert parser.get_contacts("ignored") == []
def test_get_contacts_extracts_top_ranked_inter_chain_contact(monkeypatch, tmp_path):
logits = np.full((4, 4, 3), -10.0, dtype=np.float32)
logits[0, 2, 0] = 10.0
logits[2, 0, 0] = 10.0
payload = {
"ranking_confidence": 0.9,
"seqs": ["AA", "BB"],
"distogram": {
"bin_edges": np.array([4.0, 8.0, 12.0], dtype=np.float32),
"logits": logits,
},
}
with open(tmp_path / "result_model.pkl", "wb") as handle:
pickle.dump(payload, handle)
monkeypatch.setattr(distogram_parser_module, "datadir", str(tmp_path), raising=False)
parser = distogram_parser_module.distogram_parser()
contacts = parser.get_contacts("ignored", distance=9, pbtycutoff=0.5, cross_only=True)
assert len(contacts) == 1
assert contacts[0][0] == (1, "A")
assert contacts[0][1] == (1, "B")
assert contacts[0][2] > 0.99

View File

@@ -0,0 +1,118 @@
from pathlib import Path
import pytest
from alphapulldown.utils import file_handling
def test_temp_fasta_file_writes_contents_and_cleans_up():
with file_handling.temp_fasta_file(">seq\nACDE\n") as fasta_path:
path = Path(fasta_path)
assert path.is_file()
assert path.read_text(encoding="utf-8") == ">seq\nACDE\n"
assert not path.exists()
def test_convert_fasta_description_to_protein_name_sanitizes_symbols():
protein_name = file_handling.convert_fasta_description_to_protein_name(
">sp|P12345|Protein A:chain;1?"
)
assert protein_name == "sp_P12345_Protein_A_chain_1_"
def test_parse_fasta_parses_multiple_sequences_and_skips_blank_lines():
sequences, descriptions = file_handling.parse_fasta(
">protein one\nACD\n\nEF\n>protein|two\nGHI\n"
)
assert sequences == ["ACDEF", "GHI"]
assert descriptions == ["protein_one", "protein_two"]
def test_iter_seqs_yields_sequences_from_multiple_files(tmp_path):
fasta_a = tmp_path / "a.fasta"
fasta_b = tmp_path / "b.fasta"
fasta_a.write_text(">first\nAAA\n", encoding="utf-8")
fasta_b.write_text(">second\nBBB\n", encoding="utf-8")
records = list(file_handling.iter_seqs([str(fasta_a), str(fasta_b)]))
assert records == [("AAA", "first"), ("BBB", "second")]
def test_make_dir_monomer_dictionary_maps_files_to_their_source_dirs(tmp_path):
dir_a = tmp_path / "a"
dir_b = tmp_path / "b"
dir_a.mkdir()
dir_b.mkdir()
(dir_a / "proteinA.pkl").write_text("", encoding="utf-8")
(dir_b / "proteinB.pkl.xz").write_text("", encoding="utf-8")
result = file_handling.make_dir_monomer_dictionary([str(dir_a), str(dir_b)])
assert result == {
"proteinA.pkl": str(dir_a),
"proteinB.pkl.xz": str(dir_b),
}
def test_parse_csv_file_raises_for_missing_fasta():
with pytest.raises(FileNotFoundError):
file_handling.parse_csv_file("missing.csv", ["missing.fasta"], "/templates")
def test_parse_csv_file_returns_unique_records_without_clustering(tmp_path):
fasta = tmp_path / "proteins.fasta"
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
csv_path = tmp_path / "description.csv"
csv_path.write_text(
"proteinA,template1.cif,A\nproteinA,template2.cif,B\ninvalid,row\nmissing,template3.cif,C\n",
encoding="utf-8",
)
result = file_handling.parse_csv_file(
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=False
)
assert result == [
{
"protein": "proteinA.template1.cif.A",
"sequence": "ACDE",
"templates": [str(tmp_path / "templates" / "template1.cif")],
"chains": ["A"],
},
{
"protein": "proteinA.template2.cif.B",
"sequence": "ACDE",
"templates": [str(tmp_path / "templates" / "template2.cif")],
"chains": ["B"],
},
]
def test_parse_csv_file_clusters_multiple_templates_per_protein(tmp_path):
fasta = tmp_path / "proteins.fasta"
fasta.write_text(">proteinA\nACDE\n", encoding="utf-8")
csv_path = tmp_path / "description.csv"
csv_path.write_text(
"proteinA,template1.cif,A\nproteinA,template2.cif,B\n",
encoding="utf-8",
)
result = file_handling.parse_csv_file(
str(csv_path), [str(fasta)], str(tmp_path / "templates"), cluster=True
)
assert result == [
{
"protein": "proteinA",
"sequence": "ACDE",
"templates": [
str(tmp_path / "templates" / "template1.cif"),
str(tmp_path / "templates" / "template2.cif"),
],
"chains": ["A", "B"],
}
]

View File

@@ -0,0 +1,17 @@
import numpy as np
from alphapulldown.utils.plotting import plot_pae_from_matrix
def test_plot_pae_from_matrix_writes_png_file(tmp_path):
output = tmp_path / "pae.png"
plot_pae_from_matrix(
seqs=["AA", "BBB"],
pae_matrix=np.arange(25, dtype=float).reshape(5, 5),
figure_name=str(output),
ranking=2,
)
assert output.is_file()
assert output.stat().st_size > 0

View File

@@ -0,0 +1,94 @@
import gzip
import json
import pickle
from pathlib import Path
from alphapulldown.utils import post_modelling
def test_compress_file_gzips_and_removes_original(tmp_path):
file_path = tmp_path / "result.pkl"
file_path.write_bytes(b"payload")
gz_path = post_modelling.compress_file(str(file_path))
assert gz_path == str(file_path) + ".gz"
assert not file_path.exists()
with gzip.open(gz_path, "rb") as handle:
assert handle.read() == b"payload"
def test_compress_result_pickles_only_compresses_pickle_files(tmp_path):
pickle_path = tmp_path / "result.pkl"
text_path = tmp_path / "notes.txt"
pickle_path.write_bytes(b"payload")
text_path.write_text("keep", encoding="utf-8")
post_modelling.compress_result_pickles(str(tmp_path))
assert not pickle_path.exists()
assert (tmp_path / "result.pkl.gz").is_file()
assert text_path.is_file()
def test_remove_keys_from_pickle_updates_pickle_in_place(tmp_path):
file_path = tmp_path / "result.pkl"
with open(file_path, "wb") as handle:
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
post_modelling.remove_keys_from_pickle(str(file_path), ["distogram", "missing"])
with open(file_path, "rb") as handle:
payload = pickle.load(handle)
assert payload == {"keep": 1, "masked_msa": 3}
def test_remove_irrelevant_pickles_keeps_only_best_pickle(tmp_path):
best = tmp_path / "result_model_1.pkl"
other = tmp_path / "result_model_2.pkl"
note = tmp_path / "note.txt"
best.write_bytes(b"best")
other.write_bytes(b"other")
note.write_text("keep", encoding="utf-8")
post_modelling.remove_irrelevant_pickles(str(tmp_path), best.name)
assert best.is_file()
assert not other.exists()
assert note.is_file()
def test_post_prediction_process_remove_keys_then_compress_and_prune(tmp_path):
ranking_debug = tmp_path / "ranking_debug.json"
ranking_debug.write_text(json.dumps({"order": ["model_1"]}), encoding="utf-8")
result_best = tmp_path / "result_model_1.pkl"
result_other = tmp_path / "result_model_2.pkl"
with open(result_best, "wb") as handle:
pickle.dump({"keep": 1, "distogram": 2, "masked_msa": 3}, handle)
with open(result_other, "wb") as handle:
pickle.dump({"keep": 4, "aligned_confidence_probs": 5}, handle)
post_modelling.post_prediction_process(
str(tmp_path),
compress_pickles=True,
remove_pickles=True,
remove_keys=True,
)
assert not result_best.exists()
assert not result_other.exists()
gz_best = tmp_path / "result_model_1.pkl.gz"
assert gz_best.is_file()
with gzip.open(gz_best, "rb") as handle:
payload = pickle.load(handle)
assert payload == {"keep": 1}
def test_post_prediction_process_handles_missing_ranking_debug(tmp_path):
lonely_pickle = tmp_path / "result_model_1.pkl"
lonely_pickle.write_bytes(b"payload")
post_modelling.post_prediction_process(str(tmp_path), compress_pickles=True)
assert lonely_pickle.is_file()